Commit 9ac7834c authored by David Frank's avatar David Frank
Browse files

Add sketch for lut projector

parent 69219b67
......@@ -8,6 +8,7 @@
#include "PartitionDescriptor.h"
#include "Error.h"
#include "TypeCasts.hpp"
#include "Assertions.h"
#include <utility>
......@@ -108,26 +109,53 @@ namespace elsa
template <typename data_t>
data_t& DataContainer<data_t>::operator[](index_t index)
{
ELSA_VERIFY(index >= 0);
ELSA_VERIFY(index < getSize());
return (*_dataHandler)[index];
}
template <typename data_t>
const data_t& DataContainer<data_t>::operator[](index_t index) const
{
ELSA_VERIFY(index >= 0);
ELSA_VERIFY(index < getSize());
return static_cast<const DataHandler<data_t>&>(*_dataHandler)[index];
}
template <typename data_t>
data_t DataContainer<data_t>::at(IndexVector_t coordinate) const
{
const auto arr = coordinate.array();
if ((arr < 0).any()
|| (arr >= _dataDescriptor->getNumberOfCoefficientsPerDimension().array()).any()) {
return 0;
}
return (*this)[_dataDescriptor->getIndexFromCoordinate(std::move(coordinate))];
}
template <typename data_t>
data_t& DataContainer<data_t>::operator()(IndexVector_t coordinate)
{
return (*_dataHandler)[_dataDescriptor->getIndexFromCoordinate(std::move(coordinate))];
const auto arr = coordinate.array();
const auto shape = _dataDescriptor->getNumberOfCoefficientsPerDimension().array();
ELSA_VERIFY((arr >= 0).all());
ELSA_VERIFY((arr < shape).all());
return (*this)[_dataDescriptor->getIndexFromCoordinate(std::move(coordinate))];
}
template <typename data_t>
const data_t& DataContainer<data_t>::operator()(IndexVector_t coordinate) const
{
return static_cast<const DataHandler<data_t>&>(
*_dataHandler)[_dataDescriptor->getIndexFromCoordinate(std::move(coordinate))];
const auto arr = coordinate.array();
const auto shape = _dataDescriptor->getNumberOfCoefficientsPerDimension().array();
ELSA_VERIFY((arr >= 0).all());
ELSA_VERIFY((arr < shape).all());
return (*this)[_dataDescriptor->getIndexFromCoordinate(coordinate)];
}
template <typename data_t>
......
......@@ -177,6 +177,8 @@ namespace elsa
/// return an element by n-dimensional coordinate as read-only (not bounds-checked!)
const data_t& operator()(IndexVector_t coordinate) const;
data_t at(IndexVector_t coordinate) const;
/// return an element by its coordinates (not bounds-checked!)
template <typename idx0_t, typename... idx_t,
typename = std::enable_if_t<
......
#include "Blobs.h"
namespace elsa
{
template <typename data_t>
constexpr Blob<data_t>::Blob(data_t radius, SelfType_t<data_t> alpha, SelfType_t<data_t> order)
: radius_(radius), alpha_(alpha), order_(order)
{
}
template <typename data_t>
constexpr data_t Blob<data_t>::operator()(data_t s)
{
return blobs::blob_evaluate(s, radius_, alpha_, order_);
}
template <typename data_t>
constexpr data_t Blob<data_t>::radius() const
{
return radius_;
}
template <typename data_t>
constexpr data_t Blob<data_t>::alpha() const
{
return alpha_;
}
template <typename data_t>
constexpr data_t Blob<data_t>::order() const
{
return order_;
}
// ------------------------------------------
// explicit template instantiation
namespace blobs
{
template float blob_evaluate<float>(float, SelfType_t<float>, SelfType_t<float>,
SelfType_t<float>);
template double blob_evaluate<double>(double, SelfType_t<double>, SelfType_t<double>,
SelfType_t<double>);
template float blob_projected<float>(float, SelfType_t<float>, SelfType_t<float>,
SelfType_t<float>);
template double blob_projected<double>(double, SelfType_t<double>, SelfType_t<double>,
SelfType_t<double>);
template float blob_projected<float>(float);
template double blob_projected<double>(double);
} // namespace blobs
template class Blob<float>;
template class Blob<double>;
} // namespace elsa
#pragma once
#include "elsaDefines.h"
namespace elsa
{
template <class T>
struct SelfType {
using type = T;
};
template <class T>
using SelfType_t = typename SelfType<T>::type;
namespace blobs
{
template <typename data_t>
constexpr data_t blob_evaluate(data_t r, SelfType_t<data_t> a, SelfType_t<data_t> alpha,
SelfType_t<data_t> m)
{
const auto w = static_cast<data_t>(1) - std::pow(r / a, static_cast<data_t>(2));
if (w >= 0) {
const data_t Im1 = std::cyl_bessel_i(m, alpha);
const data_t arg = std::sqrt(w);
const data_t Im2 = std::cyl_bessel_i(m, alpha * arg);
return Im2 / Im1 * std::pow(arg, m);
}
return 0;
}
/// @brief Compute line integral of blob through a straight line
/// @param distance distance of blob center to straight line, in literature often referred
/// to as `r`
/// @param radius radius of blob, often referred to as `a`
/// @param alpha smoothness factor of blob
/// @param order order of Bessel function, often referred to as `m`
/// Ref:
/// https://github.com/I2PC/xmipp/blob/3d4cc3f430cbc99a337635edbd95ebbcef33fc44/src/xmipp/libraries/data/blobs.cpp#L91A
/// Distance-Driven Projection and Backprojection for Spherically Symmetric Basis Functions
/// in CT - Levakhina
/// Spherically symmetric volume elements as basis functions for image reconstructions in
/// computed laminography - P. Trampert
/// Semi-Discrete Iteration Methods in X-Ray Tomography - Jonas Vogelgesang
template <typename data_t>
constexpr data_t blob_projected(data_t s, SelfType_t<data_t> a, SelfType_t<data_t> alpha,
SelfType_t<data_t> m)
{
// Equation derived in Lewitt 1990
const data_t w = static_cast<data_t>(1) - ((s * s) / (a * a));
// If `w` is close to zero or negative, `s` > `a`, and therefore just return 0
if (w > 1e-10) {
const data_t root = std::sqrt(w);
// First three terms of equation
const data_t q1 = a / std::cyl_bessel_i(m, alpha);
const data_t q2 = std::sqrt(2 * pi<data_t> / alpha);
const data_t q3 = std::pow(root, m + static_cast<data_t>(0.5));
const data_t q4 = std::cyl_bessel_i(m + static_cast<data_t>(0.5), alpha * root);
return q1 * q2 * q3 * q4;
}
return 0;
}
template <typename data_t>
constexpr data_t blob_projected(data_t s)
{
return blob_projected(s, 2.f, 10.83f, 2);
}
} // namespace blobs
template <typename data_t>
class Blob
{
public:
constexpr Blob(data_t radius, SelfType_t<data_t> alpha, SelfType_t<data_t> order);
constexpr data_t operator()(data_t s);
constexpr data_t radius() const;
constexpr data_t alpha() const;
constexpr data_t order() const;
private:
data_t radius_;
data_t alpha_;
data_t order_;
};
template <typename data_t>
class ProjectedBlob
{
public:
constexpr ProjectedBlob(data_t radius, SelfType_t<data_t> alpha, SelfType_t<data_t> order)
: radius_(radius), alpha_(alpha), order_(order)
{
}
constexpr data_t operator()(data_t s)
{
return blobs::blob_projected(s, radius_, alpha_, order_);
}
#
data_t radius() const { return radius_; }
data_t alpha() const { return alpha_; }
data_t order() const { return order_; }
private:
data_t radius_;
data_t alpha_;
data_t order_;
};
} // namespace elsa
......@@ -5,10 +5,12 @@ set(MODULE_HEADERS
TraverseAABB.h
TraverseAABBJosephsMethod.h
SliceTraversal.h
LutProjector.h
BinaryMethod.h
SiddonsMethod.h
JosephsMethod.h
SubsetSampler.h
Blobs.h
)
# list all the code files of the module
......@@ -18,10 +20,12 @@ set(MODULE_SOURCES
TraverseAABB.cpp
TraverseAABBJosephsMethod.cpp
SliceTraversal.cpp
LutProjector.cpp
BinaryMethod.cpp
SiddonsMethod.cpp
JosephsMethod.cpp
SubsetSampler.cpp
Blobs.cpp
)
list(APPEND MODULE_PUBLIC_DEPS elsa_core elsa_logging)
......
#include "LutProjector.h"
#include "Timer.h"
#include "SliceTraversal.h"
#include "Assertions.h"
namespace elsa
{
template <typename data_t>
BlobProjector<data_t>::BlobProjector(data_t radius, data_t alpha, data_t order,
const VolumeDescriptor& domainDescriptor,
const DetectorDescriptor& rangeDescriptor)
: LutProjector<data_t, BlobProjector<data_t>>(domainDescriptor, rangeDescriptor),
lut_(radius, alpha, order)
{
// sanity checks
auto dim = domainDescriptor.getNumberOfDimensions();
if (dim < 2 || dim > 3) {
throw InvalidArgumentError("BlobProjector: only supporting 2d/3d operations");
}
if (dim != rangeDescriptor.getNumberOfDimensions()) {
throw InvalidArgumentError("BlobProjector: domain and range dimension need to match");
}
if (_detectorDescriptor.getNumberOfGeometryPoses() == 0) {
throw InvalidArgumentError("BlobProjector: rangeDescriptor without any geometry");
}
}
template <typename data_t>
BlobProjector<data_t>::BlobProjector(const VolumeDescriptor& domainDescriptor,
const DetectorDescriptor& rangeDescriptor)
: BlobProjector(2, 10.83, 2, domainDescriptor, rangeDescriptor)
{
}
// ------------------------------------------
// explicit template instantiation
template class BlobProjector<float>;
template class BlobProjector<double>;
} // namespace elsa
#pragma once
#include "elsaDefines.h"
#include "Timer.h"
#include "Luts.hpp"
#include "SliceTraversal.h"
#include "LinearOperator.h"
#include "VolumeDescriptor.h"
#include "DetectorDescriptor.h"
#include "DataContainer.h"
#include "BoundingBox.h"
#include "Logger.h"
#include "Blobs.h"
namespace elsa
{
template <typename data_t, typename Derived>
class LutProjector : public LinearOperator<data_t>
{
public:
LutProjector(const VolumeDescriptor& domainDescriptor,
const DetectorDescriptor& rangeDescriptor)
: LinearOperator<data_t>(domainDescriptor, rangeDescriptor),
_boundingBox{domainDescriptor.getNumberOfCoefficientsPerDimension()},
_detectorDescriptor(static_cast<DetectorDescriptor&>(*_rangeDescriptor)),
_volumeDescriptor(static_cast<VolumeDescriptor&>(*_domainDescriptor))
{
// sanity checks
auto dim = _domainDescriptor->getNumberOfDimensions();
if (dim < 2 || dim > 3) {
throw InvalidArgumentError("LutProjector: only supporting 2d/3d operations");
}
if (dim != _rangeDescriptor->getNumberOfDimensions()) {
throw InvalidArgumentError(
"LutProjector: domain and range dimension need to match");
}
if (_detectorDescriptor.getNumberOfGeometryPoses() == 0) {
throw InvalidArgumentError("LutProjector: rangeDescriptor without any geometry");
}
}
/// default destructor
~LutProjector() override = default;
Derived& self() { return static_cast<Derived&>(*this); }
const Derived& self() const { return static_cast<const Derived&>(*this); }
protected:
/// apply the binary method (i.e. forward projection)
void applyImpl(const DataContainer<data_t>& x, DataContainer<data_t>& Ax) const override
{
Timer t("LutProjector", "apply");
// Be sure to zero out the result
Ax = 0;
const auto sizeRange = Ax.getSize();
const auto volume_shape = x.getDataDescriptor().getNumberOfCoefficientsPerDimension();
// Loop over all the poses, and for each pose loop over all detector pixels
#pragma omp parallel for
for (index_t rangeIndex = 0; rangeIndex < sizeRange; ++rangeIndex) {
// --> get the current ray to the detector center
auto ray = _detectorDescriptor.computeRayFromDetectorCoord(rangeIndex);
index_t leadingdir = 0;
ray.direction().array().cwiseAbs().maxCoeff(&leadingdir);
auto& rangeVal = Ax[rangeIndex];
// Expand bounding box as rays have larger support now
auto aabb = _boundingBox;
aabb._min.array() -= support;
aabb._min[leadingdir] += support;
aabb._max.array() += support;
aabb._max[leadingdir] -= support;
// --> setup traversal algorithm
SliceTraversal traversal(aabb, ray);
for (const auto [curPos, curVoxel, t] : traversal) {
for (auto neighbour :
neighbours_in_slice(curVoxel, support, leadingdir, lower, upper)) {
// Correct position, such that the distance is still correct
auto correctedPos = neighbour.template cast<real_t>().array() + 0.5;
const auto distance = ray.distance(curPos);
const auto weight = self().weight(distance);
rangeVal += weight * x.at(curVoxel);
}
}
}
/// apply the adjoint of the binary method (i.e. backward projection)
void applyAdjointImpl(const DataContainer<data_t>& y,
DataContainer<data_t>& Aty) const override
{
Timer t("LutProjector", "apply");
const auto sizeRange = y.getSize();
Aty = 0;
const auto shape = _domainDescriptor->getNumberOfCoefficientsPerDimension();
#pragma omp parallel for
// Loop over all the poses, and for each pose loop over all detector pixels
for (index_t rangeIndex = 0; rangeIndex < sizeRange; ++rangeIndex) {
// --> get the current ray to the detector center (from reference to
// DetectorDescriptor)
auto ray = _detectorDescriptor.computeRayFromDetectorCoord(rangeIndex);
index_t leadingdir = 0;
ray.direction().array().cwiseAbs().maxCoeff(&leadingdir);
// Expand bounding box as rays have larger support now
auto aabb = _boundingBox;
aabb._min.array() -= support;
aabb._min[leadingdir] += support;
aabb._max.array() += support;
aabb._max[leadingdir] -= support;
// --> setup traversal algorithm
SliceTraversal traversal(aabb, ray);
const auto val = y[rangeIndex];
for (const auto [curPos, curVoxel, t] : traversal) {
for (auto neighbour :
neighbours_in_slice(curVoxel, support, leadingdir, lower, upper)) {
// Correct position, such that the distance is still correct
auto correctedPos = neighbour.template cast<real_t>().array() + 0.5;
const auto distance = ray.distance(correctedPos);
const auto weight = self().weight(distance);
#pragma omp atomic
Aty(neighbour) += weight * val;
}
}
}
}
/// implement the polymorphic clone operation
LutProjector<data_t, Derived>* cloneImpl() const override
{
return new LutProjector(_volumeDescriptor, _detectorDescriptor);
}
/// implement the polymorphic comparison operation
bool isEqual(const LinearOperator<data_t>& other) const override
{
if (!LinearOperator<data_t>::isEqual(other))
return false;
auto otherOp = downcast_safe<LutProjector>(&other);
return static_cast<bool>(otherOp);
}
private:
/// the bounding box of the volume
BoundingBox _boundingBox;
/// Lift from base class
using LinearOperator<data_t>::_domainDescriptor;
/// Lift from base class
using LinearOperator<data_t>::_rangeDescriptor;
protected:
/// Reference to DetectorDescriptor stored in LinearOperator
DetectorDescriptor& _detectorDescriptor;
/// Reference to VolumeDescriptor stored in LinearOperator
VolumeDescriptor& _volumeDescriptor;
};
template <typename data_t = real_t>
class BlobProjector : public LutProjector<data_t, BlobProjector<data_t>>
{
public:
BlobProjector(data_t radius, data_t alpha, data_t order,
const VolumeDescriptor& domainDescriptor,
const DetectorDescriptor& rangeDescriptor);
BlobProjector(const VolumeDescriptor& domainDescriptor,
const DetectorDescriptor& rangeDescriptor);
data_t weight(data_t distance) const { return lut_(distance); }
/// implement the polymorphic clone operation
BlobProjector<data_t>* cloneImpl() const override
{
return new BlobProjector(_volumeDescriptor, _detectorDescriptor);
}
/// implement the polymorphic comparison operation
bool isEqual(const LinearOperator<data_t>& other) const override
{
if (!Base::isEqual(other))
return false;
auto otherOp = downcast_safe<BlobProjector>(&other);
return static_cast<bool>(otherOp);
}
private:
ProjectedBlobLut<data_t, 100> lut_;
using Base = LutProjector<data_t, BlobProjector<data_t>>;
using Base::_detectorDescriptor;
using Base::_volumeDescriptor;
};
} // namespace elsa
#pragma once
#include "Blobs.h"
#include "Logger.h"
#include <array>
namespace elsa
{
namespace detail
{
template <typename data_t, index_t N>
constexpr std::array<data_t, N> blob_lut(ProjectedBlob<data_t> blob)
{
std::array<data_t, N> lut;
auto t = static_cast<data_t>(0);
const auto step = blob.radius() / N;
for (std::size_t i = 0; i < N; ++i) {
lut[i] = blob(t);
t += step;
}
return lut;
}
template <typename data_t>
data_t lerp(data_t a, SelfType_t<data_t> b, SelfType_t<data_t> t)
{
if ((a <= 0 && b >= 0) || (a >= 0 && b <= 0))
return t * b + (1 - t) * a;
if (t == 1)
return b;
const data_t x = a + t * (b - a);
if ((t > 1) == (b > a))
return b < x ? x : b;
else
return x < b ? x : b;
}
} // namespace detail
template <typename data_t, std::size_t N>
class Lut
{
public:
Lut(std::array<data_t, N> data) : data_(std::move(data)) {}
template <typename T, std::enable_if_t<std::is_integral_v<T>, int> = 0>
data_t operator()(T index) const
{
if (index < 0 || index > N) {
return 0;
}
return data_[index];
}
/// TODO: Handle boundary conditions
/// lerp(last, last+1, t), for some t > 0, yields f(last) / 2, as f(last + 1) = 0,
/// this should be handled
template <typename T, std::enable_if_t<std::is_floating_point_v<T>, int> = 0>
data_t operator()(T index) const
{
if (index < 0 || index > N) {
return 0;
}
// Get the two closes indices
const auto a = static_cast<std::size_t>(std::floor(index));