Commit 9cce2abb authored by Simone Mastrogiovanni's avatar Simone Mastrogiovanni
Browse files

Merge branch 'Speed_up_old_pickles' into 'pending_review'

Speed up old pickles

See merge request cbc-cosmo/gwcosmo!26
parents fd8c2e98 5bd18066
......@@ -86,7 +86,9 @@ parser = OptionParser(
Option("--detectors", default='HLV',
help="Set the detectors to use for the pickle (default=HLV)."),
Option("--det_combination", default='True',
help="Set whether or not to consider all possible detectors combinations (default=True).")
help="Set whether or not to consider all possible detectors combinations (default=True)."),
Option("--seed", default='1000', type=int,
help="Set the random seed.")
])
opts, args = parser.parse_args()
print(opts)
......@@ -102,7 +104,7 @@ if len(missing) > 0:
mass_distribution = str(opts.mass_distribution)
psd = opts.psd
seed = int(opts.seed)
alpha = float(opts.powerlaw_slope)
Mmin = float(opts.minimum_mass)
Mmax = float(opts.maximum_mass)
......@@ -116,7 +118,7 @@ b = float(opts.b)
min_H0 = float(opts.min_H0)
max_H0 = float(opts.max_H0)
bins_H0 = float(opts.bins_H0)
bins_H0 = int(opts.bins_H0)
det_combination = str2bool(opts.det_combination)
linear = str2bool(opts.linear_cosmology)
basic = str2bool(opts.basic_pdet)
......@@ -167,7 +169,7 @@ if opts.combine is None:
pdet = gwcosmo.detection_probability.DetectionProbability(mass_distribution=mass_distribution, asd=psd, detected_masses=detected_masses, basic=basic, detectors=dets,
linear=linear, alpha=alpha, Mmin=Mmin, Mmax=Mmax, Omega_m=Omega_m, alpha_2=alpha_2, mu_g=mu_g, sigma_g=sigma_g,
lambda_peak=lambda_peak, beta=beta, full_waveform=full_waveform, Nsamps=Nsamps, det_combination = det_combination,
b=b, delta_m=delta_m, constant_H0=constant_H0, H0=H0, network_snr_threshold=network_snr_threshold, path=pdet_path)
b=b, delta_m=delta_m, constant_H0=constant_H0, H0=H0, network_snr_threshold=network_snr_threshold, path=pdet_path, seed=seed)
pickle.dump( pdet, open( pdet_path, "wb" ) )
......@@ -187,7 +189,7 @@ else:
h0 = pdets.H0vec
probs[h0] = pdets.prob
if detected_masses==True:
detected[h0] = pdets.detected==1
detected[h0] = pdets.detected
H0vec = np.array(list(probs.keys()))
H0vec = sorted(H0vec)
......@@ -229,7 +231,7 @@ else:
for i in range (len(logit_prob)):
logit_prob[i]=np.where(logit_prob[i]==float('+inf'), 100, logit_prob[i])
logit_prob[i]=np.where(logit_prob[i]==float('-inf'), -33, logit_prob[i])
interp_average = interp2d(pdet.z_array, pdet.H0vec, logit_prob, kind='linear')
interp_average = interp2d(pdet.z_array, pdet.H0vec, logit_prob, kind='cubic')
pdet.interp_average = interp_average
if pdets.asd != None:
......@@ -264,5 +266,5 @@ else:
pdet_path = '{}_Nsamps{}_{}_snr_{}'.format(pdets.mass_distribution, str(pdets.Nsamps), kind,str(pdets.snr_threshold))
if detected_masses==True:
np.savez(pdet_path+'_detected.npz',[detected,m1,m2])
np.savez(pdet_path+'_detected.npz',[detected,pdet.m1,pdet.m2])
pickle.dump( pdet, open( pdet_path+'.p', "wb" ) )
......@@ -113,14 +113,12 @@ class DetectionProbability(object):
self.seed = seed
self.detected_masses = detected_masses
self.det_combination = det_combination
np.random.seed(seed)
self.cosmo = fast_cosmology(Omega_m=self.Omega_m, linear=self.linear)
self.path = str(path)+'_checkpoint.p'
if self.full_waveform is True:
self.z_array = np.logspace(-4.0, 1., 1000)
self.z_array = np.logspace(-4.0, 1., 500)
else:
# TODO: For higher values of z (z=10) this goes
# outside the range of the psds and gives an error
......@@ -131,7 +129,7 @@ class DetectionProbability(object):
detect = np.ones(self.Nsamps)
self.detected = 0
if self.detected_masses==True:
self.detected = np.zeros((len(self.z_array),self.Nsamps))
self.detected = np.zeros((len(self.z_array),self.Nsamps),dtype=np.float32)
if os.path.isfile(self.path):
pdet_checkpoint = pickle.load(open(self.path, 'rb'))
......@@ -141,7 +139,7 @@ class DetectionProbability(object):
self.detected = pdet_checkpoint['detected']
self.seed = pdet_checkpoint['seed']
detect = pdet_checkpoint['detect']
np.random.seed(self.seed)
# set up the samples for monte carlo integral
N = self.Nsamps
self.RAs = np.random.rand(N)*2.0*np.pi
......@@ -657,20 +655,18 @@ class DetectionProbability(object):
"""
lal_detectors = [lalsim.DetectorPrefixToLALDetector(name)
for name in self.detectors]
network_rhosq = np.zeros(self.Nsamps)
bar = progressbar.ProgressBar()
for i in bar(range(checkpoint, len(self.z_array))):
z = self.z_array[i]
dl = self.cosmo.dl_zH0(z, H0)
factor = 1+z
survival = 0
np.random.seed(100)
network_rhosq = np.zeros(self.Nsamps)
for n in range(self.Nsamps):
detectors = self.dets[n]
psd = self.psds[n]
if detect[n] == 1:
if self.full_waveform is True:
if detect[n] == 1:
if self.full_waveform is True:
hp,hc = self.simulate_waveform(factor*self.m1[n], factor*self.m2[n], dl, self.incs[n], self.phis[n])
rhosqs = [self.snr_squared_waveform(hp,hc,self.RAs[n],self.Decs[n],self.psis[n], 0., det, psd)
for det in detectors]
......@@ -681,29 +677,25 @@ class DetectionProbability(object):
det, 0.0, self.z_array[i], H0, psd)
for det in detectors]
network_rhosq[n] = np.sum(rhosqs)
det_SNR = ncx2.rvs(2*len(detectors), network_rhosq[n])
if det_SNR>=self.snr_threshold**2:
survival+=1
if self.detected_masses==True:
self.detected[i][n] = 1
else:
if (1-ncx2.cdf(x=self.snr_threshold**2,df=2*len(detectors), nc=network_rhosq[n]))<=1e-2:
detect[n]=0
else:
ncx2.rvs(2*len(detectors),100)
prob[i] = survival/self.Nsamps
if i%50==0:
survival = ncx2.sf(self.snr_threshold**2, 2*len(self.detectors), network_rhosq)
prob[i] = np.sum(survival, 0)/self.Nsamps
if self.detected_masses==True:
self.detected[i] = np.float32(survival)
not_surviving_samples = np.where(survival<=1e-6)[0] #threshold to consider event undetectable
detect[not_surviving_samples] = 0.
if i%20==0:
if os.path.isfile(self.path):
os.remove(self.path)
checkpoint = self.checkpointing(detect,prob,i)
pickle.dump(checkpoint, open( self.path, "wb" ))
if os.path.isfile(self.path):
os.remove(self.path)
os.remove(self.path)
return prob
def checkpointing(self,detect,prob,i):
return {'seed':self.seed,'detect':detect,'detected':self.detected,'prob_checkpoint':prob,'checkpoint_z':i,'z_array':self.z_array}
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment