Heart Disease Prediction Using Logistic Regression and Random Forest

By Rohit Sharma

Updated on Aug 05, 2025 | 12 min read | 1.36K+ views

Share:

Heart disease is among the leading causes of death worldwide on a large scale. Detection in its early stages, along with timely medical care, can greatly decrease the chances of life-threatening complications or death.

This project makes use of machine learning to make a Heart Disease Prediction using the existing patient information, such as cholesterol, blood pressure, and age. Through training models such as Logistic Regression and Random Forest on the UCI Heart Disease dataset, our goal is to develop an effective tool that aids in rapid and precise diagnosis.

Want to get into data science? upGrad offers Online Data Science Courses in Python, Machine Learning, AI, SQL, and Tableau, taught by experts. Enroll today!

For further exploration, consider this compilation of outstanding Python Data Science Projects catering to various skill levels, from beginner to advanced.

What Should You Know Before Starting?

To work efficiently on this heart disease prediction project, you should be comfortable with the following:

  • Basic Python programming knowledge (For writing scripts, defining functions, and handling control flows)
  • Experience with data manipulation using Pandas and NumPy (Required for reading data, handling missing values, and structuring datasets)
  • Understanding of data visualisation with Matplotlib and Seaborn (Helps in generating charts like countplots, histograms, and heatmaps for EDA)
  • Knowledge of data preprocessing techniques (Includes tasks like removing unnecessary columns, encoding categorical variables, normalising data, and splitting it into training and testing sets)
  • Familiarity with Regression Algorithms (Since this is a classification problem, you should understand models like Logistic Regression and Random Forest Classifier, which predict categorical outcomes such as the presence or absence of heart disease)

Also Read - Top 35 Linear Regression Projects in Machine Learning With Source Code

Learn data science with upGrad's industry-led, top-ranked courses for direct mentorship and career guidance.

Heart Disease Prediction: How Tech Is Making It Possible

To build and evaluate the heart disease prediction model, you'll use essential Python tools for data handling, visualisation, classification modelling, and evaluation:

Tool / Library

Purpose

Python The core programming language used to write and run the entire project
Google Colab An online platform to run notebooks with ready-to-use libraries and GPU support
Pandas Reads the heart disease dataset and handles cleaning and preprocessing
NumPy Supports numerical operations and works with arrays for model input
Matplotlib / Seaborn Visualises data distribution, relationships, and correlation heatmaps
scikit-learn Used for data splitting, encoding, training models, and evaluating results
LogisticRegression A simple yet effective baseline model for binary classification
RandomForestClassifier An advanced model that handles feature interactions and boosts accuracy
Accuracy, Precision, Recall, F1-Score Key metrics to evaluate model performance and reliability

Also Read - Decision Tree vs Random Forest: Use Cases & Performance Metrics

background

Liverpool John Moores University

MS in Data Science

Double Credentials

Master's Degree17 Months

Placement Assistance

Certification6 Months

How are we Predicting Heart Disease?

To build an effective Heart Disease Prediction model, you'll apply key classification techniques that help detect the risk of heart disease based on patient health data:

  • Data preprocessing and cleaning
  • Exploratory Data Analysis (EDA)
  • Classification Algorithms (Logistic Regression & Random Forest Classifier)
  • Feature Importance Analysis
  • Model evaluation using accuracy, precision, recall, and F1-score

Also Read - Evaluation Metrics in Machine Learning: Top 10 Metrics You Should Know

Time Commitment Required 

You can complete this Heart Disease Prediction project in around 4 to 5 hours. It’s designed for beginners and intermediate learners who have basic Python skills.

Building a Heart Disease Prediction Model: A Step-by-Step Guide

Here’s how you can build this project from scratch:

  • Load the Heart Disease Dataset
    Start by importing the dataset that includes features such as age, cholesterol levels, and more, along with the target variable indicating heart disease presence.
  • Clean and Preprocess the Data
    Handle missing values (if any), encode categorical features, and scale numerical values to prepare the data for machine learning.
  • Explore and Visualise Key Insights
    Use visual tools like bar plots, histograms, pair plots, and heatmaps to analyse patterns and correlations between patient attributes and heart disease.
  • Train Classification Models
    Train models such as Logistic Regression and Random Forest Classifier to classify whether a patient is likely to have heart disease based on the given features.
  • Evaluate Model Performance
    Use classification metrics like accuracy, precision, recall, F1-score, and ROC-AUC to assess how well your model can predict heart disease outcomes.

Let's get started!

Step 1: Import Required Libraries

Before starting the heart disease prediction project, you need to import the necessary Python libraries for data handling, analysis, and visualisation.

Here's the code for importing:

import pandas as pd         # For data loading and manipulation
import numpy as np          # For numerical operations
import matplotlib.pyplot as plt  # For creating static plots
import seaborn as sns       # For attractive statistical plots

# Set default styles for plots
sns.set_style('whitegrid')
plt.style.use('fivethirtyeight')

Also Read - Libraries in Python Explained: List of Important Libraries

Step 2: Load and Inspect the Heart Disease Dataset

Now let’s load the heart disease dataset, check for missing values, and clean the data to make it ready for analysis and modelling.

# Load the dataset from the uploaded file
try:
    df = pd.read_csv('heart_disease_uci.csv')
    print("--- Dataset Loaded Successfully ---")
    print("Initial 5 rows of the dataset:")
    print(df.head())
except FileNotFoundError:
    print("Error: 'heart_disease_uci.csv' not found. Please make sure the file is in the correct directory.")
    exit()

# Check basic dataset info
print("\n--- Dataset Info ---")
df.info()

# Check for missing values represented as '?'
print("\n--- Checking for missing values ('?') ---")
for col in df.columns:
    missing_count = (df[col] == '?').sum()
    if missing_count > 0:
        print(f"Column '{col}' has {missing_count} missing values marked as '?'.")

# Replace '?' with NaN
df.replace('?', np.nan, inplace=True)

# Convert numeric columns to correct types
for col in ['trestbps', 'chol', 'thalch', 'oldpeak', 'ca']:
    df[col] = pd.to_numeric(df[col], errors='coerce')

# Fill missing numeric values with median
for col in ['trestbps', 'chol', 'thalch', 'oldpeak', 'ca']:
    df[col].fillna(df[col].median(), inplace=True)

# Fill missing categorical values with mode
for col in ['fbs', 'exang', 'restecg', 'slope', 'thal']:
    df[col].fillna(df[col].mode()[0], inplace=True)

# Final check
print("\n--- Missing values handled. Checking info again: ---")
df.info()

Output:

--- Dataset Loaded Successfully ---
Initial 5 rows of the dataset:
  id  age     sex    dataset               cp              trestbps   chol    fbs  \
0   1   63    Male  Cleveland   typical angina     145.0  233.0   True   
1   2   67    Male  Cleveland     asymptomatic     160.0  286.0  False   
2   3   67    Male  Cleveland     asymptomatic     120.0  229.0  False   
3   4   37    Male  Cleveland      non-anginal     130.0  250.0  False   
4   5   41  Female  Cleveland  atypical angina     130.0  204.0  False  

         restecg  thalch  exang  oldpeak        slope   ca  \
0  lv hypertrophy   150.0  False      2.3  downsloping  0.0   
1  lv hypertrophy   108.0   True      1.5         flat  3.0   
2  lv hypertrophy   129.0   True      2.6         flat  2.0   
3          normal   187.0  False      3.5  downsloping  0.0   
4  lv hypertrophy   172.0  False      1.4    upsloping  0.0  

               thal  num  
0       fixed defect    0  
1             normal    2  
2  reversable defect    1  
3             normal    0  
4             normal    0  

--- Dataset Info ---
RangeIndex: 920 entries, 0 to 919
Data columns (total 16 columns):
#   Column    Non-Null Count  Dtype  
---  ------    --------------  -----  
0   id        920 non-null    int64  
1   age       920 non-null    int64  
2   sex       920 non-null    object 
3   dataset   920 non-null    object 
4   cp        920 non-null    object 
5   trestbps  861 non-null    float64
6   chol      890 non-null    float64
7   fbs       830 non-null    object 
8   restecg   918 non-null    object 
9   thalch    865 non-null    float64
10  exang     865 non-null    object 
11  oldpeak   858 non-null    float64
12  slope     611 non-null    object 
13  ca        309 non-null    float64
14  thal      434 non-null    object 
15  num       920 non-null    int64  
dtypes: float64(5), int64(3), object(8)

--- Checking for missing values ('?') ---

--- Missing values handled. Checking info again: ---

RangeIndex: 920 entries, 0 to 919
Data columns (total 16 columns):
#   Column    Non-Null Count  Dtype  
---  ------    --------------  -----  
0   id        920 non-null    int64  
1   age       920 non-null    int64  
2   sex       920 non-null    object 
3   dataset   920 non-null    object 
4   cp        920 non-null    object 
5   trestbps  920 non-null    float64
6   chol      920 non-null    float64
7   fbs       920 non-null    bool   
8   restecg   920 non-null    object 
9   thalch    920 non-null    float64
10  exang     920 non-null    bool   
11  oldpeak   920 non-null    float64
12  slope     920 non-null    object 
13  ca        920 non-null    float64
14  thal      920 non-null    object 
15  num       920 non-null    int64  
dtypes: bool(2), float64(5), int64(3), object(6)

Also Read - Data Cleaning Techniques: 15 Simple & Effective Ways To Clean Data

Step 3: Prepare the Data for Modelling

In this step, you’ll convert the target variable into binary form (0 = no heart disease, 1 = disease), map gender to numeric values, and one-hot encode the categorical features. This prepares the data for machine learning.

# Convert the target column: 0 = No Disease, 1 = Disease (for all values > 0)
df['target'] = df['num'].apply(lambda x: 1 if x > 0 else 0)

# Encode 'sex': Male → 1, Female → 0
df['sex'] = df['sex'].map({'Male': 1, 'Female': 0})

# One-hot encode categorical columns
df = pd.get_dummies(df, columns=['cp', 'restecg', 'slope', 'thal'], drop_first=True)

# Drop unnecessary columns
df.drop(['id', 'dataset', 'num'], axis=1, inplace=True)

# Confirm preprocessing
print("\n--- Data Preprocessing Complete ---")
print("New dataset shape:", df.shape)
print("Columns after encoding:", df.columns.tolist())

Output:

New dataset shape: (920, 19)
Columns after encoding: ['age', 'sex', 'trestbps', 'chol', 'fbs', 'thalch', 'exang', 'oldpeak', 'ca', 'target', 'cp_atypical angina', 'cp_non-anginal', 'cp_typical angina', 'restecg_normal', 'restecg_st-t abnormality', 'slope_flat', 'slope_upsloping', 'thal_normal', 'thal_reversable defect']

Also Read - Label Encoder vs One Hot Encoder in Machine Learning

Step 4: Exploratory Data Analysis (EDA)

In this step, you’ll explore the data to understand how features like age and other variables relate to heart disease. Visualisations help uncover patterns and feature importance.

# 1. Target Variable Distribution
plt.figure(figsize=(8, 6))
sns.countplot(x='target', data=df, palette='viridis')
plt.title('Heart Disease Presence Distribution')
plt.xlabel('Heart Disease (0 = No, 1 = Yes)')
plt.ylabel('Patient Count')
plt.show()

# 2. Correlation Matrix Heatmap
plt.figure(figsize=(20, 15))
sns.heatmap(df.corr(), annot=True, cmap='coolwarm', fmt='.2f')
plt.title('Correlation Matrix of All Features')
plt.show()

# 3. Age Distribution vs. Target
plt.figure(figsize=(12, 7))
sns.histplot(data=df, x='age', hue='target', multiple='stack', palette='pastel')
plt.title('Age Distribution by Heart Disease Presence')
plt.xlabel('Age')
plt.ylabel('Count')
plt.show()

Output:

Also Read - Comprehensive Guide to Exploratory Data Analysis (EDA) in 2025: Tools, Types, and Best Practices

Step 5: Model Training and Evaluation – Logistic Regression

In this step, you'll train a Logistic Regression model as a baseline to predict heart disease. You'll evaluate it using accuracy, classification metrics, and a confusion matrix.

The code for this step is as follows:

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score

print("\n--- Starting Logistic Regression Training ---")

# Define features (X) and target (y)
X = df.drop('target', axis=1)
y = df['target']

# Split data into training and testing sets (80% train, 20% test)
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42, stratify=y)

# Scale numerical features
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)

# Train Logistic Regression model
log_reg = LogisticRegression(random_state=42, max_iter=1000)
log_reg.fit(X_train, y_train)
y_pred_lr = log_reg.predict(X_test)

# Evaluation
print("\n--- Logistic Regression Evaluation ---")
print(f"Accuracy: {accuracy_score(y_test, y_pred_lr):.4f}")
print("Classification Report:")
print(classification_report(y_test, y_pred_lr))

# Confusion Matrix
plt.figure(figsize=(8, 6))
cm_lr = confusion_matrix(y_test, y_pred_lr)
sns.heatmap(cm_lr, annot=True, fmt='d', cmap='Blues',
            xticklabels=['No Disease', 'Disease'],
            yticklabels=['No Disease', 'Disease'])
plt.title('Logistic Regression - Confusion Matrix')
plt.xlabel('Predicted Label')
plt.ylabel('True Label')
plt.show()

Output:

--- Training Logistic Regression Model ---

Logistic Regression - Evaluation:
Accuracy: 0.8424
Classification Report:
             precision    recall  f1-score   support

          0       0.84      0.79      0.82        82
          1       0.84      0.88      0.86       102

   accuracy                           0.84       184
  macro avg       0.84      0.84      0.84       184
weighted avg       0.84      0.84      0.84       184

Step 6:  Model Training and Evaluation – Random Forest Classifier

Now you’ll train a Random Forest Classifier, which is a more powerful and flexible model compared to Logistic Regression. It builds multiple decision trees and averages the results to improve prediction accuracy and control overfitting.

# 6. --- MODEL TRAINING AND EVALUATION: RANDOM FOREST ---

from sklearn.ensemble import RandomForestClassifier

print("\n--- Training Random Forest Classifier Model ---")

# Initialize and train the model
rf_clf = RandomForestClassifier(n_estimators=100, random_state=42)
rf_clf.fit(X_train, y_train)

# Make predictions
y_pred_rf = rf_clf.predict(X_test)

# Evaluation
print("\n--- Random Forest Evaluation ---")
print(f"Accuracy: {accuracy_score(y_test, y_pred_rf):.4f}")
print("Classification Report:")
print(classification_report(y_test, y_pred_rf))

# Confusion Matrix
plt.figure(figsize=(8, 6))
cm_rf = confusion_matrix(y_test, y_pred_rf)
sns.heatmap(cm_rf, annot=True, fmt='d', cmap='Greens',
            xticklabels=['No Disease', 'Disease'],
            yticklabels=['No Disease', 'Disease'])
plt.title('Random Forest - Confusion Matrix')
plt.xlabel('Predicted Label')
plt.ylabel('True Label')
plt.show()

Output:

--- Training Random Forest Classifier Model ---

Random Forest - Evaluation:
Accuracy: 0.8533
Classification Report:
             precision    recall  f1-score   support

          0       0.87      0.79      0.83        82
          1       0.84      0.90      0.87       102

   accuracy                           0.85       184
  macro avg       0.86      0.85      0.85       184
weighted avg       0.85      0.85      0.85       184

Explore this project, Airline Passenger Traffic Analysis Project Using Python

Step 7: ROC Curve Comparison – Logistic Regression vs Random Forest

To compare the performance of both classifiers, you’ll use the Receiver Operating Characteristic (ROC) Curve. It shows how well the models distinguish between the two classes (presence vs absence of heart disease). The Area Under the Curve (AUC) indicates the model’s ability to separate the classes; the higher it is better.

Here is the code for evaluating model performance: 

# 7. --- ROC CURVE COMPARISON ---

from sklearn.metrics import roc_curve, auc

# Logistic Regression ROC values
fpr_lr, tpr_lr, _ = roc_curve(y_test, log_reg.predict_proba(X_test)[:, 1])
roc_auc_lr = auc(fpr_lr, tpr_lr)

