Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
bilby
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Iterations
Requirements
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Locked files
Build
Pipelines
Jobs
Pipeline schedules
Test cases
Artifacts
Deploy
Releases
Container Registry
Model registry
Operate
Environments
Monitor
Incidents
Service Desk
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Code review analytics
Issue analytics
Insights
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
lscsoft
bilby
Commits
09e3c9f3
Commit
09e3c9f3
authored
3 years ago
by
Sylvia Biscoveanu
Committed by
Gregory Ashton
3 years ago
Browse files
Options
Downloads
Patches
Plain Diff
Add ability to load results produced with custom priors
parent
b3c8e741
No related branches found
Branches containing commit
No related tags found
Tags containing commit
1 merge request
!1010
Add ability to load results produced with custom priors
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
bilby/core/utils/io.py
+75
-44
75 additions, 44 deletions
bilby/core/utils/io.py
with
75 additions
and
44 deletions
bilby/core/utils/io.py
+
75
−
44
View file @
09e3c9f3
...
...
@@ -25,25 +25,29 @@ def check_directory_exists_and_if_not_mkdir(directory):
class
BilbyJsonEncoder
(
json
.
JSONEncoder
):
def
default
(
self
,
obj
):
from
..prior
import
MultivariateGaussianDist
,
Prior
,
PriorDict
from
...gw.prior
import
HealPixMapPriorDist
from
...bilby_mcmc.proposals
import
ProposalCycle
if
isinstance
(
obj
,
np
.
integer
):
return
int
(
obj
)
if
isinstance
(
obj
,
np
.
floating
):
return
float
(
obj
)
if
isinstance
(
obj
,
PriorDict
):
return
{
'
__prior_dict__
'
:
True
,
'
content
'
:
obj
.
_get_json_dict
()}
return
{
"
__prior_dict__
"
:
True
,
"
content
"
:
obj
.
_get_json_dict
()}
if
isinstance
(
obj
,
(
MultivariateGaussianDist
,
HealPixMapPriorDist
,
Prior
)):
return
{
'
__prior__
'
:
True
,
'
__module__
'
:
obj
.
__module__
,
'
__name__
'
:
obj
.
__class__
.
__name__
,
'
kwargs
'
:
dict
(
obj
.
get_instantiation_dict
())}
return
{
"
__prior__
"
:
True
,
"
__module__
"
:
obj
.
__module__
,
"
__name__
"
:
obj
.
__class__
.
__name__
,
"
kwargs
"
:
dict
(
obj
.
get_instantiation_dict
()),
}
if
isinstance
(
obj
,
ProposalCycle
):
return
str
(
obj
)
try
:
from
astropy
import
cosmology
as
cosmo
,
units
if
isinstance
(
obj
,
cosmo
.
FLRW
):
return
encode_astropy_cosmology
(
obj
)
if
isinstance
(
obj
,
units
.
Quantity
):
...
...
@@ -53,82 +57,104 @@ class BilbyJsonEncoder(json.JSONEncoder):
except
ImportError
:
logger
.
debug
(
"
Cannot import astropy, cannot write cosmological priors
"
)
if
isinstance
(
obj
,
np
.
ndarray
):
return
{
'
__array__
'
:
True
,
'
content
'
:
obj
.
tolist
()}
return
{
"
__array__
"
:
True
,
"
content
"
:
obj
.
tolist
()}
if
isinstance
(
obj
,
complex
):
return
{
'
__complex__
'
:
True
,
'
real
'
:
obj
.
real
,
'
imag
'
:
obj
.
imag
}
return
{
"
__complex__
"
:
True
,
"
real
"
:
obj
.
real
,
"
imag
"
:
obj
.
imag
}
if
isinstance
(
obj
,
pd
.
DataFrame
):
return
{
'
__dataframe__
'
:
True
,
'
content
'
:
obj
.
to_dict
(
orient
=
'
list
'
)}
return
{
"
__dataframe__
"
:
True
,
"
content
"
:
obj
.
to_dict
(
orient
=
"
list
"
)}
if
isinstance
(
obj
,
pd
.
Series
):
return
{
'
__series__
'
:
True
,
'
content
'
:
obj
.
to_dict
()}
return
{
"
__series__
"
:
True
,
"
content
"
:
obj
.
to_dict
()}
if
inspect
.
isfunction
(
obj
):
return
{
"
__function__
"
:
True
,
"
__module__
"
:
obj
.
__module__
,
"
__name__
"
:
obj
.
__name__
}
return
{
"
__function__
"
:
True
,
"
__module__
"
:
obj
.
__module__
,
"
__name__
"
:
obj
.
__name__
,
}
if
inspect
.
isclass
(
obj
):
return
{
"
__class__
"
:
True
,
"
__module__
"
:
obj
.
__module__
,
"
__name__
"
:
obj
.
__name__
}
return
{
"
__class__
"
:
True
,
"
__module__
"
:
obj
.
__module__
,
"
__name__
"
:
obj
.
__name__
,
}
return
json
.
JSONEncoder
.
default
(
self
,
obj
)
def
encode_astropy_cosmology
(
obj
):
cls_name
=
obj
.
__class__
.
__name__
dct
=
{
key
:
getattr
(
obj
,
key
)
for
key
in
infer_args_from_method
(
obj
.
__init__
)}
dct
[
'
__cosmology__
'
]
=
True
dct
[
'
__name__
'
]
=
cls_name
dct
=
{
key
:
getattr
(
obj
,
key
)
for
key
in
infer_args_from_method
(
obj
.
__init__
)}
dct
[
"
__cosmology__
"
]
=
True
dct
[
"
__name__
"
]
=
cls_name
return
dct
def
encode_astropy_quantity
(
dct
):
dct
=
dict
(
__astropy_quantity__
=
True
,
value
=
dct
.
value
,
unit
=
str
(
dct
.
unit
))
if
isinstance
(
dct
[
'
value
'
],
np
.
ndarray
):
dct
[
'
value
'
]
=
list
(
dct
[
'
value
'
])
if
isinstance
(
dct
[
"
value
"
],
np
.
ndarray
):
dct
[
"
value
"
]
=
list
(
dct
[
"
value
"
])
return
dct
def
decode_astropy_cosmology
(
dct
):
try
:
from
astropy
import
cosmology
as
cosmo
cosmo_cls
=
getattr
(
cosmo
,
dct
[
'
__name__
'
])
del
dct
[
'
__cosmology__
'
],
dct
[
'
__name__
'
]
cosmo_cls
=
getattr
(
cosmo
,
dct
[
"
__name__
"
])
del
dct
[
"
__cosmology__
"
],
dct
[
"
__name__
"
]
return
cosmo_cls
(
**
dct
)
except
ImportError
:
logger
.
debug
(
"
Cannot import astropy, cosmological priors may not be
"
"
properly loaded.
"
)
logger
.
debug
(
"
Cannot import astropy, cosmological priors may not be
"
"
properly loaded.
"
)
return
dct
def
decode_astropy_quantity
(
dct
):
try
:
from
astropy
import
units
if
dct
[
'
value
'
]
is
None
:
if
dct
[
"
value
"
]
is
None
:
return
None
else
:
del
dct
[
'
__astropy_quantity__
'
]
del
dct
[
"
__astropy_quantity__
"
]
return
units
.
Quantity
(
**
dct
)
except
ImportError
:
logger
.
debug
(
"
Cannot import astropy, cosmological priors may not be
"
"
properly loaded.
"
)
logger
.
debug
(
"
Cannot import astropy, cosmological priors may not be
"
"
properly loaded.
"
)
return
dct
def
load_json
(
filename
,
gzip
):
if
gzip
or
os
.
path
.
splitext
(
filename
)[
1
].
lstrip
(
'
.
'
)
==
'
gz
'
:
if
gzip
or
os
.
path
.
splitext
(
filename
)[
1
].
lstrip
(
"
.
"
)
==
"
gz
"
:
import
gzip
with
gzip
.
GzipFile
(
filename
,
'
r
'
)
as
file
:
json_str
=
file
.
read
().
decode
(
'
utf-8
'
)
with
gzip
.
GzipFile
(
filename
,
"
r
"
)
as
file
:
json_str
=
file
.
read
().
decode
(
"
utf-8
"
)
dictionary
=
json
.
loads
(
json_str
,
object_hook
=
decode_bilby_json
)
else
:
with
open
(
filename
,
'
r
'
)
as
file
:
with
open
(
filename
,
"
r
"
)
as
file
:
dictionary
=
json
.
load
(
file
,
object_hook
=
decode_bilby_json
)
return
dictionary
def
decode_bilby_json
(
dct
):
if
dct
.
get
(
"
__prior_dict__
"
,
False
):
cls
=
getattr
(
import_module
(
dct
[
'
__module__
'
]),
dct
[
'
__name__
'
])
cls
=
getattr
(
import_module
(
dct
[
"
__module__
"
]),
dct
[
"
__name__
"
])
obj
=
cls
.
_get_from_json_dict
(
dct
)
return
obj
if
dct
.
get
(
"
__prior__
"
,
False
):
cls
=
getattr
(
import_module
(
dct
[
'
__module__
'
]),
dct
[
'
__name__
'
])
obj
=
cls
(
**
dct
[
'
kwargs
'
])
try
:
cls
=
getattr
(
import_module
(
dct
[
"
__module__
"
]),
dct
[
"
__name__
"
])
except
AttributeError
:
logger
.
debug
(
"
Unknown prior class for parameter {}, defaulting to base Prior object
"
.
format
(
dct
[
"
kwargs
"
][
"
name
"
]
)
)
from
..prior
import
Prior
cls
=
Prior
obj
=
cls
(
**
dct
[
"
kwargs
"
])
return
obj
if
dct
.
get
(
"
__cosmology__
"
,
False
):
return
decode_astropy_cosmology
(
dct
)
...
...
@@ -139,9 +165,9 @@ def decode_bilby_json(dct):
if
dct
.
get
(
"
__complex__
"
,
False
):
return
complex
(
dct
[
"
real
"
],
dct
[
"
imag
"
])
if
dct
.
get
(
"
__dataframe__
"
,
False
):
return
pd
.
DataFrame
(
dct
[
'
content
'
])
return
pd
.
DataFrame
(
dct
[
"
content
"
])
if
dct
.
get
(
"
__series__
"
,
False
):
return
pd
.
Series
(
dct
[
'
content
'
])
return
pd
.
Series
(
dct
[
"
content
"
])
if
dct
.
get
(
"
__function__
"
,
False
)
or
dct
.
get
(
"
__class__
"
,
False
):
default
=
"
.
"
.
join
([
dct
[
"
__module__
"
],
dct
[
"
__name__
"
]])
return
getattr
(
import_module
(
dct
[
"
__module__
"
]),
dct
[
"
__name__
"
],
default
)
...
...
@@ -225,6 +251,7 @@ def encode_for_hdf5(key, item):
Input item converted into HDF5 saveable format
"""
from
..prior.dict
import
PriorDict
if
isinstance
(
item
,
np
.
int_
):
item
=
int
(
item
)
elif
isinstance
(
item
,
np
.
float_
):
...
...
@@ -258,7 +285,9 @@ def encode_for_hdf5(key, item):
elif
isinstance
(
item
,
pd
.
Series
):
output
=
item
.
to_dict
()
elif
inspect
.
isfunction
(
item
)
or
inspect
.
isclass
(
item
):
output
=
dict
(
__module__
=
item
.
__module__
,
__name__
=
item
.
__name__
,
__class__
=
True
)
output
=
dict
(
__module__
=
item
.
__module__
,
__name__
=
item
.
__name__
,
__class__
=
True
)
elif
isinstance
(
item
,
dict
):
output
=
item
.
copy
()
elif
isinstance
(
item
,
tuple
):
...
...
@@ -287,12 +316,15 @@ def recursively_load_dict_contents_from_group(h5file, path):
The contents of the HDF5 file unpacked into the dictionary.
"""
import
h5py
output
=
dict
()
for
key
,
item
in
h5file
[
path
].
items
():
if
isinstance
(
item
,
h5py
.
Dataset
):
output
[
key
]
=
decode_from_hdf5
(
item
[()])
elif
isinstance
(
item
,
h5py
.
Group
):
output
[
key
]
=
recursively_load_dict_contents_from_group
(
h5file
,
path
+
key
+
'
/
'
)
output
[
key
]
=
recursively_load_dict_contents_from_group
(
h5file
,
path
+
key
+
"
/
"
)
return
output
...
...
@@ -314,7 +346,7 @@ def recursively_save_dict_contents_to_group(h5file, path, dic):
for
key
,
item
in
dic
.
items
():
item
=
encode_for_hdf5
(
key
,
item
)
if
isinstance
(
item
,
dict
):
recursively_save_dict_contents_to_group
(
h5file
,
path
+
key
+
'
/
'
,
item
)
recursively_save_dict_contents_to_group
(
h5file
,
path
+
key
+
"
/
"
,
item
)
else
:
h5file
[
path
+
key
]
=
item
...
...
@@ -351,24 +383,23 @@ def move_old_file(filename, overwrite=False):
"""
if
os
.
path
.
isfile
(
filename
):
if
overwrite
:
logger
.
debug
(
'
Removing existing file {}
'
.
format
(
filename
))
logger
.
debug
(
"
Removing existing file {}
"
.
format
(
filename
))
os
.
remove
(
filename
)
else
:
logger
.
debug
(
'
Renaming existing file {} to {}.old
'
.
format
(
filename
,
filename
)
)
shutil
.
move
(
filename
,
filename
+
'
.old
'
)
"
Renaming existing file {} to {}.old
"
.
format
(
filename
,
filename
)
)
shutil
.
move
(
filename
,
filename
+
"
.old
"
)
logger
.
debug
(
"
Saving result to {}
"
.
format
(
filename
))
def
safe_save_figure
(
fig
,
filename
,
**
kwargs
):
check_directory_exists_and_if_not_mkdir
(
os
.
path
.
dirname
(
filename
))
from
matplotlib
import
rcParams
try
:
fig
.
savefig
(
fname
=
filename
,
**
kwargs
)
except
RuntimeError
:
logger
.
debug
(
"
Failed to save plot with tex labels turning off tex.
"
)
logger
.
debug
(
"
Failed to save plot with tex labels turning off tex.
"
)
rcParams
[
"
text.usetex
"
]
=
False
fig
.
savefig
(
fname
=
filename
,
**
kwargs
)
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment