Commit 19abf8f3 authored by Andi Braimllari's avatar Andi Braimllari Committed by Jonas Jelten
Browse files

add fft norm config arguments

parent 112dcb2e
...@@ -184,15 +184,15 @@ namespace elsa ...@@ -184,15 +184,15 @@ namespace elsa
} }
template <typename data_t> template <typename data_t>
void DataContainer<data_t>::fft() const void DataContainer<data_t>::fft(FFTNorm norm) const
{ {
this->_dataHandler->fft(*this->_dataDescriptor); this->_dataHandler->fft(*this->_dataDescriptor, norm);
} }
template <typename data_t> template <typename data_t>
void DataContainer<data_t>::ifft() const void DataContainer<data_t>::ifft(FFTNorm norm) const
{ {
this->_dataHandler->ifft(*this->_dataDescriptor); this->_dataHandler->ifft(*this->_dataDescriptor, norm);
} }
template <typename data_t> template <typename data_t>
......
...@@ -249,10 +249,10 @@ namespace elsa ...@@ -249,10 +249,10 @@ namespace elsa
data_t maxElement() const; data_t maxElement() const;
/// convert to the fourier transformed signal /// convert to the fourier transformed signal
void fft() const; void fft(FFTNorm norm) const;
/// convert to the inverse fourier transformed signal /// convert to the inverse fourier transformed signal
void ifft() const; void ifft(FFTNorm norm) const;
/// if the datacontainer is already complex, return itself. /// if the datacontainer is already complex, return itself.
template <typename _data_t = data_t> template <typename _data_t = data_t>
......
...@@ -84,10 +84,10 @@ namespace elsa ...@@ -84,10 +84,10 @@ namespace elsa
virtual data_t maxElement() const = 0; virtual data_t maxElement() const = 0;
/// in-place create the fourier transformed of the data vector /// in-place create the fourier transformed of the data vector
virtual DataHandler<data_t>& fft(const DataDescriptor& source_desc) = 0; virtual DataHandler<data_t>& fft(const DataDescriptor& source_desc, FFTNorm norm) = 0;
/// in-place create the inverse fourier transformed of the data vector /// in-place create the inverse fourier transformed of the data vector
virtual DataHandler<data_t>& ifft(const DataDescriptor& source_desc) = 0; virtual DataHandler<data_t>& ifft(const DataDescriptor& source_desc, FFTNorm norm) = 0;
/// compute in-place element-wise addition of another vector v /// compute in-place element-wise addition of another vector v
virtual DataHandler<data_t>& operator+=(const DataHandler<data_t>& v) = 0; virtual DataHandler<data_t>& operator+=(const DataHandler<data_t>& v) = 0;
......
...@@ -141,16 +141,18 @@ namespace elsa ...@@ -141,16 +141,18 @@ namespace elsa
} }
template <typename data_t> template <typename data_t>
DataHandler<data_t>& DataHandlerCPU<data_t>::fft(const DataDescriptor& source_desc) DataHandler<data_t>& DataHandlerCPU<data_t>::fft(const DataDescriptor& source_desc,
FFTNorm norm)
{ {
this->base_fft<true>(source_desc); this->base_fft<true>(source_desc, norm);
return *this; return *this;
} }
template <typename data_t> template <typename data_t>
DataHandler<data_t>& DataHandlerCPU<data_t>::ifft(const DataDescriptor& source_desc) DataHandler<data_t>& DataHandlerCPU<data_t>::ifft(const DataDescriptor& source_desc,
FFTNorm norm)
{ {
this->base_fft<false>(source_desc); this->base_fft<false>(source_desc, norm);
return *this; return *this;
} }
...@@ -485,7 +487,7 @@ namespace elsa ...@@ -485,7 +487,7 @@ namespace elsa
template <typename data_t> template <typename data_t>
template <bool is_forward> template <bool is_forward>
void DataHandlerCPU<data_t>::base_fft(const DataDescriptor& source_desc) void DataHandlerCPU<data_t>::base_fft(const DataDescriptor& source_desc, FFTNorm norm)
{ {
if constexpr (isComplex<data_t>) { if constexpr (isComplex<data_t>) {
...@@ -539,6 +541,11 @@ namespace elsa ...@@ -539,6 +541,11 @@ namespace elsa
Eigen::FFT<GetFloatingPointType_t<typename DataVector_t::Scalar>> fft_op; Eigen::FFT<GetFloatingPointType_t<typename DataVector_t::Scalar>> fft_op;
// disable any scaling in eigen - normally it does 1/n for ifft
fft_op.SetFlag(
Eigen::FFT<
GetFloatingPointType_t<typename DataVector_t::Scalar>>::Flag::Unscaled);
Eigen::Matrix<data_t, Eigen::Dynamic, 1> fft_in{dim_size}; Eigen::Matrix<data_t, Eigen::Dynamic, 1> fft_in{dim_size};
Eigen::Matrix<data_t, Eigen::Dynamic, 1> fft_out{dim_size}; Eigen::Matrix<data_t, Eigen::Dynamic, 1> fft_out{dim_size};
...@@ -554,9 +561,18 @@ namespace elsa ...@@ -554,9 +561,18 @@ namespace elsa
// they will corrupt wildly otherwise. // they will corrupt wildly otherwise.
if constexpr (is_forward) { if constexpr (is_forward) {
fft_op.fwd(fft_out, fft_in); fft_op.fwd(fft_out, fft_in);
if (norm == FFTNorm::FORWARD) {
fft_out /= dim_size;
} else if (norm == FFTNorm::ORTHO) {
fft_out /= std::sqrt(dim_size);
}
} else { } else {
// eigen inv-fft already scales down by dim_size
fft_op.inv(fft_out, fft_in); fft_op.inv(fft_out, fft_in);
if (norm == FFTNorm::BACKWARD) {
fft_out /= dim_size;
} else if (norm == FFTNorm::ORTHO) {
fft_out /= std::sqrt(dim_size);
}
} }
} }
......
...@@ -131,10 +131,10 @@ namespace elsa ...@@ -131,10 +131,10 @@ namespace elsa
data_t maxElement() const override; data_t maxElement() const override;
/// create the fourier transformed of the data vector /// create the fourier transformed of the data vector
DataHandler<data_t>& fft(const DataDescriptor& source_desc) override; DataHandler<data_t>& fft(const DataDescriptor& source_desc, FFTNorm norm) override;
/// create the inverse fourier transformed of the data vector /// create the inverse fourier transformed of the data vector
DataHandler<data_t>& ifft(const DataDescriptor& source_desc) override; DataHandler<data_t>& ifft(const DataDescriptor& source_desc, FFTNorm norm) override;
/// copy assign another DataHandlerCPU to this, other types handled in assign() /// copy assign another DataHandlerCPU to this, other types handled in assign()
DataHandlerCPU<data_t>& operator=(const DataHandlerCPU<data_t>& v); DataHandlerCPU<data_t>& operator=(const DataHandlerCPU<data_t>& v);
...@@ -222,6 +222,6 @@ namespace elsa ...@@ -222,6 +222,6 @@ namespace elsa
void attach(std::shared_ptr<DataVector_t>&& data); void attach(std::shared_ptr<DataVector_t>&& data);
template <bool is_forward> template <bool is_forward>
void base_fft(const DataDescriptor& source_desc); void base_fft(const DataDescriptor& source_desc, FFTNorm norm = FFTNorm::BACKWARD);
}; };
} // namespace elsa } // namespace elsa
...@@ -131,14 +131,16 @@ namespace elsa ...@@ -131,14 +131,16 @@ namespace elsa
return _data->maxElement(); return _data->maxElement();
} }
DataHandler<data_t>& DataHandlerGPU<data_t>::fft(const DataDescriptor& source_desc) template <typename data_t>
DataHandler<data_t>& DataHandlerGPU<data_t>::fft(const DataDescriptor& source_desc,
FFTNorm norm)
{ {
// until we have a gpu fft implementation, use the cpu version. // until we have a gpu fft implementation, use the cpu version.
DataContainer<data_t> tmp{source_desc, DataHandlerType::CPU}; DataContainer<data_t> tmp{source_desc, DataHandlerType::CPU};
for (index_t i = 0; i < this->getSize(); i++) { for (index_t i = 0; i < this->getSize(); i++) {
tmp[i] = this->operator[](i); tmp[i] = this->operator[](i);
} }
tmp.fft(); tmp.fft(norm);
for (index_t i = 0; i < this->getSize(); i++) { for (index_t i = 0; i < this->getSize(); i++) {
this->operator[](i) = tmp[i]; this->operator[](i) = tmp[i];
} }
...@@ -146,14 +148,15 @@ namespace elsa ...@@ -146,14 +148,15 @@ namespace elsa
} }
template <typename data_t> template <typename data_t>
DataHandler<data_t>& DataHandlerGPU<data_t>::ifft(const DataDescriptor& source_desc) DataHandler<data_t>& DataHandlerGPU<data_t>::ifft(const DataDescriptor& source_desc,
FFTNorm norm)
{ {
// until we have a gpu fft implementation, use the cpu version. // until we have a gpu fft implementation, use the cpu version.
DataContainer<data_t> tmp{source_desc, DataHandlerType::CPU}; DataContainer<data_t> tmp{source_desc, DataHandlerType::CPU};
for (index_t i = 0; i < this->getSize(); i++) { for (index_t i = 0; i < this->getSize(); i++) {
tmp[i] = this->operator[](i); tmp[i] = this->operator[](i);
} }
tmp.ifft(); tmp.ifft(norm);
for (index_t i = 0; i < this->getSize(); i++) { for (index_t i = 0; i < this->getSize(); i++) {
this->operator[](i) = tmp[i]; this->operator[](i) = tmp[i];
} }
......
...@@ -135,10 +135,10 @@ namespace elsa ...@@ -135,10 +135,10 @@ namespace elsa
data_t maxElement() const override; data_t maxElement() const override;
/// create the fourier transformed of the data vector /// create the fourier transformed of the data vector
DataHandler<data_t>& fft(const DataDescriptor& source_desc) override; DataHandler<data_t>& fft(const DataDescriptor& source_desc, FFTNorm norm) override;
/// create the inverse fourier transformed of the data vector /// create the inverse fourier transformed of the data vector
DataHandler<data_t>& ifft(const DataDescriptor& source_desc) override; DataHandler<data_t>& ifft(const DataDescriptor& source_desc, FFTNorm norm) override;
/// copy assign another DataHandlerGPU /// copy assign another DataHandlerGPU
DataHandlerGPU<data_t>& operator=(const DataHandlerGPU<data_t>& v); DataHandlerGPU<data_t>& operator=(const DataHandlerGPU<data_t>& v);
......
...@@ -133,18 +133,20 @@ namespace elsa ...@@ -133,18 +133,20 @@ namespace elsa
} }
template <typename data_t> template <typename data_t>
DataHandler<data_t>& DataHandlerMapCPU<data_t>::fft(const DataDescriptor& source_desc) DataHandler<data_t>& DataHandlerMapCPU<data_t>::fft(const DataDescriptor& source_desc,
FFTNorm norm)
{ {
// detaches internally // detaches internally
this->_dataOwner->fft(source_desc); this->_dataOwner->fft(source_desc, norm);
return *this; return *this;
} }
template <typename data_t> template <typename data_t>
DataHandler<data_t>& DataHandlerMapCPU<data_t>::ifft(const DataDescriptor& source_desc) DataHandler<data_t>& DataHandlerMapCPU<data_t>::ifft(const DataDescriptor& source_desc,
FFTNorm norm)
{ {
// detaches internally // detaches internally
this->_dataOwner->ifft(source_desc); this->_dataOwner->ifft(source_desc, norm);
return *this; return *this;
} }
......
...@@ -106,10 +106,10 @@ namespace elsa ...@@ -106,10 +106,10 @@ namespace elsa
data_t maxElement() const override; data_t maxElement() const override;
/// create the fourier transformed of the data vector /// create the fourier transformed of the data vector
DataHandler<data_t>& fft(const DataDescriptor& source_desc) override; DataHandler<data_t>& fft(const DataDescriptor& source_desc, FFTNorm norm) override;
/// create the inverse fourier transformed of the data vector /// create the inverse fourier transformed of the data vector
DataHandler<data_t>& ifft(const DataDescriptor& source_desc) override; DataHandler<data_t>& ifft(const DataDescriptor& source_desc, FFTNorm norm) override;
/// compute in-place element-wise addition of another vector v /// compute in-place element-wise addition of another vector v
DataHandler<data_t>& operator+=(const DataHandler<data_t>& v) override; DataHandler<data_t>& operator+=(const DataHandler<data_t>& v) override;
......
...@@ -142,13 +142,15 @@ namespace elsa ...@@ -142,13 +142,15 @@ namespace elsa
} }
template <typename data_t> template <typename data_t>
DataHandler<data_t>& DataHandlerMapGPU<data_t>::fft(const DataDescriptor& source_desc) DataHandler<data_t>& DataHandlerMapGPU<data_t>::fft(const DataDescriptor& source_desc,
FFTNorm norm)
{ {
throw std::runtime_error{"todo implement"}; throw std::runtime_error{"todo implement"};
} }
template <typename data_t> template <typename data_t>
DataHandler<data_t>& DataHandlerMapGPU<data_t>::ifft(const DataDescriptor& source_desc) DataHandler<data_t>& DataHandlerMapGPU<data_t>::ifft(const DataDescriptor& source_desc,
FFTNorm norm)
{ {
throw std::runtime_error{"todo implement"}; throw std::runtime_error{"todo implement"};
} }
......
...@@ -125,10 +125,10 @@ namespace elsa ...@@ -125,10 +125,10 @@ namespace elsa
data_t maxElement() const override; data_t maxElement() const override;
/// create the fourier transformed of the data vector /// create the fourier transformed of the data vector
DataHandler<data_t>& fft(const DataDescriptor& source_desc) override; DataHandler<data_t>& fft(const DataDescriptor& source_desc, FFTNorm norm) override;
/// create the inverse fourier transformed of the data vector /// create the inverse fourier transformed of the data vector
DataHandler<data_t>& ifft(const DataDescriptor& source_desc) override; DataHandler<data_t>& ifft(const DataDescriptor& source_desc, FFTNorm norm) override;
/// compute in-place element-wise addition of another vector v /// compute in-place element-wise addition of another vector v
DataHandler<data_t>& operator+=(const DataHandler<data_t>& v) override; DataHandler<data_t>& operator+=(const DataHandler<data_t>& v) override;
......
...@@ -37,6 +37,9 @@ namespace elsa ...@@ -37,6 +37,9 @@ namespace elsa
/// global constexpr for the number pi /// global constexpr for the number pi
constexpr auto pi_t = pi<real_t>; constexpr auto pi_t = pi<real_t>;
/// various values of the different norms of the Fourier transforms
enum class FFTNorm { FORWARD, ORTHO, BACKWARD };
/// type of the DataHandler used to store the actual data /// type of the DataHandler used to store the actual data
enum class DataHandlerType { enum class DataHandlerType {
CPU, ///< data is stored as an Eigen::Matrix in CPU main memory CPU, ///< data is stored as an Eigen::Matrix in CPU main memory
......
...@@ -9,8 +9,8 @@ ...@@ -9,8 +9,8 @@
namespace elsa namespace elsa
{ {
template <typename data_t> template <typename data_t>
FourierTransform<data_t>::FourierTransform(const DataDescriptor& domainDescriptor) FourierTransform<data_t>::FourierTransform(const DataDescriptor& domainDescriptor, FFTNorm norm)
: B(domainDescriptor, domainDescriptor) : B(domainDescriptor, domainDescriptor), norm{norm}
{ {
} }
...@@ -28,7 +28,7 @@ namespace elsa ...@@ -28,7 +28,7 @@ namespace elsa
// copy the input and fouriertransform it // copy the input and fouriertransform it
Ax = x; Ax = x;
Ax.fft(); Ax.fft(this->norm);
} }
template <typename data_t> template <typename data_t>
...@@ -43,7 +43,7 @@ namespace elsa ...@@ -43,7 +43,7 @@ namespace elsa
// copy the input and inverse-fouriertransform it // copy the input and inverse-fouriertransform it
Atx = x; Atx = x;
Atx.ifft(); Atx.ifft(this->norm);
} }
template <typename data_t> template <typename data_t>
......
...@@ -31,8 +31,11 @@ namespace elsa ...@@ -31,8 +31,11 @@ namespace elsa
* @brief create a fourier transform operator * @brief create a fourier transform operator
* *
* @param[in] domainDescriptor metadata defining the domain and range of the transformation * @param[in] domainDescriptor metadata defining the domain and range of the transformation
* @param[in] norm metadata indicating which forward/inverse transform is scaled and
* on which of the predefined normalization factors
*/ */
explicit FourierTransform(const DataDescriptor& domainDescriptor); explicit FourierTransform(const DataDescriptor& domainDescriptor,
FFTNorm norm = FFTNorm::BACKWARD);
~FourierTransform() override = default; ~FourierTransform() override = default;
...@@ -57,6 +60,9 @@ namespace elsa ...@@ -57,6 +60,9 @@ namespace elsa
/// implement the polymorphic comparison operation /// implement the polymorphic comparison operation
bool isEqual(const B& other) const override; bool isEqual(const B& other) const override;
private:
FFTNorm norm;
}; };
} // namespace elsa } // namespace elsa
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment