Evaluates a Kronecker-Factored Lattice using hypercube interpolation.
tfl.kronecker_factored_lattice_lib.evaluate_with_hypercube_interpolation(
inputs, scale, bias, kernel, units, num_terms, lattice_sizes, clip_inputs
)
Kronecker-Factored Lattice function is the product of the piece-wise linear
interpolation weights for each dimension of the input.
Args |
inputs
|
Tensor representing points to apply lattice interpolation to. If
units = 1, tensor should be of shape: (batch_size, ..., dims) or list of
dims tensors of same shape (batch_size, ..., 1) . If units > 1,
tensor
should be of shape: (batch_size, ..., units, dims) or list of dims
tensors of same shape (batch_size, ..., units, 1) . A typical shape is
(batch_size, dims) .
|
scale
|
Kronecker-Factored Lattice scale of shape (units, num_terms) .
|
bias
|
Kronecker-Factored Lattice bias of shape (units) .
|
kernel
|
Kronecker-Factored Lattice kernel of shape
(1, lattice_sizes, units * dims, num_terms) .
|
units
|
Output dimension of the Kronecker-Factored Lattice.
|
num_terms
|
Number of independently trained submodels per unit, the outputs
of which are averaged to get the final output.
|
lattice_sizes
|
Number of vertices per dimension.
|
clip_inputs
|
If inputs should be clipped to the input range of the
Kronecker-Factored Lattice.
|
Returns |
Tensor of shape: (batch_size, ..., units) .
|