GradientDescent.h 2.28 KB
Newer Older
Tobias Lasser's avatar
Tobias Lasser committed
1
2
3
4
5
6
7
8
9
10
11
12
13
#pragma once

#include "Solver.h"

namespace elsa
{
    /**
     * \brief Class representing a simple gradient descent solver with a fixed, given step size.
     *
     * \author Tobias Lasser - initial code
     *
     * \tparam data_t data type for the domain and range of the problem, defaulting to real_t
     *
Jens Petit's avatar
Jens Petit committed
14
15
16
     * This class implements a simple gradient descent iterative solver with a fixed, given step
     * size. No particular stopping rule is currently implemented (only a fixed number of
     * iterations, default to 100).
Tobias Lasser's avatar
Tobias Lasser committed
17
18
     */
    template <typename data_t = real_t>
Jens Petit's avatar
Jens Petit committed
19
20
    class GradientDescent : public Solver<data_t>
    {
Tobias Lasser's avatar
Tobias Lasser committed
21
22
23
24
25
26
27
28
29
30
31
32
    public:
        /**
         * \brief Constructor for gradient descent, accepting a problem and a fixed step size
         *
         * \param[in] problem the problem that is supposed to be solved
         * \param[in] stepSize the fixed step size to be used while solving
         */
        GradientDescent(const Problem<data_t>& problem, real_t stepSize);

        /// default destructor
        ~GradientDescent() override = default;

33
34
35
36
    protected:
        /// default copy constructor, hidden from non-derived classes to prevent potential slicing
        GradientDescent(const GradientDescent<data_t>&) = default;

Tobias Lasser's avatar
Tobias Lasser committed
37
38
39
40
41
42
43
44
45
46
47
48
49
50
    private:
        /// the step size
        real_t _stepSize;

        /// the default number of iterations
        const index_t _defaultIterations{100};

        /// lift the base class method getCurrentSolution
        using Solver<data_t>::getCurrentSolution;

        /// lift the base class variable _problem
        using Solver<data_t>::_problem;

        /**
Jens Petit's avatar
Jens Petit committed
51
52
53
54
55
56
57
58
         * \brief Solve the optimization problem, i.e. apply iterations number of iterations of
         * gradient descent
         *
         * \param[in] iterations number of iterations to execute (the default 0 value executes
         * _defaultIterations of iterations)
         *
         * \returns a reference to the current solution
         */
59
        DataContainer<data_t>& solveImpl(index_t iterations) override;
Tobias Lasser's avatar
Tobias Lasser committed
60
61
62
63
64
65
66
67

        /// implement the polymorphic clone operation
        GradientDescent<data_t>* cloneImpl() const override;

        /// implement the polymorphic comparison operation
        bool isEqual(const Solver<data_t>& other) const override;
    };
} // namespace elsa