Skip to content
Snippets Groups Projects
Commit 3407a67f authored by Colm Talbot's avatar Colm Talbot
Browse files

Make final plot safer

parent c847c753
No related branches found
No related tags found
2 merge requests!748WIP: Improve dynesty checkpointing II,!746Pickle dump entire sampler in dynesty
Pipeline #111812 passed
......@@ -132,7 +132,8 @@ class Dynesty(NestedSampler):
def __getstate__(self):
""" For pickle: remove external_sampler, which can be an unpicklable "module" """
state = self.__dict__.copy()
del state['external_sampler']
if "external_sampler" in state:
del state['external_sampler']
return state
@property
......@@ -418,8 +419,8 @@ class Dynesty(NestedSampler):
if self.check_point_plot:
import dynesty.plotting as dyplot
labels = [label.replace('_', ' ') for label in self.search_parameter_keys]
filename = "{}/{}_checkpoint_trace.png".format(self.outdir, self.label)
try:
filename = "{}/{}_checkpoint_trace.png".format(self.outdir, self.label)
fig = dyplot.traceplot(self.sampler.results, labels=labels)[0]
fig.tight_layout()
fig.savefig(filename)
......@@ -428,8 +429,8 @@ class Dynesty(NestedSampler):
logger.warning('Failed to create dynesty state plot at checkpoint')
finally:
plt.close("all")
filename = "{}/{}_checkpoint_run.png".format(self.outdir, self.label)
try:
filename = "{}/{}_checkpoint_run.png".format(self.outdir, self.label)
fig, axs = dyplot.runplot(self.sampler.results)
fig.tight_layout()
plt.savefig(filename)
......@@ -438,15 +439,20 @@ class Dynesty(NestedSampler):
logger.warning('Failed to create dynesty run plot at checkpoint')
finally:
plt.close('all')
filename = "{}/{}_checkpoint_stats.png".format(self.outdir, self.label)
fig, axs = plt.subplots(nrows=3, sharex=True)
for ax, name in zip(axs, ["boundidx", "nc", "scale"]):
ax.plot(getattr(self.sampler, f"saved_{name}"), color="C0")
ax.set_ylabel(name)
axs[-1].set_xlabel("iteration")
fig.tight_layout()
plt.savefig(filename)
plt.close('all')
try:
filename = "{}/{}_checkpoint_stats.png".format(self.outdir, self.label)
fig, axs = plt.subplots(nrows=3, sharex=True)
for ax, name in zip(axs, ["boundidx", "nc", "scale"]):
ax.plot(getattr(self.sampler, f"saved_{name}"), color="C0")
ax.set_ylabel(name)
axs[-1].set_xlabel("iteration")
fig.tight_layout()
plt.savefig(filename)
except (RuntimeError, ValueError) as e:
logger.warning(e)
logger.warning('Failed to create dynesty stats plot at checkpoint')
finally:
plt.close('all')
def generate_trace_plots(self, dynesty_results):
check_directory_exists_and_if_not_mkdir(self.outdir)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment