Source code for pystatistics.mixed.solvers

"""
Solver dispatch for mixed models.

Public API:
    lmm()  — fit a linear mixed model (REML or ML)
    glmm() — fit a generalized linear mixed model (Laplace approximation)
"""

from __future__ import annotations

import warnings
import numpy as np
from numpy.typing import ArrayLike
from scipy.optimize import minimize
from scipy import stats

from pystatistics.core.result import Result
from pystatistics.core.compute.timing import Timer

from pystatistics.mixed._common import (
    LMMParams, GLMMParams, VarCompSummary,
)
from pystatistics.mixed._random_effects import (
    parse_random_effects, build_z_matrix, build_lambda,
    theta_lower_bounds, theta_start,
)
from pystatistics.mixed._pls import solve_pls
from pystatistics.mixed._deviance import (
    profiled_deviance_lmm, profiled_deviance_glmm,
)
from pystatistics.mixed._satterthwaite import satterthwaite_df
from pystatistics.mixed.design import MixedDesign
from pystatistics.mixed.solution import LMMSolution, GLMMSolution


[docs] def lmm( y: ArrayLike, X: ArrayLike, groups: dict[str, ArrayLike], *, random_effects: dict[str, list[str]] | None = None, random_data: dict[str, ArrayLike] | None = None, reml: bool = True, tol: float = 1e-8, max_iter: int = 200, compute_satterthwaite: bool = True, ) -> LMMSolution: """Fit a linear mixed model. Estimates fixed effects β, random effects variance components, and conditional modes (BLUPs) of random effects using the profiled REML/ML deviance approach from Bates et al. (2015). Args: y: Response vector (n,). X: Fixed effects design matrix (n, p). Should include an intercept column if desired. groups: Dict mapping grouping factor names to group label arrays. Example: {'subject': subject_ids}. random_effects: Optional dict mapping group names to lists of random effect terms. Default: random intercept per group. Example: {'subject': ['1', 'time']} for (1 + time | subject). random_data: Optional dict mapping variable names to data arrays for random slope variables. Example: {'time': time_array}. reml: If True (default), use REML estimation. If False, use ML. Use ML (reml=False) for likelihood ratio tests between models with different fixed effects. tol: Convergence tolerance for the optimizer. Default 1e-8. max_iter: Maximum optimizer iterations. Default 200. compute_satterthwaite: If True (default), compute Satterthwaite denominator df for fixed effects. Set to False for speed if p-values are not needed. Returns: LMMSolution with fixed effects, random effects, variance components, model fit statistics, and R-style summary(). Examples: # Random intercept model >>> result = lmm(y, X, groups={'subject': subject_ids}) # Random intercept + slope >>> result = lmm(y, X, groups={'subject': subject_ids}, ... random_effects={'subject': ['1', 'time']}, ... random_data={'time': time_array}) # Crossed random effects >>> result = lmm(y, X, groups={'subject': subj, 'item': item}) """ timer = Timer() timer.start() # Validate inputs design = MixedDesign.validate( np.asarray(y, dtype=np.float64), np.asarray(X, dtype=np.float64), groups, random_effects, random_data, ) with timer.section('setup'): # Parse random effects and build Z specs = parse_random_effects( design.groups, design.random_effects, design.random_data, design.n ) Z = build_z_matrix(specs) # Starting values and bounds theta0 = theta_start(specs) lb = theta_lower_bounds(specs) bounds = [(lb[i], None) for i in range(len(theta0))] # Optimize θ — use multiple starting points for models with random slopes # (the profiled deviance can have local minima when q > 1) with timer.section('optimization'): has_slopes = any(s.n_terms > 1 for s in specs) if has_slopes: # Generate candidate starting values: the default [1,0,...,1] # plus variants with smaller diagonal values for slope terms starts = [theta0] for scale in (0.2, 0.5): alt = theta0.copy() idx = 0 for spec in specs: q = spec.n_terms for row in range(q): for col in range(row + 1): if row == col and row > 0: alt[idx] = scale idx += 1 starts.append(alt) best_result = None for start in starts: res = minimize( profiled_deviance_lmm, start, args=(design.X, Z, design.y, specs, reml), method='L-BFGS-B', bounds=bounds, options={'maxiter': max_iter, 'ftol': tol, 'gtol': tol * 10}, ) if best_result is None or res.fun < best_result.fun: best_result = res opt_result = best_result else: opt_result = minimize( profiled_deviance_lmm, theta0, args=(design.X, Z, design.y, specs, reml), method='L-BFGS-B', bounds=bounds, options={'maxiter': max_iter, 'ftol': tol, 'gtol': tol * 10}, ) converged = opt_result.success theta_hat = opt_result.x n_iter = opt_result.nit if not converged: warnings.warn( f"LMM optimizer did not converge after {n_iter} iterations. " f"Message: {opt_result.message}", RuntimeWarning, stacklevel=2, ) # Final PLS solve at optimal θ with timer.section('final_solve'): Lambda_hat = build_lambda(theta_hat, specs) pls = solve_pls(design.X, Z, design.y, Lambda_hat, reml=reml) # Compute variance components with timer.section('variance_components'): var_comps = _extract_var_components(theta_hat, pls.sigma_sq, specs) n_groups_dict = {s.group_name: s.n_groups for s in specs} # Extract BLUPs with timer.section('blups'): random_effs = _extract_blups(pls.b, specs) # Compute Satterthwaite df and p-values with timer.section('satterthwaite'): se = _compute_se(pls, design.X, Z, Lambda_hat, design.X.shape[1]) if compute_satterthwaite: df_satt = satterthwaite_df( theta_hat, design.X, Z, design.y, specs, reml=reml ) else: # Use residual df as fallback df_satt = np.full(design.p, float(design.n - design.p)) t_vals = pls.beta / se p_vals = 2.0 * stats.t.sf(np.abs(t_vals), df_satt) # Log-likelihood, AIC, BIC with timer.section('model_fit'): ll, aic, bic = _compute_fit_stats( pls, theta_hat, design.n, design.p, specs, reml ) # Coefficient names coef_names = _make_coef_names(design.p) timer.stop() # Assemble params params = LMMParams( coefficients=pls.beta, coefficient_names=tuple(coef_names), se=se, df_satterthwaite=df_satt, t_values=t_vals, p_values=p_vals, var_components=tuple(var_comps), residual_variance=pls.sigma_sq, residual_std=np.sqrt(pls.sigma_sq), log_likelihood=ll, reml=reml, aic=aic, bic=bic, n_obs=design.n, n_groups=n_groups_dict, converged=converged, n_iter=n_iter, random_effects=random_effs, fitted_values=pls.fitted, residuals=pls.residuals, theta=theta_hat, ) warn_list = [] if not converged: warn_list.append(f"Optimizer did not converge: {opt_result.message}") result = Result( params=params, info={ 'method': 'REML' if reml else 'ML', 'optimizer': 'L-BFGS-B', 'converged': converged, 'n_iter': n_iter, 'deviance': opt_result.fun, }, timing=timer.result(), backend_name='cpu_lmm', warnings=tuple(warn_list), ) return LMMSolution(_result=result)
[docs] def glmm( y: ArrayLike, X: ArrayLike, groups: dict[str, ArrayLike], *, family: 'str | Family' = 'binomial', random_effects: dict[str, list[str]] | None = None, random_data: dict[str, ArrayLike] | None = None, tol: float = 1e-8, max_iter: int = 200, ) -> GLMMSolution: """Fit a generalized linear mixed model. Uses Laplace approximation to the marginal likelihood with Penalized IRLS (PIRLS) for the inner loop and L-BFGS-B for the outer optimization over variance components. Args: y: Response vector (n,). X: Fixed effects design matrix (n, p). groups: Dict mapping grouping factor names to group label arrays. family: GLM family specification. String ('binomial', 'poisson') or a Family instance from pystatistics.regression.families. random_effects: Optional random effects specification. random_data: Optional data for random slope variables. tol: Convergence tolerance. max_iter: Maximum optimizer iterations. Returns: GLMMSolution with fixed effects, random effects, and model fit. """ from pystatistics.regression.families import resolve_family, Family timer = Timer() timer.start() # Resolve family if not isinstance(family, Family): family_obj = resolve_family(family) else: family_obj = family # Validate inputs design = MixedDesign.validate( np.asarray(y, dtype=np.float64), np.asarray(X, dtype=np.float64), groups, random_effects, random_data, ) with timer.section('setup'): specs = parse_random_effects( design.groups, design.random_effects, design.random_data, design.n ) Z = build_z_matrix(specs) theta0 = theta_start(specs) lb = theta_lower_bounds(specs) bounds = [(lb[i], None) for i in range(len(theta0))] # Optimize θ via Laplace-approximated deviance with timer.section('optimization'): opt_result = minimize( profiled_deviance_glmm, theta0, args=(design.X, Z, design.y, specs, family_obj), method='L-BFGS-B', bounds=bounds, options={'maxiter': max_iter, 'ftol': tol, 'gtol': tol * 10}, ) converged = opt_result.success theta_hat = opt_result.x n_iter = opt_result.nit if not converged: warnings.warn( f"GLMM optimizer did not converge after {n_iter} iterations. " f"Message: {opt_result.message}", RuntimeWarning, stacklevel=2, ) # Final PIRLS solve at optimal θ with timer.section('final_solve'): from pystatistics.mixed._pirls import solve_pirls Lambda_hat = build_lambda(theta_hat, specs) pirls = solve_pirls(design.X, Z, design.y, Lambda_hat, family_obj) # Variance components (for GLMM, σ² = 1 by convention) with timer.section('variance_components'): var_comps = _extract_var_components(theta_hat, 1.0, specs) n_groups_dict = {s.group_name: s.n_groups for s in specs} # BLUPs with timer.section('blups'): random_effs = _extract_blups(pirls.pls.b, specs) # Fixed effect SEs and Wald z-statistics with timer.section('inference'): se = _compute_se_glmm(pirls.pls, design.X.shape[1]) z_vals = pirls.pls.beta / se p_vals = 2.0 * stats.norm.sf(np.abs(z_vals)) # Model fit with timer.section('model_fit'): wt = np.ones(design.n, dtype=np.float64) deviance = family_obj.deviance(design.y, pirls.mu, wt) n_params = len(pirls.pls.beta) + len(theta_hat) # Laplace-approximated marginal log-likelihood: # ll = conditional_loglik - 0.5 * ||u||^2 - 0.5 * log|L_theta|^2 # where conditional_loglik = sum(f(y_i | mu_i)) is the full # conditional log-likelihood including normalizing constants. # For GLMM, dispersion = 1. cond_ll = family_obj.log_likelihood(design.y, pirls.mu, wt, 1.0) penalty = float(pirls.pls.u @ pirls.pls.u) log_det_L = 2.0 * np.sum(np.log(np.maximum(np.diag(pirls.pls.L), 1e-20))) ll = cond_ll - 0.5 * penalty - 0.5 * log_det_L aic = -2.0 * ll + 2.0 * n_params bic = -2.0 * ll + np.log(design.n) * n_params coef_names = _make_coef_names(design.p) timer.stop() params = GLMMParams( coefficients=pirls.pls.beta, coefficient_names=tuple(coef_names), se=se, t_values=z_vals, p_values=p_vals, var_components=tuple(var_comps), log_likelihood=ll, deviance=deviance, aic=aic, bic=bic, n_obs=design.n, n_groups=n_groups_dict, family_name=family_obj.name, link_name=family_obj.link.name, converged=converged, n_iter=n_iter, random_effects=random_effs, fitted_values=pirls.mu, linear_predictor=pirls.eta, residuals=design.y - pirls.mu, theta=theta_hat, ) warn_list = [] if not converged: warn_list.append(f"Optimizer did not converge: {opt_result.message}") if not pirls.converged: warn_list.append(f"PIRLS did not converge after {pirls.n_iter} iterations") result = Result( params=params, info={ 'method': 'Laplace', 'family': family_obj.name, 'link': family_obj.link.name, 'optimizer': 'L-BFGS-B', 'converged': converged, 'pirls_converged': pirls.converged, 'n_iter': n_iter, 'pirls_iter': pirls.n_iter, 'deviance': opt_result.fun, }, timing=timer.result(), backend_name='cpu_glmm', warnings=tuple(warn_list), ) return GLMMSolution(_result=result)
# ===================================================================== # Helpers # ===================================================================== def _extract_var_components( theta: np.ndarray, sigma_sq: float, specs: list, ) -> list[VarCompSummary]: """Extract variance component summaries from θ and σ². The actual covariance of random effects is σ² × Λ Λ'. For each grouping factor, compute the covariance matrix and extract variance, std dev, and correlations. """ var_comps = [] theta_offset = 0 for spec in specs: q = spec.n_terms n_theta = spec.theta_size # Reconstruct the q × q lower-triangular Cholesky factor theta_k = theta[theta_offset:theta_offset + n_theta] theta_offset += n_theta T = np.zeros((q, q), dtype=np.float64) idx = 0 for row in range(q): for col in range(row + 1): T[row, col] = theta_k[idx] idx += 1 # Covariance matrix: σ² × T T' cov_matrix = sigma_sq * (T @ T.T) # Term names term_names = [] for term in spec.terms: if term == '1': term_names.append('(Intercept)') else: term_names.append(term) # Extract variance, std dev, correlations for i in range(q): var_i = cov_matrix[i, i] sd_i = np.sqrt(max(var_i, 0.0)) # Correlation with first term (only for 2nd+ terms) if i > 0 and cov_matrix[0, 0] > 0 and var_i > 0: corr = cov_matrix[i, 0] / (np.sqrt(cov_matrix[0, 0]) * sd_i) corr = np.clip(corr, -1.0, 1.0) else: corr = None var_comps.append(VarCompSummary( group=spec.group_name, name=term_names[i], variance=float(var_i), std_dev=float(sd_i), corr=float(corr) if corr is not None else None, )) return var_comps def _extract_blups(b: np.ndarray, specs: list) -> dict[str, np.ndarray]: """Extract BLUPs per grouping factor from the flat b vector. b is structured as [b_group1, b_group2, ...] where each b_groupk has J_k * q_k elements laid out as [term0_group0, term0_group1, ..., term1_group0, ...]. Returns dict: group_name → (J_k, q_k) array. """ result = {} offset = 0 for spec in specs: block_size = spec.n_groups * spec.n_terms b_block = b[offset:offset + block_size] # Reshape: columns are terms, rows are groups # b_block layout: [term0_g0, term0_g1, ..., term1_g0, term1_g1, ...] b_matrix = np.zeros((spec.n_groups, spec.n_terms), dtype=np.float64) for t in range(spec.n_terms): start = t * spec.n_groups b_matrix[:, t] = b_block[start:start + spec.n_groups] result[spec.group_name] = b_matrix offset += block_size return result def _compute_se(pls, X, Z, Lambda, p: int) -> np.ndarray: """Compute standard errors of fixed effects via V matrix. SE = sqrt(diag(Var(β̂))) where Var(β̂) = σ² × (X'V*⁻¹X)⁻¹ and V* = ZΛΛ'Z' + I. This matches R's lme4 computation exactly, avoiding numerical differences from the Schur complement approach. """ n = X.shape[0] V_star = Z @ Lambda @ Lambda.T @ Z.T + np.eye(n) try: C = np.linalg.inv(X.T @ np.linalg.solve(V_star, X)) except np.linalg.LinAlgError: C = np.linalg.pinv(X.T @ np.linalg.solve(V_star, X)) vcov = pls.sigma_sq * C se = np.sqrt(np.maximum(np.diag(vcov), 0.0)) return se def _compute_se_glmm(pls, p: int) -> np.ndarray: """Compute standard errors for GLMM (σ² = 1 by convention).""" try: RX_inv = np.linalg.inv(pls.RX) vcov = RX_inv @ RX_inv.T except np.linalg.LinAlgError: RtR = pls.RX @ pls.RX.T vcov = np.linalg.pinv(RtR) se = np.sqrt(np.maximum(np.diag(vcov), 0.0)) return se def _compute_fit_stats(pls, theta, n, p, specs, reml): """Compute log-likelihood, AIC, BIC for LMM.""" sigma_sq = pls.sigma_sq pwrss = pls.pwrss # Number of variance parameters n_theta = len(theta) # Total parameters for AIC: fixed effects + variance components + σ² n_params = p + n_theta + 1 if reml: df = n - p # REML log-likelihood log_det_L = 2.0 * np.sum(np.log(np.maximum(np.diag(pls.L), 1e-20))) log_det_RX = 2.0 * np.sum(np.log(np.maximum(np.abs(np.diag(pls.RX)), 1e-20))) ll = -0.5 * ( log_det_L + log_det_RX + df * (1.0 + np.log(2.0 * np.pi * pwrss / df)) ) # REML AIC/BIC: R's lme4 counts all parameters (fixed + variance) # npar = p (fixed effects) + n_theta (RE params) + 1 (sigma) aic = -2.0 * ll + 2.0 * n_params bic = -2.0 * ll + np.log(n) * n_params else: # ML log-likelihood log_det_L = 2.0 * np.sum(np.log(np.maximum(np.diag(pls.L), 1e-20))) ll = -0.5 * ( log_det_L + n * (1.0 + np.log(2.0 * np.pi * pwrss / n)) ) aic = -2.0 * ll + 2.0 * n_params bic = -2.0 * ll + np.log(n) * n_params return float(ll), float(aic), float(bic) def _make_coef_names(p: int) -> list[str]: """Generate default coefficient names.""" if p == 1: return ['(Intercept)'] names = ['(Intercept)'] for i in range(1, p): names.append(f'X{i}') return names