Commit 7238d8d4 authored by Jonas Jelten's avatar Jonas Jelten 🥕
Browse files

add fourier transform operator

parent 939f1176
Pipeline #370153 failed with stages
in 21 minutes and 37 seconds
......@@ -11,6 +11,7 @@ set(MODULE_HEADERS
Identity.h
Scaling.h
FiniteDifferences.h
FourierTransform.h
BlockLinearOperator.h)
# list all the code files of the module
......@@ -18,6 +19,7 @@ set(MODULE_SOURCES
Identity.cpp
Scaling.cpp
FiniteDifferences.cpp
FourierTransform.cpp
BlockLinearOperator.cpp)
......
#include "FourierTransform.h"
#include "Timer.h"
#include <unsupported/Eigen/FFT>
namespace elsa {
template <typename data_t>
FourierTransform<data_t>::FourierTransform(const DataDescriptor &domainDescriptor)
: B(domainDescriptor, domainDescriptor) {
}
template <typename data_t>
void FourierTransform<data_t>::applyImpl(
const DataContainer<data_t> &x,
DataContainer<data_t> &Ax) const {
Timer<> timeguard("FourierTransform", "apply()");
auto x_size = x.getDataDescriptor().getNumberOfCoefficientsPerDimension();
auto x_dims = x.getDataDescriptor().getNumberOfDimensions();
// input container size must match dimensionality of operator setup
assert(x_size.size() == this->_domainDescriptor.getNumberOfDimensions());
fftvector_t input{x.getSize()};
input.setZero();
// TODO: avoid this copy, instead access datahandler or get an eigen-matrix directly
// convert to datacontainer to eigen-matrix
// another approach: move fft implementation to datahandler
for (index_t i = 0; i < x.getSize(); ++i) {
input(i) = x(i);
}
fftvector_t output(x.getSize());
output.setZero();
this->fft(input, output, x_dims);
for (index_t i = 0; i < output.size(); ++i) {
Ax(i) = output(i);
}
}
template <typename data_t>
void FourierTransform<data_t>::fft(
const fftvector_t &in,
fftvector_t &out,
index_t dims) const {
auto dim_idx = dims - 1;
if (dim_idx == 0) {
this->fft1d(in, out);
/*
} else if (dim_idx == 1) {
// TODO: use fftw 2d transformation directly!
*/
} else {
const auto &in_coeffs_per_dim = this->_domainDescriptor->getNumberOfCoefficientsPerDimension();
const auto &out_coeffs_per_dim = this->_rangeDescriptor->getNumberOfCoefficientsPerDimension();
const index_t in_stride = in_coeffs_per_dim.head(dim_idx).prod();
const index_t out_stride = out_coeffs_per_dim.head(dim_idx).prod();
// number of coefficients for the current dimension
const index_t dim_size = out_coeffs_per_dim(dim_idx);
// number of coefficients for the other dimensions
// e.g. [10x10x10], and dim_idx=2 -> out.size() == 1000 -> /10 = 100
const index_t dims_remaining_size = out.size() / dim_size;
// TODO: in 2d-case, no need to copy
#pragma omp parallel for
for (index_t i = 0; i < dim_size; ++i) {
fftvector_t in_tmp(in_stride);
for (index_t j = 0; j < in_stride; j++) {
in_tmp(j) = in[in_stride * i + j];
}
fftvector_t out_tmp(out_stride);
out_tmp.setZero();
// recursive call!
fft(in_tmp, out_tmp, dims - 1);
// store out_stride elements at output row i
out.segment(i * out_stride, out_stride) = out_tmp;
}
#pragma omp parallel for
for (index_t i = 0; i < dims_remaining_size; ++i) {
// for the column-calculations,
// map the correct indices, i.e. use InnerStride
// to specify the discance between two consecutive indices
// yes, use out as input, since out contains the
// calculation of the rows from the above loop.
const Eigen::Map<fftvector_t, 0, Eigen::InnerStride<>>
input_map(const_cast<data_t*>(out.data() + i),
dim_size,
Eigen::InnerStride<>(out_stride));
Eigen::Map<fftvector_t, 0, Eigen::InnerStride<>>
output_map(out.data() + i,
dim_size,
Eigen::InnerStride<>(out_stride));
// copy the data into a new vector
const fftvector_t in_tmp(input_map);
fftvector_t out_tmp(output_map);
this->fft1d(in_tmp, out_tmp);
// bring temporary results into the output vector
// the map is just holding pointers to the output vector
// update happens implicitly
output_map = out_tmp;
}
}
}
template <typename data_t>
void FourierTransform<data_t>::applyAdjointImpl(const DataContainer<data_t>& x,
DataContainer<data_t>& Atx) const
{
Timer<> timeguard("FourierTransform", "applyAdjoint()");
// TODO ifft
}
template <typename data_t>
FourierTransform<data_t>* FourierTransform<data_t>::cloneImpl() const
{
auto& domainDescriptor = static_cast<const DataDescriptor&>(*this->_domainDescriptor);
return new FourierTransform(domainDescriptor);
}
template <typename data_t>
bool FourierTransform<data_t>::isEqual(const B& other) const
{
if (!B::isEqual(other))
return false;
auto otherOP = dynamic_cast<const FourierTransform *>(&other);
if (!otherOP)
return false;
// TODO actually check for equality!
return true;
}
template <typename data_t>
void FourierTransform<data_t>::fft1d(const fftvector_t &in,
fftvector_t &out) const {
Eigen::FFT<typename data_t::value_type> fft;
fft.fwd(out, in);
}
template <typename data_t>
void FourierTransform<data_t>::ifft1d(const fftvector_t &in,
fftvector_t &out) const {
Eigen::FFT<typename data_t::value_type> fft;
fft.inv(out, in);
}
template class FourierTransform<std::complex<float>>;
template class FourierTransform<std::complex<double>>;
} // namespace elsa
#pragma once
#include <complex>
#include "LinearOperator.h"
namespace elsa
{
/**
* \brief Operator for applying multi-dimensional fourier transforms.
*
* \author Jonas Jelten - initial code
*
* \tparam data_t data type for the domain and range of the transformation,
* defaulting to real_t
*
* Implements the n-dimensional signal fourier transformation.
* Can support multiple backends, by default uses Eigen::FFT with FFTW.
*/
template <typename data_t = std::complex<real_t>>
class FourierTransform : public LinearOperator<data_t> {
private:
using B = LinearOperator<data_t>;
/** working data container for the fft.
like in datacontainer, we operate on the vector in n dimensions. */
using fftvector_t = Eigen::Matrix<data_t, Eigen::Dynamic, 1>;
public:
/**
* \brief create a fourier transform operator
*
* \param[in] domainDescriptor metadata defining the domain and range of the transformation
*/
explicit FourierTransform(const DataDescriptor& domainDescriptor);
~FourierTransform() override = default;
protected:
/**
* \brief perform the fourier transformation
* \param x inputData (image matrix)
* \param Ax outputData (fourier transformed image matrix)
*/
void applyImpl(const DataContainer<data_t> &x, DataContainer<data_t> &Ax) const override;
/**
* \brief TODO ifft
* \param x inputData (XXX)
* \param Atx outputData (XXX)
*/
void applyAdjointImpl(const DataContainer<data_t> &y, DataContainer<data_t> &Aty) const override;
/// implement the polymorphic clone operation
FourierTransform *cloneImpl() const override;
/// implement the polymorphic comparison operation
bool isEqual(const B &other) const override;
/// recursively called fft implementation
void fft(const fftvector_t &in,
fftvector_t &out,
index_t dims) const;
void fft1d(const fftvector_t &in,
fftvector_t& out) const;
void ifft1d(const fftvector_t &in,
fftvector_t& out) const;
};
} // namespace elsa
......@@ -9,3 +9,4 @@ ELSA_TEST(Identity)
ELSA_TEST(Scaling)
ELSA_TEST(FiniteDifferences)
ELSA_TEST(BlockLinearOperator)
ELSA_TEST(FourierTransform)
/**
* \file test_FourierTransform.cpp
*
* \brief Tests for the fourier transform operator
*
* \author Jonas Jelten
*/
#include <catch2/catch.hpp>
#include "FourierTransform.h"
#include "VolumeDescriptor.h"
using namespace elsa;
SCENARIO("Constructing a FourierTransform operator")
{
GIVEN("a descriptor")
{
IndexVector_t numCoeff(3);
numCoeff << 13, 45, 28;
VolumeDescriptor dd(numCoeff);
WHEN("instantiating a FourierTransform operator")
{
FourierTransform fftOp(dd);
THEN("the DataDescriptors are as expected")
{
REQUIRE(fftOp.getDomainDescriptor() == dd);
REQUIRE(fftOp.getRangeDescriptor() == dd);
}
}
WHEN("cloning a FourierTransform operator")
{
FourierTransform fftOp(dd);
auto fftOpClone = fftOp.clone();
THEN("everything matches")
{
REQUIRE(fftOpClone.get() != &fftOp);
REQUIRE(*fftOpClone == fftOp);
}
}
}
}
SCENARIO("Using FourierTransform")
{
GIVEN("some data")
{
IndexVector_t numCoeff(2);
numCoeff << 11, 13;
VolumeDescriptor dd(numCoeff);
DataContainer<complex_t> input(dd);
input = 3.3f;
FourierTransform fftOp(dd);
WHEN("applying the fft")
{
auto output = fftOp.apply(input);
THEN("the result is as expected")
{
REQUIRE(output.getSize() == input.getSize());
//REQUIRE(input == output);
}
}
WHEN("applying the adjoint of fft")
{
auto output = fftOp.applyAdjoint(input);
THEN("the results is as expected")
{
REQUIRE(output.getSize() == input.getSize());
//REQUIRE(input == output);
}
}
WHEN("applying the fft and inverse fft")
{
auto output = fftOp.apply(fftOp.applyAdjoint(input));
THEN("the results is as expected")
{
REQUIRE(output.getSize() == input.getSize());
REQUIRE(input == output);
}
}
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment