With the evolving digital landscape, a large amount of data is generated and captured from various sources. While immensely valuable, this vast universe of information often reflects the unbalanced distribution of real-world phenomena. The problem of imbalanced data is not simply a statistical challenge; has far-reaching implications for the accuracy and reliability of data-driven models.
Take, for example, the growing and prevalent concern over fraud detection in the financial industry. As much as we want to avoid fraud due to its highly harmful nature, machines (and even humans) inevitably need to learn from examples of fraudulent transactions (albeit rare) to distinguish them from the number of daily legitimate transactions.
This imbalance in the distribution of data between fraudulent and non-fraudulent transactions poses significant challenges for machine learning models aimed at detecting such anomalous activities. Without proper handling of data imbalance, these models risk being biased toward predicting transactions as legitimate, potentially missing rare cases of fraud.
Healthcare is another field where machine learning models are leveraged to predict imbalanced outcomes, such as diseases like cancer or rare genetic disorders. These outcomes occur much less frequently than their benign counterparts. Therefore, models trained on such imbalanced data are more susceptible to incorrect predictions and diagnoses. This missed health alert defeats the purpose of the model in the first place, that is, detecting early disease.
These are just a few examples that highlight the profound impact of data imbalance, that is, when one class significantly outperforms the other. Oversampling and undersampling are two standard data preprocessing techniques for balancing the data set, of which we will focus on undersampling in this article.
Let's discuss some popular methods for subsampling a given distribution.
Let's start with an illustrative example to better understand the importance of subsampling techniques. The following visualization demonstrates the impact of the relative number of points per class, executed by a Support Vector Machines with a linear core. The following code and graphics refer to the Kaggle Notebook.
import matplotlib.pyplot as plt
from sklearn.svm import LinearSVC
import numpy as np
from collections import Counter
from sklearn.datasets import make_classification
def create_dataset(
n_samples=1000, weights=(0.01, 0.01, 0.98), n_classes=3, class_sep=0.8, n_clusters=1
):
return make_classification(
n_samples=n_samples,
n_features=2,
n_informative=2,
n_redundant=0,
n_repeated=0,
n_classes=n_classes,
n_clusters_per_class=n_clusters,
weights=list(weights),
class_sep=class_sep,
random_state=0,
)
def plot_decision_function(X, y, clf, ax):
plot_step = 0.02
x_min, x_max = X(:, 0).min() - 1, X(:, 0).max() + 1
y_min, y_max = X(:, 1).min() - 1, X(:, 1).max() + 1
xx, yy = np.meshgrid(
np.arange(x_min, x_max, plot_step), np.arange(y_min, y_max, plot_step)
)
Z = clf.predict(np.c_(xx.ravel(), yy.ravel()))
Z = Z.reshape(xx.shape)
ax.contourf(xx, yy, Z, alpha=0.4)
ax.scatter(X(:, 0), X(:, 1), alpha=0.8, c=y, edgecolor="k")
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 12))
ax_arr = (ax1, ax2, ax3, ax4)
weights_arr = (
(0.01, 0.01, 0.98),
(0.01, 0.05, 0.94),
(0.2, 0.1, 0.7),
(0.33, 0.33, 0.33),
)
for ax, weights in zip(ax_arr, weights_arr):
X, y = create_dataset(n_samples=1000, weights=weights)
clf = LinearSVC().fit(X, y)
plot_decision_function(X, y, clf, ax)
ax.set_title("Linear SVC with y={}".format(Counter(y)))
The code above generates graphs for four different distributions from a highly imbalanced data set with one class dominating 97% of the instances. The second and third graphs have 93% and 69% of the instances of a single class, respectively, while the last graph has a perfectly balanced distribution, that is, the three classes contribute one third of the instances. Below are graphs of the data sets from most imbalanced to least. By fitting SVM on this data, the hyperplane in the first (highly imbalanced) graph is pushed to one side of the graph, mainly because the algorithm treats each instance equally, regardless of class, and tries to separate the classes with the maximum margin. Therefore, a majority yellow population near the center pushes the hyperplane toward the corner, causing the algorithm to misclassify minority classes.
The algorithm successfully classifies all classes of interest as we move towards a more balanced distribution.
In summary, when a data set is dominated by one or more classes, the resulting solution typically results in a model with higher misclassifications. However, the classifier shows a decreasing bias as the distribution of observations per class approaches an even split.
In this case, subsampling the yellow points presents the simplest solution to address model errors caused by the rare class problem. It's worth noting that not all data sets face this issue, but for those that do, rectifying this imbalance is a crucial preliminary step in modeling the data.
We will use the Python Imbalanced-Learn library (imbalanced-learn or imblearn). We can install it using pip:
pip install -U imbalanced-learn
Let's discuss and experiment with some of the most popular subsampling techniques. Suppose you have a binary classification data set where class '0' significantly outperforms class '1'.
Near Incorrect Subsampling
NearMiss is a subsampling technique that reduces the number of majority samples by bringing them closer to the minority class. This would facilitate a clean classification by any algorithm that uses spatial separation or divides the dimensional space between the two classes. There are three versions of NearMiss:
Almost Miss-1: Majority class samples with a minimum average distance to the three closest minority class samples.
Almost Miss-2: Majority class samples with a minimum average distance to the three furthest minority class samples.
Almost Miss-3: Samples of majority classes with minimum distance to each sample of minority classes.
Let's demonstrate the NearMiss-1 subsampling algorithm using a code example:
# Import necessary libraries and modules
import numpy as np
import matplotlib.pyplot as plt
from collections import Counter
from sklearn.datasets import make_classification
from imblearn.under_sampling import NearMiss
# Generate the dataset with different class weights
features, labels = make_classification(
n_samples=1000,
n_features=2,
n_redundant=0,
n_clusters_per_class=1,
weights=(0.95, 0.05),
flip_y=0,
random_state=0,
)
# Print the distribution of classes
dist_classes = Counter(labels)
print("Before Undersampling:")
print(dist_classes)
# Generate a scatter plot of instances, labeled by class
for class_label, _ in dist_classes.items():
instances = np.where(labels == class_label)(0)
plt.scatter(features(instances, 0), features(instances, 1), label=str(class_label))
plt.legend()
plt.show()
# Set up the undersampling method
undersampler = NearMiss(version=1, n_neighbors=3)
# Apply the transformation to the dataset
features, labels = undersampler.fit_resample(features, labels)
# Print the new distribution of classes
dist_classes = Counter(labels)
print("After Undersampling:")
print(dist_classes)
# Generate a scatter plot of instances, labeled by class
for class_label, _ in dist_classes.items():
instances = np.where(labels == class_label)(0)
plt.scatter(features(instances, 0), features(instances, 1), label=str(class_label))
plt.legend()
plt.show()
Change version=1 to version=2 or version=3 in the NearMiss() class to use the NearMiss-2 or NearMiss-3 subsampling algorithm.
NearMiss-2 selects instances in the center of the overlap region between the two classes. Using the NeverMiss-3 algorithm, we observe that each instance of the minority class, which overlaps with the region of the majority class, has up to three neighbors of the majority class. The n_neighbors attribute in the code example above defines this.
This method starts by considering a subset of the majority class as noise. It then uses a 1-nearest neighbor algorithm to classify the instances. If an instance of the majority class is misclassified, it is included in the subset. The process continues until no more instances are included in the subset.
from imblearn.under_sampling import CondensedNearestNeighbour
cnn = CondensedNearestNeighbour(random_state=42)
X_res, y_res = cnn.fit_resample(X, y)
Tomek links are pairs of instances of opposite classes located in close proximity. Removing instances of the majority class from each pair increases the space between the two classes, facilitating the classification process.
from imblearn.under_sampling import TomekLinks
tl = TomekLinks()
X_res, y_res = tl.fit_resample(X, y)
print('Original dataset shape:', Counter(y))
print('Resample dataset shape:', Counter(y_res))
With this, we have delved into the nitty-gritty of subsampling techniques in Python, covering three prominent methods: near miss subsampling, condensed nearest neighbor, and Tomek link subsampling.
Subsampling is a crucial step in data processing to address class imbalance issues in machine learning and also helps improve model performance and fairness. Each of these techniques offers unique advantages and can be tailored to specific data sets and the goals of machine learning projects.
This article provides a comprehensive understanding of subsampling methods and their application in Python. I hope it allows you to make informed decisions about how to address class imbalance challenges in your machine learning projects.
Vidhi Chugh is an ai strategist and digital transformation leader working at the intersection of product, science, and engineering to build scalable machine learning systems. She is an award-winning innovation leader, author and international speaker. Her mission is to democratize machine learning and break down the jargon so everyone can be a part of this transformation.