public class SGD<Model: Differentiable>: Optimizer
where
Model.TangentVector: VectorProtocol & ElementaryFunctions & KeyPathIterable,
Model.TangentVector.VectorSpaceScalar == Float
A stochastic gradient descent (SGD) optimizer.
Implements the stochastic gradient descent algorithm with support for momentum, learning rate decay, and Nesterov momentum. Momentum and Nesterov momentum (a.k.a. the Nesterov accelerated gradient method) are first-order optimization methods that can improve the training speed and convergence rate of gradient descent.
References:
- “A Stochastic Approximation Method” (Robbins and Monro, 1951)
- “On the Stochastic Approximation Method of Robbins and Monro” (Wolfowitz, 1952)
- “Stochastic Estimation of the Maximum of a Regression Function” (Kiefer and Wolfowitz, 1952)
- “Some methods of speeding up the convergence of iteration method” (Polyak, 1964)
- “A method for unconstrained convex minimization problem with the rate of convergence” (Nesterov, 1983)
-
Declaration
public typealias Model = Model
-
The learning rate.
Declaration
public var learningRate: Float
-
The momentum factor. It accelerates stochastic gradient descent in the relevant direction and dampens oscillations.
Declaration
public var momentum: Float
-
The learning rate decay.
Declaration
public var decay: Float
-
Use Nesterov momentum if true.
Declaration
public var nesterov: Bool
-
The velocity state of the model.
Declaration
public var velocity: Model.TangentVector
-
The set of steps taken.
Declaration
public var step: Int
-
Creates an instance for
model
.Declaration
public init( for model: __shared Model, learningRate: Float = 0.01, momentum: Float = 0, decay: Float = 0, nesterov: Bool = false )
Parameters
learningRate
The learning rate. The default value is
0.01
.momentum
The momentum factor that accelerates stochastic gradient descent in the relevant direction and dampens oscillations. The default value is
0
.decay
The learning rate decay. The default value is
0
.nesterov
Use Nesterov momentum iff
true
. The default value istrue
. -
Declaration
public required init(copying other: SGD, to device: Device)