diff --git a/bilby/core/grid.py b/bilby/core/grid.py
index a9b957f96a9ce649b4a4b1e29807f73347d433d9..d53bfa1af29313d9700e754e26dd7ff24c0ce2cc 100644
--- a/bilby/core/grid.py
+++ b/bilby/core/grid.py
@@ -7,7 +7,8 @@ from collections import OrderedDict
 
 from .prior import Prior, PriorDict
 from .utils import (logtrapzexp, check_directory_exists_and_if_not_mkdir,
-                    logger, BilbyJsonEncoder, decode_bilby_json)
+                    logger)
+from .utils import BilbyJsonEncoder, decode_bilby_json
 from .result import FileMovedError
 
 
@@ -406,12 +407,10 @@ class Grid(object):
 
         logger.debug("Saving result to {}".format(filename))
 
-        # Convert the prior to a string representation for saving on disk
         dictionary = self._get_save_data_dictionary()
-        if dictionary.get('priors', False):
-            dictionary['priors'] = {key: str(self.priors[key]) for key in self.priors}
 
         try:
+            dictionary["priors"] = dictionary["priors"]._get_json_dict()
             if gzip or (os.path.splitext(filename)[-1] == '.gz'):
                 import gzip
                 # encode to a string
@@ -468,12 +467,6 @@ class Grid(object):
             else:
                 with open(fname, 'r') as file:
                     dictionary = json.load(file, object_hook=decode_bilby_json)
-            for key in dictionary.keys():
-                # Convert the loaded priors to bilby prior type
-                if key == 'priors':
-                    for param in dictionary[key].keys():
-                        dictionary[key][param] = str(dictionary[key][param])
-                    dictionary[key] = PriorDict(dictionary[key])
             try:
                 grid = cls(likelihood=None, priors=dictionary['priors'],
                            grid_size=dictionary['sample_points'],
diff --git a/bilby/core/prior.py b/bilby/core/prior.py
index 2ba628560947b6d577880d4fe34ea3af6047b150..9ea0b1934ddf35e4d021329fe6d4a7bd2036f955 100644
--- a/bilby/core/prior.py
+++ b/bilby/core/prior.py
@@ -5,16 +5,21 @@ from importlib import import_module
 import os
 from collections import OrderedDict
 from future.utils import iteritems
-from matplotlib.cbook import flatten
+import json
 
 import numpy as np
 import scipy.stats
 from scipy.integrate import cumtrapz
 from scipy.interpolate import interp1d
 from scipy.special import erf, erfinv
+from matplotlib.cbook import flatten
 
+# Keep import bilby statement, it is necessary for some eval() statements
+from .utils import BilbyJsonEncoder, decode_bilby_json
 from .utils import (
-    logger, infer_args_from_method, check_directory_exists_and_if_not_mkdir)
+    check_directory_exists_and_if_not_mkdir,
+    infer_args_from_method, logger
+)
 
 
 class PriorDict(OrderedDict):
@@ -107,6 +112,22 @@ class PriorDict(OrderedDict):
                     outfile.write(
                         "{} = {}\n".format(key, self[key]))
 
