Skip to content
Snippets Groups Projects
Commit 902952c0 authored by Edward Fauchon-Jones's avatar Edward Fauchon-Jones
Browse files

Merge branch 'fix-romspline-validation-#1' into 'master'

- Fix romspline validation #1
- See merge request !2
parents dbe55b36 a2b8e484
No related branches found
No related tags found
1 merge request!2Fix romspline validation #1
......@@ -65,12 +65,12 @@ class InvalidSubFields(Error):
sim: lvcnrpy.Sim.Sim
Simulation being tested that has this error in formating.
"""
name = "INVALID VALUE"
name = "INVALID SUBFIELDS"
msg = "(Field has subfields [{0:s}] but should have [{1:s}])"
def __init__(self, groupSpec, sim):
definedSubFields = ", ".join(sim[groupSpec.name].keys())
requiredSubFields = ", ".join(groupSpec.components)
requiredSubFields = ", ".join(groupSpec.subfields)
self.msg = self.msg.format(definedSubFields, requiredSubFields)
......
......@@ -15,6 +15,7 @@
import h5py as h5
from . import errors as err
import numpy as np
class Spec(object):
......@@ -82,6 +83,24 @@ class Spec(object):
else:
return err.Missing(self)
@classmethod
def createField(self, sim):
"""Create valid field in simulation of field represented by spec
Parameters
----------
sim: lvcnrpy.Sim.Sim
LVCNR Waveform HDF5 Sim object in which to validate the represented
format specification field.
"""
if self.dtype == int:
val = self.values[0] if self.values is not None else 1
elif self.dtype == float:
val = self.values[0] if self.values is not None else 1.0
elif self.dtype == basestring:
val = self.values[0] if self.values is not None else "1.0"
sim.attrs[self.name] = val
class GroupSpec(Spec):
"""Specification for `h5.Group` fields"""
......@@ -101,25 +120,30 @@ class GroupSpec(Spec):
else:
return err.Missing(self)
@classmethod
def createField(self, sim):
sim.create_group(self.name)
class ROMSplineSpec(Spec):
"""Specification for ROMSpline group fields
Attributes
----------
components: list of str
List of component subfields of the represented field.
subfields: list of str
List of subfields a ROMSpline field must contain. A standard ROMSpline
field should contain `(X, Y, deg, errors, tol)` subfields.
"""
dtype = h5.Group
components = ['t_horizon']
subfields = ['X', 'Y', 'deg', 'errors', 'tol']
@classmethod
def valid(self, sim):
"""Validate represented field against format specification.
This will validate that the type of the represented field agrees with
the format specifictaion and that the names of the components in the
dataset agrees with `self.components`.
the format specifictaion and that the names of the subfields in the
group agrees with `self.subfields`.
"""
try:
value = sim[self.name]
......@@ -128,7 +152,7 @@ class ROMSplineSpec(Spec):
if value is not None:
if isinstance(value, self.dtype):
s1 = set(self.components)
s1 = set(self.subfields)
s2 = set(value.keys())
diff = s1.symmetric_difference(s2)
if len(diff) == 0:
......@@ -140,6 +164,13 @@ class ROMSplineSpec(Spec):
else:
return err.Missing(self)
@classmethod
def createField(self, sim):
group = sim.create_group(self.name)
for sub in self.subfields:
data = np.array([float(i) for i in range(10)])
group.create_dataset(sub, data=data)
# General Fields
class Type(Spec):
......@@ -382,149 +413,124 @@ class MeanAnomaly(Spec):
class Mass1VsTime(ROMSplineSpec):
"""Specification for the `mass1-vs-time` field"""
name = 'mass1-vs-time'
components = ['t_horizon', 'M1']
class Mass2VsTime(ROMSplineSpec):
"""Specification for the `mass2-vs-time` field"""
name = 'mass2-vs-time'
components = ['t_horizon', 'M2']
class Spin1xVsTime(ROMSplineSpec):
"""Specification for the `spin1x-vs-time` field"""
name = 'spin1x-vs-time'
components = ['t_horizon', 'chi1x']
class Spin1yVsTime(ROMSplineSpec):
"""Specification for the `spin1y-vs-time` field"""
name = 'spin1y-vs-time'
components = ['t_horizon', 'chi1y']
class Spin1zVsTime(ROMSplineSpec):
"""Specification for the `spin1z-vs-time` field"""
name = 'spin1z-vs-time'
components = ['t_horizon', 'chi1z']
class Spin2xVsTime(ROMSplineSpec):
"""Specification for the `spin2x-vs-time` field"""
name = 'spin2x-vs-time'
components = ['t_horizon', 'chi2x']
class Spin2yVsTime(ROMSplineSpec):
"""Specification for the `spin2y-vs-time` field"""
name = 'spin2y-vs-time'
components = ['t_horizon', 'chi2y']
class Spin2zVsTime(ROMSplineSpec):
"""Specification for the `spin2z-vs-time` field"""
name = 'spin2z-vs-time'
components = ['t_horizon', 'chi2z']
class Position1xVsTime(ROMSplineSpec):
"""Specification for the `position1x-vs-time` field"""
name = 'position1x-vs-time'
components = ['t_horizon', 'c1x']
class Position1yVsTime(ROMSplineSpec):
"""Specification for the `position1y-vs-time` field"""
name = 'position1y-vs-time'
components = ['t_horizon', 'c1y']
class Position1zVsTime(ROMSplineSpec):
"""Specification for the `position1z-vs-time` field"""
name = 'position1z-vs-time'
components = ['t_horizon', 'c1z']
class Position2xVsTime(ROMSplineSpec):
"""Specification for the `position2x-vs-time` field"""
name = 'position2x-vs-time'
components = ['t_horizon', 'c2x']
class Position2yVsTime(ROMSplineSpec):
"""Specification for the `position2y-vs-time` field"""
name = 'position2y-vs-time'
components = ['t_horizon', 'c2y']
class Position2zVsTime(ROMSplineSpec):
"""Specification for the `position2z-vs-time` field"""
name = 'position2z-vs-time'
components = ['t_horizon', 'c2z']
class LNhatxVsTime(ROMSplineSpec):
"""Specification for the `LNhatx-vs-time` field"""
name = 'LNhatx-vs-time'
components = ['t_horizon', 'Lhatx']
class LNhatyVsTime(ROMSplineSpec):
"""Specification for the `LNhaty-vs-time` field"""
name = 'LNhaty-vs-time'
components = ['t_horizon', 'Lhaty']
class LNhatzVsTime(ROMSplineSpec):
"""Specification for the `LNhatz-vs-time` field"""
name = 'LNhatz-vs-time'
components = ['t_horizon', 'Lhatz']
class OmegaVsTime(ROMSplineSpec):
"""Specification for the `Omega-vs-time` field"""
name = 'Omega-vs-time'
components = ['t_horizon', 'Omega']
# Format 3
class RemnantMassVsTime(ROMSplineSpec):
"""Specification for the `remnant-mass-vs-time`` field"""
name = 'remnant-mass-vs-time'
components = ['t_horizon', 'Mr']
class RemnantSpinxVsTime(ROMSplineSpec):
"""Specification for the `remnant-spinx-vs-time` field"""
name = 'remnant-spinx-vs-time'
components = ['t_horizon', 'chix']
class RemnantSpinyVsTime(ROMSplineSpec):
"""Specification for the `remnant-spiny-vs-time` field"""
name = 'remnant-spiny-vs-time'
components = ['t_horizon', 'chiy']
class RemnantSpinzVsTime(ROMSplineSpec):
"""Specification for the `remnant-spinz-vs-time` field"""
name = 'remnant-spinz-vs-time'
components = ['t_horizon', 'chiz']
class RemnantPositionxVsTime(ROMSplineSpec):
"""Specification for the `remnant-positionx-vs-time` field"""
name = 'remnant-positionx-vs-time'
components = ['t_horizon', 'cx']
class RemnantPositionyVsTime(ROMSplineSpec):
"""Specification for the `remnant-positiony-vs-time` field"""
name = 'remnant-positiony-vs-time'
components = ['t_horizon', 'cy']
class RemnantPositionzVsTime(ROMSplineSpec):
"""Specification for the `remnant-positionz-vs-time` field"""
name = 'remnant-positionz-vs-time'
components = ['t_horizon', 'cz']
This diff is collapsed.
......@@ -21,6 +21,7 @@ from lvcnrpy.format.specs import Spec, GroupSpec, ROMSplineSpec
import h5py as h5
from collections import OrderedDict
import numpy as np
from lvcnrpy.Sim import Sim
def lvcnrcheck(args):
......@@ -46,44 +47,29 @@ def lvcnrcheck(args):
return output
def getFormat1HDF5():
"""Create a temporary HDF5 file that is a valid format 1 file.
def createValidSim():
"""Create a temporary HDF5 file that is a valid LVC NR simulation.
Returns
-------
file: NamedTemporaryFile
`NamedTemporaryFile` object that is initialized as a HDF5 file that is
a valid format 1 file.
a valid LVC NR simulation.
"""
f = NamedTemporaryFile()
h5f = h5.File(f.name, 'w')
sim = Sim(f.name, 'w')
def createLeaves(nodes):
for nodeKey, node in nodes.items():
if isinstance(node, OrderedDict):
createLeaves(node)
elif isinstance(node(), Spec):
if isinstance(node(), GroupSpec):
h5f.create_group(node.name)
continue
elif isinstance(node(), ROMSplineSpec):
group = h5f.create_group(node.name)
for comp in node.components:
data = np.array([float(i) for i in range(10)])
group.create_dataset(comp, data=data)
continue
elif node.dtype == int:
val = node.values[0] if node.values is not None else 1
elif node.dtype == float:
val = node.values[0] if node.values is not None else 1.0
elif node.dtype == basestring:
val = node.values[0] if node.values is not None else "1.0"
h5f.attrs[node.name] = val
node.createField(sim)
createLeaves(format1)
createLeaves(format2)
createLeaves(format3)
h5f.close()
sim.close()
return f
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment