Skip to content
Snippets Groups Projects
Commit 4af071b1 authored by Michael Williams's avatar Michael Williams Committed by Colm Talbot
Browse files

Update for nessai v0.4.0

parent 6cf9a859
No related branches found
No related tags found
1 merge request!1042Update for nessai v0.4.0
import numpy as np
import os
from pandas import DataFrame
from .base_sampler import NestedSampler
......@@ -15,55 +16,42 @@ class Nessai(NestedSampler):
Documentation: https://nessai.readthedocs.io/
"""
default_kwargs = dict(
output=None,
nlive=1000,
stopping=0.1,
resume=True,
max_iteration=None,
checkpointing=True,
seed=1234,
acceptance_threshold=0.01,
analytic_priors=False,
maximum_uninformed=1000,
uninformed_proposal=None,
uninformed_proposal_kwargs=None,
flow_class=None,
flow_config=None,
training_frequency=None,
reset_weights=False,
reset_permutations=False,
reset_acceptance=False,
train_on_empty=True,
cooldown=100,
memory=False,
poolsize=None,
drawsize=None,
max_poolsize_scale=10,
update_poolsize=False,
latent_prior='truncated_gaussian',
draw_latent_kwargs=None,
compute_radius_with_all=False,
min_radius=False,
max_radius=50,
check_acceptance=False,
fuzz=1.0,
expansion_fraction=1.0,
rescale_parameters=True,
rescale_bounds=[-1, 1],
update_bounds=False,
boundary_inversion=False,
inversion_type='split', detect_edges=False,
detect_edges_kwargs=None,
reparameterisations=None,
n_pool=None,
max_threads=1,
pytorch_threads=None,
plot=None,
proposal_plots=False
)
_default_kwargs = None
seed_equiv_kwargs = ['sampling_seed']
@property
def default_kwargs(self):
"""Default kwargs for nessai.
Retrieves default values from nessai directly and then includes any
bilby specific defaults. This avoids the need to update bilby when the
defaults change or new kwargs are added to nessai.
"""
if not self._default_kwargs:
from inspect import signature
from nessai.flowsampler import FlowSampler
from nessai.nestedsampler import NestedSampler
from nessai.proposal import AugmentedFlowProposal, FlowProposal
kwargs = {}
classes = [
AugmentedFlowProposal,
FlowProposal,
NestedSampler,
FlowSampler,
]
for c in classes:
kwargs.update(
{k: v.default for k, v in signature(c).parameters.items() if v.default is not v.empty}
)
# Defaults for bilby that will override nessai defaults
bilby_defaults = dict(
output=None,
)
kwargs.update(bilby_defaults)
self._default_kwargs = kwargs
return self._default_kwargs
def log_prior(self, theta):
"""
......@@ -194,9 +182,9 @@ class Nessai(NestedSampler):
self.kwargs['n_pool'] = None
if not self.kwargs['output']:
self.kwargs['output'] = self.outdir + '/{}_nessai/'.format(self.label)
if self.kwargs['output'].endswith('/') is False:
self.kwargs['output'] = '{}/'.format(self.kwargs['output'])
self.kwargs['output'] = os.path.join(
self.outdir, '{}_nessai'.format(self.label), ''
)
check_directory_exists_and_if_not_mkdir(self.kwargs['output'])
NestedSampler._verify_kwargs_against_default_kwargs(self)
......@@ -22,53 +22,8 @@ class TestNessai(unittest.TestCase):
plot=False,
skip_import_verification=True,
)
self.expected = dict(
output="outdir/label_nessai/",
nlive=1000,
stopping=0.1,
resume=True,
max_iteration=None,
checkpointing=True,
seed=1234,
acceptance_threshold=0.01,
analytic_priors=False,
maximum_uninformed=1000,
uninformed_proposal=None,
uninformed_proposal_kwargs=None,
flow_class=None,
flow_config=None,
training_frequency=None,
reset_weights=False,
reset_permutations=False,
reset_acceptance=False,
train_on_empty=True,
cooldown=100,
memory=False,
poolsize=None,
drawsize=None,
max_poolsize_scale=10,
update_poolsize=False,
latent_prior='truncated_gaussian',
draw_latent_kwargs=None,
compute_radius_with_all=False,
min_radius=False,
max_radius=50,
check_acceptance=False,
fuzz=1.0,
expansion_fraction=1.0,
rescale_parameters=True,
rescale_bounds=[-1, 1],
update_bounds=False,
boundary_inversion=False,
inversion_type='split', detect_edges=False,
detect_edges_kwargs=None,
reparameterisations=None,
n_pool=None,
max_threads=1,
pytorch_threads=None,
plot=False,
proposal_plots=False
)
self.expected = self.sampler.default_kwargs
self.expected['output'] = 'outdir/label_nessai/'
def tearDown(self):
del self.likelihood
......@@ -76,16 +31,16 @@ class TestNessai(unittest.TestCase):
del self.sampler
del self.expected
def test_default_kwargs(self):
expected = self.expected.copy()
self.assertDictEqual(expected, self.sampler.kwargs)
def test_translate_kwargs_nlive(self):
expected = self.expected.copy()
# nlive in the default kwargs is not a fixed but depends on the
# version of nessai, so get the value here and use it when setting
# the equivalent kwargs.
nlive = expected["nlive"]
for equiv in bilby.core.sampler.base_sampler.NestedSampler.npoints_equiv_kwargs:
new_kwargs = self.sampler.kwargs.copy()
del new_kwargs["nlive"]
new_kwargs[equiv] = 1000
new_kwargs[equiv] = nlive
self.sampler.kwargs = new_kwargs
self.assertDictEqual(expected, self.sampler.kwargs)
......@@ -117,10 +72,10 @@ class TestNessai(unittest.TestCase):
self.sampler.kwargs = new_kwargs
self.assertDictEqual(expected, self.sampler.kwargs)
@patch("builtins.open", mock_open(read_data='{"nlive": 2000}'))
@patch("builtins.open", mock_open(read_data='{"nlive": 4000}'))
def test_update_from_config_file(self):
expected = self.expected.copy()
expected["nlive"] = 2000
expected["nlive"] = 4000
new_kwargs = self.expected.copy()
new_kwargs["config_file"] = "config_file.json"
self.sampler.kwargs = new_kwargs
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment