My Project
Public Member Functions | List of all members
cuda_mlp::CudaSGD Class Reference

SGD with optional momentum and learning-rate decay. More...

Inheritance diagram for cuda_mlp::CudaSGD:
Inheritance graph
[legend]
Collaboration diagram for cuda_mlp::CudaSGD:
Collaboration graph
[legend]

Public Member Functions

 CudaSGD (CublasHandle &handle)
 Construct the optimizer. More...
 
void setLearningRate (CudaScalar lr)
 Set the base learning rate. More...
 
void setMomentum (CudaScalar momentum)
 Set the momentum factor in [0,1) More...
 
void setBatchSize (int batch_size)
 Set minibatch size. More...
 
void setLearningRateDecay (CudaScalar rate, int step_size)
 Configure step-wise learning rate decay. More...
 
void setDimensions (int input_dim, int output_dim)
 Set the input/output dimensions to stride batches. More...
 
void solve (int n, CudaScalar *params, const CudaScalar *input, const CudaScalar *target, int total_samples, const LossGradFun &loss_grad) override
 Run SGD optimization. More...
 
- Public Member Functions inherited from cuda_mlp::CudaMinimizerBase
 CudaMinimizerBase (CublasHandle &handle)
 Construct with a cuBLAS handle reference. More...
 
virtual ~CudaMinimizerBase ()=default
 
int iterations () const noexcept
 Return the number of iterations performed in the last solve. More...
 
void setRecorder (::IterationRecorder< CudaBackend > *recorder)
 Attach a recorder for loss/grad norm history. More...
 
void setMaxIterations (int iters)
 Set maximum number of iterations. More...
 
void setTolerance (CudaScalar tol)
 Set stopping tolerance (interpretation depends on optimizer) More...
 
void setLineSearchParams (int max_iters, CudaScalar c1, CudaScalar rho)
 Configure Armijo line search parameters. More...
 

Additional Inherited Members

- Public Types inherited from cuda_mlp::CudaMinimizerBase
using LossGradFun = std::function< CudaScalar(const CudaScalar *params, CudaScalar *grad, const CudaScalar *input, const CudaScalar *target, int batch)>
 Loss and gradient callback signature. More...
 
using IterHook = std::function< void(int)>
 Optional per-iteration hook signature. More...
 
- Protected Attributes inherited from cuda_mlp::CudaMinimizerBase
CublasHandlehandle_
 cuBLAS handle used by the optimizer More...
 
int max_iters_ = 200
 
int max_line_iters_ = 20
 Iteration limits. More...
 
CudaScalar tol_ = 1e-6f
 
CudaScalar c1_ = 1e-4f
 
CudaScalar rho_ = 0.5f
 Stopping and line-search params. More...
 
int last_iterations_ = 0
 Iterations performed in last run. More...
 
::IterationRecorder< CudaBackend > * recorder_ = nullptr
 Optional recorder for diagnostics. More...
 

Detailed Description

SGD with optional momentum and learning-rate decay.

Constructor & Destructor Documentation

◆ CudaSGD()

cuda_mlp::CudaSGD::CudaSGD ( CublasHandle handle)
inlineexplicit

Construct the optimizer.

Member Function Documentation

◆ setBatchSize()

void cuda_mlp::CudaSGD::setBatchSize ( int  batch_size)
inline

Set minibatch size.

◆ setDimensions()

void cuda_mlp::CudaSGD::setDimensions ( int  input_dim,
int  output_dim 
)
inline

Set the input/output dimensions to stride batches.

◆ setLearningRate()

void cuda_mlp::CudaSGD::setLearningRate ( CudaScalar  lr)
inline

Set the base learning rate.

◆ setLearningRateDecay()

void cuda_mlp::CudaSGD::setLearningRateDecay ( CudaScalar  rate,
int  step_size 
)
inline

Configure step-wise learning rate decay.

Parameters
rateMultiplicative decay factor
step_sizeIterations between decays

◆ setMomentum()

void cuda_mlp::CudaSGD::setMomentum ( CudaScalar  momentum)
inline

Set the momentum factor in [0,1)

◆ solve()

void cuda_mlp::CudaSGD::solve ( int  n,
CudaScalar params,
const CudaScalar input,
const CudaScalar target,
int  total_samples,
const LossGradFun loss_grad 
)
inlineoverridevirtual

Run SGD optimization.

Parameters
nNumber of parameters
paramsParameter vector (device)
inputInput data (device)
targetTarget data (device)
total_samplesTotal number of samples
loss_gradCallback returning batch loss and gradient

Implements cuda_mlp::CudaMinimizerBase.

Here is the call graph for this function:

The documentation for this class was generated from the following file: