Source code for do_mpc.sampling._samplingplanner

import types
import pickle
import os
import numpy as np
import pathlib
import pdb
import scipy.io as sio
import copy
import itertools
from do_mpc.tools import load_pickle, save_pickle
from typing import Union,Callable

[docs] class SamplingPlanner: """A class for generating sampling plans. These sampling plans will be executed by :py:class:`do_mpc.sampling.Sampler` to generate data. The class can be created with optional keyword arguments which are passed to :py:meth:`set_param`. **Configuration and sampling plan generation:** 1. Set variables which should be sampled with :py:func:`set_sampling_var`. 2. (Optional) Set further options of the SamplingPlanner with :py:meth:`set_param` 3. Generate the sampling plan with :py:func:`gen_sampling_plan`. 4. And / or: Add specific sampling case with :py:meth:`add_sampling_case`. 5. Export the plan with all sampling cases with :py:meth:`export` """ def __init__(self, **kwargs): self.sampling_vars = [] self.sampling_var_names = [] self.sampling_plan = [] # Parameters that can be set for the SamplingPlanner: self.data_fields = [ 'overwrite', 'id_precision', ] self.data_dir = './' self.overwrite = False self.id_precision = 3 if kwargs: self.set_param(**kwargs) @property def data_dir(self): """Set the save directory for the ``samplingplan``. If the directory does not exist yet, it is created. If the directory is nested all (non-existing) parent folders are also created. **Example:** :: sp = do_mpc.sampling.SamplingPlanner() sp.data_dir = './samples/experiment_1/' This will set the directory to the indicated path. If the path does not exist, all folders are created. """ return self._data_dir @data_dir.setter def data_dir(self, val): self._data_dir = val pathlib.Path(val).mkdir(parents=True, exist_ok=True) def set_param(self, **kwargs)->None: """Set the parameters of the :py:class:`SamplingPlanner` class. Parameters must be passed as pairs of valid keywords and respective argument. For example: :: sp.set_param(overwrite = True) It is also possible and convenient to pass a dictionary with multiple parameters simultaneously as shown in the following example: :: setup_dict = { 'overwrite': True, 'save_format': pickle, } sp.set_param(**setup_dict) This makes use of thy python "unpack" operator. See `more details here`_. .. _`more details here`: https://codeyarns.github.io/tech/2012-04-25-unpack-operator-in-python.html Note: :py:func:`set_param` can be called multiple times. Previously passed arguments are overwritten by successive calls. The following parameters are available: Args: overwrite(bool): Overwrites existing samplingplan under the same name, if set to ``True``. id_precision(str): Padding for IDs of created samples. Defaults to 3. This means sample 20 will be denoted as 020. """ for key, value in kwargs.items(): if not (key in self.data_fields): print('Warning: Key {} does not exist for SamplingPlanner.'.format(key)) else: setattr(self, key, value) def set_sampling_var(self, name:str, fun_var_pdf:Callable[[],Union[float,int]]=None)->None: """Introduce new sampling variables to the :py:class:`SamplingPlanner`. Define variable name. Optionally add a function to generate values for the sampled variable (e.g. following some distribution). The parameter ``fun_var_pdf`` defaults to ``None``. Note: If no value-generating function is passed (for any of the introduced variables), all sampling cases must be created manually with :py:meth:`add_sampling_case`. Note: Value generating function ``fun_var_pdf`` must not require inputs. **Example:** :: sp = do_mpc.sampling.SamplingPlanner() # Plan with two variables alpha and beta: sp.set_sampling_var('alpha', np.random.randn) sp.set_sampling_var('beta', lambda: np.random.randint(0,5)) In the example we have passed a ``BuiltinFunction`` for the introduced variable ``alpha``. We use the function that created values from the random normal distribution with zero mean and unity covariance. For the variable ``beta`` we created a new lambda function that draws random integers from 0 to 5. Args: name: Name of the sampled variable fun_var_pdf: Declare the value-generating function of the sampled variable Raises: assertion: ``name`` must be string assertion: ``fun_var_pdf`` must be Function or BuiltinFunction """ assert isinstance(name, str), 'name must be str, you have {}'.format(type(name)) assert isinstance(fun_var_pdf, (types.FunctionType, types.BuiltinFunctionType, type(None))), 'fun_var_pdf must be either Function or BuiltinFunction_or_Method or None, you have {}'.format(type(fun_var_pdf)) self.sampling_vars.append({'name':name, 'fun_var_pdf':fun_var_pdf}) self.sampling_var_names.append(name) def add_sampling_case(self, **kwargs)->list: """ Manually add sampling case with user-defined values. Create a sampling case by choosing values for the previously introduced sampling variables (with :py:meth:`set_sampling_var`). Method takes arbitrary (keyword, argument) pairs, where the keywords must refer to previously introduced sampling variables. :py:meth:`add_sampling_case` will automatically augment the sampling case with values for variables that are not passed as arguments. This only works if these variables were created with the argument ``fun_var_pdf``. **Example:** :: sp = do_mpc.sampling.SamplingPlanner() # Plan with two variables alpha and beta: sp.set_sampling_var('alpha', np.random.randn) sp.set_sampling_var('beta', lambda: np.random.randint(0,5)) # Create two new sampling cases, missing variable is auto-generated: sp.add_sampling_case(alpha=1) sp.add_sampling_case(beta= 0) Returns: Returns the newly created sampling plan. """ # Create each sampling case as dict: temp_dic = {} # Iterate over all the key value pairs added in the method call: for key, value in kwargs.items(): if not (key in self.sampling_var_names): raise Exception('{} is not a valid sampling variable. Introduce sampling variables with set_sampling_var.'.format(key)) else: temp_dic[key] = value # Add key value pairs for all keys that were not referenced in add_sampling_case: for var in self.sampling_vars: if var['name'] not in kwargs.keys(): # Augmentation is not possible if value generating function is not supplied. assert var['fun_var_pdf'] is not None, 'Cannot augment sampling_case for missing variable {}. Variable generating function is missing.'.format(var['name']) temp_dic[var['name']] = var['fun_var_pdf']() # Generate string ID of sampling case based on index and pad with zeros: id = len(self.sampling_plan) temp_dic['id'] = str(id).zfill(self.id_precision) self.sampling_plan.append(temp_dic) return self.sampling_plan def gen_sampling_plan(self, n_samples:int)->list: """Generate the sampling plan. The generated plan contains ``n_samples`` samples based on the defined variables and the corresponding evaluation functions. Args: n_samples: The number of generated samples Raises: assertion: n_samples must be int Returns: Returns the newly created sampling plan. """ assert isinstance(n_samples, int), 'n_samples must be int, you have {}'.format(type(n_samples)) assert n_samples>0, 'n_samples must be larger than 0.' for i in range(n_samples): self.add_sampling_case() return self.sampling_plan def product(self, **kwargs:dict)->list: """Cartesian product of input variables. This method is inspired by `itertools.product <https://docs.python.org/3/library/itertools.html#itertools.product>`_. Must pass a list for each ``sampling_var`` that should be considered. Not all ``sampling_vars`` must be referenced. Sampling vars that are excluded, will generate a value according to their assigned ``fun_var_pdf`` (see :py:meth:`set_sampling_var`). Args: kwargs: Keyword arguments of the form ``var_name=var_values``. Returns: Returns the newly created sampling plan. """ # Check if all key word values are lists: check = np.alltrue([isinstance(v, list) for v in kwargs.values()]) if not check: raise ValueError('keyword values must be lists') # Check if all key words are existing sampling variables: keys = kwargs.keys() check = np.alltrue([v in self.sampling_var_names for v in keys]) if not check: raise ValueError('keyword names must be existing sampling variables') # Create cartesian product of all values passen in kwargs: values = list(itertools.product(*list(kwargs.values()))) # Create new sampling cases: for value in values: # Zip together the value(s) of the current case with the respective keys: case = dict(zip(keys, value)) # Add sampling case self.add_sampling_case(**case) return self.sampling_plan def export(self, sampling_plan_name:str)->None: """Export SamplingPlan in pickle format. Pass ``sampling_plan_name`` without any path. File extension can be added (but will be stripped automatically). Change the path with :py:attr:`data_dir`. Args: sampling_plan_name: Name of the exported sampling plan file. Raises: assertion: ``sampling_plan_name`` must be string. """ assert isinstance(sampling_plan_name, str), 'sampling_plan_name must be of type str. You have {}.'.format(type(sampling_plan_name)) # Strip file extension from name: sampling_plan_name = os.path.splitext(sampling_plan_name)[0] full_name = self.data_dir + sampling_plan_name + '.pkl' if not os.path.isfile(full_name) or self.overwrite: save_pickle(full_name, self.sampling_plan) else: for i in range(1,10000): full_name = self.data_dir + sampling_plan_name + str(i) + '.pkl' if not os.path.isfile(full_name): save_pickle(full_name, self.sampling_plan) break