Reduce complexity linear_discriminant_analysis. (#2452)

* Reduce complexity linear_discriminant_analysis.

* Fix whitespace

* Update machine_learning/linear_discriminant_analysis.py

Co-authored-by: Dhruv Manilawala <dhruvmanila@gmail.com>

* fixup! Format Python code with psf/black push

* Fix format to surpass pre-commit tests

* updating DIRECTORY.md

* Update machine_learning/linear_discriminant_analysis.py

Co-authored-by: Dhruv Manilawala <dhruvmanila@gmail.com>

* fixup! Format Python code with psf/black push

Co-authored-by: Dhruv Manilawala <dhruvmanila@gmail.com>
Co-authored-by: github-actions <${GITHUB_ACTOR}@users.noreply.github.com>
This commit is contained in:
poloso 2020-11-10 21:35:11 -05:00 committed by GitHub
parent a6ad25c351
commit 4851942ec0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -2,6 +2,7 @@
Linear Discriminant Analysis Linear Discriminant Analysis
Assumptions About Data : Assumptions About Data :
1. The input variables has a gaussian distribution. 1. The input variables has a gaussian distribution.
2. The variance calculated for each input variables by class grouping is the 2. The variance calculated for each input variables by class grouping is the
@ -44,6 +45,7 @@
from math import log from math import log
from os import name, system from os import name, system
from random import gauss, seed from random import gauss, seed
from typing import Callable, TypeVar
# Make a training dataset drawn from a gaussian distribution # Make a training dataset drawn from a gaussian distribution
@ -245,6 +247,40 @@ def accuracy(actual_y: list, predicted_y: list) -> float:
return (correct / len(actual_y)) * 100 return (correct / len(actual_y)) * 100
num = TypeVar("num")
def valid_input(
input_type: Callable[[object], num], # Usually float or int
input_msg: str,
err_msg: str,
condition: Callable[[num], bool] = lambda x: True,
default: str = None,
) -> num:
"""
Ask for user value and validate that it fulfill a condition.
:input_type: user input expected type of value
:input_msg: message to show user in the screen
:err_msg: message to show in the screen in case of error
:condition: function that represents the condition that user input is valid.
:default: Default value in case the user does not type anything
:return: user's input
"""
while True:
try:
user_input = input_type(input(input_msg).strip() or default)
if condition(user_input):
return user_input
else:
print(f"{user_input}: {err_msg}")
continue
except ValueError:
print(
f"{user_input}: Incorrect input type, expected {input_type.__name__!r}"
)
# Main Function # Main Function
def main(): def main():
""" This function starts execution phase """ """ This function starts execution phase """
@ -254,48 +290,26 @@ def main():
print("First of all we should specify the number of classes that") print("First of all we should specify the number of classes that")
print("we want to generate as training dataset") print("we want to generate as training dataset")
# Trying to get number of classes # Trying to get number of classes
n_classes = 0 n_classes = valid_input(
while True: input_type=int,
try: condition=lambda x: x > 0,
user_input = int( input_msg="Enter the number of classes (Data Groupings): ",
input("Enter the number of classes (Data Groupings): ").strip() err_msg="Number of classes should be positive!",
) )
if user_input > 0:
n_classes = user_input
break
else:
print(
f"Your entered value is {user_input} , Number of classes "
f"should be positive!"
)
continue
except ValueError:
print("Your entered value is not numerical!")
print("-" * 100) print("-" * 100)
std_dev = 1.0 # Default value for standard deviation of dataset
# Trying to get the value of standard deviation # Trying to get the value of standard deviation
while True: std_dev = valid_input(
try: input_type=float,
user_sd = float( condition=lambda x: x >= 0,
input( input_msg=(
"Enter the value of standard deviation" "Enter the value of standard deviation"
"(Default value is 1.0 for all classes): " "(Default value is 1.0 for all classes): "
).strip() ),
or "1.0" err_msg="Standard deviation should not be negative!",
) default="1.0",
if user_sd >= 0.0: )
std_dev = user_sd
break
else:
print(
f"Your entered value is {user_sd}, Standard deviation should "
f"not be negative!"
)
continue
except ValueError:
print("Your entered value is not numerical!")
print("-" * 100) print("-" * 100)
@ -303,38 +317,24 @@ def main():
# dataset # dataset
counts = [] # An empty list to store instance counts of classes in dataset counts = [] # An empty list to store instance counts of classes in dataset
for i in range(n_classes): for i in range(n_classes):
while True: user_count = valid_input(
try: input_type=int,
user_count = int( condition=lambda x: x > 0,
input(f"Enter The number of instances for class_{i+1}: ") input_msg=(f"Enter The number of instances for class_{i+1}: "),
) err_msg="Number of instances should be positive!",
if user_count > 0: )
counts.append(user_count) counts.append(user_count)
break
else:
print(
f"Your entered value is {user_count}, Number of "
"instances should be positive!"
)
continue
except ValueError:
print("Your entered value is not numerical!")
print("-" * 100) print("-" * 100)
# An empty list to store values of user-entered means of classes # An empty list to store values of user-entered means of classes
user_means = [] user_means = []
for a in range(n_classes): for a in range(n_classes):
while True: user_mean = valid_input(
try: input_type=float,
user_mean = float( input_msg=(f"Enter the value of mean for class_{a+1}: "),
input(f"Enter the value of mean for class_{a+1}: ") err_msg="This is an invalid value.",
) )
if isinstance(user_mean, float): user_means.append(user_mean)
user_means.append(user_mean)
break
print(f"You entered an invalid value: {user_mean}")
except ValueError:
print("Your entered value is not numerical!")
print("-" * 100) print("-" * 100)
print("Standard deviation: ", std_dev) print("Standard deviation: ", std_dev)