pulsarpputils.py 99.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
# -*- coding: utf-8 -*-
#
#       pulsarpputils.py
#
#       Copyright 2012
#       Matthew Pitkin <matthew.pitkin@ligo.org>
#
#
#       This program is free software; you can redistribute it and/or modify
#       it under the terms of the GNU General Public License as published by
#       the Free Software Foundation; either version 2 of the License, or
#       (at your option) any later version.
#
#       This program is distributed in the hope that it will be useful,
#       but WITHOUT ANY WARRANTY; without even the implied warranty of
#       MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#       GNU General Public License for more details.
#
#       You should have received a copy of the GNU General Public License
#       along with this program; if not, write to the Free Software
#       Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
#       MA 02110-1301, USA.

# known pulsar analysis post-processing utilities

# Many functions in this a taken from, or derived from equivalents available in
# the PRESTO pulsar software package http://www.cv.nrao.edu/~sransom/presto/

29
from __future__ import print_function, division
30

31
32
import sys
import math
33
import cmath
34
35
import os
import numpy as np
36
import struct
37
import re
38
import h5py
39
import urllib2
40
41
42

from scipy.integrate import cumtrapz
from scipy.interpolate import interp1d
43
from scipy.stats import hmean
44
from scipy.misc import logsumexp
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70

from types import StringType, FloatType

# some common constants taken from psr_constants.py in PRESTO
ARCSECTORAD = float('4.8481368110953599358991410235794797595635330237270e-6')
RADTOARCSEC = float('206264.80624709635515647335733077861319665970087963')
SECTORAD    = float('7.2722052166430399038487115353692196393452995355905e-5')
RADTOSEC    = float('13750.987083139757010431557155385240879777313391975')
RADTODEG    = float('57.295779513082320876798154814105170332405472466564')
DEGTORAD    = float('1.7453292519943295769236907684886127134428718885417e-2')
RADTOHRS    = float('3.8197186342054880584532103209403446888270314977710')
HRSTORAD    = float('2.6179938779914943653855361527329190701643078328126e-1')
PI          = float('3.1415926535897932384626433832795028841971693993751')
TWOPI       = float('6.2831853071795864769252867665590057683943387987502')
PIBYTWO     = float('1.5707963267948966192313216916397514420985846996876')
SECPERDAY   = float('86400.0')
SECPERJULYR = float('31557600.0')
KMPERPC     = float('3.0856776e13')
KMPERKPC    = float('3.0856776e16')
Tsun        = float('4.925490947e-6') # sec
Msun        = float('1.9891e30')      # kg
Mjup        = float('1.8987e27')      # kg
Rsun        = float('6.9551e8')       # m
Rearth      = float('6.378e6')        # m
SOL         = float('299792458.0')    # m/s
MSUN        = float('1.989e+30')      # kg
71
G           = float('6.673e-11')      # m^3/s^2/kg
72
C           = SOL
73
KPC         = float('3.0856776e19')   # kiloparsec in metres
74
I38         = float('1e38')           # moment of inertia kg m^2
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93

# some angle conversion functions taken from psr_utils.py in PRESTO
def rad_to_dms(rad):
  """
  rad_to_dms(rad):
     Convert radians to degrees, minutes, and seconds of arc.
  """
  if (rad < 0.0): sign = -1
  else: sign = 1
  arc = RADTODEG * np.fmod(np.fabs(rad), math.pi)
  d = int(arc)
  arc = (arc - d) * 60.0
  m = int(arc)
  s = (arc - m) * 60.0
  if sign==-1 and d==0:
    return (sign * d, sign * m, sign * s)
  else:
    return (sign * d, m, s)

94
def dms_to_rad(deg, mins, sec):
95
96
97
98
99
100
  """
  dms_to_rad(deg, min, sec):
     Convert degrees, minutes, and seconds of arc to radians.
  """
  if (deg < 0.0):
    sign = -1
101
  elif (deg==0.0 and (mins < 0.0 or sec < 0.0)):
102
103
104
105
    sign = -1
  else:
    sign = 1
  return sign * ARCSECTORAD * \
106
    (60.0 * (60.0 * np.fabs(deg) + np.fabs(mins)) + np.fabs(sec))
107

108
def dms_to_deg(deg, mins, sec):
109
110
111
112
  """
  dms_to_deg(deg, min, sec):
     Convert degrees, minutes, and seconds of arc to degrees.
  """
113
  return RADTODEG * dms_to_rad(deg, mins, sec)
114
115
116
117
118
119
120
121
122
123
124
125
126
127

def rad_to_hms(rad):
  """
  rad_to_hms(rad):
     Convert radians to hours, minutes, and seconds of arc.
  """
  rad = np.fmod(rad, 2.*math.pi)
  if (rad < 0.0): rad = rad + 2.*math.pi
  arc = RADTOHRS * rad
  h = int(arc)
  arc = (arc - h) * 60.0
  m = int(arc)
  s = (arc - m) * 60.0
  return (h, m, s)
128

129
def hms_to_rad(hour, mins, sec):
130
131
132
133
134
135
136
  """
  hms_to_rad(hour, min, sec):
     Convert hours, minutes, and seconds of arc to radians
  """
  if (hour < 0.0): sign = -1
  else: sign = 1
  return sign * SECTORAD * \
137
         (60.0 * (60.0 * np.fabs(hour) + np.fabs(mins)) + np.fabs(sec))
138
139
140
141
142
143
144
145
146
147
148
149
150

def coord_to_string(h_or_d, m, s):
  """
  coord_to_string(h_or_d, m, s):
     Return a formatted string of RA or DEC values as
     'hh:mm:ss.ssss' if RA, or 'dd:mm:ss.ssss' if DEC.
  """
  retstr = ""
  if h_or_d < 0:
    retstr = "-"
  elif abs(h_or_d)==0:
    if (m < 0.0) or (s < 0.0):
      retstr = "-"
151

152
153
154
155
156
  h_or_d, m, s = abs(h_or_d), abs(m), abs(s)
  if (s >= 9.9995):
    return retstr+"%.2d:%.2d:%.4f" % (h_or_d, m, s)
  else:
    return retstr+"%.2d:%.2d:0%.4f" % (h_or_d, m, s)
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172

def rad_to_string(rad, ra_or_dec):
  """
  rad_to_string(rad, ra_or_dec):
     Convert an angle in radians to hours/degrees, minutes seconds and output
     it as a string in the format 'hh:mm:ss.ssss' if RA, or 'dd:mm:ss.ssss' if DEC.
     Whether to use hours or degrees is set by whether ra_or_dec is 'RA' or 'DEC'
  """
  if ra_or_dec.upper() == 'RA':
    v, m, s = rad_to_hms(rad)
  elif ra_or_dec.upper() == 'DEC':
    v, m, s = rad_to_dms(rad)
  else:
    raise("Unrecognised option: Expected 'ra_or_dec' to be 'RA' or 'DEC'")

  return coord_to_string(v, m, s)
173
174
175
176
177
178

def ra_to_rad(ra_string):
  """
  ra_to_rad(ar_string):
     Given a string containing RA information as
     'hh:mm:ss.ssss', return the equivalent decimal
179
     radians. Also deal with cases where input
180
     string is just hh:mm, or hh.
181
  """
182
183
184
185
186
187
188
189
  hms = ra_string.split(":")
  if len(hms) == 3:
    return hms_to_rad(int(hms[0]), int(hms[1]), float(hms[2]))
  elif len(hms) == 2:
    return hms_to_rad(int(hms[0]), int(hms[1]), 0.0)
  elif len(hms) == 1:
    return hms_to_rad(float(hms[0]), 0.0, 0.0)
  else:
190
    print("Problem parsing RA string %s" % ra_string, file=sys.stderr)
191
    sys.exit(1)
192
193
194
195
196
197

def dec_to_rad(dec_string):
  """
  dec_to_rad(dec_string):
     Given a string containing DEC information as
     'dd:mm:ss.ssss', return the equivalent decimal
198
     radians. Also deal with cases where input string
199
     is just dd:mm or dd
200
  """
201
202
203
204
205
  dms = dec_string.split(":")
  if "-" in dms[0] and float(dms[0]) == 0.0:
    m = '-'
  else:
    m = ''
206

207
208
209
210
211
212
213
  if len(dms) == 3:
    return dms_to_rad(int(dms[0]), int(m+dms[1]), float(m+dms[2]))
  elif len(dms) == 2:
    return dms_to_rad(int(dms[0]), int(m+dms[1]), 0.0)
  elif len(dms) == 1:
    return dms_to_rad(float(dms[0]), 0.0, 0.0)
  else:
214
    print("Problem parsing DEC string %s" % dec_string, file=sys.stderr)
215
    sys.exit(1)
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232

def p_to_f(p, pd, pdd=None):
  """
  p_to_f(p, pd, pdd=None):
    Convert period, period derivative and period second
    derivative to the equivalent frequency counterparts.
    Will also convert from f to p.
  """
  f = 1.0 / p
  fd = -pd / (p * p)
  if (pdd==None):
    return [f, fd]
  else:
    if (pdd==0.0):
      fdd = 0.0
    else:
      fdd = 2.0 * pd * pd / (p**3.0) - pdd / (p * p)
233

234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
    return [f, fd, fdd]

def pferrs(porf, porferr, pdorfd=None, pdorfderr=None):
  """
  pferrs(porf, porferr, pdorfd=None, pdorfderr=None):
     Calculate the period or frequency errors and
     the pdot or fdot errors from the opposite one.
  """
  if (pdorfd==None):
    return [1.0 / porf, porferr / porf**2.0]
  else:
    forperr = porferr / porf**2.0
    fdorpderr = np.sqrt((4.0 * pdorfd**2.0 * porferr**2.0) / porf**6.0 +
                          pdorfderr**2.0 / porf**4.0)
    [forp, fdorpd] = p_to_f(porf, pdorfd)
249

250
251
252
253
    return [forp, forperr, fdorpd, fdorpderr]

# class to read in a pulsar par file - this is heavily based on the function
# in parfile.py in PRESTO
254
float_keys = ["F", "F0", "F1", "F2", "F3", "F4", "F5", "F6", "F7", "F8", "F9", "F10",
255
              "PEPOCH", "POSEPOCH", "DM", "START", "FINISH", "NTOA",
256
              "TRES", "TZRMJD", "TZRFRQ", "TZRSITE", "NITS",
257
258
259
              "A1", "XDOT", "E", "ECC", "EDOT", "T0", "PB", "PBDOT", "OM",
              "OMDOT", "EPS1", "EPS2", "EPS1DOT", "EPS2DOT", "TASC", "LAMBDA",
              "BETA", "RA_RAD", "DEC_RAD", "GAMMA", "SINI", "M2", "MTOT",
260
              "FB0", "FB1", "FB2", "ELAT", "ELONG", "PMRA", "PMDEC", "DIST",
261
262
              "PB_2", "PB_3", "T0_2", "T0_3", "A1_2", "A1_3", "OM_2", "OM_3",
              "ECC_2", "ECC_3", "DIST", "PX", "KIN", "KOM", "A0", "B0", "D_AOP",
263
              # GW PARAMETERS
264
              "H0", "COSIOTA", "PSI", "PHI0", "THETA", "I21", "I31", "C22", "HPLUS", "HCROSS",
265
              "C21", "PHI22", "PHI21", "SNR", "COSTHETA", "IOTA", "Q22"]
266
str_keys = ["FILE", "PSR", "PSRJ", "NAME", "RAJ", "DECJ", "RA", "DEC", "EPHEM",
267
            "CLK", "BINARY", "UNITS"]
268
269
270

class psr_par:
  def __init__(self, parfilenm):
271
272
273
274
275
    """
    This class parses a TEMPO(2)-style pulsar parameter file. If possible all parameters
    are converted into SI units and angles are in radians. Epochs will be converted from
    MJD values into GPS times.
    """
276
277
278
    self.FILE = parfilenm
    pf = open(parfilenm)
    for line in pf.readlines():
279
280
281
      # ignore empty lines (i.e. containing only whitespace)
      if not line.strip():
        continue
282

283
284
285
286
287
288
289
      # Convert any 'D-' or 'D+' to 'E-' or 'E+'
      line = line.replace("D-", "E-")
      line = line.replace("D+", "E+")
      # also check for lower case
      line = line.replace("d-", "e-")
      line = line.replace("d+", "e+")
      splitline = line.split()
290

291
292
      # get all upper case version in case lower case in par file
      key = splitline[0].upper()
293

294
295
296
297
298
      if key in str_keys:
        setattr(self, key, splitline[1])
      elif key in float_keys:
        try:
          setattr(self, key, float(splitline[1]))
299
300
        except:
          continue
301

302
303
304
      if len(splitline)==3: # Some parfiles don't have flags, but do have errors
        if splitline[2] not in ['0', '1']:
          setattr(self, key+'_ERR', float(splitline[2]))
305
          setattr(self, key+'_FIT', 0) # parameter was not fit
306

307
      if len(splitline)==4:
308
309
310
311
        if splitline[2] == '1': # parameter was fit
          setattr(self, key+'_FIT', 1)
        else:
          setattr(self, key+'_FIT', 0)
312
        setattr(self, key+'_ERR', float(splitline[3]))
313

314
315
316
    # sky position
    if hasattr(self, 'RAJ'):
      setattr(self, 'RA_RAD', ra_to_rad(self.RAJ))
317

318
319
320
      # set RA error in rads (rather than secs)
      if hasattr(self, 'RAJ_ERR'):
        setattr(self, 'RA_RAD_ERR', hms_to_rad(0, 0, self.RAJ_ERR))
321
322
323

        if hasattr(self, 'RAJ_FIT'):
          setattr(self, 'RA_RAD_FIT', self['RAJ_FIT'])
324
325
    if hasattr(self, 'DECJ'):
      setattr(self, 'DEC_RAD', dec_to_rad(self.DECJ))
326

327
328
      # set DEC error in rads (rather than arcsecs)
      if hasattr(self, 'DECJ_ERR'):
329
        setattr(self, 'DEC_RAD_ERR', dms_to_rad(0, 0, self.DECJ_ERR))
330

331
332
333
334
335
336
337
        if hasattr(self, 'DECJ_FIT'):
          setattr(self, 'DEC_RAD_FIT', self['DECJ_FIT'])

    # convert proper motions to rads/sec from mas/year
    for pv in ['RA', 'DEC']:
      if hasattr(self, 'PM'+pv):
        pmv = self['PM'+pv]
338
        setattr(self, 'PM'+pv+'_ORIGINAL', pmv) # save original value
339
340
341
342
        setattr(self, 'PM'+pv , pmv*np.pi/(180.*3600.e3*365.25*86400.))

        if hasattr(self, 'PM'+pv+'_ERR'):
          pmve = self['PM'+pv+'_ERR']
343
          setattr(self, 'PM'+pv+'_ERR_ORIGINAL', pmv) # save original value
344
345
          setattr(self, 'PM'+pv+'_ERR' , pmve*np.pi/(180.*3600.e3*365.25*86400.))

346
347
348
349
350
351
352
353
354
355
356
357
358
    # periods and frequencies
    if hasattr(self, 'P'):
      setattr(self, 'P0', self.P)
    if hasattr(self, 'P0'):
      setattr(self, 'F0', 1.0/self.P0)
    if hasattr(self, 'F0'):
      setattr(self, 'P0', 1.0/self.F0)
    if hasattr(self, 'FB0'):
      setattr(self, 'PB', (1.0/self.FB0)/86400.0)
    if hasattr(self, 'P0_ERR'):
      if hasattr(self, 'P1_ERR'):
        f, ferr, fd, fderr = pferrs(self.P0, self.P0_ERR,
                                       self.P1, self.P1_ERR)
