LossesHelper

public class LossesHelper

These are helper methods for Losses and Metrics and will be module private when Java modularity is applied to TensorFlow Java. These methods should not be used outside of the losses and metrics packages.

Public Constructors

Public Methods

static <T extends TNumber> Operand<TInt32>
allAxes(Ops tf, Operand<T> op)
Gets a Constant integer array representing all the axes of the operand.
static <T extends TNumber> Operand<T>
computeWeightedLoss(Ops tf, Operand<T> loss, Reduction reduction, Operand<T> sampleWeight)
Computes the weighted loss
static <T extends TNumber> Operand<T>
rangeCheck(Ops tf, String prefix, Operand<T> values, Operand<T> minValue, Operand<T> maxValue)
Perform an inclusive range check on the values
static <T extends TNumber> LossTuple<T>
removeSqueezableDimensions(Ops tf, Operand<T> labels, Operand<T> predictions)
Squeeze last dim if ranks differ from expected by exactly 1.
static <T extends TNumber> LossTuple<T>
removeSqueezableDimensions(Ops tf, Operand<T> labels, Operand<T> predictions, int expectedRankDiff)
Squeeze last dim if ranks differ from expected by exactly 1.
static <T extends TNumber> Operand<T>
safeMean(Ops tf, Operand<T> losses, long numElements)
Computes a safe mean of the losses.
static <T extends TNumber> LossTuple<T>
squeezeOrExpandDimensions(Ops tf, Operand<T> labels, Operand<T> predictions)
Squeeze or expand last dimension if needed with a sampleWeights of one.
static <T extends TNumber> LossTuple<T>
squeezeOrExpandDimensions(Ops tf, Operand<T> labels, Operand<T> predictions, Operand<T> sampleWeights)
Squeeze or expand last dimension if needed.
static <T extends TNumber> Operand<T>
valueCheck(Ops tf, String prefix, Operand<T> values, Operand<T> allowedValues)
Checks to see if all the values are in the allowed values set.

Inherited Methods

Public Constructors

public LossesHelper ()

Public Methods

public static Operand<TInt32> allAxes (Ops tf, Operand<T> op)

Gets a Constant integer array representing all the axes of the operand.

Parameters
tf the TensorFlow Ops
op the TensorFlow Ops
Returns
  • a Constant that represents all the axes of the operand.

public static Operand<T> computeWeightedLoss (Ops tf, Operand<T> loss, Reduction reduction, Operand<T> sampleWeight)

Computes the weighted loss

Parameters
tf the TensorFlow Ops
loss the unweighted loss
reduction the type of reduction
sampleWeight the sample weight, if null then this defaults to one.
Returns
  • the weighted loss

public static Operand<T> rangeCheck (Ops tf, String prefix, Operand<T> values, Operand<T> minValue, Operand<T> maxValue)

Perform an inclusive range check on the values

Parameters
tf the TensorFlow Ops
prefix A String prefix to include in the error message
values the values to check
minValue the minimum value
maxValue the maximum value
Returns
  • the values possibly with control dependencies if the TensorFlow Ops represents a Graph Session
Throws
IllegalArgumentException if the TensorFlow Ops represents an Eager Session

public static LossTuple<T> removeSqueezableDimensions (Ops tf, Operand<T> labels, Operand<T> predictions)

Squeeze last dim if ranks differ from expected by exactly 1.

Parameters
tf the TensorFlowOps
labels Label values, a Tensor whose dimensions match predictions .
predictions Predicted values, a Tensor of arbitrary dimensions.
Returns
  • labels and predictions, possibly with last dim squeezed.

public static LossTuple<T> removeSqueezableDimensions (Ops tf, Operand<T> labels, Operand<T> predictions, int expectedRankDiff)

Squeeze last dim if ranks differ from expected by exactly 1.

Parameters
tf the TensorFlowOps
labels Label values, a Operand whose dimensions match predictions .
predictions Predicted values, a Tensor of arbitrary dimensions.
expectedRankDiff Expected result of rank(predictions) - rank(labels).
Returns
  • labels and predictions, possibly with last dim squeezed.

public static Operand<T> safeMean (Ops tf, Operand<T> losses, long numElements)

Computes a safe mean of the losses.

Parameters
tf the TensorFlow Ops
losses Operand whose elements contain individual loss measurements.
numElements The number of measurable elements in losses.
Returns
  • A scalar representing the mean of losses. If numElements is zero, then zero is returned.

public static LossTuple<T> squeezeOrExpandDimensions (Ops tf, Operand<T> labels, Operand<T> predictions)

Squeeze or expand last dimension if needed with a sampleWeights of one.

  1. Squeezes last dim of predictions or labels if their rank differs by 1 (using removeSqueezableDimensions(Ops, Operand<T>, Operand<T>)).
  2. Squeezes or expands last dim of sampleWeight if its rank differs by 1 from the new rank of predictions. If sampleWeight is scalar, it is kept scalar.

Parameters
tf the TensorFlow Ops
labels Optional label Operand whose dimensions match prediction .
predictions Predicted values, a Operand of arbitrary dimensions.
Returns
  • LossTuple of prediction, label,sampleWeight will be null. Each of them possibly has the last dimension squeezed, sampleWeight could be extended by one dimension. If sampleWeight is null, (prediction, label) is returned.

public static LossTuple<T> squeezeOrExpandDimensions (Ops tf, Operand<T> labels, Operand<T> predictions, Operand<T> sampleWeights)

Squeeze or expand last dimension if needed.

  1. Squeezes last dim of predictions or labels if their rank do not differ by 1.
  2. Squeezes or expands last dim of sampleWeight if its rank differs by 1 from the new rank of predictions. If sampleWeight is scalar, it is kept scalar.

Parameters
tf the TensorFlow Ops
labels Optional label Operand whose dimensions match prediction .
predictions Predicted values, a Operand of arbitrary dimensions.
sampleWeights Optional sample weight(s) Operand whose dimensions match prediction.
Returns
  • LossTuple of predictions, labels and sampleWeight . Each of them possibly has the last dimension squeezed, sampleWeight could be extended by one dimension. If sampleWeight is null, only the possibly shape modified predictions and labels are returned.

public static Operand<T> valueCheck (Ops tf, String prefix, Operand<T> values, Operand<T> allowedValues)

Checks to see if all the values are in the allowed values set. Running the operand in Graph mode will throw TFInvalidArgumentException, if at least one value is not in the allowed values set. In Eager mode, this method will throw an IllegalArgumentException if at least one value is not in the allowed values set.

Parameters
tf The TensorFlow Ops
prefix A String prefix to include in the error message
values the values to check
allowedValues the allowed values
Returns
  • the values possibly with control dependencies if the TensorFlow Ops represents a Graph Session
Throws
IllegalArgumentException if the Session is in Eager mode and at least one value is not in the allowed values set