Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Changelog

## Unreleased
* fix pickle and deep copy for classification models inheriting from EP #1108 [olamarre]

* update prior `__new__` methods #1098 [MartinBubel]

* fix invalid escape sequence #1011 [janmayer]
Expand Down
27 changes: 10 additions & 17 deletions GPy/inference/latent_function_inference/expectation_propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,24 +229,17 @@ def _stop_criteria(self, ga_approx):
v_diff = np.mean(np.square(ga_approx.v-self.ga_approx_old.v))
return ((tau_diff < self.epsilon) and (v_diff < self.epsilon))

def __setstate__(self, state):
super(EPBase, self).__setstate__(state[0])
self.epsilon, self.eta, self.delta = state[1]
self.reset()

def __getstate__(self):
return [super(EPBase, self).__getstate__() , [self.epsilon, self.eta, self.delta]]

def _save_to_input_dict(self):
input_dict = super(EPBase, self)._save_to_input_dict()
input_dict["epsilon"]=self.epsilon
input_dict["eta"]=self.eta
input_dict["delta"]=self.delta
input_dict["always_reset"]=self.always_reset
input_dict["max_iters"]=self.max_iters
input_dict["ep_mode"]=self.ep_mode
input_dict["parallel_updates"]=self.parallel_updates
input_dict["loading"]=True
input_dict = {
"epsilon": self.epsilon,
"eta": self.eta,
"delta": self.delta,
"always_reset": self.always_reset,
"max_iters": self.max_iters,
"ep_mode": self.ep_mode,
"parallel_updates": self.parallel_updates,
"loading": True
}
return input_dict

class EP(EPBase, ExactGaussianInference):
Expand Down
28 changes: 28 additions & 0 deletions GPy/testing/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
The test cases for various inference algorithms
"""

import copy
import pickle
import numpy as np
import GPy

Expand Down Expand Up @@ -146,6 +148,32 @@ def test_inference_EP(self):
< 1e6
)

def test_pickle_copy_EP(self):
"""Pickling and deep-copying a classification model employing EP"""

# Dummy binary classification dataset
X = np.array([0, 1, 2, 3]).reshape(-1, 1)
Y = np.array([0, 0, 1, 1]).reshape(-1, 1)

# Some classification model
inf = GPy.inference.latent_function_inference.expectation_propagation.EP(
max_iters=30, delta=0.5
)
m = GPy.core.GP(
X=X,
Y=Y,
kernel=GPy.kern.RBF(input_dim=1, variance=1.0, lengthscale=1.0),
inference_method = inf,
likelihood=GPy.likelihoods.Bernoulli(),
mean_function=None
)
m.optimize()

m_pickled = pickle.dumps(m)
assert pickle.loads(m_pickled) is not None

assert copy.deepcopy(m) is not None

# NOTE: adding a test like above for parameterized likelihood- the above test is
# only for probit likelihood which does not have any tunable hyperparameter which is why
# the term in dictionary of gradients: dL_dthetaL will always be zero. So here we repeat tests for
Expand Down
Loading