Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
bilby
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Iterations
Requirements
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Locked files
Build
Pipelines
Jobs
Pipeline schedules
Test cases
Artifacts
Deploy
Releases
Container Registry
Model registry
Operate
Environments
Monitor
Incidents
Service Desk
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Code review analytics
Issue analytics
Insights
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
lscsoft
bilby
Commits
def7f5e3
Commit
def7f5e3
authored
5 years ago
by
Gregory Ashton
Browse files
Options
Downloads
Patches
Plain Diff
Improve resume behaviour from checkpoint
parent
49581fb5
No related branches found
No related tags found
1 merge request
!750
Improve ptemcee
Pipeline
#113248
failed
5 years ago
Stage: test
Changes
1
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
bilby/core/sampler/ptemcee.py
+42
-22
42 additions, 22 deletions
bilby/core/sampler/ptemcee.py
with
42 additions
and
22 deletions
bilby/core/sampler/ptemcee.py
+
42
−
22
View file @
def7f5e3
...
...
@@ -97,22 +97,36 @@ class Ptemcee(MCMCSampler):
if
os
.
path
.
isfile
(
self
.
resume_file
)
and
self
.
resume
is
True
:
logger
.
info
(
"
Resume data {} found
"
.
format
(
self
.
resume_file
))
with
open
(
self
.
resume_file
,
"
rb
"
)
as
file
:
import
IPython
;
IPython
.
embed
()
data
=
dill
.
load
(
file
)
self
.
sampler
=
data
[
"
sampler
"
]
self
.
sampler
.
pool
=
self
.
pool
self
.
sampler
.
threads
=
self
.
threads
self
.
tau_list
=
data
[
"
tau_list
"
]
self
.
tau_list_n
=
data
[
"
tau_list_n
"
]
self
.
time_per_check
=
data
[
"
time_per_check
"
]
self
.
sampler
.
pool
=
self
.
pool
self
.
sampler
.
threads
=
self
.
threads
pos0
=
None
logger
.
info
(
"
Resuming from previous run with time={}
"
.
format
(
self
.
sampler
.
time
))
else
:
# Initialize the PTSampler
self
.
sampler
=
ptemcee
.
Sampler
(
dim
=
self
.
ndim
,
logl
=
do_nothing_function
,
logp
=
do_nothing_function
,
pool
=
self
.
pool
,
threads
=
self
.
threads
,
**
self
.
sampler_init_kwargs
)
# Overwrite the _likeprior to improve performance with threads > 1
self
.
sampler
.
_likeprior
=
LikePriorEvaluator
(
self
.
search_parameter_keys
,
use_ratio
=
self
.
use_ratio
)
# Set up empty lists
self
.
tau_list
=
[]
self
.
tau_list_n
=
[]
self
.
time_per_check
=
[]
# Initialize the walker postitions
pos0
=
self
.
get_pos0_from_prior
()
return
self
.
sampler
,
pos0
...
...
@@ -138,9 +152,6 @@ class Ptemcee(MCMCSampler):
def
run_sampler_internal
(
self
):
import
emcee
sampler
,
pos0
=
self
.
setup_sampler
()
self
.
time_per_check
=
[]
self
.
tau_list
=
[]
self
.
tau_list_n
=
[]
t0
=
datetime
.
datetime
.
now
()
logger
.
info
(
"
Starting to sample
"
)
...
...
@@ -230,7 +241,7 @@ class Ptemcee(MCMCSampler):
last_checkpoint_s
=
np
.
sum
(
self
.
time_per_check
)
if
last_checkpoint_s
>
self
.
check_point_deltaT
:
self
.
write_current_state
()
self
.
write_current_state
(
plot
=
self
.
plot
)
# Check if we reached the end without converging
if
sampler
.
time
==
self
.
sampler_function_kwargs
[
"
iterations
"
]:
...
...
@@ -241,7 +252,7 @@ class Ptemcee(MCMCSampler):
)
# Run a final checkpoint to update the plots and samples
self
.
write_current_state
()
self
.
write_current_state
(
plot
=
self
.
plot
)
# Get 0-likelihood samples and store in the result
samples
=
sampler
.
chain
[
0
,
:,
:,
:]
# nwalkers, nsteps, ndim
...
...
@@ -268,16 +279,27 @@ class Ptemcee(MCMCSampler):
def
write_current_state_and_exit
(
self
,
signum
=
None
,
frame
=
None
):
logger
.
warning
(
"
Run terminated with signal {}
"
.
format
(
signum
))
if
self
.
pool
:
if
getattr
(
self
,
'
pool
'
,
None
):
self
.
write_current_state
(
plot
=
False
)
logger
.
warning
(
"
Closing pool
"
)
self
.
pool
.
close
()
self
.
write_current_state
()
sys
.
exit
(
77
)
def
write_current_state
(
self
):
def
write_current_state
(
self
,
plot
=
True
):
checkpoint
(
self
.
outdir
,
self
.
label
,
self
.
nsamples_effective
,
self
.
sampler
,
self
.
nburn
,
self
.
thin
,
self
.
search_parameter_keys
,
self
.
resume_file
,
self
.
tau_list
,
self
.
tau_list_n
)
self
.
tau_list_n
,
self
.
time_per_check
)
if
plot
:
# Generate the walkers plot diagnostic
plot_walkers
(
self
.
sampler
.
chain
[
0
,
:,
:
self
.
sampler
.
time
,
:],
self
.
nburn
,
self
.
search_parameter_keys
,
self
.
outdir
,
self
.
label
)
# Generate the tau plot diagnostic
plot_tau
(
self
.
tau_list_n
,
self
.
tau_list
,
self
.
outdir
,
self
.
label
)
def
print_progress
(
...
...
@@ -340,7 +362,8 @@ def print_progress(
def
checkpoint
(
outdir
,
label
,
nsamples_effective
,
sampler
,
nburn
,
thin
,
search_parameter_keys
,
resume_file
,
tau_list
,
tau_list_n
):
search_parameter_keys
,
resume_file
,
tau_list
,
tau_list_n
,
time_per_check
):
logger
.
info
(
"
Writing checkpoint and diagnostics
"
)
ndim
=
sampler
.
dim
...
...
@@ -360,20 +383,17 @@ def checkpoint(outdir, label, nsamples_effective, sampler, nburn, thin,
sampler_copy
.
_logposterior
=
sampler
.
_logposterior
[:,
:,
:
sampler
.
time
]
sampler_copy
.
_loglikelihood
=
sampler
.
_loglikelihood
[:,
:,
:
sampler
.
time
]
sampler_copy
.
_beta_history
=
sampler
.
_beta_history
[:,
:
sampler
.
time
]
data
=
dict
(
sampler
=
sampler_copy
,
tau_list
=
tau_list
,
tau_list_n
=
tau_list_n
)
data
=
dict
(
sampler
=
sampler_copy
,
tau_list
=
tau_list
,
tau_list_n
=
tau_list_n
,
time_per_check
=
time_per_check
)
with
open
(
resume_file
,
"
wb
"
)
as
file
:
dill
.
dump
(
data
,
file
,
protocol
=
4
)
del
data
,
sampler_copy
# Generate the walkers plot diagnostic
plot_walkers
(
sampler
.
chain
[
0
,
:,
:
sampler
.
time
,
:],
nburn
,
search_parameter_keys
,
outdir
,
label
)
# Generate the tau plot diagnostic
plot_tau
(
tau_list_n
,
tau_list
,
outdir
,
label
)
logger
.
info
(
"
Finished writing checkpoint
and diagnostics
"
)
logger
.
info
(
"
Finished writing checkpoint
"
)
def
plot_walkers
(
walkers
,
nburn
,
parameter_labels
,
outdir
,
label
):
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment