Commit baf04e2e authored by Andi Braimllari's avatar Andi Braimllari Committed by Tobias Lasser
Browse files

add Weighted L1 Norm

parent 65e71040
Pipeline #869234 passed with stages
in 46 minutes and 42 seconds
...@@ -40,6 +40,11 @@ L1Norm ...@@ -40,6 +40,11 @@ L1Norm
.. doxygenclass:: elsa::L1Norm .. doxygenclass:: elsa::L1Norm
WeightedL1Norm
==============
.. doxygenclass:: elsa::WeightedL1Norm
L2NormPow2 L2NormPow2
========== ==========
......
...@@ -14,6 +14,7 @@ set(MODULE_HEADERS ...@@ -14,6 +14,7 @@ set(MODULE_HEADERS
EmissionLogLikelihood.h EmissionLogLikelihood.h
TransmissionLogLikelihood.h TransmissionLogLikelihood.h
Constraint.h Constraint.h
WeightedL1Norm.h
) )
# list all the code files of the module # list all the code files of the module
...@@ -32,6 +33,7 @@ set(MODULE_SOURCES ...@@ -32,6 +33,7 @@ set(MODULE_SOURCES
EmissionLogLikelihood.cpp EmissionLogLikelihood.cpp
TransmissionLogLikelihood.cpp TransmissionLogLikelihood.cpp
Constraint.cpp Constraint.cpp
WeightedL1Norm.cpp
) )
list(APPEND MODULE_PUBLIC_DEPS elsa_core elsa_logging elsa_operators) list(APPEND MODULE_PUBLIC_DEPS elsa_core elsa_logging elsa_operators)
......
#include "WeightedL1Norm.h"
#include "LinearOperator.h"
namespace elsa
{
template <typename data_t>
WeightedL1Norm<data_t>::WeightedL1Norm(const DataContainer<data_t>& weightingOp)
: Functional<data_t>(weightingOp.getDataDescriptor()), _weightingOp{weightingOp}
{
// sanity check
if (weightingOp.minElement() < 0) {
throw InvalidArgumentError(
"WeightedL1Norm: all weights in the w vector should be >= 0");
}
}
template <typename data_t>
WeightedL1Norm<data_t>::WeightedL1Norm(const Residual<data_t>& residual,
const DataContainer<data_t>& weightingOp)
: Functional<data_t>(residual), _weightingOp{weightingOp}
{
// sanity check
if (residual.getRangeDescriptor().getNumberOfCoefficients()
!= weightingOp.getDataDescriptor().getNumberOfCoefficients()) {
throw InvalidArgumentError(
"WeightedL1Norm: sizes of residual and weighting operator do not match");
}
// sanity check
if (weightingOp.minElement() < 0) {
throw InvalidArgumentError(
"WeightedL1Norm: all weights in the w vector should be >= 0");
}
}
template <typename data_t>
const DataContainer<data_t>& WeightedL1Norm<data_t>::getWeightingOperator() const
{
return _weightingOp;
}
template <typename data_t>
data_t WeightedL1Norm<data_t>::evaluateImpl(const DataContainer<data_t>& Rx)
{
return _weightingOp.dot(cwiseAbs(Rx));
}
template <typename data_t>
void WeightedL1Norm<data_t>::getGradientInPlaceImpl([[maybe_unused]] DataContainer<data_t>& Rx)
{
throw LogicError("WeightedL1Norm: not differentiable, so no gradient! (busted!)");
}
template <typename data_t>
LinearOperator<data_t>
WeightedL1Norm<data_t>::getHessianImpl([[maybe_unused]] const DataContainer<data_t>& Rx)
{
throw LogicError("WeightedL1Norm: not differentiable, so no Hessian! (busted!)");
}
template <typename data_t>
WeightedL1Norm<data_t>* WeightedL1Norm<data_t>::cloneImpl() const
{
return new WeightedL1Norm(this->getResidual(), _weightingOp);
}
template <typename data_t>
bool WeightedL1Norm<data_t>::isEqual(const Functional<data_t>& other) const
{
if (!Functional<data_t>::isEqual(other))
return false;
auto otherWL1 = dynamic_cast<const WeightedL1Norm*>(&other);
if (!otherWL1)
return false;
if (_weightingOp != otherWL1->_weightingOp)
return false;
return true;
}
// ------------------------------------------
// explicit template instantiation
template class WeightedL1Norm<float>;
template class WeightedL1Norm<double>;
} // namespace elsa
#pragma once
#include "Functional.h"
namespace elsa
{
/**
* @brief Class representing a weighted l1 norm functional.
*
* @author Andi Braimllari - initial code
*
* @tparam data_t data type for the domain of the functional, defaulting to real_t
*
* The weighted l1 norm functional evaluates to @f$ \| x \|_{w,1} = \sum_{i=1}^n w_{i} *
* |x_{i}| @f$ where @f$ w_{i} >= 0 @f$.
*/
template <typename data_t = real_t>
class WeightedL1Norm : public Functional<data_t>
{
public:
/**
* @brief Constructor for the weighted l1 norm, mapping domain vector to a scalar
* (without a residual)
*
* @param[in] weightingOp container of the weights
*/
explicit WeightedL1Norm(const DataContainer<data_t>& weightingOp);
/**
* @brief Constructor for the weighted l1 norm, using a residual as input to map to
* a scalar
*
* @param[in] residual to be used when evaluating the functional (or its derivatives)
* @param[in] weightingOp container of the weights
*/
WeightedL1Norm(const Residual<data_t>& residual, const DataContainer<data_t>& weightingOp);
/// make copy constructor deletion explicit
WeightedL1Norm(const WeightedL1Norm<data_t>&) = delete;
/// default destructor
~WeightedL1Norm() override = default;
/// returns the weighting operator
const DataContainer<data_t>& getWeightingOperator() const;
protected:
/// the evaluation of the weighted l1 norm
data_t evaluateImpl(const DataContainer<data_t>& Rx) override;
/// the computation of the gradient (in place)
void getGradientInPlaceImpl(DataContainer<data_t>& Rx) override;
/// the computation of the Hessian
LinearOperator<data_t> getHessianImpl(const DataContainer<data_t>& Rx) override;
/// implement the polymorphic clone operation
WeightedL1Norm<data_t>* cloneImpl() const override;
/// implement the polymorphic comparison operation
bool isEqual(const Functional<data_t>& other) const override;
private:
/// the weighting operator
DataContainer<data_t> _weightingOp;
};
} // namespace elsa
...@@ -26,3 +26,4 @@ ELSA_DOCTEST(Quadric) ...@@ -26,3 +26,4 @@ ELSA_DOCTEST(Quadric)
ELSA_DOCTEST(EmissionLogLikelihood) ELSA_DOCTEST(EmissionLogLikelihood)
ELSA_DOCTEST(TransmissionLogLikelihood) ELSA_DOCTEST(TransmissionLogLikelihood)
ELSA_DOCTEST(Constraint) ELSA_DOCTEST(Constraint)
ELSA_DOCTEST(WeightedL1Norm)
/**
* @file test_WeightedL1Norm.cpp
*
* @brief Tests for the WeightedL1Norm class
*
* @author Andi Braimllari
*/
#include "WeightedL1Norm.h"
#include "LinearResidual.h"
#include "Identity.h"
#include "VolumeDescriptor.h"
#include "testHelpers.h"
#include "TypeCasts.hpp"
#include <doctest/doctest.h>
using namespace elsa;
using namespace doctest;
TEST_SUITE_BEGIN("functionals");
TEST_CASE_TEMPLATE("WeightedL1Norm: Testing the weighted, l1 norm functional", TestType, float,
double)
{
using Vector = Eigen::Matrix<TestType, Eigen::Dynamic, 1>;
GIVEN("a linear residual and weights with a non-positive element")
{
IndexVector_t numCoeff(2);
numCoeff << 25, 27;
VolumeDescriptor dd(numCoeff);
Vector randomData(dd.getNumberOfCoefficients());
randomData.setRandom();
DataContainer<TestType> b(dd, randomData);
Identity<TestType> A(dd);
LinearResidual<TestType> linRes(A, b);
// scaling operator
DataContainer<TestType> scaleFactors(dd);
scaleFactors = 1;
scaleFactors[3] = -8;
WHEN("instantiating an WeightedL1Norm object")
{
THEN("an InvalidArgumentError is thrown")
{
REQUIRE_THROWS_AS(WeightedL1Norm<TestType>{scaleFactors}, InvalidArgumentError);
REQUIRE_THROWS_AS(WeightedL1Norm<TestType>(linRes, scaleFactors),
InvalidArgumentError);
}
}
}
GIVEN("weights of value 1 and no residual")
{
IndexVector_t numCoeff(2);
numCoeff << 7, 17;
VolumeDescriptor dd(numCoeff);
DataContainer<TestType> scaleFactors(dd);
scaleFactors = 1;
WHEN("instantiating an WeightedL1Norm object")
{
WeightedL1Norm<TestType> func(scaleFactors);
THEN("the functional is as expected")
{
REQUIRE(func.getDomainDescriptor() == dd);
REQUIRE(func.getWeightingOperator() == scaleFactors);
const auto* linRes =
dynamic_cast<const LinearResidual<TestType>*>(&func.getResidual());
REQUIRE(linRes);
REQUIRE(linRes->hasOperator() == false);
REQUIRE(linRes->hasDataVector() == false);
}
THEN("a clone behaves as expected")
{
auto wl1Clone = func.clone();
REQUIRE(wl1Clone.get() != &func);
REQUIRE(*wl1Clone == func);
}
Vector dataVec(dd.getNumberOfCoefficients());
dataVec.setRandom();
DataContainer<TestType> x(dd, dataVec);
THEN("the evaluate works as expected")
{
REQUIRE(func.evaluate(x) == Approx(scaleFactors.dot(cwiseAbs(x))));
}
THEN("the gradient and Hessian throw as expected")
{
REQUIRE_THROWS_AS(func.getGradient(x), LogicError);
REQUIRE_THROWS_AS(func.getHessian(x), LogicError);
}
}
}
GIVEN("different sizes of the linear residual and weighting operator")
{
// linear residual
IndexVector_t numCoeff(2);
numCoeff << 47, 11;
VolumeDescriptor dd(numCoeff);
// linear residual
IndexVector_t otherNumCoeff(3);
otherNumCoeff << 15, 24, 4;
VolumeDescriptor otherDD(otherNumCoeff);
Vector randomData(dd.getNumberOfCoefficients());
randomData.setRandom();
DataContainer<TestType> b(dd, randomData);
Identity<TestType> A(dd);
LinearResidual<TestType> linRes(A, b);
// scaling operator
DataContainer<TestType> scaleFactors(otherDD);
scaleFactors = 1;
WHEN("instantiating an WeightedL1Norm object")
{
THEN("an InvalidArgumentError is thrown")
{
REQUIRE_THROWS_AS(WeightedL1Norm<TestType>(linRes, scaleFactors),
InvalidArgumentError);
}
}
}
GIVEN("weights of value 1 and a linear residual")
{
// linear residual
IndexVector_t numCoeff(2);
numCoeff << 47, 11;
VolumeDescriptor dd(numCoeff);
Vector randomData(dd.getNumberOfCoefficients());
randomData.setRandom();
DataContainer<TestType> b(dd, randomData);
Identity<TestType> A(dd);
LinearResidual<TestType> linRes(A, b);
// scaling operator
DataContainer<TestType> scaleFactors(dd);
scaleFactors = 1;
WHEN("instantiating an WeightedL1Norm object")
{
WeightedL1Norm<TestType> func(linRes, scaleFactors);
THEN("the functional is as expected")
{
REQUIRE(func.getDomainDescriptor() == dd);
REQUIRE(func.getWeightingOperator() == scaleFactors);
const auto* lRes =
dynamic_cast<const LinearResidual<TestType>*>(&func.getResidual());
REQUIRE(lRes);
REQUIRE(*lRes == linRes);
}
THEN("a clone behaves as expected")
{
auto wl1Clone = func.clone();
REQUIRE(wl1Clone.get() != &func);
REQUIRE(*wl1Clone == func);
}
THEN("the evaluate, gradient and Hessian work was expected")
{
Vector dataVec(dd.getNumberOfCoefficients());
dataVec.setRandom();
DataContainer<TestType> x(dd, dataVec);
REQUIRE(func.evaluate(x) == Approx(scaleFactors.dot(cwiseAbs(x - b))));
REQUIRE_THROWS_AS(func.getGradient(x), LogicError);
REQUIRE_THROWS_AS(func.getHessian(x), LogicError);
}
}
}
}
TEST_SUITE_END();
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment