Skip to content

Conversation

@eclipse1605
Copy link

Description

added log-probability support for leaky-ReLU graphs constructed as

y = switch(x > 0, x, a * x)

where x is a single continuous measurable variable.

notes

  • only supports a single continuous measurable variable.
  • the slope a must be non-measurable and strictly positive.
  • behavior at y == 0 follows the y <= 0 branch (measure-zero set).

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

@ricardoV94
Copy link
Member

The more general requirement which we could try to see of we can cover is a switch where the two branches don't overlap, so when you invert you know for sure which branch it came from.

@eclipse1605
Copy link
Author

that makes sense. correct me if im wrong, but the key property we need is that the two branches map to non overlapping regions in y, so that when we observe a y we can tell which branch it came from and apply the correct inverse + Jacobian.

if so, we can extend the current pattern matcher to y = switch(x > k, f1(x), f2(x)). then we can extract (m1, b1) and (m2, b2) for the two branches. check that each branch is monotone (m != 0). check the images don’t overlap once you restrict to the domains x ? k. if the check passes, build the inverse per branch x = (y - b) / m and add the Jacobian term -log|m|, using a switch on y with the boundary induced by k.

@eclipse1605
Copy link
Author

@ricardoV94 if my understanding is right would you prefer that i extend this PR with that generalization, or should I open a follow-up PR to keep the changes scoped?

@ricardoV94
Copy link
Member

ricardoV94 commented Dec 15, 2025

that makes sense. correct me if im wrong, but the key property we need is that the two branches map to non overlapping regions in y, so that when we observe a y we can tell which branch it came from and apply the correct inverse + Jacobian.

if so, we can extend the current pattern matcher to y = switch(x > k, f1(x), f2(x)). then we can extract (m1, b1) and (m2, b2) for the two branches. check that each branch is monotone (m != 0). check the images don’t overlap once you restrict to the domains x ? k. if the check passes, build the inverse per branch x = (y - b) / m and add the Jacobian term -log|m|, using a switch on y with the boundary induced by k.

Yeah exactly! But you don't need to implement the logic to invert the branches, or the jacobian. If both branches of the switch are measurable, it means PyMC figured out the logp and you can just evaluate it at the final value (that will take care of those details). The question posed by the switch is which of them you need?

Logp for this switch might look something like (pseudocode):

def conditional_switch_logp(value, true_branch_rv, false_branch_rv):
  value_implies_true_branch = f(value)
  # Note the logp is evaluated at the value, the switch just gates which one is selected
  logp = switch(value_implies_true_branch, logp(true_branch1_rv, value), logp(true_branch2_rv, value))
  return logp

I think (need to confirm) that the way things are setup, the rewrite that marks the switch as being measurable only has to worry about whether we meet the constraints you mentioned.

Current code can already figure out the logp(Normal, value), or logp(Normal * a, value), for you.

The strategy may look something like this sequence of checks:

  1. You have a switch
  2. It's not yet Measurable
  3. Both true and false branches are measurable (that means PyMC already known how to get their logp)
  4. We have a simple condition expression, x > k or something along those lines. This is where you find what x even is.
  5. Both branches are related to x
    5.1 If none is connected, this is already handled by the switch mixture machinery switch(cond, x, y), do nothing
    5.2 If only one is connected, and the other is a constant, it's actually a censored process something like switch(x > 0, k, x). But this will never be the case given requirement 3.
    5.3 if only one is connected, and the other is a measurable variable it's also fine, but the checks for invertibility may be trickier? Didn't want to think about this right now. If it seems simple to you let me know.
  6. The setup passes the constraints for invertibility (from what we can infer)

Examples:
y = switch(x > 0, x, x * a), a > 0 -> true_branch = y > 0
y = switch(x > 0, x * a, x * b), a,b > 0 -> true_branch = y > 0
y = switch(x > 0, x * a, x * b), a, b < 0 -> true_branch = y < 0

But also:
y = switch(x > 1, x ** 3, x) -> true_branch = y > 1
y = switch(x > 0, exp(x) - 1, x) -> true_branch = y > 0
y = switch(x > -1, x, exp(x+1) - 2) -> true_branch = y > -1
y = switch(x > 1, exp(x - 1), log(x) + 1) -> true_branch = y > 1

(y is what becomes value in the logp function)

Restricting to the original monotonically increasing leaky RELU case is fine, but I would like to structure the code so it's ready to extend to more cases in the future.

If once you figure it out, you want to extend that's awesome and welcome but not a blocker.

How does that sound?

@eclipse1605
Copy link
Author

@ricardoV94 thanks, that makes sense, i refactored the implementation to follow that approach.

  • rewrote find_measurable_leaky_relu_switch so it now only tags the switch as measurable when both branches are already measurable, so we can delegate all inversion/Jacobian details to existing logprob rules for each branch.
  • the _logprob for MeasurableLeakyReLUSwitch now just gates between branch logps evaluated at the observed value:
    switch(value > 0, _logprob_helper(x, value), _logprob_helper(neg_branch, value))
  • kept the runtime CheckParameterValue("leaky_relu slope > 0") guard to ensure the “value implies branch” predicate is valid, and attached it to the returned expression so it can’t get optimized away.

this is currently scoped to the leaky ReLU pattern, but the structure is such that extending to other non overlapping switch patterns should be straightforward (separate predicate + constraints logic). if approved ill go ahead trying to implement the general “non-overlapping images” framework.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Derive logprob of Leaky ReLU transform

2 participants