Commit 800e72c0 authored by Charlie Hoy's avatar Charlie Hoy
Browse files

update properties when update method is used

parent 765b3256
Pipeline #471797 failed with stages
in 147 minutes and 34 seconds
......@@ -490,8 +490,8 @@ def _1d_histogram_plot(
elif ax is None:
ax = fig.gca()
if len(set(samples)) <= max_vline:
for _ind, _sample in enumerate(set(samples)):
if len(set(samples.to_numpy())) <= max_vline:
for _ind, _sample in enumerate(set(samples.to_numpy())):
_label = None
if _ind == 0:
_label = label
......@@ -531,7 +531,6 @@ def _1d_histogram_plot(
prior, color=conf.prior_color, ax=ax, linestyle=linestyle,
**kwargs
)
if set_labels:
ax.set_xlabel(latex_label)
ax.set_ylabel("Probability Density")
......
......@@ -325,7 +325,24 @@ class TestSamplesDict(object):
assert dataset.mean["a"] == np.mean(self.samples[0])
assert dataset.mean["b"] == np.mean(self.samples[1])
assert dataset.number_of_samples == len(self.samples[1])
assert len(dataset.downsample(10)["a"]) == 10
p = dataset.to_pandas()
assert isinstance(p, pd.core.frame.DataFrame)
def test_downsample_and_discard(self):
"""Test that the downsample and discard methods of the SamplesDict class
are correct
"""
dataset = SamplesDict(self.parameters, self.samples)
np.testing.assert_almost_equal(
dataset["a"].confidence_interval([5, 95]),
np.percentile(dataset["a"], [5, 95])
)
dataset.downsample(10)
assert len(dataset["a"]) == 10
np.testing.assert_almost_equal(
dataset["a"].confidence_interval([5, 95]),
np.percentile(dataset["a"], [5, 95])
)
dataset = SamplesDict(self.parameters, self.samples)
assert len(dataset.discard_samples(10)["a"]) == len(self.samples[0]) - 10
np.testing.assert_almost_equal(
......@@ -334,8 +351,6 @@ class TestSamplesDict(object):
np.testing.assert_almost_equal(
dataset.discard_samples(10)["b"], self.samples[1][10:]
)
p = dataset.to_pandas()
assert isinstance(p, pd.core.frame.DataFrame)
remove = dataset.pop("a")
assert list(dataset.keys()) == ["b"]
......
......@@ -133,7 +133,10 @@ class SamplesDict(Dict):
_value = Array(value)
super(SamplesDict, self).__setitem__(key, _value)
try:
if key not in self.parameters:
if key in self.parameters:
ind = self.parameters.index(key)
self.samples[ind] = value
else:
self.parameters.append(key)
try:
cond = (
......@@ -300,6 +303,14 @@ class SamplesDict(Dict):
return Parameters([key for key in original if key[0] != "_"])
return Parameters(original)
def update(self, *args, **kwargs):
super(SamplesDict, self).update(*args, **kwargs)
try:
for key, item in self.items():
self.__setitem__(key, item)
except ValueError:
self = self.__init__(self)
def write(self, **kwargs):
"""Save the stored posterior samples to file
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment