View source on GitHub |
Builds a trainable blockwise tf.linalg.LinearOperator
.
tfp.experimental.vi.util.build_trainable_linear_operator_block(
*args, seed=None, **kwargs
)
Used in the notebooks
Used in the tutorials |
---|
This function returns a trainable blockwise LinearOperator
. If operators
is a flat list, it is interpreted as blocks along the diagonal of the
structure and an instance of tf.linalg.LinearOperatorBlockDiag
is returned.
If operators
is a doubly nested list, then a
tf.linalg.LinearOperatorBlockLowerTriangular
instance is returned, with
the block in row i
column j
(i >= j
) given by operators[i][j]
.
The operators
list may contain LinearOperator
instances, LinearOperator
subclasses, or callables defining custom constructors (see example below).
The dimensions of the blocks are given by block_dims
; this argument may be
omitted if operators
contains only LinearOperator
instances.
Args | |
---|---|
operators
|
A list or tuple containing LinearOperator subclasses,
LinearOperator instances, and/or callables returning
(init_fn, apply_fn) pairs. If the list is flat, a
tf.linalg.LinearOperatorBlockDiag instance is returned. Otherwise, the
list must be singly nested, with the
first element of length 1, second element of length 2, etc.; the
elements of the outer list are interpreted as rows of a lower-triangular
block structure, and a tf.linalg.LinearOperatorBlockLowerTriangular
instance is returned. Callables contained in the lists must take two
arguments -- shape , the shape of the parameter instantiating the
LinearOperator , and dtype , the tf.dtype of the LinearOperator --
and return a further pair of callables representing a stateless trainable
operator (see example below).
|
block_dims
|
List or tuple of integers, representing the sizes of the blocks
along one dimension of the (square) blockwise LinearOperator . If
operators contains only LinearOperator instances, block_dims may be
None and the dimensions are inferred.
|
batch_shape
|
Batch shape of the LinearOperator .
|
dtype
|
tf.dtype of the LinearOperator .
|
name
|
str, name for tf.name_scope . seed: PRNG seed; see tfp.random.sanitize_seed for details.
|
Returns | |
---|---|
instance
|
instance parameterized by trainable tf.Variable s.
|
Examples
To build a 5x5 trainable LinearOperatorBlockDiag
given LinearOperator
subclasses and block_dims
:
op = build_trainable_linear_operator_block(
operators=(tf.linalg.LinearOperatorDiag,
tf.linalg.LinearOperatorLowerTriangular),
block_dims=[3, 2],
dtype=tf.float32)
If operators
contains only LinearOperator
instances, the block_dims
argument is not necessary:
# Builds a 6x6 `LinearOperatorBlockDiag` with batch shape `(4,).
op = build_trainable_linear_operator_block(
operators=(tf.linalg.LinearOperatorDiag(tf.Variable(tf.ones((4, 3)))),
tf.linalg.LinearOperatorFullMatrix([4.]),
tf.linalg.LinearOperatorIdentity(2)))
A custom operator constructor may be specified as a callable taking
arguments shape
and dtype
, and returning a pair of callables
(init_fn, apply_fn)
describing a parameterized operator, with the following
signatures:
raw_parameters = init_fn(seed)
linear_operator = apply_fn(raw_parameters)
For example, to define a custom initialization for a diagonal operator:
import functools
def diag_operator_with_uniform_initialization(shape, dtype):
init_fn = functools.partial(
samplers.uniform, shape, maxval=2., dtype=dtype)
apply_fn = lambda scale_diag: tf.linalg.LinearOperatorDiag(
scale_diag, is_non_singular=True)
return init_fn, apply_fn
# Build an 8x8 `LinearOperatorBlockLowerTriangular`, with our custom diagonal
# operator in the upper left block, and `LinearOperator` subclasses in the
# lower two blocks.
op = build_trainable_linear_operator_block(
operators=(diag_operator_with_uniform_initialization,
(tf.linalg.LinearOperatorFullMatrix,
tf.linalg.LinearOperatorLowerTriangular)),
block_dims=[4, 4],
dtype=tf.float64)