359
360
361
        setattr(self, 'F0_ERR', ferr)
        setattr(self, 'F1', fd)
        setattr(self, 'F1_ERR', fderr)
362
      else:
363
364
365
        if hasattr(self, 'P1'):
          f, fd, = p_to_f(self.P0, self.P1)
          setattr(self, 'F0_ERR', self.P0_ERR/(self.P0*self.P0))
366
          setattr(self, 'F1', fd)
367
368
369
370
    if hasattr(self, 'F0_ERR'):
      if hasattr(self, 'F1_ERR'):
        p, perr, pd, pderr = pferrs(self.F0, self.F0_ERR,
                                    self.F1, self.F1_ERR)
371
372
373
        setattr(self, 'P0_ERR', perr)
        setattr(self, 'P1', pd)
        setattr(self, 'P1_ERR', pderr)
374
      else:
375
376
377
        if hasattr(self, 'F1'):
          p, pd, = p_to_f(self.F0, self.F1)
          setattr(self, 'P0_ERR', self.F0_ERR/(self.F0*self.F0))
378
379
          setattr(self, 'P1', pd)

380
381
    # convert epochs (including binary epochs) to GPS if possible
    try:
382
      from lalpulsar import TTMJDtoGPS
383
384
      for epoch in ['PEPOCH', 'POSEPOCH', 'DMEPOCH', 'T0', 'TASC', 'T0_2', 'T0_3']:
        if hasattr(self, epoch):
385
          setattr(self, epoch+'_ORIGINAL', self[epoch]) # save original value
386
          setattr(self, epoch, TTMJDtoGPS(self[epoch]))
387
388

          if hasattr(self, epoch+'_ERR'): # convert errors from days to seconds
389
            setattr(self, epoch+'_ERR_ORIGINAL', self[epoch+'_ERR']) # save original value
390
391
            setattr(self, epoch+'_ERR', self[epoch+'_ERR'] * SECPERDAY)
    except:
392
      print("Could not convert epochs to GPS times. They are all still MJD values.", file=sys.stderr)
393
394
395
396
397

    # distance and parallax (distance: kpc -> metres, parallax: mas -> rads)
    convfacs = {'DIST': KPC, 'PX': 1e-3*ARCSECTORAD}
    for item in convfacs:
      if hasattr(self, item): # convert kpc to metres
398
        setattr(self, item+'_ORIGINAL', self[item]) # save original value
399
400
        setattr(self, item, self[item] * convfacs[item])

401
402
403
        if hasattr(self, item+'_ERR'):
          setattr(self, item+'_ERR_ORIGINAL', self[item+'_ERR']) # save original value
          setattr(self, item+'_ERR', self[item+'_ERR'] * convfacs[item])
404

405
    # binary parameters
406
407
408
409
410
411
412
413
414
415
416
    # omega (angle of periastron) parameters (or others requiring conversion from degs to rads)
    for om in ['OM', 'OM_2', 'OM_3', 'KIN', 'KOM']:
      if hasattr(self, om): # convert OM from degs to rads
        setattr(self, om, self[om] / RADTODEG )

        if hasattr(self, om+'_ERR'):
          setattr(self, om+'_ERR', self[om+'_ERR'] / RADTODEG )

    # period
    for pb in ['PB', 'PB_2', 'PB_3']:
      if hasattr(self, pb): # convert PB from days to seconds
417
        setattr(self, pb+'_ORIGINAL', self[pb]) # save original value
418
419
420
        setattr(self, pb, self[pb] * SECPERDAY)

        if hasattr(self, pb+'_ERR'):
421
          setattr(self, pb+'_ERR_ORIGINAL', self[pb+'_ERR']) # save original value
422
423
424
425
          setattr(self, pb+'_ERR', self[pb+'_ERR'] * SECPERDAY)

    # OMDOT
    if hasattr(self, 'OMDOT'): # convert from deg/year to rad/sec
426
      setattr(self, 'OMDOT_ORIGINAL', self['OMDOT']) # save original value
427
428
429
      setattr(self, 'OMDOT', self['OMDOT'] / (RADTODEG * 365.25 * SECPERDAY))

      if hasattr(self, 'OMDOT_ERR'):
430
        setattr(self, 'OMDOT_ERR_ORIGINAL', self['OMDOT_ERR']) # save original value
431
432
        setattr(self, 'OMDOT_ERR', self['OMDOT_ERR'] / (RADTODEG * 365.25 * SECPERDAY))

433
434
435
436
    if hasattr(self, 'EPS1') and hasattr(self, 'EPS2'):
      ecc = math.sqrt(self.EPS1 * self.EPS1 + self.EPS2 * self.EPS2)
      omega = math.atan2(self.EPS1, self.EPS2)
      setattr(self, 'E', ecc)
437
      setattr(self, 'OM', omega) # omega in rads
438
      setattr(self, 'T0', self.TASC + self.PB * omega/TWOPI)
