Source code for do_mpc.sampling.samplingplanner

import types
import pickle
import os
import numpy as np
import pathlib
import pdb
import as sio
import copy
from import load_pickle, save_pickle

[docs]class SamplingPlanner: """A class for generating sampling plans. These sampling plans will be executed by :py:class:`do_mpc.sampling.sampler.Sampler` to generate data. **Configuration and sampling plan generation:** 1. Set variables which should be sampled with :py:func:`set_sampling_var`. 2. (Optional) Set further options of the SamplingPlanner with :py:meth:`set_param` 3. Generate the sampling plan with :py:func:`gen_sampling_plan`. 4. And / or: Add specific sampling case with :py:meth:`add_sampling_case`. 5. Export the plan with all sampling cases with :py:meth:`export` """ def __init__(self): self.sampling_vars = [] self.sampling_var_names = [] self.sampling_plan = [] # Parameters that can be set for the SamplingPlanner: self.data_fields = [ 'overwrite', 'id_precision', ] self.data_dir = './' self.overwrite = False self.id_precision = 3 @property def data_dir(self): """Set the save directory for the ``samplingplan``. If the directory does not exist yet, it is created. If the directory is nested all (non-existing) parent folders are also created. **Example:** :: sp = do_mpc.sampling.SamplingPlanner() sp.data_dir = './samples/experiment_1/' This will set the directory to the indicated path. If the path does not exist, all folders are created. """ return self._data_dir @data_dir.setter def data_dir(self, val): self._data_dir = val pathlib.Path(val).mkdir(parents=True, exist_ok=True) def set_param(self, **kwargs): """Set the parameters of the :py:class:`SamplingPlanner` class. Parameters must be passed as pairs of valid keywords and respective argument. For example: :: sp.set_param(overwrite = True) It is also possible and convenient to pass a dictionary with multiple parameters simultaneously as shown in the following example: :: setup_dict = { 'overwrite': True, 'save_format': pickle, } sp.set_param(**setup_dict) This makes use of thy python "unpack" operator. See `more details here`_. .. _`more details here`: .. note:: :py:func:`set_param` can be called multiple times. Previously passed arguments are overwritten by successive calls. The following parameters are available: :param overwrite: Overwrites existing samplingplan under the same name, if set to ``True``. :type overwrite: bool :param id_precision: Padding for IDs of created samples. Defaults to 3. This means sample 20 will be denoted as 020. :type id_precision: string """ for key, value in kwargs.items(): if not (key in self.data_fields): print('Warning: Key {} does not exist for SamplingPlanner.'.format(key)) else: setattr(self, key, value) def set_sampling_var(self, name, fun_var_pdf=None): """Introduce new sampling variables to the :py:class:`SamplingPlanner`. Define variable name. Optionally add a function to generate values for the sampled variable (e.g. following some distribution). The parameter ``fun_var_pdf`` defaults to ``None``. .. note:: If no value-generating function is passed (for any of the introduced variables), all sampling cases must be created manually with :py:meth:`add_sampling_case`. .. note:: Value generating function ``fun_var_pdf`` must not require inputs. **Example:** :: sp = do_mpc.sampling.SamplingPlanner() # Plan with two variables alpha and beta: sp.set_sampling_var('alpha', np.random.randn) sp.set_sampling_var('beta', lambda: np.random.randint(0,5)) In the example we have passed a ``BuiltinFunction`` for the introduced variable ``alpha``. We use the function that created values from the random normal distribution with zero mean and unity covariance. For the variable ``beta`` we created a new lambda function that draws random integers from 0 to 5. :param name: Name of the sampled variable :type name: string :param fun_var_pdf: Declare the value-generating function of the sampled variable :type fun_var_pdf: Function of BuiltinFunction :raises assertion: ``name`` must be string :raises assertion: ``fun_var_pdf`` must be Function or BuiltinFunction """ assert isinstance(name, str), 'name must be str, you have {}'.format(type(name)) assert isinstance(fun_var_pdf, (types.FunctionType, types.BuiltinFunctionType, type(None))), 'fun_var_pdf must be either Function or BuiltinFunction_or_Method or None, you have {}'.format(type(fun_var_pdf)) self.sampling_vars.append({'name':name, 'fun_var_pdf':fun_var_pdf}) self.sampling_var_names.append(name) def add_sampling_case(self, **kwargs): """ Manually add sampling case with user-defined values. Create a sampling case by choosing values for the previously introduced sampling variables (with :py:meth:`set_sampling_var`). Method takes arbitrary (keyword, argument) pairs, where the keywords must refer to previously introduced sampling variables. :py:meth:`add_sampling_case` will automatically augment the sampling case with values for variables that are not passed as arguments. This only works if these variables were created with the argument ``fun_var_pdf``. **Example:** :: sp = do_mpc.sampling.SamplingPlanner() # Plan with two variables alpha and beta: sp.set_sampling_var('alpha', np.random.randn) sp.set_sampling_var('beta', lambda: np.random.randint(0,5)) # Create two new sampling cases, missing variable is auto-generated: sp.add_sampling_case(alpha=1) sp.add_sampling_case(beta= 0) """ # Create each sampling case as dict: temp_dic = {} # Iterate over all the key value pairs added in the method call: for key, value in kwargs.items(): if not (key in self.sampling_var_names): raise Exception('{} is not a valid sampling variable. Introduce sampling variables with set_sampling_var.'.format(key)) else: temp_dic[key] = value # Add key value pairs for all keys that were not referenced in add_sampling_case: for var in self.sampling_vars: if var['name'] not in kwargs.keys(): # Augmentation is not possible if value generating function is not supplied. assert var['fun_var_pdf'] is not None, 'Cannot augment sampling_case for missing variable {}. Variable generating function is missing.'.format(var['name']) temp_dic[var['name']] = var['fun_var_pdf']() # Generate string ID of sampling case based on index and pad with zeros: id = len(self.sampling_plan) temp_dic['id'] = str(id).zfill(self.id_precision) self.sampling_plan.append(temp_dic) return self.sampling_plan def gen_sampling_plan(self, n_samples): """Generate the sampling plan. The generated plan contains ``n_samples`` samples based on the defined variables and the corresponding evaluation functions. :param n_samples: The number of generated samples :type n_samples: int :raises assertion: n_samples must be int :return: Returns the newly created sampling plan. :rtype: list """ assert isinstance(n_samples, int), 'n_samples must be int, you have {}'.format(type(n_samples)) assert n_samples>0, 'n_samples must be larger than 0.' for i in range(n_samples): self.add_sampling_case() return self.sampling_plan def export(self, sampling_plan_name): """Export SamplingPlan in pickle format. Pass ``sampling_plan_name`` without any path. File extension can be added (but will be stripped automatically). Change the path with :py:attr:`data_dir`. :param sampling_plan_name: Name of the exported sampling plan file. :type sampling_plan_name: str :raises assertion: ``sampling_plan_name`` must be string. """ assert isinstance(sampling_plan_name, str), 'sampling_plan_name must be of type str. You have {}.'.format(type(sampling_plan_name)) # Strip file extension from name: sampling_plan_name = os.path.splitext(sampling_plan_name)[0] full_name = self.data_dir + sampling_plan_name + '.pkl' if not os.path.isfile(full_name) or self.overwrite: save_pickle(full_name, self.sampling_plan) else: for i in range(1,10000): full_name = self.data_dir + sampling_plan_name + str(i) + '.pkl' if not os.path.isfile(full_name): save_pickle(full_name, self.sampling_plan) break