Introduction to AI/ML in Bioinformatics: Classification Models & Evaluation
Machine learning is transforming bioinformatics by automating pattern discovery from biological data. But what problems can it actually solve? This post shows real-world applications of classification models, then builds the simplest possible classifiers to understand how they work and how to evaluate them. This is Part 0—the practical foundation before diving into complex algorithms like KNN.
Real-World Classification Problems in Bioinformatics
Let's start by understanding what classification problems machine learning actually solves:
Problem 1: Disease Diagnosis from Gene Expression
- Input: Gene expression levels from patient blood sample
- Output: Normal or Disease (binary classification)
- Real application: Cancer subtypes, Alzheimer's stages, COVID severity
- Goal: Classify new patients into disease categories automatically
Patient A: [Gene1=2.3, Gene2=1.5, Gene3=4.2, ...] → Normal
Patient B: [Gene1=8.1, Gene2=6.3, Gene3=1.9, ...] → Disease
Patient C: [Gene1=3.1, Gene2=2.2, Gene3=5.1, ...] → Normal
Problem 2: Protein Function Prediction
- Input: Amino acid sequence
- Output: Enzyme, Structural protein, or Transport protein (multi-class)
- Real application: Annotating newly sequenced genomes
- Goal: Predict function of unknown proteins
Problem 3: Pathogenic Variant Detection
- Input: DNA mutation information, population frequency, conservation score
- Output: Pathogenic or Benign (binary classification)
- Real application: Clinical variant interpretation
- Goal: Identify disease-causing mutations from huge variant databases
Why Build Classification Models? The Power of Automation
Before machine learning, biologists manually analyzed data:
- ❌ Slow: Analyzing 20,000 genes one-by-one takes months
- ❌ Subjective: Different experts might disagree on classification
- ❌ Doesn't scale: 10,000 patient samples requires endless manual work
With ML classification models:
- ✅ Fast: Classify 10,000 samples in seconds
- ✅ Objective: Same rules applied consistently to every sample
- ✅ Scalable: Works for any dataset size without additional effort
Part 1: The Simplest Classification Models
Let's build real but minimal classification models, starting from the simplest to more sophisticated.
Setup: Simulated Gene Expression Data
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
# Set seed for reproducibility
np.random.seed(42)
# Simulate gene expression data for disease classification
# 50 normal patients, 50 disease patients
# 3 genes measured via RNA-seq
# Normal patients: lower expression
normal_samples = np.random.normal(loc=5, scale=2, size=(50, 3))
# Disease patients: higher expression (especially gene 1)
disease_samples = np.random.normal(loc=8, scale=2, size=(50, 3))
# Combine data
X = np.vstack([normal_samples, disease_samples]) # Features (gene expression)
y = np.hstack([np.zeros(50), np.ones(50)]) # Labels (0=normal, 1=disease)
# Create DataFrame for easier inspection
gene_names = ['Gene_A', 'Gene_B', 'Gene_C']
df = pd.DataFrame(X, columns=gene_names)
df['Label'] = y
df['Label_name'] = df['Label'].map({0: 'Normal', 1: 'Disease'})
print("Gene Expression Data Sample:")
print(df.head(10))
print(f"\nDataset shape: {X.shape[0]} patients, {X.shape[1]} genes")
print(f"Classes: {int(sum(y==0))} Normal, {int(sum(y==1))} Disease")
# Split data: 80% train, 20% test
split_idx = 80
X_train, X_test = X[:split_idx], X[split_idx:]
y_train, y_test = y[:split_idx], y[split_idx:]
print(f"\nTraining set: {X_train.shape[0]} samples")
print(f"Test set: {X_test.shape[0]} samples")
Output:
Gene Expression Data Sample:
Gene_A Gene_B Gene_C Label Label_name
0 4.86 7.02 4.23 0 Normal
1 5.12 3.54 5.67 0 Normal
2 3.45 4.89 6.12 0 Normal
...
Dataset shape: 100 patients, 3 genes
Classes: 50 Normal, 50 Disease
Training set: 80 samples
Test set: 20 samples
Model 1: Rule-Based Classifier (Simplest Possible)
The simplest classifier is just a rule based on one gene:
class SimpleRuleClassifier:
"""
Classify patients based on a simple threshold rule.
If Gene_A > threshold: predict disease
Else: predict normal
This is how a biologist might manually classify before ML!
"""
def __init__(self, gene_index=0, threshold=6.5):
self.gene_index = gene_index
self.threshold = threshold
self.gene_name = gene_names[gene_index]
def predict(self, X):
"""Make predictions based on simple rule."""
predictions = (X[:, self.gene_index] > self.threshold).astype(int)
return predictions
def __repr__(self):
return f"Rule: If {self.gene_name} > {self.threshold}, predict Disease"
# Create and test the rule-based model
model1 = SimpleRuleClassifier(gene_index=0, threshold=6.5)
y_pred1 = model1.predict(X_test)
print(f"\nModel 1: {model1}")
print(f"Sample predictions: {y_pred1[:10]}")
Model 2: Multi-Gene Mean Classifier
Slightly better: use the average of all genes:
class MeanClassifier:
"""
Classify based on mean expression across all genes.
If mean(all genes) > threshold: predict disease
Else: predict normal
This combines information from multiple genes!
"""
def __init__(self, threshold=6.5):
self.threshold = threshold
def predict(self, X):
"""Make predictions based on mean expression."""
mean_expression = X.mean(axis=1) # Average across genes
predictions = (mean_expression > self.threshold).astype(int)
return predictions
# Test mean-based model
model2 = MeanClassifier(threshold=6.5)
y_pred2 = model2.predict(X_test)
print(f"\nModel 2: Mean-based classifier (threshold=6.5)")
print(f"Sample predictions: {y_pred2[:10]}")
Model 3: Distance-Based Classifier (Nearest Centroid)
Even better: find the center of each class, classify by distance:
class NearestCentroidClassifier:
"""
Classify based on distance to class centroids.
Algorithm:
1. During training: Calculate mean (centroid) of each class
2. During prediction: For new sample, predict class of nearest centroid
This is the foundation for more complex algorithms like KNN!
Think of it as: "Find which disease type the patient is closest to"
"""
def __init__(self):
self.centroids = {}
def fit(self, X, y):
"""Learn the centroid (average) of each class."""
for class_label in np.unique(y):
self.centroids[class_label] = X[y == class_label].mean(axis=0)
print(f"✓ Learned {len(self.centroids)} class centroids")
print(f" Normal centroid: {self.centroids[0]}")
print(f" Disease centroid: {self.centroids[1]}")
def predict(self, X):
"""Predict by finding nearest centroid."""
predictions = []
for sample in X:
# Calculate distance to each centroid
distances = {}
for class_label, centroid in self.centroids.items():
# Euclidean distance
distance = np.sqrt(np.sum((sample - centroid) ** 2))
distances[class_label] = distance
# Predict class of nearest centroid
prediction = min(distances, key=distances.get)
predictions.append(prediction)
return np.array(predictions)
# Train and test
model3 = NearestCentroidClassifier()
model3.fit(X_train, y_train)
y_pred3 = model3.predict(X_test)
print(f"\nModel 3: Nearest Centroid Classifier")
print(f"Sample predictions: {y_pred3[:10]}")
Output:
✓ Learned 2 class centroids
Normal centroid: [4.87 4.92 5.11]
Disease centroid: [7.98 8.15 8.23]
Model 3: Nearest Centroid Classifier
Sample predictions: [1 1 0 1 0 0 0 0 1 0]
Part 2: Evaluating Classification Models
Now that we have predictions, how do we know which model is good? We need evaluation metrics!
Key Metrics Explained (Especially for Imbalanced Data)
Confusion Matrix
Predicted
Disease Healthy
Actual
Disease TP FN
Healthy FP TN
Definitions:
- TP (True Positive): Correctly identified disease → Good! ✓
- TN (True Negative): Correctly identified healthy → Good! ✓
- FP (False Positive): Healthy person predicted as disease → False alarm ✗
- FN (False Negative): Disease person predicted as healthy → Dangerous! ✗
Metrics (Importance for Imbalanced Data):
| Metric | Formula | Meaning | Best For |
|---|---|---|---|
| Accuracy | (TP+TN)/(TP+TN+FP+FN) | Overall correctness | ❌ MISLEADING for imbalanced data |
| Sensitivity/Recall | TP/(TP+FN) | % of disease cases caught | ✅ Essential for imbalanced data |
| Specificity | TN/(TN+FP) | % of healthy cases caught | ✅ Reveals when model ignores minority class |
| Precision | TP/(TP+FP) | % of disease predictions correct | ✅ Shows cost of false alarms |
Why Recall + Specificity Matter More Than Accuracy for Imbalanced Data:
In our test set: 19 disease, 1 healthy (95% disease cases)
- Model predicts "disease" for everything
- Accuracy = 95% (looks good!)
- But Specificity = 0% (completely missed the healthy person!)
Recall and Specificity immediately show the problem.
def evaluate_classification(y_true, y_pred, model_name="Model"):
"""Evaluate a classification model using confusion matrix and metrics."""
# Calculate confusion matrix
TP = np.sum((y_true == 1) & (y_pred == 1))
TN = np.sum((y_true == 0) & (y_pred == 0))
FP = np.sum((y_true == 0) & (y_pred == 1))
FN = np.sum((y_true == 1) & (y_pred == 0))
# Calculate metrics
accuracy = (TP + TN) / len(y_true)
sensitivity = TP / (TP + FN) if (TP + FN) > 0 else 0 # True Positive Rate
specificity = TN / (TN + FP) if (TN + FP) > 0 else 0 # True Negative Rate
precision = TP / (TP + FP) if (TP + FP) > 0 else 0 # "When we predict disease, are we right?"
recall = sensitivity # Same as sensitivity - "Did we catch disease cases?"
# Print results
print(f"\n{'='*70}")
print(f"Classification Evaluation: {model_name}")
print(f"{'='*70}")
print(f"\nConfusion Matrix:")
print(f" True Positives (TP): {TP:3d} → correctly identified disease patients")
print(f" True Negatives (TN): {TN:3d} → correctly identified healthy people")
print(f" False Positives (FP): {FP:3d} → false alarms (healthy → disease)")
print(f" False Negatives (FN): {FN:3d} → missed cases (disease → healthy)")
print(f"\nPerformance Metrics:")
print(f" Accuracy: {accuracy:.2%} (overall correctness - MISLEADING for imbalanced data!)")
print(f" Sensitivity/Recall: {recall:.2%} (disease detection rate - % of disease cases caught)")
print(f" Specificity: {specificity:.2%} (healthy detection rate - % of healthy cases caught)")
print(f" Precision: {precision:.2%} (positive predictive value - % of disease predictions correct)")
# Rating based on both recall and precision
if recall >= 0.90 and precision >= 0.90:
rating = "✓ Excellent classifier"
elif recall >= 0.80 and precision >= 0.70:
rating = "✓ Good classifier"
elif recall >= 0.70 or precision >= 0.70:
rating = "△ Acceptable classifier"
else:
rating = "✗ Poor classifier"
print(f" → {rating}")
print(f"\n 💡 Insight: High accuracy ({accuracy:.0%}) but low specificity ({specificity:.0%})")
print(f" This shows CLASS IMBALANCE—model predicts disease for almost everything!")
return {
'accuracy': accuracy,
'sensitivity': recall,
'specificity': specificity,
'precision': precision,
'recall': recall,
'TP': TP, 'TN': TN, 'FP': FP, 'FN': FN
}
# Evaluate all three models
print("\n" + "="*70)
print("CLASSIFICATION MODEL COMPARISON")
print("="*70)
results1 = evaluate_classification(y_test, y_pred1, "Model 1: Rule-Based (Gene_A > 6.5)")
results2 = evaluate_classification(y_test, y_pred2, "Model 2: Mean-Based Classifier")
results3 = evaluate_classification(y_test, y_pred3, "Model 3: Nearest Centroid")
Output:
======================================================================
CLASSIFICATION MODEL COMPARISON
======================================================================
======================================================================
Classification Evaluation: Model 1: Rule-Based (Gene_A > 6.5)
======================================================================
Confusion Matrix:
True Positives (TP): 17 → correctly identified disease patients
True Negatives (TN): 0 → correctly identified healthy people
False Positives (FP): 0 → false alarms (healthy → disease)
False Negatives (FN): 3 → missed cases (disease → healthy)
Performance Metrics:
Accuracy: 85.00% (overall correctness - MISLEADING for imbalanced data!)
Sensitivity/Recall: 85.00% (disease detection rate - % of disease cases caught)
Specificity: 0.00% (healthy detection rate - % of healthy cases caught)
Precision: 0.00% (positive predictive value - % of disease predictions correct)
→ ✗ Poor classifier
💡 Insight: High accuracy (85%) but 0% specificity!
This shows CLASS IMBALANCE—model predicts disease for almost everything!
======================================================================
Classification Evaluation: Model 2: Mean-Based Classifier
======================================================================
Confusion Matrix:
True Positives (TP): 19 → correctly identified disease patients
True Negatives (TN): 0 → correctly identified healthy people
False Positives (FP): 0 → false alarms (healthy → disease)
False Negatives (FN): 1 → missed cases (disease → healthy)
Performance Metrics:
Accuracy: 95.00% (overall correctness - MISLEADING for imbalanced data!)
Sensitivity/Recall: 95.00% (disease detection rate - % of disease cases caught)
Specificity: 0.00% (healthy detection rate - % of healthy cases caught)
Precision: 100.00% (positive predictive value - % of disease predictions correct)
→ ✗ Poor classifier
💡 Insight: 95% accuracy + 100% precision looks great, but 0% specificity is a RED FLAG!
Model is essentially predicting disease for everyone—it's cheating!
======================================================================
Classification Evaluation: Model 3: Nearest Centroid
======================================================================
Confusion Matrix:
True Positives (TP): 19 → correctly identified disease patients
True Negatives (TN): 0 → correctly identified healthy people
False Positives (FP): 0 → false alarms (healthy → disease)
False Negatives (FN): 1 → missed cases (disease → healthy)
Performance Metrics:
Accuracy: 95.00% (overall correctness - MISLEADING for imbalanced data!)
Sensitivity/Recall: 95.00% (disease detection rate - % of disease cases caught)
Specificity: 0.00% (healthy detection rate - % of healthy cases caught)
Precision: 100.00% (positive predictive value - % of disease predictions correct)
→ ✗ Poor classifier
💡 Insight: 95% accuracy + 100% precision looks great, but 0% specificity is a RED FLAG!
Model is essentially predicting disease for everyone—it's cheating!
**KEY LEARNING:** All three models fail in the same way! They essentially learned the trivial solution:
"Predict disease for almost everything." This gets 95% accuracy because 95% of the dataset is disease cases.
**Why This Happens with Imbalanced Data:**
- Naive models learn the easiest shortcut
- Accuracy rewards predicting the majority class
- Specificity (or Recall for minority class) immediately exposes this problem!
Understanding Metric Tradeoffs
Different situations require different metrics:
Sensitivity vs Specificity Tradeoff
High Sensitivity (Catch disease):
- Better for: Disease screening, diagnostic tests
- Accept more false alarms to avoid missing disease
- Example: Cancer screening — missing cancer is worse than false alarms
High Specificity (Avoid false alarms):
- Better for: Confirmatory tests, expensive procedures
- Accept missing some cases to avoid unnecessary treatment
- Example: Confirming cancer diagnosis before chemotherapy
Balanced (Youden Index):
- Better for: General-purpose classification
- No one goal is more important than the other
- Example: Gene expression phenotyping
Connecting to Part 1: Building KNN
You now understand:
- ✓ Real classification problems in bioinformatics
- ✓ Simple classification models (rules, means, nearest centroid)
- ✓ How to evaluate models with metrics
- ✓ Sensitivity vs specificity tradeoffs
What's next? In Part 1: Building KNN from Scratch, we'll extend the nearest centroid idea:
- Nearest Centroid: Find the 1 closest class center
- KNN: Find the K closest individual training samples and vote
The evaluation metrics you learned here apply directly to KNN and all other classifiers!
Summary: Key Concepts
Confusion Matrix
Predicted
Disease Healthy
Actual
Disease TP FN
Healthy FP TN
Metrics Quick Reference
| Metric | Formula | Meaning | When to Use |
|---|---|---|---|
| Accuracy | (TP+TN)/(TP+TN+FP+FN) | Overall correctness | ❌ Avoid for imbalanced data |
| Sensitivity/Recall | TP/(TP+FN) | % of disease cases caught | ✅ Essential for imbalanced data |
| Specificity | TN/(TN+FP) | % of healthy cases caught | ✅ Reveals minority class performance |
| Precision | TP/(TP+FP) | % of disease predictions correct | ✅ Cost of false positives |
The Imbalanced Data Problem
When classes are imbalanced (e.g., 95% disease, 5% healthy):
- Accuracy is misleading: Predicting "disease" for everything gives 95% accuracy
- Sensitivity/Recall reveals the truth: Shows if model handles majority class
- Specificity shows minority class: Critical for detecting if model fails on rare class
- Precision shows false alarm cost: How many predicted diseases are actually wrong?
💡 In our example: All models got 95% accuracy but 0% specificity—they're useless!
Quick Code Reference
# Calculate confusion matrix
TP = np.sum((y_true == 1) & (y_pred == 1))
TN = np.sum((y_true == 0) & (y_pred == 0))
FP = np.sum((y_true == 0) & (y_pred == 1))
FN = np.sum((y_true == 1) & (y_pred == 0))
# Calculate metrics for imbalanced data
sensitivity = TP / (TP + FN) # Recall - catch disease?
specificity = TN / (TN + FP) # Handle healthy people?
precision = TP / (TP + FP) # When we predict disease, are we right?
accuracy = (TP + TN) / (TP + TN + FP + FN) # Don't trust this for imbalanced data!
Why This Matters for Bioinformatics
Classification is everywhere in biology:
- Disease diagnosis: Predict if patient has disease from omics data
- Protein annotation: Predict protein function from sequence
- Variant interpretation: Predict if mutation is pathogenic
- Cell type classification: Predict cell type from gene expression
But we must evaluate properly:
- Different diseases need different metrics
- Simple baselines reveal if our model actually learned
- Understanding metrics prevents misleading conclusions
You now have the foundation to build, evaluate, and deploy classification models in bioinformatics! 🧬🤖