Commit 19abf8f3 authored by Andi Braimllari's avatar Andi Braimllari Committed by Jonas Jelten
Browse files

add fft norm config arguments

parent 112dcb2e
......@@ -184,15 +184,15 @@ namespace elsa
}
template <typename data_t>
void DataContainer<data_t>::fft() const
void DataContainer<data_t>::fft(FFTNorm norm) const
{
this->_dataHandler->fft(*this->_dataDescriptor);
this->_dataHandler->fft(*this->_dataDescriptor, norm);
}
template <typename data_t>
void DataContainer<data_t>::ifft() const
void DataContainer<data_t>::ifft(FFTNorm norm) const
{
this->_dataHandler->ifft(*this->_dataDescriptor);
this->_dataHandler->ifft(*this->_dataDescriptor, norm);
}
template <typename data_t>
......
......@@ -249,10 +249,10 @@ namespace elsa
data_t maxElement() const;
/// convert to the fourier transformed signal
void fft() const;
void fft(FFTNorm norm) const;
/// convert to the inverse fourier transformed signal
void ifft() const;
void ifft(FFTNorm norm) const;
/// if the datacontainer is already complex, return itself.
template <typename _data_t = data_t>
......
......@@ -84,10 +84,10 @@ namespace elsa
virtual data_t maxElement() const = 0;
/// in-place create the fourier transformed of the data vector
virtual DataHandler<data_t>& fft(const DataDescriptor& source_desc) = 0;
virtual DataHandler<data_t>& fft(const DataDescriptor& source_desc, FFTNorm norm) = 0;
/// in-place create the inverse fourier transformed of the data vector
virtual DataHandler<data_t>& ifft(const DataDescriptor& source_desc) = 0;
virtual DataHandler<data_t>& ifft(const DataDescriptor& source_desc, FFTNorm norm) = 0;
/// compute in-place element-wise addition of another vector v
virtual DataHandler<data_t>& operator+=(const DataHandler<data_t>& v) = 0;
......
......@@ -141,16 +141,18 @@ namespace elsa
}
template <typename data_t>
DataHandler<data_t>& DataHandlerCPU<data_t>::fft(const DataDescriptor& source_desc)
DataHandler<data_t>& DataHandlerCPU<data_t>::fft(const DataDescriptor& source_desc,
FFTNorm norm)
{
this->base_fft<true>(source_desc);
this->base_fft<true>(source_desc, norm);
return *this;
}
template <typename data_t>
DataHandler<data_t>& DataHandlerCPU<data_t>::ifft(const DataDescriptor& source_desc)
DataHandler<data_t>& DataHandlerCPU<data_t>::ifft(const DataDescriptor& source_desc,
FFTNorm norm)
{
this->base_fft<false>(source_desc);
this->base_fft<false>(source_desc, norm);
return *this;
}
......@@ -485,7 +487,7 @@ namespace elsa
template <typename data_t>
template <bool is_forward>
void DataHandlerCPU<data_t>::base_fft(const DataDescriptor& source_desc)
void DataHandlerCPU<data_t>::base_fft(const DataDescriptor& source_desc, FFTNorm norm)
{
if constexpr (isComplex<data_t>) {
......@@ -539,6 +541,11 @@ namespace elsa
Eigen::FFT<GetFloatingPointType_t<typename DataVector_t::Scalar>> fft_op;
// disable any scaling in eigen - normally it does 1/n for ifft
fft_op.SetFlag(
Eigen::FFT<
GetFloatingPointType_t<typename DataVector_t::Scalar>>::Flag::Unscaled);
Eigen::Matrix<data_t, Eigen::Dynamic, 1> fft_in{dim_size};
Eigen::Matrix<data_t, Eigen::Dynamic, 1> fft_out{dim_size};
......@@ -554,9 +561,18 @@ namespace elsa
// they will corrupt wildly otherwise.
if constexpr (is_forward) {
fft_op.fwd(fft_out, fft_in);
if (norm == FFTNorm::FORWARD) {
fft_out /= dim_size;
} else if (norm == FFTNorm::ORTHO) {
fft_out /= std::sqrt(dim_size);
}
} else {
// eigen inv-fft already scales down by dim_size
fft_op.inv(fft_out, fft_in);
if (norm == FFTNorm::BACKWARD) {
fft_out /= dim_size;
} else if (norm == FFTNorm::ORTHO) {
fft_out /= std::sqrt(dim_size);
}
}
}
......
......@@ -131,10 +131,10 @@ namespace elsa
data_t maxElement() const override;
/// create the fourier transformed of the data vector
DataHandler<data_t>& fft(const DataDescriptor& source_desc) override;
DataHandler<data_t>& fft(const DataDescriptor& source_desc, FFTNorm norm) override;
/// create the inverse fourier transformed of the data vector
DataHandler<data_t>& ifft(const DataDescriptor& source_desc) override;
DataHandler<data_t>& ifft(const DataDescriptor& source_desc, FFTNorm norm) override;
/// copy assign another DataHandlerCPU to this, other types handled in assign()
DataHandlerCPU<data_t>& operator=(const DataHandlerCPU<data_t>& v);
......@@ -222,6 +222,6 @@ namespace elsa
void attach(std::shared_ptr<DataVector_t>&& data);
template <bool is_forward>
void base_fft(const DataDescriptor& source_desc);
void base_fft(const DataDescriptor& source_desc, FFTNorm norm = FFTNorm::BACKWARD);
};
} // namespace elsa
......@@ -131,14 +131,16 @@ namespace elsa
return _data->maxElement();
}
DataHandler<data_t>& DataHandlerGPU<data_t>::fft(const DataDescriptor& source_desc)
template <typename data_t>
DataHandler<data_t>& DataHandlerGPU<data_t>::fft(const DataDescriptor& source_desc,
FFTNorm norm)
{
// until we have a gpu fft implementation, use the cpu version.
DataContainer<data_t> tmp{source_desc, DataHandlerType::CPU};
for (index_t i = 0; i < this->getSize(); i++) {
tmp[i] = this->operator[](i);
}
tmp.fft();
tmp.fft(norm);
for (index_t i = 0; i < this->getSize(); i++) {
this->operator[](i) = tmp[i];
}
......@@ -146,14 +148,15 @@ namespace elsa
}
template <typename data_t>
DataHandler<data_t>& DataHandlerGPU<data_t>::ifft(const DataDescriptor& source_desc)
DataHandler<data_t>& DataHandlerGPU<data_t>::ifft(const DataDescriptor& source_desc,
FFTNorm norm)
{
// until we have a gpu fft implementation, use the cpu version.
DataContainer<data_t> tmp{source_desc, DataHandlerType::CPU};
for (index_t i = 0; i < this->getSize(); i++) {
tmp[i] = this->operator[](i);
}
tmp.ifft();
tmp.ifft(norm);
for (index_t i = 0; i < this->getSize(); i++) {
this->operator[](i) = tmp[i];
}
......
......@@ -135,10 +135,10 @@ namespace elsa
data_t maxElement() const override;
/// create the fourier transformed of the data vector
DataHandler<data_t>& fft(const DataDescriptor& source_desc) override;
DataHandler<data_t>& fft(const DataDescriptor& source_desc, FFTNorm norm) override;
/// create the inverse fourier transformed of the data vector
DataHandler<data_t>& ifft(const DataDescriptor& source_desc) override;
DataHandler<data_t>& ifft(const DataDescriptor& source_desc, FFTNorm norm) override;
/// copy assign another DataHandlerGPU
DataHandlerGPU<data_t>& operator=(const DataHandlerGPU<data_t>& v);
......
......@@ -133,18 +133,20 @@ namespace elsa
}
template <typename data_t>
DataHandler<data_t>& DataHandlerMapCPU<data_t>::fft(const DataDescriptor& source_desc)
DataHandler<data_t>& DataHandlerMapCPU<data_t>::fft(const DataDescriptor& source_desc,
FFTNorm norm)
{
// detaches internally
this->_dataOwner->fft(source_desc);
this->_dataOwner->fft(source_desc, norm);
return *this;
}
template <typename data_t>
DataHandler<data_t>& DataHandlerMapCPU<data_t>::ifft(const DataDescriptor& source_desc)
DataHandler<data_t>& DataHandlerMapCPU<data_t>::ifft(const DataDescriptor& source_desc,
FFTNorm norm)
{
// detaches internally
this->_dataOwner->ifft(source_desc);
this->_dataOwner->ifft(source_desc, norm);
return *this;
}
......
......@@ -106,10 +106,10 @@ namespace elsa
data_t maxElement() const override;
/// create the fourier transformed of the data vector
DataHandler<data_t>& fft(const DataDescriptor& source_desc) override;
DataHandler<data_t>& fft(const DataDescriptor& source_desc, FFTNorm norm) override;
/// create the inverse fourier transformed of the data vector
DataHandler<data_t>& ifft(const DataDescriptor& source_desc) override;
DataHandler<data_t>& ifft(const DataDescriptor& source_desc, FFTNorm norm) override;
/// compute in-place element-wise addition of another vector v
DataHandler<data_t>& operator+=(const DataHandler<data_t>& v) override;
......
......@@ -142,13 +142,15 @@ namespace elsa
}
template <typename data_t>
DataHandler<data_t>& DataHandlerMapGPU<data_t>::fft(const DataDescriptor& source_desc)
DataHandler<data_t>& DataHandlerMapGPU<data_t>::fft(const DataDescriptor& source_desc,
FFTNorm norm)
{
throw std::runtime_error{"todo implement"};
}
template <typename data_t>
DataHandler<data_t>& DataHandlerMapGPU<data_t>::ifft(const DataDescriptor& source_desc)
DataHandler<data_t>& DataHandlerMapGPU<data_t>::ifft(const DataDescriptor& source_desc,
FFTNorm norm)
{
throw std::runtime_error{"todo implement"};
}
......
......@@ -125,10 +125,10 @@ namespace elsa
data_t maxElement() const override;
/// create the fourier transformed of the data vector
DataHandler<data_t>& fft(const DataDescriptor& source_desc) override;
DataHandler<data_t>& fft(const DataDescriptor& source_desc, FFTNorm norm) override;
/// create the inverse fourier transformed of the data vector
DataHandler<data_t>& ifft(const DataDescriptor& source_desc) override;
DataHandler<data_t>& ifft(const DataDescriptor& source_desc, FFTNorm norm) override;
/// compute in-place element-wise addition of another vector v
DataHandler<data_t>& operator+=(const DataHandler<data_t>& v) override;
......
......@@ -37,6 +37,9 @@ namespace elsa
/// global constexpr for the number pi
constexpr auto pi_t = pi<real_t>;
/// various values of the different norms of the Fourier transforms
enum class FFTNorm { FORWARD, ORTHO, BACKWARD };
/// type of the DataHandler used to store the actual data
enum class DataHandlerType {
CPU, ///< data is stored as an Eigen::Matrix in CPU main memory
......
......@@ -9,8 +9,8 @@
namespace elsa
{
template <typename data_t>
FourierTransform<data_t>::FourierTransform(const DataDescriptor& domainDescriptor)
: B(domainDescriptor, domainDescriptor)
FourierTransform<data_t>::FourierTransform(const DataDescriptor& domainDescriptor, FFTNorm norm)
: B(domainDescriptor, domainDescriptor), norm{norm}
{
}
......@@ -28,7 +28,7 @@ namespace elsa
// copy the input and fouriertransform it
Ax = x;
Ax.fft();
Ax.fft(this->norm);
}
template <typename data_t>
......@@ -43,7 +43,7 @@ namespace elsa
// copy the input and inverse-fouriertransform it
Atx = x;
Atx.ifft();
Atx.ifft(this->norm);
}
template <typename data_t>
......
......@@ -31,8 +31,11 @@ namespace elsa
* @brief create a fourier transform operator
*
* @param[in] domainDescriptor metadata defining the domain and range of the transformation
* @param[in] norm metadata indicating which forward/inverse transform is scaled and
* on which of the predefined normalization factors
*/
explicit FourierTransform(const DataDescriptor& domainDescriptor);
explicit FourierTransform(const DataDescriptor& domainDescriptor,
FFTNorm norm = FFTNorm::BACKWARD);
~FourierTransform() override = default;
......@@ -57,6 +60,9 @@ namespace elsa
/// implement the polymorphic comparison operation
bool isEqual(const B& other) const override;
private:
FFTNorm norm;
};
} // namespace elsa
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment