tfp.experimental.distribute.make_sharded_log_prob_parts
Stay organized with collections
Save and categorize content based on your preferences.
Constructs a log prob parts function that all-reduces over terms.
tfp.experimental.distribute.make_sharded_log_prob_parts(
log_prob_parts_fn, axis_names
)
Given a log_prob_parts function, this function will return a new one that
includes all-reduce sums over terms according to the is_sharded
property. It
will also add all-reduce sums for the gradient of sharded terms w.r.t.
unsharded terms.
Args |
log_prob_parts_fn
|
a callable that takes in a structured value and returns a
structure of log densities for each of the terms, that when summed returns
a locally correct log-density.
|
axis_names
|
a structure of values that matches the input and output of
log_prob_parts_fn . Each value in axis_names is either None, a string
name of a mapped axis in the JAX backend or any non- Nonevalue in TF
backend, or an iterable thereof corresponding to multiple sharding axes.
If the axis_nameis not None, the returned function will add
all-reduce sum(s) for its term in the log prob calculation. If it is None`, the returned function will have an all-reduce sum over the
gradient of sharded terms w.r.t. to the unsharded value.
|
Returns |
A new log prob parts function that can be run inside of a strategy.
|
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.
Last updated 2023-11-21 UTC.
[[["Easy to understand","easyToUnderstand","thumb-up"],["Solved my problem","solvedMyProblem","thumb-up"],["Other","otherUp","thumb-up"]],[["Missing the information I need","missingTheInformationINeed","thumb-down"],["Too complicated / too many steps","tooComplicatedTooManySteps","thumb-down"],["Out of date","outOfDate","thumb-down"],["Samples / code issue","samplesCodeIssue","thumb-down"],["Other","otherDown","thumb-down"]],["Last updated 2023-11-21 UTC."],[],[]]