Enable pickle compatibility for Posterior class
Done statement: The .__dict__
attribute of the Posterior
class and its child classes (i.e. MarginalizedPosterior
) can be successfully saved using pickle.dump
.
Context
SPIIR's basic p_astro model for signal vs. noise is implemented as a custom class that imports from ligo.p_astro
. This class also has convenience methods such as .fit
, .predict
, .save
and .load
which are commonly used as part of a typical statistical modelling/data science workflows. In particular, SPIIR has written a save_pkl
and load_pkl
method that serializes the trained model state as a .pkl
file, which can be loaded later by the pipeline during inference. These methods look like the following:
def save_pkl(self, path: Union[str, Path]):
with Path(path).open(mode="wb") as f:
pickle.dump(self.__dict__, f)
def load_pkl(self, path: Union[str, Path]):
with Path(path).open(mode="rb") as f:
self.__dict__ = pickle.load(f)
Problem
If we have a variable defined as a dict
, i.e. x = dict(a=1, b=2)
, then we are unable to pickle x.keys()
nor x.values()
. This is a basic limitation of Python's pickle object serialization.
Unfortunately, the ligo.p_astro.Posterior
class has two attributes, self.args_astro
and self.keys_fixed
that are defined as dict.values() and dict.keys() respectively. As long as these two attributes are stored as this type, we cannot use ligo.p_astro
as is for our use case of conveniently saving model state with pickle.
SPIIR handles this problem itself (by forking this repository and/or manually changing the type of these two attributes) - but in reality this a very simple fix that can be done upstream here in the ligo.p_astro
repository itself.
Requested Changes
As I do not have access to push branches for a merge request into this repository, I have taken a screenshot of the necessary changes to fix this problem.
I have created a merge request from the feautre/enable_pickle_compat branch from the spiir-group/p-astro fork, see: !43.
It is a simple change on two lines, converting astro_sources.values()
to list(astro_sources.values())
and fix_sources.keys()
to list(fix_sources.keys())
respectively.
As self.args_astro
and self.keys_fixed
are only ever used as Iterable objects, the conversion to the list
type does not break any functionality or change behaviour. However, it does allow us to pickle the state of a trained MarginalizedPosterior
class.
Note that if we would like these attributes to be immutable rather than mutable, an alternative would be to use tuple()
rather than list()
.
Demonstration
To illustrate this problem and solution, we'll demonstrate a minimal example based on the spiir.search.p_astro.TwoComponentModel
class and methods as mentioned in the initial context of this Issue.
Let us define a class with a dictionary attribute, called MyClass
that has a similar save method as the one in the Python code snippet in the initial Context above. Then, we'll attempt to save this instantiated object to a .pkl file while it has attributes stored as dict.keys()/dict.values(). We should expect to see a TypeError
preventing us from serializing my_class
.
Next, we re-define the same class but this time we convert both the dict.keys() and dict.values() into lists when storing the attribute data. We'll see the code executes successfully.
For completeness, we can also show that loading this .pkl file correctly updates model state:
spiir.search.p_astro.TwoComponentModel
Demonstration with We repeat a minimal example with our typical workflow in SPIIR in a notebook, showing how changing the type of our nested attributes fixes this problem in the same way: