Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in
Toggle navigation
Menu
Open sidebar
IP
elsa
Commits
7f24cd6d
Commit
7f24cd6d
authored
Mar 16, 2020
by
Jens Petit
Browse files
DataHandlerGPU: Core functionality (
#21
)
parent
825e8c5b
Changes
13
Hide whitespace changes
Inline
Side-by-side
elsa/core/DataContainer.cpp
View file @
7f24cd6d
...
...
@@ -53,7 +53,7 @@ namespace elsa
_dataHandler
=
other
.
_dataHandler
->
clone
();
}
_dataH
andler
T
ype
=
other
.
_dataHandlerT
ype
;
// TODO: Check what to do with h
andler
t
ype
if CPU copy assign to GPU t
ype
}
return
*
this
;
...
...
@@ -81,7 +81,7 @@ namespace elsa
_dataHandler
=
std
::
move
(
other
.
_dataHandler
);
}
_dataH
andler
T
ype
=
std
::
move
(
other
.
_dataHandlerT
ype
);
// TODO: Check what to do with h
andler
t
ype
if CPU move assign to GPU t
ype
// leave other in a valid state
other
.
_dataDescriptor
=
nullptr
;
...
...
@@ -230,6 +230,12 @@ namespace elsa
return
std
::
make_unique
<
DataHandlerCPU
<
data_t
>>
(
std
::
forward
<
Args
>
(
args
)...);
case
DataHandlerType
::
MAP_CPU
:
return
std
::
make_unique
<
DataHandlerCPU
<
data_t
>>
(
std
::
forward
<
Args
>
(
args
)...);
#ifdef ELSA_CUDA_VECTOR
case
DataHandlerType
::
GPU
:
return
std
::
make_unique
<
DataHandlerGPU
<
data_t
>>
(
std
::
forward
<
Args
>
(
args
)...);
case
DataHandlerType
::
MAP_GPU
:
return
std
::
make_unique
<
DataHandlerGPU
<
data_t
>>
(
std
::
forward
<
Args
>
(
args
)...);
#endif
default:
throw
std
::
invalid_argument
(
"DataContainer: unknown handler type"
);
}
...
...
@@ -277,8 +283,13 @@ namespace elsa
const
auto
&
ithDesc
=
blockDesc
->
getDescriptorOfBlock
(
i
);
index_t
blockSize
=
ithDesc
.
getNumberOfCoefficients
();
DataHandlerType
newHandlerType
=
(
_dataHandlerType
==
DataHandlerType
::
CPU
||
_dataHandlerType
==
DataHandlerType
::
MAP_CPU
)
?
DataHandlerType
::
MAP_CPU
:
DataHandlerType
::
MAP_GPU
;
return
DataContainer
<
data_t
>
{
ithDesc
,
_dataHandler
->
getBlock
(
startIndex
,
blockSize
),
Data
HandlerType
::
MAP_CPU
};
new
HandlerType
};
}
template
<
typename
data_t
>
...
...
@@ -295,10 +306,15 @@ namespace elsa
const
auto
&
ithDesc
=
blockDesc
->
getDescriptorOfBlock
(
i
);
index_t
blockSize
=
ithDesc
.
getNumberOfCoefficients
();
DataHandlerType
newHandlerType
=
(
_dataHandlerType
==
DataHandlerType
::
CPU
||
_dataHandlerType
==
DataHandlerType
::
MAP_CPU
)
?
DataHandlerType
::
MAP_CPU
:
DataHandlerType
::
MAP_GPU
;
// getBlock() returns a pointer to non-const DH, but that's fine as it gets wrapped in a
// constant container
return
DataContainer
<
data_t
>
{
ithDesc
,
_dataHandler
->
getBlock
(
startIndex
,
blockSize
),
Data
HandlerType
::
MAP_CPU
};
new
HandlerType
};
}
template
<
typename
data_t
>
...
...
@@ -307,8 +323,13 @@ namespace elsa
if
(
dataDescriptor
.
getNumberOfCoefficients
()
!=
getSize
())
throw
std
::
invalid_argument
(
"DataContainer: view must have same size as container"
);
DataHandlerType
newHandlerType
=
(
_dataHandlerType
==
DataHandlerType
::
CPU
||
_dataHandlerType
==
DataHandlerType
::
MAP_CPU
)
?
DataHandlerType
::
MAP_CPU
:
DataHandlerType
::
MAP_GPU
;
return
DataContainer
<
data_t
>
{
dataDescriptor
,
_dataHandler
->
getBlock
(
0
,
getSize
()),
Data
HandlerType
::
MAP_CPU
};
new
HandlerType
};
}
template
<
typename
data_t
>
...
...
@@ -318,10 +339,15 @@ namespace elsa
if
(
dataDescriptor
.
getNumberOfCoefficients
()
!=
getSize
())
throw
std
::
invalid_argument
(
"DataContainer: view must have same size as container"
);
DataHandlerType
newHandlerType
=
(
_dataHandlerType
==
DataHandlerType
::
CPU
||
_dataHandlerType
==
DataHandlerType
::
MAP_CPU
)
?
DataHandlerType
::
MAP_CPU
:
DataHandlerType
::
MAP_GPU
;
// getBlock() returns a pointer to non-const DH, but that's fine as it gets wrapped in a
// constant container
return
DataContainer
<
data_t
>
{
dataDescriptor
,
_dataHandler
->
getBlock
(
0
,
getSize
()),
Data
HandlerType
::
MAP_CPU
};
new
HandlerType
};
}
template
<
typename
data_t
>
...
...
@@ -396,22 +422,6 @@ namespace elsa
return
const_reverse_iterator
(
cbegin
());
}
template
<
typename
data_t
>
typename
DataContainer
<
data_t
>::
HandlerTypes_t
DataContainer
<
data_t
>::
getHandlerPtr
()
const
{
DataContainer
<
data_t
>::
HandlerTypes_t
handler
;
if
(
_dataHandlerType
==
DataHandlerType
::
CPU
)
{
handler
=
static_cast
<
DataHandlerCPU
<
data_t
>*>
(
_dataHandler
.
get
());
}
if
(
_dataHandlerType
==
DataHandlerType
::
MAP_CPU
)
{
handler
=
static_cast
<
DataHandlerMapCPU
<
data_t
>*>
(
_dataHandler
.
get
());
}
return
handler
;
}
template
<
typename
data_t
>
DataHandlerType
DataContainer
<
data_t
>::
getDataHandlerType
()
const
{
...
...
elsa/core/DataContainer.h
View file @
7f24cd6d
...
...
@@ -8,6 +8,11 @@
#include "DataContainerIterator.h"
#include "Expression.h"
#ifdef ELSA_CUDA_VECTOR
#include "DataHandlerGPU.h"
#include "DataHandlerMapGPU.h"
#endif
#include <memory>
#include <type_traits>
...
...
@@ -33,9 +38,6 @@ namespace elsa
class
DataContainer
{
public:
/// union of all possible handler raw pointers
using
HandlerTypes_t
=
std
::
variant
<
DataHandlerCPU
<
data_t
>*
,
DataHandlerMapCPU
<
data_t
>*>
;
/// delete default constructor (without metadata there can be no valid container)
DataContainer
()
=
delete
;
...
...
@@ -46,7 +48,7 @@ namespace elsa
* \param[in] handlerType the data handler (default: CPU)
*/
explicit
DataContainer
(
const
DataDescriptor
&
dataDescriptor
,
DataHandlerType
handlerType
=
Data
HandlerType
::
CPU
);
DataHandlerType
handlerType
=
default
HandlerType
);
/**
* \brief Constructor for DataContainer, initializing it with a DataVector
...
...
@@ -57,7 +59,7 @@ namespace elsa
*/
DataContainer
(
const
DataDescriptor
&
dataDescriptor
,
const
Eigen
::
Matrix
<
data_t
,
Eigen
::
Dynamic
,
1
>&
data
,
DataHandlerType
handlerType
=
Data
HandlerType
::
CPU
);
DataHandlerType
handlerType
=
default
HandlerType
);
/**
* \brief Copy constructor for DataContainer
...
...
@@ -106,7 +108,21 @@ namespace elsa
template
<
typename
Source
,
typename
=
std
::
enable_if_t
<
isExpression
<
Source
>
>>
DataContainer
<
data_t
>&
operator
=
(
Source
const
&
source
)
{
_dataHandler
->
accessData
()
=
source
.
eval
();
if
(
auto
handler
=
dynamic_cast
<
DataHandlerCPU
<
data_t
>*>
(
_dataHandler
.
get
()))
{
handler
->
accessData
()
=
source
.
template
eval
<
false
>();
}
else
if
(
auto
handler
=
dynamic_cast
<
DataHandlerMapCPU
<
data_t
>*>
(
_dataHandler
.
get
()))
{
handler
->
accessData
()
=
source
.
template
eval
<
false
>();
#ifdef ELSA_CUDA_VECTOR
}
else
if
(
auto
handler
=
dynamic_cast
<
DataHandlerGPU
<
data_t
>*>
(
_dataHandler
.
get
()))
{
handler
->
accessData
().
eval
(
source
.
template
eval
<
true
>());
}
else
if
(
auto
handler
=
dynamic_cast
<
DataHandlerMapGPU
<
data_t
>*>
(
_dataHandler
.
get
()))
{
handler
->
accessData
().
eval
(
source
.
template
eval
<
true
>());
#endif
}
else
{
throw
std
::
logic_error
(
"Unknown handler type"
);
}
return
*
this
;
}
...
...
@@ -121,9 +137,9 @@ namespace elsa
*/
template
<
typename
Source
,
typename
=
std
::
enable_if_t
<
isExpression
<
Source
>
>>
DataContainer
<
data_t
>
(
Source
const
&
source
)
:
DataContainer
<
data_t
>
(
source
.
getDataMetaInfo
().
first
,
source
.
eval
(),
source
.
getDataMetaInfo
().
second
)
:
DataContainer
<
data_t
>
(
source
.
getDataMetaInfo
().
first
,
source
.
getDataMetaInfo
().
second
)
{
this
->
operator
=
(
source
);
}
/// return the current DataDescriptor
...
...
@@ -174,7 +190,23 @@ namespace elsa
template
<
typename
Source
,
typename
=
std
::
enable_if_t
<
isExpression
<
Source
>
>>
data_t
dot
(
const
Source
&
source
)
const
{
return
(
*
this
*
source
).
eval
().
sum
();
if
(
auto
handler
=
dynamic_cast
<
DataHandlerCPU
<
data_t
>*>
(
_dataHandler
.
get
()))
{
return
(
*
this
*
source
).
template
eval
<
false
>().
sum
();
}
else
if
(
auto
handler
=
dynamic_cast
<
DataHandlerMapCPU
<
data_t
>*>
(
_dataHandler
.
get
()))
{
return
(
*
this
*
source
).
template
eval
<
false
>().
sum
();
#ifdef ELSA_CUDA_VECTOR
}
else
if
(
auto
handler
=
dynamic_cast
<
DataHandlerGPU
<
data_t
>*>
(
_dataHandler
.
get
()))
{
DataContainer
temp
=
(
*
this
*
source
);
return
temp
.
sum
();
}
else
if
(
auto
handler
=
dynamic_cast
<
DataHandlerMapGPU
<
data_t
>*>
(
_dataHandler
.
get
()))
{
DataContainer
temp
=
(
*
this
*
source
);
return
temp
.
sum
();
#endif
}
else
{
throw
std
::
logic_error
(
"Unknown handler type"
);
}
}
/// return the squared l2 norm of this signal (dot product with itself)
...
...
@@ -336,13 +368,10 @@ namespace elsa
DataHandlerType
getDataHandlerType
()
const
;
/// friend constexpr function to implement expression templates
template
<
class
Operand
,
std
::
enable_if_t
<
isDataContainer
<
Operand
>,
int
>>
template
<
bool
GPU
,
class
Operand
,
std
::
enable_if_t
<
isDataContainer
<
Operand
>,
int
>>
friend
constexpr
auto
evaluateOrReturn
(
Operand
const
&
operand
);
private:
/// returns the underlying derived handler as a raw pointer in a std::variant
HandlerTypes_t
getHandlerPtr
()
const
;
/// the current DataDescriptor
std
::
unique_ptr
<
DataDescriptor
>
_dataDescriptor
;
...
...
@@ -361,35 +390,49 @@ namespace elsa
/// private constructor accepting a DataDescriptor and a DataHandler
explicit
DataContainer
(
const
DataDescriptor
&
dataDescriptor
,
std
::
unique_ptr
<
DataHandler
<
data_t
>>
dataHandler
,
DataHandlerType
dataType
=
Data
HandlerType
::
CPU
);
DataHandlerType
dataType
=
default
HandlerType
);
};
/// User-defined template argument deduction guide for the expression based constructor
template
<
typename
Source
>
DataContainer
(
Source
const
&
source
)
->
DataContainer
<
typename
Source
::
data_t
>
;
/// Collects callable lambdas for later dispatch
template
<
typename
...
Ts
>
struct
Callables
:
Ts
...
{
using
Ts
::
operator
()...;
};
/// Class template deduction guide
template
<
typename
...
Ts
>
Callables
(
Ts
...)
->
Callables
<
Ts
...
>
;
/// Multiplying two operands (including scalars)
template
<
typename
LHS
,
typename
RHS
,
typename
=
std
::
enable_if_t
<
isBinaryOpOk
<
LHS
,
RHS
>
>>
auto
operator
*
(
LHS
const
&
lhs
,
RHS
const
&
rhs
)
{
auto
multiplicationGPU
=
[](
auto
const
&
left
,
auto
const
&
right
,
bool
/**/
)
{
return
left
*
right
;
};
if
constexpr
(
isDcOrExpr
<
LHS
>
&&
isDcOrExpr
<
RHS
>
)
{
auto
multiplication
=
[](
auto
const
&
left
,
auto
const
&
right
)
{
return
(
left
.
array
()
*
right
.
array
()).
matrix
();
};
return
Expression
{
multiplication
,
lhs
,
rhs
};
return
Expression
{
Callables
{
multiplication
,
multiplicationGPU
},
lhs
,
rhs
};
}
else
if
constexpr
(
isArithmetic
<
LHS
>
)
{
auto
multiplication
=
[](
auto
const
&
left
,
auto
const
&
right
)
{
return
(
left
*
right
.
array
()).
matrix
();
};
return
Expression
{
multiplication
,
lhs
,
rhs
};
return
Expression
{
Callables
{
multiplication
,
multiplicationGPU
},
lhs
,
rhs
};
}
else
if
constexpr
(
isArithmetic
<
RHS
>
)
{
auto
multiplication
=
[](
auto
const
&
left
,
auto
const
&
right
)
{
return
(
left
.
array
()
*
right
).
matrix
();
};
return
Expression
{
multiplication
,
lhs
,
rhs
};
return
Expression
{
Callables
{
multiplication
,
multiplicationGPU
},
lhs
,
rhs
};
}
else
{
auto
multiplication
=
[](
auto
const
&
left
,
auto
const
&
right
)
{
return
left
*
right
;
};
return
Expression
{
multiplication
,
lhs
,
rhs
};
return
Expression
{
Callables
{
multiplication
,
multiplicationGPU
},
lhs
,
rhs
};
}
}
...
...
@@ -397,22 +440,26 @@ namespace elsa
template
<
typename
LHS
,
typename
RHS
,
typename
=
std
::
enable_if_t
<
isBinaryOpOk
<
LHS
,
RHS
>
>>
auto
operator
+
(
LHS
const
&
lhs
,
RHS
const
&
rhs
)
{
auto
additionGPU
=
[](
auto
const
&
left
,
auto
const
&
right
,
bool
/**/
)
{
return
left
+
right
;
};
if
constexpr
(
isDcOrExpr
<
LHS
>
&&
isDcOrExpr
<
RHS
>
)
{
auto
addition
=
[](
auto
const
&
left
,
auto
const
&
right
)
{
return
left
+
right
;
};
return
Expression
{
addition
,
lhs
,
rhs
};
return
Expression
{
Callables
{
addition
,
additionGPU
}
,
lhs
,
rhs
};
}
else
if
constexpr
(
isArithmetic
<
LHS
>
)
{
auto
addition
=
[](
auto
const
&
left
,
auto
const
&
right
)
{
return
(
left
+
right
.
array
()).
matrix
();
};
return
Expression
{
addition
,
lhs
,
rhs
};
return
Expression
{
Callables
{
addition
,
additionGPU
}
,
lhs
,
rhs
};
}
else
if
constexpr
(
isArithmetic
<
RHS
>
)
{
auto
addition
=
[](
auto
const
&
left
,
auto
const
&
right
)
{
return
(
left
.
array
()
+
right
).
matrix
();
};
return
Expression
{
addition
,
lhs
,
rhs
};
return
Expression
{
Callables
{
addition
,
additionGPU
}
,
lhs
,
rhs
};
}
else
{
auto
addition
=
[](
auto
const
&
left
,
auto
const
&
right
)
{
return
left
+
right
;
};
return
Expression
{
addition
,
lhs
,
rhs
};
return
Expression
{
Callables
{
addition
,
additionGPU
}
,
lhs
,
rhs
};
}
}
...
...
@@ -420,22 +467,26 @@ namespace elsa
template
<
typename
LHS
,
typename
RHS
,
typename
=
std
::
enable_if_t
<
isBinaryOpOk
<
LHS
,
RHS
>
>>
auto
operator
-
(
LHS
const
&
lhs
,
RHS
const
&
rhs
)
{
auto
subtractionGPU
=
[](
auto
const
&
left
,
auto
const
&
right
,
bool
/**/
)
{
return
left
-
right
;
};
if
constexpr
(
isDcOrExpr
<
LHS
>
&&
isDcOrExpr
<
RHS
>
)
{
auto
subtraction
=
[](
auto
const
&
left
,
auto
const
&
right
)
{
return
left
-
right
;
};
return
Expression
{
subtraction
,
lhs
,
rhs
};
return
Expression
{
Callables
{
subtraction
,
subtractionGPU
},
lhs
,
rhs
};
}
else
if
constexpr
(
isArithmetic
<
LHS
>
)
{
auto
subtraction
=
[](
auto
const
&
left
,
auto
const
&
right
)
{
return
(
left
-
right
.
array
()).
matrix
();
};
return
Expression
{
subtraction
,
lhs
,
rhs
};
return
Expression
{
Callables
{
subtraction
,
subtractionGPU
},
lhs
,
rhs
};
}
else
if
constexpr
(
isArithmetic
<
RHS
>
)
{
auto
subtraction
=
[](
auto
const
&
left
,
auto
const
&
right
)
{
return
(
left
.
array
()
-
right
).
matrix
();
};
return
Expression
{
subtraction
,
lhs
,
rhs
};
return
Expression
{
Callables
{
subtraction
,
subtractionGPU
},
lhs
,
rhs
};
}
else
{
auto
subtraction
=
[](
auto
const
&
left
,
auto
const
&
right
)
{
return
left
-
right
;
};
return
Expression
{
subtraction
,
lhs
,
rhs
};
return
Expression
{
Callables
{
subtraction
,
subtractionGPU
},
lhs
,
rhs
};
}
}
...
...
@@ -443,24 +494,28 @@ namespace elsa
template
<
typename
LHS
,
typename
RHS
,
typename
=
std
::
enable_if_t
<
isBinaryOpOk
<
LHS
,
RHS
>
>>
auto
operator
/
(
LHS
const
&
lhs
,
RHS
const
&
rhs
)
{
auto
divisionGPU
=
[](
auto
const
&
left
,
auto
const
&
right
,
bool
/**/
)
{
return
left
/
right
;
};
if
constexpr
(
isDcOrExpr
<
LHS
>
&&
isDcOrExpr
<
RHS
>
)
{
auto
division
=
[](
auto
const
&
left
,
auto
const
&
right
)
{
return
(
left
.
array
()
/
right
.
array
()).
matrix
();
};
return
Expression
{
division
,
lhs
,
rhs
};
return
Expression
{
Callables
{
division
,
divisionGPU
}
,
lhs
,
rhs
};
}
else
if
constexpr
(
isArithmetic
<
LHS
>
)
{
auto
division
=
[](
auto
const
&
left
,
auto
const
&
right
)
{
return
(
left
/
right
.
array
()).
matrix
();
};
return
Expression
{
division
,
lhs
,
rhs
};
return
Expression
{
Callables
{
division
,
divisionGPU
}
,
lhs
,
rhs
};
}
else
if
constexpr
(
isArithmetic
<
RHS
>
)
{
auto
division
=
[](
auto
const
&
left
,
auto
const
&
right
)
{
return
(
left
.
array
()
/
right
).
matrix
();
};
return
Expression
{
division
,
lhs
,
rhs
};
return
Expression
{
Callables
{
division
,
divisionGPU
}
,
lhs
,
rhs
};
}
else
{
auto
division
=
[](
auto
const
&
left
,
auto
const
&
right
)
{
return
left
*
right
;
};
return
Expression
{
division
,
lhs
,
rhs
};
auto
division
=
[](
auto
const
&
left
,
auto
const
&
right
)
{
return
left
/
right
;
};
return
Expression
{
Callables
{
division
,
divisionGPU
}
,
lhs
,
rhs
};
}
}
...
...
@@ -469,7 +524,12 @@ namespace elsa
auto
square
(
Operand
const
&
operand
)
{
auto
square
=
[](
auto
const
&
operand
)
{
return
(
operand
.
array
().
square
()).
matrix
();
};
#ifdef ELSA_CUDA_VECTOR
auto
squareGPU
=
[](
auto
const
&
operand
,
bool
/**/
)
{
return
quickvec
::
square
(
operand
);
};
return
Expression
{
Callables
{
square
,
squareGPU
},
operand
};
#else
return
Expression
{
square
,
operand
};
#endif
}
/// Element-wise square-root operation
...
...
@@ -477,7 +537,12 @@ namespace elsa
auto
sqrt
(
Operand
const
&
operand
)
{
auto
sqrt
=
[](
auto
const
&
operand
)
{
return
(
operand
.
array
().
sqrt
()).
matrix
();
};
#ifdef ELSA_CUDA_VECTOR
auto
sqrtGPU
=
[](
auto
const
&
operand
,
bool
/**/
)
{
return
quickvec
::
sqrt
(
operand
);
};
return
Expression
{
Callables
{
sqrt
,
sqrtGPU
},
operand
};
#else
return
Expression
{
sqrt
,
operand
};
#endif
}
/// Element-wise exponenation operation with euler base
...
...
@@ -485,7 +550,12 @@ namespace elsa
auto
exp
(
Operand
const
&
operand
)
{
auto
exp
=
[](
auto
const
&
operand
)
{
return
(
operand
.
array
().
exp
()).
matrix
();
};
#ifdef ELSA_CUDA_VECTOR
auto
expGPU
=
[](
auto
const
&
operand
,
bool
/**/
)
{
return
quickvec
::
exp
(
operand
);
};
return
Expression
{
Callables
{
exp
,
expGPU
},
operand
};
#else
return
Expression
{
exp
,
operand
};
#endif
}
/// Element-wise natural logarithm
...
...
@@ -493,7 +563,12 @@ namespace elsa
auto
log
(
Operand
const
&
operand
)
{
auto
log
=
[](
auto
const
&
operand
)
{
return
(
operand
.
array
().
log
()).
matrix
();
};
#ifdef ELSA_CUDA_VECTOR
auto
logGPU
=
[](
auto
const
&
operand
,
bool
/**/
)
{
return
quickvec
::
log
(
operand
);
};
return
Expression
{
Callables
{
log
,
logGPU
},
operand
};
#else
return
Expression
{
log
,
operand
};
#endif
}
}
// namespace elsa
elsa/core/DataHandler.h
View file @
7f24cd6d
...
...
@@ -4,6 +4,10 @@
#include "Cloneable.h"
#include "ExpressionPredicates.h"
#ifdef ELSA_CUDA_VECTOR
#include "Quickvec.h"
#endif
#include <Eigen/Core>
namespace
elsa
...
...
@@ -29,13 +33,6 @@ namespace elsa
template
<
typename
data_t
=
real_t
>
class
DataHandler
:
public
Cloneable
<
DataHandler
<
data_t
>>
{
/// for enabling accessData()
template
<
class
Operand
,
std
::
enable_if_t
<
isDataContainer
<
Operand
>,
int
>>
friend
constexpr
auto
evaluateOrReturn
(
Operand
const
&
operand
);
/// for enabling accessData()
friend
DataContainer
<
data_t
>
;
protected:
/// convenience typedef for the Eigen::Matrix data vector
using
DataVector_t
=
Eigen
::
Matrix
<
data_t
,
Eigen
::
Dynamic
,
1
>
;
...
...
@@ -178,11 +175,5 @@ namespace elsa
/// derived classes should override this method to implement move assignment
virtual
void
assign
(
DataHandler
<
data_t
>&&
other
)
=
0
;
/// derived classes return underlying data
virtual
DataMap_t
accessData
()
=
0
;
/// derived classes return underlying data
virtual
DataMap_t
accessData
()
const
=
0
;
};
}
// namespace elsa
elsa/core/DataHandlerCPU.h
View file @
7f24cd6d
...
...
@@ -18,7 +18,7 @@ namespace elsa
class
DataHandlerCPU
;
// forward declaration, used for testing and defined in test file (declared as friend)
template
<
typename
data_t
>
long
useCount
(
const
DataHandlerCPU
<
data_t
>&
);
long
useCount
(
const
DataHandlerCPU
<
data_t
>&
/*dh*/
);
/**
* \brief Class representing and owning a vector stored in CPU main memory (using
...
...
@@ -45,9 +45,16 @@ namespace elsa
/// declare DataHandlerMapCPU as friend, allows the use of Eigen for improved performance
friend
DataHandlerMapCPU
<
data_t
>
;
/// for enabling accessData()
friend
DataContainer
<
data_t
>
;
/// used for testing only and defined in test file
friend
long
useCount
<>
(
const
DataHandlerCPU
<
data_t
>&
dh
);
/// friend constexpr function to implement expression templates
template
<
bool
GPU
,
class
Operand
,
std
::
enable_if_t
<
isDataContainer
<
Operand
>,
int
>>
friend
constexpr
auto
evaluateOrReturn
(
Operand
const
&
operand
);
protected:
/// convenience typedef for the Eigen::Matrix data vector
using
DataVector_t
=
Eigen
::
Matrix
<
data_t
,
Eigen
::
Dynamic
,
1
>
;
...
...
@@ -176,10 +183,10 @@ namespace elsa
void
assign
(
DataHandler
<
data_t
>&&
other
)
override
;
/// return non-const version of data
DataMap_t
accessData
()
override
;
DataMap_t
accessData
();
/// return const version of data
DataMap_t
accessData
()
const
override
;
DataMap_t
accessData
()
const
;
private:
/// creates the deep copy for the copy-on-write mechanism
...
...
elsa/core/DataHandlerGPU.cpp
0 → 100644
View file @
7f24cd6d
#include "DataHandlerGPU.h"
#include "DataHandlerMapGPU.h"
#include <cublas_v2.h>
namespace
elsa
{
template
<
typename
data_t
>
DataHandlerGPU
<
data_t
>::
DataHandlerGPU
(
index_t
size
)
:
_data
(
std
::
make_shared
<
quickvec
::
Vector
<
data_t
>>
(
size
))
{
}
template
<
typename
data_t
>
DataHandlerGPU
<
data_t
>::
DataHandlerGPU
(
DataVector_t
const
&
vector
)
:
_data
(
std
::
make_shared
<
quickvec
::
Vector
<
data_t
>>
(
vector
))
{
}
template
<
typename
data_t
>
DataHandlerGPU
<
data_t
>::
DataHandlerGPU
(
quickvec
::
Vector
<
data_t
>
const
&
vector
)
:
_data
(
std
::
make_shared
<
quickvec
::
Vector
<
data_t
>>
(
vector
.
clone
()))
{
}
template
<
typename
data_t
>
DataHandlerGPU
<
data_t
>::
DataHandlerGPU
(
const
DataHandlerGPU
<
data_t
>&
other
)
:
_data
{
other
.
_data
},
_associatedMaps
{}
{
}
template
<
typename
data_t
>
DataHandlerGPU
<
data_t
>::
DataHandlerGPU
(
DataHandlerGPU
<
data_t
>&&
other
)
noexcept
:
_data
{
std
::
move
(
other
.
_data
)},
_associatedMaps
{
std
::
move
(
other
.
_associatedMaps
)}
{
for
(
auto
&
map
:
_associatedMaps
)
map
->
_dataOwner
=
this
;
}
template
<
typename
data_t
>
DataHandlerGPU
<
data_t
>::~
DataHandlerGPU
()
{
for
(
auto
&
map
:
_associatedMaps
)
map
->
_dataOwner
=
nullptr
;
}
template
<
typename
data_t
>
index_t
DataHandlerGPU
<
data_t
>::
getSize
()
const
{
return
static_cast
<
index_t
>
(
_data
->
size
());
}
template
<
typename
data_t
>
data_t
&
DataHandlerGPU
<
data_t
>::
operator
[](
index_t
index
)
{
detach
();
return
(
*
_data
)[
static_cast
<
size_t
>
(
index
)];
}
template
<
typename
data_t
>
const
data_t
&
DataHandlerGPU
<
data_t
>::
operator
[](
index_t
index
)
const
{
return
(
*
_data
)[
static_cast
<
size_t
>
(
index
)];
}
template
<
typename
data_t
>
data_t
DataHandlerGPU
<
data_t
>::
dot
(
const
DataHandler
<
data_t
>&
v
)
const
{
if
(
v
.
getSize
()
!=
getSize
())
throw
std
::
invalid_argument
(
"DataHandlerGPU: dot product argument has wrong size"
);
// use CUDA if the other handler is GPU, otherwise use the slow fallback version
if
(
auto
otherHandler
=
dynamic_cast
<
const
DataHandlerGPU
*>
(
&
v
))
{
return
_data
->
dot
(
*
otherHandler
->
_data
);
}
else
if
(
auto
otherHandler
=
dynamic_cast
<
const
DataHandlerMapGPU
<
data_t
>*>
(
&
v
))
{
return
_data
->
dot
(
otherHandler
->
_map
);
}
else
{
return
this
->
slowDotProduct
(
v
);
}
}
template
<
typename
data_t
>
GetFloatingPointType_t
<
data_t
>
DataHandlerGPU
<
data_t
>::
squaredL2Norm
()
const
{
return
_data
->
squaredL2Norm
();
}
template
<
typename
data_t
>
GetFloatingPointType_t
<
data_t
>
DataHandlerGPU
<
data_t
>::
l1Norm
()
const
{
return
_data
->
l1Norm
();
}
template
<
typename
data_t
>
GetFloatingPointType_t
<
data_t
>
DataHandlerGPU
<
data_t
>::
lInfNorm
()
const
{
return
_data
->
lInfNorm
();
}
template
<
typename
data_t
>
data_t
DataHandlerGPU
<
data_t
>::
sum
()
const
{