mirror of
https://hub.njuu.cf/TheAlgorithms/C-Plus-Plus.git
synced 2023-10-11 13:05:55 +08:00
added test case to classify points lying within a sphere
This commit is contained in:
parent
2d44fb1e76
commit
5e6c374445
@ -255,7 +255,57 @@ void test2(double eta = 0.01) {
|
|||||||
|
|
||||||
std::cout << "Predict for x=(" << x0 << "," << x1 << "): " << predict;
|
std::cout << "Predict for x=(" << x0 << "," << x1 << "): " << predict;
|
||||||
|
|
||||||
int expected_val = (x0 + x1) > -1 ? 1 : -1;
|
int expected_val = (x0 + 3. * x1) > -1 ? 1 : -1;
|
||||||
|
assert(predict == expected_val);
|
||||||
|
std::cout << " ...passed" << std::endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* test function to predict points in a 3D coordinate system lying within the
|
||||||
|
* sphere of radius 1 and centre at origin as +1 and others as -1. Note that
|
||||||
|
* each point is defined by 3 values but we use 6 features. The function will
|
||||||
|
* create random sample points for training and test purposes.
|
||||||
|
* \param[in] eta learning rate (optional, default=0.01)
|
||||||
|
*/
|
||||||
|
void test3(double eta = 0.01) {
|
||||||
|
adaline ada(6, eta); // 2 features
|
||||||
|
|
||||||
|
const int N = 100; // number of sample points
|
||||||
|
|
||||||
|
std::vector<double> X[N];
|
||||||
|
int Y[N]; // corresponding y-values
|
||||||
|
|
||||||
|
// generate sample points in the interval
|
||||||
|
// [-range2/100 , (range2-1)/100]
|
||||||
|
int range = 200; // sample points full-range
|
||||||
|
int range2 = range >> 1; // sample points half-range
|
||||||
|
for (int i = 0; i < N; i++) {
|
||||||
|
double x0 = ((std::rand() % range) - range2) / 100.f;
|
||||||
|
double x1 = ((std::rand() % range) - range2) / 100.f;
|
||||||
|
double x2 = ((std::rand() % range) - range2) / 100.f;
|
||||||
|
X[i] = {x0, x1, x2, x0 * x0, x1 * x1, x2 * x2};
|
||||||
|
Y[i] = ((x0 * x0) + (x1 * x1) + (x2 * x2)) <= 1.f ? 1 : -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::cout << "------- Test 3 -------" << std::endl;
|
||||||
|
std::cout << "Model before fit: " << ada << std::endl;
|
||||||
|
|
||||||
|
ada.fit(X, Y);
|
||||||
|
std::cout << "Model after fit: " << ada << std::endl;
|
||||||
|
|
||||||
|
int N_test_cases = 5;
|
||||||
|
for (int i = 0; i < N_test_cases; i++) {
|
||||||
|
double x0 = ((std::rand() % range) - range2) / 100.f;
|
||||||
|
double x1 = ((std::rand() % range) - range2) / 100.f;
|
||||||
|
double x2 = ((std::rand() % range) - range2) / 100.f;
|
||||||
|
|
||||||
|
int predict = ada.predict({x0, x1, x2, x0 * x0, x1 * x1, x2 * x2});
|
||||||
|
|
||||||
|
std::cout << "Predict for x=(" << x0 << "," << x1 << "," << x2
|
||||||
|
<< "): " << predict;
|
||||||
|
|
||||||
|
int expected_val = ((x0 * x0) + (x1 * x1) + (x2 * x2)) <= 1.f ? 1 : -1;
|
||||||
assert(predict == expected_val);
|
assert(predict == expected_val);
|
||||||
std::cout << " ...passed" << std::endl;
|
std::cout << " ...passed" << std::endl;
|
||||||
}
|
}
|
||||||
@ -265,7 +315,7 @@ void test2(double eta = 0.01) {
|
|||||||
int main(int argc, char **argv) {
|
int main(int argc, char **argv) {
|
||||||
std::srand(std::time(nullptr)); // initialize random number generator
|
std::srand(std::time(nullptr)); // initialize random number generator
|
||||||
|
|
||||||
double eta = 0.2; // default value of eta
|
double eta = 0.1; // default value of eta
|
||||||
if (argc == 2) // read eta value from commandline argument if present
|
if (argc == 2) // read eta value from commandline argument if present
|
||||||
eta = strtof(argv[1], nullptr);
|
eta = strtof(argv[1], nullptr);
|
||||||
|
|
||||||
@ -276,5 +326,10 @@ int main(int argc, char **argv) {
|
|||||||
|
|
||||||
test2(eta);
|
test2(eta);
|
||||||
|
|
||||||
|
std::cout << "Press ENTER to continue..." << std::endl;
|
||||||
|
std::cin.get();
|
||||||
|
|
||||||
|
test3(eta);
|
||||||
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user