+    def _get_json_dict(self):
+        self.convert_floats_to_delta_functions()
+        total_dict = {key: json.loads(self[key].to_json()) for key in self}
+        total_dict["__prior_dict__"] = True
+        total_dict["__module__"] = self.__module__
+        total_dict["__name__"] = self.__class__.__name__
+        return total_dict
+
+    def to_json(self, outdir, label):
+        check_directory_exists_and_if_not_mkdir(outdir)
+        prior_file = os.path.join(outdir, "{}_prior.json".format(label))
+        logger.debug("Writing priors to {}".format(prior_file))
+        with open(prior_file, "w") as outfile:
+            json.dump(self._get_json_dict(), outfile, cls=BilbyJsonEncoder,
+                      indent=2)
+
     def from_file(self, filename):
         """ Reads in a prior from a file specification
 
@@ -150,7 +171,7 @@ class PriorDict(OrderedDict):
                     cls = cls.split('.')[-1]
                 else:
                     module = __name__
-                cls = getattr(import_module(module), cls)
+                cls = getattr(import_module(module), cls, cls)
                 if key.lower() == "conversion_function":
                     setattr(self, key, cls)
                 elif (cls.__name__ in ['MultivariateGaussianDist',
@@ -170,6 +191,38 @@ class PriorDict(OrderedDict):
                                 filename, key, val, e))
         self.update(prior)
 
+    @classmethod
+    def _get_from_json_dict(cls, prior_dict):
+        try:
+            cls == getattr(
+                import_module(prior_dict["__module__"]),
+                prior_dict["__name__"])
+        except ImportError:
+            logger.debug("Cannot import prior module {}.{}".format(
+                prior_dict["__module__"], prior_dict["__name__"]
+            ))
+        except KeyError:
+            logger.debug("Cannot find module name to load")
+        for key in ["__module__", "__name__", "__prior_dict__"]:
+            if key in prior_dict:
+                del prior_dict[key]
+        obj = cls(dict())
+        obj.from_dictionary(prior_dict)
+        return obj
+
+    @classmethod
+    def from_json(cls, filename):
+        """ Reads in a prior from a json file
+
+        Parameters
+        ----------
+        filename: str
+            Name of the file to be read in
+        """
+        with open(filename, "r") as ff:
+            obj = json.load(ff, object_hook=decode_bilby_json)
+        return obj
+
     def from_dictionary(self, dictionary):
         for key, val in iteritems(dictionary):
             if isinstance(val, str):
@@ -182,6 +235,10 @@ class PriorDict(OrderedDict):
                         "Failed to load dictionary value {} correctly"
                         .format(key))
                     pass
+            elif isinstance(val, dict):
+                logger.warning(
+                    'Cannot convert {} into a prior object. '
+                    'Leaving as dictionary.'.format(key))
             self[key] = val
 
     def convert_floats_to_delta_functions(self):
@@ -628,7 +685,9 @@ class Prior(object):
 
         """
         prior_name = self.__class__.__name__
-        args = ', '.join(['{}={}'.format(key, repr(self._repr_dict[key])) for key in self._repr_dict])
+        instantiation_dict = self._get_instantiation_dict()
+        args = ', '.join(['{}={}'.format(key, repr(instantiation_dict[key]))
+                          for key in instantiation_dict])
         return "{}({})".format(prior_name, args)
 
     @property
@@ -709,6 +768,18 @@ class Prior(object):
     def maximum(self, maximum):
         self._maximum = maximum
 
+    def _get_instantiation_dict(self):
+        subclass_args = infer_args_from_method(self.__init__)
+        property_names = [p for p in dir(self.__class__)
+                          if isinstance(getattr(self.__class__, p), property)]
+        dict_with_properties = self.__dict__.copy()
+        for key in property_names:
+            dict_with_properties[key] = getattr(self, key)
+        instantiation_dict = OrderedDict()
+        for key in subclass_args:
+            instantiation_dict[key] = dict_with_properties[key]
+        return instantiation_dict
+
     @property
     def boundary(self):
         return self._boundary
@@ -727,6 +798,13 @@ class Prior(object):
             label = self.name
         return label
 
+    def to_json(self):
+        return json.dumps(self, cls=BilbyJsonEncoder)
+
+    @classmethod
+    def from_json(cls, dct):
+        return decode_bilby_json(dct)
+
     @classmethod
     def from_repr(cls, string):
         """Generate the prior from it's __repr__"""
@@ -2914,6 +2992,22 @@ class MultivariateGaussianDist(object):
 
         return np.exp(self.ln_prob(samp))
 
