Skip to content
Snippets Groups Projects
Commit 002ddfe7 authored by Colm Talbot's avatar Colm Talbot
Browse files

BUGFIX: close pool for emcee saving

parent ac3b9a0e
No related branches found
No related tags found
1 merge request!1274BUGFIX: close pool for emcee saving
......@@ -2,7 +2,6 @@ import datetime
import time
import numpy as np
import pandas as pd
from ..utils import logger
from .base_sampler import NestedSampler, _TemporaryFileSamplerMixin, signal_wrapper
......@@ -150,7 +149,6 @@ class DNest4(_TemporaryFileSamplerMixin, NestedSampler):
self.start_time = np.nan
self.sampler = None
self._information = np.nan
self._last_live_sample_info = np.nan
# Get the estimates of the prior distributions' widths and centers.
widths = []
......@@ -229,22 +227,7 @@ class DNest4(_TemporaryFileSamplerMixin, NestedSampler):
self.result.log_evidence = stats["log_Z"]
self._information = stats["H"]
self.result.log_evidence_err = np.sqrt(self._information / self.num_particles)
if self._backend == "memory":
self._last_live_sample_info = pd.DataFrame(
self.sampler.backend.sample_info[-1]
)
self.result.log_likelihood_evaluations = self._last_live_sample_info[
"log_likelihood"
]
self.result.samples = np.array(self.sampler.backend.posterior_samples)
else:
sample_info_path = (
"./" + self.kwargs["outputfiles_basename"] + "/sample_info.txt"
)
sample_info = np.genfromtxt(sample_info_path, comments="#", names=True)
self.result.log_likelihood_evaluations = sample_info["log_likelihood"]
self.result.samples = np.array(self.sampler.backend.posterior_samples)
self.result.samples = np.array(self.sampler.backend.posterior_samples)
self.result.sampler_output = out
self.result.outputfiles_basename = self.outputfiles_basename
......
......@@ -403,6 +403,7 @@ class Emcee(MCMCSampler):
if self.verbose:
iterator.close()
self.write_current_state()
self._close_pool()
self.result.sampler_output = np.nan
self.calculate_autocorrelation(self.sampler.chain.reshape((-1, self.ndim)))
......
......@@ -134,14 +134,15 @@ class TestRunningSamplers(unittest.TestCase):
likelihood=self.likelihood,
priors=self.priors,
sampler=sampler,
save=False,
save="hdf5",
npool=pool_size,
conversion_function=self.conversion_function,
**kwargs,
**extra_kwargs,
)
assert "derived" in res.posterior
assert res.log_likelihood_evaluations is not None
if sampler != "dnest4":
assert res.log_likelihood_evaluations is not None
@parameterized.expand(_sampler_kwargs.keys())
def test_interrupt_sampler_single(self, sampler):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment