View source on GitHub |
Decorates a function and adds custom derivatives.
tfp.experimental.distributions.marginal_fns.tfp_custom_gradient.custom_gradient(
vjp_fwd=None, vjp_bwd=None, jvp_fn=None, nondiff_argnums=()
)
TF only supports VJPs, so we decorate with tf.custom_gradient.
JAX supports either JVP or VJP. If a custom JVP is provided, then JAX can transpose to derive a VJP rule. Therefore we prefer jvp_fn if given, but fall back to the vjp functions otherwise.
Returns | |
---|---|
A decorator to be applied to a function f(*args) => output. |