Skip to content
Snippets Groups Projects
Commit 7210c7f3 authored by Colm Talbot's avatar Colm Talbot
Browse files

Merge branch 'clean-up-saved-output' into 'master'

Clean up saved output

See merge request Monash/tupak!120
parents 0dd2f4cd 56c8bac1
No related branches found
No related tags found
1 merge request!120Clean up saved output
Pipeline #
......@@ -362,6 +362,8 @@ class Result(dict):
if conversion_function is not None:
data_frame = conversion_function(data_frame, likelihood, priors)
self.posterior = data_frame
# We save the samples in the posterior and remove the array of samples
del self.samples
def construct_cbc_derived_parameters(self):
""" Construct widely used derived parameters of CBCs """
......
......@@ -945,7 +945,7 @@ class Ptemcee(Emcee):
def run_sampler(likelihood, priors=None, label='label', outdir='outdir',
sampler='dynesty', use_ratio=None, injection_parameters=None,
conversion_function=None, plot=False, default_priors_file=None,
clean=None, meta_data=None, **kwargs):
clean=None, meta_data=None, save=True, **kwargs):
"""
The primary interface to easy parameter estimation
......@@ -984,6 +984,8 @@ def run_sampler(likelihood, priors=None, label='label', outdir='outdir',
saving. For example, if `meta_data={dtype: 'signal'}`. Warning: in case
of conflict with keys saved by tupak, the meta_data keys will be
overwritten.
save: bool
If true, save the priors and results to disk.
**kwargs:
All kwargs are passed directly to the samplers `run` function
......@@ -996,7 +998,6 @@ def run_sampler(likelihood, priors=None, label='label', outdir='outdir',
if clean:
utils.command_line_args.clean = clean
utils.check_directory_exists_and_if_not_mkdir(outdir)
implemented_samplers = get_implemented_samplers()
if priors is None:
......@@ -1010,7 +1011,10 @@ def run_sampler(likelihood, priors=None, label='label', outdir='outdir',
raise ValueError
priors.fill_priors(likelihood, default_priors_file=default_priors_file)
priors.write_to_file(outdir, label)
if save:
utils.check_directory_exists_and_if_not_mkdir(outdir)
priors.write_to_file(outdir, label)
if implemented_samplers.__contains__(sampler.title()):
sampler_class = globals()[sampler.title()]
......@@ -1049,14 +1053,14 @@ def run_sampler(likelihood, priors=None, label='label', outdir='outdir',
if conversion_function is not None:
result.injection_parameters = conversion_function(result.injection_parameters)
result.fixed_parameter_keys = sampler.fixed_parameter_keys
# result.prior = prior # Removed as this breaks the saving of the data
result.samples_to_posterior(likelihood=likelihood, priors=priors,
conversion_function=conversion_function)
result.kwargs = sampler.kwargs
result.save_to_file()
if save:
result.save_to_file()
logger.info("Results saved to {}/".format(outdir))
if plot:
result.plot_corner()
logger.info("Sampling finished, results saved to {}/".format(outdir))
logger.info("Summary of results:\n{}".format(result))
return result
else:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment