Skip to main content

Introduction to AI/ML in Bioinformatics: Classification Models & Evaluation

· 12 min read
Thanh-Giang Tan Nguyen
Founder at RIVER

Machine learning is transforming bioinformatics by automating pattern discovery from biological data. But what problems can it actually solve? This post shows real-world applications of classification models, then builds the simplest possible classifiers to understand how they work and how to evaluate them. This is Part 0—the practical foundation before diving into complex algorithms like KNN.

Real-World Classification Problems in Bioinformatics

Let's start by understanding what classification problems machine learning actually solves:

Problem 1: Disease Diagnosis from Gene Expression

  • Input: Gene expression levels from patient blood sample
  • Output: Normal or Disease (binary classification)
  • Real application: Cancer subtypes, Alzheimer's stages, COVID severity
  • Goal: Classify new patients into disease categories automatically
Patient A: [Gene1=2.3, Gene2=1.5, Gene3=4.2, ...] → Normal
Patient B: [Gene1=8.1, Gene2=6.3, Gene3=1.9, ...] → Disease
Patient C: [Gene1=3.1, Gene2=2.2, Gene3=5.1, ...] → Normal

Problem 2: Protein Function Prediction

  • Input: Amino acid sequence
  • Output: Enzyme, Structural protein, or Transport protein (multi-class)
  • Real application: Annotating newly sequenced genomes
  • Goal: Predict function of unknown proteins

Problem 3: Pathogenic Variant Detection

  • Input: DNA mutation information, population frequency, conservation score
  • Output: Pathogenic or Benign (binary classification)
  • Real application: Clinical variant interpretation
  • Goal: Identify disease-causing mutations from huge variant databases

Why Build Classification Models? The Power of Automation

Before machine learning, biologists manually analyzed data:

  • Slow: Analyzing 20,000 genes one-by-one takes months
  • Subjective: Different experts might disagree on classification
  • Doesn't scale: 10,000 patient samples requires endless manual work

With ML classification models:

  • Fast: Classify 10,000 samples in seconds
  • Objective: Same rules applied consistently to every sample
  • Scalable: Works for any dataset size without additional effort

Part 1: The Simplest Classification Models

Let's build real but minimal classification models, starting from the simplest to more sophisticated.

Setup: Simulated Gene Expression Data

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# Set seed for reproducibility
np.random.seed(42)

# Simulate gene expression data for disease classification
# 50 normal patients, 50 disease patients
# 3 genes measured via RNA-seq

# Normal patients: lower expression
normal_samples = np.random.normal(loc=5, scale=2, size=(50, 3))

# Disease patients: higher expression (especially gene 1)
disease_samples = np.random.normal(loc=8, scale=2, size=(50, 3))

# Combine data
X = np.vstack([normal_samples, disease_samples]) # Features (gene expression)
y = np.hstack([np.zeros(50), np.ones(50)]) # Labels (0=normal, 1=disease)

# Create DataFrame for easier inspection
gene_names = ['Gene_A', 'Gene_B', 'Gene_C']
df = pd.DataFrame(X, columns=gene_names)
df['Label'] = y
df['Label_name'] = df['Label'].map({0: 'Normal', 1: 'Disease'})

print("Gene Expression Data Sample:")
print(df.head(10))
print(f"\nDataset shape: {X.shape[0]} patients, {X.shape[1]} genes")
print(f"Classes: {int(sum(y==0))} Normal, {int(sum(y==1))} Disease")

# Split data: 80% train, 20% test
split_idx = 80
X_train, X_test = X[:split_idx], X[split_idx:]
y_train, y_test = y[:split_idx], y[split_idx:]

print(f"\nTraining set: {X_train.shape[0]} samples")
print(f"Test set: {X_test.shape[0]} samples")

Output:

Gene Expression Data Sample:
Gene_A Gene_B Gene_C Label Label_name
0 4.86 7.02 4.23 0 Normal
1 5.12 3.54 5.67 0 Normal
2 3.45 4.89 6.12 0 Normal
...
Dataset shape: 100 patients, 3 genes
Classes: 50 Normal, 50 Disease

Training set: 80 samples
Test set: 20 samples

Model 1: Rule-Based Classifier (Simplest Possible)

The simplest classifier is just a rule based on one gene:

class SimpleRuleClassifier:
"""
Classify patients based on a simple threshold rule.
If Gene_A > threshold: predict disease
Else: predict normal

This is how a biologist might manually classify before ML!
"""
def __init__(self, gene_index=0, threshold=6.5):
self.gene_index = gene_index
self.threshold = threshold
self.gene_name = gene_names[gene_index]

def predict(self, X):
"""Make predictions based on simple rule."""
predictions = (X[:, self.gene_index] > self.threshold).astype(int)
return predictions

def __repr__(self):
return f"Rule: If {self.gene_name} > {self.threshold}, predict Disease"

# Create and test the rule-based model
model1 = SimpleRuleClassifier(gene_index=0, threshold=6.5)
y_pred1 = model1.predict(X_test)

print(f"\nModel 1: {model1}")
print(f"Sample predictions: {y_pred1[:10]}")

Model 2: Multi-Gene Mean Classifier

Slightly better: use the average of all genes:

class MeanClassifier:
"""
Classify based on mean expression across all genes.
If mean(all genes) > threshold: predict disease
Else: predict normal

This combines information from multiple genes!
"""
def __init__(self, threshold=6.5):
self.threshold = threshold

def predict(self, X):
"""Make predictions based on mean expression."""
mean_expression = X.mean(axis=1) # Average across genes
predictions = (mean_expression > self.threshold).astype(int)
return predictions

# Test mean-based model
model2 = MeanClassifier(threshold=6.5)
y_pred2 = model2.predict(X_test)

print(f"\nModel 2: Mean-based classifier (threshold=6.5)")
print(f"Sample predictions: {y_pred2[:10]}")

Model 3: Distance-Based Classifier (Nearest Centroid)

Even better: find the center of each class, classify by distance:

class NearestCentroidClassifier:
"""
Classify based on distance to class centroids.

Algorithm:
1. During training: Calculate mean (centroid) of each class
2. During prediction: For new sample, predict class of nearest centroid

This is the foundation for more complex algorithms like KNN!
Think of it as: "Find which disease type the patient is closest to"
"""
def __init__(self):
self.centroids = {}

def fit(self, X, y):
"""Learn the centroid (average) of each class."""
for class_label in np.unique(y):
self.centroids[class_label] = X[y == class_label].mean(axis=0)
print(f"✓ Learned {len(self.centroids)} class centroids")
print(f" Normal centroid: {self.centroids[0]}")
print(f" Disease centroid: {self.centroids[1]}")

def predict(self, X):
"""Predict by finding nearest centroid."""
predictions = []
for sample in X:
# Calculate distance to each centroid
distances = {}
for class_label, centroid in self.centroids.items():
# Euclidean distance
distance = np.sqrt(np.sum((sample - centroid) ** 2))
distances[class_label] = distance

# Predict class of nearest centroid
prediction = min(distances, key=distances.get)
predictions.append(prediction)

return np.array(predictions)

# Train and test
model3 = NearestCentroidClassifier()
model3.fit(X_train, y_train)
y_pred3 = model3.predict(X_test)

print(f"\nModel 3: Nearest Centroid Classifier")
print(f"Sample predictions: {y_pred3[:10]}")

Output:

✓ Learned 2 class centroids
Normal centroid: [4.87 4.92 5.11]
Disease centroid: [7.98 8.15 8.23]

Model 3: Nearest Centroid Classifier
Sample predictions: [1 1 0 1 0 0 0 0 1 0]

Part 2: Evaluating Classification Models

Now that we have predictions, how do we know which model is good? We need evaluation metrics!

Key Metrics Explained (Especially for Imbalanced Data)

Confusion Matrix

                 Predicted
Disease Healthy
Actual
Disease TP FN
Healthy FP TN

Definitions:

  • TP (True Positive): Correctly identified disease → Good! ✓
  • TN (True Negative): Correctly identified healthy → Good! ✓
  • FP (False Positive): Healthy person predicted as disease → False alarm ✗
  • FN (False Negative): Disease person predicted as healthy → Dangerous! ✗

Metrics (Importance for Imbalanced Data):

MetricFormulaMeaningBest For
Accuracy(TP+TN)/(TP+TN+FP+FN)Overall correctnessMISLEADING for imbalanced data
Sensitivity/RecallTP/(TP+FN)% of disease cases caught✅ Essential for imbalanced data
SpecificityTN/(TN+FP)% of healthy cases caught✅ Reveals when model ignores minority class
PrecisionTP/(TP+FP)% of disease predictions correct✅ Shows cost of false alarms

Why Recall + Specificity Matter More Than Accuracy for Imbalanced Data:

In our test set: 19 disease, 1 healthy (95% disease cases)

  • Model predicts "disease" for everything
  • Accuracy = 95% (looks good!)
  • But Specificity = 0% (completely missed the healthy person!)

Recall and Specificity immediately show the problem.

def evaluate_classification(y_true, y_pred, model_name="Model"):
"""Evaluate a classification model using confusion matrix and metrics."""

# Calculate confusion matrix
TP = np.sum((y_true == 1) & (y_pred == 1))
TN = np.sum((y_true == 0) & (y_pred == 0))
FP = np.sum((y_true == 0) & (y_pred == 1))
FN = np.sum((y_true == 1) & (y_pred == 0))

# Calculate metrics
accuracy = (TP + TN) / len(y_true)
sensitivity = TP / (TP + FN) if (TP + FN) > 0 else 0 # True Positive Rate
specificity = TN / (TN + FP) if (TN + FP) > 0 else 0 # True Negative Rate
precision = TP / (TP + FP) if (TP + FP) > 0 else 0 # "When we predict disease, are we right?"
recall = sensitivity # Same as sensitivity - "Did we catch disease cases?"

# Print results
print(f"\n{'='*70}")
print(f"Classification Evaluation: {model_name}")
print(f"{'='*70}")
print(f"\nConfusion Matrix:")
print(f" True Positives (TP): {TP:3d} → correctly identified disease patients")
print(f" True Negatives (TN): {TN:3d} → correctly identified healthy people")
print(f" False Positives (FP): {FP:3d} → false alarms (healthy → disease)")
print(f" False Negatives (FN): {FN:3d} → missed cases (disease → healthy)")

print(f"\nPerformance Metrics:")
print(f" Accuracy: {accuracy:.2%} (overall correctness - MISLEADING for imbalanced data!)")
print(f" Sensitivity/Recall: {recall:.2%} (disease detection rate - % of disease cases caught)")
print(f" Specificity: {specificity:.2%} (healthy detection rate - % of healthy cases caught)")
print(f" Precision: {precision:.2%} (positive predictive value - % of disease predictions correct)")

# Rating based on both recall and precision
if recall >= 0.90 and precision >= 0.90:
rating = "✓ Excellent classifier"
elif recall >= 0.80 and precision >= 0.70:
rating = "✓ Good classifier"
elif recall >= 0.70 or precision >= 0.70:
rating = "△ Acceptable classifier"
else:
rating = "✗ Poor classifier"

print(f" → {rating}")
print(f"\n 💡 Insight: High accuracy ({accuracy:.0%}) but low specificity ({specificity:.0%})")
print(f" This shows CLASS IMBALANCE—model predicts disease for almost everything!")

return {
'accuracy': accuracy,
'sensitivity': recall,
'specificity': specificity,
'precision': precision,
'recall': recall,
'TP': TP, 'TN': TN, 'FP': FP, 'FN': FN
}

# Evaluate all three models
print("\n" + "="*70)
print("CLASSIFICATION MODEL COMPARISON")
print("="*70)

results1 = evaluate_classification(y_test, y_pred1, "Model 1: Rule-Based (Gene_A > 6.5)")
results2 = evaluate_classification(y_test, y_pred2, "Model 2: Mean-Based Classifier")
results3 = evaluate_classification(y_test, y_pred3, "Model 3: Nearest Centroid")

Output:

======================================================================
CLASSIFICATION MODEL COMPARISON
======================================================================

======================================================================
Classification Evaluation: Model 1: Rule-Based (Gene_A > 6.5)
======================================================================

Confusion Matrix:
True Positives (TP): 17 → correctly identified disease patients
True Negatives (TN): 0 → correctly identified healthy people
False Positives (FP): 0 → false alarms (healthy → disease)
False Negatives (FN): 3 → missed cases (disease → healthy)

Performance Metrics:
Accuracy: 85.00% (overall correctness - MISLEADING for imbalanced data!)
Sensitivity/Recall: 85.00% (disease detection rate - % of disease cases caught)
Specificity: 0.00% (healthy detection rate - % of healthy cases caught)
Precision: 0.00% (positive predictive value - % of disease predictions correct)
→ ✗ Poor classifier

💡 Insight: High accuracy (85%) but 0% specificity!
This shows CLASS IMBALANCE—model predicts disease for almost everything!

======================================================================
Classification Evaluation: Model 2: Mean-Based Classifier
======================================================================

Confusion Matrix:
True Positives (TP): 19 → correctly identified disease patients
True Negatives (TN): 0 → correctly identified healthy people
False Positives (FP): 0 → false alarms (healthy → disease)
False Negatives (FN): 1 → missed cases (disease → healthy)

Performance Metrics:
Accuracy: 95.00% (overall correctness - MISLEADING for imbalanced data!)
Sensitivity/Recall: 95.00% (disease detection rate - % of disease cases caught)
Specificity: 0.00% (healthy detection rate - % of healthy cases caught)
Precision: 100.00% (positive predictive value - % of disease predictions correct)
→ ✗ Poor classifier

💡 Insight: 95% accuracy + 100% precision looks great, but 0% specificity is a RED FLAG!
Model is essentially predicting disease for everyone—it's cheating!

======================================================================
Classification Evaluation: Model 3: Nearest Centroid
======================================================================

Confusion Matrix:
True Positives (TP): 19 → correctly identified disease patients
True Negatives (TN): 0 → correctly identified healthy people
False Positives (FP): 0 → false alarms (healthy → disease)
False Negatives (FN): 1 → missed cases (disease → healthy)

Performance Metrics:
Accuracy: 95.00% (overall correctness - MISLEADING for imbalanced data!)
Sensitivity/Recall: 95.00% (disease detection rate - % of disease cases caught)
Specificity: 0.00% (healthy detection rate - % of healthy cases caught)
Precision: 100.00% (positive predictive value - % of disease predictions correct)
→ ✗ Poor classifier

💡 Insight: 95% accuracy + 100% precision looks great, but 0% specificity is a RED FLAG!
Model is essentially predicting disease for everyone—it's cheating!

**KEY LEARNING:** All three models fail in the same way! They essentially learned the trivial solution:
"Predict disease for almost everything." This gets 95% accuracy because 95% of the dataset is disease cases.

**Why This Happens with Imbalanced Data:**
- Naive models learn the easiest shortcut
- Accuracy rewards predicting the majority class
- Specificity (or Recall for minority class) immediately exposes this problem!

Understanding Metric Tradeoffs

Different situations require different metrics:

Sensitivity vs Specificity Tradeoff

High Sensitivity (Catch disease):

  • Better for: Disease screening, diagnostic tests
  • Accept more false alarms to avoid missing disease
  • Example: Cancer screening — missing cancer is worse than false alarms

High Specificity (Avoid false alarms):

  • Better for: Confirmatory tests, expensive procedures
  • Accept missing some cases to avoid unnecessary treatment
  • Example: Confirming cancer diagnosis before chemotherapy

Balanced (Youden Index):

  • Better for: General-purpose classification
  • No one goal is more important than the other
  • Example: Gene expression phenotyping

Connecting to Part 1: Building KNN

You now understand:

  • ✓ Real classification problems in bioinformatics
  • ✓ Simple classification models (rules, means, nearest centroid)
  • ✓ How to evaluate models with metrics
  • ✓ Sensitivity vs specificity tradeoffs

What's next? In Part 1: Building KNN from Scratch, we'll extend the nearest centroid idea:

  • Nearest Centroid: Find the 1 closest class center
  • KNN: Find the K closest individual training samples and vote

The evaluation metrics you learned here apply directly to KNN and all other classifiers!


Summary: Key Concepts

Confusion Matrix

                 Predicted
Disease Healthy
Actual
Disease TP FN
Healthy FP TN

Metrics Quick Reference

MetricFormulaMeaningWhen to Use
Accuracy(TP+TN)/(TP+TN+FP+FN)Overall correctness❌ Avoid for imbalanced data
Sensitivity/RecallTP/(TP+FN)% of disease cases caught✅ Essential for imbalanced data
SpecificityTN/(TN+FP)% of healthy cases caught✅ Reveals minority class performance
PrecisionTP/(TP+FP)% of disease predictions correct✅ Cost of false positives

The Imbalanced Data Problem

When classes are imbalanced (e.g., 95% disease, 5% healthy):

  • Accuracy is misleading: Predicting "disease" for everything gives 95% accuracy
  • Sensitivity/Recall reveals the truth: Shows if model handles majority class
  • Specificity shows minority class: Critical for detecting if model fails on rare class
  • Precision shows false alarm cost: How many predicted diseases are actually wrong?

💡 In our example: All models got 95% accuracy but 0% specificity—they're useless!

Quick Code Reference

# Calculate confusion matrix
TP = np.sum((y_true == 1) & (y_pred == 1))
TN = np.sum((y_true == 0) & (y_pred == 0))
FP = np.sum((y_true == 0) & (y_pred == 1))
FN = np.sum((y_true == 1) & (y_pred == 0))

# Calculate metrics for imbalanced data
sensitivity = TP / (TP + FN) # Recall - catch disease?
specificity = TN / (TN + FP) # Handle healthy people?
precision = TP / (TP + FP) # When we predict disease, are we right?
accuracy = (TP + TN) / (TP + TN + FP + FN) # Don't trust this for imbalanced data!

Why This Matters for Bioinformatics

Classification is everywhere in biology:

  • Disease diagnosis: Predict if patient has disease from omics data
  • Protein annotation: Predict protein function from sequence
  • Variant interpretation: Predict if mutation is pathogenic
  • Cell type classification: Predict cell type from gene expression

But we must evaluate properly:

  • Different diseases need different metrics
  • Simple baselines reveal if our model actually learned
  • Understanding metrics prevents misleading conclusions

You now have the foundation to build, evaluate, and deploy classification models in bioinformatics! 🧬🤖