Commit ea2a7aa7 authored by Gregory Ashton's avatar Gregory Ashton

Close #17 move output format to h5

Removes use of pickle in favour of h5 format. Uses deepdish
(https://github.com/uchicago-cs/deepdish) over the more standard
h5py for the following reasons

- h5py is built around saving 'arrays' while deepdish around saving
  dictionaries
- provides useful tools to check what data is in the h5 file
- is well written, provides good documentation and no strange error
  messages (personal experience with trying to implement h5py).
parent 768831bf
Pipeline #17458 passed with stages
in 5 minutes and 51 seconds
import logging import logging
import os import os
import pickle import deepdish
class Result(dict): class Result(dict):
...@@ -24,7 +24,7 @@ class Result(dict): ...@@ -24,7 +24,7 @@ class Result(dict):
self.logzerr)) self.logzerr))
def save_to_file(self, outdir, label): def save_to_file(self, outdir, label):
file_name = '{}/{}_results.p'.format(outdir, label) file_name = '{}/{}_result.h5'.format(outdir, label)
if os.path.isdir(outdir) is False: if os.path.isdir(outdir) is False:
os.makedirs(outdir) os.makedirs(outdir)
if os.path.isfile(file_name): if os.path.isfile(file_name):
...@@ -34,5 +34,4 @@ class Result(dict): ...@@ -34,5 +34,4 @@ class Result(dict):
os.rename(file_name, file_name + '.old') os.rename(file_name, file_name + '.old')
logging.info("Saving result to {}".format(file_name)) logging.info("Saving result to {}".format(file_name))
with open(file_name, 'wb+') as f: deepdish.io.save(file_name, self)
pickle.dump(self, f)
...@@ -247,7 +247,7 @@ class Pymultinest(Sampler): ...@@ -247,7 +247,7 @@ class Pymultinest(Sampler):
def run_sampler(likelihood, priors, label='label', outdir='outdir', def run_sampler(likelihood, priors, label='label', outdir='outdir',
sampler='nestle', use_ratio=False, sampler='nestle', use_ratio=False, injection_parameters=None,
**sampler_kwargs): **sampler_kwargs):
""" """
The primary interface to easy parameter estimation The primary interface to easy parameter estimation
...@@ -268,6 +268,9 @@ def run_sampler(likelihood, priors, label='label', outdir='outdir', ...@@ -268,6 +268,9 @@ def run_sampler(likelihood, priors, label='label', outdir='outdir',
use_ratio: bool (False) use_ratio: bool (False)
If True, use the likelihood's loglikelihood_ratio, rather than just If True, use the likelihood's loglikelihood_ratio, rather than just
the loglikelhood. the loglikelhood.
injection_parameters: dict
A dictionary of injection parameters used in creating the data (if
using simulated data). Appended to the result object and saved.
**sampler_kwargs: **sampler_kwargs:
All kwargs are passed directly to the samplers `run` functino All kwargs are passed directly to the samplers `run` functino
...@@ -289,7 +292,7 @@ def run_sampler(likelihood, priors, label='label', outdir='outdir', ...@@ -289,7 +292,7 @@ def run_sampler(likelihood, priors, label='label', outdir='outdir',
result = sampler.run_sampler() result = sampler.run_sampler()
result.noise_logz = likelihood.noise_log_likelihood() result.noise_logz = likelihood.noise_log_likelihood()
result.log_bayes_factor = result.logz - result.noise_logz result.log_bayes_factor = result.logz - result.noise_logz
print("") result.injection_parameters = injection_parameters
result.save_to_file(outdir=outdir, label=label) result.save_to_file(outdir=outdir, label=label)
return result, sampler return result, sampler
else: else:
......
...@@ -12,3 +12,4 @@ gwsurrogate ...@@ -12,3 +12,4 @@ gwsurrogate
NRSur7dq2 NRSur7dq2
chainconsumer chainconsumer
nestle nestle
deepdish
...@@ -84,6 +84,6 @@ sampling_parameters['luminosity_distance'] = peyote.prior.Uniform(lower=30, uppe ...@@ -84,6 +84,6 @@ sampling_parameters['luminosity_distance'] = peyote.prior.Uniform(lower=30, uppe
result, sampler = peyote.sampler.run_sampler( result, sampler = peyote.sampler.run_sampler(
likelihood, priors=sampling_parameters, label='BasicTutorial', likelihood, priors=sampling_parameters, label='BasicTutorial',
sampler='nestle', verbose=True) sampler='nestle', verbose=True, injection_parameters=injection_parameters)
sampler.plot_corner() sampler.plot_corner()
print(result) print(result)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment