In this article we're going to introduce the problem of dataset class imbalance which often occurs in real-world classification problems. We'll then look at oversampling as a possible solution and provide a coded example as a demonstration on an imbalanced dataset.
Let's assume we have a dataset where the data points are classified into two categories: Class A and Class B. In an ideal scenario the division of the data point classifications would be equal between the two categories, e.g.:
- Class A accounts for 50% of the dataset.
- Class B accounts for the other 50% of the dataset.
With the above scenario we could sufficiently measure the performance of a classification model using classification accuracy.
Unfortunately, this ideal balance often isn't the case when working with real- world problems, e.g. where categories of interest may occur less often. For context, here are some examples:
- Credit Card Fraud : The majority of credit card transactions are genuine, whereas the minority of credit card transactions are fraudulent.
- Medical Scans : The majority of medical scans are normal, whereas the minority of medical scans indicate something pathological.
- Weapons Detection: The majority of body scans are normal, whereas the minority of body scans detect a concealed weapon.
An imbalanced dataset could consist of data points divided as follows:
- Class A accounts for 90% of the dataset.
- Class B accounts for 10% of the dataset.
Let's say in this case that Class B represents the suspect categories, e.g. a weapon/disease/fraud has been detected. If a model scored a classification accuracy of 90% , we may decide we're happy. After all, the model appears to be correct 90% of the time.
However, this measurement is misleading when dealing with an imbalanced dataset. Another way to look at it: I could write a simple function which simply classified everything as Class A , and also achieve a classification accuracy of 90% when tested against this imbalanced dataset.
function classify_data(data): return "Class A"
Unfortunately, my solution is useless for detecting anything meaningful in the real-world.
This is a problem.
One solution to this problem is to use a sampling technique to either:
- Oversample - this will create new synthetic samples that simulate the minority class to balance the dataset.
- Undersample - this will remove samples from the majority class according to some scheme to balance the dataset.
For this article we will focus on oversampling to create a balanced training set for a machine learning algorithm. Because this involves creating synthetic samples, it is important not to include these in the test set. Testing of a model must rely entirely on the real data.
It's also important to note that oversampling is not always a suitable solution to the imbalanced dataset problem. This depends entirely on factors such as the characteristics of the dataset, the problem domain, etc.
Let's demonstrate the oversampling approach using a dataset and some Python libraries. We will be employing the imbalanced-learn package which contains many oversampling and under-sampling methods. A handy feature is its great compatibility with scikit-learn. Specifically, we will be using the Adaptive Synthetic (ADASYN) over-sampling method based on the publication below, but other popular methods, e.g. the Synthetic Minority Oversampling Technique (SMOTE), may work just as well.
He, Haibo, Yang Bai, Edwardo A. Garcia, and Shutao Li. “ADASYN: Adaptive synthetic sampling approach for imbalanced learning,” In IEEE International Joint Conference on Neural Networks (IEEE World Congress on Computational Intelligence), pp. 1322-1328, 2008.
First we begin by importing our packages. We have the usual suspects, numpy , matplotlib , and scikit-learn , with the addition of the new package which contains implementations of sampling methods: imblearn. If you 're using a pre-configured environment, e.g. Kaggle Kernels, Anaconda, or various Docker images, then it is likely you will need to install imblearn before you can import it.
import numpy as np # linear algebra from numpy import genfromtxt import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv) import seaborn as sns from sklearn.decomposition import PCA from imblearn.over_sampling import ADASYN
Moving forward we will need to load in our dataset. For this example, we have
a CSV file,
dataset.csv which contains our input variables, and the
correct/desired classification labels.
# load the dataset data = pd.read_csv("dataset.csv")
The quick invocation of the
head() function gives us some idea about the
form of the data, with input variables a to o , with the final column
labelled class :
Confirming the balance of the dataset
Before we decide if the dataset needs oversampling, we need to investigate the current balance of the samples according to their classification. Depending on the size and complexity of your dataset, you could get away with simply outputting the classification labels and observing the balance.
The output of the above tells us that there is certainly an imbalance in the
dataset, where the majority class, 1 , significantly outnumbers the
minority class, 0. To be sure of this, we can have a closer look using
The output of which will be:
As you can see, value_counts() has listed the number of instances per class, and it appears to be exactly what we were expecting. With the knowledge that one class consists of 47 samples and the other consists of 457 samples, it is clear this is an imbalanced dataset. Let's visualise this before moving on. We're going to use Principle Component Analysis (PCA) through sklearn.decomposition.PCA() to reduce the dimensionality of our input variables for easier visualisation.
pca = PCA(n_components=2) data_2d = pd.DataFrame(pca.fit_transform(data.iloc[:,0:16]))
The output of which will look something like:
The final column containing the classifications has been omitted from the
docs/version/0.17.0/generated/pandas.DataFrame.iloc.html). After the
transformation we will add the classification label column to the DataFrame
for use in visualisations coming later. We will also name our columns for easy
data_2d= pd.concat([data_2d, data['class']], axis=1) data_2d.columns = ['x', 'y', 'class']
This can be confirmed by outputting the DataFrame again:
With our DataFrame in the desirable form, we can create a quick scatterplot visualisation which again confirms the imbalance of the dataset.
ADASYN for oversampling
Using ADASYN through
imblearn.over_sampling is straight-forward. An ADASYN
object is instantiated, and then the
fit_sample() method is invoked with the
input variables and output classifications as the parameters:
ada = ADASYN() X_resampled, y_resampled = ada.fit_sample(data.iloc[:,0:16], data['class'])
The oversampled input variables have been stored in
X_resampled and their
corresponding output classifications have been stored in
again, we're going to restore our data into the DataFrame form for easy
interrogation and visualisation:
data_oversampled = pd.concat([pd.DataFrame(X_resampled), pd.DataFrame(y_resampled)], axis=1) data_oversampled.columns = data.columns
value_counts() we can have a look at the new balance:
Now we have our oversampled and more balanced dataset. Let's visualise this on a scatterplot using our earlier approach.
data_2d_oversampled = pd.DataFrame(pca.transform(data_oversampled.iloc[:,0:16])) data_2d_oversampled= pd.concat([data_2d_oversampled, data_oversampled['class']], axis=1) data_2d_oversampled.columns = ['x', 'y', 'class']
Similar to the last time, we've used PCA to reduce the dimensionality of our newly oversampled dataset for easier visualisation We've also restored the data into a DataFrame with the desired column names. If we plot this data, we can see there is no longer a significant majority class:
In this article we've had a quick look at the problem of imbalanced datasets and suggested one approach to the problem through oversampling. We've implemented a coded example which applied ADASYN to an imbalanced dataset and visualised the difference before and after. If you plan to use this approach in practice, don't forget to first split your data into the training and testing sets before applying oversampling techniques to the training set only.