Skip to content
Snippets Groups Projects
Commit b1c54cc6 authored by Lee McCuller's avatar Lee McCuller
Browse files

added tests, inheritance and some docs

parent 7c146498
No related branches found
No related tags found
No related merge requests found
......@@ -58,6 +58,11 @@ singlehtml:
@echo
@echo "Build finished. The HTML page is in $(BUILDDIR)/singlehtml."
optimizer-dcc:
$(SPHINXBUILD) -D master_doc="optimizer/overview" -E -b singlehtml $(ALLSPHINXOPTS) $(BUILDDIR)/optimizer-dcc
@echo
@echo "Build finished. The HTML page is in $(BUILDDIR)/singlehtml."
pickle:
$(SPHINXBUILD) -b pickle $(ALLSPHINXOPTS) $(BUILDDIR)/pickle
@echo
......
......@@ -25,7 +25,7 @@ sys.path.insert(0, os.path.abspath('..'))
#gwinc must be importable to build the docs properly anyway, using apidoc, so
#import it now for the __version__ parameter
import gwinc
import gwinc # noqa
# -- General configuration ------------------------------------------------
......
......@@ -90,7 +90,7 @@ def load_budget(name_or_path, freq=None, bname=None):
If `bname` is specified the Budget class with that name will be
loaded from the budget module. Otherwise, the Budget class with
the same name as the budget module will be load.
the same name as the budget module will be loaded.
If the budget is a package directory which includes an 'ifo.yaml'
file the ifo Struct will be loaded from that file and assigned to
......@@ -118,9 +118,23 @@ def load_budget(name_or_path, freq=None, bname=None):
if ext in Struct.STRUCT_EXT:
logger.info("loading struct {}...".format(path))
ifo = Struct.from_file(path)
bname = 'aLIGO'
modname = 'gwinc.ifo.aLIGO'
inherit_ifo = ifo.get('inherit', None)
if inherit_ifo is not None:
del ifo['inherit']
#make the inherited path relative to the loaded path
#if it is a yml file
if os.path.splitext(inherit_ifo)[1] in Struct.STRUCT_EXT:
base = os.path.split(path)[0]
inherit_ifo = os.path.join(base, inherit_ifo)
inherit_budget = load_budget(inherit_ifo, freq=freq, bname=bname)
pre_ifo = inherit_budget.ifo
pre_ifo.update(ifo, overwrite_atoms = False)
inherit_budget.update(ifo=pre_ifo)
return inherit_budget
else:
modname = 'gwinc.ifo.aLIGO'
else:
bname = bname or base
modname = path
......
......@@ -3,11 +3,18 @@ import re
import io
import yaml
import numpy as np
from collections.abc import Mapping, Sequence, MutableSequence
#base class for numbers
from numbers import Number
from scipy.io import loadmat
from scipy.io.matlab.mio5_params import mat_struct
# HACK: fix loading number in scientific notation
#this is an assumption made for the recursive update method later, so check here
assert(not issubclass(np.ndarray, Sequence))
# HACK: fix loading Number in scientific notation
#
# https://stackoverflow.com/questions/30458977/yaml-loads-5e-6-as-string-and-not-a-number
#
......@@ -34,12 +41,15 @@ def dictlist2recarray(l):
else:
return type(v)
# get dtypes from first element dict
dtypes = [(k, dtype(v)) for k,v in l[0].items()]
dtypes = [(k, dtype(v)) for k, v in l[0].items()]
values = [tuple(el.values()) for el in l]
out = np.array(values, dtype=dtypes)
return out.view(np.recarray)
#very unique object that serves as a key meaning "Clear this value during update"
CLEAR_KEY = ("CLEAR", lambda : None)
class Struct(object):
"""Matlab struct-like object
......@@ -56,6 +66,8 @@ class Struct(object):
file, a nested dict, or a MATLAB struct object.
"""
#very unique object that serves as a key meaning "Clear this value during update"
CLEAR_KEY = CLEAR_KEY
STRUCT_EXT = ['.yaml', '.yml', '.mat', '.m']
"""accepted extension types for struct files"""
......@@ -86,7 +98,8 @@ class Struct(object):
Struct.
"""
self.update(dict(*args, **kwargs))
#TODO, should this use the more or less permissive allow_unknown_types?
self.update(dict(*args, **kwargs), allow_unknown_types=True)
def __getitem__(self, key):
"""Get a (possibly nested) value from the struct.
......@@ -119,36 +132,139 @@ class Struct(object):
else:
self.__dict__[key] = value
def __delitem__(self, key):
del self.__dict__[key]
def setdefault(self, key, default):
return self.__dict__.setdefault(key, default)
def update(self, other):
def update(
self, other,
overwrite_atoms=False,
clear_key=CLEAR_KEY,
value_types=(str, Number, np.ndarray),
allow_unknown_types=True,
):
"""Update Struct from other Struct or dict.
This is *recursive* and will also update using lists, performing a
deepcopy of the dict/list structure. It inspects the internal types to
do this.
None's are not inserted, and are always overwritten.
if other contains any values of exactly clear_key
(default to Struct.CLEAR_KEY), then that value is cleared in the updated
self. override the argument clear_key=None to clear null values.
"""
if isinstance(other, Struct):
d = other.__dict__
else:
d = dict(other)
for k, v in d.items():
kw = dict(
overwrite_atoms=overwrite_atoms,
clear_key=clear_key,
value_types=value_types,
allow_unknown_types=allow_unknown_types,
)
def update_element(self, k, other_v,):
"""
type dispatch that assigns into self[k] based on the current type
and the type of other_v
"""
self_v = self[k]
if other_v is CLEAR_KEY:
if isinstance(self, Mapping):
del self[k]
else:
raise RuntimeError("clear_key deletions not allowed in sequences like lists")
elif other_v is None:
#don't update on None
pass
elif isinstance(other_v, value_types):
#other is a value type, not a collection
if isinstance(self_v, (Sequence, Mapping)):
raise RuntimeError("struct update is an incompatible storage type (e.g. updating a value into a dict or list)")
else:
self[k] = other_v
elif isinstance(other_v, Mapping):
if isinstance(self_v, value_types):
if not overwrite_atoms:
raise RuntimeError("struct update is an incompatible storage type (e.g. updating a dict into a float)")
else:
self_v = self[k] = Struct()
self_v.update(other_v, **kw)
elif isinstance(self_v, Sequence):
raise RuntimeError("struct update is an incompatible storage type (e.g. updating a dict into a list)")
elif isinstance(self_v, Mapping):
self[k].update(other_v, **kw)
elif self_v is None:
self_v = self[k] = Struct()
self_v.update(other_v, **kw)
else:
raise RuntimeError("struct update is an incompatible storage type (e.g. updating a dict into a list)")
elif isinstance(other_v, Sequence):
#this check MUST come after value_types, or string is included
#make mutable
if not isinstance(self_v, MutableSequence):
self_v = list(self_v)
if isinstance(self_v, value_types):
if not overwrite_atoms:
raise RuntimeError("struct update is an incompatible storage type (e.g. updating a dict into a string)")
else:
self_v = self[k] = other_v
elif isinstance(self_v, Sequence):
#the string check MUST come before Sequence
list_update(self_v, other_v)
elif isinstance(self_v, Mapping):
raise RuntimeError("struct update is an incompatible storage type (e.g. updating a list into a dict)")
elif self_v is None:
self_v = self[k] = other_v
else:
raise RuntimeError("struct update is an incompatible storage type (e.g. updating a value into a list)")
else:
#other is an unknown value type, not a collection
if not allow_unknown_types:
raise RuntimeError("Unknown type assigned during recursive .update()")
if isinstance(self_v, (Sequence, Mapping)):
raise RuntimeError("struct update is an incompatible storage type (e.g. updating a value into a dict or list)")
else:
self[k] = other_v
return
def list_update(self_v, other_v,):
"""
helper function for the recursive update
"""
N_min = min(len(self_v), len(other_v))
#make self as long as other, filled with None's so that assignment occurs
self_v.extend([None] * (len(other_v) - N_min))
for idx, sub_other_v in enumerate(other_v):
update_element(self_v, idx, sub_other_v)
return
#actual code loop for the recursive update
for k, other_v in other.items():
if k in self:
if isinstance(self[k], Struct) \
and isinstance(v, (dict, Struct)):
self[k].update(v)
continue
try:
delattr(self, k)
except AttributeError:
delattr(self.__class__, k)
if isinstance(v, dict):
self[k] = Struct(v)
elif isinstance(v, (list, tuple)):
try:
self[k] = list(map(Struct, v))
except TypeError:
self[k] = v
update_element(self, k, other_v)
else:
self[k] = v
#k not in self, so just assign
if other_v is CLEAR_KEY:
pass
elif isinstance(other_v, value_types):
#value type to directly assign
self[k] = other_v
elif isinstance(other_v, Mapping):
self_v = self[k] = Struct()
#use update so that it is a deepcopy
self_v.update(other_v, **kw)
elif isinstance(other_v, Sequence):
#MUST come after the value types check, or strings included
self_v = self[k] = []
list_update(self_v, other_v)
else:
if not allow_unknown_types:
raise RuntimeError("Unknown type assigned during recursive .update()")
#value type to directly assign
self[k] = other_v
def items(self):
return self.__dict__.items()
......@@ -162,7 +278,6 @@ class Struct(object):
def __contains__(self, key):
return key in self.__dict__
def to_dict(self, array=False):
"""Return nested dictionary representation of Struct.
......@@ -216,6 +331,27 @@ class Struct(object):
def __repr__(self):
return self.__str__()
def __len__(self):
return len(self.__dict__)
def _repr_pretty_(self, s, cycle):
"""
This is the pretty print extension function for IPython's pretty printer
"""
if cycle:
s.text('GWINC Struct(...)')
return
s.begin_group(8, 'Struct({')
for idx, (k, v) in enumerate(self.items()):
s.pretty(k)
s.text(': ')
s.pretty(v)
if idx+1 < len(self):
s.text(',')
s.breakable()
s.end_group(8, '})')
return
def __iter__(self):
return iter(self.__dict__)
......@@ -223,17 +359,19 @@ class Struct(object):
"""Iterate over all leaves in the struct tree.
"""
for k,v in self.__dict__.items():
for k, v in self.__dict__.items():
if k[0] == '_':
continue
if isinstance(v, type(self)):
for sk,sv in v.walk():
if isinstance(v, (dict, Struct)):
for sk, sv in v.walk():
yield k+'.'+sk, sv
else:
try:
for i,vv in enumerate(v):
for sk,sv in vv.walk():
yield '{}[{}].{}'.format(k,i,sk), sv
for i, vv in enumerate(v):
if isinstance(vv, dict):
vv = Struct(vv)
for sk, sv in vv.walk():
yield '{}[{}].{}'.format(k, i, sk), sv
except (AttributeError, TypeError):
yield k, v
......@@ -247,12 +385,14 @@ class Struct(object):
return k in keys
else:
return True
def map_tuple(kv):
k, v = kv
if isinstance(v, list):
return k, tuple(v)
else:
return k, v
return hash(tuple(sorted(
map(map_tuple, filter(filter_keys, self.walk()))
)))
......@@ -263,16 +403,33 @@ class Struct(object):
Returns list of (key, value, other_value) tuples. Value is
None if key not present.
Note: yaml also supports putting None into dictionaries by not supplying
a value. The None values returned here and the "missing value" None's
from yaml are not distinguished in this diff
"""
diffs = []
UNIQUE = lambda x : None
if isinstance(other, dict):
other = Struct(other)
for k, ov in other.walk():
v = self.get(k, None)
if ov != v and ov is not v:
diffs.append((k, v, ov))
try:
v = self.get(k, UNIQUE)
if ov != v and ov is not v:
if v is UNIQUE:
diffs.append((k, None, ov))
else:
diffs.append((k, v, ov))
except TypeError:
#sometimes the deep keys go through unmappable objects
#which TypeError if indexed
diffs.append((k, None, ov))
for k, v in self.walk():
ov = other.get(k, None)
if ov is None:
diffs.append((k, v, ov))
try:
ov = other.get(k, UNIQUE)
if ov is UNIQUE:
diffs.append((k, v, None))
except TypeError:
diffs.append((k, v, None))
return diffs
def __eq__(self, other):
......@@ -298,7 +455,7 @@ class Struct(object):
elif isinstance(v, (list, np.ndarray)):
if isinstance(v, list):
v = np.array(v)
v = np.array2string(v, separator='', max_line_width=np.Inf, formatter={'all':lambda x: "{:0.6e} ".format(x)})
v = np.array2string(v, separator='', max_line_width=np.Inf, formatter={'all': lambda x: "{:0.6e} ".format(x)})
base = 's'
else:
base = 's'
......@@ -313,7 +470,6 @@ class Struct(object):
else:
return txt.getvalue()
@classmethod
def from_yaml(cls, y):
"""Create Struct from YAML string.
......@@ -322,7 +478,6 @@ class Struct(object):
d = yaml.load(y, Loader=yaml_loader) or {}
return cls(d)
@classmethod
def from_matstruct(cls, s):
"""Create Struct from scipy.io.matlab mat_struct object.
......@@ -331,9 +486,9 @@ class Struct(object):
c = cls()
try:
s = s['ifo']
except:
except Exception:
pass
for k,v in s.__dict__.items():
for k, v in s.__dict__.items():
if k in ['_fieldnames']:
# skip these fields
pass
......@@ -343,7 +498,7 @@ class Struct(object):
# handle lists of Structs
try:
c.__dict__[k] = list(map(Struct.from_matstruct, v))
except:
except Exception:
c.__dict__[k] = v
# try:
# c.__dict__[k] = float(v)
......@@ -351,7 +506,6 @@ class Struct(object):
# c.__dict__[k] = v
return c
@classmethod
def from_file(cls, path):
"""Load Struct from .yaml or MATLAB .mat file.
......@@ -392,3 +546,8 @@ class Struct(object):
return cls.from_matstruct(s)
else:
raise IOError("Unknown file type: {}".format(ext))
Mapping.register(Struct)
inherit: 'Aplus'
#test that list merging works
Suspension:
Stage:
# Stage1
-
# Stage2
-
# Stage3
-
# Stage4
- Mass: 30
Squeezer:
AmplitudedB: 14 # SQZ amplitude [dB]
InjectionLoss: 0.02 # power loss to sqz
inherit: 'Aplus_mod.yaml'
#test that list merging works
Suspension:
Stage:
# Stage1
-
# Stage2
-
# Stage3
- Mass: 30
Squeezer:
AmplitudedB: 12 # SQZ amplitude [dB]
"""
"""
from gwinc import load_budget
def test_load(pprint, tpath_join, fpath_join):
fpath = fpath_join('Aplus_mod.yaml')
B_inherit = load_budget(fpath)
B_orig = load_budget('Aplus')
pprint(B_inherit.ifo)
pprint("ACTUAL TEST")
pprint(B_inherit.ifo.diff(B_orig.ifo))
assert(
sorted(B_inherit.ifo.diff(B_orig.ifo))
== sorted([
('Suspension.Stage[3].Mass', 30, 22.1),
('Squeezer.AmplitudedB', 14, 12),
('Squeezer.InjectionLoss', 0.02, 0.05)])
)
fpath2 = fpath_join('Aplus_mod2.yaml')
B_inherit2 = load_budget(fpath2)
pprint(B_inherit2.ifo.diff(B_orig.ifo))
assert(
sorted(B_inherit2.ifo.diff(B_orig.ifo))
== sorted([
('Suspension.Stage[2].Mass', 30, 21.8),
('Suspension.Stage[3].Mass', 30, 22.1),
('Squeezer.InjectionLoss', 0.02, 0.05)
])
)
lists:
test_nonelist:
-
-
dicts:
test_nonedict:
A:
B:
"""
"""
import numpy as np
from os import path
from gwinc import Struct
import pylab as pyl
def test_load(pprint, tpath_join, fpath_join):
fpath = fpath_join('test_load.yml')
yml = Struct.from_file(fpath)
pprint("full yaml")
pprint(yml)
pprint("individual tests")
pprint(yml.lists.test_nonelist)
assert(yml.lists.test_nonelist == [None, None])
pprint(yml.dicts.test_nonedict)
S_cmp = Struct({'A': None, 'B': None})
pprint(yml.dicts.test_nonedict.diff(S_cmp))
assert(yml.dicts.test_nonedict == S_cmp)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment