prior.py 21.2 KB
Newer Older
Colm Talbot's avatar
Colm Talbot committed
1
#!/bin/python
2
from __future__ import division
Colm Talbot's avatar
Colm Talbot committed
3 4

import numpy as np
Colm Talbot's avatar
Colm Talbot committed
5 6
from scipy.interpolate import interp1d
from scipy.integrate import cumtrapz
Colm Talbot's avatar
Colm Talbot committed
7
from scipy.special import erf, erfinv
Colm Talbot's avatar
Colm Talbot committed
8
import logging
Colm Talbot's avatar
Colm Talbot committed
9
import os
Colm Talbot's avatar
Colm Talbot committed
10 11 12


class Prior(object):
13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43
    """
    Prior class

    Methods
    -------
    __init__:
        Instantiate a prior object.
    __call__:
        Draw a single sample from the prior.
    __repr__:
        Print prior type and parameters.
    sample(size=None):
        Draw samples of size size from the prior.
    rescale(val):
        Rescale samples from a uniform distribution on [0, 1] to samples from the prior.
    test_valid_for_recaling(val):
        Test whether val is in [0, 1] and hence valid for rescaling.

    Parameters
    ----------
    name: str
        Name associated with prior.
    latex_label: str
        Latex label associated with prior, used for plotting.
    minimum: float, optional
        Minimum of the domain, default=-np.inf
    maximum: float, optional
        Maximum of the domain, default=np.inf
    """

    def __init__(self, name=None, latex_label=None, minimum=-np.inf, maximum=np.inf):
44 45
        self.name = name
        self.latex_label = latex_label
46 47
        self.minimum = minimum
        self.maximum = maximum
Colm Talbot's avatar
Colm Talbot committed
48 49

    def __call__(self):
Gregory Ashton's avatar
Gregory Ashton committed
50
        return self.sample()
Colm Talbot's avatar
Colm Talbot committed
51

Gregory Ashton's avatar
Gregory Ashton committed
52 53 54
    def sample(self, size=None):
        """Draw a sample from the prior """
        return self.rescale(np.random.uniform(0, 1, size))
Colm Talbot's avatar
Colm Talbot committed
55

Colm Talbot's avatar
Colm Talbot committed
56
    def rescale(self, val):
Colm Talbot's avatar
Colm Talbot committed
57
        """
58
        'Rescale' a sample from the unit line element to the prior.
Colm Talbot's avatar
Colm Talbot committed
59

60
        This should be overwritten by each subclass.
Colm Talbot's avatar
Colm Talbot committed
61
        """
62 63 64
        return None

    @staticmethod
Colm Talbot's avatar
Colm Talbot committed
65
    def test_valid_for_rescaling(val):
66
        """Test if 0 < val < 1"""
Gregory Ashton's avatar
Gregory Ashton committed
67 68 69
        val = np.atleast_1d(val)
        tests = (val < 0) + (val > 1)
        if np.any(tests):
70
            raise ValueError("Number to be rescaled should be in [0, 1]")
Colm Talbot's avatar
Colm Talbot committed
71

72 73 74
    def __repr__(self):
        return self.subclass_repr_helper()

75
    def subclass_repr_helper(self, subclass_args=list()):
76
        prior_name = self.__class__.__name__
77 78
        args = ['name', 'latex_label', 'minimum', 'maximum']
        args.extend(subclass_args)
79

80 81 82 83 84
        property_names = [p for p in dir(self.__class__) if isinstance(getattr(self.__class__, p), property)]
        dict_with_properties = self.__dict__.copy()
        for key in property_names:
            dict_with_properties[key] = getattr(self, key)

85 86
        args = ', '.join(['{}={}'.format(key, repr(dict_with_properties[key])) for key in args])
        return "{}({})".format(prior_name, args)
87

88 89 90 91 92 93 94 95 96 97 98 99 100 101 102
    @property
    def is_fixed(self):
        return isinstance(self, DeltaFunction)

    @property
    def latex_label(self):
        return self.__latex_label

    @latex_label.setter
    def latex_label(self, latex_label=None):
        if latex_label is None:
            self.__latex_label = self.__default_latex_label
        else:
            self.__latex_label = latex_label

