Commit bcb0a543 authored by Jens Petit's avatar Jens Petit Committed by Tobias Lasser

Add correct return type deduction for norm operations (#38)

parent 62b45b6d
Pipeline #209087 passed with stages
in 23 minutes and 5 seconds
......@@ -134,19 +134,19 @@ namespace elsa
}
template <typename data_t>
data_t DataContainer<data_t>::squaredL2Norm() const
GetFloatingPointType_t<data_t> DataContainer<data_t>::squaredL2Norm() const
{
return _dataHandler->squaredL2Norm();
}
template <typename data_t>
data_t DataContainer<data_t>::l1Norm() const
GetFloatingPointType_t<data_t> DataContainer<data_t>::l1Norm() const
{
return _dataHandler->l1Norm();
}
template <typename data_t>
data_t DataContainer<data_t>::lInfNorm() const
GetFloatingPointType_t<data_t> DataContainer<data_t>::lInfNorm() const
{
return _dataHandler->lInfNorm();
}
......
......@@ -178,13 +178,13 @@ namespace elsa
}
/// return the squared l2 norm of this signal (dot product with itself)
data_t squaredL2Norm() const;
GetFloatingPointType_t<data_t> squaredL2Norm() const;
/// return the l1 norm of this signal (sum of absolute values)
data_t l1Norm() const;
GetFloatingPointType_t<data_t> l1Norm() const;
/// return the linf norm of this signal (maximum of absolute values)
data_t lInfNorm() const;
GetFloatingPointType_t<data_t> lInfNorm() const;
/// return the sum of all elements of this signal
data_t sum() const;
......
......@@ -57,13 +57,13 @@ namespace elsa
virtual data_t dot(const DataHandler<data_t>& v) const = 0;
/// return the squared l2 norm of the data vector (dot product with itself)
virtual data_t squaredL2Norm() const = 0;
virtual GetFloatingPointType_t<data_t> squaredL2Norm() const = 0;
/// return the l1 norm of the data vector (sum of absolute values)
virtual data_t l1Norm() const = 0;
virtual GetFloatingPointType_t<data_t> l1Norm() const = 0;
/// return the linf norm of the data vector (maximum of absolute values)
virtual data_t lInfNorm() const = 0;
virtual GetFloatingPointType_t<data_t> lInfNorm() const = 0;
/// return the sum of all elements of the data vector
virtual data_t sum() const = 0;
......
......@@ -75,19 +75,19 @@ namespace elsa
}
template <typename data_t>
data_t DataHandlerCPU<data_t>::squaredL2Norm() const
GetFloatingPointType_t<data_t> DataHandlerCPU<data_t>::squaredL2Norm() const
{
return _data->squaredNorm();
}
template <typename data_t>
data_t DataHandlerCPU<data_t>::l1Norm() const
GetFloatingPointType_t<data_t> DataHandlerCPU<data_t>::l1Norm() const
{
return _data->array().abs().sum();
}
template <typename data_t>
data_t DataHandlerCPU<data_t>::lInfNorm() const
GetFloatingPointType_t<data_t> DataHandlerCPU<data_t>::lInfNorm() const
{
return _data->array().abs().maxCoeff();
}
......
......@@ -99,13 +99,13 @@ namespace elsa
data_t dot(const DataHandler<data_t>& v) const override;
/// return the squared l2 norm of the data vector (dot product with itself)
data_t squaredL2Norm() const override;
GetFloatingPointType_t<data_t> squaredL2Norm() const override;
/// return the l1 norm of the data vector (sum of absolute values)
data_t l1Norm() const override;
GetFloatingPointType_t<data_t> l1Norm() const override;
/// return the linf norm of the data vector (maximum of absolute values)
data_t lInfNorm() const override;
GetFloatingPointType_t<data_t> lInfNorm() const override;
/// return the sum of all elements of the data vector
data_t sum() const override;
......
......@@ -75,19 +75,19 @@ namespace elsa
}
template <typename data_t>
data_t DataHandlerMapCPU<data_t>::squaredL2Norm() const
GetFloatingPointType_t<data_t> DataHandlerMapCPU<data_t>::squaredL2Norm() const
{
return _map.squaredNorm();
}
template <typename data_t>
data_t DataHandlerMapCPU<data_t>::l1Norm() const
GetFloatingPointType_t<data_t> DataHandlerMapCPU<data_t>::l1Norm() const
{
return _map.array().abs().sum();
}
template <typename data_t>
data_t DataHandlerMapCPU<data_t>::lInfNorm() const
GetFloatingPointType_t<data_t> DataHandlerMapCPU<data_t>::lInfNorm() const
{
return _map.array().abs().maxCoeff();
}
......
......@@ -75,13 +75,13 @@ namespace elsa
data_t dot(const DataHandler<data_t>& v) const override;
/// return the squared l2 norm of the data vector (dot product with itself)
data_t squaredL2Norm() const override;
GetFloatingPointType_t<data_t> squaredL2Norm() const override;
/// return the l1 norm of the data vector (sum of absolute values)
data_t l1Norm() const override;
GetFloatingPointType_t<data_t> l1Norm() const override;
/// return the linf norm of the data vector (maximum of absolute values)
data_t lInfNorm() const override;
GetFloatingPointType_t<data_t> lInfNorm() const override;
/// return the sum of all elements of the data vector
data_t sum() const override;
......
......@@ -37,4 +37,21 @@ namespace elsa
CPU, ///< data is stored as an Eigen::Matrix in CPU main memory
MAP_CPU ///< data is not explicitly stored, but using an Eigen::Map to refer to other
};
/// base case for deducing floating point type of std::complex
template <typename T>
struct GetFloatingPointType {
using type = T;
};
/// partial specialization to derive correct floating point type
template <typename T>
struct GetFloatingPointType<std::complex<T>> {
using type = T;
};
/// helper typedef to facilitate usage
template <typename T>
using GetFloatingPointType_t = typename GetFloatingPointType<T>::type;
} // namespace elsa
......@@ -25,4 +25,15 @@ SCENARIO("Testing PI")
{
REQUIRE(pi<double> == 3.14159265358979323846);
}
}
SCENARIO("Testing compile-time predicates")
{
static_assert(std::is_same_v<float, GetFloatingPointType_t<std::complex<float>>>);
static_assert(std::is_same_v<double, GetFloatingPointType_t<std::complex<double>>>);
static_assert(std::is_same_v<double, GetFloatingPointType_t<double>>);
static_assert(std::is_same_v<float, GetFloatingPointType_t<float>>);
static_assert(!std::is_same_v<float, GetFloatingPointType_t<double>>);
REQUIRE(true);
}
\ No newline at end of file
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment