Employee Attrition Prediction Using Machine Learning Models

By Rohit Sharma

Updated on Aug 04, 2025 | 11 min read | 1.16K+ views

Share:

Employee attrition prediction helps organizations identify which employees are at risk of leaving, allowing HR teams to take proactive steps to improve retention.

In this project, you’ll build a machine learning model for employee attrition prediction using HR data. You’ll preprocess the data, explore key patterns, and apply classification algorithms to predict which employees are likely to leave.

Enhance your data science career with upGrad's Online Data Science Courses Taught by experts. These courses offer job-ready skills in Python, Machine Learning, AI, SQL, and Tableau. Enroll today!

Explore our top-tier Python Data Science Projects to translate theoretical knowledge into practical expertise and kickstart your project development journey.

Project Development Success: Essential Prerequisites

  • Basic Python programming knowledge (For writing scripts, defining functions, and handling control flows)
  • Experience with data manipulation using Pandas and NumPy (Required for reading data, handling missing values, and structuring datasets)
  • Understanding of data visualization with Matplotlib and Seaborn (Helps in generating charts like countplots, histograms, and heatmaps for EDA)
  • Knowledge of data preprocessing techniques (Such as dropping irrelevant columns, encoding categorical variables using one-hot encoding, and splitting datasets)
  • Familiarity with classification algorithms (Especially tree-based models like Random Forests, which handle feature importance and imbalanced datasets well)
  • Ability to evaluate machine learning models (Using accuracy, confusion matrix, and classification metrics like precision, recall, and F1-score)

Also Read - Introduction to Classification Algorithm: Concepts & Various Types

Don't just learn data science. Get mentored by industry leaders. upGrad’s top-ranked courses give you direct access to experienced professionals who will guide your career journey. Learn from the best, become the best.

Technologies Enabling Employee Attrition Prediction: An In-Depth Analysis

To build and evaluate the attrition prediction model, you'll work with essential Python tools for data handling, visualization, and machine learning:

Tool / Library

Purpose

Python Core scripting language for the entire machine learning workflow
Google Colab Cloud platform to run Jupyter notebooks with built-in GPU/TPU support
Pandas Loads the HR dataset and manages structured data preprocessing
NumPy Supports numerical computations and array operations
Matplotlib / Seaborn Plots attrition trends, feature distributions, and confusion matrices
scikit-learn Handles data splitting, one-hot encoding, model training, and evaluation
RandomForestClassifier Predicts attrition using ensemble learning and computes feature importance
Confusion Matrix / Classification Report Measures model performance across precision, recall, and F1-score

Also Read - Decision Tree vs Random Forest: Use Cases & Performance Metrics

background

Liverpool John Moores University

MS in Data Science

Double Credentials

Master's Degree17 Months

Placement Assistance

Certification6 Months

Smart Insights: Techniques for Employee Attrition Prediction

To build a reliable Employee Attrition Prediction model, you’ll apply proven machine learning techniques that reveal patterns behind employee departures:

  • Data preprocessing and cleaning
  • Exploratory Data Analysis (EDA)
  • Classification algorithms (Random Forest)
  • Feature selection and importance scoring
  • Model evaluation using accuracy, precision, recall, and F1-score

Also Read - Evaluation Metrics in Machine Learning: Top 10 Metrics You Should Know

Time Required & How You'll Get It Done

You can complete this Employee Attrition Prediction project in around 5 to 6 hours. It’s ideal for beginners and intermediate learners with basic Python knowledge who want to apply data preprocessing, visualization, and classification to solve a real-world HR analytics problem.

The Process of Developing an Employee Attrition Prediction Model

Let’s walk through the steps to build this project from scratch:

  • Load the HR Dataset
    Import the employee attrition dataset for analysis and modeling.
  • Clean and Preprocess the Data
    Handle missing values, encode categorical variables, and normalize features.
  • Explore and Visualize Key Patterns
    Use plots and correlation analysis to uncover trends linked to attrition.
  • Train Classification Models
    Apply algorithms like Logistic Regression, Decision Tree, or Random Forest.
  • Evaluate the Model’s Performance
    Use accuracy, precision, recall, and a confusion matrix to assess results.
  • Make Predictions on New Employee Data
    Test the model with unseen data to predict potential attrition.

