Commit bcb0a543 authored by Jens Petit's avatar Jens Petit Committed by Tobias Lasser

Add correct return type deduction for norm operations (#38)

parent 62b45b6d
Pipeline #209087 passed with stages
in 23 minutes and 5 seconds
...@@ -134,19 +134,19 @@ namespace elsa ...@@ -134,19 +134,19 @@ namespace elsa
} }
template <typename data_t> template <typename data_t>
data_t DataContainer<data_t>::squaredL2Norm() const GetFloatingPointType_t<data_t> DataContainer<data_t>::squaredL2Norm() const
{ {
return _dataHandler->squaredL2Norm(); return _dataHandler->squaredL2Norm();
} }
template <typename data_t> template <typename data_t>
data_t DataContainer<data_t>::l1Norm() const GetFloatingPointType_t<data_t> DataContainer<data_t>::l1Norm() const
{ {
return _dataHandler->l1Norm(); return _dataHandler->l1Norm();
} }
template <typename data_t> template <typename data_t>
data_t DataContainer<data_t>::lInfNorm() const GetFloatingPointType_t<data_t> DataContainer<data_t>::lInfNorm() const
{ {
return _dataHandler->lInfNorm(); return _dataHandler->lInfNorm();
} }
......
...@@ -178,13 +178,13 @@ namespace elsa ...@@ -178,13 +178,13 @@ namespace elsa
} }
/// return the squared l2 norm of this signal (dot product with itself) /// return the squared l2 norm of this signal (dot product with itself)
data_t squaredL2Norm() const; GetFloatingPointType_t<data_t> squaredL2Norm() const;
/// return the l1 norm of this signal (sum of absolute values) /// return the l1 norm of this signal (sum of absolute values)
data_t l1Norm() const; GetFloatingPointType_t<data_t> l1Norm() const;
/// return the linf norm of this signal (maximum of absolute values) /// return the linf norm of this signal (maximum of absolute values)
data_t lInfNorm() const; GetFloatingPointType_t<data_t> lInfNorm() const;
/// return the sum of all elements of this signal /// return the sum of all elements of this signal
data_t sum() const; data_t sum() const;
......
...@@ -57,13 +57,13 @@ namespace elsa ...@@ -57,13 +57,13 @@ namespace elsa
virtual data_t dot(const DataHandler<data_t>& v) const = 0; virtual data_t dot(const DataHandler<data_t>& v) const = 0;
/// return the squared l2 norm of the data vector (dot product with itself) /// return the squared l2 norm of the data vector (dot product with itself)
virtual data_t squaredL2Norm() const = 0; virtual GetFloatingPointType_t<data_t> squaredL2Norm() const = 0;
/// return the l1 norm of the data vector (sum of absolute values) /// return the l1 norm of the data vector (sum of absolute values)
virtual data_t l1Norm() const = 0; virtual GetFloatingPointType_t<data_t> l1Norm() const = 0;
/// return the linf norm of the data vector (maximum of absolute values) /// return the linf norm of the data vector (maximum of absolute values)
virtual data_t lInfNorm() const = 0; virtual GetFloatingPointType_t<data_t> lInfNorm() const = 0;
/// return the sum of all elements of the data vector /// return the sum of all elements of the data vector
virtual data_t sum() const = 0; virtual data_t sum() const = 0;
......
...@@ -75,19 +75,19 @@ namespace elsa ...@@ -75,19 +75,19 @@ namespace elsa
} }
template <typename data_t> template <typename data_t>
data_t DataHandlerCPU<data_t>::squaredL2Norm() const GetFloatingPointType_t<data_t> DataHandlerCPU<data_t>::squaredL2Norm() const
{ {
return _data->squaredNorm(); return _data->squaredNorm();
} }
template <typename data_t> template <typename data_t>
data_t DataHandlerCPU<data_t>::l1Norm() const GetFloatingPointType_t<data_t> DataHandlerCPU<data_t>::l1Norm() const
{ {
return _data->array().abs().sum(); return _data->array().abs().sum();
} }
template <typename data_t> template <typename data_t>
data_t DataHandlerCPU<data_t>::lInfNorm() const GetFloatingPointType_t<data_t> DataHandlerCPU<data_t>::lInfNorm() const
{ {
return _data->array().abs().maxCoeff(); return _data->array().abs().maxCoeff();
} }
......
...@@ -99,13 +99,13 @@ namespace elsa ...@@ -99,13 +99,13 @@ namespace elsa
data_t dot(const DataHandler<data_t>& v) const override; data_t dot(const DataHandler<data_t>& v) const override;
/// return the squared l2 norm of the data vector (dot product with itself) /// return the squared l2 norm of the data vector (dot product with itself)
data_t squaredL2Norm() const override; GetFloatingPointType_t<data_t> squaredL2Norm() const override;
/// return the l1 norm of the data vector (sum of absolute values) /// return the l1 norm of the data vector (sum of absolute values)
data_t l1Norm() const override; GetFloatingPointType_t<data_t> l1Norm() const override;
/// return the linf norm of the data vector (maximum of absolute values) /// return the linf norm of the data vector (maximum of absolute values)
data_t lInfNorm() const override; GetFloatingPointType_t<data_t> lInfNorm() const override;
/// return the sum of all elements of the data vector /// return the sum of all elements of the data vector
data_t sum() const override; data_t sum() const override;
......
...@@ -75,19 +75,19 @@ namespace elsa ...@@ -75,19 +75,19 @@ namespace elsa
} }
template <typename data_t> template <typename data_t>
data_t DataHandlerMapCPU<data_t>::squaredL2Norm() const GetFloatingPointType_t<data_t> DataHandlerMapCPU<data_t>::squaredL2Norm() const
{ {
return _map.squaredNorm(); return _map.squaredNorm();
} }
template <typename data_t> template <typename data_t>
data_t DataHandlerMapCPU<data_t>::l1Norm() const GetFloatingPointType_t<data_t> DataHandlerMapCPU<data_t>::l1Norm() const
{ {
return _map.array().abs().sum(); return _map.array().abs().sum();
} }
template <typename data_t> template <typename data_t>
data_t DataHandlerMapCPU<data_t>::lInfNorm() const GetFloatingPointType_t<data_t> DataHandlerMapCPU<data_t>::lInfNorm() const
{ {
return _map.array().abs().maxCoeff(); return _map.array().abs().maxCoeff();
} }
......
...@@ -75,13 +75,13 @@ namespace elsa ...@@ -75,13 +75,13 @@ namespace elsa
data_t dot(const DataHandler<data_t>& v) const override; data_t dot(const DataHandler<data_t>& v) const override;
/// return the squared l2 norm of the data vector (dot product with itself) /// return the squared l2 norm of the data vector (dot product with itself)
data_t squaredL2Norm() const override; GetFloatingPointType_t<data_t> squaredL2Norm() const override;
/// return the l1 norm of the data vector (sum of absolute values) /// return the l1 norm of the data vector (sum of absolute values)
data_t l1Norm() const override; GetFloatingPointType_t<data_t> l1Norm() const override;
/// return the linf norm of the data vector (maximum of absolute values) /// return the linf norm of the data vector (maximum of absolute values)
data_t lInfNorm() const override; GetFloatingPointType_t<data_t> lInfNorm() const override;
/// return the sum of all elements of the data vector /// return the sum of all elements of the data vector
data_t sum() const override; data_t sum() const override;
......
...@@ -37,4 +37,21 @@ namespace elsa ...@@ -37,4 +37,21 @@ namespace elsa
CPU, ///< data is stored as an Eigen::Matrix in CPU main memory CPU, ///< data is stored as an Eigen::Matrix in CPU main memory
MAP_CPU ///< data is not explicitly stored, but using an Eigen::Map to refer to other MAP_CPU ///< data is not explicitly stored, but using an Eigen::Map to refer to other
}; };
/// base case for deducing floating point type of std::complex
template <typename T>
struct GetFloatingPointType {
using type = T;
};
/// partial specialization to derive correct floating point type
template <typename T>
struct GetFloatingPointType<std::complex<T>> {
using type = T;
};
/// helper typedef to facilitate usage
template <typename T>
using GetFloatingPointType_t = typename GetFloatingPointType<T>::type;
} // namespace elsa } // namespace elsa
...@@ -25,4 +25,15 @@ SCENARIO("Testing PI") ...@@ -25,4 +25,15 @@ SCENARIO("Testing PI")
{ {
REQUIRE(pi<double> == 3.14159265358979323846); REQUIRE(pi<double> == 3.14159265358979323846);
} }
}
SCENARIO("Testing compile-time predicates")
{
static_assert(std::is_same_v<float, GetFloatingPointType_t<std::complex<float>>>);
static_assert(std::is_same_v<double, GetFloatingPointType_t<std::complex<double>>>);
static_assert(std::is_same_v<double, GetFloatingPointType_t<double>>);
static_assert(std::is_same_v<float, GetFloatingPointType_t<float>>);
static_assert(!std::is_same_v<float, GetFloatingPointType_t<double>>);
REQUIRE(true);
} }
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment