My Project
s_lbfgs.hpp
Go to the documentation of this file.
1 #pragma once
2 
3 #include "../common.hpp"
4 #include "../seed.hpp"
5 #include "ring_buffer.hpp"
7 
8 #include <Eigen/Eigen>
9 #include <autodiff/reverse/var.hpp>
10 #include <autodiff/reverse/var/eigen.hpp>
11 #include <chrono>
12 #include <cmath>
13 #include <numeric>
14 #include <random>
15 
16 namespace cpu_mlp {
17 
22 template <typename V, typename M> class SLBFGS : public StochasticMinimizer<V, M> {
23 public:
25 
26 protected:
27  using Base::_iters;
28  using Base::_max_iters;
29  using Base::_tol;
30  using Base::step_size;
31 
32 public:
34  using Base::setStepSize;
35  using Base::setTolerance;
36 
37  using BatchGradFun = std::function<void(const V &, const std::vector<size_t> &, V &)>;
38  using BatchLossFun = std::function<double(const V &, const std::vector<size_t> &)>;
39 
54  V stochastic_solve(V weights,
55  const BatchLossFun &f,
56  const BatchGradFun &batch_g,
57  int m,
58  int M_param,
59  int L,
60  int b,
61  int b_H,
62  double step_size,
63  int N);
64 
68  static std::vector<size_t> sample_minibatch_indices(const size_t N, size_t batch_size, std::mt19937 &rng);
69 
75  void setData(const BatchLossFun &f, const BatchGradFun &g) {
76  _sf = f;
77  _sg = g;
78  }
79 
80 private:
81  BatchLossFun _sf;
82  BatchGradFun _sg;
83 };
84 
85 // -------------------------------------------------------------------------
86 // Helper: Finite Difference HVP on a Batch
87 // -------------------------------------------------------------------------
88 template <typename V, typename BatchFn>
90  BatchFn &g, const V &weights, const std::vector<size_t> &indices, const V &v, double epsilon = 1e-4) {
91  V w_plus = weights + epsilon * v;
92  V w_minus = weights - epsilon * v;
93 
94  V grad_plus = V::Zero(weights.size());
95  V grad_minus = V::Zero(weights.size());
96 
97  g(w_plus, indices, grad_plus);
98  g(w_minus, indices, grad_minus);
99 
100  return (grad_plus - grad_minus) / (2.0 * epsilon);
101 }
102 
103 // -------------------------------------------------------------------------
104 // Helper: L-BFGS Two Loop Recursion
105 // -------------------------------------------------------------------------
106 template <typename V>
107 V lbfgs_two_loop(const RingBuffer<V> &s_list, const RingBuffer<V> &y_list, const RingBuffer<double> &rho_list, const V &v) {
108  int M = s_list.size();
109  std::vector<double> alpha(M);
110  V q = v;
111 
112  // Backward
113  for (int i = M - 1; i >= 0; --i) {
114  alpha[i] = rho_list[i] * s_list[i].dot(q);
115  q = q - alpha[i] * y_list[i];
116  }
117 
118  // Scaling
119  double gamma = 1.0;
120  if (M > 0) {
121  double denom = y_list.back().dot(y_list.back());
122  if (std::abs(denom) < 1e-12)
123  gamma = 1.0;
124  else
125  gamma = s_list.back().dot(y_list.back()) / denom;
126  gamma = std::min(std::max(gamma, 1e-6), 1e6);
127  }
128  V r = gamma * q;
129 
130  // Forward
131  for (int i = 0; i < M; ++i) {
132  double beta = rho_list[i] * y_list[i].dot(r);
133  r = r + s_list[i] * (alpha[i] - beta);
134  }
135  return r;
136 }
137 
138 // -------------------------------------------------------------------------
139 // Implementation: sample_minibatch_indices
140 // -------------------------------------------------------------------------
141 template <typename V, typename M>
142 std::vector<size_t> SLBFGS<V, M>::sample_minibatch_indices(const size_t N, size_t batch_size, std::mt19937 &rng) {
143  if (N == 0 || batch_size == 0) return {};
144  if (batch_size >= N) {
145  std::vector<size_t> idx(N);
146  std::iota(idx.begin(), idx.end(), 0);
147  return idx;
148  }
149 
150  // Partial Fisher-Yates
151  std::vector<size_t> idx(N);
152  std::iota(idx.begin(), idx.end(), 0);
153  for (size_t i = 0; i < batch_size; ++i) {
154  std::uniform_int_distribution<size_t> dist(i, N - 1);
155  size_t j = dist(rng);
156  std::swap(idx[i], idx[j]);
157  }
158  idx.resize(batch_size);
159  return idx;
160 }
161 
162 // -------------------------------------------------------------------------
163 // Implementation: stochastic_solve
164 // -------------------------------------------------------------------------
165 template <typename V, typename M>
167  const BatchLossFun &f,
168  const BatchGradFun &batch_g,
169  int m,
170  int M_param,
171  int L,
172  int b,
173  int b_H,
174  double step_size,
175  int N) {
176 
177  _iters = 0;
178  RingBuffer<V> u_list(M_param > 0 ? M_param + 1 : 0);
179  RingBuffer<V> s_list(M_param > 0 ? M_param : 0);
180  RingBuffer<V> y_list(M_param > 0 ? M_param : 0);
181  RingBuffer<double> rho_list(M_param > 0 ? M_param : 0);
182 
183  std::mt19937 rng(kDefaultSeed);
184 
185  int dim_weights = weights.size();
186  V wt = weights;
187  this->step_size = step_size;
188 
189  RingBuffer<V> w_history(L + 1);
190 
191  // Full Batch Indices
192  std::vector<size_t> full_indices(N);
193  std::iota(full_indices.begin(), full_indices.end(), 0);
194 
195  const bool timing = (this->recorder_ != nullptr);
196  if (this->recorder_) this->recorder_->reset();
197  auto start_time = std::chrono::steady_clock::now();
198 
199  while (_iters < _max_iters) {
200 
201  w_history.clear();
202 
203  // 1. Compute Full Gradient (Variance Reduction Anchor)
204  V full_gradient = V::Zero(dim_weights);
205 
206  batch_g(weights, full_indices, full_gradient);
207 
208  if (full_gradient.norm() < _tol) {
209  std::cout << "Converged: gradient norm " << full_gradient.norm() << std::endl;
210  break;
211  }
212 
213  wt = weights;
214  w_history.push_back(wt);
215  V variance_reduced_gradient = V::Zero(dim_weights);
216 
217  // 2. Inner Loop (Stochastic Updates)
218  for (int t = 0; t < m; ++t) {
219 
220  auto minibatch_indices = sample_minibatch_indices(N, b, rng);
221 
222  V grad_estimate_wt = V::Zero(dim_weights);
223  V grad_estimate_wk = V::Zero(dim_weights);
224 
225  batch_g(wt, minibatch_indices, grad_estimate_wt);
226  batch_g(weights, minibatch_indices, grad_estimate_wk);
227 
228  variance_reduced_gradient = (grad_estimate_wt - grad_estimate_wk) + full_gradient;
229 
230  V direction = lbfgs_two_loop(s_list, y_list, rho_list, variance_reduced_gradient);
231  wt = wt - this->step_size * direction;
232 
233  w_history.push_back(wt);
234 
235  // 3. Hessian Update (Curvature Pairs)
236  if (t > 0 && t % L == 0) {
237 
238  V u = V::Zero(dim_weights);
239  const int num_wt = static_cast<int>(w_history.size());
240  for (size_t i = 0; i < w_history.size(); ++i)
241  u += w_history[i];
242  if (num_wt > 0) u /= static_cast<double>(num_wt);
243 
244  if (!u_list.empty()) {
245  const V &u_prev = u_list.back();
246  V s = u - u_prev;
247 
248  auto batch_indices_H = sample_minibatch_indices(N, b_H, rng);
249 
250  V y = finite_difference_hvp_batch(batch_g, u, batch_indices_H, s);
251 
252  double ys = y.dot(s);
253  if (std::abs(ys) > 1e-10) {
254  s_list.push_back(s);
255  y_list.push_back(y);
256  rho_list.push_back(1.0 / ys);
257  }
258  }
259 
260  u_list.push_back(u);
261  }
262  }
263 
264  // Reset anchor
265  if (w_history.size() >= 2) {
266  std::uniform_int_distribution<size_t> pick_i(0, w_history.size() - 2);
267  weights = w_history[pick_i(rng)];
268  } else {
269  weights = wt;
270  }
271 
272  // Logging (approximate loss via simple full gradient norm or callback if desired)
273  // Calculating full loss is expensive. We skip it or use passed callback on full batch.
274  if (this->recorder_) {
275  double full_loss = f(weights, full_indices);
276  V grad_log = V::Zero(dim_weights);
277  batch_g(weights, full_indices, grad_log);
278  double elapsed_ms = 0.0;
279  if (timing) {
280  auto now = std::chrono::steady_clock::now();
281  elapsed_ms = std::chrono::duration<double, std::milli>(now - start_time).count();
282  }
283  this->recorder_->record(_iters, full_loss, grad_log.norm(), elapsed_ms);
284  }
285 
286  _iters++;
287  }
288 
289  return weights;
290 };
291 } // namespace cpu_mlp
A fixed-capacity ring buffer (circular buffer).
Definition: ring_buffer.hpp:15
bool empty() const
Checks if the buffer is empty.
Definition: ring_buffer.hpp:104
void clear()
Clears the buffer content.
Definition: ring_buffer.hpp:112
void push_back(const T &val)
Pushes a new element into the buffer.
Definition: ring_buffer.hpp:43
size_t size() const
Returns the number of elements currently stored.
Definition: ring_buffer.hpp:101
T & back()
Access the newest element.
Definition: ring_buffer.hpp:86
Stochastic Limited-memory BFGS (S-LBFGS) minimizer.
Definition: s_lbfgs.hpp:22
std::function< double(const V &, const std::vector< size_t > &)> BatchLossFun
Definition: s_lbfgs.hpp:38
static std::vector< size_t > sample_minibatch_indices(const size_t N, size_t batch_size, std::mt19937 &rng)
Helper to sample minibatch indices.
Definition: s_lbfgs.hpp:142
void setData(const BatchLossFun &f, const BatchGradFun &g)
Configure data and callbacks for the solver.
Definition: s_lbfgs.hpp:75
double step_size
Definition: stochastic_minimizer.hpp:47
V stochastic_solve(V weights, const BatchLossFun &f, const BatchGradFun &batch_g, int m, int M_param, int L, int b, int b_H, double step_size, int N)
Stochastic Solve using Batch Callbacks.
Definition: s_lbfgs.hpp:166
std::function< void(const V &, const std::vector< size_t > &, V &)> BatchGradFun
Definition: s_lbfgs.hpp:37
Base class for Stochastic Minimizers.
Definition: stochastic_minimizer.hpp:16
void setMaxIterations(int max_iters)
Sets the maximum number of iterations.
Definition: stochastic_minimizer.hpp:24
unsigned int _iters
Definition: stochastic_minimizer.hpp:45
unsigned int _max_iters
Definition: stochastic_minimizer.hpp:44
void setTolerance(double tol)
Sets the tolerance for convergence (full gradient norm).
Definition: stochastic_minimizer.hpp:36
void setStepSize(double s)
Sets the step size (learning rate).
Definition: stochastic_minimizer.hpp:30
double step_size
Definition: stochastic_minimizer.hpp:47
double _tol
Definition: stochastic_minimizer.hpp:46
Definition: layer.hpp:13
V lbfgs_two_loop(const RingBuffer< V > &s_list, const RingBuffer< V > &y_list, const RingBuffer< double > &rho_list, const V &v)
Definition: s_lbfgs.hpp:107
V finite_difference_hvp_batch(BatchFn &g, const V &weights, const std::vector< size_t > &indices, const V &v, double epsilon=1e-4)
Definition: s_lbfgs.hpp:89
constexpr unsigned int kDefaultSeed
Definition: seed.hpp:4