Commit 3268337b authored by David Frank's avatar David Frank
Browse files

Add example that uses argument parsing

parent 9ac7834c
...@@ -27,16 +27,24 @@ add_example(example2d_admm example2d_admm.cpp) ...@@ -27,16 +27,24 @@ add_example(example2d_admm example2d_admm.cpp)
# build the shearlet example program # build the shearlet example program
add_example(shearlet_example example2d_shearlet.cpp) add_example(shearlet_example example2d_shearlet.cpp)
if(ELSA_BUILD_CUDA_PROJECTORS) CPMAddPackage(
include(CheckLanguage) NAME argparse
check_language(CUDA) GITHUB_REPOSITORY p-ranav/argparse
if(CMAKE_CUDA_COMPILER) GIT_TAG v2.2
enable_language(CUDA) )
add_example(example_argparse example_argparse.cpp)
target_link_libraries(example_argparse PRIVATE argparse::argparse)
if(ELSA_BUILD_CUDA_PROJECTORS)
include(CheckLanguage)
check_language(CUDA)
if(CMAKE_CUDA_COMPILER)
enable_language(CUDA)
# build the 3d example program # build the 3d example program
add_example(example3d example3d.cpp) add_example(example3d example3d.cpp)
# build the GPU projector speed test program # build the GPU projector speed test program
add_example(speed_test speed_test.cpp) add_example(speed_test speed_test.cpp)
endif() endif()
endif() endif()
...@@ -11,7 +11,8 @@ void example2d() ...@@ -11,7 +11,8 @@ void example2d()
{ {
// generate 2d phantom // generate 2d phantom
IndexVector_t size(2); IndexVector_t size(2);
size << 128, 128; // size << 128, 128;
size << 512, 512;
auto phantom = PhantomGenerator<real_t>::createModifiedSheppLogan(size); auto phantom = PhantomGenerator<real_t>::createModifiedSheppLogan(size);
auto& volumeDescriptor = phantom.getDataDescriptor(); auto& volumeDescriptor = phantom.getDataDescriptor();
...@@ -19,7 +20,7 @@ void example2d() ...@@ -19,7 +20,7 @@ void example2d()
io::write(phantom, "2dphantom.pgm"); io::write(phantom, "2dphantom.pgm");
// generate circular trajectory // generate circular trajectory
index_t numAngles{180}, arc{360}; index_t numAngles{512}, arc{360};
const auto distance = static_cast<real_t>(size(0)); const auto distance = static_cast<real_t>(size(0));
auto sinoDescriptor = CircleTrajectoryGenerator::createTrajectory( auto sinoDescriptor = CircleTrajectoryGenerator::createTrajectory(
numAngles, phantom.getDataDescriptor(), arc, distance * 100.0f, distance); numAngles, phantom.getDataDescriptor(), arc, distance * 100.0f, distance);
...@@ -44,10 +45,11 @@ void example2d() ...@@ -44,10 +45,11 @@ void example2d()
// solve the reconstruction problem // solve the reconstruction problem
CG cgSolver(wlsProblem); CG cgSolver(wlsProblem);
index_t noIterations{20}; index_t noIterations{10};
Logger::get("Info")->info("Solving reconstruction using {} iterations of conjugate gradient", Logger::get("Info")->info("Solving reconstruction using {} iterations of conjugate gradient",
noIterations); noIterations);
auto cgReconstruction = cgSolver.solve(noIterations); auto cgReconstruction = cgSolver.solve(noIterations);
std::cout << cgReconstruction.l2Norm() << "\n";
// write the reconstruction out // write the reconstruction out
io::write(cgReconstruction, "2dreconstruction_cg.pgm"); io::write(cgReconstruction, "2dreconstruction_cg.pgm");
......
...@@ -2,56 +2,148 @@ ...@@ -2,56 +2,148 @@
#include "LutProjector.h" #include "LutProjector.h"
#include "Utilities/Statistics.hpp" #include "Utilities/Statistics.hpp"
#include "IO.h" #include "IO.h"
#include "spdlog/fmt/bundled/core.h"
#include <argparse/argparse.hpp> #include <argparse/argparse.hpp>
std::pair<elsa::DataContainer<elsa::real_t>, elsa::DataContainer<elsa::real_t>> elsa::DataContainer<elsa::real_t> get_phantom(elsa::index_t dims, elsa::index_t s,
recon2d(elsa::index_t s, elsa::index_t numAngles, elsa::index_t arc, elsa::index_t iters) std::string phantom_kind)
{ {
elsa::IndexVector_t size({{s, s}}); const auto size = elsa::IndexVector_t::Constant(dims, s);
if (phantom_kind == "SheppLogan") {
return elsa::PhantomGenerator<elsa::real_t>::createModifiedSheppLogan(size);
} else if (phantom_kind == "Rectangle") {
auto quarter = s / 4;
const auto lower = elsa::IndexVector_t::Constant(size.size(), quarter);
const auto upper = elsa::IndexVector_t::Constant(size.size(), s - quarter);
return elsa::PhantomGenerator<elsa::real_t>::createRectanglePhantom(size, lower, upper);
} else if (phantom_kind == "Circle") {
return elsa::PhantomGenerator<elsa::real_t>::createCirclePhantom(size, s / 4.f);
}
throw elsa::Error("Unknown phantom kind {}", phantom_kind);
}
std::unique_ptr<elsa::Solver<elsa::real_t>>
get_solver(std::string solver_kind, const elsa::LinearOperator<elsa::real_t>& projector,
const elsa::DataContainer<elsa::real_t>& sinogram)
{
if (solver_kind == "CG") {
elsa::TikhonovProblem problem(projector, sinogram, 0.1);
elsa::CG solver(problem);
return solver.clone();
} else if (solver_kind == "ISTA") {
elsa::LASSOProblem problem(projector, sinogram);
elsa::ISTA solver(problem);
return solver.clone();
} else if (solver_kind == "FISTA") {
elsa::LASSOProblem problem(projector, sinogram);
elsa::FISTA solver(problem);
return solver.clone();
} else {
throw elsa::Error("Unknown Solver {}", solver_kind);
}
}
const auto phantom = elsa::PhantomGenerator<elsa::real_t>::createModifiedSheppLogan(size); std::unique_ptr<elsa::LinearOperator<elsa::real_t>>
get_projector(std::string projector_kind, const elsa::DataDescriptor& volume,
const elsa::DetectorDescriptor& sinogram)
{
const auto& vol = dynamic_cast<const elsa::VolumeDescriptor&>(volume);
if (projector_kind == "Blob") {
elsa::BlobProjector<elsa::real_t> projector(vol, sinogram);
return projector.clone();
} else if (projector_kind == "Siddon") {
elsa::SiddonsMethod<elsa::real_t> projector(vol, sinogram);
return projector.clone();
} else if (projector_kind == "Joseph") {
elsa::JosephsMethod<elsa::real_t> projector(vol, sinogram);
return projector.clone();
}
throw elsa::Error("Unknown projector {}", projector_kind);
}
elsa::DataContainer<elsa::real_t> reconstruct(const elsa::DataContainer<elsa::real_t>& phantom,
elsa::index_t numAngles, elsa::index_t arc,
elsa::index_t iters, std::string projector_kind,
std::string forward_kind, std::string solver_kind)
{
const auto size = phantom.getDataDescriptor().getNumberOfCoefficientsPerDimension();
auto& volumeDescriptor = phantom.getDataDescriptor(); auto& volumeDescriptor = phantom.getDataDescriptor();
// write the phantom out const auto dims = size.size();
elsa::io::write(phantom, "2dphantom.pgm");
// generate circular trajectory // generate circular trajectory
const auto distance = static_cast<elsa::real_t>(size(0)); const auto distance = static_cast<elsa::real_t>(size(0));
auto sinoDescriptor = elsa::CircleTrajectoryGenerator::createTrajectory( auto sinoDescriptor = elsa::CircleTrajectoryGenerator::createTrajectory(
numAngles, phantom.getDataDescriptor(), arc, distance * 100.0f, distance); numAngles, phantom.getDataDescriptor(), arc, distance * 100.0f, distance);
// dynamic_cast to VolumeDescriptor is legal and will not throw, as PhantomGenerator returns a elsa::Logger::get("Info")->info("Create {}-Projector", projector_kind);
// VolumeDescriptor auto projector_ptr = get_projector(projector_kind, volumeDescriptor, *sinoDescriptor);
elsa::Logger::get("Info")->info("Create BlobProjector"); auto& projector = *projector_ptr;
elsa::BlobProjector projector(dynamic_cast<const elsa::VolumeDescriptor&>(volumeDescriptor),
*sinoDescriptor); // Don't commit the inverse crim (i.e, use a different projector for the initial forward
// projection)
auto forward_ptr = get_projector(forward_kind, volumeDescriptor, *sinoDescriptor);
auto& forward = *forward_ptr;
// simulate the sinogram // simulate the sinogram
elsa::Logger::get("Info")->info("Calculate sinogram"); elsa::Logger::get("Info")->info("Calculate sinogram using {}-Projector", forward_kind);
auto sinogram = projector.apply(phantom); auto sinogram = forward.apply(phantom);
elsa::io::write(sinogram, "2dsinogram.pgm");
// setup reconstruction problem elsa::io::write(sinogram, fmt::format("{}dsinogram_{}.edf", dims, forward_kind));
elsa::WLSProblem wlsProblem(projector, sinogram);
elsa::CG cgSolver(wlsProblem); if (dims == 2) {
elsa::io::write(sinogram, fmt::format("{}dsinogram_{}.pgm", dims, forward_kind));
} else if (dims == 3) {
for (int i = 0; i < size[dims - 1]; ++i) {
elsa::io::write(sinogram.slice(i),
fmt::format("{}dsinogram_{:02}_{}.pgm", dims, i, forward_kind));
}
}
auto cgReconstruction = cgSolver.solve(iters); // Only compute sinogram for the "main" projector, if actually a different forward projector is
elsa::PGM::write(cgReconstruction, "2dreconstruction_cg.pgm"); // used
if (projector_kind != forward_kind) {
elsa::Logger::get("Info")->info("Calculate sinogram using {}-Projector", projector_kind);
auto sino = projector.apply(phantom);
return {cgReconstruction, phantom}; elsa::io::write(sino, fmt::format("{}dsinogram_{}.edf", dims, projector_kind));
// return {phantom, phantom}; if (dims == 2) {
} elsa::io::write(sinogram, fmt::format("{}dsinogram_{}.pgm", dims, forward_kind));
} else if (dims == 3) {
for (int i = 0; i < size[dims - 1]; ++i) {
elsa::io::write(sinogram.slice(i),
fmt::format("{}dsinogram_{:02}_{}.pgm", dims, i, forward_kind));
}
}
}
std::pair<elsa::DataContainer<elsa::real_t>, elsa::DataContainer<elsa::real_t>> // setup reconstruction problem
recon3d(elsa::index_t s, elsa::index_t numAngles, elsa::index_t arc, elsa::index_t iters) elsa::Logger::get("Info")->info("Setting up solver {}", solver_kind);
{ auto solver_ptr = get_solver(solver_kind, projector, sinogram);
elsa::IndexVector_t size({{s, s}}); auto& solver = *solver_ptr;
const auto phantom = elsa::PhantomGenerator<elsa::real_t>::createModifiedSheppLogan(size); elsa::Logger::get("Info")->info("Start reconstruction");
auto reconstruction = solver.solve(iters);
return {phantom, phantom}; elsa::io::write(reconstruction, fmt::format("{}dreconstruction_{}.edf", dims, projector_kind));
if (dims == 2) {
elsa::io::write(reconstruction,
fmt::format("{}dreconstruction_{}.pgm", dims, projector_kind));
}
if (dims == 2) {
elsa::io::write(reconstruction,
fmt::format("{}dreconstruction_{}.pgm", dims, projector_kind));
} else if (dims == 3) {
for (int i = 0; i < size[dims - 1]; ++i) {
elsa::io::write(reconstruction.slice(i),
fmt::format("{}dreconstruction_{:02}_{}.pgm", dims, i, projector_kind));
}
}
return reconstruction;
} }
template <typename data_t> template <typename data_t>
...@@ -101,7 +193,8 @@ data_t standardDeviation(elsa::DataContainer<data_t> a) ...@@ -101,7 +193,8 @@ data_t standardDeviation(elsa::DataContainer<data_t> a)
template <typename data_t> template <typename data_t>
data_t meanSquaredError(elsa::DataContainer<data_t> a, elsa::DataContainer<data_t> b) data_t meanSquaredError(elsa::DataContainer<data_t> a, elsa::DataContainer<data_t> b)
{ {
return elsa::DataContainer(a - b).l2Norm() / a.getSize(); elsa::DataContainer diff = a - b;
return diff.l2Norm() / a.getSize();
} }
template <typename data_t> template <typename data_t>
...@@ -182,11 +275,30 @@ int main(int argc, char* argv[]) ...@@ -182,11 +275,30 @@ int main(int argc, char* argv[])
.default_value(10) .default_value(10)
.scan<'i', int>(); .scan<'i', int>();
args.add_argument("--projector")
.help("Projector to use for reconstruction (\"Blob\", \"Siddon\", \"Joseph\")")
.default_value(std::string("Blob"));
args.add_argument("--forward")
.help("Choose different projector for forward proj (\"Blob\", \"Siddon\", \"Joseph\")")
.default_value(std::string("Joseph"));
args.add_argument("--solver")
.help("Choose different solver (\"CG\", \"ISTA\", \"FISTA\")")
.default_value(std::string("CG"));
args.add_argument("--phantom")
.help("Choose different solver (\"SheppLogan\", \"Rectangle\", \"Circle\")")
.default_value(std::string("SheppLogan"));
args.add_argument("--analyze") args.add_argument("--analyze")
.help("Analyze reconstruction") .help("Analyze reconstruction")
.default_value(false) .default_value(false)
.implicit_value(true); .implicit_value(true);
args.add_argument("--baseline")
.help("Give a baseline file with which the current reconstruction is compared");
try { try {
args.parse_args(argc, argv); args.parse_args(argc, argv);
} catch (const std::runtime_error& err) { } catch (const std::runtime_error& err) {
...@@ -204,20 +316,44 @@ int main(int argc, char* argv[]) ...@@ -204,20 +316,44 @@ int main(int argc, char* argv[])
} }
return size; return size;
}(); }();
const elsa::index_t arc = args.get<int>("--arc"); const elsa::index_t arc = args.get<int>("--arc");
const elsa::index_t iters = args.get<int>("--iters"); const elsa::index_t iters = args.get<int>("--iters");
auto [recon, phantom] = [&]() { const auto projector_kind = args.get<std::string>("--projector");
if (dims == 2) {
return recon2d(size, num_angles, arc, iters); const auto forward_projection = [&]() {
} else { if (args.is_used("--forward")) {
return recon3d(size, num_angles, arc, iters); return args.get<std::string>("--forward");
} }
return projector_kind;
}(); }();
const auto solver_kind = args.get<std::string>("--solver");
const auto phantom_kind = args.get<std::string>("--phantom");
/// reconstruction setup
const auto phantom = get_phantom(dims, size, phantom_kind);
// write the phantom out
if (dims == 2) {
elsa::io::write(phantom, fmt::format("{}dphantom.pgm", dims));
}
elsa::io::write(phantom, fmt::format("{}dphantom.edf", dims));
auto recon = reconstruct(phantom, num_angles, arc, iters, projector_kind, forward_projection,
solver_kind);
if (args["--analyze"] == true) { if (args["--analyze"] == true) {
analyze(phantom, recon); analyze(phantom, recon);
if (args.is_used("--baseline")) {
const auto baseline_file = args.get<std::string>("--baseline");
const auto baseline = elsa::io::read<elsa::real_t>(baseline_file);
analyze(baseline, recon);
}
} }
return 0; return 0;
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment