diff --git a/DecisionTree.py b/DecisionTree.py index 3ce7c68..13f7640 100644 --- a/DecisionTree.py +++ b/DecisionTree.py @@ -48,7 +48,7 @@ def __init__( # Getting the GINI impurity based on the Y distribution self.gini_impurity = self.get_GINI() - # Sorting the counts and saving the final prediction of the node + # Sorting the counts and saving the final prediction counts_sorted = list(sorted(self.counts.items(), key=lambda item: item[1])) # Getting the last item @@ -312,4 +312,4 @@ def predict_obs(self, values: dict) -> int: # Predicting Xsubset = X.copy() Xsubset['yhat'] = root.predict(Xsubset) - print(Xsubset) \ No newline at end of file + print(Xsubset)