Skip to content
Snippets Groups Projects
Commit 67571bf5 authored by Gregory Ashton's avatar Gregory Ashton
Browse files

Merge branch 'minor-fixes-needed-for-ptemcee' into 'master'

Minor fixes required for stable use of ptemcee

See merge request !625
parents 38c93820 635de241
No related branches found
Tags v0.6.6
1 merge request!625Minor fixes required for stable use of ptemcee
Pipeline #85138 passed with warnings
......@@ -1361,7 +1361,7 @@ class ResultList(list):
self.check_consistent_priors()
# check which kind of sampler was used: MCMC or Nested Sampling
if result.nested_samples is not None:
if result._nested_samples is not None:
posteriors, result = self._combine_nested_sampled_runs(result)
else:
posteriors = [res.posterior for res in self]
......
......@@ -2,6 +2,8 @@ from __future__ import absolute_import, division, print_function
import os
from shutil import copyfile
import signal
import sys
import numpy as np
......@@ -47,6 +49,10 @@ class Ptemcee(Emcee):
nburn=nburn, burn_in_fraction=burn_in_fraction,
burn_in_act=burn_in_act, resume=resume, **kwargs)
signal.signal(signal.SIGTERM, self.write_current_state_and_exit)
signal.signal(signal.SIGINT, self.write_current_state_and_exit)
signal.signal(signal.SIGALRM, self.write_current_state_and_exit)
@property
def sampler_function_kwargs(self):
keys = ['iterations', 'thin', 'storechain', 'adapt', 'swap_ratios']
......@@ -82,7 +88,10 @@ class Ptemcee(Emcee):
chain_file = self.checkpoint_info.chain_file
temp_chain_file = chain_file + '.temp'
if os.path.isfile(chain_file):
copyfile(chain_file, temp_chain_file)
try:
copyfile(chain_file, temp_chain_file)
except OSError:
logger.warning("Failed to write temporary chain file {}".format(temp_chain_file))
with open(temp_chain_file, "a") as ff:
loglike = np.squeeze(loglike[0, :])
......@@ -92,6 +101,10 @@ class Ptemcee(Emcee):
ff.write(self.checkpoint_info.chain_template.format(ii, *line))
os.rename(temp_chain_file, chain_file)
def write_current_state_and_exit(self, signum=None, frame=None):
logger.warning("Run terminated with signal {}".format(signum))
sys.exit(130)
@property
def _previous_iterations(self):
""" Returns the number of iterations that the sampler has saved
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment