mirror of
https://hub.njuu.cf/TheAlgorithms/Python.git
synced 2023-10-11 13:06:12 +08:00
Reduce complexity linear_discriminant_analysis. (#2452)
* Reduce complexity linear_discriminant_analysis. * Fix whitespace * Update machine_learning/linear_discriminant_analysis.py Co-authored-by: Dhruv Manilawala <dhruvmanila@gmail.com> * fixup! Format Python code with psf/black push * Fix format to surpass pre-commit tests * updating DIRECTORY.md * Update machine_learning/linear_discriminant_analysis.py Co-authored-by: Dhruv Manilawala <dhruvmanila@gmail.com> * fixup! Format Python code with psf/black push Co-authored-by: Dhruv Manilawala <dhruvmanila@gmail.com> Co-authored-by: github-actions <${GITHUB_ACTOR}@users.noreply.github.com>
This commit is contained in:
parent
a6ad25c351
commit
4851942ec0
@ -2,6 +2,7 @@
|
||||
Linear Discriminant Analysis
|
||||
|
||||
|
||||
|
||||
Assumptions About Data :
|
||||
1. The input variables has a gaussian distribution.
|
||||
2. The variance calculated for each input variables by class grouping is the
|
||||
@ -44,6 +45,7 @@
|
||||
from math import log
|
||||
from os import name, system
|
||||
from random import gauss, seed
|
||||
from typing import Callable, TypeVar
|
||||
|
||||
|
||||
# Make a training dataset drawn from a gaussian distribution
|
||||
@ -245,6 +247,40 @@ def accuracy(actual_y: list, predicted_y: list) -> float:
|
||||
return (correct / len(actual_y)) * 100
|
||||
|
||||
|
||||
num = TypeVar("num")
|
||||
|
||||
|
||||
def valid_input(
|
||||
input_type: Callable[[object], num], # Usually float or int
|
||||
input_msg: str,
|
||||
err_msg: str,
|
||||
condition: Callable[[num], bool] = lambda x: True,
|
||||
default: str = None,
|
||||
) -> num:
|
||||
"""
|
||||
Ask for user value and validate that it fulfill a condition.
|
||||
|
||||
:input_type: user input expected type of value
|
||||
:input_msg: message to show user in the screen
|
||||
:err_msg: message to show in the screen in case of error
|
||||
:condition: function that represents the condition that user input is valid.
|
||||
:default: Default value in case the user does not type anything
|
||||
:return: user's input
|
||||
"""
|
||||
while True:
|
||||
try:
|
||||
user_input = input_type(input(input_msg).strip() or default)
|
||||
if condition(user_input):
|
||||
return user_input
|
||||
else:
|
||||
print(f"{user_input}: {err_msg}")
|
||||
continue
|
||||
except ValueError:
|
||||
print(
|
||||
f"{user_input}: Incorrect input type, expected {input_type.__name__!r}"
|
||||
)
|
||||
|
||||
|
||||
# Main Function
|
||||
def main():
|
||||
""" This function starts execution phase """
|
||||
@ -254,48 +290,26 @@ def main():
|
||||
print("First of all we should specify the number of classes that")
|
||||
print("we want to generate as training dataset")
|
||||
# Trying to get number of classes
|
||||
n_classes = 0
|
||||
while True:
|
||||
try:
|
||||
user_input = int(
|
||||
input("Enter the number of classes (Data Groupings): ").strip()
|
||||
)
|
||||
if user_input > 0:
|
||||
n_classes = user_input
|
||||
break
|
||||
else:
|
||||
print(
|
||||
f"Your entered value is {user_input} , Number of classes "
|
||||
f"should be positive!"
|
||||
)
|
||||
continue
|
||||
except ValueError:
|
||||
print("Your entered value is not numerical!")
|
||||
n_classes = valid_input(
|
||||
input_type=int,
|
||||
condition=lambda x: x > 0,
|
||||
input_msg="Enter the number of classes (Data Groupings): ",
|
||||
err_msg="Number of classes should be positive!",
|
||||
)
|
||||
|
||||
print("-" * 100)
|
||||
|
||||
std_dev = 1.0 # Default value for standard deviation of dataset
|
||||
# Trying to get the value of standard deviation
|
||||
while True:
|
||||
try:
|
||||
user_sd = float(
|
||||
input(
|
||||
"Enter the value of standard deviation"
|
||||
"(Default value is 1.0 for all classes): "
|
||||
).strip()
|
||||
or "1.0"
|
||||
)
|
||||
if user_sd >= 0.0:
|
||||
std_dev = user_sd
|
||||
break
|
||||
else:
|
||||
print(
|
||||
f"Your entered value is {user_sd}, Standard deviation should "
|
||||
f"not be negative!"
|
||||
)
|
||||
continue
|
||||
except ValueError:
|
||||
print("Your entered value is not numerical!")
|
||||
std_dev = valid_input(
|
||||
input_type=float,
|
||||
condition=lambda x: x >= 0,
|
||||
input_msg=(
|
||||
"Enter the value of standard deviation"
|
||||
"(Default value is 1.0 for all classes): "
|
||||
),
|
||||
err_msg="Standard deviation should not be negative!",
|
||||
default="1.0",
|
||||
)
|
||||
|
||||
print("-" * 100)
|
||||
|
||||
@ -303,38 +317,24 @@ def main():
|
||||
# dataset
|
||||
counts = [] # An empty list to store instance counts of classes in dataset
|
||||
for i in range(n_classes):
|
||||
while True:
|
||||
try:
|
||||
user_count = int(
|
||||
input(f"Enter The number of instances for class_{i+1}: ")
|
||||
)
|
||||
if user_count > 0:
|
||||
counts.append(user_count)
|
||||
break
|
||||
else:
|
||||
print(
|
||||
f"Your entered value is {user_count}, Number of "
|
||||
"instances should be positive!"
|
||||
)
|
||||
continue
|
||||
except ValueError:
|
||||
print("Your entered value is not numerical!")
|
||||
user_count = valid_input(
|
||||
input_type=int,
|
||||
condition=lambda x: x > 0,
|
||||
input_msg=(f"Enter The number of instances for class_{i+1}: "),
|
||||
err_msg="Number of instances should be positive!",
|
||||
)
|
||||
counts.append(user_count)
|
||||
print("-" * 100)
|
||||
|
||||
# An empty list to store values of user-entered means of classes
|
||||
user_means = []
|
||||
for a in range(n_classes):
|
||||
while True:
|
||||
try:
|
||||
user_mean = float(
|
||||
input(f"Enter the value of mean for class_{a+1}: ")
|
||||
)
|
||||
if isinstance(user_mean, float):
|
||||
user_means.append(user_mean)
|
||||
break
|
||||
print(f"You entered an invalid value: {user_mean}")
|
||||
except ValueError:
|
||||
print("Your entered value is not numerical!")
|
||||
user_mean = valid_input(
|
||||
input_type=float,
|
||||
input_msg=(f"Enter the value of mean for class_{a+1}: "),
|
||||
err_msg="This is an invalid value.",
|
||||
)
|
||||
user_means.append(user_mean)
|
||||
print("-" * 100)
|
||||
|
||||
print("Standard deviation: ", std_dev)
|
||||
|
Loading…
Reference in New Issue
Block a user