439
    if hasattr(self, 'PB') and hasattr(self, 'A1') and not (hasattr(self, 'E') or hasattr(self, 'ECC')):
440
      setattr(self, 'E', 0.0)
441
    if hasattr(self, 'T0') and not hasattr(self, 'TASC') and hasattr(self, 'OM') and hasattr(self, 'PB'):
442
443
444
445
446
      setattr(self, 'TASC', self.T0 - self.PB * self.OM/TWOPI)

    # binary unit conversion for small numbers (TEMPO2 checks if these are > 1e-7 and if so then the units are in 1e-12) - errors are not converted
    for binu in ['XDOT', 'PBDOT', 'EDOT', 'EPS1DOT', 'EPS2DOT', 'XPBDOT']:
      if hasattr(self, binu):
447
        setattr(self, binu+'_ORIGINAL', self[binu]) # save original value
448
449
450
451
452
453
454
455
456
        if np.abs(self[binu]) > 1e-7:
          setattr(self, binu, self[binu] * 1.0e-12)

          # check value is not extremely large due to ATNF catalogue error
          if self[binu] > 10000.: # set to zero
            setattr(self, binu, 0.0)

    # masses
    for mass in ['M2', 'MTOT']:
457
      if hasattr(self, mass): # convert solar masses to kg
458
        setattr(self, mass+'_ORIGINAL', self[mass]) # save original value
459
        setattr(self, mass, self[mass]*MSUN)
460

461
        if hasattr(self, mass+'_ERR'):
462
          setattr(self, mass+'_ERR_ORIGINAL', self[mass+'_ERR']) # save original value
463
          setattr(self, mass+'_ERR', self[mass+'_ERR'] * MSUN)
464
465
466

    # D_AOP
    if hasattr(self, 'D_AOP'): # convert from inverse arcsec to inverse radians
467
      setattr(self, 'D_AOP_ORIGINAL', self['D_AOP']) # save original value
468
      setattr(self, 'D_AOP', self['D_AOP'] * RADTODEG * 3600. )
469

470
    pf.close()
471

472
473
474
475
476
  def __getitem__(self, key):
    try:
      par = getattr(self, key)
    except:
      par = None
477

478
    return par
479

480
481
482
483
484
485
486
487
  def __str__(self):
    out = ""
    for k, v in self.__dict__.items():
      if k[:2]!="__":
        if type(self.__dict__[k]) is StringType:
          out += "%10s = '%s'\n" % (k, v)
        else:
          out += "%10s = %-20.15g\n" % (k, v)
488

489
490
    return out

491
# class to read in a nested sampling prior file
492
class psr_prior:
493
494
495
496
497
  def __init__(self, priorfilenm):
    self.FILE = priorfilenm
    pf = open(priorfilenm)
    for line in pf.readlines():
      splitline = line.split()
498

499
500
      # get all upper case version in case lower case in par file
      key = splitline[0].upper()
501

502
      if key in str_keys:
503
        # everything in a prior files should be numeric
504
        setattr(self, key, [float(splitline[2]), float(splitline[3])])
505
      elif key in float_keys:
506
        setattr(self, key, [float(splitline[2]), float(splitline[3])])
507
508

    # get sky positions in rads as strings 'dd/hh:mm:ss.s'
509
    if hasattr(self, 'RA'):
510
511
512
513
514
      hl, ml, sl = rad_to_hms(self.RA[0])
      rastrl = coord_to_string(hl, ml, sl)
      hu, mu, su = rad_to_hms(self.RA[1])
      rastru = coord_to_string(hu, mu, su)
      setattr(self, 'RA_STR', [rastrl, rastru])
515

516
    if hasattr(self, 'DEC'):
517
518
519
520
521
      dl, ml, sl = rad_to_dms(self.DEC[0])
      decstrl = coord_to_string(dl, ml, sl)
      du, mu, su = rad_to_dms(self.DEC[1])
      decstru = coord_to_string(du, mu, su)
      setattr(self, 'DEC_STR', [decstrl, decstru])
522

523
    pf.close()
524

525
526
527
528
529
  def __getitem__(self, key):
    try:
      atr = getattr(self, key)
    except:
      atr = None
530

531
    return atr
532

533
534
535
536
  def __str__(self):
    out = ""
    for k, v in self.__dict__.items():
      if k[:2]!="__":
537
        out += "%10s = %-20.15g, %-20.15g\n" % (k, float(v[0]), float(v[1]))
538

539
    return out
540

541

542
543
544
545
# Function to return a pulsar's strain spin-down limit given its spin frequency
#(Hz), spin-down (Hz/s) and distance (kpc). The canonical value of moment of
# inertia of 1e38 kg m^2 is used
def spin_down_limit(freq, fdot, dist):
546
  hsd = np.sqrt((5./2.)*(G/C**3)*I38*np.fabs(fdot)/freq)/(dist*KPC)
547

548
  return hsd
549
550


551
552
553
# Function to convert a pulsar stain into ellipticity assuming the canonical
# moment of inertia
def h0_to_ellipticity(h0, freq, dist):
554
  ell = h0*C**4.*dist*KPC/(16.*np.pi**2*G*I38*freq**2)
555

556
557
  return ell

558
559
560

