Commit b84ee79d authored by Jens Petit's avatar Jens Petit

DataHandlerGPU: Change tests for default generic DataHandlerType (#21)

parent c62f0d34
......@@ -89,7 +89,11 @@ SCENARIO("Testing Problem without regularization")
* (scaling.array() * x0Vec.array() - dataVec.array())
.matrix()
.squaredNorm()));
REQUIRE(prob.getGradient() == dcScaling * (dcScaling * dcX0 - dcData));
DataContainer gradientDirect = dcScaling * (dcScaling * dcX0 - dcData);
auto gradient = prob.getGradient();
for (index_t i = 0; i < gradientDirect.getSize(); ++i)
REQUIRE(gradient[i] == Approx(gradientDirect[i]));
auto hessian = prob.getHessian();
auto result = hessian.apply(dcData);
......@@ -179,8 +183,12 @@ SCENARIO("Testing Problem with one regularization term")
0.5
* (scaling.array() * x0Vec.array() - dataVec.array()).matrix().squaredNorm();
REQUIRE(prob.evaluate() == Approx(valueData + weight * 0.5 * x0Vec.squaredNorm()));
REQUIRE(prob.getGradient()
== dcScaling * (dcScaling * dcX0 - dcData) + weight * dcX0);
DataContainer gradientDirect =
dcScaling * (dcScaling * dcX0 - dcData) + weight * dcX0;
auto gradient = prob.getGradient();
for (index_t i = 0; i < gradient.getSize(); ++i)
REQUIRE(gradient[i] == Approx(gradientDirect[i]));
auto hessian = prob.getHessian();
auto result = hessian.apply(dcData);
......@@ -279,9 +287,12 @@ SCENARIO("Testing Problem with several regularization terms")
REQUIRE(prob.evaluate()
== Approx(valueData + weight1 * 0.5 * x0Vec.squaredNorm()
+ weight2 * 0.5 * x0Vec.squaredNorm()));
REQUIRE(prob.getGradient()
== dcScaling * (dcScaling * dcX0 - dcData) + weight1 * dcX0
+ weight2 * dcX0);
auto gradient = prob.getGradient();
DataContainer gradientDirect =
dcScaling * (dcScaling * dcX0 - dcData) + weight1 * dcX0 + weight2 * dcX0;
for (index_t i = 0; i < gradient.getSize(); ++i)
REQUIRE(gradient[i] == Approx(gradientDirect[i]).margin(0.00001));
auto hessian = prob.getHessian();
auto result = hessian.apply(dcData);
......
......@@ -93,7 +93,11 @@ TEMPLATE_TEST_CASE("Scenario: Testing QuadricProblem", "", QuadricProblem<float>
REQUIRE(prob.evaluate()
== Approx(static_cast<data_t>(0.5 * scaleFactor) * x0.squaredL2Norm()
- x0.dot(dc)));
REQUIRE(prob.getGradient() == scaleFactor * x0 - dc);
auto gradient = prob.getGradient();
DataContainer gradientDirect = scaleFactor * x0 - dc;
for (index_t i = 0; i < gradient.getSize(); ++i)
REQUIRE(gradient[i] == Approx(gradientDirect[i]));
auto hessian = prob.getHessian();
REQUIRE(hessian == leaf(scalingOp));
......@@ -163,7 +167,11 @@ TEMPLATE_TEST_CASE("Scenario: Testing QuadricProblem", "", QuadricProblem<float>
REQUIRE(prob.evaluate()
== Approx(static_cast<data_t>(0.5 * scaleFactor) * x0.squaredL2Norm()
- x0.dot(dc)));
REQUIRE(prob.getGradient() == scaleFactor * x0 - dc);
auto gradient = prob.getGradient();
DataContainer gradientDirect = scaleFactor * x0 - dc;
for (index_t i = 0; i < gradient.getSize(); ++i)
REQUIRE(gradient[i] == Approx(gradientDirect[i]));
auto hessian = prob.getHessian();
REQUIRE(hessian == leaf(scalingOp));
......@@ -234,7 +242,11 @@ TEMPLATE_TEST_CASE("Scenario: Testing QuadricProblem", "", QuadricProblem<float>
== Approx(static_cast<data_t>(0.5 * scaleFactor * scaleFactor)
* x0.squaredL2Norm()
- scaleFactor * x0.dot(dc)));
REQUIRE(prob.getGradient() == scaleFactor * (scaleFactor * x0) - scaleFactor * dc);
auto gradient = prob.getGradient();
DataContainer gradientDirect = scaleFactor * (scaleFactor * x0) - scaleFactor * dc;
for (index_t i = 0; i < gradient.getSize(); ++i)
REQUIRE(gradient[i] == Approx(gradientDirect[i]).margin(0.00001));
auto hessian = prob.getHessian();
REQUIRE(hessian == leaf(adjoint(scalingOp) * scalingOp));
......@@ -497,15 +509,22 @@ TEMPLATE_TEST_CASE("Scenario: Testing QuadricProblem", "", QuadricProblem<float>
== Approx(0.5 * scaleFactor * scaleFactor * weightFactor
* x0.squaredL2Norm()
- scaleFactor * weightFactor * x0.dot(dc)));
REQUIRE(prob.getGradient()
== scaleFactor * (weightFactor * (scaleFactor * x0))
- scaleFactor * (weightFactor * dc));
auto gradient = prob.getGradient();
DataContainer gradientDirect =
scaleFactor * (weightFactor * (scaleFactor * x0))
- scaleFactor * (weightFactor * dc);
for (index_t i = 0; i < gradient.getSize(); ++i)
REQUIRE(gradient[i] == Approx(gradientDirect[i]));
} else {
REQUIRE(prob.evaluate()
== Approx(0.5 * scaleFactor * scaleFactor * weightFactor
* x0.squaredL2Norm()));
REQUIRE(prob.getGradient()
== scaleFactor * (weightFactor * (scaleFactor * x0)));
auto gradient = prob.getGradient();
DataContainer gradientDirect =
scaleFactor * (weightFactor * (scaleFactor * x0));
for (index_t i = 0; i < gradient.getSize(); ++i)
REQUIRE(gradient[i] == Approx(gradientDirect[i]));
}
REQUIRE(prob.getHessian()
== leaf(adjoint(scalingOp) * weightingOp * scalingOp));
......@@ -519,7 +538,11 @@ TEMPLATE_TEST_CASE("Scenario: Testing QuadricProblem", "", QuadricProblem<float>
} else {
REQUIRE(prob.evaluate()
== Approx(0.5 * weightFactor * x0.squaredL2Norm()));
REQUIRE(prob.getGradient() == weightFactor * x0);
auto gradient = prob.getGradient();
DataContainer gradientDirect = weightFactor * x0;
for (index_t i = 0; i < gradient.getSize(); ++i)
REQUIRE(gradient[i] == Approx(gradientDirect[i]));
}
REQUIRE(prob.getHessian() == leaf(weightingOp));
}
......@@ -530,13 +553,19 @@ TEMPLATE_TEST_CASE("Scenario: Testing QuadricProblem", "", QuadricProblem<float>
== Approx(0.5 * scaleFactor * scaleFactor
* x0.squaredL2Norm()
- scaleFactor * x0.dot(dc)));
REQUIRE(prob.getGradient()
== scaleFactor * (scaleFactor * x0) - scaleFactor * dc);
auto gradient = prob.getGradient();
DataContainer gradientDirect =
scaleFactor * (scaleFactor * x0) - scaleFactor * dc;
for (index_t i = 0; i < gradient.getSize(); ++i)
REQUIRE(gradient[i] == Approx(gradientDirect[i]));
} else {
REQUIRE(prob.evaluate()
== Approx(0.5 * scaleFactor * scaleFactor
* x0.squaredL2Norm()));
REQUIRE(prob.getGradient() == scaleFactor * (scaleFactor * x0));
auto gradient = prob.getGradient();
DataContainer gradientDirect = scaleFactor * (scaleFactor * x0);
for (index_t i = 0; i < gradient.getSize(); ++i)
REQUIRE(gradient[i] == Approx(gradientDirect[i]));
}
REQUIRE(prob.getHessian() == leaf(adjoint(scalingOp) * scalingOp));
} else {
......@@ -731,7 +760,11 @@ TEMPLATE_TEST_CASE("Scenario: Testing QuadricProblem", "", QuadricProblem<float>
}
}
REQUIRE(prob.evaluate() == Approx(0.5 * x0.dot(Ax) - x0.dot(b)));
REQUIRE(prob.getGradient() == Ax - b);
auto gradient = prob.getGradient();
DataContainer gradientDirect = Ax - b;
for (index_t i = 0; i < gradient.getSize(); ++i)
REQUIRE(gradient[i] == Approx(gradientDirect[i]));
}
}
}
......
......@@ -125,8 +125,13 @@ TEMPLATE_TEST_CASE("Scenario: Testing TikhonovProblem with one regularization te
0.5
* (scaling.array() * x0Vec.array() - dataVec.array()).matrix().squaredNorm();
REQUIRE(prob.evaluate() == Approx(valueData + weight * 0.5 * x0Vec.squaredNorm()));
REQUIRE(prob.getGradient()
== dcScaling * (dcScaling * dcX0 - dcData) + weight * dcX0);
auto gradient = prob.getGradient();
DataContainer gradientDirect =
dcScaling * (dcScaling * dcX0 - dcData) + weight * dcX0;
for (index_t i = 0; i < gradient.getSize(); ++i)
REQUIRE(gradient[i] == Approx(gradientDirect[i]));
auto hessian = prob.getHessian();
auto result = hessian.apply(dcData);
......@@ -255,9 +260,12 @@ TEMPLATE_TEST_CASE("Scenario: Testing TikhonovProblem with several regularizatio
REQUIRE(prob.evaluate()
== Approx(valueData + weight1 * 0.5 * x0Vec.squaredNorm()
+ weight2 * 0.5 * x0Vec.squaredNorm()));
REQUIRE(prob.getGradient()
== dcScaling * (dcScaling * dcX0 - dcData) + weight1 * dcX0
+ weight2 * dcX0);
auto gradient = prob.getGradient();
DataContainer gradientDirect =
dcScaling * (dcScaling * dcX0 - dcData) + weight1 * dcX0 + weight2 * dcX0;
for (index_t i = 0; i < gradient.getSize(); ++i)
REQUIRE(gradient[i] == Approx(gradientDirect[i]).margin(0.00001));
auto hessian = prob.getHessian();
auto result = hessian.apply(dcData);
......
......@@ -86,7 +86,7 @@ TEMPLATE_TEST_CASE("Scenario: Solving a simple linear problem", "", CG<float>, C
DataContainer<data_t> resultsDifference = scalingOp.apply(solution) - dcB;
// should have converged for the given number of iterations
REQUIRE((resultsDifference).squaredL2Norm()
REQUIRE(Approx((resultsDifference).squaredL2Norm()).margin(0.00001)
<= epsilon * epsilon * dcB.squaredL2Norm());
}
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment