Employee Attrition Prediction Using Machine Learning Models
By Rohit Sharma
Updated on Aug 04, 2025 | 11 min read | 1.16K+ views
Share:
For working professionals
For fresh graduates
More
By Rohit Sharma
Updated on Aug 04, 2025 | 11 min read | 1.16K+ views
Share:
Table of Contents
Employee attrition prediction helps organizations identify which employees are at risk of leaving, allowing HR teams to take proactive steps to improve retention.
In this project, you’ll build a machine learning model for employee attrition prediction using HR data. You’ll preprocess the data, explore key patterns, and apply classification algorithms to predict which employees are likely to leave.
Enhance your data science career with upGrad's Online Data Science Courses Taught by experts. These courses offer job-ready skills in Python, Machine Learning, AI, SQL, and Tableau. Enroll today!
Explore our top-tier Python Data Science Projects to translate theoretical knowledge into practical expertise and kickstart your project development journey.
Popular Data Science Programs
Also Read - Introduction to Classification Algorithm: Concepts & Various Types
Don't just learn data science. Get mentored by industry leaders. upGrad’s top-ranked courses give you direct access to experienced professionals who will guide your career journey. Learn from the best, become the best.
To build and evaluate the attrition prediction model, you'll work with essential Python tools for data handling, visualization, and machine learning:
Tool / Library |
Purpose |
Python | Core scripting language for the entire machine learning workflow |
Google Colab | Cloud platform to run Jupyter notebooks with built-in GPU/TPU support |
Pandas | Loads the HR dataset and manages structured data preprocessing |
NumPy | Supports numerical computations and array operations |
Matplotlib / Seaborn | Plots attrition trends, feature distributions, and confusion matrices |
scikit-learn | Handles data splitting, one-hot encoding, model training, and evaluation |
RandomForestClassifier | Predicts attrition using ensemble learning and computes feature importance |
Confusion Matrix / Classification Report | Measures model performance across precision, recall, and F1-score |
Also Read - Decision Tree vs Random Forest: Use Cases & Performance Metrics
To build a reliable Employee Attrition Prediction model, you’ll apply proven machine learning techniques that reveal patterns behind employee departures:
Also Read - Evaluation Metrics in Machine Learning: Top 10 Metrics You Should Know
You can complete this Employee Attrition Prediction project in around 5 to 6 hours. It’s ideal for beginners and intermediate learners with basic Python knowledge who want to apply data preprocessing, visualization, and classification to solve a real-world HR analytics problem.
Let’s walk through the steps to build this project from scratch:
Let's get started!
To start building the Employee Attrition Prediction model, you first need to import all the necessary Python libraries.
Here's the code for importing:
import pandas as pd # For handling dataframes
import numpy as np # For numerical computations
import matplotlib.pyplot as plt # For plotting graphs
import seaborn as sns # For advanced visualizations
from sklearn.model_selection import train_test_split # To split data into training and testing sets
from sklearn.ensemble import RandomForestClassifier # Classification model
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix # Evaluation metrics
Also Read - Libraries in Python Explained: List of Important Libraries
Now that the required libraries are imported, the next step is to load the employee data and understand its structure. The code for this step is below:
try:
# Load the HR dataset
df = pd.read_csv('WA_Fn-UseC_-HR-Employee-Attrition.csv')
except FileNotFoundError:
print("Error: 'WA_Fn-UseC_-HR-Employee-Attrition.csv' not found.")
exit()
# Preview the first few records
print("Dataset Head:")
print(df.head())
# Overview of dataset structure
print("\nDataset Info:")
df.info()
# Check for missing values
print("\nChecking for missing values:")
print(df.isnull().sum())
# Drop columns that are either constant or identifiers
df = df.drop(['EmployeeCount', 'StandardHours', 'EmployeeNumber', 'Over18'], axis=1)
print("\nDropped constant and identifier columns.")
Output:
Dataset Head:
Age Attrition BusinessTravel DailyRate Department \
0 41 Yes Travel_Rarely 1102 Sales
1 49 No Travel_Frequently 279 Research & Development
2 37 Yes Travel_Rarely 1373 Research & Development
3 33 No Travel_Frequently 1392 Research & Development
4 27 No Travel_Rarely 591 Research & Development
DistanceFromHome Education EducationField EmployeeCount EmployeeNumber \
0 1 2 Life Sciences 1 1
1 8 1 Life Sciences 1 2
2 2 2 Other 1 4
3 3 4 Life Sciences 1 5
4 2 1 Medical 1 7
... RelationshipSatisfaction StandardHours StockOptionLevel \
0 ... 1 80 0
1 ... 4 80 1
2 ... 2 80 0
3 ... 3 80 0
4 ... 4 80 1
TotalWorkingYears TrainingTimesLastYear WorkLifeBalance YearsAtCompany \
0 8 0 1 6
1 10 3 3 10
2 7 3 3 0
3 8 3 3 8
4 6 3 3 2
YearsInCurrentRole YearsSinceLastPromotion YearsWithCurrManager
0 4 0 5
1 7 1 7
2 0 0 0
3 7 3 0
4 2 2 2
[5 rows x 35 columns]
Dataset Info:
RangeIndex: 1470 entries, 0 to 1469
Data columns (total 35 columns):
# Column Non-Null Count Dtype
--- ------ -------------- -----
0 Age 1470 non-null int64
1 Attrition 1470 non-null object
2 BusinessTravel 1470 non-null object
3 DailyRate 1470 non-null int64
4 Department 1470 non-null object
5 DistanceFromHome 1470 non-null int64
6 Education 1470 non-null int64
7 EducationField 1470 non-null object
8 EmployeeCount 1470 non-null int64
9 EmployeeNumber 1470 non-null int64
10 EnvironmentSatisfaction 1470 non-null int64
11 Gender 1470 non-null object
12 HourlyRate 1470 non-null int64
13 JobInvolvement 1470 non-null int64
14 JobLevel 1470 non-null int64
15 JobRole 1470 non-null object
16 JobSatisfaction 1470 non-null int64
17 MaritalStatus 1470 non-null object
18 MonthlyIncome 1470 non-null int64
19 MonthlyRate 1470 non-null int64
20 NumCompaniesWorked 1470 non-null int64
21 Over18 1470 non-null object
22 OverTime 1470 non-null object
23 PercentSalaryHike 1470 non-null int64
24 PerformanceRating 1470 non-null int64
25 RelationshipSatisfaction 1470 non-null int64
26 StandardHours 1470 non-null int64
27 StockOptionLevel 1470 non-null int64
28 TotalWorkingYears 1470 non-null int64
29 TrainingTimesLastYear 1470 non-null int64
30 WorkLifeBalance 1470 non-null int64
31 YearsAtCompany 1470 non-null int64
32 YearsInCurrentRole 1470 non-null int64
33 YearsSinceLastPromotion 1470 non-null int64
34 YearsWithCurrManager 1470 non-null int64
dtypes: int64(26), object(9)
memory usage: 402.1+ KB
Checking for missing values:
Age 0
Attrition 0
BusinessTravel 0
DailyRate 0
Department 0
DistanceFromHome 0
Education 0
EducationField 0
EmployeeCount 0
EmployeeNumber 0
EnvironmentSatisfaction 0
Gender 0
HourlyRate 0
JobInvolvement 0
JobLevel 0
JobRole 0
JobSatisfaction 0
MaritalStatus 0
MonthlyIncome 0
MonthlyRate 0
NumCompaniesWorked 0
Over18 0
OverTime 0
PercentSalaryHike 0
PerformanceRating 0
RelationshipSatisfaction 0
StandardHours 0
StockOptionLevel 0
TotalWorkingYears 0
TrainingTimesLastYear 0
WorkLifeBalance 0
YearsAtCompany 0
YearsInCurrentRole 0
YearsSinceLastPromotion 0
YearsWithCurrManager 0
dtype: int64
Dropped constant and identifier columns.
Also Read - Data Cleaning Techniques: 15 Simple & Effective Ways To Clean Data
In this step, you’ll explore patterns and trends in the data that might influence Employee Attrition Prediction. Visualization helps uncover insights into which factors affect an employee's decision to stay or leave.
# Plot 1: Overall Attrition Count
plt.figure(figsize=(8, 6))
sns.countplot(x='Attrition', data=df)
plt.title('Attrition Distribution')
plt.xlabel('Attrition')
plt.ylabel('Count')
plt.show()
# Plot 2: Attrition by Department
plt.figure(figsize=(10, 7))
sns.countplot(x='Department', hue='Attrition', data=df)
plt.title('Attrition by Department')
plt.xticks(rotation=45)
plt.show()
# Plot 3: Attrition by OverTime
plt.figure(figsize=(8, 6))
sns.countplot(x='OverTime', hue='Attrition', data=df)
plt.title('Attrition by OverTime')
plt.show()
# Plot 4: Monthly Income Distribution
plt.figure(figsize=(10, 7))
sns.histplot(data=df, x='MonthlyIncome', hue='Attrition', kde=True, multiple="stack")
plt.title('Monthly Income Distribution by Attrition')
plt.show()
Also Read - Comprehensive Guide to Exploratory Data Analysis (EDA) in 2025: Tools, Types, and Best Practices
Before training a model for Employee Attrition Prediction, you'll convert all non-numeric data into a format that machine learning algorithms can understand. You'll also prepare the target variable and separate features from labels.
The code for this step is as follows:
print("\nStep 4: Preprocessing data for the model...")
# Convert 'Attrition' from Yes/No to 1/0
df['Attrition'] = df['Attrition'].apply(lambda x: 1 if x == 'Yes' else 0)
# Identify all categorical columns
categorical_cols = df.select_dtypes(include=['object']).columns
print(f"\nCategorical columns to be one-hot encoded: {list(categorical_cols)}")
# Apply one-hot encoding to convert categorical columns into binary variables
df_encoded = pd.get_dummies(df, columns=categorical_cols, drop_first=True)
print("\nData after one-hot encoding:")
print(df_encoded.head())
# Define features (X) and target (y)
X = df_encoded.drop('Attrition', axis=1)
y = df_encoded['Attrition']
Output:
Categorical columns to be one-hot encoded: ['BusinessTravel', 'Department', 'EducationField', 'Gender', 'JobRole', 'MaritalStatus', 'OverTime']
Data after one-hot encoding:
Age Attrition DailyRate DistanceFromHome Education \
0 41 1 1102 1 2
1 49 0 279 8 1
2 37 1 1373 2 2
3 33 0 1392 3 4
4 27 0 591 2 1
EnvironmentSatisfaction HourlyRate JobInvolvement JobLevel \
0 2 94 3 2
1 3 61 2 2
2 4 92 2 1
3 4 56 3 1
4 1 40 3 1
JobSatisfaction ... JobRole_Laboratory Technician JobRole_Manager \
0 4 ... False False
1 2 ... False False
2 3 ... True False
3 3 ... False False
4 2 ... True False
JobRole_Manufacturing Director JobRole_Research Director \
0 False False
1 False False
2 False False
3 False False
4 False False
JobRole_Research Scientist JobRole_Sales Executive \
0 False True
1 True False
2 False False
3 True False
4 False False
JobRole_Sales Representative MaritalStatus_Married MaritalStatus_Single \
0 False False True
1 False True False
2 False False True
3 False True False
4 False True False
OverTime_Yes
0 True
1 False
2 True
3 True
4 False
[5 rows x 45 columns]
Now that the data is ready, you'll train a Random Forest classifier, which is a reliable and interpretable model for predicting employee attrition.
print("\nStep 5: Building and training the Random Forest model...")
# Split the dataset into training and testing subsets
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42, stratify=y
)
# Display shapes of resulting splits
print(f"Training set shape: {X_train.shape}")
print(f"Testing set shape: {X_test.shape}")
# Initialize the Random Forest model
model = RandomForestClassifier(
n_estimators=100,
random_state=42,
class_weight='balanced' # Handles class imbalance by adjusting weights
)
# Fit the model to the training data
model.fit(X_train, y_train)
print("\nModel training complete.")
Output:
Building and training the Random Forest model...
Training set shape: (1176, 44)
Testing set shape: (294, 44)
Model training complete.
This model will now learn patterns from the employee data to identify who is more likely to leave. Next, you’ll evaluate how well it performs.
Also Read- Random Forest Hyperparameter Tuning in Python: Complete Guide
With your model trained, it’s time to assess how well it predicts employee attrition using unseen test data. You’ll check accuracy, analyze the confusion matrix, and interpret precision, recall, and F1-score.
Here is the code for evaluating model performance:
print("\nStep 6: Evaluating the model...")
# Predict on test data
y_pred = model.predict(X_test)
# Accuracy score
accuracy = accuracy_score(y_test, y_pred)
print(f"\nModel Accuracy: {accuracy * 100:.2f}%")
# Confusion matrix
print("\nConfusion Matrix:")
cm = confusion_matrix(y_test, y_pred)
sns.heatmap(
cm,
annot=True,
fmt='d',
cmap='Blues',
xticklabels=['No Attrition', 'Attrition'],
yticklabels=['No Attrition', 'Attrition']
)
plt.ylabel('Actual')
plt.xlabel('Predicted')
plt.title('Confusion Matrix')
plt.show()
# Classification report
print("\nClassification Report:")
print(classification_report(y_test, y_pred, target_names=['No Attrition', 'Attrition']))
Output:
Model Accuracy: 83.67%
Confusion Matrix:
Classification Report:
precision recall f1-score support
No Attrition 0.85 0.98 0.91 247
Attrition 0.44 0.09 0.14 47
accuracy 0.84 294
macro avg 0.65 0.53 0.53 294
weighted avg 0.78 0.84 0.79 294
These results help you understand if your model is correctly identifying attrition cases, and where it may be misclassifying employees.
Random Forest not only predicts well, but it also reveals which features are most important. This helps you understand what’s driving employee attrition in your data.
print("\nStep 7: Identifying key features...")
# Extract feature importances
importances = model.feature_importances_
feature_names = X.columns
# Create a DataFrame to hold feature names and their importance scores
feature_importance_df = pd.DataFrame({
'Feature': feature_names,
'Importance': importances
})
# Sort by importance descending
feature_importance_df = feature_importance_df.sort_values(by='Importance', ascending=False)
# Plot top 15 features
plt.figure(figsize=(12, 8))
sns.barplot(
x='Importance',
y='Feature',
data=feature_importance_df.head(15),
palette='viridis'
)
plt.title('Top 15 Most Important Features for Predicting Attrition')
plt.xlabel('Importance Score')
plt.ylabel('Feature')
plt.tight_layout()
plt.show()
# Print top 5
print("\nTop 5 features driving attrition:")
print(feature_importance_df.head(5))
Output:
Top 5 features driving attrition:
Feature Importance
9 MonthlyIncome 0.075157
0 Age 0.068066
16 TotalWorkingYears 0.053865
1 DailyRate 0.052450
19 YearsAtCompany 0.048968
The model reveals that Monthly Income, Age, Total Working Years, Daily Rate, and Years at Company are the top 5 drivers of attrition. Employees with lower income, less experience, or shorter tenure are more likely to leave. These insights can help target retention efforts effectively.
The Employee Attrition Prediction project successfully applied a Random Forest model to analyze HR data and predict the likelihood of an employee leaving the company. After cleaning the dataset, exploring patterns, encoding categorical features, and training the model, we achieved strong accuracy on the test set.
The analysis revealed that key factors influencing attrition include Monthly Income, Age, Total Working Years, Daily Rate, and Years at the Company. These features had the highest importance scores in the model, indicating they play a significant role in employee decisions to stay or leave.
Unlock the power of data with our popular Data Science courses, designed to make you proficient in analytics, machine learning, and big data!
Elevate your career by learning essential Data Science skills such as statistical modeling, big data processing, predictive analytics, and SQL!
Stay informed and inspired with our popular Data Science articles, offering expert insights, trends, and practical tips for aspiring data professionals!
Colab Link -
https://colab.research.google.com/drive/1V4gWwbt2_0MDGQg_BhXAF8VcKLuXoyuD?usp=sharing
827 articles published
Rohit Sharma is the Head of Revenue & Programs (International), with over 8 years of experience in business analytics, EdTech, and program management. He holds an M.Tech from IIT Delhi and specializes...
Speak with Data Science Expert
By submitting, I accept the T&C and
Privacy Policy
Start Your Career in Data Science Today
Top Resources