# Function to convert a pulsar strain into a mass quadrupole moment
def h0_to_quadrupole(h0, freq, dist):
561
  q22 = np.sqrt(15./(8.*np.pi))*h0*C**4.*dist*KPC/(16.*np.pi**2*G*freq**2)
562

563
564
565
  return q22


566
567
568
569
570
571
572
# Function to conver quadrupole moment to strain
def quadrupole_to_h0(q22, freq, dist):
   h0 = q22*np.sqrt((8.*np.pi)/15.)*16.*np.pi**2*G*freq**2/(C**4.*dist*KPC)

   return h0


573
574
575
# function to convert the psi' and phi0' coordinates used in nested sampling
# into the standard psi and phi0 coordinates (using vectors of those parameters
def phipsiconvert(phipchain, psipchain):
576
  chainlen=len(phipchain)
577

578
579
  phichain = []
  psichain = []
580

581
582
583
  theta = math.atan2(1,2);
  ct = math.cos(theta);
  st = math.sin(theta);
584
585

  for i in range(0,chainlen):
586
587
    phi0 = (1/(2*st))*phipchain[i] - (1/(2*st))*psipchain[i];
    psi = (1/(2*ct))*phipchain[i] + (1/(2*ct))*psipchain[i];
588

589
590
591
592
    # put psi between +/-pi/4
    if math.fabs(psi) > math.pi/4.:
      # shift phi0 by pi
      phi0 = phi0 + math.pi;
593

594
595
596
597
598
      # wrap around psi
      if psi > math.pi/4.:
        psi = -(math.pi/4.) + math.fmod(psi+(math.pi/4.), math.pi/2.);
      else:
        psi = (math.pi/4.) - math.fmod((math.pi/4.)-psi, math.pi/2.);
599

600
    # get phi0 into 0 -> 2pi range
601
    if phi0 > 2.*math.pi:
602
603
604
      phi0 = math.fmod(phi0, 2.*math.pi);
    else:
      phi0 = 2.*math.pi - math.fmod(2.*math.pi-phi0, 2.*math.pi);
605

606
607
    phichain.append(phi0)
    psichain.append(psi)
608
609

  return phichain, psichain
610
611


612
# function to create histogram plot of the 1D posterior (potentially for
613
# multiple IFOs) for a parameter (param). If an upper limit is given then
614
# that will be output
615
def plot_posterior_hist(poslist, param, ifos,
616
                        parambounds=[float("-inf"), float("inf")],
617
618
                        nbins=50, upperlimit=0, overplot=False,
                        parfile=None, mplparams=False):
619
620
  import matplotlib
  from matplotlib import pyplot as plt
621
  from lalapps.pulsarhtmlutils import paramlatexdict
622

623
624
  # create list of figures
  myfigs = []
625

626
627
  # create a list of upper limits
  ulvals = []
628

629
630
631
632
633
634
635
  # set some matplotlib defaults for hist
  if not mplparams:
    mplparams = { \
      'backend': 'Agg',
      'text.usetex': True, # use LaTeX for all text
      'axes.linewidth': 0.5, # set axes linewidths to 0.5
      'axes.grid': True, # add a grid
636
      'grid.linewidth': 0.5,
637
638
      'font.family': 'serif',
      'font.size': 12 }
639

640
  matplotlib.rcParams.update(mplparams)
641

642
  # ifos line colour specs
643
  coldict = {'H1': 'r', 'H2': 'c', 'L1': 'g', 'V1': 'b', 'G1': 'm', 'Joint':'k'}
644

645
646
  # param name for axis label
  try:
647
    paraxis = paramlatexdict[param.upper()]
648
649
  except:
    paraxis = param
650

651
  ymax = []
652

653
654
  # if a par file object is given expect that we have an injection file
  # containing the injected value to overplot on the histgram
655
  parval = None
656
657
  if parfile:
    parval = parfile[param.upper()]
658

659
660
  if ifos == None:
    # default to just output colour for H1
661
    ifos = ['H1']
662

663
664
  # loop over ifos
  for idx, ifo in enumerate(ifos):
665
666
    # check whether to plot all figures on top of each other
    if overplot and idx == 0:
667
      myfig = plt.figure(figsize=(4,4),dpi=200)
668
669
      plt.hold(True)
    elif not overplot:
670
      myfig = plt.figure(figsize=(4,4),dpi=200)
671

672
    pos = poslist[idx]
673

674
    pos_samps = pos[param].samples
675

676
    # get a normalised histogram for each
677
    n, bins = hist_norm_bounds( pos_samps, int(nbins), parambounds[0], \
Matthew David Pitkin's avatar
Matthew David Pitkin committed
678
                                parambounds[1] )
679

680
    # plot histogram
681
    plt.step(bins, n, color=coldict[ifo])
682

683
684
    if 'h0' not in param:
      plt.xlim(parambounds[0], parambounds[1])
685

686
687
688
    plt.xlabel(r''+paraxis, fontsize=14, fontweight=100)
    plt.ylabel(r'Probability Density', fontsize=14, fontweight=100)
    myfig.subplots_adjust(left=0.18, bottom=0.15) # adjust size
689

690
    if not overplot:
691
      plt.ylim(0, n.max()+0.1*n.max())
692
693
694
695
      #plt.legend(ifo)
      # set background colour of axes
      ax = plt.gca()
      ax.set_axis_bgcolor("#F2F1F0")
696
      myfigs.append(myfig)
697
698
    else:
      ymax.append(n.max()+0.1*n.max())
699

700
701
    # if upper limit is needed then integrate posterior using trapezium rule
    if upperlimit != 0:
702
      ct = cumtrapz(n, bins)
703

704
705
      # prepend a zero to ct
      ct = np.insert(ct, 0, 0)
706

707
      # use linear interpolation to find the value at 'upper limit'
708
      ctu, ui = np.unique(ct, return_index=True)
709
      intf = interp1d(ctu, bins[ui], kind='linear')
710
      ulvals.append(intf(float(upperlimit)))
711

712
713
714
715
  # plot parameter values
  if parval:
    if not overplot:
      plt.hold(True)
716
      plt.axvline(parval, color='k', ls='--', linewidth=1.5)
717
    else:
718
      plt.axvline(parval, color='k', ls='--', linewidth=1.5)
719

720
  if overplot:
721
    plt.ylim(0, max(ymax))
722
723
724
    #plt.legend(ifos)
    ax = plt.gca()
    ax.set_axis_bgcolor("#F2F1F0")
725
    plt.hold(False)
726
    myfigs.append(myfig)
727

728
  return myfigs, ulvals
729

730

731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
# function to return an upper limit from a posteriors: pos is an array of posteriors samples for a
# particular parameter
def upper_limit(pos, upperlimit=0.95, parambounds=[float("-inf"), float("inf")], nbins=50):
  ulval = 0

  # get a normalised histogram of posterior samples
  n, bins = hist_norm_bounds( pos, int(nbins), parambounds[0], parambounds[1] )

  # if upper limit is needed then integrate posterior using trapezium rule
  if upperlimit != 0:
    ct = cumtrapz(n, bins)

    # prepend a zero to ct
    ct = np.insert(ct, 0, 0)

    # use linear interpolation to find the value at 'upper limit'
    ctu, ui = np.unique(ct, return_index=True)
    intf = interp1d(ctu, bins[ui], kind='linear')
    ulval = intf(float(upperlimit))

  return ulval


754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
def upper_limit_greedy(pos, upperlimit=0.95, nbins=100):
  n, binedges = np.histogram(pos, bins=nbins)
  dbins = binedges[1]-binedges[0] # width of a histogram bin

  frac = 0.0
  j = 0
  for nv in n:
    prevfrac = frac
    frac += float(nv)/len(pos)
    j += 1
    if frac > upperlimit:
      break

  # linearly interpolate to get upper limit
  ul = binedges[j-1] + (upperlimit-prevfrac)*(dbins/(frac-prevfrac))

  return ul


773
774
775
776
777
778
# function to plot a posterior chain (be it MCMC chains or nested samples)
# the input should be a list of posteriors for each IFO, and the parameter
# required, the list of IFO. grr is a list of dictionaries giving
# the Gelman-Rubins statistics for the given parameter for each IFO.
# If withhist is set then it will also output a histgram, with withhist number
# of bins
779
def plot_posterior_chain(poslist, param, ifos, grr=None, withhist=0, mplparams=False):
780
781
  import matplotlib
  from matplotlib import pyplot as plt
782
  from lalapps.pulsarhtmlutils import paramlatexdict
783

784
785
786
787
  try:
    from matplotlib import gridspec
  except:
    return None
788

789
790
791
792
793
794
  if not mplparams:
    mplparams = { \
      'backend': 'Agg',
      'text.usetex': True, # use LaTeX for all text
      'axes.linewidth': 0.5, # set axes linewidths to 0.5
      'axes.grid': True, # add a grid
795
      'grid.linewidth': 0.5,
796
797
798
      'font.family': 'sans-serif',
      'font.sans-serif': 'Avant Garde, Helvetica, Computer Modern Sans serif',
      'font.size': 15 }
799

800
  matplotlib.rcParams.update(mplparams)
801

802
  coldict = {'H1': 'r', 'H2': 'c', 'L1': 'g', 'V1': 'b', 'G1': 'm', 'Joint': 'k'}
803
804

  # param name for axis label
805
806
807
808
809
  try:
    if param == 'iota':
      p = 'cosiota'
    else:
      p = param
810

811
    paryaxis = paramlatexdict[p.upper()]
812
813
  except:
    paryaxis = param
814

815
816
  if grr:
    legendvals = []
817

818
819
820
821
  maxiter = 0
  maxn = 0
  minsamp = float('inf')
  maxsamp = -float('inf')
822

823
824
  for idx, ifo in enumerate(ifos):
    if idx == 0:
825
      myfig = plt.figure(figsize=(12,4),dpi=200)
826
      myfig.subplots_adjust(bottom=0.15)
827

828
829
830
831
      if withhist:
        gs = gridspec.GridSpec(1,4, wspace=0)
        ax1 = plt.subplot(gs[:-1])
        ax2 = plt.subplot(gs[-1])
832

833
    pos = poslist[idx]
834

835
836
837
838
839
    # check for cosiota
    if 'iota' == param:
      pos_samps = np.cos(pos['iota'].samples)
    else:
      pos_samps = pos[param].samples
840

841
842
843
844
    if np.min(pos_samps) < minsamp:
      minsamp = np.min(pos_samps)
    if np.max(pos_samps) > maxsamp:
      maxsamp = np.max(pos_samps)
845

846
847
848
    if withhist:
      ax1.plot(pos_samps, '.', color=coldict[ifo], markersize=1)

849
      n, binedges = np.histogram( pos_samps, withhist, density=True )
850
851
      n = np.append(n, 0)
      ax2.step(n, binedges, color=coldict[ifo])
852

853
854
855
856
857
858
859
      if np.max(n) > maxn:
        maxn = np.max(n)
    else:
      plt.plot(pos_samps, '.', color=coldict[ifo], markersize=1)

    if grr:
      try:
860
        legendvals.append(r'$R = %.2f$' % grr[idx][param])
861
      except:
862
863
864
        legendvals = []
    else:
      legendvals = None
865

866
867
    if len(pos_samps) > maxiter:
      maxiter = len(pos_samps)
868

869
870
  if not withhist:
    ax1 = plt.gca()
871

872
  bounds = [minsamp, maxsamp]
873

874
875
  ax1.set_ylabel(r''+paryaxis, fontsize=16, fontweight=100)
  ax1.set_xlabel(r'Iterations', fontsize=16, fontweight=100)
876

877
878
  ax1.set_xlim(0, maxiter)
  ax1.set_ylim(bounds[0], bounds[1])
879

880
881
882
883
884
  if withhist:
    ax2.set_ylim(bounds[0], bounds[1])
    ax2.set_xlim(0, maxn+0.1*maxn)
    ax2.set_yticklabels([])
    ax2.set_axis_bgcolor("#F2F1F0")
885

886
887
888
  # add gelman-rubins stat data
  if legendvals:
    ax1.legend(legendvals, title='Gelman-Rubins test')
889

890
  return myfig
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909


# function to read in and plot a 2D histogram from a binary file, where the files
# structure is of the form:
#   a header of six doubles with:
#     - minimum of N-dim
#     - the step size in N-dim
#     - N - number of N-dim values
#     - minimum of M-dim
#     - the step size in M-dim
#     - M - number of M-dim values
#   followed by an NxM double array of values.
# The information returned are two lists of the x and y-axis bin centres, along
# with a numpy array containing the histogram values.
def read_hist_from_file(histfile):
  # read in 2D binary file
  try:
    fp = open(histfile, 'rb')
  except:
910
    print("Could not open prior file %s" % histfile, file=sys.stderr)
911
    return None, None, None
912

913
914
915
  try:
    pd = fp.read() # read in all the data
  except:
916
    print("Could not read in data from prior file %s" % histfile, file=sys.stderr)
917
    return None, None, None
918

919
  fp.close()
920

921
922
923
924
925
926
  # read in the header (6 doubles)
  #try:
  header = struct.unpack("d"*6, pd[:6*8])
  # except:
  #  print >> sys.stderr, "Could not read in header data from prior file %s" % histfile
  #  return None, None, None
927

928
929
930
931
932
933
  # read in histogram grid (NxM doubles)
  #try:
  grid = struct.unpack("d"*int(header[2])*int(header[5]), pd[6*8:])
  #except:
  #  print >> sys.stderr, "Could not read in header data from prior file %s" % histfile
  #  return None, None, None
934

935
  header = list(header)
936

937
938
939
940
941
  # convert grid into numpy array
  g = list(grid) # convert to list
  histarr = np.array([g[:int(header[5])]], ndmin=2)
  for i in range(int(header[2])-1):
    histarr = np.append(histarr, [g[(i+1)*int(header[5]):(i+2)*int(header[5])]], axis=0)
942

943
944
  xbins = np.linspace(header[0], header[0]+header[1]*(header[2]-1), int(header[2]))
  ybins = np.linspace(header[3], header[3]+header[4]*(header[5]-1), int(header[5]))
945

946
947
948
949
950
951
952
953
954
  return xbins, ybins, histarr


# Function to plot a 2D histogram of from a binary file containing it.
# Also supply the label names for the n-dimension (x-axis) and m-dimension
# (y-axis). If margpars is true marginalised plots of both dimensions will
# also be returned.
def plot_2Dhist_from_file(histfile, ndimlabel, mdimlabel, margpars=True, \
                          mplparams=False):
955
956
  import matplotlib
  from matplotlib import pyplot as plt
957
  from lalapps.pulsarhtmlutils import paramlatexdict
958

959
960
  # read in 2D h0 vs cos(iota) binary prior file
  xbins, ybins, histarr = read_hist_from_file(histfile)
961

962
  if not xbins.any():
963
    print("Could not read binary histogram file", file=sys.stderr)
964
    return None
965

966
  figs = []
967

968
969
970
971
972
973
974
975
976
977
  # set some matplotlib defaults for amplitude spectral density
  if not mplparams:
    mplparams = { \
      'backend': 'Agg',
      'text.usetex': True, # use LaTeX for all text
      'axes.linewidth': 0.5, # set axes linewidths to 0.5
      'axes.grid': True, # add a grid
      'grid.linewidth': 0.5,
      'font.family': 'serif',
      'font.size': 12 }
978

979
  matplotlib.rcParams.update(mplparams)
980

981
  fig = plt