diff --git a/pymc/distributions/continuous.py b/pymc/distributions/continuous.py index 1d30227fed..b8b71d3205 100644 --- a/pymc/distributions/continuous.py +++ b/pymc/distributions/continuous.py @@ -512,6 +512,13 @@ def logcdf(value, mu, sigma): msg="sigma > 0", ) + def logccdf(value, mu, sigma): + return check_parameters( + normal_lccdf(mu, sigma, value), + sigma > 0, + msg="sigma > 0", + ) + def icdf(value, mu, sigma): res = mu + sigma * -np.sqrt(2.0) * pt.erfcinv(2 * value) res = check_icdf_value(res, value) diff --git a/pymc/distributions/distribution.py b/pymc/distributions/distribution.py index 27d53c8687..d9e54c648e 100644 --- a/pymc/distributions/distribution.py +++ b/pymc/distributions/distribution.py @@ -50,7 +50,7 @@ rv_size_is_none, shape_from_dims, ) -from pymc.logprob.abstract import MeasurableOp, _icdf, _logcdf, _logprob +from pymc.logprob.abstract import MeasurableOp, _icdf, _logccdf, _logcdf, _logprob from pymc.logprob.basic import logp from pymc.logprob.rewriting import logprob_rewrites_db from pymc.printing import str_for_dist @@ -150,6 +150,17 @@ def logcdf(op, value, *dist_params, **kwargs): dist_params = [dist_params[i] for i in params_idxs] return class_logcdf(value, *dist_params) + class_logccdf = clsdict.get("logccdf") + if class_logccdf: + + @_logccdf.register(rv_type) + def logccdf(op, value, *dist_params, **kwargs): + if isinstance(op, RandomVariable): + rng, size, *dist_params = dist_params + elif params_idxs: + dist_params = [dist_params[i] for i in params_idxs] + return class_logccdf(value, *dist_params) + class_icdf = clsdict.get("icdf") if class_icdf: diff --git a/pymc/distributions/truncated.py b/pymc/distributions/truncated.py index 4b984e4c41..deec06cd2f 100644 --- a/pymc/distributions/truncated.py +++ b/pymc/distributions/truncated.py @@ -44,7 +44,7 @@ from pymc.distributions.transforms import _default_transform from pymc.exceptions import TruncationError from pymc.logprob.abstract import _logcdf, _logprob -from pymc.logprob.basic import icdf, logcdf, logp +from pymc.logprob.basic import icdf, logccdf, logcdf, logp from pymc.math import logdiffexp from pymc.pytensorf import collect_default_updates from pymc.util import check_dist_not_registered @@ -211,6 +211,23 @@ def _create_logcdf_exprs( upper_logcdf = graph_replace(lower_logcdf, {lower_value: upper_value}) return lower_logcdf, upper_logcdf + @staticmethod + def _create_lower_logccdf_expr( + base_rv: TensorVariable, + value: TensorVariable, + lower: TensorVariable, + ) -> TensorVariable: + """Create logccdf expression at lower bound for base_rv. + + Uses `value` as a template for broadcasting. This is numerically more + stable than computing log(1 - exp(logcdf)) for distributions that have + a registered logccdf method. + """ + # For left truncated discrete RVs, we need to include the whole lower bound. + lower_value = lower - 1 if base_rv.type.dtype.startswith("int") else lower + lower_value = pt.full_like(value, lower_value, dtype=config.floatX) + return logccdf(base_rv, lower_value, warn_rvs=False) + def update(self, node: Apply): """Return the update mapping for the internal RNGs. @@ -401,7 +418,7 @@ def truncated_logprob(op, values, *inputs, **kwargs): if is_lower_bounded and is_upper_bounded: lognorm = logdiffexp(upper_logcdf, lower_logcdf) elif is_lower_bounded: - lognorm = pt.log1mexp(lower_logcdf) + lognorm = TruncatedRV._create_lower_logccdf_expr(base_rv, value, lower) elif is_upper_bounded: lognorm = upper_logcdf @@ -438,7 +455,7 @@ def truncated_logcdf(op: TruncatedRV, value, *inputs, **kwargs): if is_lower_bounded and is_upper_bounded: lognorm = logdiffexp(upper_logcdf, lower_logcdf) elif is_lower_bounded: - lognorm = pt.log1mexp(lower_logcdf) + lognorm = TruncatedRV._create_lower_logccdf_expr(base_rv, value, lower) elif is_upper_bounded: lognorm = upper_logcdf diff --git a/pymc/logprob/__init__.py b/pymc/logprob/__init__.py index 2e67a6c55b..5f0dc65c6b 100644 --- a/pymc/logprob/__init__.py +++ b/pymc/logprob/__init__.py @@ -39,6 +39,7 @@ from pymc.logprob.basic import ( conditional_logp, icdf, + logccdf, logcdf, logp, transformed_conditional_logp, @@ -59,6 +60,7 @@ __all__ = ( "icdf", + "logccdf", "logcdf", "logp", ) diff --git a/pymc/logprob/abstract.py b/pymc/logprob/abstract.py index 4b8808a3bd..32c9690861 100644 --- a/pymc/logprob/abstract.py +++ b/pymc/logprob/abstract.py @@ -42,7 +42,7 @@ from pytensor.graph import Apply, Op, Variable from pytensor.graph.utils import MetaType -from pytensor.tensor import TensorVariable +from pytensor.tensor import TensorVariable, log1mexp from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.random.op import RandomVariable @@ -108,6 +108,45 @@ def _logcdf_helper(rv, value, **kwargs): return logcdf +@singledispatch +def _logccdf( + op: Op, + value: TensorVariable, + *inputs: TensorVariable, + **kwargs, +): + """Create a graph for the log complementary CDF (log survival function) of a ``RandomVariable``. + + This function dispatches on the type of ``op``, which should be a subclass + of ``RandomVariable``. If you want to implement new logccdf graphs + for a ``RandomVariable``, register a new function on this dispatcher. + + The log complementary CDF is defined as log(1 - CDF(x)), also known as the + log survival function. For distributions with a numerically stable implementation, + this should be used instead of computing log(1 - exp(logcdf)). + """ + raise NotImplementedError(f"LogCCDF method not implemented for {op}") + + +def _logccdf_helper(rv, value, **kwargs): + """Helper that calls `_logccdf` dispatcher with fallback to log1mexp(logcdf). + + If a numerically stable `_logccdf` implementation is registered for the + distribution, it will be used. Otherwise, falls back to computing + `log(1 - exp(logcdf))` which may be numerically unstable in the tails. + """ + try: + logccdf = _logccdf(rv.owner.op, value, *rv.owner.inputs, name=rv.name, **kwargs) + except NotImplementedError: + logcdf = _logcdf_helper(rv, value, **kwargs) + logccdf = log1mexp(logcdf) + + if rv.name: + logccdf.name = f"{rv.name}_logccdf" + + return logccdf + + @singledispatch def _icdf( op: Op, diff --git a/pymc/logprob/basic.py b/pymc/logprob/basic.py index e45e14a723..348e513a77 100644 --- a/pymc/logprob/basic.py +++ b/pymc/logprob/basic.py @@ -53,6 +53,7 @@ from pymc.logprob.abstract import ( MeasurableOp, _icdf_helper, + _logccdf_helper, _logcdf_helper, _logprob, _logprob_helper, @@ -302,6 +303,70 @@ def normal_logcdf(value, mu, sigma): return expr +def logccdf(rv: TensorVariable, value: TensorLike, warn_rvs=True, **kwargs) -> TensorVariable: + """Create a graph for the log complementary CDF (log survival function) of a random variable. + + The log complementary CDF is defined as log(1 - CDF(x)), also known as the + log survival function. For distributions with a numerically stable implementation, + this is more accurate than computing log(1 - exp(logcdf)). + + Parameters + ---------- + rv : TensorVariable + value : tensor_like + Should be the same type (shape and dtype) as the rv. + warn_rvs : bool, default True + Warn if RVs were found in the logccdf graph. + This can happen when a variable has other random variables as inputs. + In that case, those random variables should be replaced by their respective values. + + Returns + ------- + logccdf : TensorVariable + + Raises + ------ + RuntimeError + If the logccdf cannot be derived. + + Examples + -------- + Create a compiled function that evaluates the logccdf of a variable + + .. code-block:: python + + import pymc as pm + import pytensor.tensor as pt + + mu = pt.scalar("mu") + rv = pm.Normal.dist(mu, 1.0) + + value = pt.scalar("value") + rv_logccdf = pm.logccdf(rv, value) + + # Use .eval() for debugging + print(rv_logccdf.eval({value: 0.9, mu: 0.0})) # -1.5272506 + + # Compile a function for repeated evaluations + rv_logccdf_fn = pm.compile_pymc([value, mu], rv_logccdf) + print(rv_logccdf_fn(value=0.9, mu=0.0)) # -1.5272506 + + """ + value = pt.as_tensor_variable(value, dtype=rv.dtype) + try: + return _logccdf_helper(rv, value, **kwargs) + except NotImplementedError: + # Try to rewrite rv + fgraph = construct_ir_fgraph({rv: value}) + [ir_valued_rv] = fgraph.outputs + [ir_rv, ir_value] = ir_valued_rv.owner.inputs + expr = _logccdf_helper(ir_rv, ir_value, **kwargs) + [expr] = cleanup_ir([expr]) + if warn_rvs: + _warn_rvs_in_inferred_graph([expr]) + return expr + + def icdf(rv: TensorVariable, value: TensorLike, warn_rvs=True, **kwargs) -> TensorVariable: """Create a graph for the inverse CDF of a random variable. diff --git a/pymc/logprob/binary.py b/pymc/logprob/binary.py index 27449e2d2c..9dfd273365 100644 --- a/pymc/logprob/binary.py +++ b/pymc/logprob/binary.py @@ -25,6 +25,7 @@ from pymc.logprob.abstract import ( MeasurableElemwise, + _logccdf_helper, _logcdf_helper, _logprob, _logprob_helper, @@ -95,7 +96,7 @@ def comparison_logprob(op, values, base_rv, operand, **kwargs): base_rv_op = base_rv.owner.op logcdf = _logcdf_helper(base_rv, operand, **kwargs) - logccdf = pt.log1mexp(logcdf) + logccdf = _logccdf_helper(base_rv, operand, **kwargs) condn_exp = pt.eq(value, np.array(True)) diff --git a/pymc/logprob/censoring.py b/pymc/logprob/censoring.py index 411b8162a8..ab7f84e12a 100644 --- a/pymc/logprob/censoring.py +++ b/pymc/logprob/censoring.py @@ -47,7 +47,7 @@ from pytensor.tensor.math import ceil, clip, floor, round_half_to_even from pytensor.tensor.variable import TensorConstant -from pymc.logprob.abstract import MeasurableElemwise, _logcdf, _logprob +from pymc.logprob.abstract import MeasurableElemwise, _logccdf_helper, _logcdf, _logprob from pymc.logprob.rewriting import measurable_ir_rewrites_db from pymc.logprob.utils import CheckParameterValue, filter_measurable_variables @@ -119,7 +119,8 @@ def clip_logprob(op, values, base_rv, lower_bound, upper_bound, **kwargs): if not (isinstance(upper_bound, TensorConstant) and np.all(np.isinf(upper_bound.value))): is_upper_bounded = True - logccdf = pt.log1mexp(logcdf) + logccdf = _logccdf_helper(base_rv, value, **kwargs) + # For right clipped discrete RVs, we need to add an extra term # corresponding to the pmf at the upper bound if base_rv.dtype.startswith("int"): diff --git a/pymc/logprob/transforms.py b/pymc/logprob/transforms.py index 8d2bbacd26..625587f91d 100644 --- a/pymc/logprob/transforms.py +++ b/pymc/logprob/transforms.py @@ -111,6 +111,7 @@ MeasurableOp, _icdf, _icdf_helper, + _logccdf_helper, _logcdf, _logcdf_helper, _logprob, @@ -248,9 +249,10 @@ def measurable_transform_logcdf(op: MeasurableTransform, value, *inputs, **kwarg logcdf = _logcdf_helper(measurable_input, backward_value) if is_discrete: - logccdf = pt.log1mexp(_logcdf_helper(measurable_input, backward_value - 1)) + # For discrete distributions, P(X >= t) = P(X > t-1) + logccdf = _logccdf_helper(measurable_input, backward_value - 1) else: - logccdf = pt.log1mexp(logcdf) + logccdf = _logccdf_helper(measurable_input, backward_value) if isinstance(op.scalar_op, MONOTONICALLY_INCREASING_OPS): pass diff --git a/tests/distributions/test_censored.py b/tests/distributions/test_censored.py index 5201f16a82..127e38dc2a 100644 --- a/tests/distributions/test_censored.py +++ b/tests/distributions/test_censored.py @@ -213,3 +213,27 @@ def test_censored_logcdf_discrete(self): logcdf(censored_cat, eval_points).eval(), expected_interval, ) + + @pytest.mark.parametrize( + "censoring_side,bound_value", + [ + ("right", 100.0), + ("left", -100.0), + ], + ) + def test_censored_logp_numerical_stability(self, censoring_side, bound_value): + """Censored logp at 100 sigma should be finite, not -inf.""" + ref_scipy = sp.stats.norm(0, 1) + + normal_dist = pm.Normal.dist(mu=0.0, sigma=1.0) + if censoring_side == "right": + censored = pm.Censored.dist(normal_dist, lower=None, upper=bound_value) + expected_logp = ref_scipy.logsf(bound_value) + else: + censored = pm.Censored.dist(normal_dist, lower=bound_value, upper=None) + expected_logp = ref_scipy.logcdf(bound_value) + + logp_at_bound = logp(censored, bound_value).eval() + + assert np.isfinite(logp_at_bound) + assert np.isclose(logp_at_bound, expected_logp, rtol=1e-6) diff --git a/tests/distributions/test_distribution.py b/tests/distributions/test_distribution.py index 48ad0579b6..d827a392a2 100644 --- a/tests/distributions/test_distribution.py +++ b/tests/distributions/test_distribution.py @@ -237,6 +237,46 @@ def rv_op(cls, size=None, rng=None): resized_rv = change_dist_size(rv, new_size=5, expand=True) assert resized_rv.type.shape == (5,) + def test_logccdf_with_extended_signature(self): + """Test logccdf registration for SymbolicRandomVariable with extended_signature.""" + from pymc.distributions.dist_math import normal_lccdf + from pymc.distributions.distribution import Distribution + + class TestDistWithLogccdf(Distribution): + # Create a SymbolicRandomVariable type with extended_signature + rv_type = type( + "TestRVWithLogccdf", + (SymbolicRandomVariable,), + {"extended_signature": "[rng],[size],(),()->[rng],()"}, + ) + + @classmethod + def dist(cls, mu, sigma, **kwargs): + mu = pt.as_tensor(mu) + sigma = pt.as_tensor(sigma) + return super().dist([mu, sigma], **kwargs) + + @classmethod + def rv_op(cls, mu, sigma, size=None, rng=None): + rng = normalize_rng_param(rng) + size = normalize_size_param(size) + # Internally uses Normal, but wrapped in SymbolicRandomVariable + next_rng, draws = Normal.dist(mu, sigma, size=size, rng=rng).owner.outputs + return cls.rv_type( + inputs=[rng, size, mu, sigma], + outputs=[next_rng, draws], + ndim_supp=0, + )(rng, size, mu, sigma) + + # This logccdf will be registered via params_idxs path + def logccdf(value, mu, sigma): + return normal_lccdf(mu, sigma, value) + + rv = TestDistWithLogccdf.dist(0, 1) + result = pm.logccdf(rv, 0.5).eval() + expected = st.norm(0, 1).logsf(0.5) # ≈ -0.994 + npt.assert_allclose(result, expected) + def test_distribution_op_registered(): """Test that returned Ops are registered as virtual subclasses of the respective PyMC distributions.""" diff --git a/tests/logprob/test_abstract.py b/tests/logprob/test_abstract.py index 5d8024cdca..96cea02356 100644 --- a/tests/logprob/test_abstract.py +++ b/tests/logprob/test_abstract.py @@ -46,7 +46,7 @@ import pymc as pm from pymc.logprob.abstract import MeasurableElemwise, MeasurableOp, _logcdf_helper -from pymc.logprob.basic import logcdf +from pymc.logprob.basic import logccdf, logcdf def assert_equal_hash(classA, classB): @@ -95,3 +95,67 @@ def test_logcdf_transformed_argument(): pm.TruncatedNormal.dist(0, sigma_value, lower=None, upper=1.0), x_value ).eval() assert np.isclose(observed, expected) + + +def test_logccdf(): + value = pt.vector("value") + x = pm.Normal.dist(0, 1) + + x_logccdf = logccdf(x, value) + np.testing.assert_almost_equal(x_logccdf.eval({value: [0, 1]}), sp.norm(0, 1).logsf([0, 1])) + + +def test_logccdf_numerical_stability(): + """Logccdf at 100 sigma should be finite, not -inf.""" + x = pm.Normal.dist(0, 1) + + result = logccdf(x, 100.0).eval() + expected = sp.norm(0, 1).logsf(100.0) + + assert np.isfinite(result) + np.testing.assert_allclose(result, expected, rtol=1e-6) + + +def test_logccdf_fallback(): + """Distributions without logccdf should fall back to log1mexp(logcdf). + + This test assumes Uniform does not implement logccdf. Implementing one would + not be very useful since the logcdf is very simple and there are no numerical + stability concerns. If Uniform ever gets a logccdf implementation, this test + should be updated to use a different distribution without one. + + Before rewrites, the logccdf graph for Uniform should contain log1mexp. + + Normal implements a specialized logccdf using erfc/erfcx, so its graph, even + before rewrites, should not contain log1mexp. + """ + from pytensor.graph.traversal import ancestors + from pytensor.scalar.math import Log1mexp + from pytensor.tensor.elemwise import Elemwise + + def graph_contains_log1mexp(var): + return any( + v.owner + and isinstance(v.owner.op, Elemwise) + and isinstance(v.owner.op.scalar_op, Log1mexp) + for v in ancestors([var]) + ) + + # Uniform has no logccdf - should use fallback + uniform_logccdf = logccdf(pm.Uniform.dist(0, 1), 0.5) + assert graph_contains_log1mexp(uniform_logccdf) + + # Normal has logccdf - should NOT use fallback + normal_logccdf = logccdf(pm.Normal.dist(0, 1), 0.5) + assert not graph_contains_log1mexp(normal_logccdf) + + +def test_logccdf_discrete(): + mu = 3.0 + x = pm.Poisson.dist(mu=mu) + + test_values = np.array([0, 1, 2, 3, 5, 10]) + result = logccdf(x, test_values).eval() + expected = sp.poisson(mu).logsf(test_values) + + np.testing.assert_allclose(result, expected, rtol=1e-6) diff --git a/tests/logprob/test_transforms.py b/tests/logprob/test_transforms.py index c9aeaa8abf..7eebb54cda 100644 --- a/tests/logprob/test_transforms.py +++ b/tests/logprob/test_transforms.py @@ -45,7 +45,7 @@ from pymc.distributions.continuous import Cauchy, ChiSquared from pymc.distributions.discrete import Bernoulli -from pymc.logprob.basic import conditional_logp, icdf, logcdf, logp +from pymc.logprob.basic import conditional_logp, icdf, logccdf, logcdf, logp from pymc.logprob.transforms import ( ArccoshTransform, ArcsinhTransform, @@ -549,6 +549,31 @@ def test_extra_bijective_rv_transforms(pt_transform, transform): ) +@pytest.mark.parametrize( + "pt_transform, transform", + [ + (pt.erfc, ErfcTransform()), + (pt.erfcx, ErfcxTransform()), + ], +) +def test_monotonically_decreasing_transform_logcdf(pt_transform, transform): + """Test logcdf for monotonically decreasing transforms (Erfc, Erfcx).""" + base_rv = pt.random.normal(0.5, 1, name="base_rv") + rv = pt_transform(base_rv) + + vv = rv.clone() + rv_logcdf = logcdf(rv, vv) + + # For decreasing transform: P(Y <= y) = P(X >= backward(y)) = 1 - P(X < backward(y)) + expected_logcdf = logccdf(base_rv, transform.backward(vv)) + + vv_test = np.array(0.25) + np.testing.assert_allclose( + rv_logcdf.eval({vv: vv_test}), + expected_logcdf.eval({vv: vv_test}), + ) + + def test_cosh_rv_transform(): # Something not centered around 0 is usually better base_rv = pt.random.normal(0.5, 1, size=(2,), name="base_rv") @@ -709,6 +734,10 @@ def test_negated_discrete_rv_transform(): logcdf_fn = pytensor.function([vv], logcdf(rv, vv)) np.testing.assert_allclose(logcdf_fn([-2, -1, 0, 1]), [-np.inf, np.log(p), 0, 0]) + # logccdf: P(Y > y) + logccdf_fn = pytensor.function([vv], logccdf(rv, vv)) + np.testing.assert_allclose(logccdf_fn([-2, -1, 0, 1]), [0, np.log(1 - p), -np.inf, -np.inf]) + with pytest.raises(NotImplementedError): icdf(rv, [-2, -1, 0, 1]) @@ -730,6 +759,13 @@ def test_shifted_discrete_rv_transform(): np.testing.assert_allclose(rv_logcdf_fn(6), 0) assert rv_logcdf_fn(7) == 0 + # logccdf: P(Y > y) + rv_logccdf_fn = pytensor.function([vv], logccdf(rv, vv)) + np.testing.assert_allclose(rv_logccdf_fn(4), 0) + np.testing.assert_allclose(rv_logccdf_fn(5), np.log(p)) + assert rv_logccdf_fn(6) == -np.inf + assert rv_logccdf_fn(7) == -np.inf + # icdf not supported yet with pytest.raises(NotImplementedError): icdf(rv, 0)