diff --git a/bilby/core/prior/slabspike.py b/bilby/core/prior/slabspike.py
index d3d38304f7c47616ca86a1df840458dbeba234b4..74ed3494386b3fa11ca8f47b5dc3d255eb94be72 100644
--- a/bilby/core/prior/slabspike.py
+++ b/bilby/core/prior/slabspike.py
@@ -84,7 +84,7 @@ class SlabSpikePrior(Prior):
         =======
         array_like: Associated prior value with input value.
         """
-        original_type = type(val)
+        original_is_array = isinstance(val, np.ndarray)
         val = np.atleast_1d(val)
 
         lower_indices = np.where(val < self.inverse_cdf_below_spike)[0]
@@ -97,7 +97,7 @@ class SlabSpikePrior(Prior):
         res[lower_indices] = self._contracted_rescale(val[lower_indices])
         res[intermediate_indices] = self.spike_location
         res[higher_indices] = self._contracted_rescale(val[higher_indices] - self.spike_height)
-        if original_type == int or original_type == float:
+        if not original_is_array:
             assert res.shape == (1,)
             res = res[0]
         return res