diff --git a/bilby/core/prior/slabspike.py b/bilby/core/prior/slabspike.py index d3d38304f7c47616ca86a1df840458dbeba234b4..74ed3494386b3fa11ca8f47b5dc3d255eb94be72 100644 --- a/bilby/core/prior/slabspike.py +++ b/bilby/core/prior/slabspike.py @@ -84,7 +84,7 @@ class SlabSpikePrior(Prior): ======= array_like: Associated prior value with input value. """ - original_type = type(val) + original_is_array = isinstance(val, np.ndarray) val = np.atleast_1d(val) lower_indices = np.where(val < self.inverse_cdf_below_spike)[0] @@ -97,7 +97,7 @@ class SlabSpikePrior(Prior): res[lower_indices] = self._contracted_rescale(val[lower_indices]) res[intermediate_indices] = self.spike_location res[higher_indices] = self._contracted_rescale(val[higher_indices] - self.spike_height) - if original_type == int or original_type == float: + if not original_is_array: assert res.shape == (1,) res = res[0] return res