Commit 060aaf21 authored by Gregory Ashton's avatar Gregory Ashton
Browse files

WIP: Add dynamic dynesty

parent aae0a49a
Pipeline #104122 failed with stages
in 3 minutes and 57 seconds
......@@ -14,7 +14,7 @@ import matplotlib.pyplot as plt
import mpi4py
import numpy as np
import pandas as pd
from dynesty import NestedSampler
from dynesty import DynamicNestedSampler, NestedSampler
from dynesty.plotting import traceplot
from dynesty.utils import resample_equal, unitcheck
from numpy import linalg
......@@ -540,7 +540,7 @@ with MPIPool() as pool:
logger.info("Initializing sampling points")
live_points = get_initial_points_from_prior(ndim, nlive)
sampler = NestedSampler(
sampler = DynamicNestedSampler(
likelihood_function,
prior_transform_function,
ndim,
......@@ -576,25 +576,40 @@ with MPIPool() as pool:
dlogz=input_args.dlogz,
)
while True:
sampler_kwargs["add_live"] = False
sampler_kwargs["maxcall"] += input_args.n_check_point
sampler.run_nested(**sampler_kwargs)
if sampler.ncall == old_ncall:
break
old_ncall = sampler.ncall
sampling_time += (datetime.datetime.now() - t0).total_seconds()
t0 = datetime.datetime.now()
write_checkpoint(
sampler,
resume_file,
sampling_time,
sampling_keys,
no_plot=input_args.no_plot,
)
sampler_kwargs["add_live"] = True
sampler_kwargs.pop("dlogz")
sampler_kwargs["dlogz_init"] = 0.01
sampler_kwargs["nlive_init"] = 1500
sampler_kwargs["nlive_batch"] = 1000
# sampler_kwargs["maxbatch"] = 10
print(sampler_kwargs)
# while True:
# #sampler_kwargs["add_live"] = False
# sampler_kwargs["maxcall"] += input_args.n_check_point
# sampler.run_nested(**sampler_kwargs)
# if sampler.ncall == old_ncall:
# break
# old_ncall = sampler.ncall
# sampling_time += (datetime.datetime.now() - t0).total_seconds()
# t0 = datetime.datetime.now()
# write_checkpoint(
# sampler,
# resume_file,
# sampling_time,
# sampling_keys,
# no_plot=input_args.no_plot,
# )
sampler.run_nested(**sampler_kwargs)
sampler.run_nested(maxiter=50000) # (possibly) adding more samples
sampler.run_nested(maxbatch=50) # (possibly) adding more samples
filename_trace = "{}/{}_final_trace.png".format(outdir, label)
fig = traceplot(sampler.results, labels=sampling_keys)[0]
fig.tight_layout()
fig.savefig(filename_trace)
plt.close("all")
# sampler_kwargs["add_live"] = True
sampling_time += (datetime.datetime.now() - t0).total_seconds()
out = sampler.results
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment