Commit f6b6baec authored by David Frank's avatar David Frank
Browse files

Optimize Traversal slightly to avoid unnecessary allocations

parent 84988350
Pipeline #862144 passed with stages
in 58 minutes and 59 seconds
......@@ -78,6 +78,9 @@ namespace elsa
index_t leadingdir = 0;
ray.direction().array().cwiseAbs().maxCoeff(&leadingdir);
IndexVector_t distvec = IndexVector_t::Constant(lower.size(), support);
distvec[leadingdir] = 0;
auto rangeVal = Ax[rangeIndex];
// Expand bounding box as rays have larger support now
......@@ -88,19 +91,21 @@ namespace elsa
aabb._max.array() += static_cast<real_t>(support);
aabb._max[leadingdir] -= static_cast<real_t>(support);
// Keep this here, as it saves us a couple of allocations on clang
CartesianIndices neighbours(upper);
// --> setup traversal algorithm
SliceTraversal traversal(aabb, ray);
for (const auto [curPos, curVoxel, t] : traversal) {
for (auto neighbour :
neighbours_in_slice(curVoxel, support, leadingdir, lower, upper)) {
for (const auto& curVoxel : traversal) {
neighbours = neighbours_in_slice(curVoxel, distvec, lower, upper);
for (auto neighbour : neighbours) {
// Correct position, such that the distance is still correct
auto correctedPos = neighbour.template cast<real_t>().array() + 0.5;
const auto correctedPos = neighbour.template cast<real_t>().array() + 0.5;
const auto distance = ray.distance(correctedPos);
const auto weight = self().weight(distance);
rangeVal += weight * x.at(neighbour);
rangeVal += weight * x(neighbour);
}
}
......@@ -112,7 +117,7 @@ namespace elsa
void applyAdjointImpl(const DataContainer<data_t>& y,
DataContainer<data_t>& Aty) const override
{
Timer t("LutProjector", "apply");
Timer t("LutProjector", "applyAdjoint");
const auto sizeRange = y.getSize();
Aty = 0;
......@@ -133,6 +138,9 @@ namespace elsa
index_t leadingdir = 0;
ray.direction().array().cwiseAbs().maxCoeff(&leadingdir);
IndexVector_t distvec = IndexVector_t::Constant(lower.size(), support);
distvec[leadingdir] = 0;
// Expand bounding box as rays have larger support now
auto aabb = _boundingBox;
aabb._min.array() -= static_cast<real_t>(support);
......@@ -141,19 +149,22 @@ namespace elsa
aabb._max.array() += static_cast<real_t>(support);
aabb._max[leadingdir] -= static_cast<real_t>(support);
// Keep this here, as it saves us a couple of allocations on clang
CartesianIndices neighbours(upper);
// --> setup traversal algorithm
SliceTraversal traversal(aabb, ray);
const auto val = y[rangeIndex];
for (const auto [curPos, curVoxel, t] : traversal) {
for (auto neighbour :
neighbours_in_slice(curVoxel, support, leadingdir, lower, upper)) {
for (const auto& curVoxel : traversal) {
neighbours = neighbours_in_slice(curVoxel, distvec, lower, upper);
for (auto neighbour : neighbours) {
// Correct position, such that the distance is still correct
auto correctedPos = neighbour.template cast<real_t>().array() + 0.5;
const auto correctedPos = neighbour.template cast<real_t>().array() + 0.5;
const auto distance = ray.distance(correctedPos);
const auto weight = self().weight(distance);
#pragma omp atomic
Aty(neighbour) += weight * val;
}
......
......@@ -196,9 +196,16 @@ namespace elsa
return idx;
}
SliceTraversal::Iter SliceTraversal::begin() const { return {startIndex_, ray_, tDelta_, t_}; }
real_t SliceTraversal::t() const { return t_; }
SliceTraversal::Iter SliceTraversal::end() const { return {endIndex_, ray_, tDelta_, t_}; }
real_t SliceTraversal::tDelta() const { return tDelta_; }
SliceTraversal::Iter SliceTraversal::begin() const
{
return {startIndex_, ray_.pointAt(t_), ray_.direction() * tDelta_};
}
SliceTraversal::Iter SliceTraversal::end() const { return {endIndex_}; }
index_t SliceTraversal::startIndex() const { return startIndex_; }
......@@ -206,15 +213,13 @@ namespace elsa
SliceTraversal::Iter::value_type SliceTraversal::Iter::operator*() const
{
const RealVector_t curPos = ray_.pointAt(t_);
const IndexVector_t curVoxel = curPos.template cast<index_t>();
return {curPos, curVoxel, t_};
return cur_.template cast<index_t>();
}
SliceTraversal::Iter& SliceTraversal::Iter::operator++()
{
++pos_;
t_ += tDelta_;
cur_ += dir_;
return *this;
}
......
......@@ -173,14 +173,6 @@ namespace elsa
RealRay_t ray_;
public:
/// The Dereference type of the iterator
/// TODO: With C++20 support for proxy iterators is given, maybe this could change then
struct IterValue {
RealVector_t curPosition_;
IndexVector_t curVoxel_;
real_t t_;
};
/// Traversal iterator, models forward iterator, maybe this should actually be an input
/// iterator due to the non reference type of the dereference type. IDK, as we use it, this
/// works, but in the future this might should be different.
......@@ -188,15 +180,14 @@ namespace elsa
public:
using iterator_category = std::forward_iterator_tag;
using difference_type = std::ptrdiff_t;
using value_type = IterValue;
using value_type = IndexVector_t;
using pointer = value_type*;
using reference = value_type&;
private:
index_t pos_;
RealRay_t ray_;
real_t tDelta_;
real_t t_;
index_t pos_{};
RealVector_t cur_{};
RealVector_t dir_{};
public:
/// Construct iterator
......@@ -205,11 +196,13 @@ namespace elsa
/// @param ray traversed ray used to compute exact position on dereference
/// @param deltat increment of t each increment
/// @param t position along the ray
Iter(index_t pos, RealRay_t ray, real_t deltat, real_t t)
: pos_(pos), ray_(ray), tDelta_(deltat), t_(t)
Iter(index_t pos, const RealVector_t& entry, const RealVector_t& dir)
: pos_(pos), cur_(entry), dir_(dir)
{
}
Iter(index_t pos) : pos_(pos) {}
/// Dereference iterator
value_type operator*() const;
......@@ -226,25 +219,31 @@ namespace elsa
friend bool operator!=(const Iter& lhs, const Iter& rhs);
};
// Delete default construction
/// Delete default construction
SliceTraversal() = delete;
// Construct traversal from bounding box and ray
/// Construct traversal from bounding box and ray
SliceTraversal(BoundingBox aabb, RealRay_t ray);
// Get the first visited voxel
/// Get the first visited voxel
Iter begin() const;
// Get on past the end
/// Get on past the end
Iter end() const;
// Get the index of the first visited voxel
/// Get the index of the first visited voxel
index_t startIndex() const;
// Get the index of the one past the last visited voxel
/// Get the index of the one past the last visited voxel
index_t endIndex() const;
// Get the leading direction of the ray
/// Get the leading direction of the ray
index_t leadingDirection() const;
/// Get the initial computed t value for the entry point
real_t t() const;
/// Get the computed delta t value
real_t tDelta() const;
};
} // namespace elsa
......@@ -405,6 +405,8 @@ index_t checkTraversal(BoundingBox aabb, RealRay_t ray, std::deque<RealVector_t>
// INFO("exitPoint := ", traversal.exitPoint_.format(fmt));
index_t counter = 0;
auto t = traversal.t();
auto tdelta = traversal.tDelta();
for (auto iter = traversal.begin(); iter != traversal.end(); ++iter) {
auto value = *iter;
......@@ -412,9 +414,9 @@ index_t checkTraversal(BoundingBox aabb, RealRay_t ray, std::deque<RealVector_t>
REQUIRE_MESSAGE(!visitedVoxels.empty(), "Visiting more voxels than expected");
RealVector_t expected = visitedVoxels.front();
RealVector_t point = ray.pointAt(value.t_);
RealVector_t point = ray.pointAt(t);
CAPTURE(value.t_);
CAPTURE(t);
INFO("RealRay_t hit: ", point.format(fmt));
INFO("Should hit: ", expected.format(fmt));
......@@ -425,6 +427,7 @@ index_t checkTraversal(BoundingBox aabb, RealRay_t ray, std::deque<RealVector_t>
// increment counter
++counter;
t += tdelta;
}
INFO("Voxels left in list: ", visitedVoxels.size());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment