"""
Solver dispatch for Monte Carlo methods.
Provides boot(), boot_ci(), and permutation_test() as the public API.
"""
from __future__ import annotations
from typing import Any, Callable, Literal, Sequence
import numpy as np
from numpy.typing import ArrayLike, NDArray
from pystatistics.montecarlo._common import BootParams
from pystatistics.montecarlo.design import BootstrapDesign, PermutationDesign
from pystatistics.montecarlo.solution import BootstrapSolution, PermutationSolution
from pystatistics.montecarlo.backends.cpu import (
CPUBootstrapBackend,
CPUPermutationBackend,
)
BackendChoice = Literal['auto', 'cpu', 'gpu']
def _get_boot_backend(backend: BackendChoice):
"""Select bootstrap backend."""
if backend in ('cpu', 'auto'):
# auto defaults to CPU for bootstrap (GPU requires special handling)
return CPUBootstrapBackend()
if backend == 'gpu':
try:
from pystatistics.montecarlo.backends.gpu import GPUBootstrapBackend
return GPUBootstrapBackend()
except (ImportError, RuntimeError):
# Fall back to CPU if GPU unavailable
return CPUBootstrapBackend()
raise ValueError(f"Unknown backend: {backend!r}")
def _get_perm_backend(backend: BackendChoice):
"""Select permutation backend."""
if backend in ('cpu', 'auto'):
return CPUPermutationBackend()
if backend == 'gpu':
try:
from pystatistics.montecarlo.backends.gpu import GPUPermutationBackend
return GPUPermutationBackend()
except (ImportError, RuntimeError):
return CPUPermutationBackend()
raise ValueError(f"Unknown backend: {backend!r}")
[docs]
def boot(
data: ArrayLike,
statistic: Callable,
R: int = 999,
*,
sim: Literal["ordinary", "parametric", "balanced"] = "ordinary",
stype: Literal["i", "f", "w"] = "i",
strata: ArrayLike | None = None,
ran_gen: Callable | None = None,
mle: Any = None,
seed: int | None = None,
backend: BackendChoice = 'auto',
) -> BootstrapSolution:
"""
Bootstrap resampling. Matches R's boot::boot().
The statistic function signature depends on sim:
- For nonparametric (sim="ordinary" or "balanced"):
statistic(data, indices) -> array of shape (k,)
where indices are bootstrap sample indices (stype="i"),
frequency counts (stype="f"), or weights (stype="w").
- For parametric (sim="parametric"):
statistic(simulated_data) -> array of shape (k,)
where simulated_data is generated by ran_gen(data, mle, rng).
Args:
data: Original data, shape (n,) or (n, p).
statistic: Function to compute the statistic(s) of interest.
R: Number of bootstrap replicates. Default 999.
sim: Simulation type: "ordinary", "balanced", or "parametric".
stype: Type of second argument to statistic: "i", "f", or "w".
strata: Stratification vector (resampling within strata).
ran_gen: For parametric bootstrap: fn(data, mle, rng) -> sim_data.
mle: Parameter estimates for parametric bootstrap.
seed: Random seed for reproducibility.
backend: "auto", "cpu", or "gpu".
Returns:
BootstrapSolution with t0, t, bias, SE.
Examples:
>>> import numpy as np
>>> from pystatistics.montecarlo import boot
>>> data = np.array([1.0, 2.0, 3.0, 4.0, 5.0])
>>> def mean_stat(data, indices):
... return np.array([np.mean(data[indices])])
>>> result = boot(data, mean_stat, R=999, seed=42)
>>> result.t0 # observed mean
>>> result.bias # bootstrap bias estimate
>>> result.se # bootstrap standard error
"""
design = BootstrapDesign.for_bootstrap(
data=data,
statistic=statistic,
R=R,
sim=sim,
stype=stype,
strata=strata,
ran_gen=ran_gen,
mle=mle,
seed=seed,
)
be = _get_boot_backend(backend)
result = be.solve(design)
return BootstrapSolution(_result=result, _design=design)
[docs]
def boot_ci(
boot_out: BootstrapSolution,
*,
conf: float | Sequence[float] = 0.95,
type: str | Sequence[str] = "all",
index: int = 0,
var_t0: float | None = None,
var_t: NDArray | None = None,
) -> BootstrapSolution:
"""
Compute bootstrap confidence intervals. Matches R's boot::boot.ci().
Takes a BootstrapSolution from boot() and computes confidence intervals
using one or more methods.
Args:
boot_out: Result from boot().
conf: Confidence level(s). Default 0.95.
type: CI type(s): "normal", "basic", "perc", "bca", "stud", or "all".
"all" computes normal, basic, percentile, and BCa (not studentized
unless var_t is provided).
index: Which statistic to compute CI for (0-indexed into t0).
var_t0: Variance of the observed statistic (for normal/studentized).
var_t: Per-replicate variance estimates, shape (R,). Required for
studentized CI.
Returns:
New BootstrapSolution with CI populated.
Examples:
>>> result = boot(data, mean_stat, R=999, seed=42)
>>> ci_result = boot_ci(result, type="perc")
>>> ci_result.ci["perc"] # shape (k, 2) for [lower, upper]
"""
from pystatistics.montecarlo._ci import compute_ci
# Normalize conf to a single float for now
if isinstance(conf, (list, tuple)):
conf_level = float(conf[0])
else:
conf_level = float(conf)
# Normalize type
if isinstance(type, str):
if type == "all":
types = ["normal", "basic", "perc", "bca"]
if var_t is not None:
types.append("stud")
else:
types = [type]
else:
types = list(type)
ci_dict = compute_ci(
boot_out=boot_out,
types=types,
conf_level=conf_level,
index=index,
var_t0=var_t0,
var_t=var_t,
)
# Create new BootParams with CI
old_params = boot_out._result.params
new_params = BootParams(
t0=old_params.t0,
t=old_params.t,
R=old_params.R,
bias=old_params.bias,
se=old_params.se,
ci=ci_dict,
ci_conf_level=conf_level,
)
from pystatistics.core.result import Result
new_result = Result(
params=new_params,
info=boot_out._result.info,
timing=boot_out._result.timing,
backend_name=boot_out._result.backend_name,
warnings=boot_out._result.warnings,
)
return BootstrapSolution(_result=new_result, _design=boot_out._design)
[docs]
def permutation_test(
x: ArrayLike,
y: ArrayLike,
statistic: Callable,
R: int = 9999,
*,
alternative: Literal["two.sided", "less", "greater"] = "two.sided",
seed: int | None = None,
backend: BackendChoice = 'auto',
) -> PermutationSolution:
"""
Permutation test for two groups.
Shuffles the combined data R times, computing the test statistic on
each permutation. P-value uses the Phipson-Smyth correction:
(count + 1) / (R + 1).
Args:
x: Group 1 data.
y: Group 2 data.
statistic: fn(x, y) -> float. The test statistic.
R: Number of permutations. Default 9999.
alternative: "two.sided", "less", or "greater".
seed: Random seed for reproducibility.
backend: "auto", "cpu", or "gpu".
Returns:
PermutationSolution with observed_stat, perm_stats, p_value.
Examples:
>>> x = np.array([1, 2, 3, 4, 5])
>>> y = np.array([6, 7, 8, 9, 10])
>>> def mean_diff(x, y): return np.mean(x) - np.mean(y)
>>> result = permutation_test(x, y, mean_diff, R=9999, seed=42)
>>> result.p_value
"""
design = PermutationDesign.for_permutation_test(
x=x,
y=y,
statistic=statistic,
R=R,
alternative=alternative,
seed=seed,
)
be = _get_perm_backend(backend)
result = be.solve(design)
return PermutationSolution(_result=result, _design=design)