3 #include "../common.hpp"
6 #include <autodiff/reverse/var.hpp>
7 #include <autodiff/reverse/var/eigen.hpp>
11 template <
typename M> constexpr
bool isSparse = std::is_base_of_v<Eigen::SparseMatrixBase<M>, M>;
14 using DefaultSolverT =
typename std::conditional<isSparse<M>, Eigen::ConjugateGradient<M>, Eigen::LDLT<M>>::type;
19 template <
typename V,
typename M,
typename Solver = DefaultSolverT<M>>
class BFGS :
public FullBatchMinimizer<V, M> {
27 using SolverT =
typename std::conditional<UseDefaultSolver, Solver, Solver &>::type;
57 if (_B.rows() != x.size()) {
58 _B = M::Identity(x.size(), x.size());
61 for (_iters = 0; _iters < _max_iters && Gradient(x).norm() > _tol; ++_iters) {
64 check(_solver.info() == Eigen::Success,
"conjugate gradient solver error");
66 V p = _solver.solve(-Gradient(x));
74 V y = Gradient(x_next) - Gradient(x);
77 _B = _B + (y * y.transpose()) / (y.transpose() * s) - (b_prod * b_prod.transpose()) / (s.transpose() * _B * s);
BFGS (Broyden–Fletcher–Goldfarb–Shanno) minimizer.
Definition: bfgs.hpp:19
typename std::conditional< UseDefaultSolver, Solver, Solver & >::type SolverT
Definition: bfgs.hpp:27
void setInitialHessian(const M &b)
Sets the initial approximate Hessian matrix.
Definition: bfgs.hpp:44
V solve(V x, VecFun< V, double > &f, GradFun< V > &Gradient) override
Solves the optimization problem using BFGS method.
Definition: bfgs.hpp:53
requires(UseDefaultSolver)
Definition: bfgs.hpp:35
static constexpr bool UseDefaultSolver
Definition: bfgs.hpp:26
Base class for Full Batch Minimizers.
Definition: full_batch_minimizer.hpp:23
double _tol
Definition: full_batch_minimizer.hpp:109
double line_search(V x, V p, VecFun< V, double > &f, GradFun< V > &Gradient)
Backtracking Line Search satisfying Wolfe Conditions.
Definition: full_batch_minimizer.hpp:126
unsigned int _iters
Definition: full_batch_minimizer.hpp:108
unsigned int _max_iters
Definition: full_batch_minimizer.hpp:107
std::function< T(T)> GradFun
Gradient function type alias (T -> T).
Definition: common.hpp:32
std::function< W(T)> VecFun
Objective function type alias (T -> W).
Definition: common.hpp:35
#define check(condition, message)
Debug assertion with message and source location.
Definition: common.hpp:14
constexpr bool isSparse
Definition: bfgs.hpp:11
typename std::conditional< isSparse< M >, Eigen::ConjugateGradient< M >, Eigen::LDLT< M > >::type DefaultSolverT
Definition: bfgs.hpp:14