These are helper methods for Losses and Metrics and will be module private when Java modularity is applied to TensorFlow Java. These methods should not be used outside of the losses and metrics packages.
Public Constructors
Public Methods
static <T extends TNumber> Operand<TInt32> | |
static <T extends TNumber> Operand<T> |
computeWeightedLoss(Ops tf, Operand<T> loss, Reduction reduction, Operand<T> sampleWeight)
Computes the weighted loss
|
static <T extends TNumber> Operand<T> |
rangeCheck(Ops tf, String prefix, Operand<T> values, Operand<T> minValue, Operand<T> maxValue)
Perform an inclusive range check on the values
|
static <T extends TNumber> LossTuple<T> |
removeSqueezableDimensions(Ops tf, Operand<T> labels, Operand<T> predictions)
Squeeze last dim if ranks differ from expected by exactly 1.
|
static <T extends TNumber> LossTuple<T> |
removeSqueezableDimensions(Ops tf, Operand<T> labels, Operand<T> predictions, int expectedRankDiff)
Squeeze last dim if ranks differ from expected by exactly 1.
|
static <T extends TNumber> Operand<T> | |
static <T extends TNumber> LossTuple<T> |
squeezeOrExpandDimensions(Ops tf, Operand<T> labels, Operand<T> predictions)
Squeeze or expand last dimension if needed with a sampleWeights of one.
|
static <T extends TNumber> LossTuple<T> |
squeezeOrExpandDimensions(Ops tf, Operand<T> labels, Operand<T> predictions, Operand<T> sampleWeights)
Squeeze or expand last dimension if needed.
|
static <T extends TNumber> Operand<T> |
valueCheck(Ops tf, String prefix, Operand<T> values, Operand<T> allowedValues)
Checks to see if all the values are in the allowed values set.
|
Inherited Methods
Public Constructors
public LossesHelper ()
Public Methods
public static Operand<TInt32> allAxes (Ops tf, Operand<T> op)
Gets a Constant integer array representing all the axes of the operand.
Parameters
tf | the TensorFlow Ops |
---|---|
op | the TensorFlow Ops |
Returns
- a Constant that represents all the axes of the operand.
public static Operand<T> computeWeightedLoss (Ops tf, Operand<T> loss, Reduction reduction, Operand<T> sampleWeight)
Computes the weighted loss
Parameters
tf | the TensorFlow Ops |
---|---|
loss | the unweighted loss |
reduction | the type of reduction |
sampleWeight | the sample weight, if null then this defaults to one. |
Returns
- the weighted loss
public static Operand<T> rangeCheck (Ops tf, String prefix, Operand<T> values, Operand<T> minValue, Operand<T> maxValue)
Perform an inclusive range check on the values
Parameters
tf | the TensorFlow Ops |
---|---|
prefix | A String prefix to include in the error message |
values | the values to check |
minValue | the minimum value |
maxValue | the maximum value |
Returns
- the values possibly with control dependencies if the TensorFlow Ops represents a Graph Session
Throws
IllegalArgumentException | if the TensorFlow Ops represents an Eager Session |
---|
public static LossTuple<T> removeSqueezableDimensions (Ops tf, Operand<T> labels, Operand<T> predictions)
Squeeze last dim if ranks differ from expected by exactly 1.
Parameters
tf | the TensorFlowOps |
---|---|
labels | Label values, a Tensor whose dimensions match predictions
. |
predictions | Predicted values, a Tensor of arbitrary dimensions. |
Returns
labels
andpredictions
, possibly with last dim squeezed.
public static LossTuple<T> removeSqueezableDimensions (Ops tf, Operand<T> labels, Operand<T> predictions, int expectedRankDiff)
Squeeze last dim if ranks differ from expected by exactly 1.
Parameters
tf | the TensorFlowOps |
---|---|
labels | Label values, a Operand whose dimensions match predictions
. |
predictions | Predicted values, a Tensor of arbitrary dimensions. |
expectedRankDiff | Expected result of rank(predictions) - rank(labels) . |
Returns
labels
andpredictions
, possibly with last dim squeezed.
public static Operand<T> safeMean (Ops tf, Operand<T> losses, long numElements)
Computes a safe mean of the losses.
Parameters
tf | the TensorFlow Ops |
---|---|
losses | Operand whose elements contain individual loss measurements. |
numElements | The number of measurable elements in losses . |
Returns
- A scalar representing the mean of
losses
. IfnumElements
is zero, then zero is returned.
public static LossTuple<T> squeezeOrExpandDimensions (Ops tf, Operand<T> labels, Operand<T> predictions)
Squeeze or expand last dimension if needed with a sampleWeights of one.
- Squeezes last dim of
predictions
orlabels
if their rank differs by 1 (usingremoveSqueezableDimensions(Ops, Operand<T>, Operand<T>)
). - Squeezes or expands last dim of
sampleWeight
if its rank differs by 1 from the new rank ofpredictions
. IfsampleWeight
is scalar, it is kept scalar.
Parameters
tf | the TensorFlow Ops |
---|---|
labels | Optional label Operand whose dimensions match prediction
. |
predictions | Predicted values, a Operand of arbitrary dimensions. |
Returns
- LossTuple of
prediction
,label
,sampleWeight
will be null. Each of them possibly has the last dimension squeezed,sampleWeight
could be extended by one dimension. IfsampleWeight
is null, (prediction, label) is returned.
public static LossTuple<T> squeezeOrExpandDimensions (Ops tf, Operand<T> labels, Operand<T> predictions, Operand<T> sampleWeights)
Squeeze or expand last dimension if needed.
- Squeezes last dim of
predictions
orlabels
if their rank do not differ by 1. - Squeezes or expands last dim of
sampleWeight
if its rank differs by 1 from the new rank ofpredictions
. IfsampleWeight
is scalar, it is kept scalar.
Parameters
tf | the TensorFlow Ops |
---|---|
labels | Optional label Operand whose dimensions match prediction
. |
predictions | Predicted values, a Operand of arbitrary dimensions. |
sampleWeights | Optional sample weight(s) Operand whose dimensions match
prediction . |
Returns
- LossTuple of
predictions
,labels
andsampleWeight
. Each of them possibly has the last dimension squeezed,sampleWeight
could be extended by one dimension. IfsampleWeight
is null, only the possibly shape modifiedpredictions
andlabels
are returned.
public static Operand<T> valueCheck (Ops tf, String prefix, Operand<T> values, Operand<T> allowedValues)
Checks to see if all the values are in the allowed values set. Running the operand in Graph
mode will throw TFInvalidArgumentException
, if at least one
value is not in the allowed values set. In Eager mode, this method will throw an IllegalArgumentException
if at least one value is not in the allowed values set.
Parameters
tf | The TensorFlow Ops |
---|---|
prefix | A String prefix to include in the error message |
values | the values to check |
allowedValues | the allowed values |
Returns
- the values possibly with control dependencies if the TensorFlow Ops represents a Graph Session
Throws
IllegalArgumentException | if the Session is in Eager mode and at least one value is not in the allowed values set |
---|