Commit c62f0d34 authored by Jens Petit's avatar Jens Petit

DataHandlerGPU: Adapt for generic default DataHandlerType (#21)

parent bfb8b866
......@@ -69,7 +69,7 @@ namespace elsa
template <typename data_t>
DataContainer<data_t> LinearOperator<data_t>::apply(const DataContainer<data_t>& x) const
{
DataContainer<data_t> result(*_rangeDescriptor);
DataContainer<data_t> result(*_rangeDescriptor, x.getDataHandlerType());
apply(x, result);
return result;
}
......@@ -129,7 +129,7 @@ namespace elsa
throw std::invalid_argument(
"LinearOperator::apply: incorrect input/output sizes for mult leaf");
DataContainer<data_t> temp(_rhs->getRangeDescriptor());
DataContainer<data_t> temp(_rhs->getRangeDescriptor(), x.getDataHandlerType());
_rhs->apply(x, temp);
_lhs->apply(temp, Ax);
return;
......@@ -142,7 +142,7 @@ namespace elsa
template <typename data_t>
DataContainer<data_t> LinearOperator<data_t>::applyAdjoint(const DataContainer<data_t>& y) const
{
DataContainer<data_t> result(*_domainDescriptor);
DataContainer<data_t> result(*_domainDescriptor, y.getDataHandlerType());
applyAdjoint(y, result);
return result;
}
......@@ -202,7 +202,7 @@ namespace elsa
throw std::invalid_argument(
"LinearOperator::applyAdjoint: incorrect input/output sizes for mult leaf");
DataContainer<data_t> temp(_lhs->getDomainDescriptor());
DataContainer<data_t> temp(_lhs->getDomainDescriptor(), y.getDataHandlerType());
_lhs->applyAdjoint(y, temp);
_rhs->applyAdjoint(temp, Aty);
return;
......
......@@ -50,7 +50,7 @@ namespace elsa
template <typename data_t>
DataContainer<data_t> Functional<data_t>::getGradient(const DataContainer<data_t>& x)
{
DataContainer<data_t> result(_residual->getRangeDescriptor());
DataContainer<data_t> result(_residual->getRangeDescriptor(), x.getDataHandlerType());
getGradient(x, result);
return result;
}
......
......@@ -26,7 +26,7 @@ namespace elsa
template <typename data_t>
DataContainer<data_t> Residual<data_t>::evaluate(const DataContainer<data_t>& x)
{
DataContainer<data_t> result(*_rangeDescriptor);
DataContainer<data_t> result(*_rangeDescriptor, x.getDataHandlerType());
evaluate(x, result);
return result;
}
......
......@@ -62,7 +62,7 @@ namespace elsa
const RegularizationTerm<data_t>& regTerm)
: _dataTerm{dataTerm.clone()},
_regTerms{regTerm},
_currentSolution{dataTerm.getDomainDescriptor()}
_currentSolution{dataTerm.getDomainDescriptor(), defaultHandlerType}
{
// sanity check
if (dataTerm.getDomainDescriptor().getNumberOfCoefficients()
......@@ -85,7 +85,8 @@ namespace elsa
template <typename data_t>
Problem<data_t>::Problem(const Functional<data_t>& dataTerm)
: _dataTerm{dataTerm.clone()}, _currentSolution{dataTerm.getDomainDescriptor()}
: _dataTerm{dataTerm.clone()},
_currentSolution{dataTerm.getDomainDescriptor(), defaultHandlerType}
{
_currentSolution = 0;
}
......@@ -193,7 +194,8 @@ namespace elsa
template <typename data_t>
DataContainer<data_t> Problem<data_t>::getGradient()
{
DataContainer<data_t> result(_currentSolution.getDataDescriptor());
DataContainer<data_t> result(_currentSolution.getDataDescriptor(),
_currentSolution.getDataHandlerType());
getGradient(result);
return result;
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment