Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
Bilby_Psi4
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Requirements
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Locked files
Deploy
Releases
Package registry
Container Registry
Model registry
Operate
Terraform modules
Monitor
Incidents
Service Desk
Analyze
Value stream analytics
Contributor analytics
Repository analytics
Code review analytics
Issue analytics
Insights
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Samson Leong
Bilby_Psi4
Commits
41cb703c
Commit
41cb703c
authored
7 months ago
by
Colm Talbot
Browse files
Options
Downloads
Patches
Plain Diff
FEAT: enable caching to be disabled in hyper.model.Model
parent
e37c308a
No related branches found
Branches containing commit
No related tags found
No related merge requests found
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
bilby/hyper/model.py
+17
-7
17 additions, 7 deletions
bilby/hyper/model.py
with
17 additions
and
7 deletions
bilby/hyper/model.py
+
17
−
7
View file @
41cb703c
from
..core.utils
import
infer_args_from_function_except_n_args
class
Model
(
object
)
:
class
Model
:
r
"""
Population model that combines a set of factorizable models.
...
...
@@ -12,18 +12,24 @@ class Model(object):
p(\theta | \Lambda) = \prod_{i} p_{i}(\theta | \Lambda)
"""
def
__init__
(
self
,
model_functions
=
None
):
def
__init__
(
self
,
model_functions
=
None
,
cache
=
True
):
"""
Parameters
==========
model_functions: list
List of callables to compute the probability.
If this includes classes, the `__call__` method
should return the
probability.
If this includes classes, the
:code:
`__call__`
:
method
should return the
probability.
The requires variables are chosen at run time based on either
inspection or querying a :code:`variable_names` attribute.
cache: bool
Whether to cache the value returned by the model functions,
default=:code:`True`. The caching only looks at the parameters
not the data, so should be used with caution. The caching also
breaks :code:`jax` JIT compilation.
"""
self
.
models
=
model_functions
self
.
cache
=
cache
self
.
_cached_parameters
=
{
model
:
None
for
model
in
self
.
models
}
self
.
_cached_probability
=
{
model
:
None
for
model
in
self
.
models
}
...
...
@@ -48,14 +54,18 @@ class Model(object):
probability
=
1.0
for
ii
,
function
in
enumerate
(
self
.
models
):
function_parameters
=
self
.
_get_function_parameters
(
function
)
if
self
.
_cached_parameters
[
function
]
==
function_parameters
:
if
(
self
.
cache
and
self
.
_cached_parameters
[
function
]
==
function_parameters
):
new_probability
=
self
.
_cached_probability
[
function
]
else
:
new_probability
=
function
(
data
,
**
self
.
_get_function_parameters
(
function
)
)
self
.
_cached_parameters
[
function
]
=
function_parameters
self
.
_cached_probability
[
function
]
=
new_probability
if
self
.
cache
:
self
.
_cached_parameters
[
function
]
=
function_parameters
self
.
_cached_probability
[
function
]
=
new_probability
probability
*=
new_probability
return
probability
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment