diff --git a/bilby/core/sampler/ptemcee.py b/bilby/core/sampler/ptemcee.py
index 7baa0a9a65ffce828aab4ef19b7b725130787380..da13250a48b46a2fc45dffa375bfb8594a924c8f 100644
--- a/bilby/core/sampler/ptemcee.py
+++ b/bilby/core/sampler/ptemcee.py
@@ -193,7 +193,13 @@ class Ptemcee(MCMCSampler):
         )
         self.convergence_inputs = ConvergenceInputs(**convergence_inputs_dict)
 
-        # MultiProcessing inputs
+        # Check if threads was given as an equivalent arg
+        if threads == 1:
+            for equiv in self.npool_equiv_kwargs:
+                if equiv in kwargs:
+                    threads = kwargs.pop(equiv)
+
+        # Store threads
         self.threads = threads
 
         # Misc inputs
@@ -221,10 +227,6 @@ class Ptemcee(MCMCSampler):
             for equiv in self.nwalkers_equiv_kwargs:
                 if equiv in kwargs:
                     kwargs["nwalkers"] = kwargs.pop(equiv)
-        if "threads" not in kwargs:
-            for equiv in self.npool_equiv_kwargs:
-                if equiv in kwargs:
-                    kwargs["threads"] = kwargs.pop(equiv)
 
     def get_pos0_from_prior(self):
         """ Draw the initial positions from the prior