CLASSIFICATION ALGORITHM
Decision trees are everywhere in machine learning, prized for their intuitive output. Who doesn’t love a simple “if-then” flowchart? Despite their popularity, it’s surprising how difficult it is to find a clear, step-by-step explanation of how decision trees work. (I’m actually embarrassed by how long it took me to really understand how the algorithm works.)
In this article, I'll focus on the basics of tree building. We'll explain EXACTLY what happens at each node and why, from the root to the last leaves (with visuals, of course).
A decision tree classifier builds an inverted tree to make predictions, starting at the top with a question about an important feature of the data and then branching out based on the answers. As you follow these branches down, each stop asks another question, narrowing down the possibilities. This question-and-answer game continues until you reach the end (a leaf node), where you get your final prediction, or classification.
In this article, we will use this artificial golf dataset (inspired by (1)) as an example. This dataset predicts whether a person will play golf based on the weather conditions.
# Import libraries
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import pandas as pd
import numpy as np# Load data
dataset_dict = {
'Outlook': ('sunny', 'sunny', 'overcast', 'rainy', 'rainy', 'rainy', 'overcast', 'sunny', 'sunny', 'rainy', 'sunny', 'overcast', 'overcast', 'rainy', 'sunny', 'overcast', 'rainy', 'sunny', 'sunny', 'rainy', 'overcast', 'rainy', 'sunny', 'overcast', 'sunny', 'overcast', 'rainy', 'overcast'),
'Temperature': (85.0, 80.0, 83.0, 70.0, 68.0, 65.0, 64.0, 72.0, 69.0, 75.0, 75.0, 72.0, 81.0, 71.0, 81.0, 74.0, 76.0, 78.0, 82.0, 67.0, 85.0, 73.0, 88.0, 77.0, 79.0, 80.0, 66.0, 84.0),
'Humidity': (85.0, 90.0, 78.0, 96.0, 80.0, 70.0, 65.0, 95.0, 70.0, 80.0, 70.0, 90.0, 75.0, 80.0, 88.0, 92.0, 85.0, 75.0, 92.0, 90.0, 85.0, 88.0, 65.0, 70.0, 60.0, 95.0, 70.0, 78.0),
'Wind': (False, True, False, False, False, True, True, False, False, False, True, True, False, True, True, False, False, True, False, True, True, False, True, False, False, True, False, False),
'Play': ('No', 'No', 'Yes', 'Yes', 'Yes', 'No', 'Yes', 'No', 'Yes', 'Yes', 'Yes', 'Yes', 'Yes', 'No', 'No', 'Yes', 'Yes', 'No', 'No', 'No', 'Yes', 'Yes', 'Yes', 'Yes', 'Yes', 'Yes', 'No', 'Yes')
}
df = pd.DataFrame(dataset_dict)
# Preprocess data
df = pd.get_dummies(df, columns=('Outlook'), prefix='', prefix_sep='', dtype=int)
df('Wind') = df('Wind').astype(int)
df('Play') = (df('Play') == 'Yes').astype(int)
# Reorder the columns
df = df(('sunny', 'overcast', 'rainy', 'Temperature', 'Humidity', 'Wind', 'Play'))
# Prepare features and target
x, y = df.drop(columns='Play'), df('Play')
# Split data
X_train, X_test, y_train, y_test = train_test_split(x, y, train_size=0.5, shuffle=False)
# Display results
print(pd.concat((X_train, y_train), axis=1), '\n')
print(pd.concat((X_test, y_test), axis=1))
The decision tree classifier works by recursively splitting the data based on the most informative features. Here's how it works:
- Start with the entire dataset at the root node.
- Select the best feature to split the data (based on measures such as Gini impurity).
- Create child nodes for each possible value of the selected feature.
- Repeat steps 2 and 3 for each child node until a stopping criterion is met (e.g., maximum depth reached, minimum samples per leaf, or pure leaf nodes).
- Assign the majority class to each leaf node.
In scikit-learn, the decision tree algorithm is called CART (Classification and Regression Trees). It builds binary trees and typically follows these steps:
- Start with all training samples at the root node.
2. For each feature:
a. Order the values of the characteristics.
b. Consider all possible thresholds between adjacent values as possible splitting points.
def potential_split_points(attr_name, attr_values):
sorted_attr = np.sort(attr_values)
unique_values = np.unique(sorted_attr)
split_points = ((unique_values(i) + unique_values(i+1)) / 2 for i in range(len(unique_values) - 1))
return {attr_name: split_points}# Calculate and display potential split points for all columns
for column in X_train.columns:
splits = potential_split_points(column, X_train(column))
for attr, points in splits.items():
print(f"{attr:11}: {points}")
3. For each potential split point:
a. Calculate the impurity (e.g. Gini impurity) of the current node.
b. Calculate the weighted average of impurities.
def gini_impurity(y):
p = np.bincount(y) / len(y)
return 1 - np.sum(p**2)def weighted_average_impurity(y, split_index):
n = len(y)
left_impurity = gini_impurity(y(:split_index))
right_impurity = gini_impurity(y(split_index:))
return (split_index * left_impurity + (n - split_index) * right_impurity) / n
# Sort 'sunny' feature and corresponding labels
sunny = X_train('sunny')
sorted_indices = np.argsort(sunny)
sorted_sunny = sunny.iloc(sorted_indices)
sorted_labels = y_train.iloc(sorted_indices)
# Find split index for 0.5
split_index = np.searchsorted(sorted_sunny, 0.5, side='right')
# Calculate impurity
impurity = weighted_average_impurity(sorted_labels, split_index)
print(f"Weighted average impurity for 'sunny' at split point 0.5: {impurity:.3f}")
4. After calculating all the impurities for all characteristics and split points, choose the lowest one.
def calculate_split_impurities(x, y):
split_data = ()for feature in x.columns:
sorted_indices = np.argsort(x(feature))
sorted_feature = x(feature).iloc(sorted_indices)
sorted_y = y.iloc(sorted_indices)
unique_values = sorted_feature.unique()
split_points = (unique_values(1:) + unique_values(:-1)) / 2
for split in split_points:
split_index = np.searchsorted(sorted_feature, split, side='right')
impurity = weighted_average_impurity(sorted_y, split_index)
split_data.append({
'feature': feature,
'split_point': split,
'weighted_avg_impurity': impurity
})
return pd.DataFrame(split_data)
# Calculate split impurities for all features
calculate_split_impurities(X_train, y_train).round(3)
5. Create two child nodes based on the chosen feature and split point:
– Left child: samples with feature value <= split point
– Right child: samples with feature value > split point
6. Repeat steps 2 through 5 recursively for each child node. You can also stop until a stopping criterion is met (e.g., the maximum depth is reached, the minimum number of samples per leaf node is obtained, or impurities are minimally reduced).
# Calculate split impurities forselected index
selected_index = (4,8,3,13,7,9,10) # Change it depending on which indices you want to check
calculate_split_impurities(X_train.iloc(selected_index), y_train.iloc(selected_index)).round(3)
from sklearn.tree import DecisionTreeClassifier# The whole Training Phase above is done inside sklearn like this
dt_clf = DecisionTreeClassifier()
dt_clf.fit(X_train, y_train)
Final full tree
The class label of a leaf node is the majority class of the training samples that arrived at that node.
import matplotlib.pyplot as plt
from sklearn.tree import plot_tree
# Plot the decision tree
plt.figure(figsize=(20, 10))
plot_tree(dt_clf, filled=True, feature_names=x.columns, class_names=('Not Play', 'Play'))
plt.show()
This is how the prediction process works once the decision tree is trained:
- Start at the root node of the trained decision tree.
- Evaluate the feature and split condition at the current node.
- Repeat step 2 for each subsequent node until you reach a leaf node.
- The class label of the leaf node becomes the prediction for the new instance.
# Make predictions
y_pred = dt_clf.predict(X_test)
print(y_pred)
# Evaluate the classifier
print(f"Accuracy: {accuracy_score(y_test, y_pred)}")
Decision trees have several important parameters that control their growth and complexity:
1 . Maximum depth:This sets the maximum depth of the tree, which can be a valuable tool to avoid overfitting.
Helpful tip: Consider starting with a shallow tree (perhaps 3-5 levels deep) and gradually increasing the depth.
2. Minimum sample splitting:This parameter determines the minimum number of samples required to split an internal node.
Helpful TipSetting this to a higher value (around 5-10% of your training data) can help prevent the tree from creating too many small, specific splits that might not generalize well to new data.
3. Minimum leaf samples:This specifies the minimum number of samples required at a leaf node.
Helpful Tip: Choose a value that ensures each leaf represents a meaningful subset of your data (roughly 1% to 5% of your training data). This can help avoid overly specific predictions.
4. Criterion:The function used to measure the quality of a split (usually “gini” for Gini impurity or “entropy” for information gain).
Helpful Tip:While the Gini coefficient is generally simpler and faster to compute, entropy often performs better on multi-class problems. That said, they often yield similar results.
Like any machine learning algorithm, decision trees have their strengths and limitations.
Advantages:
- Interpretability:Easy to understand and visualize the decision-making process.
- No feature scale:Can handle numerical and categorical data without normalization.
- Handles non-linear relationships:You can capture complex patterns in data.
- Importance of characteristics:Provides a clear indication of which features are most important for prediction.
Cons:
- Overfitting:Prone to creating overly complex trees that do not generalize well, especially with small data sets.
- Instability:Small changes in the data can generate a completely different tree.
- Biased with imbalanced data sets:It may be biased towards the ruling classes.
- Inability to extrapolate:Predictions cannot be made beyond the range of the training data.
In our golf example, a decision tree might create very precise and interpretable rules for deciding whether to play golf based on weather conditions. However, it might overfit to specific combinations of conditions if it is not pruned properly or if the data set is small.
Decision tree classifiers are a great tool for solving many types of problems in machine learning. They are easy to understand, can handle complex data, and show us how they make decisions. This makes them useful in many areas, from business to medicine. While decision trees are powerful and interpretable, they are often used as building blocks for more advanced ensemble methods like random forests or gradient boosting machines.
# Import libraries
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from sklearn.tree import plot_tree, DecisionTreeClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score# Load data
dataset_dict = {
'Outlook': ('sunny', 'sunny', 'overcast', 'rainy', 'rainy', 'rainy', 'overcast', 'sunny', 'sunny', 'rainy', 'sunny', 'overcast', 'overcast', 'rainy', 'sunny', 'overcast', 'rainy', 'sunny', 'sunny', 'rainy', 'overcast', 'rainy', 'sunny', 'overcast', 'sunny', 'overcast', 'rainy', 'overcast'),
'Temperature': (85.0, 80.0, 83.0, 70.0, 68.0, 65.0, 64.0, 72.0, 69.0, 75.0, 75.0, 72.0, 81.0, 71.0, 81.0, 74.0, 76.0, 78.0, 82.0, 67.0, 85.0, 73.0, 88.0, 77.0, 79.0, 80.0, 66.0, 84.0),
'Humidity': (85.0, 90.0, 78.0, 96.0, 80.0, 70.0, 65.0, 95.0, 70.0, 80.0, 70.0, 90.0, 75.0, 80.0, 88.0, 92.0, 85.0, 75.0, 92.0, 90.0, 85.0, 88.0, 65.0, 70.0, 60.0, 95.0, 70.0, 78.0),
'Wind': (False, True, False, False, False, True, True, False, False, False, True, True, False, True, True, False, False, True, False, True, True, False, True, False, False, True, False, False),
'Play': ('No', 'No', 'Yes', 'Yes', 'Yes', 'No', 'Yes', 'No', 'Yes', 'Yes', 'Yes', 'Yes', 'Yes', 'No', 'No', 'Yes', 'Yes', 'No', 'No', 'No', 'Yes', 'Yes', 'Yes', 'Yes', 'Yes', 'Yes', 'No', 'Yes')
}
df = pd.DataFrame(dataset_dict)
# Prepare data
df = pd.get_dummies(df, columns=('Outlook'), prefix='', prefix_sep='', dtype=int)
df('Wind') = df('Wind').astype(int)
df('Play') = (df('Play') == 'Yes').astype(int)
# Split data
x, y = df.drop(columns='Play'), df('Play')
X_train, X_test, y_train, y_test = train_test_split(x, y, train_size=0.5, shuffle=False)
# Train model
dt_clf = DecisionTreeClassifier(
max_depth=None, # Maximum depth of the tree
min_samples_split=2, # Minimum number of samples required to split an internal node
min_samples_leaf=1, # Minimum number of samples required to be at a leaf node
criterion='gini' # Function to measure the quality of a split
)
dt_clf.fit(X_train, y_train)
# Make predictions
y_pred = dt_clf.predict(X_test)
# Evaluate model
print(f"Accuracy: {accuracy_score(y_test, y_pred)}")
# Visualize tree
plt.figure(figsize=(20, 10))
plot_tree(dt_clf, filled=True, feature_names=x.columns,
class_names=('Not Play', 'Play'), impurity=False)
plt.show()