Commit 3268337b authored by David Frank's avatar David Frank
Browse files

Add example that uses argument parsing

parent 9ac7834c
......@@ -27,16 +27,24 @@ add_example(example2d_admm example2d_admm.cpp)
# build the shearlet example program
add_example(shearlet_example example2d_shearlet.cpp)
if(ELSA_BUILD_CUDA_PROJECTORS)
include(CheckLanguage)
check_language(CUDA)
if(CMAKE_CUDA_COMPILER)
enable_language(CUDA)
CPMAddPackage(
NAME argparse
GITHUB_REPOSITORY p-ranav/argparse
GIT_TAG v2.2
)
add_example(example_argparse example_argparse.cpp)
target_link_libraries(example_argparse PRIVATE argparse::argparse)
if(ELSA_BUILD_CUDA_PROJECTORS)
include(CheckLanguage)
check_language(CUDA)
if(CMAKE_CUDA_COMPILER)
enable_language(CUDA)
# build the 3d example program
add_example(example3d example3d.cpp)
# build the GPU projector speed test program
add_example(speed_test speed_test.cpp)
endif()
endif()
endif()
......@@ -11,7 +11,8 @@ void example2d()
{
// generate 2d phantom
IndexVector_t size(2);
size << 128, 128;
// size << 128, 128;
size << 512, 512;
auto phantom = PhantomGenerator<real_t>::createModifiedSheppLogan(size);
auto& volumeDescriptor = phantom.getDataDescriptor();
......@@ -19,7 +20,7 @@ void example2d()
io::write(phantom, "2dphantom.pgm");
// generate circular trajectory
index_t numAngles{180}, arc{360};
index_t numAngles{512}, arc{360};
const auto distance = static_cast<real_t>(size(0));
auto sinoDescriptor = CircleTrajectoryGenerator::createTrajectory(
numAngles, phantom.getDataDescriptor(), arc, distance * 100.0f, distance);
......@@ -44,10 +45,11 @@ void example2d()
// solve the reconstruction problem
CG cgSolver(wlsProblem);
index_t noIterations{20};
index_t noIterations{10};
Logger::get("Info")->info("Solving reconstruction using {} iterations of conjugate gradient",
noIterations);
auto cgReconstruction = cgSolver.solve(noIterations);
std::cout << cgReconstruction.l2Norm() << "\n";
// write the reconstruction out
io::write(cgReconstruction, "2dreconstruction_cg.pgm");
......
......@@ -2,56 +2,148 @@
#include "LutProjector.h"
#include "Utilities/Statistics.hpp"
#include "IO.h"
#include "spdlog/fmt/bundled/core.h"
#include <argparse/argparse.hpp>
std::pair<elsa::DataContainer<elsa::real_t>, elsa::DataContainer<elsa::real_t>>
recon2d(elsa::index_t s, elsa::index_t numAngles, elsa::index_t arc, elsa::index_t iters)
elsa::DataContainer<elsa::real_t> get_phantom(elsa::index_t dims, elsa::index_t s,
std::string phantom_kind)
{
elsa::IndexVector_t size({{s, s}});
const auto size = elsa::IndexVector_t::Constant(dims, s);
if (phantom_kind == "SheppLogan") {
return elsa::PhantomGenerator<elsa::real_t>::createModifiedSheppLogan(size);
} else if (phantom_kind == "Rectangle") {
auto quarter = s / 4;
const auto lower = elsa::IndexVector_t::Constant(size.size(), quarter);
const auto upper = elsa::IndexVector_t::Constant(size.size(), s - quarter);
return elsa::PhantomGenerator<elsa::real_t>::createRectanglePhantom(size, lower, upper);
} else if (phantom_kind == "Circle") {
return elsa::PhantomGenerator<elsa::real_t>::createCirclePhantom(size, s / 4.f);
}
throw elsa::Error("Unknown phantom kind {}", phantom_kind);
}
std::unique_ptr<elsa::Solver<elsa::real_t>>
get_solver(std::string solver_kind, const elsa::LinearOperator<elsa::real_t>& projector,
const elsa::DataContainer<elsa::real_t>& sinogram)
{
if (solver_kind == "CG") {
elsa::TikhonovProblem problem(projector, sinogram, 0.1);
elsa::CG solver(problem);
return solver.clone();
} else if (solver_kind == "ISTA") {
elsa::LASSOProblem problem(projector, sinogram);
elsa::ISTA solver(problem);
return solver.clone();
} else if (solver_kind == "FISTA") {
elsa::LASSOProblem problem(projector, sinogram);
elsa::FISTA solver(problem);
return solver.clone();
} else {
throw elsa::Error("Unknown Solver {}", solver_kind);
}
}
const auto phantom = elsa::PhantomGenerator<elsa::real_t>::createModifiedSheppLogan(size);
std::unique_ptr<elsa::LinearOperator<elsa::real_t>>
get_projector(std::string projector_kind, const elsa::DataDescriptor& volume,
const elsa::DetectorDescriptor& sinogram)
{
const auto& vol = dynamic_cast<const elsa::VolumeDescriptor&>(volume);
if (projector_kind == "Blob") {
elsa::BlobProjector<elsa::real_t> projector(vol, sinogram);
return projector.clone();
} else if (projector_kind == "Siddon") {
elsa::SiddonsMethod<elsa::real_t> projector(vol, sinogram);
return projector.clone();
} else if (projector_kind == "Joseph") {
elsa::JosephsMethod<elsa::real_t> projector(vol, sinogram);
return projector.clone();
}
throw elsa::Error("Unknown projector {}", projector_kind);
}
elsa::DataContainer<elsa::real_t> reconstruct(const elsa::DataContainer<elsa::real_t>& phantom,
elsa::index_t numAngles, elsa::index_t arc,
elsa::index_t iters, std::string projector_kind,
std::string forward_kind, std::string solver_kind)
{
const auto size = phantom.getDataDescriptor().getNumberOfCoefficientsPerDimension();
auto& volumeDescriptor = phantom.getDataDescriptor();
// write the phantom out
elsa::io::write(phantom, "2dphantom.pgm");
const auto dims = size.size();
// generate circular trajectory
const auto distance = static_cast<elsa::real_t>(size(0));
auto sinoDescriptor = elsa::CircleTrajectoryGenerator::createTrajectory(
numAngles, phantom.getDataDescriptor(), arc, distance * 100.0f, distance);
// dynamic_cast to VolumeDescriptor is legal and will not throw, as PhantomGenerator returns a
// VolumeDescriptor
elsa::Logger::get("Info")->info("Create BlobProjector");
elsa::BlobProjector projector(dynamic_cast<const elsa::VolumeDescriptor&>(volumeDescriptor),
*sinoDescriptor);
elsa::Logger::get("Info")->info("Create {}-Projector", projector_kind);
auto projector_ptr = get_projector(projector_kind, volumeDescriptor, *sinoDescriptor);
auto& projector = *projector_ptr;
// Don't commit the inverse crim (i.e, use a different projector for the initial forward
// projection)
auto forward_ptr = get_projector(forward_kind, volumeDescriptor, *sinoDescriptor);
auto& forward = *forward_ptr;
// simulate the sinogram
elsa::Logger::get("Info")->info("Calculate sinogram");
auto sinogram = projector.apply(phantom);
elsa::io::write(sinogram, "2dsinogram.pgm");
elsa::Logger::get("Info")->info("Calculate sinogram using {}-Projector", forward_kind);
auto sinogram = forward.apply(phantom);
// setup reconstruction problem
elsa::WLSProblem wlsProblem(projector, sinogram);
elsa::io::write(sinogram, fmt::format("{}dsinogram_{}.edf", dims, forward_kind));
elsa::CG cgSolver(wlsProblem);
if (dims == 2) {
elsa::io::write(sinogram, fmt::format("{}dsinogram_{}.pgm", dims, forward_kind));
} else if (dims == 3) {
for (int i = 0; i < size[dims - 1]; ++i) {
elsa::io::write(sinogram.slice(i),
fmt::format("{}dsinogram_{:02}_{}.pgm", dims, i, forward_kind));
}
}
auto cgReconstruction = cgSolver.solve(iters);
elsa::PGM::write(cgReconstruction, "2dreconstruction_cg.pgm");
// Only compute sinogram for the "main" projector, if actually a different forward projector is
// used
if (projector_kind != forward_kind) {
elsa::Logger::get("Info")->info("Calculate sinogram using {}-Projector", projector_kind);
auto sino = projector.apply(phantom);
return {cgReconstruction, phantom};
// return {phantom, phantom};
}
elsa::io::write(sino, fmt::format("{}dsinogram_{}.edf", dims, projector_kind));
if (dims == 2) {
elsa::io::write(sinogram, fmt::format("{}dsinogram_{}.pgm", dims, forward_kind));
} else if (dims == 3) {
for (int i = 0; i < size[dims - 1]; ++i) {
elsa::io::write(sinogram.slice(i),
fmt::format("{}dsinogram_{:02}_{}.pgm", dims, i, forward_kind));
}
}
}
std::pair<elsa::DataContainer<elsa::real_t>, elsa::DataContainer<elsa::real_t>>
recon3d(elsa::index_t s, elsa::index_t numAngles, elsa::index_t arc, elsa::index_t iters)
{
elsa::IndexVector_t size({{s, s}});
// setup reconstruction problem
elsa::Logger::get("Info")->info("Setting up solver {}", solver_kind);
auto solver_ptr = get_solver(solver_kind, projector, sinogram);
auto& solver = *solver_ptr;
const auto phantom = elsa::PhantomGenerator<elsa::real_t>::createModifiedSheppLogan(size);
elsa::Logger::get("Info")->info("Start reconstruction");
auto reconstruction = solver.solve(iters);
return {phantom, phantom};
elsa::io::write(reconstruction, fmt::format("{}dreconstruction_{}.edf", dims, projector_kind));
if (dims == 2) {
elsa::io::write(reconstruction,
fmt::format("{}dreconstruction_{}.pgm", dims, projector_kind));
}
if (dims == 2) {
elsa::io::write(reconstruction,
fmt::format("{}dreconstruction_{}.pgm", dims, projector_kind));
} else if (dims == 3) {
for (int i = 0; i < size[dims - 1]; ++i) {
elsa::io::write(reconstruction.slice(i),
fmt::format("{}dreconstruction_{:02}_{}.pgm", dims, i, projector_kind));
}
}
return reconstruction;
}
template <typename data_t>
......@@ -101,7 +193,8 @@ data_t standardDeviation(elsa::DataContainer<data_t> a)
template <typename data_t>
data_t meanSquaredError(elsa::DataContainer<data_t> a, elsa::DataContainer<data_t> b)
{
return elsa::DataContainer(a - b).l2Norm() / a.getSize();
elsa::DataContainer diff = a - b;
return diff.l2Norm() / a.getSize();
}
template <typename data_t>
......@@ -182,11 +275,30 @@ int main(int argc, char* argv[])
.default_value(10)
.scan<'i', int>();
args.add_argument("--projector")
.help("Projector to use for reconstruction (\"Blob\", \"Siddon\", \"Joseph\")")
.default_value(std::string("Blob"));
args.add_argument("--forward")
.help("Choose different projector for forward proj (\"Blob\", \"Siddon\", \"Joseph\")")
.default_value(std::string("Joseph"));
args.add_argument("--solver")
.help("Choose different solver (\"CG\", \"ISTA\", \"FISTA\")")
.default_value(std::string("CG"));
args.add_argument("--phantom")
.help("Choose different solver (\"SheppLogan\", \"Rectangle\", \"Circle\")")
.default_value(std::string("SheppLogan"));
args.add_argument("--analyze")
.help("Analyze reconstruction")
.default_value(false)
.implicit_value(true);
args.add_argument("--baseline")
.help("Give a baseline file with which the current reconstruction is compared");
try {
args.parse_args(argc, argv);
} catch (const std::runtime_error& err) {
......@@ -204,20 +316,44 @@ int main(int argc, char* argv[])
}
return size;
}();
const elsa::index_t arc = args.get<int>("--arc");
const elsa::index_t iters = args.get<int>("--iters");
auto [recon, phantom] = [&]() {
if (dims == 2) {
return recon2d(size, num_angles, arc, iters);
} else {
return recon3d(size, num_angles, arc, iters);
const auto projector_kind = args.get<std::string>("--projector");
const auto forward_projection = [&]() {
if (args.is_used("--forward")) {
return args.get<std::string>("--forward");
}
return projector_kind;
}();
const auto solver_kind = args.get<std::string>("--solver");
const auto phantom_kind = args.get<std::string>("--phantom");
/// reconstruction setup
const auto phantom = get_phantom(dims, size, phantom_kind);
// write the phantom out
if (dims == 2) {
elsa::io::write(phantom, fmt::format("{}dphantom.pgm", dims));
}
elsa::io::write(phantom, fmt::format("{}dphantom.edf", dims));
auto recon = reconstruct(phantom, num_angles, arc, iters, projector_kind, forward_projection,
solver_kind);
if (args["--analyze"] == true) {
analyze(phantom, recon);
if (args.is_used("--baseline")) {
const auto baseline_file = args.get<std::string>("--baseline");
const auto baseline = elsa::io::read<elsa::real_t>(baseline_file);
analyze(baseline, recon);
}
}
return 0;
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment