example_argparse.cpp 12.4 KB
Newer Older
1
2
3
4
#include "elsa.h"
#include "LutProjector.h"
#include "Utilities/Statistics.hpp"
#include "IO.h"
5
#include "spdlog/fmt/bundled/core.h"
6
7
8

#include <argparse/argparse.hpp>

9
10
elsa::DataContainer<elsa::real_t> get_phantom(elsa::index_t dims, elsa::index_t s,
                                              std::string phantom_kind)
11
{
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
    const auto size = elsa::IndexVector_t::Constant(dims, s);

    if (phantom_kind == "SheppLogan") {
        return elsa::PhantomGenerator<elsa::real_t>::createModifiedSheppLogan(size);
    } else if (phantom_kind == "Rectangle") {
        auto quarter = s / 4;
        const auto lower = elsa::IndexVector_t::Constant(size.size(), quarter);
        const auto upper = elsa::IndexVector_t::Constant(size.size(), s - quarter);
        return elsa::PhantomGenerator<elsa::real_t>::createRectanglePhantom(size, lower, upper);
    } else if (phantom_kind == "Circle") {
        return elsa::PhantomGenerator<elsa::real_t>::createCirclePhantom(size, s / 4.f);
    }
    throw elsa::Error("Unknown phantom kind {}", phantom_kind);
}

std::unique_ptr<elsa::Solver<elsa::real_t>>
    get_solver(std::string solver_kind, const elsa::LinearOperator<elsa::real_t>& projector,
               const elsa::DataContainer<elsa::real_t>& sinogram)
{
    if (solver_kind == "CG") {
        elsa::TikhonovProblem problem(projector, sinogram, 0.1);
        elsa::CG solver(problem);
        return solver.clone();
    } else if (solver_kind == "ISTA") {
        elsa::LASSOProblem problem(projector, sinogram);
        elsa::ISTA solver(problem);
        return solver.clone();
    } else if (solver_kind == "FISTA") {
        elsa::LASSOProblem problem(projector, sinogram);
        elsa::FISTA solver(problem);
        return solver.clone();
    } else {
        throw elsa::Error("Unknown Solver {}", solver_kind);
    }
}
47

48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
std::unique_ptr<elsa::LinearOperator<elsa::real_t>>
    get_projector(std::string projector_kind, const elsa::DataDescriptor& volume,
                  const elsa::DetectorDescriptor& sinogram)
{
    const auto& vol = dynamic_cast<const elsa::VolumeDescriptor&>(volume);

    if (projector_kind == "Blob") {
        elsa::BlobProjector<elsa::real_t> projector(vol, sinogram);
        return projector.clone();
    } else if (projector_kind == "Siddon") {
        elsa::SiddonsMethod<elsa::real_t> projector(vol, sinogram);
        return projector.clone();
    } else if (projector_kind == "Joseph") {
        elsa::JosephsMethod<elsa::real_t> projector(vol, sinogram);
        return projector.clone();
    }
    throw elsa::Error("Unknown projector {}", projector_kind);
}

