diff --git a/CHANGELOG.md b/CHANGELOG.md index 047568dd7d62ca86f202bb7c5fa1b1c67b07aeca..19cc2068fa9a69aa7c42ea02e7554921d60b30ec 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,7 +3,7 @@ ## Unreleased ### Added -- +- A plot_skymap method to the CBCResult object based on ligo.skymap ### Changed - The `periodic_boundary` option to the prior classes has been changed to `boundary`. @@ -16,7 +16,7 @@ function to create trace plots during the dynesty checkpoints - Dynesty now prints the progress to STDOUT rather than STDERR ### Removed -- +- Obsolete (and potentially incorrect) plot_skymap methods from gw.utils ## [0.4.5] 2019-04-03 diff --git a/bilby/core/prior.py b/bilby/core/prior.py index 5e6894e740a2c272a7dac7d8c8eac7e1ee1cbaae..aca8e710e1f0fb741ff281e1053a00b2c24253d8 100644 --- a/bilby/core/prior.py +++ b/bilby/core/prior.py @@ -115,7 +115,12 @@ class PriorDict(OrderedDict): elements = line.split('=') key = elements[0].replace(' ', '') val = '='.join(elements[1:]) - prior[key] = eval(val) + try: + prior[key] = eval(val) + except TypeError as e: + raise TypeError( + "Unable to parse dictionary file {}, bad line: {} = {}. Error message {}" + .format(filename, key, val, e)) self.update(prior) def from_dictionary(self, dictionary): diff --git a/bilby/gw/result.py b/bilby/gw/result.py index 2cff34828fceb77a0551ead61c085f2195bdc186..b8b4c61a314e44c3319b18de8855b80414731dc1 100644 --- a/bilby/gw/result.py +++ b/bilby/gw/result.py @@ -1,11 +1,15 @@ from __future__ import division +import json +import pickle +import os + import matplotlib.pyplot as plt +from matplotlib import rcParams import numpy as np -import os from ..core.result import Result as CoreResult -from ..core.utils import logger +from ..core.utils import logger, check_directory_exists_and_if_not_mkdir from .utils import (plot_spline_pos, spline_angle_xform) @@ -180,5 +184,191 @@ class CompactBinaryCoalesenceResult(CoreResult): fig.savefig(filename, bbox_inches='tight') plt.close(fig) + def plot_skymap( + self, maxpts=None, trials=5, jobs=1, enable_multiresolution=True, + objid=None, instruments=None, geo=False, dpi=600, + transparent=False, colorbar=False, contour=[50, 90], + annotate=True, cmap='cylon', load_pickle=False): + """ Generate a fits file and sky map from a result + + Code adapted from ligo.skymap.tool.ligo_skymap_from_samples and + ligo.skymap.tool.plot_skymap. Note, the use of this additionally + required the installation of ligo.skymap. + + Parameters + ---------- + maxpts: int + Number of samples to use, if None all samples are used + trials: int + Number of trials at each clustering number + jobs: int + Number of multiple threads + enable_multiresolution: bool + Generate a multiresolution HEALPix map (default: True) + objid: st + Event ID to store in FITS header + instruments: str + Name of detectors + geo: bool + Plot in geographic coordinates (lat, lon) instead of RA, Dec + dpi: int + Resolution of figure in fots per inch + transparent: bool + Save image with transparent background + colorbar: bool + Show colorbar + contour: list + List of contour levels to use + annotate: bool + Annotate image with details + cmap: str + Name of the colormap to use + load_pickle: bool, str + If true, load the cached pickle file (default name), or the + pickle-file give as a path. + """ + + try: + from astropy.time import Time + from ligo.skymap import io, version, plot, postprocess, bayestar, kde + import healpy as hp + except ImportError as e: + logger.info("Unable to generate skymap: error {}".format(e)) + return + + check_directory_exists_and_if_not_mkdir(self.outdir) + + logger.info('Reading samples for skymap') + data = self.posterior + + if maxpts is not None and maxpts < len(data): + logger.info('Taking random subsample of chain') + data = data.sample(maxpts) + + default_obj_filename = os.path.join(self.outdir, '{}_skypost.obj'.format(self.label)) + + if load_pickle is False: + try: + pts = data[['ra', 'dec', 'luminosity_distance']].values + cls = kde.Clustered2Plus1DSkyKDE + distance = True + except KeyError: + logger.warning("The results file does not contain luminosity_distance") + pts = data[['ra', 'dec']].values + cls = kde.Clustered2DSkyKDE + distance = False + + logger.info('Initialising skymap class') + skypost = cls(pts, trials=trials, multiprocess=jobs) + logger.info('Pickling skymap to {}'.format(default_obj_filename)) + with open(default_obj_filename, 'wb') as out: + pickle.dump(skypost, out) + + else: + if isinstance(load_pickle, str): + obj_filename = load_pickle + else: + obj_filename = default_obj_filename + logger.info('Reading from pickle {}'.format(obj_filename)) + with open(obj_filename, 'rb') as file: + skypost = pickle.load(file) + skypost.multiprocess = jobs + + logger.info('Making skymap') + hpmap = skypost.as_healpix() + if not enable_multiresolution: + hpmap = bayestar.rasterize(hpmap) + + hpmap.meta.update(io.fits.metadata_for_version_module(version)) + hpmap.meta['creator'] = "bilby" + hpmap.meta['origin'] = 'LIGO/Virgo' + hpmap.meta['gps_creation_time'] = Time.now().gps + hpmap.meta['history'] = "" + if objid is not None: + hpmap.meta['objid'] = objid + if instruments: + hpmap.meta['instruments'] = instruments + if distance: + hpmap.meta['distmean'] = np.mean(data['luminosity_distance']) + hpmap.meta['diststd'] = np.std(data['luminosity_distance']) + + try: + time = data['geocent_time'] + hpmap.meta['gps_time'] = time.mean() + except KeyError: + logger.warning('Cannot determine the event time from geocent_time') + + fits_filename = os.path.join(self.outdir, "{}_skymap.fits".format(self.label)) + logger.info('Saving skymap fits-file to {}'.format(fits_filename)) + io.write_sky_map(fits_filename, hpmap, nest=True) + + skymap, metadata = io.fits.read_sky_map(fits_filename, nest=None) + nside = hp.npix2nside(len(skymap)) + + # Convert sky map from probability to probability per square degree. + deg2perpix = hp.nside2pixarea(nside, degrees=True) + probperdeg2 = skymap / deg2perpix + + if geo: + obstime = Time(metadata['gps_time'], format='gps').utc.isot + ax = plt.axes(projection='geo degrees mollweide', obstime=obstime) + else: + ax = plt.axes(projection='astro hours mollweide') + ax.grid() + + # Plot sky map. + vmax = probperdeg2.max() + img = ax.imshow_hpx( + (probperdeg2, 'ICRS'), nested=metadata['nest'], vmin=0., vmax=vmax, + cmap=cmap) + + # Add colorbar. + if colorbar: + cb = plot.colorbar(img) + cb.set_label(r'prob. per deg$^2$') + + if contour is not None: + cls = 100 * postprocess.find_greedy_credible_levels(skymap) + cs = ax.contour_hpx( + (cls, 'ICRS'), nested=metadata['nest'], + colors='k', linewidths=0.5, levels=contour) + fmt = r'%g\%%' if rcParams['text.usetex'] else '%g%%' + plt.clabel(cs, fmt=fmt, fontsize=6, inline=True) + + # Add continents. + if geo: + geojson_filename = os.path.join( + os.path.dirname(plot.__file__), 'ne_simplified_coastline.json') + with open(geojson_filename, 'r') as geojson_file: + geoms = json.load(geojson_file)['geometries'] + verts = [coord for geom in geoms + for coord in zip(*geom['coordinates'])] + plt.plot(*verts, color='0.5', linewidth=0.5, + transform=ax.get_transform('world')) + + # Add a white outline to all text to make it stand out from the background. + plot.outline_text(ax) + + if annotate: + text = [] + try: + objid = metadata['objid'] + except KeyError: + pass + else: + text.append('event ID: {}'.format(objid)) + if contour: + pp = np.round(contour).astype(int) + ii = np.round(np.searchsorted(np.sort(cls), contour) * + deg2perpix).astype(int) + for i, p in zip(ii, pp): + text.append( + u'{:d}% area: {:d} deg$^2$'.format(p, i, grouping=True)) + ax.text(1, 1, '\n'.join(text), transform=ax.transAxes, ha='right') + + filename = os.path.join(self.outdir, "{}_skymap.png".format(self.label)) + logger.info("Generating 2D projected skymap to {}".format(filename)) + plt.savefig(filename, dpi=500) + CBCResult = CompactBinaryCoalesenceResult diff --git a/bilby/gw/utils.py b/bilby/gw/utils.py index ff71b55520b2173acaf1c6ac8dbe841726773d9a..548be0a72d89e124f9dc8364c481a52c6182fdb6 100644 --- a/bilby/gw/utils.py +++ b/bilby/gw/utils.py @@ -568,84 +568,6 @@ def gw_data_find(observatory, gps_start_time, duration, calibration, return output_cache_file -def save_to_fits(posterior, outdir, label): - """ Generate a fits file from a posterior array """ - from astropy.io import fits - from astropy.units import pixel - from astropy.table import Table - import healpy as hp - nside = hp.get_nside(posterior) - npix = hp.nside2npix(nside) - logger.debug('Generating table') - m = Table([posterior], names=['PROB']) - m['PROB'].unit = pixel ** -1 - - ordering = 'RING' - extra_header = [('PIXTYPE', 'HEALPIX', - 'HEALPIX pixelisation'), - ('ORDERING', ordering, - 'Pixel ordering scheme: RING, NESTED, or NUNIQ'), - ('COORDSYS', 'C', - 'Ecliptic, Galactic or Celestial (equatorial)'), - ('NSIDE', hp.npix2nside(npix), - 'Resolution parameter of HEALPIX'), - ('INDXSCHM', 'IMPLICIT', - 'Indexing: IMPLICIT or EXPLICIT')] - - fname = '{}/{}_{}.fits'.format(outdir, label, nside) - hdu = fits.table_to_hdu(m) - hdu.header.extend(extra_header) - hdulist = fits.HDUList([fits.PrimaryHDU(), hdu]) - logger.debug('Writing to a fits file') - hdulist.writeto(fname, overwrite=True) - - -def plot_skymap(result, center='120d -40d', nside=512): - """ Generate a sky map from a result """ - import scipy - from astropy.units import deg - import healpy as hp - import ligo.skymap.plot # noqa - import matplotlib.pyplot as plt - logger.debug('Generating skymap') - - logger.debug('Reading in ra and dec, creating kde and converting') - ra_dec_radians = result.posterior[['ra', 'dec']].values - kde = scipy.stats.gaussian_kde(ra_dec_radians.T) - npix = hp.nside2npix(nside) - ipix = range(npix) - theta, phi = hp.pix2ang(nside, ipix) - ra = phi - dec = 0.5 * np.pi - theta - - logger.debug('Generating posterior') - post = kde(np.row_stack([ra, dec])) - post /= np.sum(post * hp.nside2pixarea(nside)) - - fig = plt.figure(figsize=(5, 5)) - ax = plt.axes([0.05, 0.05, 0.9, 0.9], - projection='astro globe', - center=center) - ax.coords.grid(True, linestyle='--') - lon = ax.coords[0] - lat = ax.coords[1] - lon.set_ticks(exclude_overlapping=True, spacing=45 * deg) - lat.set_ticks(spacing=30 * deg) - - lon.set_major_formatter('dd') - lat.set_major_formatter('hh') - lon.set_ticklabel(color='k') - lat.set_ticklabel(color='k') - - logger.debug('Plotting sky map') - ax.imshow_hpx(post) - - lon.set_ticks_visible(False) - lat.set_ticks_visible(False) - - fig.savefig('{}/{}_skymap.png'.format(result.outdir, result.label)) - - def build_roq_weights(data, basis, deltaF): """