Commit 7043077c authored by Colm Talbot's avatar Colm Talbot Committed by Gregory Ashton

Consistent plot formatting

parent 2c7dd519
......@@ -107,6 +107,17 @@ scheduled-python-3.7:
- pytest test/gw_example_test.py
- pytest test/sample_from_the_prior_test.py
plotting:
stage: test
image: bilbydev/bilby-test-suite-python37
only:
- schedules
script:
- python -m pip install .
- python -m pip install ligo.skymap
- pytest test/gw_plot_test.py
pages:
stage: deploy
dependencies:
......
......@@ -17,8 +17,11 @@ import scipy.stats
from scipy.special import logsumexp
from . import utils
from .utils import (logger, infer_parameters_from_function,
check_directory_exists_and_if_not_mkdir,)
from .utils import (
logger, infer_parameters_from_function,
check_directory_exists_and_if_not_mkdir,
latex_plot_format, safe_save_figure,
)
from .utils import BilbyJsonEncoder, decode_bilby_json
from .prior import Prior, PriorDict, DeltaFunction
......@@ -655,6 +658,7 @@ class Result(object):
fmt(summary.median), fmt(summary.minus), fmt(summary.plus))
return summary
@latex_plot_format
def plot_single_density(self, key, prior=None, cumulative=False,
title=None, truth=None, save=True,
file_base_name=None, bins=50, label_fontsize=16,
......@@ -734,7 +738,7 @@ class Result(object):
file_name = file_base_name + key + '_cdf'
else:
file_name = file_base_name + key + '_pdf'
fig.savefig(file_name, dpi=dpi)
safe_save_figure(fig=fig, filename=file_name, dpi=dpi)
plt.close(fig)
else:
return fig
......@@ -819,6 +823,7 @@ class Result(object):
bins=bins, label_fontsize=label_fontsize, dpi=dpi,
title_fontsize=title_fontsize, quantiles=quantiles)
@latex_plot_format
def plot_corner(self, parameters=None, priors=None, titles=True, save=True,
filename=None, dpi=300, **kwargs):
""" Plot a corner-plot
......@@ -976,11 +981,12 @@ class Result(object):
outdir = self._safe_outdir_creation(kwargs.get('outdir'), self.plot_corner)
filename = '{}/{}_corner.png'.format(outdir, self.label)
logger.debug('Saving corner plot to {}'.format(filename))
fig.savefig(filename, dpi=dpi)
safe_save_figure(fig=fig, filename=filename, dpi=dpi)
plt.close(fig)
return fig
@latex_plot_format
def plot_walkers(self, **kwargs):
""" Method to plot the trace of the walkers in an ensemble MCMC plot """
if hasattr(self, 'walkers') is False:
......@@ -1008,9 +1014,10 @@ class Result(object):
outdir = self._safe_outdir_creation(kwargs.get('outdir'), self.plot_walkers)
filename = '{}/{}_walkers.png'.format(outdir, self.label)
logger.debug('Saving walkers plot to {}'.format('filename'))
fig.savefig(filename)
safe_save_figure(fig=fig, filename=filename)
plt.close(fig)
@latex_plot_format
def plot_with_data(self, model, x, y, ndraws=1000, npoints=1000,
xlabel=None, ylabel=None, data_label='data',
data_fmt='o', draws_label=None, filename=None,
......@@ -1082,7 +1089,7 @@ class Result(object):
if filename is None:
outdir = self._safe_outdir_creation(outdir, self.plot_with_data)
filename = '{}/{}_plot_with_data'.format(outdir, self.label)
fig.savefig(filename, dpi=dpi)
safe_save_figure(fig=fig, filename=filename, dpi=dpi)
plt.close(fig)
@staticmethod
......@@ -1459,6 +1466,7 @@ class ResultList(list):
raise ResultListError("Inconsistent samplers between results")
@latex_plot_format
def plot_multiple(results, filename=None, labels=None, colours=None,
save=True, evidences=False, **kwargs):
""" Generate a corner plot overlaying two sets of results
......@@ -1538,10 +1546,11 @@ def plot_multiple(results, filename=None, labels=None, colours=None,
filename = default_filename
if save:
fig.savefig(filename)
safe_save_figure(fig=fig, filename=filename)
return fig
@latex_plot_format
def make_pp_plot(results, filename=None, save=True, confidence_interval=[0.68, 0.95, 0.997],
lines=None, legend_fontsize='x-small', keys=None, title=True,
confidence_interval_alpha=0.1,
......@@ -1651,7 +1660,7 @@ def make_pp_plot(results, filename=None, save=True, confidence_interval=[0.68, 0
if save:
if filename is None:
filename = 'outdir/pp.png'
fig.savefig(filename, dpi=500)
safe_save_figure(fig=fig, filename=filename, dpi=500)
return fig, pvals
......
from __future__ import division
from distutils.spawn import find_executable
import logging
import os
from math import fmod
import argparse
import traceback
import inspect
import functools
import types
import subprocess
import multiprocessing
......@@ -1099,6 +1101,43 @@ def reflect(u):
return u
def latex_plot_format(func):
"""
Wrap a plotting function to set rcParams so that text renders nicely with
latex and Computer Modern Roman font.
"""
@functools.wraps(func)
def wrapper_decorator(*args, **kwargs):
from matplotlib import rcParams
_old_tex = rcParams["text.usetex"]
_old_serif = rcParams["font.serif"]
_old_family = rcParams["font.family"]
if find_executable("latex"):
rcParams["text.usetex"] = True
else:
rcParams["text.usetex"] = False
rcParams["font.serif"] = "Computer Modern Roman"
rcParams["font.family"] = "serif"
value = func(*args, **kwargs)
rcParams["text.usetex"] = _old_tex
rcParams["font.serif"] = _old_serif
rcParams["font.family"] = _old_family
return value
return wrapper_decorator
def safe_save_figure(fig, filename, **kwargs):
from matplotlib import rcParams
try:
fig.savefig(fname=filename, **kwargs)
except RuntimeError:
logger.debug(
"Failed to save plot with tex labels turning off tex."
)
rcParams["text.usetex"] = False
fig.savefig(fname=filename, **kwargs)
class IllegalDurationAndSamplingFrequencyException(Exception):
pass
......
......@@ -9,7 +9,10 @@ from matplotlib import rcParams
import numpy as np
from ..core.result import Result as CoreResult
from ..core.utils import infft, logger, check_directory_exists_and_if_not_mkdir
from ..core.utils import (
infft, logger, check_directory_exists_and_if_not_mkdir,
latex_plot_format, safe_save_figure
)
from .utils import plot_spline_pos, spline_angle_xform, asd_from_freq_series
from .detector import get_empty_interferometer, Interferometer
......@@ -133,6 +136,7 @@ class CompactBinaryCoalescenceResult(CoreResult):
logger.info("No injection for detector {}".format(detector))
return None
@latex_plot_format
def plot_calibration_posterior(self, level=.9, format="png"):
""" Plots the calibration amplitude and phase uncertainty.
Adapted from the LALInference version in bayespputils
......@@ -148,12 +152,6 @@ class CompactBinaryCoalescenceResult(CoreResult):
"""
if format not in ["png", "pdf"]:
raise ValueError("Format should be one of png or pdf")
_old_tex = rcParams["text.usetex"]
_old_serif = rcParams["font.serif"]
_old_family = rcParams["font.family"]
rcParams["text.usetex"] = True
rcParams["font.serif"] = "Computer Modern Roman"
rcParams["font.family"] = "Serif"
fig, [ax1, ax2] = plt.subplots(2, 1, figsize=(15, 15), dpi=500)
posterior = self.posterior
......@@ -214,17 +212,11 @@ class CompactBinaryCoalescenceResult(CoreResult):
filename = os.path.join(outdir, self.label + '_calibration.' + format)
fig.tight_layout()
try:
plt.savefig(filename, format=format, dpi=600, bbox_inches='tight')
except RuntimeError:
logger.debug(
"Failed to save waveform with tex labels turning off tex."
)
rcParams["text.usetex"] = False
plt.savefig(filename, format=format, dpi=600, bbox_inches='tight')
rcParams["text.usetex"] = _old_tex
rcParams["font.serif"] = _old_serif
rcParams["font.family"] = _old_family
safe_save_figure(
fig=fig, filename=filename,
format=format, dpi=600, bbox_inches='tight'
)
logger.debug("Calibration figure saved to {}".format(filename))
plt.close()
def plot_waveform_posterior(
......@@ -268,6 +260,7 @@ class CompactBinaryCoalescenceResult(CoreResult):
save=True, format=format, start_time=start_time,
end_time=end_time)
@latex_plot_format
def plot_interferometer_waveform_posterior(
self, interferometer, level=0.9, n_samples=None, save=True,
format='png', start_time=None, end_time=None):
......@@ -328,13 +321,6 @@ class CompactBinaryCoalescenceResult(CoreResult):
"HTML plotting requested, but plotly cannot be imported, "
"falling back to png format for waveform plot.")
format = "png"
else:
_old_tex = rcParams["text.usetex"]
_old_serif = rcParams["font.serif"]
_old_family = rcParams["font.family"]
rcParams["text.usetex"] = True
rcParams["font.serif"] = "Computer Modern Roman"
rcParams["font.family"] = "Serif"
if isinstance(interferometer, str):
interferometer = get_empty_interferometer(interferometer)
......@@ -383,7 +369,6 @@ class CompactBinaryCoalescenceResult(CoreResult):
len(frequency_idxs))
)
plot_times = interferometer.time_array[time_idxs]
# if format == "html":
plot_times -= interferometer.strain_data.start_time
start_time -= interferometer.strain_data.start_time
end_time -= interferometer.strain_data.start_time
......@@ -663,7 +648,7 @@ class CompactBinaryCoalescenceResult(CoreResult):
fig.update_xaxes(title_text=f_domain_x_label, type="log", row=1)
fig.update_yaxes(title_text=f_domain_y_label, type="log", row=1)
fig.update_xaxes(title_text=t_domain_x_label, type="linear", row=2)
fig.update_yaxes(title_text=t_domain_x_label, type="linear", row=2)
fig.update_yaxes(title_text=t_domain_y_label, type="linear", row=2)
else:
axs[0].set_xlim(interferometer.minimum_frequency,
interferometer.maximum_frequency)
......@@ -684,19 +669,12 @@ class CompactBinaryCoalescenceResult(CoreResult):
plot(fig, filename=filename, include_mathjax='cdn', auto_open=False)
else:
plt.tight_layout()
try:
plt.savefig(filename, format=format, dpi=600)
except RuntimeError:
logger.debug(
"Failed to save waveform with tex labels turning off tex."
)
rcParams["text.usetex"] = False
plt.savefig(filename, format=format, dpi=600)
safe_save_figure(
fig=fig, filename=filename,
format=format, dpi=600
)
plt.close()
rcParams["text.usetex"] = _old_tex
rcParams["font.serif"] = _old_serif
rcParams["font.family"] = _old_family
logger.debug("Figure saved to {}".format(filename))
logger.debug("Waveform figure saved to {}".format(filename))
else:
return fig
......@@ -714,14 +692,14 @@ class CompactBinaryCoalescenceResult(CoreResult):
Parameters
----------
maxpts: int
Number of samples to use, if None all samples are used
Maximum number of samples to use, if None all samples are used
trials: int
Number of trials at each clustering number
jobs: int
Number of multiple threads
enable_multiresolution: bool
Generate a multiresolution HEALPix map (default: True)
objid: st
objid: str
Event ID to store in FITS header
instruments: str
Name of detectors
......@@ -766,16 +744,16 @@ class CompactBinaryCoalescenceResult(CoreResult):
if load_pickle is False:
try:
pts = data[['ra', 'dec', 'luminosity_distance']].values
cls = kde.Clustered2Plus1DSkyKDE
confidence_levels = kde.Clustered2Plus1DSkyKDE
distance = True
except KeyError:
logger.warning("The results file does not contain luminosity_distance")
pts = data[['ra', 'dec']].values
cls = kde.Clustered2DSkyKDE
confidence_levels = kde.Clustered2DSkyKDE
distance = False
logger.info('Initialising skymap class')
skypost = cls(pts, trials=trials, multiprocess=jobs)
skypost = confidence_levels(pts, trials=trials, jobs=jobs)
logger.info('Pickling skymap to {}'.format(default_obj_filename))
with open(default_obj_filename, 'wb') as out:
pickle.dump(skypost, out)
......@@ -788,7 +766,8 @@ class CompactBinaryCoalescenceResult(CoreResult):
logger.info('Reading from pickle {}'.format(obj_filename))
with open(obj_filename, 'rb') as file:
skypost = pickle.load(file)
skypost.multiprocess = jobs
skypost.jobs = jobs
distance = isinstance(skypost, kde.Clustered2Plus1DSkyKDE)
logger.info('Making skymap')
hpmap = skypost.as_healpix()
......@@ -844,12 +823,12 @@ class CompactBinaryCoalescenceResult(CoreResult):
cb.set_label(r'prob. per deg$^2$')
if contour is not None:
cls = 100 * postprocess.find_greedy_credible_levels(skymap)
cs = ax.contour_hpx(
(cls, 'ICRS'), nested=metadata['nest'],
confidence_levels = 100 * postprocess.find_greedy_credible_levels(skymap)
contours = ax.contour_hpx(
(confidence_levels, 'ICRS'), nested=metadata['nest'],
colors='k', linewidths=0.5, levels=contour)
fmt = r'%g\%%' if rcParams['text.usetex'] else '%g%%'
plt.clabel(cs, fmt=fmt, fontsize=6, inline=True)
plt.clabel(contours, fmt=fmt, fontsize=6, inline=True)
# Add continents.
if geo:
......@@ -875,7 +854,7 @@ class CompactBinaryCoalescenceResult(CoreResult):
text.append('event ID: {}'.format(objid))
if contour:
pp = np.round(contour).astype(int)
ii = np.round(np.searchsorted(np.sort(cls), contour) *
ii = np.round(np.searchsorted(np.sort(confidence_levels), contour) *
deg2perpix).astype(int)
for i, p in zip(ii, pp):
text.append(
......@@ -884,15 +863,7 @@ class CompactBinaryCoalescenceResult(CoreResult):
filename = os.path.join(self.outdir, "{}_skymap.png".format(self.label))
logger.info("Generating 2D projected skymap to {}".format(filename))
plt.savefig(filename, dpi=500)
class CompactBinaryCoalesenceResult(CompactBinaryCoalescenceResult):
def __init__(self, **kwargs):
logger.warning('CompactBinaryCoalesenceResult is deprecated use '
'CompactBinaryCoalescenceResult')
super(CompactBinaryCoalesenceResult, self).__init__(**kwargs)
safe_save_figure(fig=plt.gcf(), filename=filename, dpi=dpi)
CBCResult = CompactBinaryCoalescenceResult
......@@ -9,6 +9,7 @@ addopts =
--ignore test/gw_example_test.py
--ignore test/example_test.py
--ignore test/sample_from_the_prior_test.py
--ignore test/gw_plot_test.py
--ignore test/sampler_test.py
[metadata]
......
import os
import shutil
import unittest
import pandas as pd
import bilby
class TestCBCResult(unittest.TestCase):
def setUp(self):
bilby.utils.command_line_args.bilby_test_mode = False
priors = bilby.gw.prior.BBHPriorDict()
priors['geocent_time'] = 2
injection_parameters = priors.sample()
self.meta_data = dict(
likelihood=dict(
phase_marginalization=True,
distance_marginalization=False,
time_marginalization=True,
frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole,
waveform_arguments=dict(
reference_frequency=20.0,
waveform_approximant='IMRPhenomPv2'),
interferometers=dict(
H1=dict(optimal_SNR=1, parameters=injection_parameters),
L1=dict(optimal_SNR=1, parameters=injection_parameters)),
sampling_frequency=4096,
duration=4,
start_time=0,
waveform_generator_class=bilby.gw.waveform_generator.WaveformGenerator,
parameter_conversion=bilby.gw.conversion.convert_to_lal_binary_black_hole_parameters,
)
)
self.result = bilby.gw.result.CBCResult(
label='label', outdir='outdir', sampler='nestle',
search_parameter_keys=list(priors.keys()), fixed_parameter_keys=list(),
priors=priors, sampler_kwargs=dict(test='test', func=lambda x: x),
injection_parameters=injection_parameters,
meta_data=self.meta_data,
posterior=pd.DataFrame(priors.sample(100))
)
if not os.path.isdir(self.result.outdir):
os.mkdir(self.result.outdir)
pass
def tearDown(self):
bilby.utils.command_line_args.bilby_test_mode = True
try:
shutil.rmtree(self.result.outdir)
except OSError:
pass
del self.result
pass
def test_calibration_plot(self):
calibration_prior = bilby.gw.prior.CalibrationPriorDict.constant_uncertainty_spline(
amplitude_sigma=0.1,
phase_sigma=0.1,
minimum_frequency=20,
maximum_frequency=2048,
label="recalib_H1_",
n_nodes=5,
)
calibration_filename = f"{self.result.outdir}/{self.result.label}_calibration.png"
for key in calibration_prior:
self.result.posterior[key] = calibration_prior[key].sample(100)
self.result.plot_calibration_posterior()
self.assertTrue(os.path.exists(calibration_filename))
def test_calibration_plot_returns_none_with_no_calibration_parameters(self):
self.assertIsNone(self.result.plot_calibration_posterior())
calibration_filename = f"{self.result.outdir}/{self.result.label}_calibration.png"
self.assertFalse(os.path.exists(calibration_filename))
def test_calibration_pdf_plot(self):
calibration_prior = bilby.gw.prior.CalibrationPriorDict.constant_uncertainty_spline(
amplitude_sigma=0.1,
phase_sigma=0.1,
minimum_frequency=20,
maximum_frequency=2048,
label="recalib_H1_",
n_nodes=5,
)
calibration_filename = f"{self.result.outdir}/{self.result.label}_calibration.pdf"
for key in calibration_prior:
self.result.posterior[key] = calibration_prior[key].sample(100)
self.result.plot_calibration_posterior(format="pdf")
self.assertTrue(os.path.exists(calibration_filename))
def test_calibration_invalid_format_raises_error(self):
with self.assertRaises(ValueError):
self.result.plot_calibration_posterior(format="bilby")
def test_waveform_plotting_png(self):
self.result.plot_waveform_posterior(n_samples=200)
for ifo in self.result.interferometers:
self.assertTrue(os.path.exists(
f"{self.result.outdir}/{self.result.label}_{ifo}_waveform.png")
)
def test_plot_skymap_meta_data(self):
from ligo.skymap import io
expected_keys = {
"HISTORY", "build_date", "creator", "distmean", "diststd",
"gps_creation_time", "gps_time", "nest", "objid", "origin",
"vcs_revision", "vcs_version", "instruments"
}
self.result.plot_skymap(
maxpts=50, geo=False, objid="test", instruments="H1L1"
)
fits_filename = f"{self.result.outdir}/{self.result.label}_skymap.fits"
skymap_filename = f"{self.result.outdir}/{self.result.label}_skymap.png"
pickle_filename = f"{self.result.outdir}/{self.result.label}_skypost.obj"
hpmap, meta = io.read_sky_map(fits_filename)
self.assertEqual(expected_keys, set(meta.keys()))
self.assertTrue(os.path.exists(skymap_filename))
self.assertTrue(os.path.exists(pickle_filename))
self.result.plot_skymap(
maxpts=50, geo=False, objid="test", instruments="H1L1",
load_pickle=True, colorbar=True
)
if __name__ == '__main__':
unittest.main()
from __future__ import absolute_import, division
import os
import shutil
import unittest
import pandas as pd
import bilby
import unittest
import shutil
class TestCBCResult(unittest.TestCase):
def setUp(self):
bilby.utils.command_line_args.bilby_test_mode = False
priors = bilby.prior.PriorDict(dict(
x=bilby.prior.Uniform(0, 1, 'x', latex_label='$x$', unit='s'),
y=bilby.prior.Uniform(0, 1, 'y', latex_label='$y$', unit='m'),
c=1,
d=2))
priors = bilby.gw.prior.BBHPriorDict()
priors['geocent_time'] = 2
injection_parameters = priors.sample()
self.meta_data = dict(
likelihood=dict(
phase_marginalization=True, distance_marginalization=False,
phase_marginalization=True,
distance_marginalization=False,
time_marginalization=True,
frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole,
waveform_arguments=dict(
reference_frequency=20.0,
waveform_approximant='IMRPhenomPv2'),
interferometers=dict(
H1=dict(optimal_SNR=1, parameters=dict(x=0.1, y=0.3)),
L1=dict(optimal_SNR=1, parameters=dict(x=0.1, y=0.3)))))
H1=dict(optimal_SNR=1, parameters=injection_parameters),
L1=dict(optimal_SNR=1, parameters=injection_parameters)),
sampling_frequency=4096,
duration=4,
start_time=0,
waveform_generator_class=bilby.gw.waveform_generator.WaveformGenerator,
parameter_conversion=bilby.gw.conversion.convert_to_lal_binary_black_hole_parameters,
)
)
self.result = bilby.gw.result.CBCResult(
label='label', outdir='outdir', sampler='nestle',
search_parameter_keys=['x', 'y'], fixed_parameter_keys=['c', 'd'],
search_parameter_keys=list(priors.keys()), fixed_parameter_keys=list(),
priors=priors, sampler_kwargs=dict(test='test', func=lambda x: x),
injection_parameters=dict(x=0.5, y=0.5),
meta_data=self.meta_data)
injection_parameters=injection_parameters,
meta_data=self.meta_data,
posterior=pd.DataFrame(priors.sample(100))
)
if not os.path.isdir(self.result.outdir):
os.mkdir(self.result.outdir)
pass
def tearDown(self):
......@@ -82,6 +94,36 @@ class TestCBCResult(unittest.TestCase):
with self.assertRaises(AttributeError):
self.result.reference_frequency
def test_sampling_frequency(self):
self.assertEqual(
self.result.sampling_frequency,
self.meta_data['likelihood']['sampling_frequency'])
def test_sampling_frequency_unset(self):
self.result.meta_data['likelihood'].pop('sampling_frequency')
with self.assertRaises(AttributeError):
self.result.sampling_frequency
def test_duration(self):
self.assertEqual(
self.result.duration,
self.meta_data['likelihood']['duration'])
def test_duration_unset(self):
self.result.meta_data['likelihood'].pop('duration')
with self.assertRaises(AttributeError):
self.result.duration
def test_start_time(self):
self.assertEqual(
self.result.start_time,
self.meta_data['likelihood']['start_time'])
def test_start_time_unset(self):