103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118
    @property
    def minimum(self):
        return self.__minimum

    @minimum.setter
    def minimum(self, minimum):
        self.__minimum = minimum

    @property
    def maximum(self):
        return self.__maximum

    @maximum.setter
    def maximum(self, maximum):
        self.__maximum = maximum

119 120
    @property
    def __default_latex_label(self):
121 122 123
        default_labels = {
            'mass_1': '$m_1$',
            'mass_2': '$m_2$',
124 125 126 127
            'total_mass': '$M$',
            'chirp_mass': '$\mathcal{M}$',
            'mass_ratio': '$q$',
            'symmetric_mass_ratio': '$\eta$',
128 129 130 131
            'a_1': '$a_1$',
            'a_2': '$a_2$',
            'tilt_1': '$\\theta_1$',
            'tilt_2': '$\\theta_2$',
132 133
            'cos_tilt_1': '$\cos\\theta_1$',
            'cos_tilt_2': '$\cos\\theta_2$',
134 135 136 137 138 139
            'phi_12': '$\Delta\phi$',
            'phi_jl': '$\phi_{JL}$',
            'luminosity_distance': '$d_L$',
            'dec': '$\mathrm{DEC}$',
            'ra': '$\mathrm{RA}$',
            'iota': '$\iota$',
140
            'cos_iota': '$\cos\iota$',
141 142 143 144 145 146
            'psi': '$\psi$',
            'phase': '$\phi$',
            'geocent_time': '$t_c$'
        }
        if self.name in default_labels.keys():
            label = default_labels[self.name]
147
        else:
148 149
            label = self.name
        return label
150

Colm Talbot's avatar
Colm Talbot committed
151 152

class DeltaFunction(Prior):
Colm Talbot's avatar
Colm Talbot committed
153
    """Dirac delta function prior, this always returns peak."""
Colm Talbot's avatar
Colm Talbot committed
154

155
    def __init__(self, peak, name=None, latex_label=None):
156
        Prior.__init__(self, name, latex_label, minimum=peak, maximum=peak)
Colm Talbot's avatar
Colm Talbot committed
157 158 159
        self.peak = peak

    def rescale(self, val):
Colm Talbot's avatar
Colm Talbot committed
160
        """Rescale everything to the peak with the correct shape."""
161
        Prior.test_valid_for_rescaling(val)
Colm Talbot's avatar
Colm Talbot committed
162 163 164
        return self.peak * val ** 0

    def prob(self, val):
Colm Talbot's avatar
Colm Talbot committed
165
        """Return the prior probability of val"""
Colm Talbot's avatar
Colm Talbot committed
166 167 168 169 170
        if self.peak == val:
            return np.inf
        else:
            return 0

171
    def __repr__(self):
172
        return Prior.subclass_repr_helper(self, subclass_args=['peak'])
173

Colm Talbot's avatar
Colm Talbot committed
174 175

class PowerLaw(Prior):
Colm Talbot's avatar
Colm Talbot committed
176
    """Power law prior distribution"""
Colm Talbot's avatar
Colm Talbot committed
177

178
    def __init__(self, alpha, minimum, maximum, name=None, latex_label=None):
Colm Talbot's avatar
Colm Talbot committed
179
        """Power law with bounds and alpha, spectral index"""
180
        Prior.__init__(self, name, latex_label, minimum, maximum)
Colm Talbot's avatar
Colm Talbot committed
181 182 183
        self.alpha = alpha

    def rescale(self, val):
Colm Talbot's avatar
Colm Talbot committed
184 185 186 187 188
        """
        'Rescale' a sample from the unit line element to the power-law prior.

        This maps to the inverse CDF. This has been analytically solved for this case.
        """
189
        Prior.test_valid_for_rescaling(val)
Colm Talbot's avatar
Colm Talbot committed
190
        if self.alpha == -1:
191
            return self.minimum * np.exp(val * np.log(self.maximum / self.minimum))
Colm Talbot's avatar
Colm Talbot committed
192
        else:
193 194
            return (self.minimum ** (1 + self.alpha) + val *
                    (self.maximum ** (1 + self.alpha) - self.minimum ** (1 + self.alpha))) ** (1. / (1 + self.alpha))
