Commit f2c0b9e6 authored by David Tellenbach's avatar David Tellenbach
Browse files

Add test for FixedLayer backward semantics

parent 54bf92c1
Pipeline #205183 passed with stages
in 13 minutes and 37 seconds
......@@ -13,33 +13,31 @@ using namespace elsa;
TEST_CASE("FixedLayer", "elsa_ml")
{
SECTION("Forward semantics")
{
// A fixed layer's forward pass is defined by its applyAdjoint method
// generate 2d phantom
IndexVector_t size(2);
size << 32, 32;
auto phantom = PhantomGenerator<real_t>::createModifiedSheppLogan(size);
// generate 2d phantom
IndexVector_t size(2);
size << 32, 32;
auto phantom = PhantomGenerator<real_t>::createModifiedSheppLogan(size);
index_t noAngles{180}, arc{360};
auto [geometry, sinoDescriptor] = CircleTrajectoryGenerator::createTrajectory(
noAngles, phantom.getDataDescriptor(), arc, size(0) * 100, size(0));
index_t noAngles{180}, arc{360};
auto [geometry, sinoDescriptor] = CircleTrajectoryGenerator::createTrajectory(
noAngles, phantom.getDataDescriptor(), arc, size(0) * 100, size(0));
JosephsMethod projector(phantom.getDataDescriptor(), *sinoDescriptor, geometry);
JosephsMethod projector(phantom.getDataDescriptor(), *sinoDescriptor, geometry);
auto sinogram = projector.apply(phantom);
auto sinogram = projector.apply(phantom);
auto phantom2 = projector.applyAdjoint(sinogram);
auto phantom2 = projector.applyAdjoint(sinogram);
FixedLayer layer(*sinoDescriptor, projector);
FixedLayer layer(*sinoDescriptor, projector);
// The layer's input descriptor is the operator's range descriptor
REQUIRE(layer.getInputDescriptor() == projector.getRangeDescriptor());
// The layer's input descriptor is the operator's range descriptor
REQUIRE(layer.getInputDescriptor() == projector.getRangeDescriptor());
// The layer's output descriptor is the operator's domain descriptor
REQUIRE(layer.getOutputDescriptor() == projector.getDomainDescriptor());
// The layer's output descriptor is the operator's domain descriptor
REQUIRE(layer.getOutputDescriptor() == projector.getDomainDescriptor());
SECTION("Forward semantics")
{
// Setup layer
auto backend = layer.getBackend();
backend->setInput(sinogram);
......@@ -60,6 +58,23 @@ TEST_CASE("FixedLayer", "elsa_ml")
SECTION("Backward semantics")
{
// A fixed layer's forward pass is defined by its apply method
// Setup layer
auto backend = layer.getBackend();
backend->setInput(sinogram);
backend->setOutputGradient(phantom);
backend->initialize();
backend->compile();
// Forward propagate
auto engine = backend->getEngine();
dnnl::stream s(*engine);
backend->forwardPropagate(s);
backend->backwardPropagate(s);
// Get output
auto inputGradient = backend->getInputGradient();
for (index_t i = 0; i < inputGradient.getDataDescriptor().getNumberOfCoefficients(); ++i)
REQUIRE(inputGradient[i] == Approx(sinogram[i]));
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment