test_LASSOProblem.cpp 4.4 KB
Newer Older
1
2
3
4
5
6
7
8
/**
 * @file test_LASSOProblem.cpp
 *
 * @brief Tests for the LASSOProblem class
 *
 * @author Andi Braimllari
 */

9
10
#include "doctest/doctest.h"

11
#include "Error.h"
12
13
14
15
#include "L2NormPow2.h"
#include "LASSOProblem.h"
#include "VolumeDescriptor.h"
#include "Identity.h"
16
#include "testHelpers.h"
17
18

using namespace elsa;
19
using namespace doctest;
20

21
22
TEST_SUITE_BEGIN("problems");

23
TEST_CASE_TEMPLATE("Scenario: Testing LASSOProblem", data_t, float, double)
24
25
26
27
28
29
30
{
    GIVEN("some data term and a regularization term")
    {
        IndexVector_t numCoeff(2);
        numCoeff << 17, 53;
        VolumeDescriptor dd(numCoeff);

31
        Vector_t<data_t> scaling(dd.getNumberOfCoefficients());
32
        scaling.setRandom();
33
        DataContainer<data_t> dcScaling(dd, scaling);
34
35
        Scaling scaleOp(dd, dcScaling);

36
        Vector_t<data_t> dataVec(dd.getNumberOfCoefficients());
37
        dataVec.setRandom();
38
        DataContainer<data_t> dcData(dd, dataVec);
39

40
        WLSProblem<data_t> wlsProblem(scaleOp, dcData);
41

42
43
        auto invalidWeight = static_cast<data_t>(-2.0);
        auto weight = static_cast<data_t>(2.0);
44
45

        // l1 norm regularization term
46
        L1Norm<data_t> regFunc(dd);
47
48
49

        WHEN("setting up a LASSOProblem with a negative regularization weight")
        {
50
            RegularizationTerm<data_t> invalidRegTerm(invalidWeight, regFunc);
51
52
            THEN("an invalid_argument exception is thrown")
            {
53
                REQUIRE_THROWS_AS(LASSOProblem<data_t>(wlsProblem, invalidRegTerm),
54
                                  InvalidArgumentError);
55
56
57
58
            }
        }

        // l2 norm regularization term
59
        L2NormPow2<data_t> invalidRegFunc(dd);
60
61
62

        WHEN("setting up a LASSOProblem with a L2NormPow2 regularization term")
        {
63
            RegularizationTerm<data_t> invalidRegTerm(weight, invalidRegFunc);
64
65
            THEN("an invalid_argument exception is thrown")
            {
66
                REQUIRE_THROWS_AS(LASSOProblem<data_t>(wlsProblem, invalidRegTerm),
67
                                  InvalidArgumentError);
68
69
70
            }
        }

71
        RegularizationTerm<data_t> regTerm(weight, regFunc);
72
73
74

        WHEN("setting up a LASSOProblem without an x0")
        {
75
            LASSOProblem<data_t> lassoProb(wlsProblem, regTerm);
76
77
78
79
80

            THEN("cloned LASSOProblem equals original LASSOProblem")
            {
                auto lassoProbClone = lassoProb.clone();

81
82
                REQUIRE_NE(lassoProbClone.get(), &lassoProb);
                REQUIRE_EQ(*lassoProbClone, lassoProb);
83
84
85
86
87
            }
        }

        WHEN("setting up a LASSOProblem with an x0")
        {
88
            Eigen::Matrix<data_t, Eigen::Dynamic, 1> x0Vec(dd.getNumberOfCoefficients());
89
            x0Vec.setRandom();
90
            DataContainer<data_t> dcX0(dd, x0Vec);
91
92

            wlsProblem.getCurrentSolution() = dcX0;
93
            LASSOProblem<data_t> lassoProb(wlsProblem, regTerm);
94
95
96
97
98

            THEN("cloned LASSOProblem equals original LASSOProblem")
            {
                auto lassoProbClone = lassoProb.clone();

99
100
                REQUIRE_NE(lassoProbClone.get(), &lassoProb);
                REQUIRE_EQ(*lassoProbClone, lassoProb);
101
102
103
            }
        }

104
105
        Identity<data_t> idOp(dd);
        WLSProblem<data_t> wlsProblemForLC(idOp, dcData);
106
107
108

        WHEN("setting up the Lipschitz Constant of a LASSOProblem without an x0")
        {
109
            LASSOProblem<data_t> lassoProb(wlsProblemForLC, regTerm);
110

111
            auto lipschitzConstant = lassoProb.getLipschitzConstant();
112
113
114
115

            THEN("the Lipschitz Constant of a LASSOProblem with an Identity Operator as the "
                 "Linear Operator A is 1")
            {
116
                REQUIRE_UNARY(checkApproxEq(lipschitzConstant, as<data_t>(1.0)));
117
118
119
120
121
            }
        }

        WHEN("setting up the Lipschitz Constant of a LASSOProblem with an x0")
        {
122
            Vector_t<data_t> x0Vec(dd.getNumberOfCoefficients());
123
            x0Vec.setRandom();
124
            DataContainer<data_t> dcX0(dd, x0Vec);
125
126
            wlsProblemForLC.getCurrentSolution() = dcX0;

127
            LASSOProblem<data_t> lassoProb(wlsProblemForLC, regTerm);
128

129
            auto lipschitzConstant = lassoProb.getLipschitzConstant();
130
131
132
133

            THEN("the Lipschitz Constant of a LASSOProblem with an Identity Operator as the "
                 "Linear Operator A is 1")
            {
134
                REQUIRE_EQ(lipschitzConstant, Approx(as<data_t>(1.0)));
135
136
137
138
            }
        }
    }
}
139
140

TEST_SUITE_END();