Colm Talbot's avatar
Colm Talbot committed
195 196

    def prob(self, val):
Colm Talbot's avatar
Colm Talbot committed
197
        """Return the prior probability of val"""
198 199 200
        in_prior = (val >= self.minimum) & (val <= self.maximum)
        if self.alpha == -1:
            return np.nan_to_num(1 / val / np.log(self.maximum / self.minimum)) * in_prior
201
        else:
202 203
            return np.nan_to_num(val ** self.alpha * (1 + self.alpha) / (self.maximum ** (1 + self.alpha)
                                                                         - self.minimum ** (1 + self.alpha))) * in_prior
Colm Talbot's avatar
Colm Talbot committed
204

205 206
    def lnprob(self, val):
        in_prior = (val >= self.minimum) & (val <= self.maximum)
207 208
        normalising = (1 + self.alpha) / (self.maximum ** (1 + self.alpha)
                                          - self.minimum ** (1 + self.alpha))
Gregory Ashton's avatar
Gregory Ashton committed
209
        return self.alpha * np.log(val) * np.log(normalising) * in_prior
210

211
    def __repr__(self):
212
        return Prior.subclass_repr_helper(self, subclass_args=['alpha'])
213

214

215 216 217 218 219 220 221
class Uniform(PowerLaw):
    """Uniform prior"""

    def __init__(self, minimum, maximum, name=None, latex_label=None):
        Prior.__init__(self, name, latex_label, minimum, maximum)
        self.alpha = 0

222 223 224
    def __repr__(self, subclass_keys=list(), subclass_names=list()):
        return PowerLaw.__repr__(self)

225 226 227 228 229 230 231

class LogUniform(PowerLaw):
    """Uniform prior"""

    def __init__(self, minimum, maximum, name=None, latex_label=None):
        Prior.__init__(self, name, latex_label, minimum, maximum)
        self.alpha = -1
232
        if self.minimum <= 0:
233 234
            logging.warning('You specified a uniform-in-log prior with minimum={}'.format(self.minimum))

235 236 237
    def __repr__(self, subclass_keys=list(), subclass_names=list()):
        return PowerLaw.__repr__(self)

238

Colm Talbot's avatar
Colm Talbot committed
239 240
class Cosine(Prior):

241 242
    def __init__(self, name=None, latex_label=None, minimum=-np.pi / 2, maximum=np.pi / 2):
        Prior.__init__(self, name, latex_label, minimum, maximum)
Colm Talbot's avatar
Colm Talbot committed
243

244
    def rescale(self, val):
Colm Talbot's avatar
Colm Talbot committed
245 246 247 248 249
        """
        'Rescale' a sample from the unit line element to a uniform in cosine prior.

        This maps to the inverse CDF. This has been analytically solved for this case.
        """
250
        Prior.test_valid_for_rescaling(val)
Colm Talbot's avatar
Colm Talbot committed
251 252
        return np.arcsin(-1 + val * 2)

253
    def prob(self, val):
254
        """Return the prior probability of val, defined over [-pi/2, pi/2]"""
255 256
        in_prior = (val >= self.minimum) & (val <= self.maximum)
        return np.cos(val) / 2 * in_prior
Colm Talbot's avatar
Colm Talbot committed
257

258
    def __repr__(self, subclass_keys=list(), subclass_names=list()):
259
        return Prior.subclass_repr_helper(self)
260

Colm Talbot's avatar
Colm Talbot committed
261 262 263

class Sine(Prior):

264 265
    def __init__(self, name=None, latex_label=None, minimum=0, maximum=np.pi):
        Prior.__init__(self, name, latex_label, minimum, maximum)
Colm Talbot's avatar
Colm Talbot committed
266

267
    def rescale(self, val):
Colm Talbot's avatar
Colm Talbot committed
268 269 270 271 272
        """
        'Rescale' a sample from the unit line element to a uniform in sine prior.

        This maps to the inverse CDF. This has been analytically solved for this case.
        """
273
        Prior.test_valid_for_rescaling(val)
274
        return np.arccos(1 - val * 2)
Colm Talbot's avatar
Colm Talbot committed
275

276
    def prob(self, val):
277
        """Return the prior probability of val, defined over [0, pi]"""
278 279
        in_prior = (val >= self.minimum) & (val <= self.maximum)
        return np.sin(val) / 2 * in_prior
Colm Talbot's avatar
Colm Talbot committed
280

281
    def __repr__(self, subclass_keys=list(), subclass_names=list()):
282
        return Prior.subclass_repr_helper(self)
283

Colm Talbot's avatar
Colm Talbot committed
284

Colm Talbot's avatar
Colm Talbot committed
285 286 287 288 289 290 291 292 293 294 295 296 297 298 299
class Gaussian(Prior):
    """Gaussian prior"""

    def __init__(self, mu, sigma, name=None, latex_label=None):
        """Power law with bounds and alpha, spectral index"""
        Prior.__init__(self, name, latex_label)
        self.mu = mu
        self.sigma = sigma

    def rescale(self, val):
        """
        'Rescale' a sample from the unit line element to the appropriate Gaussian prior.

        This maps to the inverse CDF. This has been analytically solved for this case.
        """
300
        Prior.test_valid_for_rescaling(val)
301
        return self.mu + erfinv(2 * val - 1) * 2 ** 0.5 * self.sigma
Colm Talbot's avatar
Colm Talbot committed
302 303 304

    def prob(self, val):
        """Return the prior probability of val"""
305
        return np.exp(-(self.mu - val) ** 2 / (2 * self.sigma ** 2)) / (2 * np.pi) ** 0.5 / self.sigma
Colm Talbot's avatar
Colm Talbot committed
306

307
    def lnprob(self, val):
308
        return -0.5 * ((self.mu - val) ** 2 / self.sigma ** 2 + np.log(2 * np.pi * self.sigma ** 2))
309

310
    def __repr__(self):
311
        return Prior.subclass_repr_helper(self, subclass_args=['mu', 'sigma'])
312

Colm Talbot's avatar
Colm Talbot committed
313

Colm Talbot's avatar
Colm Talbot committed
314 315 316 317 318 319 320
class TruncatedGaussian(Prior):
    """
    Truncated Gaussian prior

    https://en.wikipedia.org/wiki/Truncated_normal_distribution
    """

321
    def __init__(self, mu, sigma, minimum, maximum, name=None, latex_label=None):
Colm Talbot's avatar
Colm Talbot committed
322
        """Power law with bounds and alpha, spectral index"""
323
        Prior.__init__(self, name=name, latex_label=latex_label, minimum=minimum, maximum=maximum)
Colm Talbot's avatar
Colm Talbot committed
324 325 326
        self.mu = mu
        self.sigma = sigma

327 328
        self.normalisation = (erf((self.maximum - self.mu) / 2 ** 0.5 / self.sigma) - erf(
            (self.minimum - self.mu) / 2 ** 0.5 / self.sigma)) / 2
Colm Talbot's avatar
Colm Talbot committed
329 330 331 332 333 334 335

    def rescale(self, val):
        """
        'Rescale' a sample from the unit line element to the appropriate truncated Gaussian prior.

        This maps to the inverse CDF. This has been analytically solved for this case.
        """
336
        Prior.test_valid_for_rescaling(val)
Colm Talbot's avatar
Colm Talbot committed
337
        return erfinv(2 * val * self.normalisation + erf(
338
            (self.minimum - self.mu) / 2 ** 0.5 / self.sigma)) * 2 ** 0.5 * self.sigma + self.mu
Colm Talbot's avatar
Colm Talbot committed
339 340 341

    def prob(self, val):
        """Return the prior probability of val"""
342 343
        in_prior = (val >= self.minimum) & (val <= self.maximum)
        return np.exp(-(self.mu - val) ** 2 / (2 * self.sigma ** 2)) / (
344
                2 * np.pi) ** 0.5 / self.sigma / self.normalisation * in_prior
Colm Talbot's avatar
Colm Talbot committed
345

346
    def __repr__(self):
347
        return Prior.subclass_repr_helper(self, subclass_args=['mu', 'sigma'])
348

Colm Talbot's avatar
Colm Talbot committed
349

