Skip to content
Snippets Groups Projects
Commit 864a4011 authored by Lee McCuller's avatar Lee McCuller
Browse files

first pass at mapping-based precompute

parent 88db740b
No related branches found
No related tags found
1 merge request!95precomp decorator support
......@@ -7,7 +7,7 @@ from . import logger
from .trace import BudgetTrace
def precomp(*precomp_funcs):
def precomp(*precomp_funcs, **precomp_fmaps):
"""BudgetItem.calc decorator to add pre-computed functions
This is intended to decorate BudgetItem.calc() methods with
......@@ -36,10 +36,19 @@ def precomp(*precomp_funcs):
"""
def decorator(func):
if not hasattr(func, '_precomp'):
func._precomp = set()
for f in precomp_funcs:
func._precomp.add(f)
if precomp_funcs:
try:
func._precomp_list.extend(precomp_funcs)
except AttributeError:
#probably don't need to copy the **kwargs dict
func._precomp_list = list(precomp_funcs)
if precomp_fmaps:
try:
func._precomp_mapped.update(precomp_fmaps)
except AttributeError:
#probably don't need to copy the **kwargs dict
func._precomp_mapped = dict(precomp_fmaps)
return func
return decorator
......@@ -56,6 +65,38 @@ def quadsum(data):
return np.nansum(data, 0)
def _precomp_recurse_mapping(func, freq, ifo, _precomp):
"""
Recurses down functions which may themselves have precomp decorators, this
builds the **kwarg mapping to pass to the function call, and this mapping
is returned
"""
#run the prerequisite precomps first. These typically modify the ifo Struct (yuck)
for pc_func in getattr(func, '_precomp_list', []):
pc_map = _precomp_recurse_mapping(pc_func, freq, ifo, _precomp = _precomp)
#now call the function with the built mapping
pc_func(freq, ifo, **pc_map)
#now run the prerequisite mappings. These return values which get mapped
precomp_mapping = dict()
for name, pc_func in getattr(func, '_precomp_mapped', {}).items():
try:
PC = _precomp[pc_func]
except KeyError:
#not in _precomp already
pass
else:
precomp_mapping[name] = PC
continue
logger.debug("precomp {}".format(pc_func))
#build the mapping for the requisite call
pc_map = _precomp_recurse_mapping(pc_func, freq, ifo, _precomp = _precomp)
#now call the function with the built mapping
PC = pc_func(freq, ifo, **pc_map)
precomp_mapping[name] = PC
return precomp_mapping
class BudgetItem:
"""GWINC BudgetItem class
......@@ -78,11 +119,11 @@ class BudgetItem:
supplied with the `freq` and `ifo` attributes as arguments.
See the `precomp` documentation for more information.
The `_precomp` keyword argument is for internal use. If
provided, it is assumed to be a set of previously executed
precomp functions. Any function included in the set will not
be re-executed, and the set will be updated with any newly
executed functions.
The `_precomp` keyword argument is for internal use. If provided, it is
assumed to be a dictionary of previously executed precomp functions
mapped to their return values. Any function included in the dict will
not be re-executed, and the dict will be updated with any newly executed
functions.
This method can be overridden, but if it is, it's important to
make sure that the method defined in the base class is always
......@@ -94,13 +135,12 @@ class BudgetItem:
for key, val in kwargs.items():
setattr(self, key, val)
if _precomp is None:
_precomp = set()
for func in getattr(self.calc, '_precomp', []):
if func in _precomp:
continue
logger.debug("precomp {}".format(func))
func(self.freq, self.ifo)
_precomp.add(func)
_precomp = dict()
_PCmap = _precomp_recurse_mapping(self.calc, self.freq, self.ifo, _precomp = _precomp)
#PCmap is not used for this "dry run" update. _precomp could be cached?
_PCmap # I just refer to _PCmap here to appease the linter
return
def calc(self):
"""Overload method for final PSD calculation.
......@@ -203,7 +243,7 @@ class Noise(BudgetItem):
budget=budget,
)
def calc_trace(self, calibration=1, calc=True):
def calc_trace(self, calibration=1, calc=True, _precomp = None):
"""Calculate noise and return BudgetTrace object
`calibration` should either be a scalar or a len(self.freq)
......@@ -215,9 +255,14 @@ class Noise(BudgetItem):
trace style info.
"""
if _precomp is None:
_precomp = dict()
total = None
if calc:
total = self.calc() * calibration
PCmap = _precomp_recurse_mapping(self.calc, self.freq, self.ifo, _precomp)
total = self.calc(**PCmap) * calibration
return self._make_trace(psd=total)
def run(self, **kwargs):
......@@ -239,6 +284,7 @@ class Noise(BudgetItem):
self.load()
self._loaded = True
_precomp = dict()
ifo = kwargs.get('ifo', getattr(self, 'ifo'))
if ifo:
if not hasattr(ifo, '_orig_keys'):
......@@ -253,9 +299,9 @@ class Noise(BudgetItem):
kwargs['ifo'] = self.ifo
self._ifo_hash = ifo_hash
if kwargs:
self.update(**kwargs)
self.update(_precomp = precomp, **kwargs)
return self.calc_trace()
return self.calc_trace(_precomp = _precomp)
class Budget(Noise):
......@@ -459,15 +505,17 @@ class Budget(Noise):
"""
for key, val in kwargs.items():
setattr(self, key, val)
if _precomp is None:
_precomp = set()
_precomp = dict()
for name, item in itertools.chain(
self._cal_objs.items(),
self._noise_objs.items()):
logger.debug("update {}".format(item))
item.update(_precomp=_precomp, **kwargs)
def calc_noise(self, name, calibration=1, calc=True, _cals=None):
def calc_noise(self, name, calibration=1, calc=True, _cals=None, _precomp = None):
"""Return calibrated individual noise BudgetTrace.
The noise and calibration transfer functions are calculated,
......@@ -476,21 +524,25 @@ class Budget(Noise):
the noise.
"""
if _precomp is None:
_precomp = dict()
for cal in self._noise_cals[name]:
if _cals:
calibration *= _cals[cal]
else:
obj = self._cal_objs[cal]
logger.debug("calc {}".format(obj))
calibration *= obj.calc()
PCmap = _precomp_recurse_mapping(obj.calc, self.freq, self.ifo, _precomp)
calibration *= obj.calc(**PCmap)
noise = self._noise_objs[name]
logger.debug("calc {}".format(noise))
return noise.calc_trace(
calibration=calibration,
calc=calc,
_precomp=_precomp,
)
def calc_trace(self, calibration=1, calc=True):
def calc_trace(self, calibration=1, calc=True, _precomp = None):
"""Calculate all budget noises and return BudgetTrace object
`calibration` should either be a scalar or a len(self.freq)
......@@ -503,10 +555,13 @@ class Budget(Noise):
"""
_cals = {}
if _precomp is None:
_precomp = dict()
if calc:
for name, cal in self._cal_objs.items():
logger.debug("calc {}".format(cal))
_cals[name] = cal.calc()
PCmap = _precomp_recurse_mapping(cal.calc, self.freq, self.ifo, _precomp)
_cals[name] = cal.calc(**PCmap)
budget = []
for name in self._noise_objs:
......@@ -515,6 +570,7 @@ class Budget(Noise):
calibration=calibration,
calc=calc,
_cals=_cals,
_precomp=_precomp,
)
budget.append(trace)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment