My Project
full_batch_minimizer.hpp
Go to the documentation of this file.
1 #pragma once
2 
3 #include "../common.hpp"
4 #include "../iteration_recorder.hpp"
5 #include <Eigen/Eigen>
6 #include <cmath>
7 #include <functional>
8 #include <limits>
9 
10 extern "C" {
11 extern double __enzyme_autodiff(void *, ...);
12 extern int enzyme_dup;
13 extern int enzyme_const;
14 extern int enzyme_out;
15 }
16 
17 namespace cpu_mlp {
18 
23 template <typename V, typename M> class FullBatchMinimizer {
24 public:
25  virtual ~FullBatchMinimizer() = default;
26 
34  virtual V solve(V x, VecFun<V, double> &f, GradFun<V> &Gradient) = 0;
35 
40  virtual void setInitialHessian(const M & /*hess*/) {}
41 
46  virtual void setHessian(const HessFun<V, M> & /*hessFun*/) {}
47 
56  template <auto LossFn, typename DataType> V solve_with_enzyme(V x, DataType *data) {
57  VecFun<V, double> f = [data](const V &w) -> double { return LossFn(const_cast<double *>(w.data()), data); };
58  GradFun<V> Gradient = [data](const V &w) -> V {
59  V grad = V::Zero(w.size());
60  __enzyme_autodiff((void *)LossFn, enzyme_dup, const_cast<double *>(w.data()), grad.data(), enzyme_const, data);
61  return grad;
62  };
63  return this->solve(x, f, Gradient);
64  }
65 
70  void setMaxIterations(int max_iters) { _max_iters = max_iters; }
75  void setMaxLineIters(int max_line) { max_line_iters = max_line; }
76 
81  void setArmijoMaxIter(int max_armijo) { max_line_iters = max_armijo; }
82 
87  void setTolerance(double tol) { _tol = tol; }
88 
93  unsigned int iterations() const { return _iters; }
94 
99  double tolerance() const { return _tol; }
104  void setRecorder(::IterationRecorder<CpuBackend> *recorder) { recorder_ = recorder; }
105 
106 protected:
107  unsigned int _max_iters = 1000;
108  unsigned int _iters = 0;
109  double _tol = 1e-10;
111 
112  // Line Search Parameters
113  double c1 = 1e-4;
114  double c2 = 0.9;
115  double rho = 0.5;
116  double max_line_iters = 50;
117 
126  double line_search(V x, V p, VecFun<V, double> &f, GradFun<V> &Gradient) {
127  double f_old = f(x);
128  double grad_f_old = Gradient(x).dot(p);
129  double inf = std::numeric_limits<double>::infinity();
130  double alpha_min = 0.0;
131  double alpha_max = inf;
132  double alpha = 1.0;
133 
134  for (int i = 0; i < max_line_iters; ++i) {
135  V x_new = x + alpha * p;
136  double f_new = f(x_new);
137 
138  if (f_new > f_old + c1 * alpha * grad_f_old) {
139  alpha_max = alpha;
140  alpha = rho * (alpha_min + alpha_max);
141  continue;
142  }
143 
144  double grad_f_new_dot_p = Gradient(x_new).dot(p);
145 
146  if (grad_f_new_dot_p < c2 * grad_f_old) {
147  alpha_min = alpha;
148  if (alpha_max == inf)
149  alpha *= 2;
150  else
151  alpha = rho * (alpha_min + alpha_max);
152  continue;
153  }
154  return alpha;
155  }
156  return alpha;
157  }
158 };
159 
160 } // namespace cpu_mlp
CPU recorder that stores loss/gradient history on host.
Definition: iteration_recorder.hpp:18
Base class for Full Batch Minimizers.
Definition: full_batch_minimizer.hpp:23
double c2
Definition: full_batch_minimizer.hpp:114
double tolerance() const
Returns the tolerance used.
Definition: full_batch_minimizer.hpp:99
V solve_with_enzyme(V x, DataType *data)
Helper to solve directly using an Enzyme-compatible raw function.
Definition: full_batch_minimizer.hpp:56
double rho
Definition: full_batch_minimizer.hpp:115
double _tol
Definition: full_batch_minimizer.hpp:109
virtual void setHessian(const HessFun< V, M > &)
Sets the Hessian function (for Second Order methods).
Definition: full_batch_minimizer.hpp:46
double line_search(V x, V p, VecFun< V, double > &f, GradFun< V > &Gradient)
Backtracking Line Search satisfying Wolfe Conditions.
Definition: full_batch_minimizer.hpp:126
void setRecorder(::IterationRecorder< CpuBackend > *recorder)
Attach a recorder for loss/grad history.
Definition: full_batch_minimizer.hpp:104
virtual ~FullBatchMinimizer()=default
double max_line_iters
Definition: full_batch_minimizer.hpp:116
unsigned int _iters
Definition: full_batch_minimizer.hpp:108
void setTolerance(double tol)
Sets the tolerance for convergence.
Definition: full_batch_minimizer.hpp:87
unsigned int iterations() const
Returns the number of iterations performed.
Definition: full_batch_minimizer.hpp:93
::IterationRecorder< CpuBackend > * recorder_
Optional recorder for diagnostics.
Definition: full_batch_minimizer.hpp:110
double c1
Definition: full_batch_minimizer.hpp:113
virtual void setInitialHessian(const M &)
Sets the initial Hessian approximation (if applicable).
Definition: full_batch_minimizer.hpp:40
void setMaxLineIters(int max_line)
Sets the maximum number of iterations for the line search.
Definition: full_batch_minimizer.hpp:75
void setArmijoMaxIter(int max_armijo)
Sets the maximum iterations for Armijo condition check (alias for setMaxLineIters).
Definition: full_batch_minimizer.hpp:81
unsigned int _max_iters
Definition: full_batch_minimizer.hpp:107
void setMaxIterations(int max_iters)
Sets the maximum number of iterations.
Definition: full_batch_minimizer.hpp:70
virtual V solve(V x, VecFun< V, double > &f, GradFun< V > &Gradient)=0
Performs optimization.
std::function< T(T)> GradFun
Gradient function type alias (T -> T).
Definition: common.hpp:32
std::function< W(T)> VecFun
Objective function type alias (T -> W).
Definition: common.hpp:35
std::function< M(V)> HessFun
Hessian function type alias (V -> M).
Definition: common.hpp:38
int enzyme_dup
int enzyme_out
int enzyme_const
double __enzyme_autodiff(void *,...)
Definition: layer.hpp:13