Illustrated examples of overfitting and underfitting, as well as how to detect & overcome them
Table of Contents
Overfitting and underfitting are two common problems in machine learning where the model becomes too complex or too simple for the given dataset. This article illustrates both problems with simple examples and elaborates on ways to detect and overcome both challenges.
Finding the right balance between overfitting and underfitting is crucial for building a good machine learning model that can generalise well to new data.
What is overfitting in machine learning?
Overfitting is a common problem in machine learning where a model performs exceptionally well on the training data but poorly on new, unseen data. Overfitting occurs when a model becomes too complex and learns noise or irrelevant patterns in the training data rather than the true underlying patterns that generalise well to new data.
Overfitting occurs when a model becomes too complex and learns noise or irrelevant patterns
In other words, overfitting happens when a model memorises the training data instead of learning the underlying relationship between the input features and the output variable. This can result in a model that performs very well on the training data but poorly on new, unseen data because it has become too specialised for the training data.
Overfitting can be addressed through regularisation, cross-validation, and early stopping.
This can help the model better generalise to new data by reducing its complexity and preventing it from memorising noise or irrelevant patterns in the training data.
What is underfitting in machine learning?
Underfitting is a common problem in machine learning. It happens when a model needs to be more complex to capture the underlying patterns in the data. Unfortunately, this means that the model needs to improve on the training data and new data it has never seen before.
In other words, underfitting occurs when a model is not complex enough to learn the underlying relationship between the input features and the output variable. This can result in a model that performs poorly on the training data and new, unseen data because it has not learned enough from it to generalize well.
Underfitting can be addressed through techniques such as increasing the complexity of the model, adding more features or increasing the number of hidden layers in a neural network, and increasing the training time.
However, finding a balance between model complexity and generalisation is vital, as overfitting can also be a problem when the model becomes too complex. Therefore, using techniques such as cross-validation to evaluate the model’s performance on the training data and new, unseen data and choosing a model that performs well on both is important.
Overfitting and underfitting simple example
Suppose you have a dataset with one input feature (e.g., the number of hours studied) and one output variable (e.g., the exam score) that looks like this:
Hours Studied | Exam Score |
1 | 20 |
1.5 | 20 |
2 | 60 |
3 | 75 |
4 | 85 |
5 | 90 |
6 | 95 |
Let’s say you want to teach a machine learning model to predict the score on an exam based on how many hours you studied.
- Suppose you train a linear regression model on this dataset. In this case, the model may underfit the data because the relationship between the input feature and the output variable may not be linear. In other words, a straight line may not capture the underlying patterns in the data. The model may predict poorly on both the training data and new, unseen data, as shown in the following plot:
Example of underfitting
If you want to recreate the plot with your data, you can adapt the Python code below that uses Matplotlib to generate the plots quickly.
import numpy
import matplotlib.pyplot as plt
x = [1,1.5,2,3,4,5,6]
y = [20, 20, 60, 75, 85, 90, 95]
# change the last 1 to a 2 or 5 to generate the other plots further down
mymodel = numpy.poly1d(numpy.polyfit(x, y, 1))
myline = numpy.linspace(1, 6, 100)
plt.scatter(x, y)
plt.plot(myline, mymodel(myline))
plt.show()
- Suppose you train a polynomial regression model on this dataset with a degree of 5. In this case, the model may overfit the data because it is too complex to capture the underlying patterns. The model may perform very well on the training data but poorly on new, unseen data, as shown in the following plot:
Example of overfitting
If we use this model to conclude the optimum study time, we would incorrectly conclude that 5.5 hours of study is better than 6 or more.
This shows the danger of overfitting. The model doesn’t represent the actual pattern you are trying to predict.
- If you train a polynomial regression model on this dataset with a degree of 2, the model may fit the data well and generalise well to new, unseen data, as shown in the following plot:
Example of a function generalising well
In short, underfitting happens when the model needs to be more complex to capture the underlying patterns in the data.
Conversely, overfitting occurs when the model needs to be simplified and starts to learn noise or irrelevant patterns in the data.
A good fit is when the model can capture the true patterns in the data without overfitting or underfitting.
Examples of underfitting and overfitting in real applications
Here are some real-world examples of underfitting and overfitting.
Underfitting
Suppose you have a dataset of images of handwritten digits, and you want to train a machine learning model to recognise the numbers. If you use a simple model, such as logistic regression, the model may underfit the data because it needs to be more complex to capture the complex patterns in the images. As a result, the model may need to improve on both the training data and new, unseen data.
To address underfitting, you can use more complex models, such as convolutional neural networks, better suited to capture complex image patterns.
Overfitting
Let’s say you have a set of customer data, like their age, income, gender, and buying habits, and you want to train a machine learning model to predict which customers will likely make a purchase. Suppose you use a complex model, such as a deep neural network with many layers. In that case, the model may overfit the data because it needs to be simplified and starts to learn noise or irrelevant patterns in the data. As a result, the model may perform very well on the training data but poorly on new, unseen data.
To address overfitting, you can use regularisation techniques, such as L1 or L2 regularisation, that add a penalty term to the loss function to prevent the model from overfitting.
How to check if the model is overfitting or underfitting
There are several ways to detect over- or under-fitting in a machine learning model:
- Plot the learning curves: Learning curves show the model’s performance on training and validation data over time as the model is being trained. If the model is overfitting, you will see that the training error continues to decrease over time, while the validation error starts to increase after a certain point. This indicates that the model is beginning to memorise the training data and needs to be generalised well to new, unseen data.
- Evaluate the model on a holdout set: A holdout set is a subset of the data that is not used during training but is used to evaluate the model after training. If the model performs well on the training data but poorly on the holdout set, it may be overfitting the training data.
- Use cross-validation: Cross-validation is a technique where the data is divided into k-folds, and the model is trained and evaluated on each fold. If the model performs well on the training data but poorly on the validation data, it may need to be more balanced.
- Regularise the model: Regularization is a method that adds a penalty term to the loss function to stop the model from becoming too similar to the training data. By changing the regularisation parameter, you can control how hard the model is to understand and prevent it from becoming too simple.
- Use simpler models: If your complex model is overfitting the data, you can use simpler models less prone to overfitting, such as linear models or decision trees with low depth.
In general, it’s crucial to monitor the model’s performance during training and evaluation and to be aware of the trade-off between model complexity and generalisation performance.
Underfitting and overfitting in NLP
Underfitting and overfitting are common problems in tasks like text classification, sentiment analysis, and machine translation using Natural Language Processing (NLP). Here are some examples of underfitting and overfitting in NLP:
- Underfitting in NLP: If your NLP model is too simple and needs more complexity, it might not fit the data well enough and miss the essential patterns in the text data. For example, suppose you are building a sentiment analysis model and only use simple bag-of-words features without context or semantic information. In that case, the model may need to be more balanced on the training and validation data.
- Overfitting in NLP: If your NLP model is too complex and has too many parameters, it may overfit the data and start to learn noise or irrelevant patterns in the text data. For example, suppose you build a machine translation model and use an extensive neural network with too many layers. In that case, the model may overfit the training data and perform poorly on new, unseen data.
You can use the same methods as other machine learning tasks to find underfitting and overfitting in NLP models.
Conclusion
Overfitting and underfitting are common challenges in machine learning. Overfitting occurs when a model is too complex and learns noise or irrelevant patterns in the data.
At the same time, underfitting occurs when a model is too simple and cannot capture the underlying patterns in the data. To detect overfitting and underfitting, you can use techniques such as plotting learning curves, evaluating the model on a holdout set, and using cross-validation.
To address overfitting and underfitting, you can use regularisation, simpler or more complex models, or add more input features. Ultimately, the goal is to find the right balance between model complexity and data fit to achieve optimal model performance on new, unseen data.
0 Comments