diff --git a/bilby/core/prior/dict.py b/bilby/core/prior/dict.py
index 6e57ce0bca9583d5640753d35b759c38fcb63d37..b715e12fce22cc7e5da9f37280f80040e856f045 100644
--- a/bilby/core/prior/dict.py
+++ b/bilby/core/prior/dict.py
@@ -3,7 +3,6 @@ from io import open as ioopen
 import json
 import numpy as np
 import os
-from functools import lru_cache
 
 from future.utils import iteritems
 from matplotlib.cbook import flatten
@@ -42,6 +41,7 @@ class PriorDict(dict):
             self.from_file(filename)
         elif dictionary is not None:
             raise ValueError("PriorDict input dictionary not understood")
+        self.cached_normalizations= {}
 
         self.convert_floats_to_delta_functions()
 
@@ -379,23 +379,27 @@ class PriorDict(dict):
                            if not isinstance(self[key], Constraint)}
             return all_samples
 
-    @lru_cache()
     def normalize_constraint_factor(self, keys):
-        min_accept = 1000
-        sampling_chunk = 5000
-        samples = self.sample_subset(keys=keys, size=sampling_chunk)
-        keep = np.atleast_1d(self.evaluate_constraints(samples))
-        if len(keep) == 1:
-            return 1
-        all_samples = {key: np.array([]) for key in keys}
-        _first_key = list(all_samples.keys())[0]
-        while np.count_nonzero(keep) < min_accept:
+        if repr(keys) in self.cached_norm_factors.keys():
+            return self._cached_normalizations[repr(keys)]
+        else:
+            min_accept = 1000
+            sampling_chunk = 5000
             samples = self.sample_subset(keys=keys, size=sampling_chunk)
-            for key in samples:
-                all_samples[key] = np.hstack(
-                    [all_samples[key], samples[key].flatten()])
-            keep = np.array(self.evaluate_constraints(all_samples), dtype=bool)
-        return len(keep) / np.count_nonzero(keep)
+            keep = np.atleast_1d(self.evaluate_constraints(samples))
+            if len(keep) == 1:
+                return 1
+            all_samples = {key: np.array([]) for key in keys}
+            _first_key = list(all_samples.keys())[0]
+            while np.count_nonzero(keep) < min_accept:
+                samples = self.sample_subset(keys=keys, size=sampling_chunk)
+                for key in samples:
+                    all_samples[key] = np.hstack(
+                        [all_samples[key], samples[key].flatten()])
+                keep = np.array(self.evaluate_constraints(all_samples), dtype=bool)
+            factor = len(keep) / np.count_nonzero(keep)
+            self._cached_normalizations[repr(keys)] = factor
+            return factor
 
     def prob(self, sample, **kwargs):
         """
@@ -511,15 +515,6 @@ class PriorDict(dict):
         """
         return self.__class__(dictionary=dict(self))
 
-    def __key(self):
-        return tuple((k, self[k]) for k in sorted(self))
-
-    def __hash__(self):
-        return hash(self.__key())
-
-    def __eq__(self, other):
-        return self.__key() == other.__key()
-
 
 class PriorSet(PriorDict):