Commit 1a259f78 authored by Jonas Bürger's avatar Jonas Bürger
Browse files

Fix and refactor KSVD::calculateSVD

parent fd8a4487
Pipeline #530612 failed with stages
in 14 minutes and 41 seconds
......@@ -6,9 +6,7 @@ namespace elsa
KSVD<data_t>::KSVD(/*const*/ DictionaryLearningProblem<data_t>& problem, data_t epsilon)
: /*Solver<data_t>(problem),*/ _problem{problem},
_epsilon{epsilon},
_nSamples{getNumberOfSamples(problem.getSignals())},
_firstLeftSingular(VolumeDescriptor({problem.getCurrentDictionary().getNumberOfAtoms()})),
_firstRightSingular(VolumeDescriptor{1})
_nSamples{getNumberOfSamples(problem.getSignals())}
{
}
......@@ -20,6 +18,18 @@ namespace elsa
return signalsDescriptor.getNumberOfBlocks();
}
template <typename data_t>
DataContainer<data_t>& KSVD<data_t>::solve(index_t iterations)
{
return solveImpl(iterations);
}
template <typename data_t>
Dictionary<data_t>& KSVD<data_t>::getLearnedDictionary()
{
return _problem.getCurrentDictionary();
}
template <typename data_t>
DataContainer<data_t>& KSVD<data_t>::solveImpl(index_t iterations)
{
......@@ -40,9 +50,10 @@ namespace elsa
for (index_t k = 0; k < dict.getNumberOfAtoms(); ++k) {
auto affectedSignals = getAffectedSignals(representations, k);
auto modifiedError = _problem.getRestrictedError(affectedSignals, k);
calculateSVD(modifiedError);
dict.updateAtom(k, _firstLeftSingular);
updateRepresentations(representations, affectedSignals, k);
auto svd = calculateSVD(modifiedError);
dict.updateAtom(k, getNextAtom(svd));
updateRepresentations(representations, getNextRepresentation(svd), affectedSignals,
k);
_problem.updateError();
if (_problem.getGlobalError().l2Norm() < _epsilon)
break;
......@@ -69,7 +80,7 @@ namespace elsa
}
template <typename data_t>
void KSVD<data_t>::calculateSVD(DataContainer<data_t> error)
auto KSVD<data_t>::calculateSVD(DataContainer<data_t> error)
{
const auto& errorDescriptor =
dynamic_cast<const IdenticalBlocksDescriptor&>(error.getDataDescriptor());
......@@ -83,20 +94,32 @@ namespace elsa
}
}
Eigen::JacobiSVD svd(errorMatrix, Eigen::ComputeThinU | Eigen::ComputeThinV);
_firstLeftSingular =
DataContainer<data_t>(VolumeDescriptor({nCoeffs}), svd.matrixU().col(0));
_firstSingularValue = svd.singularValues()[0];
_firstRightSingular =
DataContainer<data_t>(VolumeDescriptor({nBlocks}), svd.matrixV().col(0));
return errorMatrix.jacobiSvd(Eigen::ComputeThinU | Eigen::ComputeThinV);
}
template <typename data_t>
DataContainer<data_t> KSVD<data_t>::getNextAtom(
Eigen::JacobiSVD<Eigen::Matrix<data_t, Eigen::Dynamic, Eigen::Dynamic>> svd)
{
auto firstLeft = svd.matrixU().col(0);
DataContainer<data_t> firstLeftSingular(VolumeDescriptor({firstLeft.rows()}), firstLeft);
return firstLeftSingular;
}
template <typename data_t>
DataContainer<data_t> KSVD<data_t>::getNextRepresentation(
Eigen::JacobiSVD<Eigen::Matrix<data_t, Eigen::Dynamic, Eigen::Dynamic>> svd)
{
auto firstRight = svd.matrixV().col(0);
DataContainer<data_t> firstRightSingular(VolumeDescriptor({firstRight.rows()}), firstRight);
return firstRightSingular * svd.singularValues()[0];
}
template <typename data_t>
void KSVD<data_t>::updateRepresentations(DataContainer<data_t>& representations,
DataContainer<data_t> nextRepresentation,
IndexVector_t affectedSignals, index_t atom)
{
DataContainer<data_t> nextRepresentation = _firstSingularValue * _firstRightSingular;
index_t i = 0;
for (auto idx : affectedSignals) {
representations.getBlock(idx)[atom] = nextRepresentation[i];
......
#pragma once
#include "Solver.h"
//#include "Solver.h"
#include "RepresentationProblem.h"
#include "DictionaryLearningProblem.h"
#include "OMP.h"
......@@ -38,6 +38,11 @@ namespace elsa
/// default destructor
//~KSVD() override = default;
// public entry point, will be inherited from Solver
DataContainer<data_t>& solve(index_t iterations);
Dictionary<data_t>& getLearnedDictionary();
private:
/// variable affecting the stopping condition
data_t _epsilon;
......@@ -47,10 +52,6 @@ namespace elsa
/// lift the base class variable _problem
DictionaryLearningProblem<data_t>& _problem;
DataContainer<data_t> _firstLeftSingular;
DataContainer<data_t> _firstRightSingular;
data_t _firstSingularValue;
/**
* @brief Solve the representation problem, i.e. apply iterations number of iterations of
* matching pursuit
......@@ -65,9 +66,16 @@ namespace elsa
IndexVector_t getAffectedSignals(const DataContainer<data_t>& representations,
index_t atom);
void calculateSVD(DataContainer<data_t> data);
auto calculateSVD(DataContainer<data_t> data);
DataContainer<data_t> getNextAtom(
Eigen::JacobiSVD<Eigen::Matrix<data_t, Eigen::Dynamic, Eigen::Dynamic>> svd);
DataContainer<data_t> getNextRepresentation(
Eigen::JacobiSVD<Eigen::Matrix<data_t, Eigen::Dynamic, Eigen::Dynamic>> svd);
void updateRepresentations(DataContainer<data_t>& representations,
DataContainer<data_t> nextRepresentation,
IndexVector_t affectedSignals, index_t atom);
static index_t getNumberOfSamples(const DataContainer<data_t>& signals);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment