tfp.experimental.distribute.make_pbroadcast_function
Stay organized with collections
Save and categorize content based on your preferences.
Constructs a function that broadcasts inputs over named axes.
tfp.experimental.distribute.make_pbroadcast_function(
fn, in_axes, out_axes, out_dtype
)
Given a function fn
, make_pbroadcast_function
returns a new one that
applies pbroadcast
to input terms according to axis names provided in
in_axes
and out_axes
. For each output axis in each term out the output of
fn
, inputs that do not have the output axes present are pbroadcasted before
that term is computed.
Args |
fn
|
a callable to be transformed to have proadcasts at its inputs.
|
in_axes
|
A structure of axis names that should match the structure of the
input to fn . If the set of input axes for an input value does not match
the output axes of a particular output value, the gradient of that output
value w.r.t. the input value will be psum-ed over the axes present in the
output but not the input.
|
out_axes
|
A structure of axis names that should match the structure of the
output of fn . The inputs to fn will be pbroadcast-ed before computing
output terms according to their output axes.
|
out_dtype
|
A structure of dtypes that matches the output of fn .
|
Returns |
A new function that applies pbroadcasts to the inputs of the original
function.
|
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.
Last updated 2023-11-21 UTC.
[[["Easy to understand","easyToUnderstand","thumb-up"],["Solved my problem","solvedMyProblem","thumb-up"],["Other","otherUp","thumb-up"]],[["Missing the information I need","missingTheInformationINeed","thumb-down"],["Too complicated / too many steps","tooComplicatedTooManySteps","thumb-down"],["Out of date","outOfDate","thumb-down"],["Samples / code issue","samplesCodeIssue","thumb-down"],["Other","otherDown","thumb-down"]],["Last updated 2023-11-21 UTC."],[],[]]