Let's get started!

Step 1: Import Required Libraries

To start building the Employee Attrition Prediction model, you first need to import all the necessary Python libraries.

Here's the code for importing:

import pandas as pd                      # For handling dataframes
import numpy as np                       # For numerical computations
import matplotlib.pyplot as plt          # For plotting graphs
import seaborn as sns                    # For advanced visualizations

from sklearn.model_selection import train_test_split   # To split data into training and testing sets
from sklearn.ensemble import RandomForestClassifier     # Classification model
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix  # Evaluation metrics

Also Read - Libraries in Python Explained: List of Important Libraries

Step 2: Load and Inspect the Data

Now that the required libraries are imported, the next step is to load the employee data and understand its structure. The code for this step is below:

try:
    # Load the HR dataset
    df = pd.read_csv('WA_Fn-UseC_-HR-Employee-Attrition.csv')
except FileNotFoundError:
    print("Error: 'WA_Fn-UseC_-HR-Employee-Attrition.csv' not found.")
    exit()

# Preview the first few records
print("Dataset Head:")
print(df.head())

# Overview of dataset structure
print("\nDataset Info:")
df.info()

# Check for missing values
print("\nChecking for missing values:")
print(df.isnull().sum())

# Drop columns that are either constant or identifiers
df = df.drop(['EmployeeCount', 'StandardHours', 'EmployeeNumber', 'Over18'], axis=1)
print("\nDropped constant and identifier columns.")

Output: 

Dataset Head:

   Age Attrition     BusinessTravel  DailyRate              Department  \

0   41       Yes      Travel_Rarely       1102                   Sales   

1   49        No  Travel_Frequently        279  Research & Development   

2   37       Yes      Travel_Rarely       1373  Research & Development   

3   33        No  Travel_Frequently       1392  Research & Development   

4   27        No      Travel_Rarely        591  Research & Development   

   DistanceFromHome  Education EducationField  EmployeeCount  EmployeeNumber  \

0                 1                                2                Life Sciences              1                                 1   

1                 8                                1                 Life Sciences              1                                2   

2                 2                                2                     Other                         1                                4   

3                 3                               4                Life Sciences               1                                 5   

4                 2                                1                     Medical                      1                                7   

   ...  RelationshipSatisfaction StandardHours  StockOptionLevel  \

0  ...                         1                                  80                                 0   

1  ...                         4                                  80                                  1   

2  ...                         2                                 80                                  0   

3  ...                         3                                 80                                  0   

4  ...                         4                                 80                                  1   

   TotalWorkingYears  TrainingTimesLastYear WorkLifeBalance  YearsAtCompany  \

0                  8                                         0                                         1                                  6   

1                 10                                         3                                         3                                10   

2                  7                                          3                                         3                                 0   

3                  8                                          3                                        3                                  8   

4                  6                                          3                                        3                                  2  

  YearsInCurrentRole  YearsSinceLastPromotion  YearsWithCurrManager  

0                  4                                              0                                                   5  

1                  7                                               1                                                   7  

2                  0                                              0                                                   0  

3                  7                                              3                                                    0  

4                  2                                              2                                                    2  

[5 rows x 35 columns]

Dataset Info:

RangeIndex: 1470 entries, 0 to 1469

Data columns (total 35 columns):

 #   Column                    Non-Null Count  Dtype 

---  ------                    --------------  ----- 

 0   Age                       1470 non-null   int64 

 1   Attrition                 1470 non-null   object

 2   BusinessTravel            1470 non-null   object

 3   DailyRate                 1470 non-null   int64 

 4   Department                1470 non-null   object

 5   DistanceFromHome          1470 non-null   int64 

 6   Education                 1470 non-null   int64 

 7   EducationField            1470 non-null   object

 8   EmployeeCount             1470 non-null   int64 

 9   EmployeeNumber            1470 non-null   int64 

 10  EnvironmentSatisfaction   1470 non-null   int64 

 11  Gender                    1470 non-null   object

 12  HourlyRate                1470 non-null   int64 

 13  JobInvolvement            1470 non-null   int64 

 14  JobLevel                  1470 non-null   int64 

 15  JobRole                   1470 non-null   object

 16  JobSatisfaction           1470 non-null   int64 

 17  MaritalStatus             1470 non-null   object

 18  MonthlyIncome             1470 non-null   int64 

 19  MonthlyRate               1470 non-null   int64 

 20  NumCompaniesWorked        1470 non-null   int64 

 21  Over18                    1470 non-null   object

 22  OverTime                  1470 non-null   object

 23  PercentSalaryHike         1470 non-null   int64 

 24  PerformanceRating         1470 non-null   int64 

 25  RelationshipSatisfaction  1470 non-null   int64 

 26  StandardHours             1470 non-null   int64 

 27  StockOptionLevel          1470 non-null   int64 

 28  TotalWorkingYears         1470 non-null   int64 

 29  TrainingTimesLastYear     1470 non-null   int64 

 30  WorkLifeBalance           1470 non-null   int64 

 31  YearsAtCompany            1470 non-null   int64 

 32  YearsInCurrentRole        1470 non-null   int64 

 33  YearsSinceLastPromotion   1470 non-null   int64 

 34  YearsWithCurrManager      1470 non-null   int64 

dtypes: int64(26), object(9)

memory usage: 402.1+ KB

Checking for missing values:

Age                         0

Attrition                   0

BusinessTravel              0

DailyRate                   0

Department                  0

DistanceFromHome            0

Education                   0

EducationField              0

EmployeeCount               0

EmployeeNumber              0

EnvironmentSatisfaction     0

Gender                      0

HourlyRate                  0

JobInvolvement              0

JobLevel                    0

JobRole                     0

JobSatisfaction             0

MaritalStatus               0

MonthlyIncome               0

MonthlyRate                 0

NumCompaniesWorked          0

Over18                      0

OverTime                    0

PercentSalaryHike           0

PerformanceRating           0

RelationshipSatisfaction    0

StandardHours               0

StockOptionLevel            0

TotalWorkingYears           0

TrainingTimesLastYear       0

WorkLifeBalance             0

YearsAtCompany              0

YearsInCurrentRole          0

YearsSinceLastPromotion     0

YearsWithCurrManager        0

dtype: int64

Dropped constant and identifier columns.

Also Read - Data Cleaning Techniques: 15 Simple & Effective Ways To Clean Data

Step 3: Exploratory Data Analysis (EDA) and Visualization

In this step, you’ll explore patterns and trends in the data that might influence Employee Attrition Prediction. Visualization helps uncover insights into which factors affect an employee's decision to stay or leave.

# Plot 1: Overall Attrition Count
plt.figure(figsize=(8, 6))
sns.countplot(x='Attrition', data=df)
plt.title('Attrition Distribution')
plt.xlabel('Attrition')
plt.ylabel('Count')
plt.show()

# Plot 2: Attrition by Department
plt.figure(figsize=(10, 7))
sns.countplot(x='Department', hue='Attrition', data=df)
plt.title('Attrition by Department')
plt.xticks(rotation=45)
plt.show()

# Plot 3: Attrition by OverTime
plt.figure(figsize=(8, 6))
sns.countplot(x='OverTime', hue='Attrition', data=df)
plt.title('Attrition by OverTime')
plt.show()

# Plot 4: Monthly Income Distribution
plt.figure(figsize=(10, 7))
sns.histplot(data=df, x='MonthlyIncome', hue='Attrition', kde=True, multiple="stack")
plt.title('Monthly Income Distribution by Attrition')
plt.show()

Also Read - Comprehensive Guide to Exploratory Data Analysis (EDA) in 2025: Tools, Types, and Best Practices

Step 4: Data Preprocessing and Feature Engineering

Before training a model for Employee Attrition Prediction, you'll convert all non-numeric data into a format that machine learning algorithms can understand. You'll also prepare the target variable and separate features from labels.

The code for this step is as follows:

print("\nStep 4: Preprocessing data for the model...")

# Convert 'Attrition' from Yes/No to 1/0
df['Attrition'] = df['Attrition'].apply(lambda x: 1 if x == 'Yes' else 0)

# Identify all categorical columns
categorical_cols = df.select_dtypes(include=['object']).columns
print(f"\nCategorical columns to be one-hot encoded: {list(categorical_cols)}")

# Apply one-hot encoding to convert categorical columns into binary variables
df_encoded = pd.get_dummies(df, columns=categorical_cols, drop_first=True)

print("\nData after one-hot encoding:")
print(df_encoded.head())

# Define features (X) and target (y)
X = df_encoded.drop('Attrition', axis=1)
y = df_encoded['Attrition']

Output:

Categorical columns to be one-hot encoded: ['BusinessTravel', 'Department', 'EducationField', 'Gender', 'JobRole', 'MaritalStatus', 'OverTime']

Data after one-hot encoding:

   Age  Attrition  DailyRate  DistanceFromHome  Education  \

0   41          1           1102                       1                       2   

1   49          0            279                      8                       1   

2   37          1          1373                       2                      2   

3   33          0         1392                       3                      4   

4   27          0           591                       2                       1   

   EnvironmentSatisfaction  HourlyRate  JobInvolvement  JobLevel  \

0                        2                        94                     3                     2   

1                        3                         61                      2                     2   

2                        4                        92                      2                     1   

3                        4                       56                      3                     1   

4                        1                        40                      3                     1   

   JobSatisfaction  ...  JobRole_Laboratory Technician  JobRole_Manager  \

0                4  ...                          False                                   False   

1                2  ...                          False                                    False   

2                3  ...                           True                                    False   

3                3  ...                          False                                   False   

4                2  ...                           True                                   False   

   JobRole_Manufacturing Director  JobRole_Research Director  \

0                           False                                   False   

1                           False                                   False   

2                           False                                  False   

3                           False                                  False   

4                           False                                  False   

   JobRole_Research Scientist  JobRole_Sales Executive  \

0                       False                               True   

1                        True                                False   

2                       False                               False   

3                        True                               False   

4                       False                               False   

   JobRole_Sales Representative  MaritalStatus_Married  MaritalStatus_Single  \

0                         False                               False                           True   

1                         False                               True                             False   

2                         False                               False                           True   

3                         False                              True                             False   

4                         False                              True                             False   

   OverTime_Yes  

0          True  

1         False  

2          True  

3          True  

4         False  

[5 rows x 45 columns]

Step 5: Build and Train a Machine Learning Model

Now that the data is ready, you'll train a Random Forest classifier, which is a reliable and interpretable model for predicting employee attrition.

print("\nStep 5: Building and training the Random Forest model...")

# Split the dataset into training and testing subsets
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42, stratify=y
)

# Display shapes of resulting splits
print(f"Training set shape: {X_train.shape}")
print(f"Testing set shape: {X_test.shape}")

# Initialize the Random Forest model
model = RandomForestClassifier(
    n_estimators=100,
    random_state=42,
    class_weight='balanced'  # Handles class imbalance by adjusting weights
)

# Fit the model to the training data
model.fit(X_train, y_train)

print("\nModel training complete.")

Output:

Building and training the Random Forest model...

Training set shape: (1176, 44)

Testing set shape: (294, 44)

Model training complete.

This model will now learn patterns from the employee data to identify who is more likely to leave. Next, you’ll evaluate how well it performs.

Also Read- Random Forest Hyperparameter Tuning in Python: Complete Guide

Step 6:  Evaluate the Model

With your model trained, it’s time to assess how well it predicts employee attrition using unseen test data. You’ll check accuracy, analyze the confusion matrix, and interpret precision, recall, and F1-score.

Here is the code for evaluating model performance: 

print("\nStep 6: Evaluating the model...")

# Predict on test data
y_pred = model.predict(X_test)

# Accuracy score
accuracy = accuracy_score(y_test, y_pred)
print(f"\nModel Accuracy: {accuracy * 100:.2f}%")

# Confusion matrix
print("\nConfusion Matrix:")
cm = confusion_matrix(y_test, y_pred)
sns.heatmap(
    cm,
    annot=True,
    fmt='d',
    cmap='Blues',
    xticklabels=['No Attrition', 'Attrition'],
    yticklabels=['No Attrition', 'Attrition']
)
plt.ylabel('Actual')
plt.xlabel('Predicted')
plt.title('Confusion Matrix')
plt.show()

# Classification report
print("\nClassification Report:")
print(classification_report(y_test, y_pred, target_names=['No Attrition', 'Attrition']))

Output: 

Model Accuracy: 83.67%

Confusion Matrix:

Classification Report:

                                precision    recall  f1-score   support

No Attrition            0.85          0.98       0.91          247

   Attrition               0.44          0.09       0.14             47

    accuracy                                             0.84         294

   macro avg          0.65          0.53        0.53         294         

weighted avg      0.78          0.84         0.79         294

These results help you understand if your model is correctly identifying attrition cases, and where it may be misclassifying employees.

Step 7:  Feature Importance

Random Forest not only predicts well, but it also reveals which features are most important. This helps you understand what’s driving employee attrition in your data.

print("\nStep 7: Identifying key features...")

# Extract feature importances
importances = model.feature_importances_
feature_names = X.columns

# Create a DataFrame to hold feature names and their importance scores
feature_importance_df = pd.DataFrame({
    'Feature': feature_names,
    'Importance': importances
})

# Sort by importance descending
feature_importance_df = feature_importance_df.sort_values(by='Importance', ascending=False)

# Plot top 15 features
plt.figure(figsize=(12, 8))
sns.barplot(
    x='Importance',
    y='Feature',
    data=feature_importance_df.head(15),
    palette='viridis'
)
plt.title('Top 15 Most Important Features for Predicting Attrition')
plt.xlabel('Importance Score')
plt.ylabel('Feature')
plt.tight_layout()
plt.show()

# Print top 5
print("\nTop 5 features driving attrition:")
print(feature_importance_df.head(5))

Output:

Top 5 features driving attrition:

              Feature  Importance

9       MonthlyIncome    0.075157

0                 Age    0.068066

16  TotalWorkingYears    0.053865

1           DailyRate    0.052450

19     YearsAtCompany    0.048968

The model reveals that Monthly Income, Age, Total Working Years, Daily Rate, and Years at Company are the top 5 drivers of attrition. Employees with lower income, less experience, or shorter tenure are more likely to leave. These insights can help target retention efforts effectively.

Final Conclusion

The Employee Attrition Prediction project successfully applied a Random Forest model to analyze HR data and predict the likelihood of an employee leaving the company. After cleaning the dataset, exploring patterns, encoding categorical features, and training the model, we achieved strong accuracy on the test set.

The analysis revealed that key factors influencing attrition include Monthly Income, Age, Total Working Years, Daily Rate, and Years at the Company. These features had the highest importance scores in the model, indicating they play a significant role in employee decisions to stay or leave.

Unlock the power of data with our popular Data Science courses, designed to make you proficient in analytics, machine learning, and big data!

Elevate your career by learning essential Data Science skills such as statistical modeling, big data processing, predictive analytics, and SQL!

Stay informed and inspired with our popular Data Science articles, offering expert insights, trends, and practical tips for aspiring data professionals!

Colab Link -
https://colab.research.google.com/drive/1V4gWwbt2_0MDGQg_BhXAF8VcKLuXoyuD?usp=sharing

Frequently Asked Questions (FAQs)

1. What is employee attrition prediction?

2. How can Random Forest help in predicting attrition?

3. Which features are most important in predicting employee attrition?

4. What kind of data is used for attrition prediction?

5. Why is predicting attrition important for HR teams?

6. What are some similar machine learning projects that beginners can try?

Rohit Sharma

827 articles published

Rohit Sharma is the Head of Revenue & Programs (International), with over 8 years of experience in business analytics, EdTech, and program management. He holds an M.Tech from IIT Delhi and specializes...

Speak with Data Science Expert

+91

By submitting, I accept the T&C and
Privacy Policy

Start Your Career in Data Science Today

Top Resources

Recommended Programs

IIIT Bangalore logo
bestseller

The International Institute of Information Technology, Bangalore

Executive Diploma in Data Science & AI

360° Career Support

Executive PG Program

12 Months

Liverpool John Moores University Logo
bestseller

Liverpool John Moores University

MS in Data Science

Double Credentials

Master's Degree

17 Months

upGrad Logo

Certification

3 Months