|  View source on GitHub | 
Returns (nested) explict dtype from args if there is one.
tfp.experimental.distributions.marginal_fns.ps.dtype_util.common_dtype(
    args, dtype_hint=None
)
| Returns | |
|---|---|
| dtype | The (nested) dtype common across all elements of args, orNone. | 
Examples
Usage with non-nested dtype:
x = tf.ones([3, 4], dtype=tf.float64)
y = 4.
z = None
common_dtype([x, y, z], dtype_hint=tf.float32)  # ==> tf.float64
common_dtype([y, z], dtype_hint=tf.float32)     # ==> tf.float32
# The arg to `common_dtype` can be an arbitrary nested structure; it is
# flattened, and the common dtype of its contents is returned.
common_dtype({'x': x, 'yz': (y, z)})
# ==> tf.float64
Usage with nested dtype:
# Define `x` and `y` as JointDistributions with the same nested dtype.
x = tfd.JointDistributionNamed(
    {'a': tfd.Uniform(np.float64(0.), 1.),
     'b': tfd.JointDistributionSequential(
        [tfd.Normal(0., 2.), tfd.Bernoulli(0.4)])})
x.dtype  # ==> {'a': tf.float64, 'b': [tf.float32, tf.int32]}
y = tfd.JointDistributionNamed(
    {'a': tfd.LogitNormal(np.float64(0.), 1.),
     'b': tfd.JointDistributionSequential(
        [tfd.Normal(-1., 1.), tfd.Bernoulli(0.6)])})
y.dtype  # ==> {'a': tf.float64, 'b': [tf.float32, tf.int32]}
# Pack x and y into an arbitrary nested structure and pass it to
# `common_dtype`.
args0 = [x, y]
common_dtype(args0)  # ==> {'a': tf.float64, 'b': [tf.float32, tf.int32]}
# If `dtype_hint` is not structured, the nested structure of the argument
# to `common_dtype` is flattened and ignored, and only the nested structures
# of the dtypes are relevant.
args1 = {'x': x, 'yz': {'y': y, 'z': None} }
common_dtype(args1)  # ==> {'a': tf.float64, 'b': [tf.float32, tf.int32]}
# Use structured `dtype_hint` to indicate the structure of the expected dtype.
# In this example, `x` is an object with structured dtype, and `t` is a
# a structure of objects whose dtypes are compatible with the corresponding
# components of `x.dtype`. Without structured `dtype_hint`, this example
# would fail, since the args `[x, t]` would be flattened entirely, and the
# structured `x.dtype` is incompatible with the non-structured `float32`
# contained in `t`.
t = {'a': [1., 2., 3.], 'b': [np.float32(1.), [[4, 5]]]}
common_dtype([x, t], dtype_hint={'a': None, 'b': [None, None]})
#   ==> {'a': tf.float64, 'b': [tf.float32, tf.int32]}