Commit 2737ad1b authored by Jens Petit's avatar Jens Petit Committed by Tobias Lasser

Add class template argument deduction guide to DataContainer for expression evaluation.

parent f216f987
Pipeline #201304 passed with stages
in 22 minutes and 34 seconds
......@@ -364,6 +364,10 @@ namespace elsa
DataHandlerType dataType = DataHandlerType::CPU);
};
/// User-defined template argument deduction guide for the expression based constructor
template <typename Source>
DataContainer(Source const& source)->DataContainer<typename Source::data_t>;
/// Multiplying two operands (including scalars)
template <typename LHS, typename RHS, typename = std::enable_if_t<isBinaryOpOk<LHS, RHS>>>
auto operator*(LHS const& lhs, RHS const& rhs)
......
......@@ -55,6 +55,9 @@ namespace elsa
/// type which bundles the meta information to create a new DataContainer
using MetaInfo_t = std::pair<DataDescriptor const&, DataHandlerType>;
/// indicates data type is used in the expression
using data_t = typename GetOperandsDataType<Operands...>::data_t;
/// Constructor
Expression(Callable func, Operands const&... args)
: _callable(func), _args(args...), _dataMetaInfo(initDescriptor(args...))
......
......@@ -72,4 +72,35 @@ namespace elsa
constexpr bool isBinaryOpOk = (isDcOrExpr<LHS> && isDcOrExpr<RHS>)
|| (isDcOrExpr<LHS> && isArithmetic<RHS>)
|| (isArithmetic<LHS> && isDcOrExpr<RHS>);
/// Default case to infer data_t of any operand
template <typename Operand>
struct GetOperandDataType {
using data_t = real_t;
};
/// Partial specialization to infer data_t from DataContainer
template <typename data_type>
struct GetOperandDataType<DataContainer<data_type>> {
using data_t = data_type;
};
/// Partial specialization to infer data_t from Expression
template <typename Callable, typename... Operands>
struct GetOperandDataType<Expression<Callable, Operands...>> {
using data_t = typename Expression<Callable, Operands...>::data_t;
};
/* Uses the data type used in the first or last operand depending on whether the first operand
* is an anrithmetic type
*/
template <typename... Operands>
struct GetOperandsDataType {
using data_t = std::conditional_t<
isArithmetic<std::tuple_element_t<0, std::tuple<Operands...>>>,
std::tuple_element_t<sizeof...(Operands) - 1,
std::tuple<typename GetOperandDataType<Operands>::data_t...>>,
std::tuple_element_t<0, std::tuple<typename GetOperandDataType<Operands>::data_t...>>>;
};
} // namespace elsa
......@@ -41,7 +41,7 @@ std::string type_name()
return r;
}
SCENARIO("Expression templates")
TEMPLATE_TEST_CASE("Scenario: Expression templates", "", float, double)
{
GIVEN("Three random data containers")
{
......@@ -50,15 +50,15 @@ SCENARIO("Expression templates")
IndexVector_t numCoeff(3);
numCoeff << dimension, dimension, dimension;
DataDescriptor desc(numCoeff);
DataContainer dc(desc);
DataContainer dc2(desc);
DataContainer dc3(desc);
DataContainer result(desc);
DataContainer<TestType> dc(desc);
DataContainer<TestType> dc2(desc);
DataContainer<TestType> dc3(desc);
DataContainer<TestType> result(desc);
for (index_t i = 0; i < dc.getSize(); ++i) {
dc[i] = static_cast<float>(rand()) / (static_cast<float>(RAND_MAX / 100.0));
dc2[i] = static_cast<float>(rand()) / (static_cast<float>(RAND_MAX / 100.0));
dc3[i] = static_cast<float>(rand()) / (static_cast<float>(RAND_MAX / 100.0));
dc[i] = static_cast<TestType>(rand()) / (static_cast<TestType>(RAND_MAX / 100.0));
dc2[i] = static_cast<TestType>(rand()) / (static_cast<TestType>(RAND_MAX / 100.0));
dc3[i] = static_cast<TestType>(rand()) / (static_cast<TestType>(RAND_MAX / 100.0));
}
WHEN("using auto with an expression")
......@@ -94,9 +94,23 @@ SCENARIO("Expression templates")
WHEN("Constructing a new DataContainer out of an expression")
{
DataContainer newDC = dc * dc2 + dc3 / dc2;
THEN("the type is a DataContainer again")
{
DataContainer newDC = dc * dc2 + dc3 / dc2;
INFO(type_name<decltype(newDC)>());
}
THEN("the type is a DataContainer again") { INFO(type_name<decltype(newDC)>()); }
THEN("the type is a DataContainer again")
{
DataContainer newDC2 = 2.8 * dc2;
INFO(type_name<decltype(newDC2)>());
}
THEN("the type is a DataContainer again")
{
DataContainer newDC2 = dc2 * 2.8;
INFO(type_name<decltype(newDC2)>());
}
}
}
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment