From 11a6542bf2704de3cbba22843fa8322e87569cc0 Mon Sep 17 00:00:00 2001 From: Krishna Vedala <7001608+kvedala@users.noreply.github.com> Date: Fri, 26 Jun 2020 08:04:01 -0400 Subject: [PATCH] added test cases --- .../ordinary_least_squares_regressor.cpp | 39 +++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/numerical_methods/ordinary_least_squares_regressor.cpp b/numerical_methods/ordinary_least_squares_regressor.cpp index bbd75a742..d36a84042 100644 --- a/numerical_methods/ordinary_least_squares_regressor.cpp +++ b/numerical_methods/ordinary_least_squares_regressor.cpp @@ -9,6 +9,7 @@ * * \author [Krishna Vedala](https://github.com/kvedala) */ +#include #include // for print formatting #include #include @@ -352,10 +353,48 @@ std::vector predict_OLS_regressor(std::vector> const &X, return result; } +/** Self test checks */ +void test() { + int F = 3, N = 5; + + // test function = x^2 -5 + std::cout << "Test 1 (quadratic function)...."; + std::vector> data1( + {{-5, 25, -125}, {-1, 1, -1}, {0, 0, 0}, {1, 1, 1}, {6, 36, 216}}); + std::vector Y1({20, -4, -5, -4, 31}); + std::vector beta1 = fit_OLS_regressor(data1, Y1); + std::vector> test1( + {{-2, 4, -8}, {2, 4, 8}, {-10, 100, -1000}, {10, 100, 1000}}); + std::vector expected1({-1, -1, 95, 95}); + std::vector out1 = predict_OLS_regressor(test1, beta1); + for (size_t rows = 0; rows < out1.size(); rows++) + assert(std::abs(out1[rows] - expected1[rows]) < 0.01); // accuracy + std::cout << "passed\n"; + + // test function = x^3 + x^2 - 100 + std::cout << "Test 2 (cubic function)...."; + std::vector> data2( + {{-5, 25, -125}, {-1, 1, -1}, {0, 0, 0}, {1, 1, 1}, {6, 36, 216}}); + std::vector Y2({-200, -100, -100, 98, 152}); + std::vector beta2 = fit_OLS_regressor(data2, Y2); + std::vector> test2( + {{-2, 4, -8}, {2, 4, 8}, {-10, 100, -1000}, {10, 100, 1000}}); + std::vector expected2({-104, -88, -1000, 1000}); + std::vector out2 = predict_OLS_regressor(test2, beta2); + for (size_t rows = 0; rows < out2.size(); rows++) + assert(std::abs(out2[rows] - expected2[rows]) < 0.01); // accuracy + std::cout << "passed\n"; + + std::cout << std::endl; // ensure test results are displayed on screen + // (flush stdout) +} + /** * main function */ int main() { + test(); + size_t N, F; std::cout << "Enter number of features: ";