Colm Talbot's avatar
Colm Talbot committed
350 351
class Interped(Prior):

352
    def __init__(self, xx, yy, minimum=np.nan, maximum=np.nan, name=None, latex_label=None):
Colm Talbot's avatar
Colm Talbot committed
353
        """Initialise object from arrays of x and y=p(x)"""
354 355 356
        self.xx = xx
        self.yy = yy
        self.all_interpolated = interp1d(x=xx, y=yy, bounds_error=False, fill_value=0)
357 358
        Prior.__init__(self, name, latex_label,
                       minimum=np.nanmax(np.array((min(xx), minimum))),
359
                       maximum=np.nanmin(np.array((max(xx), maximum))))
360
        self.__initialize_attributes()
Colm Talbot's avatar
Colm Talbot committed
361

Colm Talbot's avatar
Colm Talbot committed
362 363
    def prob(self, val):
        """Return the prior probability of val"""
364
        return self.probability_density(val)
Colm Talbot's avatar
Colm Talbot committed
365

366
    def rescale(self, val):
Colm Talbot's avatar
Colm Talbot committed
367 368 369 370 371
        """
        'Rescale' a sample from the unit line element to the prior.

        This maps to the inverse CDF. This is done using interpolation.
        """
372
        Prior.test_valid_for_rescaling(val)
373 374 375 376
        rescaled = self.inverse_cumulative_distribution(val)
        if rescaled.shape == ():
            rescaled = float(rescaled)
        return rescaled
Colm Talbot's avatar
Colm Talbot committed
377

378
    def __repr__(self):
379
        return Prior.subclass_repr_helper(self, subclass_args=['xx', 'yy'])
380

381 382 383 384 385 386 387
    @property
    def minimum(self):
        return self.__minimum

    @minimum.setter
    def minimum(self, minimum):
        self.__minimum = minimum
388 389
        if '_Interped__maximum' in self.__dict__ and self.__maximum < np.inf:
            self.__update_instance()
390 391 392 393 394 395 396 397

    @property
    def maximum(self):
        return self.__maximum

    @maximum.setter
    def maximum(self, maximum):
        self.__maximum = maximum
398 399 400 401 402 403
        if '_Interped__minimum' in self.__dict__ and self.__minimum < np.inf:
            self.__update_instance()

    def __update_instance(self):
        self.xx = np.linspace(self.minimum, self.maximum, len(self.xx))
        self.yy = self.all_interpolated(self.xx)
404 405 406
        self.__initialize_attributes()

    def __initialize_attributes(self):
407 408 409 410 411 412 413 414 415
        if np.trapz(self.yy, self.xx) != 1:
            logging.info('Supplied PDF for {} is not normalised, normalising.'.format(self.name))
        self.yy /= np.trapz(self.yy, self.xx)
        self.YY = cumtrapz(self.yy, self.xx, initial=0)
        # Need last element of cumulative distribution to be exactly one.
        self.YY[-1] = 1
        self.probability_density = interp1d(x=self.xx, y=self.yy, bounds_error=False, fill_value=0)
        self.cumulative_distribution = interp1d(x=self.xx, y=self.YY, bounds_error=False, fill_value=0)
        self.inverse_cumulative_distribution = interp1d(x=self.YY, y=self.xx, bounds_error=True)
416

Colm Talbot's avatar
Colm Talbot committed
417 418 419

class FromFile(Interped):

Colm Talbot's avatar
Colm Talbot committed
420
    def __init__(self, file_name, minimum=None, maximum=None, name=None, latex_label=None):
Colm Talbot's avatar
Colm Talbot committed
421
        try:
Colm Talbot's avatar
Colm Talbot committed
422
            self.id = file_name
Colm Talbot's avatar
Colm Talbot committed
423
            if '/' not in self.id:
Colm Talbot's avatar
Colm Talbot committed
424
                self.id = os.path.join(os.path.dirname(__file__), 'prior_files', self.id)
Colm Talbot's avatar
Colm Talbot committed
425 426
            xx, yy = np.genfromtxt(self.id).T
            Interped.__init__(self, xx=xx, yy=yy, minimum=minimum, maximum=maximum, name=name, latex_label=latex_label)
Colm Talbot's avatar
Colm Talbot committed
427
        except IOError:
428 429 430
            logging.warning("Can't load {}.".format(self.id))
            logging.warning("Format should be:")
            logging.warning(r"x\tp(x)")
431

432
    def __repr__(self, subclass_keys=list(), subclass_names=list()):
433
        return Prior.subclass_repr_helper(self, subclass_args=['xx', 'yy', 'id'])
434

435

436 437 438
class UniformComovingVolume(FromFile):

    def __init__(self, minimum=None, maximum=None, name=None, latex_label=None):
Colm Talbot's avatar
Colm Talbot committed
439 440
        FromFile.__init__(self, file_name='comoving.txt', minimum=minimum, maximum=maximum, name=name,
                          latex_label=latex_label)
441

442 443
    def __repr__(self, subclass_keys=list(), subclass_names=list()):
        return FromFile.__repr__(self)
444

445

446
def create_default_prior(name):
447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462
    """
    Make a default prior for a parameter with a known name.

    This is currently set up for binary black holes.

    Parameters
    ----------
    name: str
        Parameter name

    Return
    ------
    prior: Prior
        Default prior distribution for that parameter, if unknown None is returned.
    """
    default_priors = {
463 464 465 466 467 468
        'mass_1': Uniform(name=name, minimum=5, maximum=100),
        'mass_2': Uniform(name=name, minimum=5, maximum=100),
        'chirp_mass': Uniform(name=name, minimum=5, maximum=100),
        'total_mass': Uniform(name=name, minimum=10, maximum=200),
        'mass_ratio': Uniform(name=name, minimum=0.125, maximum=1),
        'symmetric_mass_ratio': Uniform(name=name, minimum=8 / 81, maximum=0.25),
469 470 471 472
        'a_1': Uniform(name=name, minimum=0, maximum=0.8),
        'a_2': Uniform(name=name, minimum=0, maximum=0.8),
        'tilt_1': Sine(name=name),
        'tilt_2': Sine(name=name),
473 474
        'cos_tilt_1': Uniform(name=name, minimum=-1, maximum=1),
        'cos_tilt_2': Uniform(name=name, minimum=-1, maximum=1),
475 476 477 478 479 480
        'phi_12': Uniform(name=name, minimum=0, maximum=2 * np.pi),
        'phi_jl': Uniform(name=name, minimum=0, maximum=2 * np.pi),
        'luminosity_distance': UniformComovingVolume(name=name, minimum=1e2, maximum=5e3),
        'dec': Cosine(name=name),
        'ra': Uniform(name=name, minimum=0, maximum=2 * np.pi),
        'iota': Sine(name=name),
481
        'cos_iota': Uniform(name=name, minimum=-1, maximum=1),
482 483 484 485 486
        'psi': Uniform(name=name, minimum=0, maximum=2 * np.pi),
        'phase': Uniform(name=name, minimum=0, maximum=2 * np.pi)
    }
    if name in default_priors.keys():
        prior = default_priors[name]
487
    else:
488 489
        logging.info(
            "No default prior found for variable {}.".format(name))
490 491 492 493
        prior = None
    return prior


494
def fill_priors(prior, likelihood):
495
    """
496
    Fill dictionary of priors based on required parameters of likelihood
497 498 499

    Any floats in prior will be converted to delta function prior. Any
    required, non-specified parameters will use the default.
500 501 502 503 504

    Parameters
    ----------
    prior: dict
        dictionary of prior objects and floats
505
    likelihood: tupak.likelihood.GravitationalWaveTransient instance
506
        Used to infer the set of parameters to fill the prior with
507

508 509
    Note: if `likelihood` has `non_standard_sampling_parameter_keys`, then this
    will set-up default priors for those as well.
510 511 512 513 514 515

    Returns
    -------
    prior: dict
        The filled prior dictionary

516
    """
517

Colm Talbot's avatar
Colm Talbot committed
518 519 520
    for key in prior:
        if isinstance(prior[key], Prior):
            continue
521 522
        elif isinstance(prior[key], float) or isinstance(prior[key], int):
            prior[key] = DeltaFunction(prior[key])
