Source code for imagine.simulators.simulator

# %% IMPORTS
# Built-in imports
import abc

# Package imports
import numpy as np

# IMAGINE imports
from imagine.observables import Measurements, Simulations
from imagine.tools import BaseClass, req_attr

# All declaration
__all__ = ['Simulator']


# %% CLASS DEFINITIONS
[docs]class Simulator(BaseClass, metaclass=abc.ABCMeta): """ Simulator base class New Simulators must be introduced by sub-classing the present class. Overriding the method :py:meth:`simulate` to convert a list of fields into simulated observables. For more details see :ref:`components:Simulators` section of the documentation. Parameters ---------- measurements : imagine.Measurements An observables dictionary containing the set of measurements that will be used to prepare the mock observables Attributes ---------- grid : imagine.Basegrid Grid object where the fields were evaluated (NB if a common grid is not being used, this is set to None grids: imagine.Basegrid Grid objects for each individual field None if common grid is being used) fields : dict Dictionary containing field types as keys and the sum of evaluated fields as values observables : list List of Observable keys output_units : astropy.units.Unit Output units used in the simulator """ def __init__(self, measurements): # Call super constructor super().__init__() self.grid = None self.grids = None self.fields = None self.field_checklist = {} self.field_parameters = {} self.controllist = {} self.observables = [] self.output_coords = {} self.output_units = {} self.output_type = {} self._ensemble_size = None self.register_observables(measurements)
[docs] def register_observables(self, measurements): """ Called during initalization to store the relevant information in the measurements dictionary Parameters ---------- measurements : imagine.Measurements An observables dictionary containing the set of measurements that will be used to prepare the mock observables """ assert isinstance(measurements, Measurements) for key in measurements.keys(): if key[0] in self.simulated_quantities: # Using a list for keys as preserving the _order_ may, # perhaps, be useful later self.observables.append(key) self.output_coords[key] = measurements[key].coords self.output_type[key] = measurements[key].otype self.output_units[key] = measurements[key].unit assert len(self.observables) > 0, 'No valid observable was requested!'
@property def ensemble_size(self): return self._ensemble_size
[docs] def register_ensemble_size(self, field_list): """ Checks whether fields have consistent ensemble size and stores this information """ set_of_field_sizes = {f.ensemble_size for f in field_list} assert len(set_of_field_sizes) == 1, 'All fields should have the same ensemble size' self._ensemble_size = set_of_field_sizes.pop()
[docs] def prepare_fields(self, field_list, i): """ Registers the available fields checking whether all requirements are satisfied. all data is saved on a dictionary, simulator.fields, where field_types are keys. The `fields` dictionary is reconstructed for *each realisation* of the ensemble. It relies on caching within the Field objects to avoid computing the same quantities multiple times. If there is more than one field of the same type, they are summed together. Parameters ---------- field_list : list List containing Field objects i : int Index of the realisation of the fields that is being registred """ self.fields = {} self.grid = None self.grids = None self.field_checklist = {} self.controllist = {} sorted_field_list = self._sort_field_dependencies(field_list) for field in sorted_field_list: if field.type in (*self.required_field_types, *self.optional_field_types): # Checks whether the grid_type is correct if ((field.grid is not None) and (self.allowed_grid_types is not None)): assert field.grid.grid_type in self.allowed_grid_types, 'Grid type not allowed' # Checks whether the grids are consistent # (if fields were evaluated on the same grid) if self.use_common_grid and (field.type != 'dummy'): if self.grid is None: self.grid = field.grid assert self.grid is field.grid, 'Multiple grids when a common grid is required' else: if self.grids is None: self.grids = {} elif field.type in self.grids: assert self.grids[field.type] is field.grid, 'Fields of the same type must have the same grid' else: self.grids[field.type] = field.grid # Organises dependencies dependencies = {} for dep in field.dependencies_list: if isinstance(dep, str): # If a string is used, dep is actually a field_type dependencies[dep] = self.fields[dep] else: # Otherwise, dep corresponds to a class for other_field in sorted_field_list: if other_field is field: continue # Stores the requested object in the dictionary if isinstance(other_field, dep): dependencies[dep] = other_field break # Finally, stores the field if field.type not in self.fields: # Stores the data self.fields[field.type] = field.get_data(i, dependencies) # Stores the parameters list self.field_parameters[field.type] = field.parameter_names # Stores the checklist (for dummies only) if field.type == 'dummy': self.field_checklist = field.field_checklist.copy() elif field.type != 'dummy': # If multiple fields of the same type are present, sums them up self.fields[field.type] = (self.fields[field.type] + field.get_data(i, dependencies)) # NB the '+=' has *not* been used to changes in the original data # due to its 'inplace' nature else: # For multiple dummies, parameters provided by _get_data are # combined (taking care to avoid modifying the original object) self.fields[field.type] = self.fields[field.type].copy() self.fields[field.type].update(field.get_data(i, dependencies)) # The checklists are also combined self.field_checklist.update(field.field_checklist) if field.type == 'dummy': self.controllist[field.name] = field.simulator_controllist # Makes sure all required fields were included assert set(self.required_field_types) <= set(self.fields.keys()), 'Missing required field'
def _sort_field_dependencies(self, fields): """ Reorders a fields list so that dependencies are evaluated before the dependent fields. Parameters ---------- fields : list List of Fields which may contain dependencies Returns ------- sorted_fields : list List of sorted Fields """ independent_fields, dependencies = self._find_dependencies(fields) sorted_fields = self._solve_dependencies(independent_fields, dependencies) return sorted_fields def _find_dependencies(self, fields): """ Reads a list of Fields and constructs a list of independent fields and dictionary containing all field depenencies. Dependencies on 'field_type' are converted to dependencies on classes. Parameters ---------- fields : list Initial list of fields Returns ------- independent_fields_list : list List containing fields with no dependencies dependencies : dict Dictionary containing field objects as keys and the classes they depend on as values. """ field_types = {} dependencies = {} independent_fields_list = [] # Prepares field_type and dependencies dictionaries for field in fields: ftype = field.type fclass = type(field) fdep = field.dependencies_list if ftype not in field_types: field_types[ftype] = {fclass} else: field_types[ftype].add(fclass) if len(fdep) == 0: independent_fields_list.append(field) else: if field not in dependencies: dependencies[field] = set(fdep) else: dependencies[field].update(fdep) # Subsititutes any field type string by field classes for deps in dependencies.values(): for dep in set(deps): if isinstance(dep, str): deps.remove(dep) deps.update(field_types[dep]) return independent_fields_list, dependencies def _solve_dependencies(self, independent_fields, dependencies, max_iter=100, overwrite=True): """ Applied basic topological sorting to the field dependenceis Parameters ---------- independent_fields : list List of field objects with no dependencies dependencies : dict Dictionary containing field objects as keys and the classes they depend on as values. max_iter : int, optional Maximum number of iterations while trying to solve the dependencies overwrite : bool If True (default),`independent_fields` and `dependencies` will be modyfied by this method Returns ------- L : list Sorted list of field object """ L = [] # Empty list that will contain the sorted elements if overwrite: S = independent_fields # Set of all nodes with no incoming edge deps = dependencies else: from copy import deepcopy S = independent_fields.copy() deps = deepcopy(dependencies) counter = 0 while S: counter += 1 assert counter < max_iter, 'Error: too many iterations' # Removes a node n from S n = S.pop() # Appends n to tail of L L.append(n) # Goes through all the nodes for m in list(deps.keys()): edges = deps[m] # If n is in the edges, removes it if type(n) in edges: edges.remove(type(n)) # If there are no edges, add it to the # independent nodes list if not edges: S.append(m) del deps[m] assert not deps, 'There is a cyclical Field dependency!' return L @property @req_attr def simulated_quantities(self): """ Must be overriden with a list or set of simulated quantities this Simulator produces. Example: ['fd', 'sync'] """ return(self.SIMULATED_QUANTITIES) @property @req_attr def required_field_types(self): """ Must be overriden with a list or set of required field types that the Simulator needs. Example: ['magnetic_field', 'cosmic_ray_electron_density'] """ return(self.REQUIRED_FIELD_TYPES) @property def optional_field_types(self): """ Can be overriden with a list or set of field types that Simulator can use if available. Example: ['magnetic_field', 'cosmic_ray_electron_density'] """ return(getattr(self, 'OPTIONAL_FIELD_TYPES', [])) @property @req_attr def allowed_grid_types(self): """ Must be overriden with a list or set of allowed grid types that work with this Simulator. Example: ['cartesian'] """ return self.ALLOWED_GRID_TYPES @property def use_common_grid(self): """ Must be overriden with a list or set of allowed grid types that work with this Simulator. Example: ['cartesian'] """ return getattr(self, 'USE_COMMON_GRID', True)
[docs] @abc.abstractmethod def simulate(self, key, coords_dict, realization_id, output_units): """ Must be overriden with a function that returns the observable described by `key` using the fields in self.fields, in units `output_units`. Parameters ---------- key : tuple Observable key in the standard form ``(data-name,str(data-freq),str(data-Nside)/"tab",str(ext))`` coords_dict : dictionary Dictionary containing coordinates associated with the observable (or None for HEALPix datasets). Nside : int HEALPix Nside parameter for HEALPix datasets (or None for tabular datasets). output_units : astropy.units.Unit The physical units that should be used for this mock observable Returns ------- numpy.ndarray 1D *pure* numpy array of length compatible with Nside or coords_dict containing the mock observable in the output_units. """ raise NotImplementedError
[docs] def __call__(self, field_list): """ Runs the simulator over a Fields list Parameters ---------- field_list : list List of imagine.Field object which must include all the `required_field_types` Returns ------- sims : imagine.Simulations A Simulations object containing all the specified mock data """ sims = Simulations() self.register_ensemble_size(field_list) for i in range(self._ensemble_size): # Prepares all fields self.prepare_fields(field_list, i) for key in self.observables: sim = self.simulate(key=key, coords_dict=self.output_coords[key], realization_id=i, output_units=self.output_units[key]) sims.append(name=key, data=sim[np.newaxis, :].to(self.output_units[key]), coords=self.output_coords[key], otype=self.output_type[key]) return sims