Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 155 additions & 0 deletions pymc/logprob/mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,10 @@
from pytensor.ifelse import IfElse, ifelse
from pytensor.scalar import Switch
from pytensor.scalar import switch as scalar_switch
from pytensor.scalar.basic import GE, GT, LE, LT, Mul
from pytensor.tensor.basic import Join, MakeVector, switch
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.random.rewriting import (
local_dimshuffle_rv_lift,
local_rv_size_lift,
Expand Down Expand Up @@ -80,7 +83,9 @@
measurable_ir_rewrites_db,
subtensor_ops,
)
from pymc.logprob.transforms import MeasurableTransform
from pymc.logprob.utils import (
CheckParameterValue,
check_potential_measurability,
filter_measurable_variables,
get_related_valued_nodes,
Expand Down Expand Up @@ -407,6 +412,80 @@ class MeasurableSwitchMixture(MeasurableElemwise):
measurable_switch_mixture = MeasurableSwitchMixture(scalar_switch)


class MeasurableLeakyReLUSwitch(MeasurableElemwise):
"""A placeholder for leaky-ReLU graphs built via `switch(x > 0, x, a * x)`.

this is an invertible, piecewise-linear transform of a single continuous measurable variable.
"""

valid_scalar_types = (Switch,)


measurable_leaky_relu_switch = MeasurableLeakyReLUSwitch(scalar_switch)


def _is_x_positive_condition(cond: TensorVariable, x: TensorVariable) -> bool:
if cond.owner is None:
return False
if not isinstance(cond.owner.op, Elemwise):
return False
scalar_op = cond.owner.op.scalar_op
if not isinstance(scalar_op, GT | GE | LT | LE):
return False

left, right = cond.owner.inputs

def _is_zero(v: TensorVariable) -> bool:
try:
return pt.get_underlying_scalar_constant_value(v) == 0
except NotScalarConstantError:
return False

# x > 0 or x >= 0
if left is x and _is_zero(right) and isinstance(scalar_op, GT | GE):
return True
# 0 < x or 0 <= x
if right is x and _is_zero(left) and isinstance(scalar_op, LT | LE):
return True
return False


def _extract_leaky_relu_slope(
neg_branch: TensorVariable, x: TensorVariable
) -> TensorVariable | None:
"""Extract slope `a` from `neg_branch` assuming it represents `a * x`.

supports both plain `Elemwise(Mul)` and `MeasurableTransform` scale rewrites.
"""
if neg_branch is x:
return pt.constant(1.0)

if neg_branch.owner is None:
return None

# handle case where `a * x` was already rewritten into a measurable scale transform
if isinstance(neg_branch.owner.op, MeasurableTransform):
op = neg_branch.owner.op
if not isinstance(op.scalar_op, Mul):
return None
# MeasurableTransform takes (measurable_input, scale)
if len(neg_branch.owner.inputs) != 2:
return None
if neg_branch.owner.inputs[op.measurable_input_idx] is not x:
return None
scale = neg_branch.owner.inputs[1 - op.measurable_input_idx]
return cast(TensorVariable, scale)

# plain multiplication
if isinstance(neg_branch.owner.op, Elemwise) and isinstance(neg_branch.owner.op.scalar_op, Mul):
left, right = neg_branch.owner.inputs
if left is x:
return cast(TensorVariable, right)
if right is x:
return cast(TensorVariable, left)
return None


@node_rewriter([switch])
def find_measurable_switch_mixture(fgraph, node):
if isinstance(node.op, MeasurableOp):
Expand All @@ -431,6 +510,51 @@ def find_measurable_switch_mixture(fgraph, node):
return [measurable_switch_mixture(switch_cond, *components)]


@node_rewriter([switch])
def find_measurable_leaky_relu_switch(fgraph, node):
"""Detect `switch(x > 0, x, a * x)` and replace it by a measurable op.

This enables a change-of-variables logprob derivation instead of treating it as a mixture.
"""
if isinstance(node.op, MeasurableOp):
return None

cond, pos_branch, neg_branch = node.inputs

# we only mark the switch measurable once both branches are already measurable.
# so, the switch logprob can simply gate between branch logps (delegating inversion/Jacobian details to each branch).
if set(filter_measurable_variables([pos_branch, neg_branch])) != {pos_branch, neg_branch}:
return None

if not filter_measurable_variables([pos_branch]):
return None
x = cast(TensorVariable, pos_branch)

if x.type.dtype.startswith("int"):
return None

if x.type.broadcastable != node.outputs[0].type.broadcastable:
return None

if not _is_x_positive_condition(cast(TensorVariable, cond), x):
return None

a = _extract_leaky_relu_slope(cast(TensorVariable, neg_branch), x)
if a is None:
return None

if check_potential_measurability([a]):
return None

return [
measurable_leaky_relu_switch(
cast(TensorVariable, cond),
x,
cast(TensorVariable, neg_branch),
)
]


@_logprob.register(MeasurableSwitchMixture)
def logprob_switch_mixture(op, values, switch_cond, component_true, component_false, **kwargs):
[value] = values
Expand All @@ -442,6 +566,30 @@ def logprob_switch_mixture(op, values, switch_cond, component_true, component_fa
)


@_logprob.register(MeasurableLeakyReLUSwitch)
def logprob_leaky_relu_switch(op, values, cond, x, neg_branch, **kwargs):
(value,) = values

a = _extract_leaky_relu_slope(cast(TensorVariable, neg_branch), cast(TensorVariable, x))
if a is None:
raise NotImplementedError("Could not extract leaky-ReLU slope")

# enforce `a > 0` at runtime to ensure invertibility and to make the branch selection predicate depend only on the observed value.
a_is_positive = pt.all(pt.gt(a, 0))

# for `a > 0`, `switch(x > 0, x, a * x)` maps to disjoint regions in `value`: true branch -> value > 0, false branch -> value <= 0.
value_implies_true_branch = pt.gt(value, 0)

logp_expr = pt.switch(
value_implies_true_branch,
_logprob_helper(x, value, **kwargs),
_logprob_helper(neg_branch, value, **kwargs),
)

# attach the parameter check to the returned expression so it can't be optimized away.
return CheckParameterValue("leaky_relu slope > 0")(logp_expr, a_is_positive)


measurable_ir_rewrites_db.register(
"find_measurable_index_mixture",
find_measurable_index_mixture,
Expand All @@ -456,6 +604,13 @@ def logprob_switch_mixture(op, values, switch_cond, component_true, component_fa
"mixture",
)

measurable_ir_rewrites_db.register(
"find_measurable_leaky_relu_switch",
find_measurable_leaky_relu_switch,
"basic",
"transform",
)


class MeasurableIfElse(MeasurableOp, IfElse):
"""Measurable subclass of IfElse operator."""
Expand Down
54 changes: 54 additions & 0 deletions tests/logprob/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@

from pytensor.graph.basic import equal_computations

import pymc as pm

from pymc.distributions.continuous import Cauchy, ChiSquared
from pymc.distributions.discrete import Bernoulli
from pymc.logprob.basic import conditional_logp, icdf, logcdf, logp
Expand Down Expand Up @@ -219,6 +221,7 @@ def test_exp_transform_rv():
logp_fn(y_val),
sp.stats.lognorm(s=1).logpdf(y_val),
)

np.testing.assert_almost_equal(
logcdf_fn(y_val),
sp.stats.lognorm(s=1).logcdf(y_val),
Expand All @@ -229,6 +232,57 @@ def test_exp_transform_rv():
)


def test_leaky_relu_switch_logp_scalar():
a = 0.5
x = pm.Normal.dist(mu=0, sigma=1)
y = pm.math.switch(x > 0, x, a * x)

v_pos = 1.2
np.testing.assert_allclose(
pm.logp(y, v_pos, warn_rvs=False).eval(),
pm.logp(x, v_pos, warn_rvs=False).eval(),
)

v_neg = -2.0
np.testing.assert_allclose(
pm.logp(y, v_neg, warn_rvs=False).eval(),
pm.logp(x, v_neg / a, warn_rvs=False).eval() - np.log(a),
)

# boundary point (measure-zero for continuous RVs): should still produce a finite logp
assert np.isfinite(pm.logp(y, 0.0, warn_rvs=False).eval())


def test_leaky_relu_switch_logp_vectorized():
a = 0.5
x = pm.Normal.dist(mu=0, sigma=1, size=(3,))
y = pm.math.switch(x > 0, x, a * x)

v = np.array([-2.0, 0.0, 1.5])
expected = pm.logp(x, np.where(v > 0, v, v / a), warn_rvs=False).eval() + np.where(
v > 0, 0.0, -np.log(a)
)
np.testing.assert_allclose(pm.logp(y, v, warn_rvs=False).eval(), expected)


def test_leaky_relu_switch_logp_symbolic_slope_checks_positive():
a = pt.scalar("a")
x = pm.Normal.dist(mu=0, sigma=1)
y = pm.math.switch(x > 0, x, a * x)

# positive slope passes
res = pm.logp(y, -1.0, warn_rvs=False).eval({a: 0.5})
expected = pm.logp(x, -1.0 / 0.5, warn_rvs=False).eval() - np.log(0.5)
np.testing.assert_allclose(res, expected)

# non pos slope raises
with pytest.raises(ParameterValueError, match="leaky_relu slope > 0"):
pm.logp(y, -1.0, warn_rvs=False).eval({a: -0.5})

with pytest.raises(ParameterValueError, match="leaky_relu slope > 0"):
pm.logp(y, -1.0, warn_rvs=False).eval({a: 0.0})


def test_log_transform_rv():
base_rv = pt.random.lognormal(0, 1, size=2, name="base_rv")
y_rv = pt.log(base_rv)
Expand Down