Commit 0f675515 authored by Jens Petit's avatar Jens Petit

Copy-on-Write for DataContainer

parent 6ccaae89
......@@ -30,6 +30,12 @@ set(MODULE_SOURCES
add_library(${ELSA_MODULE_TARGET_NAME} ${MODULE_HEADERS} ${MODULE_SOURCES})
add_library(elsa::${ELSA_MODULE_NAME} ALIAS ${ELSA_MODULE_TARGET_NAME})
# use OpenMP is available
find_package(OpenMP REQUIRED)
if(OpenMP_CXX_FOUND)
target_link_libraries(${ELSA_MODULE_TARGET_NAME} PRIVATE OpenMP::OpenMP_CXX)
endif()
target_include_directories(${ELSA_MODULE_TARGET_NAME}
PUBLIC
$<INSTALL_INTERFACE:include/${ELSA_MODULE_NAME}>
......
......@@ -28,7 +28,7 @@ namespace elsa {
template <typename data_t>
DataContainer<data_t>::DataContainer(const DataContainer<data_t> &other)
: _dataDescriptor{other._dataDescriptor->clone()},
_dataHandler{other._dataHandler->clone()}
_dataHandler{other._dataHandler}
{}
template <typename data_t>
......@@ -36,7 +36,7 @@ namespace elsa {
{
if (this != &other) {
_dataDescriptor = other._dataDescriptor->clone();
_dataHandler = other._dataHandler->clone();
_dataHandler = other._dataHandler;
}
return *this;
......@@ -85,6 +85,7 @@ namespace elsa {
template<typename data_t>
data_t &DataContainer<data_t>::operator[](index_t index)
{
detach();
return (*_dataHandler)[index];
}
......@@ -97,6 +98,7 @@ namespace elsa {
template<typename data_t>
data_t &DataContainer<data_t>::operator()(IndexVector_t coordinate)
{
detach();
return (*_dataHandler)[_dataDescriptor->getIndexFromCoordinate(coordinate)];
}
......@@ -166,6 +168,7 @@ namespace elsa {
template <typename data_t>
DataContainer<data_t>& DataContainer<data_t>::operator+=(const DataContainer<data_t>& dc)
{
detach();
*_dataHandler += *dc._dataHandler;
return *this;
}
......@@ -173,6 +176,7 @@ namespace elsa {
template <typename data_t>
DataContainer<data_t>& DataContainer<data_t>::operator-=(const DataContainer<data_t>& dc)
{
detach();
*_dataHandler -= *dc._dataHandler;
return *this;
}
......@@ -180,6 +184,7 @@ namespace elsa {
template <typename data_t>
DataContainer<data_t>& DataContainer<data_t>::operator*=(const DataContainer<data_t>& dc)
{
detach();
*_dataHandler *= *dc._dataHandler;
return *this;
}
......@@ -187,6 +192,7 @@ namespace elsa {
template <typename data_t>
DataContainer<data_t>& DataContainer<data_t>::operator/=(const DataContainer<data_t>& dc)
{
detach();
*_dataHandler /= *dc._dataHandler;
return *this;
}
......@@ -195,6 +201,7 @@ namespace elsa {
template <typename data_t>
DataContainer<data_t>& DataContainer<data_t>::operator+=(data_t scalar)
{
detach();
*_dataHandler += scalar;
return *this;
}
......@@ -202,6 +209,7 @@ namespace elsa {
template <typename data_t>
DataContainer<data_t>& DataContainer<data_t>::operator-=(data_t scalar)
{
detach();
*_dataHandler -= scalar;
return *this;
}
......@@ -209,6 +217,7 @@ namespace elsa {
template <typename data_t>
DataContainer<data_t>& DataContainer<data_t>::operator*=(data_t scalar)
{
detach();
*_dataHandler *= scalar;
return *this;
}
......@@ -216,6 +225,7 @@ namespace elsa {
template <typename data_t>
DataContainer<data_t>& DataContainer<data_t>::operator/=(data_t scalar)
{
detach();
*_dataHandler /= scalar;
return *this;
}
......@@ -224,6 +234,7 @@ namespace elsa {
template <typename data_t>
DataContainer<data_t>& DataContainer<data_t>::operator=(data_t scalar)
{
detach();
*_dataHandler = scalar;
return *this;
}
......@@ -269,6 +280,17 @@ namespace elsa {
return !(*this == other);
}
template <typename data_t>
void DataContainer<data_t>::detach()
{
if (_dataHandler.use_count() != 1) {
#pragma omp barrier
#pragma omp single
_dataHandler = _dataHandler->clone();
}
return;
}
// ------------------------------------------
// explicit template instantiation
template class DataContainer<float>;
......
......@@ -10,6 +10,11 @@
namespace elsa
{
/// forward declaration for friend function useCount
template <typename data_t = real_t> class DataContainer;
/// used for testing and defined in test file
template <typename data_t> int useCount(const DataContainer<data_t>&);
/**
* \brief class representing and storing a linearized n-dimensional signal
......@@ -23,8 +28,11 @@ namespace elsa
* This class provides a container for a signal that is stored in memory. This signal can
* be n-dimensional, and will be stored in memory in a linearized fashion. The information
* on how this linearization is performed is provided by an associated DataDescriptor.
*
* The class implements copy-on-write. Therefore any non-const functions should call the
* detach() function first to trigger the copy-on-write mechanism.
*/
template <typename data_t = real_t>
template <typename data_t>
class DataContainer {
public:
/**
......@@ -88,7 +96,6 @@ namespace elsa
*/
DataContainer<data_t>& operator=(DataContainer<data_t>&& other);
/// return the current DataDescriptor
const DataDescriptor& getDataDescriptor() const;
......@@ -191,12 +198,14 @@ namespace elsa
/// comparison with another DataContainer
bool operator!=(const DataContainer<data_t>& other) const;
/// used for testing and defined in test file
friend int useCount <> (const DataContainer<data_t>& dc);
private:
/// the current DataDescriptor
std::unique_ptr<DataDescriptor> _dataDescriptor;
/// the current DataHandler
std::unique_ptr<DataHandler<data_t>> _dataHandler;
std::shared_ptr<DataHandler<data_t>> _dataHandler;
/// factory method to create DataHandlers based on handlerType with perfect forwarding of constructor arguments
template <typename ... Args>
......@@ -204,6 +213,9 @@ namespace elsa
/// private constructor accepting a DataDescriptor and a DataHandler
explicit DataContainer(const DataDescriptor& dataDescriptor, std::unique_ptr<DataHandler<data_t>> dataHandler);
/// creates the deep copy for the copy-on-write mechanism
void detach();
};
......
......@@ -11,6 +11,12 @@
#include <catch2/catch.hpp>
#include "DataContainer.h"
template <typename data_t>
int elsa::useCount(const DataContainer<data_t>& dc)
{
return dc._dataHandler.use_count();
}
using namespace elsa;
using namespace Catch::literals; // to enable 0.0_a approximate floats
......@@ -371,6 +377,102 @@ SCENARIO("Testing the arithmetic operations with DataContainer arguments") {
REQUIRE(resultDivScalar[i] == dc[i] / scalar);
}
}
}
SCENARIO("Testing the copy-on-write mechanism") {
GIVEN("A random DataContainer") {
IndexVector_t numCoeff(3);
numCoeff << 52, 7, 29;
DataDescriptor desc(numCoeff);
DataContainer dc(desc);
Eigen::VectorXf randVec = Eigen::VectorXf::Random(dc.getSize());
for (index_t i = 0; i < dc.getSize(); ++i) {
dc[i] = randVec(i);
}
WHEN("const manipulating a copy constructed shallow copy") {
DataContainer dc2(dc);
dc + dc;
dc.sum();
dc.square();
dc.log();
dc.dot(dc);
dc.l1Norm();
THEN("the data is the same") {
REQUIRE(dc2 == dc);
REQUIRE(useCount(dc) == 2);
}
}
WHEN("non-const manipulating a copy constructed shallow copy") {
DataContainer dc2(dc);
REQUIRE(useCount(dc) == 2);
REQUIRE(useCount(dc2) == 2);
THEN("copy-on-write is invoked") {
dc2 += 2;
REQUIRE(dc2 != dc);
REQUIRE(useCount(dc2) == 1);
REQUIRE(useCount(dc) == 1);
}
THEN("copy-on-write is invoked") {
dc2 += dc;
REQUIRE(dc2 != dc);
REQUIRE(useCount(dc2) == 1);
REQUIRE(useCount(dc) == 1);
}
THEN("copy-on-write is invoked") {
dc2 -= 2;
REQUIRE(dc2 != dc);
}
THEN("copy-on-write is invoked") {
dc2 -= dc;
REQUIRE(dc2 != dc);
}
THEN("copy-on-write is invoked") {
dc2 /= 2;
REQUIRE(dc2 != dc);
}
THEN("copy-on-write is invoked") {
dc2 /= dc;
REQUIRE(dc2 != dc);
}
THEN("copy-on-write is invoked") {
dc2 *= 2;
REQUIRE(dc2 != dc);
}
THEN("copy-on-write is invoked") {
dc2 *= dc;
REQUIRE(dc2 != dc);
}
THEN("copy-on-write is invoked") {
dc[0] += 2;
REQUIRE(dc2 != dc);
}
}
WHEN("manipulating a non-shallow-copied container") {
for (index_t i = 0; i < dc.getSize(); ++i) {
dc[i] += 2;
}
THEN("copy-on-write should not be invoked") {
REQUIRE(useCount(dc) == 1);
}
}
}
}
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment