Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • john-veitch/bilby
  • duncanmmacleod/bilby
  • colm.talbot/bilby
  • lscsoft/bilby
  • matthew-pitkin/bilby
  • salvatore-vitale/tupak
  • charlie.hoy/bilby
  • bfarr/bilby
  • virginia.demilio/bilby
  • vivien/bilby
  • eric-howell/bilby
  • sebastian-khan/bilby
  • rhys.green/bilby
  • moritz.huebner/bilby
  • joseph.mills/bilby
  • scott.coughlin/bilby
  • matthew.carney/bilby
  • hyungwon.lee/bilby
  • monica.rizzo/bilby
  • christopher-berry/bilby
  • lindsay.demarchi/bilby
  • kaushik.rao/bilby
  • charles.kimball/bilby
  • andrew.matas/bilby
  • juan.calderonbustillo/bilby
  • patrick-meyers/bilby
  • hannah.middleton/bilby
  • eve.chase/bilby
  • grant.meadors/bilby
  • khun.phukon/bilby
  • sumeet.kulkarni/bilby
  • daniel.reardon/bilby
  • cjhaster/bilby
  • sylvia.biscoveanu/bilby
  • james-clark/bilby
  • meg.millhouse/bilby
  • joshua.willis/bilby
  • nikhil.sarin/bilby
  • paul.easter/bilby
  • youngmin/bilby
  • daniel-williams/bilby
  • shanika.galaudage/bilby
  • bruce.edelman/bilby
  • avi.vajpeyi/bilby
  • isobel.romero-shaw/bilby
  • andrew.kim/bilby
  • dominika.zieba/bilby
  • jonathan.davies/bilby
  • marc.arene/bilby
  • srishti.tiwari/bilby-tidal-heating-eccentric
  • aditya.vijaykumar/bilby
  • michael.williams/bilby
  • cecilio.garcia-quiros/bilby
  • rory-smith/bilby
  • maite.mateu-lucena/bilby
  • wushichao/bilby
  • kaylee.desoto/bilby
  • brandon.piotrzkowski/bilby
  • rossella.gamba/bilby
  • hunter.gabbard/bilby
  • deep.chatterjee/bilby
  • tathagata.ghosh/bilby
  • arunava.mukherjee/bilby
  • philip.relton/bilby
  • reed.essick/bilby
  • pawan.gupta/bilby
  • francisco.hernandez/bilby
  • rhiannon.udall/bilby
  • leo.tsukada/bilby
  • will-farr/bilby
  • vijay.varma/bilby
  • jeremy.baier/bilby
  • joshua.brandt/bilby
  • ethan.payne/bilby
  • ka-lok.lo/bilby
  • antoni.ramos-buades/bilby
  • oliviastephany.wilk/bilby
  • jack.heinzel/bilby
  • samson.leong/bilby-psi4
  • viviana.caceres/bilby
  • nadia.qutob/bilby
  • michael-coughlin/bilby
  • hemantakumar.phurailatpam/bilby
  • boris.goncharov/bilby
  • sama.al-shammari/bilby
  • siqi.zhong/bilby
  • jocelyn-read/bilby
  • marc.penuliar/bilby
  • stephanie.letourneau/bilby
  • alexandresebastien.goettel/bilby
  • alec.gunny/bilby
  • serguei.ossokine/bilby
  • pratyusava.baral/bilby
  • sophie.hourihane/bilby
  • eunsub/bilby
  • james.hart/bilby
  • pratyusava.baral/bilby-tg
  • zhaozc/bilby
  • pratyusava.baral/bilby_SoG
  • tomasz.baka/bilby
  • nicogerardo.bers/bilby
  • soumen.roy/bilby
  • isaac.mcmahon/healpix-redundancy
  • asamakai.baker/bilby-frequency-dependent-antenna-pattern-functions
  • anna.puecher/bilby
  • pratyusava.baral/bilby-x-g
  • thibeau.wouters/bilby
  • christian.adamcewicz/bilby
  • raffi.enficiaud/bilby
109 results
Show changes
Commits on Source (18)
...@@ -68,6 +68,9 @@ class Prior(object): ...@@ -68,6 +68,9 @@ class Prior(object):
if sorted(self.__dict__.keys()) != sorted(other.__dict__.keys()): if sorted(self.__dict__.keys()) != sorted(other.__dict__.keys()):
return False return False
for key in self.__dict__: for key in self.__dict__:
if key == "least_recently_sampled":
# ignore sample drawn from prior in comparison
continue
if type(self.__dict__[key]) is np.ndarray: if type(self.__dict__[key]) is np.ndarray:
if not np.array_equal(self.__dict__[key], other.__dict__[key]): if not np.array_equal(self.__dict__[key], other.__dict__[key]):
return False return False
......
...@@ -23,7 +23,8 @@ from .utils import ( ...@@ -23,7 +23,8 @@ from .utils import (
check_directory_exists_and_if_not_mkdir, check_directory_exists_and_if_not_mkdir,
latex_plot_format, safe_save_figure, latex_plot_format, safe_save_figure,
BilbyJsonEncoder, load_json, BilbyJsonEncoder, load_json,
move_old_file, get_version_information move_old_file, get_version_information,
decode_bilby_json,
) )
from .prior import Prior, PriorDict, DeltaFunction from .prior import Prior, PriorDict, DeltaFunction
...@@ -358,10 +359,27 @@ class Result(object): ...@@ -358,10 +359,27 @@ class Result(object):
if os.path.isfile(filename): if os.path.isfile(filename):
dictionary = deepdish.io.load(filename) dictionary = deepdish.io.load(filename)
# Some versions of deepdish/pytables return the dictionanary as # Some versions of deepdish/pytables return the dictionary as
# a dictionary with a key 'data' # a dictionary with a key 'data'
if len(dictionary) == 1 and 'data' in dictionary: if len(dictionary) == 1 and 'data' in dictionary:
dictionary = dictionary['data'] dictionary = dictionary['data']
if "priors" in dictionary:
# parse priors from JSON string (allowing for backwards
# compatibility)
if not isinstance(dictionary["priors"], PriorDict):
try:
priordict = PriorDict()
for key, value in dictionary["priors"].items():
if key not in ["__module__", "__name__", "__prior_dict__"]:
priordict[key] = decode_bilby_json(value)
dictionary["priors"] = priordict
except Exception as e:
raise IOError(
"Unable to parse priors from '{}':\n{}".format(
filename, e,
)
)
try: try:
if isinstance(dictionary.get('posterior', None), dict): if isinstance(dictionary.get('posterior', None), dict):
dictionary['posterior'] = pd.DataFrame(dictionary['posterior']) dictionary['posterior'] = pd.DataFrame(dictionary['posterior'])
...@@ -609,8 +627,9 @@ class Result(object): ...@@ -609,8 +627,9 @@ class Result(object):
dictionary['sampler_kwargs'][key] = str(dictionary['sampler_kwargs']) dictionary['sampler_kwargs'][key] = str(dictionary['sampler_kwargs'])
try: try:
# convert priors to JSON dictionary for both JSON and hdf5 files
dictionary["priors"] = dictionary["priors"]._get_json_dict()
if extension == 'json': if extension == 'json':
dictionary["priors"] = dictionary["priors"]._get_json_dict()
if gzip: if gzip:
import gzip import gzip
# encode to a string # encode to a string
......
...@@ -430,6 +430,7 @@ class Sampler(object): ...@@ -430,6 +430,7 @@ class Sampler(object):
likelihood evaluations. likelihood evaluations.
""" """
logger.info("Generating initial points from the prior")
unit_cube = [] unit_cube = []
parameters = [] parameters = []
likelihood = [] likelihood = []
......
...@@ -257,11 +257,11 @@ class Dynesty(NestedSampler): ...@@ -257,11 +257,11 @@ class Dynesty(NestedSampler):
# Constructing output. # Constructing output.
string = [] string = []
string.append("bound:{:d}".format(bounditer)) string.append("bound:{:d}".format(bounditer))
string.append("nc:{:d}".format(nc)) string.append("nc:{:3d}".format(nc))
string.append("ncall:{:d}".format(ncall)) string.append("ncall:{:.1e}".format(ncall))
string.append("eff:{:0.1f}%".format(eff)) string.append("eff:{:0.1f}%".format(eff))
string.append("{}={:0.2f}+/-{:0.2f}".format(key, logz, logzerr)) string.append("{}={:0.2f}+/-{:0.2f}".format(key, logz, logzerr))
string.append("dlogz:{:0.3f}>{:0.2f}".format(delta_logz, dlogz)) string.append("dlogz:{:0.3f}>{:0.2g}".format(delta_logz, dlogz))
self.pbar.set_postfix_str(" ".join(string), refresh=False) self.pbar.set_postfix_str(" ".join(string), refresh=False)
self.pbar.update(niter - self.pbar.n) self.pbar.update(niter - self.pbar.n)
...@@ -793,7 +793,7 @@ def sample_rwalk_bilby(args): ...@@ -793,7 +793,7 @@ def sample_rwalk_bilby(args):
v = v_list[idx] v = v_list[idx]
logl = logl_list[idx] logl = logl_list[idx]
else: else:
logger.warning("Unable to find a new point using walk: returning a random point") logger.debug("Unable to find a new point using walk: returning a random point")
u = np.random.uniform(size=n) u = np.random.uniform(size=n)
v = prior_transform(u) v = prior_transform(u)
logl = loglikelihood(v) logl = loglikelihood(v)
......
...@@ -193,7 +193,13 @@ class Ptemcee(MCMCSampler): ...@@ -193,7 +193,13 @@ class Ptemcee(MCMCSampler):
) )
self.convergence_inputs = ConvergenceInputs(**convergence_inputs_dict) self.convergence_inputs = ConvergenceInputs(**convergence_inputs_dict)
# MultiProcessing inputs # Check if threads was given as an equivalent arg
if threads == 1:
for equiv in self.npool_equiv_kwargs:
if equiv in kwargs:
threads = kwargs.pop(equiv)
# Store threads
self.threads = threads self.threads = threads
# Misc inputs # Misc inputs
...@@ -221,10 +227,6 @@ class Ptemcee(MCMCSampler): ...@@ -221,10 +227,6 @@ class Ptemcee(MCMCSampler):
for equiv in self.nwalkers_equiv_kwargs: for equiv in self.nwalkers_equiv_kwargs:
if equiv in kwargs: if equiv in kwargs:
kwargs["nwalkers"] = kwargs.pop(equiv) kwargs["nwalkers"] = kwargs.pop(equiv)
if "threads" not in kwargs:
for equiv in self.npool_equiv_kwargs:
if equiv in kwargs:
kwargs["threads"] = kwargs.pop(equiv)
def get_pos0_from_prior(self): def get_pos0_from_prior(self):
""" Draw the initial positions from the prior """ Draw the initial positions from the prior
......
...@@ -2,7 +2,10 @@ import importlib ...@@ -2,7 +2,10 @@ import importlib
import os import os
import tempfile import tempfile
import shutil import shutil
import distutils.dir_util
import signal import signal
import time
import datetime
import numpy as np import numpy as np
...@@ -115,8 +118,15 @@ class Pymultinest(NestedSampler): ...@@ -115,8 +118,15 @@ class Pymultinest(NestedSampler):
# for PyMultiNest >=2.9 the n_params kwarg cannot be None # for PyMultiNest >=2.9 the n_params kwarg cannot be None
if self.kwargs["n_params"] is None: if self.kwargs["n_params"] is None:
self.kwargs["n_params"] = self.ndim self.kwargs["n_params"] = self.ndim
if self.kwargs['dump_callback'] is None:
self.kwargs['dump_callback'] = self._dump_callback
NestedSampler._verify_kwargs_against_default_kwargs(self) NestedSampler._verify_kwargs_against_default_kwargs(self)
def _dump_callback(self, *args, **kwargs):
if self.use_temporary_directory:
self._copy_temporary_directory_contents_to_proper_path()
self._calculate_and_save_sampling_time()
def _apply_multinest_boundaries(self): def _apply_multinest_boundaries(self):
if self.kwargs["wrapped_params"] is None: if self.kwargs["wrapped_params"] is None:
self.kwargs["wrapped_params"] = [] self.kwargs["wrapped_params"] = []
...@@ -154,10 +164,6 @@ class Pymultinest(NestedSampler): ...@@ -154,10 +164,6 @@ class Pymultinest(NestedSampler):
shutil.copytree( shutil.copytree(
self.outputfiles_basename, self.temporary_outputfiles_basename self.outputfiles_basename, self.temporary_outputfiles_basename
) )
if os.path.islink(self.outputfiles_basename):
os.unlink(self.outputfiles_basename)
else:
shutil.rmtree(self.outputfiles_basename)
def write_current_state_and_exit(self, signum=None, frame=None): def write_current_state_and_exit(self, signum=None, frame=None):
""" Write current state and exit on exit_code """ """ Write current state and exit on exit_code """
...@@ -166,15 +172,15 @@ class Pymultinest(NestedSampler): ...@@ -166,15 +172,15 @@ class Pymultinest(NestedSampler):
signum, self.exit_code signum, self.exit_code
) )
) )
self._calculate_and_save_sampling_time()
if self.use_temporary_directory: if self.use_temporary_directory:
self._move_temporary_directory_to_proper_path() self._move_temporary_directory_to_proper_path()
os._exit(self.exit_code) os._exit(self.exit_code)
def _move_temporary_directory_to_proper_path(self): def _copy_temporary_directory_contents_to_proper_path(self):
""" """
Move the temporary back to the proper path Copy the temporary back to the proper path.
Do not delete the temporary directory.
Anything in the proper path at this point is removed including links
""" """
logger.info( logger.info(
"Overwriting {} with {}".format( "Overwriting {} with {}".format(
...@@ -185,11 +191,16 @@ class Pymultinest(NestedSampler): ...@@ -185,11 +191,16 @@ class Pymultinest(NestedSampler):
outputfiles_basename_stripped = self.outputfiles_basename[:-1] outputfiles_basename_stripped = self.outputfiles_basename[:-1]
else: else:
outputfiles_basename_stripped = self.outputfiles_basename outputfiles_basename_stripped = self.outputfiles_basename
if os.path.islink(outputfiles_basename_stripped): distutils.dir_util.copy_tree(self.temporary_outputfiles_basename, outputfiles_basename_stripped)
os.unlink(outputfiles_basename_stripped)
elif os.path.isdir(outputfiles_basename_stripped): def _move_temporary_directory_to_proper_path(self):
shutil.rmtree(outputfiles_basename_stripped) """
shutil.move(self.temporary_outputfiles_basename, outputfiles_basename_stripped) Copy the temporary back to the proper path
Anything in the temporary directory at this point is removed
"""
self._copy_temporary_directory_contents_to_proper_path()
shutil.rmtree(self.temporary_outputfiles_basename)
def run_sampler(self): def run_sampler(self):
import pymultinest import pymultinest
...@@ -197,17 +208,20 @@ class Pymultinest(NestedSampler): ...@@ -197,17 +208,20 @@ class Pymultinest(NestedSampler):
self._verify_kwargs_against_default_kwargs() self._verify_kwargs_against_default_kwargs()
self._setup_run_directory() self._setup_run_directory()
self._check_and_load_sampling_time_file()
# Overwrite pymultinest's signal handling function # Overwrite pymultinest's signal handling function
pm_run = importlib.import_module("pymultinest.run") pm_run = importlib.import_module("pymultinest.run")
pm_run.interrupt_handler = self.write_current_state_and_exit pm_run.interrupt_handler = self.write_current_state_and_exit
self.start_time = time.time()
out = pymultinest.solve( out = pymultinest.solve(
LogLikelihood=self.log_likelihood, LogLikelihood=self.log_likelihood,
Prior=self.prior_transform, Prior=self.prior_transform,
n_dims=self.ndim, n_dims=self.ndim,
**self.kwargs **self.kwargs
) )
self._calculate_and_save_sampling_time()
self._clean_up_run_directory() self._clean_up_run_directory()
...@@ -222,26 +236,22 @@ class Pymultinest(NestedSampler): ...@@ -222,26 +236,22 @@ class Pymultinest(NestedSampler):
self.result.log_evidence_err = out["logZerr"] self.result.log_evidence_err = out["logZerr"]
self.calc_likelihood_count() self.calc_likelihood_count()
self.result.outputfiles_basename = self.outputfiles_basename self.result.outputfiles_basename = self.outputfiles_basename
self.result.sampling_time = datetime.timedelta(seconds=self.total_sampling_time)
return self.result return self.result
def _setup_run_directory(self): def _setup_run_directory(self):
""" """
If using a temporary directory, the output directory is moved to the If using a temporary directory, the output directory is moved to the
temporary directory and symlinked back. temporary directory.
""" """
if self.use_temporary_directory: if self.use_temporary_directory:
temporary_outputfiles_basename = tempfile.TemporaryDirectory().name temporary_outputfiles_basename = tempfile.TemporaryDirectory().name
self.temporary_outputfiles_basename = temporary_outputfiles_basename self.temporary_outputfiles_basename = temporary_outputfiles_basename
if os.path.exists(self.outputfiles_basename): if os.path.exists(self.outputfiles_basename):
shutil.move(self.outputfiles_basename, self.temporary_outputfiles_basename) distutils.dir_util.copy_tree(self.outputfiles_basename, self.temporary_outputfiles_basename)
check_directory_exists_and_if_not_mkdir(temporary_outputfiles_basename) check_directory_exists_and_if_not_mkdir(temporary_outputfiles_basename)
os.symlink(
os.path.abspath(self.temporary_outputfiles_basename),
os.path.abspath(self.outputfiles_basename),
target_is_directory=True,
)
self.kwargs["outputfiles_basename"] = self.temporary_outputfiles_basename self.kwargs["outputfiles_basename"] = self.temporary_outputfiles_basename
logger.info("Using temporary file {}".format(temporary_outputfiles_basename)) logger.info("Using temporary file {}".format(temporary_outputfiles_basename))
else: else:
...@@ -249,6 +259,21 @@ class Pymultinest(NestedSampler): ...@@ -249,6 +259,21 @@ class Pymultinest(NestedSampler):
self.kwargs["outputfiles_basename"] = self.outputfiles_basename self.kwargs["outputfiles_basename"] = self.outputfiles_basename
logger.info("Using output file {}".format(self.outputfiles_basename)) logger.info("Using output file {}".format(self.outputfiles_basename))
def _check_and_load_sampling_time_file(self):
self.time_file_path = self.kwargs["outputfiles_basename"] + '/sampling_time.dat'
if os.path.exists(self.time_file_path):
with open(self.time_file_path, 'r') as time_file:
self.total_sampling_time = float(time_file.readline())
else:
self.total_sampling_time = 0
def _calculate_and_save_sampling_time(self):
current_time = time.time()
new_sampling_time = current_time - self.start_time
self.total_sampling_time += new_sampling_time
with open(self.time_file_path, 'w') as time_file:
time_file.write(str(self.total_sampling_time))
def _clean_up_run_directory(self): def _clean_up_run_directory(self):
if self.use_temporary_directory: if self.use_temporary_directory:
self._move_temporary_directory_to_proper_path() self._move_temporary_directory_to_proper_path()
......
...@@ -1236,6 +1236,13 @@ def kish_log_effective_sample_size(ln_weights): ...@@ -1236,6 +1236,13 @@ def kish_log_effective_sample_size(ln_weights):
return log_n_eff return log_n_eff
def get_function_path(func):
if hasattr(func, "__module__") and hasattr(func, "__name__"):
return "{}.{}".format(func.__module__, func.__name__)
else:
return func
class IllegalDurationAndSamplingFrequencyException(Exception): class IllegalDurationAndSamplingFrequencyException(Exception):
pass pass
......
...@@ -196,7 +196,7 @@ class GravitationalWaveTransient(Likelihood): ...@@ -196,7 +196,7 @@ class GravitationalWaveTransient(Likelihood):
"The waveform_generator {} is None. Setting from the " "The waveform_generator {} is None. Setting from the "
"provided interferometers.".format(attr)) "provided interferometers.".format(attr))
elif wfg_attr != ifo_attr: elif wfg_attr != ifo_attr:
logger.warning( logger.debug(
"The waveform_generator {} is not equal to that of the " "The waveform_generator {} is not equal to that of the "
"provided interferometers. Overwriting the " "provided interferometers. Overwriting the "
"waveform_generator.".format(attr)) "waveform_generator.".format(attr))
......
...@@ -3,6 +3,7 @@ import numpy as np ...@@ -3,6 +3,7 @@ import numpy as np
from ..core import utils from ..core import utils
from ..core.series import CoupledTimeAndFrequencySeries from ..core.series import CoupledTimeAndFrequencySeries
from .utils import PropertyAccessor from .utils import PropertyAccessor
from .conversion import convert_to_lal_binary_black_hole_parameters
class WaveformGenerator(object): class WaveformGenerator(object):
...@@ -57,7 +58,7 @@ class WaveformGenerator(object): ...@@ -57,7 +58,7 @@ class WaveformGenerator(object):
self.time_domain_source_model = time_domain_source_model self.time_domain_source_model = time_domain_source_model
self.source_parameter_keys = self.__parameters_from_source_model() self.source_parameter_keys = self.__parameters_from_source_model()
if parameter_conversion is None: if parameter_conversion is None:
self.parameter_conversion = _default_parameter_conversion self.parameter_conversion = convert_to_lal_binary_black_hole_parameters
else: else:
self.parameter_conversion = parameter_conversion self.parameter_conversion = parameter_conversion
if waveform_arguments is not None: if waveform_arguments is not None:
...@@ -67,6 +68,15 @@ class WaveformGenerator(object): ...@@ -67,6 +68,15 @@ class WaveformGenerator(object):
if isinstance(parameters, dict): if isinstance(parameters, dict):
self.parameters = parameters self.parameters = parameters
self._cache = dict(parameters=None, waveform=None, model=None) self._cache = dict(parameters=None, waveform=None, model=None)
utils.logger.info(
"Waveform generator initiated with\n"
" frequency_domain_source_model: {}\n"
" frequency_domain_source_model: {}\n"
" parameter_conversion: {}"
.format(utils.get_function_path(self.frequency_domain_source_model),
utils.get_function_path(self.time_domain_source_model),
utils.get_function_path(self.parameter_conversion))
)
def __repr__(self): def __repr__(self):
if self.frequency_domain_source_model is not None: if self.frequency_domain_source_model is not None:
...@@ -77,7 +87,7 @@ class WaveformGenerator(object): ...@@ -77,7 +87,7 @@ class WaveformGenerator(object):
tdsm_name = self.time_domain_source_model.__name__ tdsm_name = self.time_domain_source_model.__name__
else: else:
tdsm_name = None tdsm_name = None
if self.parameter_conversion.__name__ == '_default_parameter_conversion': if self.parameter_conversion is None:
param_conv_name = None param_conv_name = None
else: else:
param_conv_name = self.parameter_conversion.__name__ param_conv_name = self.parameter_conversion.__name__
...@@ -237,7 +247,3 @@ class WaveformGenerator(object): ...@@ -237,7 +247,3 @@ class WaveformGenerator(object):
raise AttributeError('Either time or frequency domain source ' raise AttributeError('Either time or frequency domain source '
'model must be provided.') 'model must be provided.')
return set(utils.infer_parameters_from_function(model)) return set(utils.infer_parameters_from_function(model))
def _default_parameter_conversion(parmeters):
return parmeters, list()
...@@ -41,6 +41,10 @@ def setup_command_line_args(): ...@@ -41,6 +41,10 @@ def setup_command_line_args():
help="Convert all results.", default=False) help="Convert all results.", default=False)
parser.add_argument("-m", "--merge", action='store_true', parser.add_argument("-m", "--merge", action='store_true',
help="Merge the set of runs, output saved using the outdir and label") help="Merge the set of runs, output saved using the outdir and label")
parser.add_argument("-e", "--extension", type=str, choices=["json", "hdf5"],
default=True, help="Use given extension for the merged output file.")
parser.add_argument("-g", "--gzip", action="store_true",
help="Gzip the merged output results file if using JSON format.")
parser.add_argument("-o", "--outdir", type=str, default=None, parser.add_argument("-o", "--outdir", type=str, default=None,
help="Output directory.") help="Output directory.")
parser.add_argument("-l", "--label", type=str, default=None, parser.add_argument("-l", "--label", type=str, default=None,
...@@ -131,4 +135,6 @@ def main(): ...@@ -131,4 +135,6 @@ def main():
result.label = args.label result.label = args.label
if args.outdir is not None: if args.outdir is not None:
result.outdir = args.outdir result.outdir = args.outdir
result.save_to_file()
extension = args.extension
result.save_to_file(gzip=args.gzip, extension=extension)
...@@ -42,6 +42,7 @@ waveform_arguments = dict(waveform_approximant='IMRPhenomPv2', ...@@ -42,6 +42,7 @@ waveform_arguments = dict(waveform_approximant='IMRPhenomPv2',
waveform_generator = bilby.gw.WaveformGenerator( waveform_generator = bilby.gw.WaveformGenerator(
duration=duration, sampling_frequency=sampling_frequency, duration=duration, sampling_frequency=sampling_frequency,
frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole, frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole,
parameter_conversion=bilby.gw.conversion.convert_to_lal_binary_black_hole_parameters,
waveform_arguments=waveform_arguments) waveform_arguments=waveform_arguments)
# Set up interferometers. In this case we'll use two interferometers # Set up interferometers. In this case we'll use two interferometers
......
...@@ -39,7 +39,7 @@ def write_version_file(version): ...@@ -39,7 +39,7 @@ def write_version_file(version):
except Exception as e: except Exception as e:
print("Unable to obtain git version information, exception: {}" print("Unable to obtain git version information, exception: {}"
.format(e)) .format(e))
git_status = '' git_status = 'release'
version_file = '.version' version_file = '.version'
if os.path.isfile(version_file) is False: if os.path.isfile(version_file) is False:
......
...@@ -767,7 +767,8 @@ class TestPymultinest(unittest.TestCase): ...@@ -767,7 +767,8 @@ class TestPymultinest(unittest.TestCase):
n_iter_before_update=100, null_log_evidence=-1e90, n_iter_before_update=100, null_log_evidence=-1e90,
max_modes=100, mode_tolerance=-1e90, seed=-1, max_modes=100, mode_tolerance=-1e90, seed=-1,
context=0, write_output=True, log_zero=-1e100, context=0, write_output=True, log_zero=-1e100,
max_iter=0, init_MPI=False, dump_callback=None) max_iter=0, init_MPI=False, dump_callback='dumper')
self.sampler.kwargs['dump_callback'] = 'dumper' # Check like the dynesty print_func
self.assertListEqual([1, 0], self.sampler.kwargs['wrapped_params']) # Check this separately self.assertListEqual([1, 0], self.sampler.kwargs['wrapped_params']) # Check this separately
self.sampler.kwargs['wrapped_params'] = None # The dict comparison can't handle lists self.sampler.kwargs['wrapped_params'] = None # The dict comparison can't handle lists
self.assertDictEqual(expected, self.sampler.kwargs) self.assertDictEqual(expected, self.sampler.kwargs)
...@@ -782,7 +783,7 @@ class TestPymultinest(unittest.TestCase): ...@@ -782,7 +783,7 @@ class TestPymultinest(unittest.TestCase):
n_iter_before_update=100, null_log_evidence=-1e90, n_iter_before_update=100, null_log_evidence=-1e90,
max_modes=100, mode_tolerance=-1e90, seed=-1, max_modes=100, mode_tolerance=-1e90, seed=-1,
context=0, write_output=True, log_zero=-1e100, context=0, write_output=True, log_zero=-1e100,
max_iter=0, init_MPI=False, dump_callback=None) max_iter=0, init_MPI=False, dump_callback='dumper')
for equiv in bilby.core.sampler.base_sampler.NestedSampler.npoints_equiv_kwargs: for equiv in bilby.core.sampler.base_sampler.NestedSampler.npoints_equiv_kwargs:
new_kwargs = self.sampler.kwargs.copy() new_kwargs = self.sampler.kwargs.copy()
...@@ -790,6 +791,7 @@ class TestPymultinest(unittest.TestCase): ...@@ -790,6 +791,7 @@ class TestPymultinest(unittest.TestCase):
new_kwargs[ new_kwargs[
"wrapped_params" "wrapped_params"
] = None # The dict comparison can't handle lists ] = None # The dict comparison can't handle lists
new_kwargs['dump_callback'] = 'dumper' # Check this like Dynesty print_func
new_kwargs[equiv] = 123 new_kwargs[equiv] = 123
self.sampler.kwargs = new_kwargs self.sampler.kwargs = new_kwargs
self.assertDictEqual(expected, self.sampler.kwargs) self.assertDictEqual(expected, self.sampler.kwargs)
......
...@@ -73,7 +73,7 @@ class TestWaveformGeneratorInstantiationWithoutOptionalParameters(unittest.TestC ...@@ -73,7 +73,7 @@ class TestWaveformGeneratorInstantiationWithoutOptionalParameters(unittest.TestC
self.waveform_generator.start_time, self.waveform_generator.start_time,
self.waveform_generator.frequency_domain_source_model.__name__, self.waveform_generator.frequency_domain_source_model.__name__,
self.waveform_generator.time_domain_source_model, self.waveform_generator.time_domain_source_model,
None, bilby.gw.conversion.convert_to_lal_binary_black_hole_parameters.__name__,
self.waveform_generator.waveform_arguments, self.waveform_generator.waveform_arguments,
) )
) )
...@@ -92,7 +92,7 @@ class TestWaveformGeneratorInstantiationWithoutOptionalParameters(unittest.TestC ...@@ -92,7 +92,7 @@ class TestWaveformGeneratorInstantiationWithoutOptionalParameters(unittest.TestC
self.waveform_generator.start_time, self.waveform_generator.start_time,
self.waveform_generator.frequency_domain_source_model, self.waveform_generator.frequency_domain_source_model,
self.waveform_generator.time_domain_source_model.__name__, self.waveform_generator.time_domain_source_model.__name__,
None, bilby.gw.conversion.convert_to_lal_binary_black_hole_parameters.__name__,
self.waveform_generator.waveform_arguments, self.waveform_generator.waveform_arguments,
) )
) )
......