# Random Forest ROC values
fpr_rf, tpr_rf, _ = roc_curve(y_test, rf_clf.predict_proba(X_test)[:, 1])
roc_auc_rf = auc(fpr_rf, tpr_rf)

# Plotting both ROC Curves
plt.figure(figsize=(10, 8))
plt.plot(fpr_lr, tpr_lr, color='blue', lw=2,
         label=f'Logistic Regression (AUC = {roc_auc_lr:.2f})')
plt.plot(fpr_rf, tpr_rf, color='green', lw=2,
         label=f'Random Forest (AUC = {roc_auc_rf:.2f})')
plt.plot([0, 1], [0, 1], color='red', lw=2, linestyle='--', label='Chance')

plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver Operating Characteristic (ROC) Curve')
plt.legend(loc="lower right")
plt.show()

Output: 

Also Read - What is AUC ROC Curve? Implementation, Comparison & Applications

Step 8: Feature Importance – What Drives the Prediction?

Random Forest gives insights into which features are most influential in predicting heart disease. This step helps you understand what matters most in your dataset.

# 8. --- FEATURE IMPORTANCE ANALYSIS ---

# Get feature importances from Random Forest
importances = rf_clf.feature_importances_
feature_names = X.columns
feature_importance_df = pd.DataFrame({
    'feature': feature_names,
    'importance': importances
})

# Sort features by importance
feature_importance_df = feature_importance_df.sort_values(by='importance', ascending=False)

# Plot top 15 features
plt.figure(figsize=(12, 10))
sns.barplot(x='importance', y='feature',
            data=feature_importance_df.head(15),
            palette='plasma')
plt.title('Top 15 Most Important Features (Random Forest)')
plt.xlabel('Importance Score')
plt.ylabel('Feature')
plt.show()

Output:

Key insights from this step:

  • Cholesterol (chol)Maximum Heart Rate (thalach), and Age are the top three most important features.
  • Other strong predictors include ST depression (oldpeak)Exercise-induced angina (exang), and resting blood pressure (trestbps).
  • Categorical features like chest pain types (cp_atypical angina, cp_non-anginal) and thalassemia indicators also play a significant role.

This tells you which patient attributes your model relies on the most, helpful for both model refinement and practical decision-making in health diagnostics.

Final Conclusion

This Heart Disease Prediction project successfully demonstrated how machine learning models like Logistic Regression and Random Forest can be used to predict the presence of heart disease based on patient data. The Random Forest model outperformed Logistic Regression in terms of accuracy and AUC score, and also provided valuable insights into the most important health features influencing the prediction. Overall, this analysis highlights the potential of data-driven approaches in supporting medical decision-making and identifying key risk factors.

Unlock the power of data with our popular Data Science courses, designed to make you proficient in analytics, machine learning, and big data!

Elevate your career by learning essential Data Science skills such as statistical modeling, big data processing, predictive analytics, and SQL!

Stay informed and inspired with our popular Data Science articles, offering expert insights, trends, and practical tips for aspiring data professionals!

Colab Link -
https://colab.research.google.com/drive/1_bbWwxH2y-lys6mo7Xw35ycYNR6xkfvK?usp=sharing

Frequently Asked Questions (FAQs)

1. What was the goal of this project?

2. Which model performed better in this analysis?

3. What were the most important features influencing heart disease prediction?

4. How was model performance evaluated?

5. Can this model be used for real-world diagnosis?

6. What are some similar machine learning projects that beginners can try?

Rohit Sharma

827 articles published

Rohit Sharma is the Head of Revenue & Programs (International), with over 8 years of experience in business analytics, EdTech, and program management. He holds an M.Tech from IIT Delhi and specializes...

Speak with Data Science Expert

+91

By submitting, I accept the T&C and
Privacy Policy

Start Your Career in Data Science Today

Top Resources

Recommended Programs

IIIT Bangalore logo
bestseller

The International Institute of Information Technology, Bangalore

Executive Diploma in Data Science & AI

360° Career Support

Executive PG Program

12 Months

Liverpool John Moores University Logo
bestseller

Liverpool John Moores University

MS in Data Science

Double Credentials

Master's Degree

17 Months

upGrad Logo

Certification

3 Months