Commit 5c098e26 authored by Andi Braimllari's avatar Andi Braimllari Committed by David Frank
Browse files

#139 Support scalar multiplication with LinearOperator

To shortcut scaling of operators, add a multiplication of left
hand sided scalar with a right hand sided LinearOperator.

Now, it's not necessary to create uniform scaling operations
to scale operators.
parent fb1409f5
Pipeline #702324 passed with stages
in 18 minutes and 53 seconds
......@@ -19,6 +19,7 @@ namespace elsa
: Cloneable<LinearOperator<data_t>>(),
_domainDescriptor{other._domainDescriptor->clone()},
_rangeDescriptor{other._rangeDescriptor->clone()},
_scalar{other._scalar},
_isLeaf{other._isLeaf},
_isAdjoint{other._isAdjoint},
_isComposite{other._isComposite},
......@@ -28,8 +29,14 @@ namespace elsa
_lhs = other._lhs->clone();
if (_isComposite) {
_lhs = other._lhs->clone();
_rhs = other._rhs->clone();
if (_mode == CompositeMode::ADD || _mode == CompositeMode::MULT) {
_lhs = other._lhs->clone();
_rhs = other._rhs->clone();
}
if (_mode == CompositeMode::SCALAR_MULT) {
_rhs = other._rhs->clone();
}
}
}
......@@ -39,6 +46,7 @@ namespace elsa
if (*this != other) {
_domainDescriptor = other._domainDescriptor->clone();
_rangeDescriptor = other._rangeDescriptor->clone();
_scalar = other._scalar;
_isLeaf = other._isLeaf;
_isAdjoint = other._isAdjoint;
_isComposite = other._isComposite;
......@@ -48,8 +56,14 @@ namespace elsa
_lhs = other._lhs->clone();
if (_isComposite) {
_lhs = other._lhs->clone();
_rhs = other._rhs->clone();
if (_mode == CompositeMode::ADD || _mode == CompositeMode::MULT) {
_lhs = other._lhs->clone();
_rhs = other._rhs->clone();
}
if (_mode == CompositeMode::SCALAR_MULT) {
_rhs = other._rhs->clone();
}
}
}
......@@ -136,6 +150,21 @@ namespace elsa
_lhs->apply(temp, Ax);
return;
}
if (_mode == CompositeMode::SCALAR_MULT) {
// sanity check the arguments for the intended evaluation tree leaf operation
if (_rhs->getDomainDescriptor().getNumberOfCoefficients() != x.getSize())
throw InvalidArgumentError("LinearOperator::apply: incorrect input/output "
"sizes for scalar mult. leaf");
// sanity check the scalar in the optional
if (!_scalar.has_value())
throw InvalidArgumentError(
"LinearOperator::apply: no value found in the scalar optional");
_rhs->apply(x, Ax);
Ax *= _scalar.value();
return;
}
}
throw LogicError("LinearOperator: apply called on ill-formed object");
......@@ -209,6 +238,21 @@ namespace elsa
_rhs->applyAdjoint(temp, Aty);
return;
}
if (_mode == CompositeMode::SCALAR_MULT) {
// sanity check the arguments for the intended evaluation tree leaf operation
if (_rhs->getRangeDescriptor().getNumberOfCoefficients() != y.getSize())
throw InvalidArgumentError("LinearOperator::apply: incorrect input/output "
"sizes for scalar mult. leaf");
// sanity check the scalar in the optional
if (!_scalar.has_value())
throw InvalidArgumentError(
"LinearOperator::apply: no value found in the scalar optional");
_rhs->applyAdjoint(y, Aty);
Aty *= _scalar.value();
return;
}
}
throw LogicError("LinearOperator: applyAdjoint called on ill-formed object");
......@@ -220,8 +264,15 @@ namespace elsa
if (_isLeaf)
return new LinearOperator<data_t>(*_lhs, _isAdjoint);
if (_isComposite)
return new LinearOperator<data_t>(*_lhs, *_rhs, _mode);
if (_isComposite) {
if (_mode == CompositeMode::ADD || _mode == CompositeMode::MULT) {
return new LinearOperator<data_t>(*_lhs, *_rhs, _mode);
}
if (_mode == CompositeMode::SCALAR_MULT) {
return new LinearOperator<data_t>(*_rhs, _isAdjoint);
}
}
return new LinearOperator<data_t>(*_domainDescriptor, *_rangeDescriptor);
}
......@@ -242,8 +293,15 @@ namespace elsa
if (_isLeaf)
return (_isAdjoint == other._isAdjoint) && (*_lhs == *other._lhs);
if (_isComposite)
return _mode == other._mode && (*_lhs == *other._lhs) && (*_rhs == *other._rhs);
if (_isComposite) {
if (_mode == CompositeMode::ADD || _mode == CompositeMode::MULT) {
return _mode == other._mode && (*_lhs == *other._lhs) && (*_rhs == *other._rhs);
}
if (_mode == CompositeMode::SCALAR_MULT) {
return (_isAdjoint == other._isAdjoint) && (*_rhs == *other._rhs);
}
}
return true;
}
......@@ -255,6 +313,7 @@ namespace elsa
_rangeDescriptor{(isAdjoint) ? op.getDomainDescriptor().clone()
: op.getRangeDescriptor().clone()},
_lhs{op.clone()},
_scalar{op._scalar},
_isLeaf{true},
_isAdjoint{isAdjoint}
{
......@@ -293,6 +352,17 @@ namespace elsa
}
}
template <typename data_t>
LinearOperator<data_t>::LinearOperator(data_t scalar, const LinearOperator<data_t>& rhs)
: _domainDescriptor{rhs.getDomainDescriptor().clone()},
_rangeDescriptor{rhs.getRangeDescriptor().clone()},
_rhs{rhs.clone()},
_scalar{scalar},
_isComposite{true},
_mode{CompositeMode::SCALAR_MULT}
{
}
// ------------------------------------------
// explicit template instantiation
template class LinearOperator<float>;
......
......@@ -123,6 +123,12 @@ namespace elsa
return LinearOperator(lhs, rhs, CompositeMode::MULT);
}
/// friend operator* to support composition of a scalar and a LinearOperator
friend LinearOperator<data_t> operator*(data_t scalar, const LinearOperator<data_t>& op)
{
return LinearOperator(scalar, op);
}
/// friend function to return the adjoint of a LinearOperator (and its derivatives)
friend LinearOperator<data_t> adjoint(const LinearOperator<data_t>& op)
{
......@@ -159,6 +165,8 @@ namespace elsa
/// pointers to nodes in the evaluation tree
std::unique_ptr<LinearOperator<data_t>> _lhs{}, _rhs{};
std::optional<data_t> _scalar = {};
/// flag whether this is a leaf-node
bool _isLeaf{false};
......@@ -169,7 +177,7 @@ namespace elsa
bool _isComposite{false};
/// enum class denoting the mode of composition (+, *)
enum class CompositeMode { ADD, MULT };
enum class CompositeMode { ADD, MULT, SCALAR_MULT };
/// variable storing the composition mode (+, *)
CompositeMode _mode{CompositeMode::MULT};
......@@ -180,6 +188,9 @@ namespace elsa
/// constructor to produce a composite (internal node) of the evaluation tree
LinearOperator(const LinearOperator<data_t>& lhs, const LinearOperator<data_t>& rhs,
CompositeMode mode);
/// constructor to produce a composite (internal node) of the evaluation tree
LinearOperator(data_t scalar, const LinearOperator<data_t>& op);
};
} // namespace elsa
......@@ -367,6 +367,68 @@ TEST_CASE("LinearOperator: Testing composite LinearOperator")
}
}
GIVEN("a scalar multiplicative composite linear operator")
{
IndexVector_t numCoeff(3);
numCoeff << 13, 47, 69;
IndexVector_t otherNumCoeff(2);
otherNumCoeff << 15, 28;
VolumeDescriptor ddDomain(numCoeff);
VolumeDescriptor ddRange(otherNumCoeff);
MockOperator op(ddDomain, ddRange);
real_t scalar = 8;
auto scalarMultOp = scalar * op;
WHEN("the operator is there")
{
THEN("the descriptors are set correctly")
{
REQUIRE_EQ(scalarMultOp.getDomainDescriptor(), ddDomain);
REQUIRE_EQ(scalarMultOp.getRangeDescriptor(), ddRange);
}
}
WHEN("given data")
{
DataContainer dcDomain(ddDomain);
DataContainer dcRange(ddRange);
THEN("the apply operations return the correct result")
{
auto resultApply = scalarMultOp.apply(dcDomain);
for (int i = 0; i < resultApply.getSize(); ++i)
REQUIRE_EQ(resultApply[i], static_cast<real_t>(8));
auto resultApplyAdjoint = scalarMultOp.applyAdjoint(dcRange);
for (int i = 0; i < resultApplyAdjoint.getSize(); ++i)
REQUIRE_EQ(resultApplyAdjoint[i], static_cast<real_t>(24));
}
THEN("the apply operations account for appropriately sized containers")
{
REQUIRE_THROWS_AS(scalarMultOp.apply(dcRange), InvalidArgumentError);
REQUIRE_THROWS_AS(scalarMultOp.applyAdjoint(dcDomain), InvalidArgumentError);
}
}
WHEN("copying/assigning")
{
const auto& newOp = scalarMultOp;
auto assignedOp = adjoint(newOp);
THEN("it should be identical")
{
REQUIRE_EQ(newOp, scalarMultOp);
REQUIRE_EQ(assignedOp, adjoint(newOp));
assignedOp = newOp;
REQUIRE_EQ(assignedOp, newOp);
}
}
}
GIVEN("a complex composite with multiple leafs and levels")
{
IndexVector_t numCoeff(2);
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment