-
Notifications
You must be signed in to change notification settings - Fork 2.2k
add logprob support for leaky-ReLU switch transforms #7995
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
add logprob support for leaky-ReLU switch transforms #7995
Conversation
|
The more general requirement which we could try to see of we can cover is a switch where the two branches don't overlap, so when you invert you know for sure which branch it came from. |
|
that makes sense. correct me if im wrong, but the key property we need is that the two branches map to non overlapping regions in if so, we can extend the current pattern matcher to |
|
@ricardoV94 if my understanding is right would you prefer that i extend this PR with that generalization, or should I open a follow-up PR to keep the changes scoped? |
Yeah exactly! But you don't need to implement the logic to invert the branches, or the jacobian. If both branches of the switch are measurable, it means PyMC figured out the logp and you can just evaluate it at the final value (that will take care of those details). The question posed by the switch is which of them you need? Logp for this switch might look something like (pseudocode): def conditional_switch_logp(value, true_branch_rv, false_branch_rv):
value_implies_true_branch = f(value)
# Note the logp is evaluated at the value, the switch just gates which one is selected
logp = switch(value_implies_true_branch, logp(true_branch1_rv, value), logp(true_branch2_rv, value))
return logpI think (need to confirm) that the way things are setup, the rewrite that marks the switch as being measurable only has to worry about whether we meet the constraints you mentioned. Current code can already figure out the logp(Normal, value), or logp(Normal * a, value), for you. The strategy may look something like this sequence of checks:
Examples: But also: (y is what becomes value in the logp function) Restricting to the original monotonically increasing leaky RELU case is fine, but I would like to structure the code so it's ready to extend to more cases in the future. If once you figure it out, you want to extend that's awesome and welcome but not a blocker. How does that sound? |
|
@ricardoV94 thanks, that makes sense, i refactored the implementation to follow that approach.
this is currently scoped to the leaky ReLU pattern, but the structure is such that extending to other non overlapping switch patterns should be straightforward (separate predicate + constraints logic). if approved ill go ahead trying to implement the general “non-overlapping images” framework. |
Description
added log-probability support for leaky-ReLU graphs constructed as
where
xis a single continuous measurable variable.notes
amust be non-measurable and strictly positive.y == 0follows they <= 0branch (measure-zero set).Related Issue
Checklist
Type of change