elsa::DataContainer<elsa::real_t> reconstruct(const elsa::DataContainer<elsa::real_t>& phantom,
                                              elsa::index_t numAngles, elsa::index_t arc,
                                              elsa::index_t iters, std::string projector_kind,
                                              std::string forward_kind, std::string solver_kind)
{
    const auto size = phantom.getDataDescriptor().getNumberOfCoefficientsPerDimension();
73
74
    auto& volumeDescriptor = phantom.getDataDescriptor();

75
    const auto dims = size.size();
76
77
78
79
80
81

    // generate circular trajectory
    const auto distance = static_cast<elsa::real_t>(size(0));
    auto sinoDescriptor = elsa::CircleTrajectoryGenerator::createTrajectory(
        numAngles, phantom.getDataDescriptor(), arc, distance * 100.0f, distance);

82
83
84
85
86
87
88
89
    elsa::Logger::get("Info")->info("Create {}-Projector", projector_kind);
    auto projector_ptr = get_projector(projector_kind, volumeDescriptor, *sinoDescriptor);
    auto& projector = *projector_ptr;

    // Don't commit the inverse crim (i.e, use a different projector for the initial forward
    // projection)
    auto forward_ptr = get_projector(forward_kind, volumeDescriptor, *sinoDescriptor);
    auto& forward = *forward_ptr;
90
91

    // simulate the sinogram
92
93
    elsa::Logger::get("Info")->info("Calculate sinogram using {}-Projector", forward_kind);
    auto sinogram = forward.apply(phantom);
94

95
    elsa::io::write(sinogram, fmt::format("{}dsinogram_{}.edf", dims, forward_kind));
96

97
98
99
100
101
102
103
104
    if (dims == 2) {
        elsa::io::write(sinogram, fmt::format("{}dsinogram_{}.pgm", dims, forward_kind));
    } else if (dims == 3) {
        for (int i = 0; i < size[dims - 1]; ++i) {
            elsa::io::write(sinogram.slice(i),
                            fmt::format("{}dsinogram_{:02}_{}.pgm", dims, i, forward_kind));
        }
    }
105

106
107
108
109
110
    // Only compute sinogram for the "main" projector, if actually a different forward projector is
    // used
    if (projector_kind != forward_kind) {
        elsa::Logger::get("Info")->info("Calculate sinogram using {}-Projector", projector_kind);
        auto sino = projector.apply(phantom);
111

112
113
114
115
116
117
118
119
120
121
        elsa::io::write(sino, fmt::format("{}dsinogram_{}.edf", dims, projector_kind));
        if (dims == 2) {
            elsa::io::write(sinogram, fmt::format("{}dsinogram_{}.pgm", dims, forward_kind));
        } else if (dims == 3) {
            for (int i = 0; i < size[dims - 1]; ++i) {
                elsa::io::write(sinogram.slice(i),
                                fmt::format("{}dsinogram_{:02}_{}.pgm", dims, i, forward_kind));
            }
        }
    }
122

123
124
125
126
    // setup reconstruction problem
    elsa::Logger::get("Info")->info("Setting up solver {}", solver_kind);
    auto solver_ptr = get_solver(solver_kind, projector, sinogram);
    auto& solver = *solver_ptr;
127

128
129
    elsa::Logger::get("Info")->info("Start reconstruction");
    auto reconstruction = solver.solve(iters);
130

131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
    elsa::io::write(reconstruction, fmt::format("{}dreconstruction_{}.edf", dims, projector_kind));
    if (dims == 2) {
        elsa::io::write(reconstruction,
                        fmt::format("{}dreconstruction_{}.pgm", dims, projector_kind));
    }
    if (dims == 2) {
        elsa::io::write(reconstruction,
                        fmt::format("{}dreconstruction_{}.pgm", dims, projector_kind));
    } else if (dims == 3) {
        for (int i = 0; i < size[dims - 1]; ++i) {
            elsa::io::write(reconstruction.slice(i),
                            fmt::format("{}dreconstruction_{:02}_{}.pgm", dims, i, projector_kind));
        }
    }

    return reconstruction;
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
}

template <typename data_t>
data_t covariance(elsa::DataContainer<data_t> a, elsa::DataContainer<data_t> b, data_t mean_a,
                  data_t mean_b)
{
    data_t sum = 0.0;
    for (int i = 0; i < a.getSize(); ++i) {
        sum += ((a[i] - mean_a) * (b[i] - mean_b));
    }
    return sum / a.getSize();
}

template <typename data_t>
data_t covariance(elsa::DataContainer<data_t> a, elsa::DataContainer<data_t> b)
{
    const auto mean_a = a.sum() / a.getSize();
    const auto mean_b = b.sum() / b.getSize();

    return covariance(a, b, mean_a, mean_b);
}

template <typename data_t>
data_t mean(elsa::DataContainer<data_t> data)
{
    return data.sum() / data.getSize();
}

template <typename data_t>
data_t variance(elsa::DataContainer<data_t> a)
{
    const auto mean_a = mean(a);

    data_t sum = 0.0;
    for (int i = 0; i < a.getSize(); ++i) {
        sum += std::pow(a[i] - mean_a, 2);
    }
    return sum / a.getSize();
}

template <typename data_t>
data_t standardDeviation(elsa::DataContainer<data_t> a)
{
    return std::sqrt(variance(a));
}

template <typename data_t>
data_t meanSquaredError(elsa::DataContainer<data_t> a, elsa::DataContainer<data_t> b)
{
196
197
    elsa::DataContainer diff = a - b;
    return diff.l2Norm() / a.getSize();
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
}

template <typename data_t>
data_t rootMeanSquaredError(elsa::DataContainer<data_t> a, elsa::DataContainer<data_t> b)
{
    return std::sqrt(meanSquaredError(a, b));
}

template <typename data_t>
data_t peakSignalToNoiseRation(elsa::DataContainer<data_t> a, elsa::DataContainer<data_t> b)
{
    const auto mse = meanSquaredError(a, b);
    return 20 * std::log10(a.maxElement()) - 10 * std::log10(mse);
}

template <typename data_t>
data_t structuralSimilarityIndex(elsa::DataContainer<data_t> a, elsa::DataContainer<data_t> b)
{
    auto mean_a = mean(a);
    auto mean_b = mean(b);

    const auto var_a = variance(a);
    const auto var_b = variance(a);

    auto covar = covariance(a, b);

    const auto L = std::pow(2, sizeof(data_t)) - 1;
    const auto k1 = 0.01;
    const auto k2 = 0.03;

    const auto c1 = std::pow(k1 * L, 2);
    const auto c2 = std::pow(k2 * L, 2);

    const auto q1 = 2 * mean_a * mean_b + c1;
    const auto q2 = 2 * covar + c1;
    const auto q3 = mean_a * mean_a + mean_b * mean_b + c1;
    const auto q4 = var_a * var_a + var_b * var_b + c2;

    return (q1 * q2) / (q3 * q4);
}

template <typename data_t>
void analyze(elsa::DataContainer<data_t> phantom, elsa::DataContainer<data_t> recon)
{
    // Compute mean squared difference
    auto mse = meanSquaredError(phantom, recon);
    auto rmse = rootMeanSquaredError(phantom, recon);

    elsa::Logger::get("Analyze")->info("MSE: {}", mse);
    elsa::Logger::get("Analyze")->info("RMSE: {}", rmse);

    auto psnr = peakSignalToNoiseRation(phantom, recon);
    elsa::Logger::get("Analyze")->info("PSNR: {} dB", psnr);

    const auto ssim = structuralSimilarityIndex(phantom, recon);
    elsa::Logger::get("Analyze")->info("SSIM: {}", ssim);
};

int main(int argc, char* argv[])
{
    argparse::ArgumentParser args("elsa", "0.7");

    args.add_argument("--dims").help("Dimension of the problem").default_value(2).scan<'i', int>();
    args.add_argument("--size").help("Size of the problem").default_value(256).scan<'i', int>();

    args.add_argument("--angles")
        .help("Number of poses for trajectory")
        .default_value(0)
        .scan<'i', int>();

    args.add_argument("--arc")
        .help("Arc for trajectory (in degree)")
        .default_value(360)
        .scan<'i', int>();

    args.add_argument("--iters")
        .help("Number of iterations for solver")
        .default_value(10)
        .scan<'i', int>();

278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
    args.add_argument("--projector")
        .help("Projector to use for reconstruction (\"Blob\", \"Siddon\", \"Joseph\")")
        .default_value(std::string("Blob"));

    args.add_argument("--forward")
        .help("Choose different projector for forward proj (\"Blob\", \"Siddon\", \"Joseph\")")
        .default_value(std::string("Joseph"));

    args.add_argument("--solver")
        .help("Choose different solver (\"CG\", \"ISTA\", \"FISTA\")")
        .default_value(std::string("CG"));

    args.add_argument("--phantom")
        .help("Choose different solver (\"SheppLogan\", \"Rectangle\", \"Circle\")")
        .default_value(std::string("SheppLogan"));

294
295
296
297
298
    args.add_argument("--analyze")
        .help("Analyze reconstruction")
        .default_value(false)
        .implicit_value(true);

299
300
301
    args.add_argument("--baseline")
        .help("Give a baseline file with which the current reconstruction is compared");

302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
    try {
        args.parse_args(argc, argv);
    } catch (const std::runtime_error& err) {
        std::cerr << err.what() << std::endl;
        std::cerr << args;
        std::exit(1);
    }

    const elsa::index_t dims = args.get<int>("--dims");
    const elsa::index_t size = args.get<int>("--size");

    const auto num_angles = [&]() {
        if (args.is_used("--angles")) {
            return static_cast<elsa::index_t>(args.get<int>("--angles"));
        }
        return size;
    }();
319

320
321
322
323
    const elsa::index_t arc = args.get<int>("--arc");

    const elsa::index_t iters = args.get<int>("--iters");

324
325
326
327
328
    const auto projector_kind = args.get<std::string>("--projector");

    const auto forward_projection = [&]() {
        if (args.is_used("--forward")) {
            return args.get<std::string>("--forward");
329
        }
330
        return projector_kind;
331
332
    }();

333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
    const auto solver_kind = args.get<std::string>("--solver");

    const auto phantom_kind = args.get<std::string>("--phantom");

    /// reconstruction setup
    const auto phantom = get_phantom(dims, size, phantom_kind);

    // write the phantom out
    if (dims == 2) {
        elsa::io::write(phantom, fmt::format("{}dphantom.pgm", dims));
    }
    elsa::io::write(phantom, fmt::format("{}dphantom.edf", dims));

    auto recon = reconstruct(phantom, num_angles, arc, iters, projector_kind, forward_projection,
                             solver_kind);

349
350
    if (args["--analyze"] == true) {
        analyze(phantom, recon);
351
352
353
354
355
356

        if (args.is_used("--baseline")) {
            const auto baseline_file = args.get<std::string>("--baseline");
            const auto baseline = elsa::io::read<elsa::real_t>(baseline_file);
            analyze(baseline, recon);
        }
357
358
359
360
    }

    return 0;
}