At present, Zygote will use forward mode AD (outsourced to FowardDiff) under 2 circumstances:
- Based on a heuristic for broadcasting
- Upon an explicit call to
Zygote.forwarddiff
As shown by https://github.com/oschulz/ForwardDiffPullbacks.jl, there are a number of cases where being able to make an rrule actually run forward mode AD would be a boon for performance. One particularly salient example from Flux would be RNN pointwise broadcasts, which are currently unfused by Zygote for a massive compute + memory penalty. However, given we are simultaneously moving away from using Zygote-specific APIs downstream, defining rrule(::typeof(pointwise_op), xs...) = Zygote.forwarddiff(...) is a non-starter. Hence, my proposal is to expose the standard frule_via_ad so that downstream code can remain AD agnostic. Under the hood, this would work much the same as Zygote.forwarddiff or ForwardDiffPullbacks.fwddiff do now. It may even be possible to share some implementation details with one of those functions.
Note that this is not a request to make frule_via_ad differentiable in reverse mode. Users would still be responsible for writing their own rrules, but one could imagine swapping out Zygote for Diffractor (which already implements frule_via_ad) without making any code changes. Guarding on RuleConfig{>:HasForwardsMode} would be enough to ensure compatibility with ADs which do not support forward mode.
At present, Zygote will use forward mode AD (outsourced to FowardDiff) under 2 circumstances:
Zygote.forwarddiffAs shown by https://github.com/oschulz/ForwardDiffPullbacks.jl, there are a number of cases where being able to make an rrule actually run forward mode AD would be a boon for performance. One particularly salient example from Flux would be RNN pointwise broadcasts, which are currently unfused by Zygote for a massive compute + memory penalty. However, given we are simultaneously moving away from using Zygote-specific APIs downstream, defining
rrule(::typeof(pointwise_op), xs...) = Zygote.forwarddiff(...)is a non-starter. Hence, my proposal is to expose the standardfrule_via_adso that downstream code can remain AD agnostic. Under the hood, this would work much the same asZygote.forwarddifforForwardDiffPullbacks.fwddiffdo now. It may even be possible to share some implementation details with one of those functions.Note that this is not a request to make
frule_via_addifferentiable in reverse mode. Users would still be responsible for writing their own rrules, but one could imagine swapping out Zygote for Diffractor (which already implementsfrule_via_ad) without making any code changes. Guarding onRuleConfig{>:HasForwardsMode}would be enough to ensure compatibility with ADs which do not support forward mode.