WeightedL1Norm.cpp 2.82 KB
Newer Older
Andi Braimllari's avatar
Andi Braimllari committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
#include "WeightedL1Norm.h"
#include "LinearOperator.h"

namespace elsa
{
    template <typename data_t>
    WeightedL1Norm<data_t>::WeightedL1Norm(const DataContainer<data_t>& weightingOp)
        : Functional<data_t>(weightingOp.getDataDescriptor()), _weightingOp{weightingOp}
    {
        // sanity check
        if (weightingOp.minElement() < 0) {
            throw InvalidArgumentError(
                "WeightedL1Norm: all weights in the w vector should be >= 0");
        }
    }

    template <typename data_t>
    WeightedL1Norm<data_t>::WeightedL1Norm(const Residual<data_t>& residual,
                                           const DataContainer<data_t>& weightingOp)
        : Functional<data_t>(residual), _weightingOp{weightingOp}
    {
        // sanity check
        if (residual.getRangeDescriptor().getNumberOfCoefficients()
            != weightingOp.getDataDescriptor().getNumberOfCoefficients()) {
            throw InvalidArgumentError(
                "WeightedL1Norm: sizes of residual and weighting operator do not match");
        }
        // sanity check
        if (weightingOp.minElement() < 0) {
            throw InvalidArgumentError(
                "WeightedL1Norm: all weights in the w vector should be >= 0");
        }
    }

    template <typename data_t>
    const DataContainer<data_t>& WeightedL1Norm<data_t>::getWeightingOperator() const
    {
        return _weightingOp;
    }

    template <typename data_t>
    data_t WeightedL1Norm<data_t>::evaluateImpl(const DataContainer<data_t>& Rx)
    {
        return _weightingOp.dot(cwiseAbs(Rx));
    }

    template <typename data_t>
    void WeightedL1Norm<data_t>::getGradientInPlaceImpl([[maybe_unused]] DataContainer<data_t>& Rx)
    {
        throw LogicError("WeightedL1Norm: not differentiable, so no gradient! (busted!)");
    }

    template <typename data_t>
    LinearOperator<data_t>
        WeightedL1Norm<data_t>::getHessianImpl([[maybe_unused]] const DataContainer<data_t>& Rx)
    {
        throw LogicError("WeightedL1Norm: not differentiable, so no Hessian! (busted!)");
    }

    template <typename data_t>
    WeightedL1Norm<data_t>* WeightedL1Norm<data_t>::cloneImpl() const
    {
        return new WeightedL1Norm(this->getResidual(), _weightingOp);
    }

    template <typename data_t>
    bool WeightedL1Norm<data_t>::isEqual(const Functional<data_t>& other) const
    {
        if (!Functional<data_t>::isEqual(other))
            return false;

        auto otherWL1 = dynamic_cast<const WeightedL1Norm*>(&other);
        if (!otherWL1)
            return false;

        if (_weightingOp != otherWL1->_weightingOp)
            return false;

        return true;
    }

    // ------------------------------------------
    // explicit template instantiation
    template class WeightedL1Norm<float>;
    template class WeightedL1Norm<double>;
} // namespace elsa