Skip to content

Commit 591ed88

Browse files
authored
Merge pull request #65 from simon-hirsch:dist_add_log
Add log pdf/cdf/pmf
2 parents 51a07d7 + 8009f89 commit 591ed88

2 files changed

Lines changed: 81 additions & 6 deletions

File tree

src/rolch/base/distribution.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,21 @@ def initial_values(
114114
) -> np.ndarray:
115115
"""Calculate the initial values for the GAMLSS fit."""
116116

117+
def quantile(self, q: np.ndarray, theta: np.ndarray) -> np.ndarray:
118+
"""
119+
Compute the quantile function for the given data.
120+
121+
This is a alias for the `ppf` method.
122+
123+
Parameters:
124+
q (np.ndarray): The quantiles to compute.
125+
theta (np.ndarray): The parameters of the distribution.
126+
127+
Returns:
128+
np.ndarray: The quantiles corresponding to the given probabilities.
129+
"""
130+
return self.ppf(q, theta)
131+
117132
@abstractmethod
118133
def cdf(self, y: np.ndarray, theta: np.ndarray) -> np.ndarray:
119134
"""
@@ -179,6 +194,30 @@ def rvs(self, size: int, theta: np.ndarray) -> np.ndarray:
179194
np.ndarray: A 2D array of random variates with shape (theta.shape[0], size).
180195
"""
181196

197+
@abstractmethod
198+
def logpmf(self, y: np.ndarray, theta: np.ndarray) -> np.ndarray:
199+
raise NotImplementedError(
200+
"Log PMF is not implemented for continuous distributions."
201+
)
202+
203+
@abstractmethod
204+
def logpdf(self, y: np.ndarray, theta: np.ndarray) -> np.ndarray:
205+
raise NotImplementedError(
206+
"Log PDF is not implemented for discrete distributions."
207+
)
208+
209+
@abstractmethod
210+
def logcdf(self, y: np.ndarray, theta: np.ndarray) -> np.ndarray:
211+
"""Compute the log of the cumulative distribution function (CDF) for the given data points.
212+
213+
Parameters:
214+
y (np.ndarray): An array of data points at which to evaluate the log CDF.
215+
theta (np.ndarray): An array of parameters for the distribution.
216+
217+
Returns:
218+
np.ndarray: An array of log CDF values corresponding to the data points in `y`.
219+
"""
220+
182221

183222
class ScipyMixin(ABC):
184223

@@ -236,6 +275,42 @@ def pmf(self, y: np.ndarray, theta: np.ndarray) -> np.ndarray:
236275
def ppf(self, q: np.ndarray, theta: np.ndarray) -> np.ndarray:
237276
return self.scipy_dist(**self.theta_to_scipy_params(theta)).ppf(q)
238277

278+
def logpmf(self, y: np.ndarray, theta: np.ndarray) -> np.ndarray:
279+
"""Compute the log of the probability mass function (PMF) for the given data points.
280+
281+
Parameters:
282+
y (np.ndarray): An array of data points at which to evaluate the log PMF.
283+
theta (np.ndarray): An array of parameters for the distribution.
284+
285+
Returns:
286+
np.ndarray: An array of log PMF values corresponding to the data points in `y`.
287+
"""
288+
return self.scipy_dist(**self.theta_to_scipy_params(theta)).logpmf(y)
289+
290+
def logpdf(self, y: np.ndarray, theta: np.ndarray) -> np.ndarray:
291+
"""Compute the log of the probability density function (PDF) for the given data points.
292+
293+
Parameters:
294+
y (np.ndarray): An array of data points at which to evaluate the log PDF.
295+
theta (np.ndarray): An array of parameters for the distribution.
296+
297+
Returns:
298+
np.ndarray: An array of log PDF values corresponding to the data points in `y`.
299+
"""
300+
return self.scipy_dist(**self.theta_to_scipy_params(theta)).logpdf(y)
301+
302+
def logcdf(self, y: np.ndarray, theta: np.ndarray) -> np.ndarray:
303+
"""Compute the log of the cumulative distribution function (CDF) for the given data points.
304+
305+
Parameters:
306+
y (np.ndarray): An array of data points at which to evaluate the log CDF.
307+
theta (np.ndarray): An array of parameters for the distribution.
308+
309+
Returns:
310+
np.ndarray: An array of log CDF values corresponding to the data points in `y`.
311+
"""
312+
return self.scipy_dist(**self.theta_to_scipy_params(theta)).logcdf(y)
313+
239314
def rvs(self, size: int, theta: np.ndarray) -> np.ndarray:
240315
return (
241316
self.scipy_dist(**self.theta_to_scipy_params(theta))

src/rolch/estimators/online_gamlss.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -646,7 +646,7 @@ def update(
646646

647647
def _outer_update(self, X, y, w):
648648
## for new observations:
649-
global_di = -2 * np.log(self.distribution.pdf(y, self.fv))
649+
global_di = -2 * self.distribution.logpdf(y, self.fv)
650650
global_dev = (1 - self.forget[0]) * self.global_dev + global_di
651651
global_dev_old = global_dev + 1000
652652
iteration_outer = 0
@@ -696,7 +696,7 @@ def _outer_update(self, X, y, w):
696696

697697
def _outer_fit(self, X, y, w):
698698

699-
global_di = -2 * np.log(self.distribution.pdf(y, self.fv))
699+
global_di = -2 * self.distribution.logpdf(y, self.fv)
700700
global_dev = np.sum(w * global_di)
701701
global_dev_old = global_dev + 1000
702702
iteration_outer = 0
@@ -761,7 +761,7 @@ def _inner_fit(
761761
dv,
762762
):
763763

764-
di = -2 * np.log(self.distribution.pdf(y, self.fv))
764+
di = -2 * self.distribution.logpdf(y, self.fv)
765765
dv = np.sum(di * w)
766766
olddv = dv + 1
767767

@@ -837,7 +837,7 @@ def _inner_fit(
837837
eta = X[param] @ self.beta[param].T
838838
self.fv[:, param] = self.distribution.link_inverse(eta, param=param)
839839

840-
di = -2 * np.log(self.distribution.pdf(y, self.fv))
840+
di = -2 * self.distribution.logpdf(y, self.fv)
841841
olddv = dv
842842
dv = np.sum(di * w)
843843

@@ -866,7 +866,7 @@ def _inner_update(
866866
dv,
867867
param,
868868
):
869-
di = -2 * np.log(self.distribution.pdf(y, self.fv))
869+
di = -2 * self.distribution.logpdf(y, self.fv)
870870
dv = (1 - self.forget[0]) * self.global_dev + np.sum(di * w)
871871
olddv = dv + 1
872872

@@ -949,7 +949,7 @@ def _inner_update(
949949

950950
olddv = dv
951951

952-
di = -2 * np.log(self.distribution.pdf(y, self.fv))
952+
di = -2 * self.distribution.logpdf(y, self.fv)
953953
dv = np.sum(di * w) + (1 - self.forget[0]) * self.global_dev
954954

955955
message = f"Outer iteration {iteration_outer}: Fitting Parameter {param}: Inner iteration {iteration_inner}: Current LL {dv}"

0 commit comments

Comments
 (0)