Skip to content
Snippets Groups Projects
Commit 81f5de90 authored by Rhiannon Udall's avatar Rhiannon Udall
Browse files

Used a more pythonic solution

parent 304b0cff
No related branches found
No related tags found
1 merge request!1235Bugfix Slabspike Sampling
......@@ -84,7 +84,7 @@ class SlabSpikePrior(Prior):
=======
array_like: Associated prior value with input value.
"""
original_type = type(val)
original_is_array = isinstance(val, np.ndarray)
val = np.atleast_1d(val)
lower_indices = np.where(val < self.inverse_cdf_below_spike)[0]
......@@ -97,7 +97,7 @@ class SlabSpikePrior(Prior):
res[lower_indices] = self._contracted_rescale(val[lower_indices])
res[intermediate_indices] = self.spike_location
res[higher_indices] = self._contracted_rescale(val[higher_indices] - self.spike_height)
if original_type == int or original_type == float:
if not original_is_array:
assert res.shape == (1,)
res = res[0]
return res
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment