Problem.cpp 7.91 KB
Newer Older
Tobias Lasser's avatar
Tobias Lasser committed
1
#include "Problem.h"
2
#include "Scaling.h"
3
4
#include "Logger.h"
#include "Timer.h"
Tobias Lasser's avatar
Tobias Lasser committed
5
6
7
8

namespace elsa
{
    template <typename data_t>
9
10
    Problem<data_t>::Problem(const Functional<data_t>& dataTerm,
                             const std::vector<RegularizationTerm<data_t>>& regTerms,
11
12
13
14
15
16
                             const DataContainer<data_t>& x0,
                             const std::optional<data_t> lipschitzConstant)
        : _dataTerm{dataTerm.clone()},
          _regTerms{regTerms},
          _currentSolution{x0},
          _lipschitzConstant{lipschitzConstant}
17
18
    {
        // sanity checks
19
20
        if (_dataTerm->getDomainDescriptor().getNumberOfCoefficients()
            != this->_currentSolution.getSize())
21
            throw InvalidArgumentError("Problem: domain of dataTerm and solution do not match");
22
23
24
25
    }

    template <typename data_t>
    Problem<data_t>::Problem(const Functional<data_t>& dataTerm,
26
27
                             const std::vector<RegularizationTerm<data_t>>& regTerms,
                             const std::optional<data_t> lipschitzConstant)
Jens Petit's avatar
Jens Petit committed
28
29
        : _dataTerm{dataTerm.clone()},
          _regTerms{regTerms},
30
31
          _currentSolution{dataTerm.getDomainDescriptor()},
          _lipschitzConstant{lipschitzConstant}
32
    {
33
        _currentSolution = 0;
34
35
36
    }

    template <typename data_t>
Jens Petit's avatar
Jens Petit committed
37
38
    Problem<data_t>::Problem(const Functional<data_t>& dataTerm,
                             const RegularizationTerm<data_t>& regTerm,
39
40
41
42
43
44
                             const DataContainer<data_t>& x0,
                             const std::optional<data_t> lipschitzConstant)
        : _dataTerm{dataTerm.clone()},
          _regTerms{regTerm},
          _currentSolution{x0},
          _lipschitzConstant{lipschitzConstant}
45
46
    {
        // sanity checks
47
48
        if (_dataTerm->getDomainDescriptor().getNumberOfCoefficients()
            != this->_currentSolution.getSize())
49
            throw InvalidArgumentError("Problem: domain of dataTerm and solution do not match");
50
51
52
53
    }

    template <typename data_t>
    Problem<data_t>::Problem(const Functional<data_t>& dataTerm,
54
55
                             const RegularizationTerm<data_t>& regTerm,
                             const std::optional<data_t> lipschitzConstant)
Jens Petit's avatar
Jens Petit committed
56
57
        : _dataTerm{dataTerm.clone()},
          _regTerms{regTerm},
58
59
          _currentSolution{dataTerm.getDomainDescriptor(), defaultHandlerType},
          _lipschitzConstant{lipschitzConstant}
60
    {
61
        _currentSolution = 0;
62
63
64
    }

    template <typename data_t>
65
66
67
    Problem<data_t>::Problem(const Functional<data_t>& dataTerm, const DataContainer<data_t>& x0,
                             const std::optional<data_t> lipschitzConstant)
        : _dataTerm{dataTerm.clone()}, _currentSolution{x0}, _lipschitzConstant{lipschitzConstant}
68
69
    {
        // sanity check
70
71
        if (_dataTerm->getDomainDescriptor().getNumberOfCoefficients()
            != this->_currentSolution.getSize())
72
            throw InvalidArgumentError("Problem: domain of dataTerm and solution do not match");
73
74
75
    }

    template <typename data_t>
76
77
    Problem<data_t>::Problem(const Functional<data_t>& dataTerm,
                             const std::optional<data_t> lipschitzConstant)
78
        : _dataTerm{dataTerm.clone()},
79
80
          _currentSolution{dataTerm.getDomainDescriptor(), defaultHandlerType},
          _lipschitzConstant{lipschitzConstant}
Jens Petit's avatar
Jens Petit committed
81
    {
82
        _currentSolution = 0;
Jens Petit's avatar
Jens Petit committed
83
    }
Tobias Lasser's avatar
Tobias Lasser committed
84
85

    template <typename data_t>
Jens Petit's avatar
Jens Petit committed
86
    Problem<data_t>::Problem(const Problem<data_t>& problem)
87
88
        : Cloneable<Problem<data_t>>(),
          _dataTerm{problem._dataTerm->clone()},
Jens Petit's avatar
Jens Petit committed
89
          _regTerms{problem._regTerms},
90
91
          _currentSolution{problem._currentSolution},
          _lipschitzConstant{problem._lipschitzConstant}
Jens Petit's avatar
Jens Petit committed
92
93
    {
    }
Tobias Lasser's avatar
Tobias Lasser committed
94

95
96
97
98
99
100
101
102
103
104
105
    template <typename data_t>
    const Functional<data_t>& Problem<data_t>::getDataTerm() const
    {
        return *_dataTerm;
    }

    template <typename data_t>
    const std::vector<RegularizationTerm<data_t>>& Problem<data_t>::getRegularizationTerms() const
    {
        return _regTerms;
    }
Tobias Lasser's avatar
Tobias Lasser committed
106
107
108
109
110
111
112
113
114
115
116
117
118

    template <typename data_t>
    const DataContainer<data_t>& Problem<data_t>::getCurrentSolution() const
    {
        return _currentSolution;
    }

    template <typename data_t>
    DataContainer<data_t>& Problem<data_t>::getCurrentSolution()
    {
        return _currentSolution;
    }

119
    template <typename data_t>
120
    data_t Problem<data_t>::evaluateImpl()
121
122
123
124
125
126
127
128
129
130
    {
        data_t result = _dataTerm->evaluate(_currentSolution);

        for (auto& regTerm : _regTerms)
            result += regTerm.getWeight() * regTerm.getFunctional().evaluate(_currentSolution);

        return result;
    }

    template <typename data_t>
131
    void Problem<data_t>::getGradientImpl(DataContainer<data_t>& result)
132
133
134
135
136
137
138
139
    {
        _dataTerm->getGradient(_currentSolution, result);

        for (auto& regTerm : _regTerms)
            result += regTerm.getWeight() * regTerm.getFunctional().getGradient(_currentSolution);
    }

    template <typename data_t>
140
    LinearOperator<data_t> Problem<data_t>::getHessianImpl() const
141
142
143
144
145
146
147
148
149
150
151
    {
        auto hessian = _dataTerm->getHessian(_currentSolution);

        for (auto& regTerm : _regTerms) {
            Scaling weight(_currentSolution.getDataDescriptor(), regTerm.getWeight());
            hessian = hessian + (weight * regTerm.getFunctional().getHessian(_currentSolution));
        }

        return hessian;
    }

152
153
154
    template <typename data_t>
    data_t Problem<data_t>::getLipschitzConstantImpl(index_t nIterations) const
    {
155
156
157
        Timer guard("Problem", "Calculating Lipschitz constant");
        Logger::get("Problem")->info("Calculating Lipschitz constant");

158
159
160
        if (_lipschitzConstant.has_value()) {
            return _lipschitzConstant.value();
        }
161
162
        // compute the Lipschitz Constant as the largest eigenvalue of the Hessian
        const auto hessian = getHessian();
163
164
        DataContainer<data_t> dcB(hessian.getDomainDescriptor());
        dcB = 1;
165
166
        for (index_t i = 0; i < nIterations; i++) {
            dcB = hessian.apply(dcB);
167
            dcB = dcB / dcB.l2Norm();
168
169
        }

170
        return dcB.dot(hessian.apply(dcB)) / dcB.l2Norm();
171
172
    }

173
174
175
176
177
178
179
180
181
    template <typename data_t>
    Problem<data_t>* Problem<data_t>::cloneImpl() const
    {
        return new Problem(*this);
    }

    template <typename data_t>
    bool Problem<data_t>::isEqual(const Problem<data_t>& other) const
    {
182
183
184
        if (typeid(*this) != typeid(other))
            return false;

Jens Petit's avatar
Jens Petit committed
185
        if (_currentSolution != other._currentSolution)
186
187
188
189
190
            return false;

        if (*_dataTerm != *other._dataTerm)
            return false;

191
192
193
        if (_regTerms.size() != other._regTerms.size())
            return false;

194
        for (std::size_t i = 0; i < _regTerms.size(); ++i)
195
196
197
198
199
            if (_regTerms.at(i) != other._regTerms.at(i))
                return false;

        return true;
    }
Tobias Lasser's avatar
Tobias Lasser committed
200
201
202
203

    template <typename data_t>
    data_t Problem<data_t>::evaluate()
    {
204
        return evaluateImpl();
Tobias Lasser's avatar
Tobias Lasser committed
205
206
207
208
209
    }

    template <typename data_t>
    DataContainer<data_t> Problem<data_t>::getGradient()
    {
210
211
        DataContainer<data_t> result(_currentSolution.getDataDescriptor(),
                                     _currentSolution.getDataHandlerType());
Tobias Lasser's avatar
Tobias Lasser committed
212
213
214
215
216
217
218
        getGradient(result);
        return result;
    }

    template <typename data_t>
    void Problem<data_t>::getGradient(DataContainer<data_t>& result)
    {
219
        getGradientImpl(result);
Tobias Lasser's avatar
Tobias Lasser committed
220
221
222
    }

    template <typename data_t>
223
    LinearOperator<data_t> Problem<data_t>::getHessian() const
Tobias Lasser's avatar
Tobias Lasser committed
224
    {
225
        return getHessianImpl();
Tobias Lasser's avatar
Tobias Lasser committed
226
227
    }

228
229
230
231
232
233
    template <typename data_t>
    data_t Problem<data_t>::getLipschitzConstant(index_t nIterations) const
    {
        return getLipschitzConstantImpl(nIterations);
    }

Tobias Lasser's avatar
Tobias Lasser committed
234
235
236
    // ------------------------------------------
    // explicit template instantiation
    template class Problem<float>;
237

Tobias Lasser's avatar
Tobias Lasser committed
238
    template class Problem<double>;
239

240
    template class Problem<complex<float>>;
241

242
    template class Problem<complex<double>>;
Tobias Lasser's avatar
Tobias Lasser committed
243
244

} // namespace elsa