tfp.experimental.distributions.marginal_fns.tfp_custom_gradient.custom_gradient
Stay organized with collections
Save and categorize content based on your preferences.
Decorates a function and adds custom derivatives.
tfp.experimental.distributions.marginal_fns.tfp_custom_gradient.custom_gradient(
vjp_fwd=None, vjp_bwd=None, jvp_fn=None, nondiff_argnums=()
)
TF only supports VJPs, so we decorate with tf.custom_gradient.
JAX supports either JVP or VJP. If a custom JVP is provided, then JAX can
transpose to derive a VJP rule. Therefore we prefer jvp_fn if given, but fall
back to the vjp functions otherwise.
Args |
vjp_fwd
|
A function (args) => (output, auxiliaries).
|
vjp_bwd
|
A function (auxiliaries, output_gradient) =>
nondiff_args_gradients. None gradients will be inserted into the correct
positions for nondiff_argnums .
|
jvp_fn
|
A function (nondiff_args, primals, tangents) =>
(primal_out, tangent_out).
|
nondiff_argnums
|
Tuple of argument indices which are not differentiable.
These must integers or other non-Tensors. Tensors with no gradient should
be indicated with a None in the result of vjp_bwd.
|
Returns |
A decorator to be applied to a function f(*args) => output.
|
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.
Last updated 2023-11-21 UTC.
[[["Easy to understand","easyToUnderstand","thumb-up"],["Solved my problem","solvedMyProblem","thumb-up"],["Other","otherUp","thumb-up"]],[["Missing the information I need","missingTheInformationINeed","thumb-down"],["Too complicated / too many steps","tooComplicatedTooManySteps","thumb-down"],["Out of date","outOfDate","thumb-down"],["Samples / code issue","samplesCodeIssue","thumb-down"],["Other","otherDown","thumb-down"]],["Last updated 2023-11-21 UTC."],[],[]]