Source code for do_mpc.sampling._sampler

import types
import pickle
import os
import inspect
import numpy as np
import pathlib
import pdb
import scipy.io as sio
import copy
from do_mpc.tools import load_pickle, save_pickle, printProgressBar
from typing import Union,Callable

[docs] class Sampler: """Generate samples based on a sampling plan. Initiate the class by passing a :py:class:`do_mpc.sampling.SamplingPlanner` (``sampling_plan``) object. The class can be configured to create samples based on the defined cases in the ``sampling_plan``. The class can be created with optional keyword arguments which are passed to :py:meth:`set_param`. **Configuration and sampling:** 1. (Optional) use :py:meth:`set_param` to configure the class. Use :py:attr:`data_dir` to choose the save location for the samples. 2. Set the sample generating function with :py:meth:`set_sample_function`. This function is executed for each of the samples in the ``sampling_plan``. 3. Use :py:meth:`sample_data` to generate all samples defined in the ``sampling_plan``. A new file is written for each sample. 4. **Or:** Create an individual sample result with :py:meth:`sample_idx`, where an index (``int``) referring to the ``sampling_plan`` determines the sampled case. Note: By default, the :py:class:`Sampler` will only create samples that do not already exist in the chosen :py:attr:`data_dir`. **Example:** :: sp = do_mpc.sampling.SamplingPlanner() # Plan with two variables alpha and beta: sp.set_sampling_var('alpha', np.random.randn) sp.set_sampling_var('beta', lambda: np.random.randint(0,5)) plan = sp.gen_sampling_plan(n_samples=10) sampler = do_mpc.sampling.Sampler(plan) # Sampler computes the product of two variables alpha and beta # that were created in the SamplingPlanner: def sample_function(alpha, beta): return alpha*beta sampler.set_sample_function(sample_function) sampler.sample_data() """ def __init__(self, sampling_plan:list, **kwargs): assert isinstance(sampling_plan, list), 'sampling_plan must be a list' assert np.all([isinstance(plan_i, dict) for plan_i in sampling_plan]), 'All elements of sampling plan must be a dictionary.' self.sampling_plan = sampling_plan self.sampling_vars = list(sampling_plan[0].keys()) self.n_samples = len(sampling_plan) self.completion_list = [] self.flags = { 'set_sample_function': False, } # Parameters that can be set for the Sampler: self.data_fields = [ 'overwrite', 'sample_name', 'save_format', 'print_progress' ] self.data_dir = './' self.sample_name = 'sample' self.save_format = 'pickle' self.overwrite = False self.print_progress = True self.n_processes = 1 if kwargs: self.set_param(**kwargs) @property def data_dir(self): """Set the save directory for the results. If the directory does not exist yet, it is created. If the directory is nested all (non-existing) parent folders are also created. **Example:** :: sampler = do_mpc.sampling.Sampler() sampler.data_dir = './samples/experiment_1/' This will set the directory to the indicated path. If the path does not exist, all folders are created. """ return self._data_dir @data_dir.setter def data_dir(self, val): self._data_dir = val pathlib.Path(val).mkdir(parents=True, exist_ok=True) def set_param(self, **kwargs)->None: """Configure the :py:class:`do_mpc.sampling.Sampler` class. Parameters must be passed as pairs of valid keywords and respective argument. For example: :: sampler.set_param(overwrite = True) Args: overwrite(bool): Should previously created results be overwritten. Default is ``False`` sample_name(str): Naming scheme for samples. save_format(str): Choose either ``pickle`` or ``mat``. print_progress(bool): Print progress-bar to terminal. Default is ``True``. """ for key, value in kwargs.items(): if not (key in self.data_fields): print('Warning: Key {} does not exist for Sampler.'.format(key)) else: setattr(self, key, value) def set_sample_function(self, sample_function:Callable[[Union[types.FunctionType,types.BuiltinFunctionType], Union[types.FunctionType,types.BuiltinFunctionType]], Union[types.FunctionType,types.BuiltinFunctionType]] )->None: """ Set sample generating function. The sampling function produces a sample result for each sample definition in the ``sampling_plan`` and is called in the method :py:meth:`sample_data`. It is important that the sample function only uses keyword arguments **with the same name as previously defined** in the ``sampling_plan``. **Example:** :: sp = do_mpc.sampling.SamplingPlanner() sp.set_sampling_var('alpha', np.random.randn) sp.set_sampling_var('beta', lambda: np.random.randint(0,5)) sampler = do_mpc.sampling.Sampler(plan) def sample_function(alpha, beta): return alpha*beta sampler.set_sample_function(sample_function) Args: sample_function: Function to create each sample of the sampling plan. """ assert isinstance(sample_function, (types.FunctionType, types.BuiltinFunctionType)), 'sample_function must be a function' dset = set(inspect.getfullargspec(sample_function).args) - set(self.sampling_vars) assert len(dset) == 0, 'sample_function must only contain keyword arguments that appear as sample vars in the sampling_plan. You have the unknown arguments: {}'.format(dset) self.sample_function = sample_function self.flags['set_sample_function'] = True def _save_name(self, sample_id): """Private method. Used in :py:meth:`sample_data`. Creates the name for a given sample based on the sample plan name and the sample id. """ name = '{sample_name}_{id}'.format(sample_name=self.sample_name, id=sample_id) if self.save_format == 'pickle': save_name = name + '.pkl' elif self.save_format == 'mat': save_name = name+'.mat' return save_name def _save(self, save_name, result): """Private method. Saves the result for a single sample in the defined format. Considers the ``overwrite`` parameter to check if existing results should be overwritten. """ if not os.path.isfile(self.data_dir + save_name) or self.overwrite: if self.save_format == 'pickle': save_pickle(self.data_dir + save_name, result) elif self.save_format == 'mat': sio.savemat(self.data_dir + save_name, {'res': result}) def sample_idx(self, idx:int)->None: """Sample case based on the index of the sample. Args: idx: Index of the ``sampling_plan`` for which the sample should be created. Raises: assertion: Index must be between 0 and ``n_samples``. assertion: sample_function must be set prior to sampling data. """ #assert isinstance(idx, int), 'idx must be of type index' assert self.flags['set_sample_function'], 'Cannot sample before setting the sample function with Sampler.set_sample_function' assert idx>=0 and idx<=len(self.sampling_plan), 'Invalid value for idx. Must be between 0 and {}. You have {}'.format(len(self.sampling_plan), idx) # Pop sample id from dictionary (not an argument to the sample function) sample_i = copy.copy(self.sampling_plan[idx]) sample_id = sample_i.pop('id') # Create and safe result if sample result does not exist: save_name = self._save_name(sample_id) if not os.path.isfile(self.data_dir + save_name) or self.overwrite: # Call sample function to create sample (pass sample information) result = self.sample_function(**sample_i) # Save true results: self._save(save_name, result) # Add id to completion list: self.completion_list.append(sample_id) if self.print_progress: printProgressBar(len(self.completion_list), self.n_samples, prefix = 'Progress:', suffix = 'Complete', length = 50) def sample_data(self)->None: """Sample data after having configured the :py:class:`Sampler`. No user input is required and the method will iterate through all the items defined in the ``sampling_plan`` (obtained with :py:class:`do_mpc.sampling.SamplingPlanner`). Note: Depending on your ``sample_function`` (set with :py:meth:`set_sample_function`) and the total number of samples, executing this method may take some time. Note: If ``sampler.set_param(overwrite = False)`` (default) data will only be sampled for instances that do not yet exist. """ for i, _ in enumerate(self.sampling_plan): self.sample_idx(i)