3 #include "../common.hpp"
9 #include <autodiff/reverse/var.hpp>
10 #include <autodiff/reverse/var/eigen.hpp>
37 using BatchGradFun = std::function<void(
const V &,
const std::vector<size_t> &, V &)>;
38 using BatchLossFun = std::function<double(
const V &,
const std::vector<size_t> &)>;
88 template <
typename V,
typename BatchFn>
90 BatchFn &g,
const V &weights,
const std::vector<size_t> &indices,
const V &v,
double epsilon = 1e-4) {
91 V w_plus = weights + epsilon * v;
92 V w_minus = weights - epsilon * v;
94 V grad_plus = V::Zero(weights.size());
95 V grad_minus = V::Zero(weights.size());
97 g(w_plus, indices, grad_plus);
98 g(w_minus, indices, grad_minus);
100 return (grad_plus - grad_minus) / (2.0 * epsilon);
106 template <
typename V>
108 int M = s_list.
size();
109 std::vector<double> alpha(M);
113 for (
int i = M - 1; i >= 0; --i) {
114 alpha[i] = rho_list[i] * s_list[i].dot(q);
115 q = q - alpha[i] * y_list[i];
121 double denom = y_list.
back().dot(y_list.
back());
122 if (std::abs(denom) < 1e-12)
125 gamma = s_list.
back().dot(y_list.
back()) / denom;
126 gamma = std::min(std::max(gamma, 1e-6), 1e6);
131 for (
int i = 0; i < M; ++i) {
132 double beta = rho_list[i] * y_list[i].dot(r);
133 r = r + s_list[i] * (alpha[i] - beta);
141 template <
typename V,
typename M>
143 if (N == 0 || batch_size == 0)
return {};
144 if (batch_size >= N) {
145 std::vector<size_t> idx(N);
146 std::iota(idx.begin(), idx.end(), 0);
151 std::vector<size_t> idx(N);
152 std::iota(idx.begin(), idx.end(), 0);
153 for (
size_t i = 0; i < batch_size; ++i) {
154 std::uniform_int_distribution<size_t> dist(i, N - 1);
155 size_t j = dist(rng);
156 std::swap(idx[i], idx[j]);
158 idx.resize(batch_size);
165 template <
typename V,
typename M>
185 int dim_weights = weights.size();
187 this->step_size = step_size;
192 std::vector<size_t> full_indices(N);
193 std::iota(full_indices.begin(), full_indices.end(), 0);
195 const bool timing = (this->recorder_ !=
nullptr);
196 if (this->recorder_) this->recorder_->reset();
197 auto start_time = std::chrono::steady_clock::now();
199 while (_iters < _max_iters) {
204 V full_gradient = V::Zero(dim_weights);
206 batch_g(weights, full_indices, full_gradient);
208 if (full_gradient.norm() < _tol) {
209 std::cout <<
"Converged: gradient norm " << full_gradient.norm() << std::endl;
215 V variance_reduced_gradient = V::Zero(dim_weights);
218 for (
int t = 0; t < m; ++t) {
220 auto minibatch_indices = sample_minibatch_indices(N, b, rng);
222 V grad_estimate_wt = V::Zero(dim_weights);
223 V grad_estimate_wk = V::Zero(dim_weights);
225 batch_g(wt, minibatch_indices, grad_estimate_wt);
226 batch_g(weights, minibatch_indices, grad_estimate_wk);
228 variance_reduced_gradient = (grad_estimate_wt - grad_estimate_wk) + full_gradient;
230 V direction =
lbfgs_two_loop(s_list, y_list, rho_list, variance_reduced_gradient);
231 wt = wt - this->step_size * direction;
236 if (t > 0 && t % L == 0) {
238 V u = V::Zero(dim_weights);
239 const int num_wt =
static_cast<int>(w_history.
size());
240 for (
size_t i = 0; i < w_history.
size(); ++i)
242 if (num_wt > 0) u /=
static_cast<double>(num_wt);
244 if (!u_list.
empty()) {
245 const V &u_prev = u_list.
back();
248 auto batch_indices_H = sample_minibatch_indices(N, b_H, rng);
252 double ys = y.dot(s);
253 if (std::abs(ys) > 1e-10) {
265 if (w_history.
size() >= 2) {
266 std::uniform_int_distribution<size_t> pick_i(0, w_history.
size() - 2);
267 weights = w_history[pick_i(rng)];
274 if (this->recorder_) {
275 double full_loss = f(weights, full_indices);
276 V grad_log = V::Zero(dim_weights);
277 batch_g(weights, full_indices, grad_log);
278 double elapsed_ms = 0.0;
280 auto now = std::chrono::steady_clock::now();
281 elapsed_ms = std::chrono::duration<double, std::milli>(now - start_time).count();
283 this->recorder_->record(_iters, full_loss, grad_log.norm(), elapsed_ms);
A fixed-capacity ring buffer (circular buffer).
Definition: ring_buffer.hpp:15
bool empty() const
Checks if the buffer is empty.
Definition: ring_buffer.hpp:104
void clear()
Clears the buffer content.
Definition: ring_buffer.hpp:112
void push_back(const T &val)
Pushes a new element into the buffer.
Definition: ring_buffer.hpp:43
size_t size() const
Returns the number of elements currently stored.
Definition: ring_buffer.hpp:101
T & back()
Access the newest element.
Definition: ring_buffer.hpp:86
Stochastic Limited-memory BFGS (S-LBFGS) minimizer.
Definition: s_lbfgs.hpp:22
std::function< double(const V &, const std::vector< size_t > &)> BatchLossFun
Definition: s_lbfgs.hpp:38
static std::vector< size_t > sample_minibatch_indices(const size_t N, size_t batch_size, std::mt19937 &rng)
Helper to sample minibatch indices.
Definition: s_lbfgs.hpp:142
void setData(const BatchLossFun &f, const BatchGradFun &g)
Configure data and callbacks for the solver.
Definition: s_lbfgs.hpp:75
double step_size
Definition: stochastic_minimizer.hpp:47
V stochastic_solve(V weights, const BatchLossFun &f, const BatchGradFun &batch_g, int m, int M_param, int L, int b, int b_H, double step_size, int N)
Stochastic Solve using Batch Callbacks.
Definition: s_lbfgs.hpp:166
std::function< void(const V &, const std::vector< size_t > &, V &)> BatchGradFun
Definition: s_lbfgs.hpp:37
Base class for Stochastic Minimizers.
Definition: stochastic_minimizer.hpp:16
void setMaxIterations(int max_iters)
Sets the maximum number of iterations.
Definition: stochastic_minimizer.hpp:24
unsigned int _iters
Definition: stochastic_minimizer.hpp:45
unsigned int _max_iters
Definition: stochastic_minimizer.hpp:44
void setTolerance(double tol)
Sets the tolerance for convergence (full gradient norm).
Definition: stochastic_minimizer.hpp:36
void setStepSize(double s)
Sets the step size (learning rate).
Definition: stochastic_minimizer.hpp:30
double step_size
Definition: stochastic_minimizer.hpp:47
double _tol
Definition: stochastic_minimizer.hpp:46
V lbfgs_two_loop(const RingBuffer< V > &s_list, const RingBuffer< V > &y_list, const RingBuffer< double > &rho_list, const V &v)
Definition: s_lbfgs.hpp:107
V finite_difference_hvp_batch(BatchFn &g, const V &weights, const std::vector< size_t > &indices, const V &v, double epsilon=1e-4)
Definition: s_lbfgs.hpp:89
constexpr unsigned int kDefaultSeed
Definition: seed.hpp:4