tfp.math.custom_gradient
Stay organized with collections
Save and categorize content based on your preferences.
Embeds a custom gradient into a Tensor
.
tfp.math.custom_gradient(
fx, gx, x, fx_gx_manually_stopped=False, name=None
)
This function works by clever application of stop_gradient
. I.e., observe
that:
h(x) = stop_gradient(f(x)) + stop_gradient(g(x)) * (x - stop_gradient(x))
is such that h(x) == stop_gradient(f(x))
and
grad[h(x), x] == stop_gradient(g(x)).
In addition to scalar-domain/scalar-range functions, this function also
supports tensor-domain/scalar-range functions.
Partial Custom Gradient:
Suppose h(x) = htilde(x, y)
. Note that dh/dx = stop(g(x))
but dh/dy =
None
. This is because a Tensor
cannot have only a portion of its gradient
stopped. To circumvent this issue, one must manually stop_gradient
the
relevant portions of f
, g
. For example see the unit-test,
test_works_correctly_fx_gx_manually_stopped
.
Args |
fx
|
Tensor . Output of function evaluated at x .
|
gx
|
Tensor or list of Tensor s. Gradient of function at (each) x .
|
x
|
Tensor or list of Tensor s. Args of evaluation for f .
|
fx_gx_manually_stopped
|
Python bool indicating that fx , gx manually
have stop_gradient applied.
|
name
|
Python str name prefixed to Ops created by this function.
|
Returns |
fx
|
Floating-type Tensor equal to f(x) but which has gradient
stop_gradient(g(x)) .
|
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.
Last updated 2023-11-21 UTC.
[[["Easy to understand","easyToUnderstand","thumb-up"],["Solved my problem","solvedMyProblem","thumb-up"],["Other","otherUp","thumb-up"]],[["Missing the information I need","missingTheInformationINeed","thumb-down"],["Too complicated / too many steps","tooComplicatedTooManySteps","thumb-down"],["Out of date","outOfDate","thumb-down"],["Samples / code issue","samplesCodeIssue","thumb-down"],["Other","otherDown","thumb-down"]],["Last updated 2023-11-21 UTC."],[],[]]