Commit 7004ac58 authored by Nikola Dinev's avatar Nikola Dinev Committed by Tobias Lasser
Browse files

Protect copy constructors in non-leaf and leaf classes, to avoid the...

Protect copy constructors in non-leaf and leaf classes, to avoid the implicitly generated ones by the compiler that could cause slicing. In Cloneable, the assignment operator was deleted to disable issues with implicitly generated assingment operators in derived classes to prevent slicing. (resolves #31)
parent aea3cb41
Pipeline #190113 passed with stages
in 7 minutes and 20 seconds
...@@ -31,6 +31,9 @@ namespace elsa ...@@ -31,6 +31,9 @@ namespace elsa
bool operator!=(const Derived& other) const { return !(*this == other); } bool operator!=(const Derived& other) const { return !(*this == other); }
/// delete implicitly declared copy assignment to prevent copy assignment of derived classes
Cloneable& operator=(const Cloneable&) = delete;
protected: protected:
/// actual clone implementation method, abstract to force override in derived classes /// actual clone implementation method, abstract to force override in derived classes
virtual Derived* cloneImpl() const = 0; virtual Derived* cloneImpl() const = 0;
...@@ -39,10 +42,8 @@ namespace elsa ...@@ -39,10 +42,8 @@ namespace elsa
virtual bool isEqual(const Derived& other) const = 0; virtual bool isEqual(const Derived& other) const = 0;
/// default copy constructor, protected to not be publicly available (but available for /// default copy constructor, protected to not be publicly available (but available for
/// cloneImpl) /// cloneImpl()
Cloneable(const Cloneable&) = default; Cloneable(const Cloneable&) = default;
/// default copy assignment, protected to not be publicly available
Cloneable& operator=(const Cloneable&) = default;
}; };
} // namespace elsa } // namespace elsa
...@@ -105,6 +105,9 @@ namespace elsa ...@@ -105,6 +105,9 @@ namespace elsa
/// vector containing the origin of the described volume (typically the center) /// vector containing the origin of the described volume (typically the center)
RealVector_t _locationOfOrigin; RealVector_t _locationOfOrigin;
/// default copy constructor, hidden from non-derived classes to prevent potential slicing
DataDescriptor(const DataDescriptor&) = default;
/// implement the polymorphic clone operation /// implement the polymorphic clone operation
DataDescriptor* cloneImpl() const override; DataDescriptor* cloneImpl() const override;
......
...@@ -68,6 +68,9 @@ namespace elsa ...@@ -68,6 +68,9 @@ namespace elsa
~EmissionLogLikelihood() override = default; ~EmissionLogLikelihood() override = default;
protected: protected:
/// default copy constructor, hidden from non-derived classes to prevent potential slicing
EmissionLogLikelihood(const EmissionLogLikelihood<data_t>&) = default;
/// the evaluation of the emission log-likelihood /// the evaluation of the emission log-likelihood
data_t evaluateImpl(const DataContainer<data_t>& Rx) override; data_t evaluateImpl(const DataContainer<data_t>& Rx) override;
......
...@@ -46,6 +46,9 @@ namespace elsa ...@@ -46,6 +46,9 @@ namespace elsa
~Huber() override = default; ~Huber() override = default;
protected: protected:
/// default copy constructor, hidden from non-derived classes to prevent potential slicing
Huber(const Huber<data_t>&) = default;
/// the evaluation of the Huber norm /// the evaluation of the Huber norm
data_t evaluateImpl(const DataContainer<data_t>& Rx) override; data_t evaluateImpl(const DataContainer<data_t>& Rx) override;
......
...@@ -42,6 +42,9 @@ namespace elsa ...@@ -42,6 +42,9 @@ namespace elsa
~L1Norm() override = default; ~L1Norm() override = default;
protected: protected:
/// default copy constructor, hidden from non-derived classes to prevent potential slicing
L1Norm(const L1Norm<data_t>&) = default;
/// the evaluation of the l1 norm /// the evaluation of the l1 norm
data_t evaluateImpl(const DataContainer<data_t>& Rx) override; data_t evaluateImpl(const DataContainer<data_t>& Rx) override;
......
...@@ -41,6 +41,9 @@ namespace elsa ...@@ -41,6 +41,9 @@ namespace elsa
~L2NormPow2() override = default; ~L2NormPow2() override = default;
protected: protected:
/// default copy constructor, hidden from non-derived classes to prevent potential slicing
L2NormPow2(const L2NormPow2<data_t>&) = default;
/// the evaluation of the l2 norm (squared) /// the evaluation of the l2 norm (squared)
data_t evaluateImpl(const DataContainer<data_t>& Rx) override; data_t evaluateImpl(const DataContainer<data_t>& Rx) override;
......
...@@ -42,6 +42,9 @@ namespace elsa ...@@ -42,6 +42,9 @@ namespace elsa
~LInfNorm() override = default; ~LInfNorm() override = default;
protected: protected:
/// default copy constructor, hidden from non-derived classes to prevent potential slicing
LInfNorm(const LInfNorm<data_t>&) = default;
/// the evaluation of the linf norm /// the evaluation of the linf norm
data_t evaluateImpl(const DataContainer<data_t>& Rx) override; data_t evaluateImpl(const DataContainer<data_t>& Rx) override;
......
...@@ -65,6 +65,9 @@ namespace elsa ...@@ -65,6 +65,9 @@ namespace elsa
const DataContainer<data_t>& getDataVector() const; const DataContainer<data_t>& getDataVector() const;
protected: protected:
/// default copy constructor, hidden from non-derived classes to prevent potential slicing
LinearResidual(const LinearResidual<data_t>&) = default;
/// implement the polymorphic clone operation /// implement the polymorphic clone operation
LinearResidual<data_t>* cloneImpl() const override; LinearResidual<data_t>* cloneImpl() const override;
......
...@@ -47,6 +47,9 @@ namespace elsa ...@@ -47,6 +47,9 @@ namespace elsa
~PseudoHuber() override = default; ~PseudoHuber() override = default;
protected: protected:
/// default copy constructor, hidden from non-derived classes to prevent potential slicing
PseudoHuber(const PseudoHuber<data_t>&) = default;
/// the evaluation of the Huber norm /// the evaluation of the Huber norm
data_t evaluateImpl(const DataContainer<data_t>& Rx) override; data_t evaluateImpl(const DataContainer<data_t>& Rx) override;
......
...@@ -67,6 +67,9 @@ namespace elsa ...@@ -67,6 +67,9 @@ namespace elsa
const LinearResidual<data_t>& getGradientExpression() const; const LinearResidual<data_t>& getGradientExpression() const;
protected: protected:
/// default copy constructor, hidden from non-derived classes to prevent potential slicing
Quadric(const Quadric<data_t>&) = default;
/// the evaluation of the Quadric functional /// the evaluation of the Quadric functional
data_t evaluateImpl(const DataContainer<data_t>& Rx) override; data_t evaluateImpl(const DataContainer<data_t>& Rx) override;
......
...@@ -77,6 +77,9 @@ namespace elsa ...@@ -77,6 +77,9 @@ namespace elsa
~TransmissionLogLikelihood() override = default; ~TransmissionLogLikelihood() override = default;
protected: protected:
/// default copy constructor, hidden from non-derived classes to prevent potential slicing
TransmissionLogLikelihood(const TransmissionLogLikelihood<data_t>&) = default;
/// the evaluation of the transmission log-likelihood /// the evaluation of the transmission log-likelihood
data_t evaluateImpl(const DataContainer<data_t>& Rx) override; data_t evaluateImpl(const DataContainer<data_t>& Rx) override;
......
...@@ -46,6 +46,9 @@ namespace elsa ...@@ -46,6 +46,9 @@ namespace elsa
const Scaling<data_t>& getWeightingOperator() const; const Scaling<data_t>& getWeightingOperator() const;
protected: protected:
/// default copy constructor, hidden from non-derived classes to prevent potential slicing
WeightedL2NormPow2(const WeightedL2NormPow2<data_t>&) = default;
/// the evaluation of the weighted, squared l2 norm /// the evaluation of the weighted, squared l2 norm
data_t evaluateImpl(const DataContainer<data_t>& Rx) override; data_t evaluateImpl(const DataContainer<data_t>& Rx) override;
......
...@@ -5,9 +5,11 @@ ...@@ -5,9 +5,11 @@
namespace elsa namespace elsa
{ {
std::pair<std::vector<Geometry>, DataDescriptor> CircleTrajectoryGenerator::createTrajectory( std::pair<std::vector<Geometry>, std::unique_ptr<DataDescriptor>>
index_t numberOfPoses, const DataDescriptor& volumeDescriptor, index_t arcDegrees, CircleTrajectoryGenerator::createTrajectory(index_t numberOfPoses,
real_t sourceToCenter, real_t centerToDetector) const DataDescriptor& volumeDescriptor,
index_t arcDegrees, real_t sourceToCenter,
real_t centerToDetector)
{ {
// sanity check // sanity check
auto dim = volumeDescriptor.getNumberOfDimensions(); auto dim = volumeDescriptor.getNumberOfDimensions();
...@@ -48,7 +50,7 @@ namespace elsa ...@@ -48,7 +50,7 @@ namespace elsa
} }
} }
return std::make_pair(geometryList, sinoDescriptor); return std::make_pair(geometryList, sinoDescriptor.clone());
} }
} // namespace elsa } // namespace elsa
...@@ -36,7 +36,7 @@ namespace elsa ...@@ -36,7 +36,7 @@ namespace elsa
* *
* Please note: the sinogram size/spacing will match the volume size/spacing. * Please note: the sinogram size/spacing will match the volume size/spacing.
*/ */
static std::pair<std::vector<Geometry>, DataDescriptor> static std::pair<std::vector<Geometry>, std::unique_ptr<DataDescriptor>>
createTrajectory(index_t numberOfPoses, const DataDescriptor& volumeDescriptor, createTrajectory(index_t numberOfPoses, const DataDescriptor& volumeDescriptor,
index_t arcDegrees, real_t sourceToCenter, real_t centerToDetector); index_t arcDegrees, real_t sourceToCenter, real_t centerToDetector);
}; };
......
...@@ -44,7 +44,7 @@ SCENARIO("Create a Circular Trajectory") ...@@ -44,7 +44,7 @@ SCENARIO("Create a Circular Trajectory")
real_t angle = (1.0 / (numberOfAngles - 1)) * halfCircular; real_t angle = (1.0 / (numberOfAngles - 1)) * halfCircular;
for (int i = 0; i < numberOfAngles; ++i) { for (int i = 0; i < numberOfAngles; ++i) {
real_t currAngle = i * angle * pi / 180.0; real_t currAngle = i * angle * pi / 180.0;
Geometry tmpGeom(sourceToCenter, centerToDetector, currAngle, desc, sdesc); Geometry tmpGeom(sourceToCenter, centerToDetector, currAngle, desc, *sdesc);
REQUIRE((tmpGeom.getCameraCenter() - geomList[i].getCameraCenter()).norm() REQUIRE((tmpGeom.getCameraCenter() - geomList[i].getCameraCenter()).norm()
== Approx(0)); == Approx(0));
...@@ -77,7 +77,7 @@ SCENARIO("Create a Circular Trajectory") ...@@ -77,7 +77,7 @@ SCENARIO("Create a Circular Trajectory")
real_t angle = (1.0 / (numberOfAngles - 1)) * halfCircular; real_t angle = (1.0 / (numberOfAngles - 1)) * halfCircular;
for (int i = 0; i < numberOfAngles; ++i) { for (int i = 0; i < numberOfAngles; ++i) {
real_t currAngle = i * angle * pi / 180.0; real_t currAngle = i * angle * pi / 180.0;
Geometry tmpGeom(sourceToCenter, centerToDetector, currAngle, desc, sdesc); Geometry tmpGeom(sourceToCenter, centerToDetector, currAngle, desc, *sdesc);
REQUIRE((tmpGeom.getCameraCenter() - geomList[i].getCameraCenter()).norm() REQUIRE((tmpGeom.getCameraCenter() - geomList[i].getCameraCenter()).norm()
== Approx(0)); == Approx(0));
...@@ -118,7 +118,7 @@ SCENARIO("Create a Circular Trajectory") ...@@ -118,7 +118,7 @@ SCENARIO("Create a Circular Trajectory")
real_t angleInc = 1.0 * halfCircular / (numberOfAngles - 1); real_t angleInc = 1.0 * halfCircular / (numberOfAngles - 1);
for (int i = 0; i < numberOfAngles; ++i) { for (int i = 0; i < numberOfAngles; ++i) {
real_t angle = i * angleInc * pi / 180.0; real_t angle = i * angleInc * pi / 180.0;
Geometry tmpGeom(sourceToCenter, centerToDetector, desc, sdesc, angle); Geometry tmpGeom(sourceToCenter, centerToDetector, desc, *sdesc, angle);
REQUIRE((tmpGeom.getCameraCenter() - geomList[i].getCameraCenter()).norm() REQUIRE((tmpGeom.getCameraCenter() - geomList[i].getCameraCenter()).norm()
== Approx(0)); == Approx(0));
...@@ -151,7 +151,7 @@ SCENARIO("Create a Circular Trajectory") ...@@ -151,7 +151,7 @@ SCENARIO("Create a Circular Trajectory")
real_t angleInc = 1.0 * halfCircular / (numberOfAngles - 1); real_t angleInc = 1.0 * halfCircular / (numberOfAngles - 1);
for (int i = 0; i < numberOfAngles; ++i) { for (int i = 0; i < numberOfAngles; ++i) {
real_t angle = i * angleInc * pi / 180.0; real_t angle = i * angleInc * pi / 180.0;
Geometry tmpGeom(sourceToCenter, centerToDetector, desc, sdesc, angle); Geometry tmpGeom(sourceToCenter, centerToDetector, desc, *sdesc, angle);
REQUIRE((tmpGeom.getCameraCenter() - geomList[i].getCameraCenter()).norm() REQUIRE((tmpGeom.getCameraCenter() - geomList[i].getCameraCenter()).norm()
== Approx(0)); == Approx(0));
......
...@@ -20,7 +20,7 @@ namespace elsa ...@@ -20,7 +20,7 @@ namespace elsa
auto [descriptor, dataType] = parseHeader(properties); auto [descriptor, dataType] = parseHeader(properties);
// read in the data // read in the data
DataContainer<data_t> dataContainer(descriptor); DataContainer<data_t> dataContainer(*descriptor);
if (dataType == DataUtils::DataType::UINT16) if (dataType == DataUtils::DataType::UINT16)
DataUtils::parseRawData<uint16_t, data_t>(file, dataContainer); DataUtils::parseRawData<uint16_t, data_t>(file, dataContainer);
...@@ -117,7 +117,7 @@ namespace elsa ...@@ -117,7 +117,7 @@ namespace elsa
return properties; return properties;
} }
std::pair<DataDescriptor, DataUtils::DataType> std::pair<std::unique_ptr<DataDescriptor>, DataUtils::DataType>
EDF::parseHeader(const std::map<std::string, std::string>& properties) EDF::parseHeader(const std::map<std::string, std::string>& properties)
{ {
// read the dimensions // read the dimensions
...@@ -208,10 +208,8 @@ namespace elsa ...@@ -208,10 +208,8 @@ namespace elsa
dimSpacingVec[i] = spacing[i]; dimSpacingVec[i] = spacing[i];
} }
// the data descriptor condensed from the info return std::make_pair(std::make_unique<DataDescriptor>(dimSizeVec, dimSpacingVec),
DataDescriptor dataDescriptor(dimSizeVec, dimSpacingVec); dataType);
return std::make_pair(dataDescriptor, dataType);
} }
template <typename data_t> template <typename data_t>
...@@ -225,7 +223,7 @@ namespace elsa ...@@ -225,7 +223,7 @@ namespace elsa
file << "ByteOrder = LowByteFirst;\n"; file << "ByteOrder = LowByteFirst;\n";
file << "DataType = " << getDataTypeName(data) << ";\n"; file << "DataType = " << getDataTypeName(data) << ";\n";
auto descriptor = data.getDataDescriptor(); auto& descriptor = data.getDataDescriptor();
// write dimension and size // write dimension and size
for (std::size_t i = 0; i < descriptor.getNumberOfDimensions(); ++i) for (std::size_t i = 0; i < descriptor.getNumberOfDimensions(); ++i)
......
...@@ -43,7 +43,7 @@ namespace elsa ...@@ -43,7 +43,7 @@ namespace elsa
static std::map<std::string, std::string> readHeader(std::ifstream& file); static std::map<std::string, std::string> readHeader(std::ifstream& file);
/// parse the EDF header property map into a DataDescriptor and DataType /// parse the EDF header property map into a DataDescriptor and DataType
static std::pair<DataDescriptor, DataUtils::DataType> static std::pair<std::unique_ptr<DataDescriptor>, DataUtils::DataType>
parseHeader(const std::map<std::string, std::string>& properties); parseHeader(const std::map<std::string, std::string>& properties);
/// write the EDF header to file /// write the EDF header to file
......
...@@ -25,7 +25,7 @@ namespace elsa ...@@ -25,7 +25,7 @@ namespace elsa
throw std::runtime_error("MHD::read: can not read from '" + dataPath + "'"); throw std::runtime_error("MHD::read: can not read from '" + dataPath + "'");
// read in the data // read in the data
DataContainer<data_t> dataContainer(descriptor); DataContainer<data_t> dataContainer(*descriptor);
if (dataType == DataUtils::DataType::UINT16) if (dataType == DataUtils::DataType::UINT16)
DataUtils::parseRawData<uint16_t, data_t>(dataFile, dataContainer); DataUtils::parseRawData<uint16_t, data_t>(dataFile, dataContainer);
...@@ -97,7 +97,7 @@ namespace elsa ...@@ -97,7 +97,7 @@ namespace elsa
return properties; return properties;
} }
std::tuple<DataDescriptor, std::string, DataUtils::DataType> std::tuple<std::unique_ptr<DataDescriptor>, std::string, DataUtils::DataType>
MHD::parseHeader(const std::map<std::string, std::string>& properties) MHD::parseHeader(const std::map<std::string, std::string>& properties)
{ {
// check the dimensions // check the dimensions
...@@ -183,17 +183,15 @@ namespace elsa ...@@ -183,17 +183,15 @@ namespace elsa
dimSpacing[i] = dimSpacingVec[i]; dimSpacing[i] = dimSpacingVec[i];
} }
// the data descriptor condensed form the info return std::make_tuple(std::make_unique<DataDescriptor>(dimSizes, dimSpacing), rawDataPath,
DataDescriptor dataDescriptor(dimSizes, dimSpacing); dataType);
return std::make_tuple(dataDescriptor, rawDataPath, dataType);
} }
template <typename data_t> template <typename data_t>
void MHD::writeHeader(std::ofstream& metaFile, const DataContainer<data_t>& data, void MHD::writeHeader(std::ofstream& metaFile, const DataContainer<data_t>& data,
std::string rawFilename) std::string rawFilename)
{ {
auto descriptor = data.getDataDescriptor(); auto& descriptor = data.getDataDescriptor();
// write dimension, size and spacing // write dimension, size and spacing
metaFile << "NDims = " << descriptor.getNumberOfDimensions() << "\n"; metaFile << "NDims = " << descriptor.getNumberOfDimensions() << "\n";
......
...@@ -46,7 +46,7 @@ namespace elsa ...@@ -46,7 +46,7 @@ namespace elsa
static std::map<std::string, std::string> readHeader(std::ifstream& metaFile); static std::map<std::string, std::string> readHeader(std::ifstream& metaFile);
/// parse the MHD header property map into a DataDescriptor and DataType /// parse the MHD header property map into a DataDescriptor and DataType
static std::tuple<DataDescriptor, std::string, DataUtils::DataType> static std::tuple<std::unique_ptr<DataDescriptor>, std::string, DataUtils::DataType>
parseHeader(const std::map<std::string, std::string>& properties); parseHeader(const std::map<std::string, std::string>& properties);
/// write the MHD header to file /// write the MHD header to file
......
...@@ -64,6 +64,9 @@ namespace elsa ...@@ -64,6 +64,9 @@ namespace elsa
~FiniteDifferences() override = default; ~FiniteDifferences() override = default;
protected: protected:
/// default copy constructor, hidden from non-derived classes to prevent potential slicing
FiniteDifferences(const FiniteDifferences<data_t>&) = default;
/// apply the finite differences operator /// apply the finite differences operator
void applyImpl(const DataContainer<data_t>& x, DataContainer<data_t>& Ax) const override; void applyImpl(const DataContainer<data_t>& x, DataContainer<data_t>& Ax) const override;
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment