Commit a5f6e13c authored by Andi Braimllari's avatar Andi Braimllari

Merge branch 'proximity_operators' into 'master'

Proximity Operators

See merge request !105
parents 939f1176 e459aa55
Pipeline #367974 passed with stages
in 35 minutes and 26 seconds
......@@ -190,6 +190,7 @@ add_subdirectory(io)
add_subdirectory(operators)
add_subdirectory(functionals)
add_subdirectory(problems)
add_subdirectory(proximity_operators)
add_subdirectory(solvers)
add_subdirectory(projectors)
if(ELSA_BUILD_CUDA_PROJECTORS)
......
......@@ -624,6 +624,28 @@ namespace elsa
using SinogramData2D = SinogramData<2>; ///< 2D sinogram data alias for 2D geometry
using SinogramData3D = SinogramData<3>; ///< 3D sinogram data alias for 3D geometry
template <typename data_t = real_t>
class Threshold
{
public:
explicit Threshold(data_t threshold) : _threshold(threshold)
{
if (threshold <= 0) {
throw std::invalid_argument("threshold must be strictly greater than 0");
}
}
auto operator<(const data_t t) const -> bool { return this->_threshold < t; };
auto operator>(const data_t t) const -> bool { return this->_threshold > t; };
auto operator-(const data_t t) const -> data_t { return this->_threshold - t; };
auto operator+(const data_t t) const -> data_t { return this->_threshold + t; };
private:
data_t _threshold;
};
} // namespace geometry
} // namespace elsa
......
......@@ -42,6 +42,11 @@
#include "FiniteDifferences.h"
#include "BlockLinearOperator.h"
// Proximity Operator headers
#include "ProximityOperator.h"
#include "SoftThresholding.h"
#include "HardThresholding.h"
// Problem headers
#include "Problem.h"
#include "RegularizationTerm.h"
......
cmake_minimum_required(VERSION 3.9)
# set the name of the module
set(ELSA_MODULE_NAME proximity_operators)
set(ELSA_MODULE_TARGET_NAME elsa_proximity_operators)
set(ELSA_MODULE_EXPORT_TARGET elsa_${ELSA_MODULE_NAME}Targets)
# list all the headers of the module
set(MODULE_HEADERS
ProximityOperator.h
SoftThresholding.h
HardThresholding.h)
# list all the code files of the module
set(MODULE_SOURCES
ProximityOperator.cpp
SoftThresholding.cpp
HardThresholding.cpp)
# build the module library
add_library(${ELSA_MODULE_TARGET_NAME} ${MODULE_HEADERS} ${MODULE_SOURCES})
add_library(elsa::${ELSA_MODULE_NAME} ALIAS ${ELSA_MODULE_TARGET_NAME})
target_link_libraries(${ELSA_MODULE_TARGET_NAME} PUBLIC elsa_core elsa_logging)
# use OpenMP is available
find_package(OpenMP)
if (OpenMP_CXX_FOUND)
target_link_libraries(${ELSA_MODULE_TARGET_NAME} PRIVATE OpenMP::OpenMP_CXX)
endif ()
target_include_directories(${ELSA_MODULE_TARGET_NAME}
PUBLIC
$<INSTALL_INTERFACE:include/elsa/${ELSA_MODULE_NAME}>
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}>
)
# require C++17
target_compile_features(${ELSA_MODULE_TARGET_NAME} PUBLIC cxx_std_17)
# set -fPIC
set_target_properties(${ELSA_MODULE_TARGET_NAME} PROPERTIES POSITION_INDEPENDENT_CODE ON)
# build the tests (if enabled)
if (ELSA_TESTING)
add_subdirectory(tests)
endif (ELSA_TESTING)
if (ELSA_BUILD_PYTHON_BINDINGS)
GENERATE_BINDINGS(${ELSA_MODULE_TARGET_NAME}
bind_${ELSA_MODULE_NAME}.cpp
${PROJECT_SOURCE_DIR}/tools/bindings_generation/hints/${ELSA_MODULE_NAME}_hints.hpp
${MODULE_SOURCES})
endif ()
# register the module
registerComponent(${ELSA_MODULE_NAME})
# install the module
InstallElsaModule(${ELSA_MODULE_NAME} ${ELSA_MODULE_TARGET_NAME} ${ELSA_MODULE_EXPORT_TARGET})
#include "HardThresholding.h"
namespace elsa
{
template <typename data_t>
HardThresholding<data_t>::HardThresholding(const DataDescriptor& descriptor)
: ProximityOperator<data_t>(descriptor)
{
}
template <typename data_t>
void HardThresholding<data_t>::applyImpl(const DataContainer<data_t>& v,
geometry::Threshold<data_t> t,
DataContainer<data_t>& prox) const
{
if (v.getSize() != prox.getSize()) {
throw std::logic_error("HardThresholding: sizes of v and prox must match");
}
auto vIter = v.begin();
auto proxIter = prox.begin();
for (; vIter != v.end() && proxIter != prox.end(); vIter++, proxIter++) {
if ((t < *vIter) || (t < -(*vIter))) {
*proxIter = *vIter;
} else {
*proxIter = 0;
}
}
}
template <typename data_t>
auto HardThresholding<data_t>::cloneImpl() const -> HardThresholding<data_t>*
{
return new HardThresholding<data_t>(this->getRangeDescriptor());
}
template <typename data_t>
auto HardThresholding<data_t>::isEqual(const ProximityOperator<data_t>& other) const -> bool
{
if (!ProximityOperator<data_t>::isEqual(other)) {
return false;
}
auto otherDerived = dynamic_cast<const HardThresholding<data_t>*>(&other);
return static_cast<bool>(otherDerived);
}
// ------------------------------------------
// explicit template instantiation
template class HardThresholding<float>;
template class HardThresholding<double>;
} // namespace elsa
#include "ProximityOperator.h"
namespace elsa
{
template <typename data_t = real_t>
class HardThresholding : public ProximityOperator<data_t>
{
public:
HardThresholding(const DataDescriptor& descriptor);
/// default destructor
~HardThresholding() override = default;
protected:
/**
* \brief apply the proximity operator of the ℓ0 norm to an element in the operator's domain
*
* \param[in] x input DataContainer
* \param[in] t input Threshold
* \param[out] prox output DataContainer
*/
void applyImpl(const DataContainer<data_t>& v, geometry::Threshold<data_t> t,
DataContainer<data_t>& prox) const override;
/// implement the polymorphic clone operation
auto cloneImpl() const -> HardThresholding<data_t>* override;
/// implement the polymorphic comparison operation
auto isEqual(const ProximityOperator<data_t>& other) const -> bool override;
};
} // namespace elsa
#include "ProximityOperator.h"
#include "Timer.h"
namespace elsa
{
template <typename data_t>
ProximityOperator<data_t>::ProximityOperator(const DataDescriptor& descriptor)
: _rangeDescriptor{descriptor.clone()}
{
}
template <typename data_t>
auto ProximityOperator<data_t>::getRangeDescriptor() const -> const DataDescriptor&
{
return *_rangeDescriptor;
}
template <typename data_t>
auto ProximityOperator<data_t>::apply(const DataContainer<data_t>& x,
geometry::Threshold<data_t> t) const
-> DataContainer<data_t>
{
Timer timeguard("ProximityOperator", "apply");
DataContainer<data_t> prox(*_rangeDescriptor, x.getDataHandlerType());
apply(x, t, prox);
return prox;
}
template <typename data_t>
void ProximityOperator<data_t>::apply(const DataContainer<data_t>& v,
geometry::Threshold<data_t> t,
DataContainer<data_t>& prox) const
{
Timer timeguard("ProximityOperator", "apply");
applyImpl(v, t, prox);
}
template <typename data_t>
auto ProximityOperator<data_t>::isEqual(const ProximityOperator<data_t>& other) const -> bool
{
return static_cast<bool>(*_rangeDescriptor == *other._rangeDescriptor);
}
// ------------------------------------------
// explicit template instantiation
template class ProximityOperator<float>;
template class ProximityOperator<double>;
} // namespace elsa
#pragma once
#include "elsaDefines.h"
#include "Cloneable.h"
#include "DataContainer.h"
#include "DataDescriptor.h"
#include "StrongTypes.h"
namespace elsa
{
/**
* \brief Base class representing a proximity operator prox.
*
* \tparam data_t data type for the domain and range of the operator, defaulting to real_t
*
* This class represents a proximity operator prox, expressed through its apply methods,
* which implement the proximity operator of f with penalty ρ i.e.
* prox<SUB>f,ρ</SUB>(v) = argmin<SUB>x</SUB>(f(x) + (ρ/2)(||x - v||<SUP>2</SUP><SUB>2</SUB>)).
* Concrete implementations of proximity operators will derive from this class and override the
* applyImpl method.
*/
template <typename data_t = real_t>
class ProximityOperator : public Cloneable<ProximityOperator<data_t>>
{
public:
ProximityOperator() = delete;
ProximityOperator(const DataDescriptor& descriptor);
ProximityOperator(const ProximityOperator<data_t>&) = delete;
~ProximityOperator() override = default;
auto getRangeDescriptor() const -> const DataDescriptor&;
/**
* \brief apply the proximity operator to an element in the operator's domain
*
* \param[in] v input DataContainer
* \param[in] t input Threshold
*
* \returns prox DataContainer containing the application of the proximity operator to
* data v, i.e. in the range of the operator
*
* Please note: this method uses apply(v, t, prox(v)) to perform the actual operation.
*/
auto apply(const DataContainer<data_t>& v, geometry::Threshold<data_t> t) const
-> DataContainer<data_t>;
/**
* \brief apply the proximity operator to an element in the operator's domain
*
* \param[in] v input DataContainer
* \param[in] t input Threshold
* \param[out] prox output DataContainer
*
* Please note: this method calls the method applyImpl that has to be overridden in derived
* classes. (Why is this method not virtual itself? Because you cannot have a non-virtual
* function overloading a virtual one [apply with one vs. two arguments]).
*/
void apply(const DataContainer<data_t>& v, geometry::Threshold<data_t> t,
DataContainer<data_t>& prox) const;
auto operator=(const ProximityOperator&) -> ProximityOperator& = delete;
protected:
std::unique_ptr<DataDescriptor> _rangeDescriptor;
/// the apply method that has to be overridden in derived classes
virtual void applyImpl(const DataContainer<data_t>& v, geometry::Threshold<data_t> t,
DataContainer<data_t>& prox) const = 0;
auto isEqual(const ProximityOperator<data_t>& other) const -> bool override;
};
} // namespace elsa
#include "SoftThresholding.h"
namespace elsa
{
template <typename data_t>
SoftThresholding<data_t>::SoftThresholding(const DataDescriptor& descriptor)
: ProximityOperator<data_t>(descriptor)
{
}
template <typename data_t>
void SoftThresholding<data_t>::applyImpl(const DataContainer<data_t>& v,
geometry::Threshold<data_t> t,
DataContainer<data_t>& prox) const
{
if (v.getSize() != prox.getSize()) {
throw std::logic_error("SoftThresholding: sizes of v and prox must match");
}
auto vIter = v.begin();
auto proxIter = prox.begin();
for (; vIter != v.end() && proxIter != prox.end(); vIter++, proxIter++) {
if (t < *vIter) {
*proxIter = -(t - *vIter);
} else if (t < -(*vIter)) {
*proxIter = t + *vIter;
} else {
*proxIter = 0;
}
}
}
template <typename data_t>
auto SoftThresholding<data_t>::cloneImpl() const -> SoftThresholding<data_t>*
{
return new SoftThresholding<data_t>(this->getRangeDescriptor());
}
template <typename data_t>
auto SoftThresholding<data_t>::isEqual(const ProximityOperator<data_t>& other) const -> bool
{
if (!ProximityOperator<data_t>::isEqual(other)) {
return false;
}
auto otherDerived = dynamic_cast<const SoftThresholding<data_t>*>(&other);
return static_cast<bool>(otherDerived);
}
// ------------------------------------------
// explicit template instantiation
template class SoftThresholding<float>;
template class SoftThresholding<double>;
} // namespace elsa
#include "ProximityOperator.h"
namespace elsa
{
template <typename data_t = real_t>
class SoftThresholding : public ProximityOperator<data_t>
{
public:
SoftThresholding(const DataDescriptor& descriptor);
/// default destructor
~SoftThresholding() override = default;
protected:
/**
* \brief apply the proximity operator of the ℓ1 norm to an element in the operator's domain
*
* \param[in] x input DataContainer
* \param[in] t input Threshold
* \param[out] prox output DataContainer
*/
void applyImpl(const DataContainer<data_t>& v, geometry::Threshold<data_t> t,
DataContainer<data_t>& prox) const override;
/// implement the polymorphic clone operation
auto cloneImpl() const -> SoftThresholding<data_t>* override;
/// implement the polymorphic comparison operation
auto isEqual(const ProximityOperator<data_t>& other) const -> bool override;
};
} // namespace elsa
cmake_minimum_required(VERSION 3.9)
# enable ctest and Catch test discovery
include(CTest)
include(Catch)
# the actual tests
ELSA_TEST(SoftThresholding)
ELSA_TEST(HardThresholding)
#include "HardThresholding.h"
#include "VolumeDescriptor.h"
#include <catch2/catch.hpp>
#include <testHelpers.h>
using namespace elsa;
SCENARIO("Constructing HardThresholding")
{
GIVEN("a DataDescriptor")
{
IndexVector_t numCoeff(3);
numCoeff << 8, 4, 52;
VolumeDescriptor volDescr(numCoeff);
WHEN("instantiating a HardThresholding operator")
{
HardThresholding<real_t> hThrOp(volDescr);
THEN("the DataDescriptors are equal")
{
REQUIRE(hThrOp.getRangeDescriptor() == volDescr);
}
}
WHEN("cloning a HardThresholding operator")
{
HardThresholding<real_t> hThrOp(volDescr);
auto hThrOpClone = hThrOp.clone();
THEN("cloned HardThresholding operator equals original HardThresholding operator")
{
REQUIRE(hThrOpClone.get() != &hThrOp);
REQUIRE(*hThrOpClone == hThrOp);
}
}
}
}
SCENARIO("Using HardThresholding in 1D")
{
GIVEN("a DataDescriptor")
{
IndexVector_t numCoeff(1);
numCoeff << 8;
VolumeDescriptor volDescr(numCoeff);
WHEN("Using HardThresholding operator in 1D")
{
HardThresholding<real_t> hThrOp(volDescr);
THEN("Values under threshold=4 are 0 and values above remain the same")
{
RealVector_t data(volDescr.getNumberOfCoefficients());
data << -2, 3, 4, -7, 7, 8, 8, 3;
DataContainer<real_t> dC(volDescr, data);
RealVector_t expectedRes(hThrOp.getRangeDescriptor().getNumberOfCoefficients());
expectedRes << 0, 0, 0, -7, 7, 8, 8, 0;
DataContainer<real_t> dCRes(hThrOp.getRangeDescriptor(), expectedRes);
REQUIRE(isApprox(dCRes, hThrOp.apply(dC, geometry::Threshold<real_t>{4})));
}
}
}
}
SCENARIO("Using HardThresholding in 3D")
{
GIVEN("a DataDescriptor")
{
IndexVector_t numCoeff(3);
numCoeff << 3, 2, 3;
VolumeDescriptor volumeDescriptor(numCoeff);
WHEN("Using HardThresholding operator in 3D")
{
HardThresholding<real_t> hThrOp(volumeDescriptor);
THEN("Values under threshold=5 are 0 and values above remain the same")
{
RealVector_t data(volumeDescriptor.getNumberOfCoefficients());
data << 2, 1, 6, 6, 1, 4, 2, -9, 7, 7, 7, 3, 1, 2, 8, 9, -4, 5;
DataContainer<real_t> dC(volumeDescriptor, data);
RealVector_t expectedRes(hThrOp.getRangeDescriptor().getNumberOfCoefficients());
expectedRes << 0, 0, 6, 6, 0, 0, 0, -9, 7, 7, 7, 0, 0, 0, 8, 9, 0, 0;
DataContainer<real_t> dCRes(hThrOp.getRangeDescriptor(), expectedRes);
REQUIRE(isApprox(dCRes, hThrOp.apply(dC, geometry::Threshold<real_t>{5})));
}
}
}
}
SCENARIO("Using HardThresholding")
{
GIVEN("a DataDescriptor")
{
IndexVector_t numCoeff(1);
numCoeff << 8;
VolumeDescriptor volDescr(numCoeff);
WHEN("Using HardThresholding operator")
{
HardThresholding<real_t> hThrOp(volDescr);
THEN("The zero vector is returned when the zero vector is given")
{
RealVector_t data(volDescr.getNumberOfCoefficients());
data << 0, 0, 0, 0, 0, 0, 0, 0;
DataContainer<real_t> dataContainer(volDescr, data);
RealVector_t expectedRes(hThrOp.getRangeDescriptor().getNumberOfCoefficients());
expectedRes << 0, 0, 0, 0, 0, 0, 0, 0;
DataContainer<real_t> dCRes(hThrOp.getRangeDescriptor(), expectedRes);
REQUIRE(
isApprox(dCRes, hThrOp.apply(dataContainer, geometry::Threshold<real_t>{4})));
}
THEN("HardThresholding operator throws exception for t = 0")
{
RealVector_t data(volDescr.getNumberOfCoefficients());
data << 0, 0, 0, 0, 0, 0, 0, 0;
DataContainer<real_t> dC(volDescr, data);
REQUIRE_THROWS_AS(hThrOp.apply(dC, geometry::Threshold<real_t>{0}),
std::invalid_argument);
}
THEN("HardThresholding operator throws exception for t < 0")
{
RealVector_t data(volDescr.getNumberOfCoefficients());
data << 0, 0, 0, 0, 0, 0, 0, 0;
DataContainer<real_t> dC(volDescr, data);
REQUIRE_THROWS_AS(hThrOp.apply(dC, geometry::Threshold<real_t>{-1}),
std::invalid_argument);
}
THEN("HardThresholding operator throws exception for differently sized v and prox")
{
RealVector_t data(volDescr.getNumberOfCoefficients());
data << 0, 0, 0, 0, 0, 0, 0, 0;
DataContainer<real_t> dC(volDescr, data);
IndexVector_t numCoeff1(1);
numCoeff1 << 9;
VolumeDescriptor volDescr1(numCoeff1);
RealVector_t data1(volDescr1.getNumberOfCoefficients());
data1 << 0, 0, 0, 0, 0, 0, 0, 0, 0;
DataContainer<real_t> dC1(volDescr1, data1);
REQUIRE_THROWS_AS(hThrOp.apply(dC, geometry::Threshold<real_t>{-1}, dC1),
std::logic_error);
}
}
}
}
#include "SoftThresholding.h"
#include "VolumeDescriptor.h"
#include <catch2/catch.hpp>
#include <testHelpers.h>
using namespace elsa;
SCENARIO("Constructing SoftThresholding")
{
GIVEN("a DataDescriptor")
{
IndexVector_t numCoeff(3);
numCoeff << 45, 11, 7;
VolumeDescriptor volDescr(numCoeff);
WHEN("instantiating a SoftThresholding operator")
{
SoftThresholding<real_t> sThrOp(volDescr);
THEN("the DataDescriptors are equal")
{
REQUIRE(sThrOp.getRangeDescriptor() == volDescr);
}
}
WHEN("cloning a SoftThresholding operator")
{
SoftThresholding<real_t> sThrOp(volDescr);
auto sThrOpClone = sThrOp.clone();