Linear Discriminant Analysis in Machine Learning
Last Updated :
10 Feb, 2025
When working with high-dimensional datasets it is important to apply dimensionality reduction techniques to make data exploration and modeling more efficient. One such technique is Linear Discriminant Analysis (LDA) which helps in reducing the dimensionality of data while retaining the most significant features for classification tasks. It works by finding the linear combinations of features that best separate the classes in the dataset. In this article we will learn about it and how to implement it in python.
Maximizing Class Separability : Role of LDA
Linear Discriminant Analysis (LDA) also known as Normal Discriminant Analysis is supervised classification problem that helps separate two or more classes by converting higher-dimensional data space into a lower-dimensional space. It is used to identify a linear combination of features that best separates classes within a dataset.

For example we have two classes that need to be separated efficiently. Each class may have multiple features and using a single feature to classify them may result in overlapping. To solve this LDA is used as it uses multiple features to improve classification accuracy.
LDA works by some assumptions and we are required to understand them so that we have a better understanding of its working.
Core Assumptions of LDA
For LDA to perform effectively certain assumptions are made:
- Gaussian Distribution: Data within each class should follow a Gaussian distribution.
- Equal Covariance Matrices: Covariance matrices of the different classes should be equal.
- Linear Separability: A linear decision boundary should be sufficient to separate the classes.
For example, when data points belonging to two classes are plotted if they are not linearly separable LDA will attempt to find a projection that maximizes class separability.

Linearly Separable Dataset
Image shows an example where the classes (black and green circles) are not linearly separable. LDA attempts to separate them using red dashed line. It uses both axes (X and Y) to generate a new axis in such a way that it maximizes the distance between the means of the two classes while minimizing the variation within each class. This transforms the dataset into a space where the classes are better separated.
After transforming the data points along a new axis LDA maximizes the class separation. This new axis allows for clearer classification by projecting the data along a line that enhances the distance between the means of the two classes.

The perpendicular distance between the line and points
Perpendicular distance between the decision boundary and the data points helps us to visualize how LDA works by reducing class variation and increasing separability.
After generating this new axis using the above-mentioned criteria, all the data points of the classes are plotted on this new axis and are shown in the figure given below.

It shows how LDA creates a new axis to project the data and separate the two classes effectively along a linear path. But it fails when the mean of the distributions are shared as it becomes impossible for LDA to find a new axis that makes both classes linearly separable. In such cases we use non-linear discriminant analysis.
How does LDA work?
LDA works by finding directions in the feature space that best separate the classes. It does this by maximizing the difference between the class means while minimizing the spread within each class.
Mathematical foundation of LDA
Let’s assume we have two classes with d
-dimensional samples such as [Tex]x_1, x_2, … x_n[/Tex] where:
- [Tex]n_1[/Tex] samples belong to class [Tex] c_1[/Tex]
- [Tex]n_2[/Tex] samples belong to class [Tex]c_2[/Tex].
If [Tex]x_i[/Tex] represents a data point, its projection onto the line represented by the unit vector v is [Tex]v^T x_i[/Tex].
Let the means of class [Tex]c_1[/Tex] and class [Tex]c_2[/Tex] before projection be μ1 and μ2 respectively. After projection the new means are [Tex]\hat{\mu}_1 = v^T \mu_1[/Tex]and [Tex]\hat{\mu}_2 = v^T \mu_2[/Tex].
We aim to normalize the difference [Tex]|\hat{\mu}_1 – \hat{\mu}_2|[/Tex]to maximize the class separation.
The scatter for samples of class c1c_1c1 is calculated as:
[Tex]s_1^2 = \sum_{x_i \in c_1} (x_i – \mu_1)^2[/Tex]
Similarly, for class c2c_2c2:
[Tex]s_2^2 = \sum_{x_i \in c_2} (x_i – \mu_2)^2[/Tex]
The goal is to maximize the ratio of the between-class scatter to the within-class scatter, which leads us to the following criterion:
[Tex]J(v) = \frac{|\hat{\mu}_1 – \hat{\mu}_2|}{s_1^2 + s_2^2}[/Tex]
For the best separation, we calculate the eigenvector corresponding to the highest eigenvalue of the scatter matrices [Tex]s_w^{-1} s_b[/Tex].
Extensions to LDA
- Quadratic Discriminant Analysis (QDA): Each class uses its own estimate of variance (or covariance) allowing it to handle more complex relationships.
- Flexible Discriminant Analysis (FDA): Uses non-linear combinations of inputs such as splines to handle non-linear separability.
- Regularized Discriminant Analysis (RDA): Introduces regularization into the covariance estimate to prevent overfitting.
Python Code Implementation of LDA
In this implementation we will perform linear discriminant analysis using the Scikit-learn library on the Iris dataset.
Python
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, confusion_matrix
iris = load_iris()
dataset = pd.DataFrame(columns=iris.feature_names,
data=iris.data)
dataset['target'] = iris.target
X = dataset.iloc[:, 0:4].values
y = dataset.iloc[:, 4].values
sc = StandardScaler()
X = sc.fit_transform(X)
le = LabelEncoder()
y = le.fit_transform(y)
X_train, X_test,\
y_train, y_test = train_test_split(X, y,
test_size=0.2)
lda = LinearDiscriminantAnalysis(n_components=2)
X_train = lda.fit_transform(X_train, y_train)
X_test = lda.transform(X_test)
plt.scatter(
X_train[:, 0], X_train[:, 1],
c=y_train,
cmap='rainbow',
alpha=0.7, edgecolors='b'
)
classifier = RandomForestClassifier(max_depth=2,
random_state=0)
classifier.fit(X_train, y_train)
y_pred = classifier.predict(X_test)
print('Accuracy : ' + str(accuracy_score(y_test, y_pred)))
conf_m = confusion_matrix(y_test, y_pred)
print(conf_m)
- StandardScaler(): Standardizes the features to ensure they have a mean of 0 and a standard deviation of 1 removing the influence of different scales.
- fit_transform(): Standardizes the feature data by applying the transformation learned from the training data ensuring each feature contributes equally.
- LabelEncoder(): Converts categorical labels into numerical values that machine learning models can process.
- fit_transform() on y: Transforms the target labels into numerical values for use in classification models.
- LinearDiscriminantAnalysis(): Reduces the dimensionality of the data by projecting it into a lower-dimensional space while maximizing the separation between classes.
- transform() on X_test: Applies the learned LDA transformation to the test data to maintain consistency with the training data.
Output:
Accuracy : 0.9333333333333333
[[10 0 0]
[ 2 10 0]
[ 0 0 8]]

Scatter plot of the iris data mapped into 2D
Advantages of LDA
- Simple and computationally efficient.
- Works well even when the number of features is much larger than the number of training samples.
- Can handle multicollinearity.
Disadvantages of LDA
- Assumes Gaussian distribution of data which may not always be the case.
- Assumes equal covariance matrices for different classes which may not hold in all datasets.
- Assumes linear separability which is not always true.
- May not always perform well in high-dimensional feature spaces.
Applications of LDA
- Face Recognition: It is used to reduce the high-dimensional feature space of pixel values in face recognition applications helping to identify faces more efficiently.
- Medical Diagnosis: It classifies disease severity in mild, moderate or severe based on patient parameters helping in decision-making for treatment.
- Customer Identification: It can help identify customer segments most likely to purchase a specific product based on survey data.
Linear Discriminant Analysis (LDA) is a technique for dimensionality reduction that not only simplifies high-dimensional data but also enhances the performance of models by maximizing class separability. By converting data into a lower-dimensional space it helps us to improve accuracy of classification task.
Similar Reads
Regularized Discriminant Analysis
Regularized Discriminant analysis Linear Discriminant analysis and QDA work straightforwardly for cases where a number of observations is far greater than the number of predictors n>p. In these situations, it offers very advantages such as ease to apply (Since we don't have to calculate the covar
3 min read
Linear and Quadratic Discriminant Analysis using Sklearn
Linear Discriminant Analysis (LDA) and Quadratic Discriminant Analysis (QDA) are two well-known classification methods that are used in machine learning to find patterns and put things into groups. They are especially helpful when you have labeled data and want to classify new observations notes int
5 min read
Normal and Shrinkage Linear Discriminant Analysis for Classification in Scikit Learn
In this article, we will try to understand the difference between Normal and Shrinkage Linear Discriminant Analysis for Classification. We will try to implement the same using sci-kit learn library in Python. But first, let's try to understand what is LDA. What is Linear discriminant analysis (LDA)?
4 min read
Gaussian Discriminant Analysis
Gaussian Discriminant Analysis (GDA) is a supervised learning algorithm used for classification tasks in machine learning. It is a variant of the Linear Discriminant Analysis (LDA) algorithm that relaxes the assumption that the covariance matrices of the different classes are equal. GDA works by ass
7 min read
Quadratic Discriminant Analysis
Linear Discriminant Analysis Now, Let's consider a classification problem represented by a Bayes Probability distribution P(Y=k | X=x), LDA does it differently by trying to model the distribution of X given the predictors class (I.e. the value of Y) P(X=x| Y=k): [Tex]P(Y=k | X=x) = \frac{P(X=x | Y=k
4 min read
Locally Linear Embedding in machine learning
LLE(Locally Linear Embedding) is an unsupervised approach designed to transform data from its original high-dimensional space into a lower-dimensional representation, all while striving to retain the essential geometric characteristics of the underlying non-linear feature structure. How Locally Line
6 min read
Curse of Dimensionality in Machine Learning
Curse of Dimensionality in Machine Learning arises when working with high-dimensional data, leading to increased computational complexity, overfitting, and spurious correlations. Techniques like dimensionality reduction, feature selection, and careful model design are essential for mitigating its ef
5 min read
Linear Algebra Operations For Machine Learning
Linear algebra is essential for many machine learning algorithms and techniques. It helps in manipulating and processing data, which is often represented as vectors and matrices. These mathematical tools make computations faster and reveal patterns within the data. It simplifies complex tasks like d
15+ min read
Interpolation in Machine Learning
In machine learning, interpolation refers to the process of estimating unknown values that fall between known data points. This can be useful in various scenarios, such as filling in missing values in a dataset or generating new data points to smooth out a curve. In this article, we are going to exp
7 min read
Linear Regression in Machine learning
Linear regression is a type of supervised machine-learning algorithm that learns from the labelled datasets and maps the data points with most optimized linear functions which can be used for prediction on new datasets.It assumes that there is a linear relationship between the input and output, mean
15+ min read