523 524
            logging.info(
                "{} converted to delta function prior.".format(key))
Colm Talbot's avatar
Colm Talbot committed
525
        else:
526 527
            logging.info(
                "{} cannot be converted to delta function prior.".format(key))
Colm Talbot's avatar
Colm Talbot committed
528

529
    missing_keys = set(likelihood.parameters) - set(prior.keys())
Colm Talbot's avatar
Colm Talbot committed
530

531 532
    if getattr(likelihood, 'non_standard_sampling_parameter_keys', None) is not None:
        for parameter in likelihood.non_standard_sampling_parameter_keys:
533 534
            prior[parameter] = create_default_prior(parameter)

Colm Talbot's avatar
Colm Talbot committed
535
    for missing_key in missing_keys:
536 537
        default_prior = create_default_prior(missing_key)
        if default_prior is None:
538
            set_val = likelihood.parameters[missing_key]
539 540 541
            logging.warning(
                "Parameter {} has no default prior and is set to {}, this will"
                " not be sampled and may cause an error."
542
                    .format(missing_key, set_val))
543
        else:
544 545 546 547 548
            if not test_redundancy(missing_key, prior):
                prior[missing_key] = default_prior

    for key in prior:
        test_redundancy(key, prior)
549

550 551
    return prior

552

553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571
def test_redundancy(key, prior):
    """
    Test whether adding the key would add be redundant.

    Parameters
    ----------
    key: str
        The string to test.
    prior: dict
        Current prior dictionary.

    Return
    ------
    redundant: bool
        Whether the key is redundant
    """
    redundant = False
    mass_parameters = {'mass_1', 'mass_2', 'chirp_mass', 'total_mass', 'mass_ratio', 'symmetric_mass_ratio'}
    spin_magnitude_parameters = {'a_1', 'a_2'}
572 573
    spin_tilt_1_parameters = {'tilt_1', 'cos_tilt_1'}
    spin_tilt_2_parameters = {'tilt_2', 'cos_tilt_2'}
574 575 576 577
    spin_azimuth_parameters = {'phi_1', 'phi_2', 'phi_12', 'phi_jl'}
    inclination_parameters = {'iota', 'cos_iota'}
    distance_parameters = {'luminosity_distance', 'comoving_distance', 'redshift'}

578
    for parameter_set in [mass_parameters, spin_magnitude_parameters, spin_azimuth_parameters]:
579 580 581 582 583 584 585 586 587
        if key in parameter_set:
            if len(parameter_set.intersection(prior.keys())) > 2:
                redundant = True
                logging.warning('{} in prior. This may lead to unexpected behaviour.'.format(
                    parameter_set.intersection(prior.keys())))
                break
            elif len(parameter_set.intersection(prior.keys())) == 2:
                redundant = True
                break
588
    for parameter_set in [inclination_parameters, distance_parameters, spin_tilt_1_parameters, spin_tilt_2_parameters]:
589 590 591 592 593 594 595 596 597 598 599 600 601
        if key in parameter_set:
            if len(parameter_set.intersection(prior.keys())) > 1:
                redundant = True
                logging.warning('{} in prior. This may lead to unexpected behaviour.'.format(
                    parameter_set.intersection(prior.keys())))
                break
            elif len(parameter_set.intersection(prior.keys())) == 1:
                redundant = True
                break

    return redundant


Gregory Ashton's avatar
Gregory Ashton committed
602
def write_priors_to_file(priors, outdir, label):
603
    """
moritz's avatar
moritz committed
604
    Write the prior distribution to file.
605 606 607 608 609

    Parameters
    ----------
    priors: dict
        priors used
Gregory Ashton's avatar
Gregory Ashton committed
610 611
    outdir, label: str
        output directory and label
612
    """
Gregory Ashton's avatar
Gregory Ashton committed
613 614 615

    prior_file = os.path.join(outdir, "{}_prior.txt".format(label))
    logging.debug("Writing priors to {}".format(prior_file))
616 617
    with open(prior_file, "w") as outfile:
        for key in priors:
Colm Talbot's avatar
Colm Talbot committed
618
            outfile.write("prior['{}'] = {}\n".format(key, priors[key]))