+    def _get_instantiation_dict(self):
+        subclass_args = infer_args_from_method(self.__init__)
+        property_names = [p for p in dir(self.__class__)
+                          if isinstance(getattr(self.__class__, p), property)]
+        dict_with_properties = self.__dict__.copy()
+        for key in property_names:
+            dict_with_properties[key] = getattr(self, key)
+        instantiation_dict = OrderedDict()
+        for key in subclass_args:
+            if isinstance(dict_with_properties[key], list):
+                value = np.asarray(dict_with_properties[key]).tolist()
+            else:
+                value = dict_with_properties[key]
+            instantiation_dict[key] = value
+        return instantiation_dict
+
     def __len__(self):
         return len(self.names)
 
@@ -2928,23 +3022,10 @@ class MultivariateGaussianDist(object):
         str: A string representation of this instance
 
         """
-        subclass_args = infer_args_from_method(self.__init__)
         dist_name = self.__class__.__name__
-
-        property_names = [p for p in dir(self.__class__) if isinstance(getattr(self.__class__, p), property)]
-        dict_with_properties = self.__dict__.copy()
-        for key in property_names:
-            dict_with_properties[key] = getattr(self, key)
-
-        argslist = []
-        for key in subclass_args:
-            # make sure lists containing arrays are returned just as lists
-            if isinstance(dict_with_properties[key], list):
-                argsval = np.asarray(dict_with_properties[key]).tolist()
-            else:
-                argsval = dict_with_properties[key]
-            argslist.append('{}={}'.format(key, repr(argsval)))
-        args = ', '.join(argslist)
+        instantiation_dict = self._get_instantiation_dict()
+        args = ', '.join(['{}={}'.format(key, repr(instantiation_dict[key]))
+                          for key in instantiation_dict])
         return "{}({})".format(dist_name, args)
 
     def __eq__(self, other):
diff --git a/bilby/core/result.py b/bilby/core/result.py
index 8b4d97dda30734a85eef823b55fa0787f77cc9fb..5ac7615e78e8a6f1204fd138c5ec0728e86b0a8c 100644
--- a/bilby/core/result.py
+++ b/bilby/core/result.py
@@ -18,8 +18,8 @@ from scipy.special import logsumexp
 
 from . import utils
 from .utils import (logger, infer_parameters_from_function,
-                    check_directory_exists_and_if_not_mkdir,
-                    BilbyJsonEncoder, decode_bilby_json)
+                    check_directory_exists_and_if_not_mkdir,)
+from .utils import BilbyJsonEncoder, decode_bilby_json
 from .prior import Prior, PriorDict, DeltaFunction
 
 
@@ -264,12 +264,6 @@ class Result(object):
             else:
                 with open(filename, 'r') as file:
                     dictionary = json.load(file, object_hook=decode_bilby_json)
-            for key in dictionary.keys():
-                # Convert the loaded priors to bilby prior type
-                if key == 'priors':
-                    for param in dictionary[key].keys():
-                        dictionary[key][param] = str(dictionary[key][param])
-                    dictionary[key] = PriorDict(dictionary[key])
             try:
                 return cls(**dictionary)
             except TypeError as e:
@@ -467,8 +461,6 @@ class Result(object):
 
         # Convert the prior to a string representation for saving on disk
         dictionary = self._get_save_data_dictionary()
-        if dictionary.get('priors', False):
-            dictionary['priors'] = {key: str(self.priors[key]) for key in self.priors}
 
         # Convert callable sampler_kwargs to strings
         if dictionary.get('sampler_kwargs', None) is not None:
@@ -478,6 +470,7 @@ class Result(object):
 
         try:
             if extension == 'json':
+                dictionary["priors"] = dictionary["priors"]._get_json_dict()
                 if gzip:
                     import gzip
                     # encode to a string
diff --git a/bilby/core/utils.py b/bilby/core/utils.py
index 5506a443819016c905760c9a0cc55ecfad93341a..6f53601f79d80a942c6317720c838583d2c9c979 100644
--- a/bilby/core/utils.py
+++ b/bilby/core/utils.py
@@ -7,8 +7,9 @@ import argparse
 import traceback
 import inspect
 import subprocess
-import json
 import multiprocessing
+from importlib import import_module
+import json
 
 import numpy as np
 from scipy.interpolate import interp2d
@@ -907,7 +908,25 @@ else:
 
 
 class BilbyJsonEncoder(json.JSONEncoder):
+
     def default(self, obj):
+        from .prior import MultivariateGaussianDist, Prior, PriorDict
+        if isinstance(obj, PriorDict):
+            return {'__prior_dict__': True, 'content': obj._get_json_dict()}
+        if isinstance(obj, (MultivariateGaussianDist, Prior)):
+            return {'__prior__': True, '__module__': obj.__module__,
+                    '__name__': obj.__class__.__name__,
+                    'kwargs': dict(obj._get_instantiation_dict())}
+        try:
+            from astropy import cosmology as cosmo, units
+            if isinstance(obj, cosmo.FLRW):
+                return encode_astropy_cosmology(obj)
+            if isinstance(obj, units.Quantity):
+                return encode_astropy_quantity(obj)
+            if isinstance(obj, units.PrefixUnit):
+                return str(obj)
+        except ImportError:
+            logger.info("Cannot import astropy, cannot write cosmological priors")
         if isinstance(obj, np.ndarray):
             return {'__array__': True, 'content': obj.tolist()}
         if isinstance(obj, complex):
@@ -917,7 +936,35 @@ class BilbyJsonEncoder(json.JSONEncoder):
         return json.JSONEncoder.default(self, obj)
 
 
+def encode_astropy_cosmology(obj):
+    cls_name = obj.__class__.__name__
+    dct = {key: getattr(obj, key) for
+           key in infer_args_from_method(obj.__init__)}
+    dct['__cosmology__'] = True
+    dct['__name__'] = cls_name
+    return dct
+
+
+def encode_astropy_quantity(dct):
+    dct = dict(__astropy_quantity__=True, value=dct.value, unit=str(dct.unit))
+    if isinstance(dct['value'], np.ndarray):
+        dct['value'] = list(dct['value'])
+    return dct
+
+
 def decode_bilby_json(dct):
+    if dct.get("__prior_dict__", False):
+        cls = getattr(import_module(dct['__module__']), dct['__name__'])
+        obj = cls._get_from_json_dict(dct)
+        return obj
+    if dct.get("__prior__", False):
+        cls = getattr(import_module(dct['__module__']), dct['__name__'])
+        obj = cls(**dct['kwargs'])
+        return obj
+    if dct.get("__cosmology__", False):
+        return decode_astropy_cosmology(dct)
+    if dct.get("__astropy_quantity__", False):
+        return decode_astropy_quantity(dct)
     if dct.get("__array__", False):
         return np.asarray(dct["content"])
     if dct.get("__complex__", False):
@@ -927,5 +974,31 @@ def decode_bilby_json(dct):
     return dct
 
 
+def decode_astropy_cosmology(dct):
+    try:
+        from astropy import cosmology as cosmo
+        cosmo_cls = getattr(cosmo, dct['__name__'])
+        del dct['__cosmology__'], dct['__name__']
+        return cosmo_cls(**dct)
+    except ImportError:
+        logger.info("Cannot import astropy, cosmological priors may not be "
+                    "properly loaded.")
+        return dct
+
+
+def decode_astropy_quantity(dct):
+    try:
+        from astropy import units
+        if dct['value'] is None:
+            return None
+        else:
+            del dct['__astropy_quantity__']
+            return units.Quantity(**dct)
+    except ImportError:
+        logger.info("Cannot import astropy, cosmological priors may not be "
+                    "properly loaded.")
+        return dct
+
+
 class IllegalDurationAndSamplingFrequencyException(Exception):
     pass
diff --git a/bilby/gw/likelihood.py b/bilby/gw/likelihood.py
index 73585744ab81af1a86b1980a37d23c2d3b451776..14f14b7646262a641f4a6ade6c79e57478bd8a3e 100644
--- a/bilby/gw/likelihood.py
+++ b/bilby/gw/likelihood.py
@@ -16,10 +16,10 @@ except ImportError:
 from scipy.special import i0e
 
 from ..core import likelihood
+from ..core.utils import BilbyJsonEncoder, decode_bilby_json
 from ..core.utils import (
-    logger, UnsortedInterp2d, BilbyJsonEncoder, decode_bilby_json,
-    create_frequency_series, create_time_series, speed_of_light,
-    radius_of_earth)
+    logger, UnsortedInterp2d, create_frequency_series, create_time_series,
+    speed_of_light, radius_of_earth)
 from ..core.prior import Interped, Prior, Uniform
 from .detector import InterferometerList
 from .prior import BBHPriorDict
diff --git a/bilby/gw/prior.py b/bilby/gw/prior.py
index 83dcf925d0fd33581746feee7d055d3eafaa7547..890659561bda732a96e1b01d5798eea159612ead 100644
--- a/bilby/gw/prior.py
+++ b/bilby/gw/prior.py
@@ -42,8 +42,8 @@ class Cosmological(Interped):
         if latex_label is not None:
             label_args['latex_label'] = latex_label
         if unit is not None:
-            if isinstance(unit, str):
-                unit = units.__dict__[unit]
+            if not isinstance(unit, units.Unit):
+                unit = units.Unit(unit)
             label_args['unit'] = unit
         self.unit = label_args['unit']
         self._minimum = dict()
diff --git a/test/gw_prior_test.py b/test/gw_prior_test.py
index 57748331680547a4cd4551ca90772285da3f24e4..9195aebed61948e8759b26c642f4351eaec17f7a 100644
--- a/test/gw_prior_test.py
+++ b/test/gw_prior_test.py
@@ -1,4 +1,5 @@
 from __future__ import division, absolute_import
+from collections import OrderedDict
 import unittest
 import os
 import sys
@@ -144,7 +145,7 @@ class TestPackagedPriors(unittest.TestCase):
 class TestBNSPriorDict(unittest.TestCase):
 
     def setUp(self):
-        self.prior_dict = dict()
+        self.prior_dict = OrderedDict()
         self.base_directory =\
             '/'.join(os.path.dirname(
                 os.path.abspath(sys.argv[0])).split('/')[:-1])
diff --git a/test/prior_test.py b/test/prior_test.py
index a042cdbe7032e5d401e8f9fa40ea164ba6e69750..852031a7286cdf68d55653f49f10a6dd4c8fcc14 100644
--- a/test/prior_test.py
+++ b/test/prior_test.py
@@ -816,5 +816,65 @@ class TestCreateDefaultPrior(unittest.TestCase):
         self.assertIsNone(bilby.core.prior.create_default_prior(name='name', default_priors_file=prior_file))
 
 
+class TestJsonIO(unittest.TestCase):
+
+    def setUp(self):
+        mvg = bilby.core.prior.MultivariateGaussianDist(names=['testa', 'testb'],
+                                                        mus=[1, 1],
+                                                        covs=np.array([[2., 0.5], [0.5, 2.]]),
+                                                        weights=1.)
+        mvn = bilby.core.prior.MultivariateGaussianDist(names=['testa', 'testb'],
+                                                        mus=[1, 1],
+                                                        covs=np.array([[2., 0.5], [0.5, 2.]]),
+                                                        weights=1.)
+
+        self.priors = bilby.core.prior.PriorDict(dict(
+            a=bilby.core.prior.DeltaFunction(name='test', unit='unit', peak=1),
+            b=bilby.core.prior.Gaussian(name='test', unit='unit', mu=0, sigma=1),
+            c=bilby.core.prior.Normal(name='test', unit='unit', mu=0, sigma=1),
+            d=bilby.core.prior.PowerLaw(name='test', unit='unit', alpha=0, minimum=0, maximum=1),
+            e=bilby.core.prior.PowerLaw(name='test', unit='unit', alpha=-1, minimum=0.5, maximum=1),
+            f=bilby.core.prior.PowerLaw(name='test', unit='unit', alpha=2, minimum=1, maximum=1e2),
+            g=bilby.core.prior.Uniform(name='test', unit='unit', minimum=0, maximum=1),
+            h=bilby.core.prior.LogUniform(name='test', unit='unit', minimum=5e0, maximum=1e2),
+            i=bilby.gw.prior.UniformComovingVolume(name='redshift', minimum=0.1, maximum=1.0),
+            j=bilby.gw.prior.UniformSourceFrame(name='luminosity_distance', minimum=1.0, maximum=1000.0),
+            k=bilby.core.prior.Sine(name='test', unit='unit'),
+            l=bilby.core.prior.Cosine(name='test', unit='unit'),
+            m=bilby.core.prior.Interped(name='test', unit='unit', xx=np.linspace(0, 10, 1000),
+                                        yy=np.linspace(0, 10, 1000) ** 4,
+                                        minimum=3, maximum=5),
+            n=bilby.core.prior.TruncatedGaussian(name='test', unit='unit', mu=1, sigma=0.4, minimum=-1, maximum=1),
+            o=bilby.core.prior.TruncatedNormal(name='test', unit='unit', mu=1, sigma=0.4, minimum=-1, maximum=1),
+            p=bilby.core.prior.HalfGaussian(name='test', unit='unit', sigma=1),
+            q=bilby.core.prior.HalfNormal(name='test', unit='unit', sigma=1),
+            r=bilby.core.prior.LogGaussian(name='test', unit='unit', mu=0, sigma=1),
+            s=bilby.core.prior.LogNormal(name='test', unit='unit', mu=0, sigma=1),
+            t=bilby.core.prior.Exponential(name='test', unit='unit', mu=1),
+            u=bilby.core.prior.StudentT(name='test', unit='unit', df=3, mu=0, scale=1),
+            v=bilby.core.prior.Beta(name='test', unit='unit', alpha=2.0, beta=2.0),
+            x=bilby.core.prior.Logistic(name='test', unit='unit', mu=0, scale=1),
+            y=bilby.core.prior.Cauchy(name='test', unit='unit', alpha=0, beta=1),
+            z=bilby.core.prior.Lorentzian(name='test', unit='unit', alpha=0, beta=1),
+            aa=bilby.core.prior.Gamma(name='test', unit='unit', k=1, theta=1),
+            ab=bilby.core.prior.ChiSquared(name='test', unit='unit', nu=2),
+            ac=bilby.gw.prior.AlignedSpin(name='test', unit='unit'),
+            ad=bilby.core.prior.MultivariateGaussian(mvg=mvg, name='testa', unit='unit'),
+            ae=bilby.core.prior.MultivariateGaussian(mvg=mvg, name='testb', unit='unit'),
+            af=bilby.core.prior.MultivariateNormal(mvg=mvn, name='testa', unit='unit'),
+            ag=bilby.core.prior.MultivariateNormal(mvg=mvn, name='testb', unit='unit')
+        ))
+
+    def test_read_write_to_json(self):
+        """ Interped prior is removed as there is numerical error in the recovered prior."""
+        self.priors.to_json(outdir="prior_files", label="json_test")
+        new_priors = bilby.core.prior.PriorDict.from_json(filename="prior_files/json_test_prior.json")
+        old_interped = self.priors.pop("m")
+        new_interped = new_priors.pop("m")
+        self.assertDictEqual(self.priors, new_priors)
+        self.assertLess(max(abs(old_interped.xx - new_interped.xx)), 1e-15)
+        self.assertLess(max(abs(old_interped.yy - new_interped.yy)), 1e-15)
+
+
 if __name__ == '__main__':
     unittest.main()