22.1.2021, 9:00 - 11:00: Due to updates GitLab may be unavailable for some minutes between 09:00 and 11:00.

Commit 1d616932 authored by Andi Braimllari's avatar Andi Braimllari

Merge branch 'proximity_operators' into 'master'

Proximity Operators

See merge request !105
parents 939f1176 41d5b1d2
Pipeline #368627 passed with stages
in 38 minutes and 5 seconds
......@@ -190,6 +190,7 @@ add_subdirectory(io)
add_subdirectory(operators)
add_subdirectory(functionals)
add_subdirectory(problems)
add_subdirectory(proximity_operators)
add_subdirectory(solvers)
add_subdirectory(projectors)
if(ELSA_BUILD_CUDA_PROJECTORS)
......
......@@ -624,6 +624,41 @@ namespace elsa
using SinogramData2D = SinogramData<2>; ///< 2D sinogram data alias for 2D geometry
using SinogramData3D = SinogramData<3>; ///< 3D sinogram data alias for 3D geometry
/**
* @brief Strong type for a single value of type data_t used in proximity operators.
* Comparison, addition, subtraction are overridden by utilizing the private member
* _threshold.
* N.B. The threshold value is expected to be strictly greater than 0, otherwise an
* exception is thrown
*
* @tparam data_t data type of the threshold
*/
template <typename data_t = real_t>
class Threshold
{
public:
explicit Threshold(data_t threshold) : _threshold(threshold)
{
if (threshold <= 0) {
throw std::invalid_argument("threshold must be strictly greater than 0");
}
}
/// return computed comparison
auto operator<(const data_t t) const -> bool { return this->_threshold < t; };
/// return computed comparison
auto operator>(const data_t t) const -> bool { return this->_threshold > t; };
/// return computed subtraction
auto operator-(const data_t t) const -> data_t { return this->_threshold - t; };
/// return computed addition
auto operator+(const data_t t) const -> data_t { return this->_threshold + t; };
private:
data_t _threshold;
};
} // namespace geometry
} // namespace elsa
......
......@@ -42,6 +42,11 @@
#include "FiniteDifferences.h"
#include "BlockLinearOperator.h"
// Proximity Operator headers
#include "ProximityOperator.h"
#include "SoftThresholding.h"
#include "HardThresholding.h"
// Problem headers
#include "Problem.h"
#include "RegularizationTerm.h"
......
cmake_minimum_required(VERSION 3.9)
# set the name of the module
set(ELSA_MODULE_NAME proximity_operators)
set(ELSA_MODULE_TARGET_NAME elsa_proximity_operators)
set(ELSA_MODULE_EXPORT_TARGET elsa_${ELSA_MODULE_NAME}Targets)
# list all the headers of the module
set(MODULE_HEADERS
ProximityOperator.h
SoftThresholding.h
HardThresholding.h)
# list all the code files of the module
set(MODULE_SOURCES
ProximityOperator.cpp
SoftThresholding.cpp
HardThresholding.cpp)
# build the module library
add_library(${ELSA_MODULE_TARGET_NAME} ${MODULE_HEADERS} ${MODULE_SOURCES})
add_library(elsa::${ELSA_MODULE_NAME} ALIAS ${ELSA_MODULE_TARGET_NAME})
target_link_libraries(${ELSA_MODULE_TARGET_NAME} PUBLIC elsa_core elsa_logging)
# use OpenMP is available
find_package(OpenMP)
if (OpenMP_CXX_FOUND)
target_link_libraries(${ELSA_MODULE_TARGET_NAME} PRIVATE OpenMP::OpenMP_CXX)
endif ()
target_include_directories(${ELSA_MODULE_TARGET_NAME}
PUBLIC
$<INSTALL_INTERFACE:include/elsa/${ELSA_MODULE_NAME}>
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}>
)
# require C++17
target_compile_features(${ELSA_MODULE_TARGET_NAME} PUBLIC cxx_std_17)
# set -fPIC
set_target_properties(${ELSA_MODULE_TARGET_NAME} PROPERTIES POSITION_INDEPENDENT_CODE ON)
# build the tests (if enabled)
if (ELSA_TESTING)
add_subdirectory(tests)
endif (ELSA_TESTING)
if (ELSA_BUILD_PYTHON_BINDINGS)
GENERATE_BINDINGS(${ELSA_MODULE_TARGET_NAME}
bind_${ELSA_MODULE_NAME}.cpp
${PROJECT_SOURCE_DIR}/tools/bindings_generation/hints/${ELSA_MODULE_NAME}_hints.hpp
${MODULE_SOURCES})
endif ()
# register the module
registerComponent(${ELSA_MODULE_NAME})
# install the module
InstallElsaModule(${ELSA_MODULE_NAME} ${ELSA_MODULE_TARGET_NAME} ${ELSA_MODULE_EXPORT_TARGET})
#include "HardThresholding.h"
namespace elsa
{
template <typename data_t>
HardThresholding<data_t>::HardThresholding(const DataDescriptor& descriptor)
: ProximityOperator<data_t>(descriptor)
{
}
template <typename data_t>
void HardThresholding<data_t>::applyImpl(const DataContainer<data_t>& v,
geometry::Threshold<data_t> t,
DataContainer<data_t>& prox) const
{
if (v.getSize() != prox.getSize()) {
throw std::logic_error("HardThresholding: sizes of v and prox must match");
}
auto vIter = v.begin();
auto proxIter = prox.begin();
for (; vIter != v.end() && proxIter != prox.end(); vIter++, proxIter++) {
if ((t < *vIter) || (t < -(*vIter))) {
*proxIter = *vIter;
} else {
*proxIter = 0;
}
}
}
template <typename data_t>
auto HardThresholding<data_t>::cloneImpl() const -> HardThresholding<data_t>*
{
return new HardThresholding<data_t>(this->getRangeDescriptor());
}
template <typename data_t>
auto HardThresholding<data_t>::isEqual(const ProximityOperator<data_t>& other) const -> bool
{
if (!ProximityOperator<data_t>::isEqual(other)) {
return false;
}
auto otherDerived = dynamic_cast<const HardThresholding<data_t>*>(&other);
return static_cast<bool>(otherDerived);
}
// ------------------------------------------
// explicit template instantiation
template class HardThresholding<float>;
template class HardThresholding<double>;
} // namespace elsa
#include "ProximityOperator.h"
namespace elsa
{
/**
* \brief Class representing the proximity operator of the l1 norm
*
* \tparam data_t data type for the values of the operator, defaulting to real_t
*
* This class represents the soft thresholding operator, expressed by its apply method
* through the function i.e. \f$ prox(v) = v·1_{\{|v| > t\}}. \f$
*
* References:
* http://sfb649.wiwi.hu-berlin.de/fedc_homepage/xplore/tutorials/xlghtmlnode93.html
*/
template <typename data_t = real_t>
class HardThresholding : public ProximityOperator<data_t>
{
public:
/**
* \brief Construct a HardThresholding operator from the given DataDescriptor
*
* \param[in] descriptor DataDescriptor describing the operator values
*/
HardThresholding(const DataDescriptor& descriptor);
/// default destructor
~HardThresholding() override = default;
protected:
/**
* \brief apply the proximity operator of the l0 norm to an element in the operator's domain
*
* \param[in] x input DataContainer
* \param[in] t input Threshold
* \param[out] prox output DataContainer
*/
void applyImpl(const DataContainer<data_t>& v, geometry::Threshold<data_t> t,
DataContainer<data_t>& prox) const override;
/// implement the polymorphic clone operation
auto cloneImpl() const -> HardThresholding<data_t>* override;
/// implement the polymorphic comparison operation
auto isEqual(const ProximityOperator<data_t>& other) const -> bool override;
};
} // namespace elsa
#include "ProximityOperator.h"
#include "Timer.h"
namespace elsa
{
template <typename data_t>
ProximityOperator<data_t>::ProximityOperator(const DataDescriptor& descriptor)
: _rangeDescriptor{descriptor.clone()}
{
}
template <typename data_t>
auto ProximityOperator<data_t>::getRangeDescriptor() const -> const DataDescriptor&
{
return *_rangeDescriptor;
}
template <typename data_t>
auto ProximityOperator<data_t>::apply(const DataContainer<data_t>& x,
geometry::Threshold<data_t> t) const
-> DataContainer<data_t>
{
Timer timeguard("ProximityOperator", "apply");
DataContainer<data_t> prox(*_rangeDescriptor, x.getDataHandlerType());
apply(x, t, prox);
return prox;
}
template <typename data_t>
void ProximityOperator<data_t>::apply(const DataContainer<data_t>& v,
geometry::Threshold<data_t> t,
DataContainer<data_t>& prox) const
{
Timer timeguard("ProximityOperator", "apply");
applyImpl(v, t, prox);
}
template <typename data_t>
auto ProximityOperator<data_t>::isEqual(const ProximityOperator<data_t>& other) const -> bool
{
return static_cast<bool>(*_rangeDescriptor == *other._rangeDescriptor);
}
// ------------------------------------------
// explicit template instantiation
template class ProximityOperator<float>;
template class ProximityOperator<double>;
} // namespace elsa
#pragma once
#include "elsaDefines.h"
#include "Cloneable.h"
#include "DataContainer.h"
#include "DataDescriptor.h"
#include "StrongTypes.h"
namespace elsa
{
/**
* \brief Base class representing a proximity operator prox.
*
* \tparam data_t data type for the values of the operator, defaulting to real_t
*
* This class represents a proximity operator prox, expressed through its apply methods,
* which implement the proximity operator of f with penalty r i.e.
* \f$ prox_{f,r}(v) = argmin_x(f(x) + (r/2)·(\| Ax - b \|^2_2)). \f$
*
* Concrete implementations of proximity operators will derive from this class and override the
* applyImpl method.
*/
template <typename data_t = real_t>
class ProximityOperator : public Cloneable<ProximityOperator<data_t>>
{
public:
/// delete no-args constructor to prevent creation of an object without a DataDescriptor
ProximityOperator() = delete;
/**
* \brief Override to construct an actual proximity operator for one of the derived classes
* from the given DataDescriptor descriptor
*
* \param[in] descriptor DataDescriptor describing the operator values
*/
ProximityOperator(const DataDescriptor& descriptor);
/// delete copy construction
ProximityOperator(const ProximityOperator<data_t>&) = delete;
/// default destructor
~ProximityOperator() override = default;
/// return the DataDescriptor
auto getRangeDescriptor() const -> const DataDescriptor&;
/**
* \brief apply the proximity operator to an element in the operator's domain
*
* \param[in] v input DataContainer
* \param[in] t input Threshold
*
* \returns prox DataContainer containing the application of the proximity operator to
* data v, i.e. in the range of the operator
*
* Please note: this method uses apply(v, t, prox(v)) to perform the actual operation.
*/
auto apply(const DataContainer<data_t>& v, geometry::Threshold<data_t> t) const
-> DataContainer<data_t>;
/**
* \brief apply the proximity operator to an element in the operator's domain
*
* \param[in] v input DataContainer
* \param[in] t input Threshold
* \param[out] prox output DataContainer
*
* Please note: this method calls the method applyImpl that has to be overridden in derived
* classes. (Why is this method not virtual itself? Because you cannot have a non-virtual
* function overloading a virtual one [apply with one vs. two arguments]).
*/
void apply(const DataContainer<data_t>& v, geometry::Threshold<data_t> t,
DataContainer<data_t>& prox) const;
/// delete implicitly declared copy assignment to prevent copy assignment
auto operator=(const ProximityOperator&) -> ProximityOperator& = delete;
protected:
std::unique_ptr<DataDescriptor> _rangeDescriptor;
/// the apply method that has to be overridden in derived classes
virtual void applyImpl(const DataContainer<data_t>& v, geometry::Threshold<data_t> t,
DataContainer<data_t>& prox) const = 0;
/// overridden comparison method based on the DataDescriptor
auto isEqual(const ProximityOperator<data_t>& other) const -> bool override;
};
} // namespace elsa
#include "SoftThresholding.h"
namespace elsa
{
template <typename data_t>
SoftThresholding<data_t>::SoftThresholding(const DataDescriptor& descriptor)
: ProximityOperator<data_t>(descriptor)
{
}
template <typename data_t>
void SoftThresholding<data_t>::applyImpl(const DataContainer<data_t>& v,
geometry::Threshold<data_t> t,
DataContainer<data_t>& prox) const
{
if (v.getSize() != prox.getSize()) {
throw std::logic_error("SoftThresholding: sizes of v and prox must match");
}
auto vIter = v.begin();
auto proxIter = prox.begin();
for (; vIter != v.end() && proxIter != prox.end(); vIter++, proxIter++) {
if (t < *vIter) {
*proxIter = -(t - *vIter);
} else if (t < -(*vIter)) {
*proxIter = t + *vIter;
} else {
*proxIter = 0;
}
}
}
template <typename data_t>
auto SoftThresholding<data_t>::cloneImpl() const -> SoftThresholding<data_t>*
{
return new SoftThresholding<data_t>(this->getRangeDescriptor());
}
template <typename data_t>
auto SoftThresholding<data_t>::isEqual(const ProximityOperator<data_t>& other) const -> bool
{
if (!ProximityOperator<data_t>::isEqual(other)) {
return false;
}
auto otherDerived = dynamic_cast<const SoftThresholding<data_t>*>(&other);
return static_cast<bool>(otherDerived);
}
// ------------------------------------------
// explicit template instantiation
template class SoftThresholding<float>;
template class SoftThresholding<double>;
} // namespace elsa
#include "ProximityOperator.h"
namespace elsa
{
/**
* \brief Class representing the proximity operator of the l1 norm
*
* \tparam data_t data type for the values of the operator, defaulting to real_t
*
* This class represents the soft thresholding operator, expressed by its apply method
* through the function i.e. \f$ prox(v) = sign(v)·(|v| - t)_+. \f$
*
* References:
* http://sfb649.wiwi.hu-berlin.de/fedc_homepage/xplore/tutorials/xlghtmlnode93.html
*/
template <typename data_t = real_t>
class SoftThresholding : public ProximityOperator<data_t>
{
public:
/**
* \brief Construct a SoftThresholding operator from the given DataDescriptor
*
* \param[in] descriptor DataDescriptor describing the operator values
*/
SoftThresholding(const DataDescriptor& descriptor);
/// default destructor
~SoftThresholding() override = default;
protected:
/**
* \brief apply the proximity operator of the l1 norm to an element in the operator's domain
*
* \param[in] x input DataContainer
* \param[in] t input Threshold
* \param[out] prox output DataContainer
*/
void applyImpl(const DataContainer<data_t>& v, geometry::Threshold<data_t> t,
DataContainer<data_t>& prox) const override;
/// implement the polymorphic clone operation
auto cloneImpl() const -> SoftThresholding<data_t>* override;
/// implement the polymorphic comparison operation
auto isEqual(const ProximityOperator<data_t>& other) const -> bool override;
};
} // namespace elsa
cmake_minimum_required(VERSION 3.9)
# enable ctest and Catch test discovery
include(CTest)
include(Catch)
# the actual tests
ELSA_TEST(SoftThresholding)
ELSA_TEST(HardThresholding)
#include "HardThresholding.h"
#include "VolumeDescriptor.h"
#include <catch2/catch.hpp>
#include <testHelpers.h>
using namespace elsa;
SCENARIO("Constructing HardThresholding")
{
GIVEN("a DataDescriptor")
{
IndexVector_t numCoeff(3);
numCoeff << 8, 4, 52;
VolumeDescriptor volDescr(numCoeff);
WHEN("instantiating a HardThresholding operator")
{
HardThresholding<real_t> hThrOp(volDescr);
THEN("the DataDescriptors are equal")
{
REQUIRE(hThrOp.getRangeDescriptor() == volDescr);
}
}
WHEN("cloning a HardThresholding operator")
{
HardThresholding<real_t> hThrOp(volDescr);
auto hThrOpClone = hThrOp.clone();
THEN("cloned HardThresholding operator equals original HardThresholding operator")
{
REQUIRE(hThrOpClone.get() != &hThrOp);
REQUIRE(*hThrOpClone == hThrOp);
}
}
}
}
SCENARIO("Using HardThresholding in 1D")
{
GIVEN("a DataDescriptor")
{
IndexVector_t numCoeff(1);
numCoeff << 8;
VolumeDescriptor volDescr(numCoeff);
WHEN("Using HardThresholding operator in 1D")
{
HardThresholding<real_t> hThrOp(volDescr);
THEN("Values under threshold=4 are 0 and values above remain the same")
{
RealVector_t data(volDescr.getNumberOfCoefficients());
data << -2, 3, 4, -7, 7, 8, 8, 3;
DataContainer<real_t> dC(volDescr, data);
RealVector_t expectedRes(hThrOp.getRangeDescriptor().getNumberOfCoefficients());
expectedRes << 0, 0, 0, -7, 7, 8, 8, 0;
DataContainer<real_t> dCRes(hThrOp.getRangeDescriptor(), expectedRes);
REQUIRE(isApprox(dCRes, hThrOp.apply(dC, geometry::Threshold<real_t>{4})));
}
}
}
}
SCENARIO("Using HardThresholding in 3D")
{
GIVEN("a DataDescriptor")
{
IndexVector_t numCoeff(3);
numCoeff << 3, 2, 3;
VolumeDescriptor volumeDescriptor(numCoeff);
WHEN("Using HardThresholding operator in 3D")
{
HardThresholding<real_t> hThrOp(volumeDescriptor);
THEN("Values under threshold=5 are 0 and values above remain the same")
{
RealVector_t data(volumeDescriptor.getNumberOfCoefficients());
data << 2, 1, 6, 6, 1, 4, 2, -9, 7, 7, 7, 3, 1, 2, 8, 9, -4, 5;
DataContainer<real_t> dC(volumeDescriptor, data);
RealVector_t expectedRes(hThrOp.getRangeDescriptor().getNumberOfCoefficients());
expectedRes << 0, 0, 6, 6, 0, 0, 0, -9, 7, 7, 7, 0, 0, 0, 8, 9, 0, 0;
DataContainer<real_t> dCRes(hThrOp.getRangeDescriptor(), expectedRes);
REQUIRE(isApprox(dCRes, hThrOp.apply(dC, geometry::Threshold<real_t>{5})));
}
}
}
}
SCENARIO("Using HardThresholding")
{
GIVEN("a DataDescriptor")
{
IndexVector_t numCoeff(1);
numCoeff << 8;
VolumeDescriptor volDescr(numCoeff);
WHEN("Using HardThresholding operator")
{
HardThresholding<real_t> hThrOp(volDescr);
THEN("The zero vector is returned when the zero vector is given")
{
RealVector_t data(volDescr.getNumberOfCoefficients());
data << 0, 0, 0, 0, 0, 0, 0, 0;
DataContainer<real_t> dataContainer(volDescr, data);
RealVector_t expectedRes(hThrOp.getRangeDescriptor().getNumberOfCoefficients());
expectedRes << 0, 0, 0, 0, 0, 0, 0, 0;
DataContainer<real_t> dCRes(hThrOp.getRangeDescriptor(), expectedRes);
REQUIRE(
isApprox(dCRes, hThrOp.apply(dataContainer, geometry::Threshold<real_t>{4})));
}
THEN("HardThresholding operator throws exception for t = 0")
{
RealVector_t data(volDescr.getNumberOfCoefficients());
data << 0, 0, 0, 0, 0, 0, 0, 0;
DataContainer<real_t> dC(volDescr, data);
REQUIRE_THROWS_AS(hThrOp.apply(dC, geometry::Threshold<real_t>{0}),
std::invalid_argument);
}
THEN("HardThresholding operator throws exception for t < 0")
{
RealVector_t data(volDescr.getNumberOfCoefficients());
data << 0, 0, 0, 0, 0, 0, 0, 0;
DataContainer<real_t> dC(volDescr, data);