diff --git a/bilby/core/prior.py b/bilby/core/prior.py index 276fd0e4bfc23c0c5c664e2cb247cb0cfd2479e5..f5b87d6f21d55e73fc8901e70e52bdbef80f9b1f 100644 --- a/bilby/core/prior.py +++ b/bilby/core/prior.py @@ -1941,8 +1941,8 @@ class MultivariateGaussian(object): Check if all requested parameters have been filled. """ - return np.any([val is None for val in - self.requested_parameters.values()]) + return not np.any([val is None for val in + self.requested_parameters.values()]) def reset_request(self): """ @@ -1957,8 +1957,8 @@ class MultivariateGaussian(object): Check is all the rescaled parameters have been filled. """ - return np.any([val is None for val in - self.rescale_parameters.values()]) + return not np.any([val is None for val in + self.rescale_parameters.values()]) def reset_rescale(self): """ @@ -2060,7 +2060,7 @@ class MultivariateGaussian(object): self.mvn.append(scipy.stats.multivariate_normal(mean=self.mus[-1], cov=self.covs[-1])) - def rescale(self, value): + def rescale(self, value, mode=None): """ Rescale from a unit hypercube to multivariate Gaussian. Note that no bounds are applied in the rescale function. @@ -2070,6 +2070,9 @@ class MultivariateGaussian(object): value: array A vector sample (one for each parameter) drawn from a uniform distribution between 0 and 1. + mode: int + Specify which mode to sample from. If not set then a mode is + chosen randomly based on its weight. Returns ------- @@ -2079,21 +2082,31 @@ class MultivariateGaussian(object): """ # pick a mode (with a probability given by their weights) - imode = np.argwhere(self.cumweights - np.random.rand() > 0)[0][0] + if mode is None: + if self.nmodes == 1: + mode = 0 + else: + mode = np.argwhere(self.cumweights - np.random.rand() > 0)[0][0] # draw points from unit variance, uncorrelated Gaussian - samp = erfinv(2. * value - 1) * 2. ** 0.5 + samp = erfinv(2. * np.asarray(value) - 1) * 2. ** 0.5 # rotate and scale to the multivariate normal shape - samp = self.mus[imode] + self.sigmas[imode] * np.einsum('j,kj->k', - samp * self.sqeigvalues[imode], - self.eigvectors[imode]) + samp = self.mus[mode] + self.sigmas[mode] * np.einsum('j,kj->k', + samp * self.sqeigvalues[mode], + self.eigvectors[mode]) return samp - def sample(self): + def sample(self, mode=None): """ - Draw a sample from the multivariate Gaussian. + Draw, and set, a sample from the multivariate Gaussian. + + Parameters + ---------- + mode: int + Specify which mode to sample from. If not set then a mode is + chosen randomly based on its weight. """ # samples drawn from unit variance uncorrelated multivariate Gaussian @@ -2102,7 +2115,7 @@ class MultivariateGaussian(object): # sample the multivariate Gaussian keys vals = np.random.uniform(0, 1, len(self)) - samp = self.rescale(vals) + samp = self.rescale(vals, mode=mode) # check sample is in bounds (otherwise perform another draw) outbound = False @@ -2184,24 +2197,42 @@ class MultivariateGaussianPrior(Prior): maximum=mvg.bounds[name][1]) self.mvg = mvg - def rescale(self, val): + def rescale(self, val, mode=None): """ Scale a unit hypercube sample to the prior. + + Parameters + ---------- + mode: int + Specify which mode to sample from. If not set then a mode is + chosen randomly based on its weight. """ # add parameter value to multivariate Gaussian self.mvg.rescale_parameters[self.name] = val if self.mvg.filled_rescale(): - samples = self.mvg.rescale(list(self.mvg.rescale_parameters.values())) + samples = self.mvg.rescale(list(self.mvg.rescale_parameters.values()), + mode=mode) self.mvg.reset_rescale() return samples else: return [] # return empty list - def sample(self, size=None): + def sample(self, size=None, mode=None): """ Draw a sample from the prior. + + Parameters + ---------- + mode: int + Specify which mode to sample from. If not set then a mode is + chosen randomly based on its weight. + + Returns + ------- + float: + A sample from the prior paramter. """ if self.name in self.mvg.sampled_parameters: @@ -2211,7 +2242,7 @@ class MultivariateGaussianPrior(Prior): if len(self.mvg.current_sample) == 0: # generate a sample - self.mvg.sample() + self.mvg.sample(mode=mode) sample = self.mvg.current_sample[self.name]