Known Direct Subclasses |
Initializer capable of adapting its scale to the shape of weights tensors.
With distribution=TRUNCATED_NORMAL or NORMAL
, samples are drawn from
a truncated/untruncated normal distribution with a mean of zero and a standard deviation (after
truncation, if used) stddev = Math.sqrt(scale / n)
, where n
is:
- number of input units in the weight tensor, if
mode=FAN_IN
- number of output units, if
mode=FAN_OUT
- average of the numbers of input and output units, if
mode=FAN_AVG
With distribution=UNIFORM
, samples are drawn from a uniform distribution within
[-limit, limit]
, where limit = Math.sqrt(3 * scale / n);
.
Examples:
long seed = 1234l; float scale = 0.1f; VarianceScaling<TFloat32, TFloat32> initializer = new org.tensorflow.framework.initializers.VarianceScaling<>( tf, scale, Mode.FAN_IN, Distribution.UNIFORM, seed); Operand<TFloat32> values = initializer.call(tf.constant(Shape.of(2,2)), TFloat32.class);
Nested Classes
enum | VarianceScaling.Distribution | The random distribution to use when initializing the values. | |
enum | VarianceScaling.Mode | The mode to use for calculating the fan values. |
Constants
double | SCALE_DEFAULT |
Fields
public static final VarianceScaling.Distribution | DISTRIBUTION_DEFAULT | |
public static final VarianceScaling.Mode | MODE_DEFAULT |
Public Constructors
VarianceScaling(Ops tf, long seed)
Creates a VarianceScaling Initializer
|
|
VarianceScaling(Ops tf, double scale, VarianceScaling.Mode mode, VarianceScaling.Distribution distribution, long seed)
Creates a VarianceScaling Initializer
|
Public Methods
Operand<T> |
Inherited Methods
Constants
public static final double SCALE_DEFAULT
Fields
Public Constructors
public VarianceScaling (Ops tf, long seed)
Creates a VarianceScaling Initializer
Parameters
tf | the TensorFlow Ops |
---|---|
seed | sed to create random seeds. |
public VarianceScaling (Ops tf, double scale, VarianceScaling.Mode mode, VarianceScaling.Distribution distribution, long seed)
Creates a VarianceScaling Initializer
Parameters
tf | the TensorFlow Ops |
---|---|
scale | Scaling factor (positive float). |
mode | the mode for the variance |
distribution | Random distribution to use. |
seed | Used to create random seeds. |