How to Build a Multi-Model AI Router
Route requests to the best AI model based on task type, cost, and quality needs.
Jay Banlasan
The AI Systems Guy
Using one AI model for every task is like using a sledgehammer for every job in your toolbox. GPT-4o is excellent at reasoning but costs 30x more than GPT-4o-mini, which handles simple tasks just as well. A multi-model router sends each request to the cheapest model capable of handling it well. I use this in every production system I build. The cost savings are immediate and significant.
The router classifies each incoming request, matches it to a model tier, and routes accordingly. Simple tasks go to cheap models. Complex tasks go to capable models. Edge cases fall back to the highest-tier model. Your costs drop 40-70% with no quality loss on the tasks that matter.
What You Need Before Starting
- Python 3.9+
- API keys for OpenAI (at minimum)
- Clear definitions of what "simple", "medium", and "complex" mean for your tasks
- Sample requests to calibrate the classifier
Step 1: Define Your Model Tiers
from dataclasses import dataclass
@dataclass
class ModelTier:
name: str
model_id: str
provider: str
cost_per_1k_input: float
cost_per_1k_output: float
max_tokens: int
strengths: list[str]
weaknesses: list[str]
TIERS = {
"nano": ModelTier(
name="nano",
model_id="gpt-4o-mini",
provider="openai",
cost_per_1k_input=0.00015,
cost_per_1k_output=0.0006,
max_tokens=4096,
strengths=["simple Q&A", "classification", "short summaries", "formatting"],
weaknesses=["complex reasoning", "long document analysis", "nuanced writing"]
),
"standard": ModelTier(
name="standard",
model_id="gpt-4o",
provider="openai",
cost_per_1k_input=0.005,
cost_per_1k_output=0.015,
max_tokens=4096,
strengths=["complex reasoning", "long context", "nuanced writing", "coding"],
weaknesses=["cost", "speed vs nano"]
),
"reasoning": ModelTier(
name="reasoning",
model_id="o1-mini",
provider="openai",
cost_per_1k_input=0.003,
cost_per_1k_output=0.012,
max_tokens=4096,
strengths=["multi-step reasoning", "math", "strategic planning", "complex code"],
weaknesses=["cost", "slow", "overkill for most tasks"]
)
}
# Map complexity to tier
COMPLEXITY_TO_TIER = {
"simple": "nano",
"medium": "standard",
"complex": "standard",
"reasoning": "reasoning"
}
Step 2: Build the Request Classifier
Classify incoming requests by complexity and task type before routing.
import openai
import json
client = openai.OpenAI(api_key="YOUR_API_KEY")
CLASSIFIER_PROMPT = """Analyze this AI request and classify it.
COMPLEXITY:
- simple: Short factual answers, basic formatting, classification, yes/no, summarizing <500 words
- medium: Multi-part questions, writing 200-500 words, moderate analysis, basic code
- complex: Deep analysis, long-form writing, nuanced reasoning, complex code, legal/medical content
- reasoning: Multi-step math, strategic planning, debugging complex systems, requires chain of thought
TASK_TYPE:
- qa, summarization, extraction, classification, writing, coding, analysis, translation, other
Return JSON: {"complexity": "...", "task_type": "...", "reasoning": "one sentence"}"""
def classify_request(messages: list) -> dict:
# Build a compact representation of the request
request_text = ""
for msg in messages[-3:]: # Last 3 messages
request_text += f"{msg['role'].upper()}: {msg['content'][:300]}\n"
response = client.chat.completions.create(
model="gpt-4o-mini", # Always use cheapest model for classification
messages=[
{"role": "system", "content": CLASSIFIER_PROMPT},
{"role": "user", "content": request_text}
],
temperature=0,
max_tokens=100,
response_format={"type": "json_object"}
)
return json.loads(response.choices[0].message.content)
Step 3: Build the Router
Combine classification with model selection and execution.
import time
from dataclasses import dataclass
@dataclass
class RouterResult:
content: str
model_used: str
tier: str
complexity: str
task_type: str
input_tokens: int
output_tokens: int
cost_usd: float
latency_ms: int
def estimate_cost(tier: ModelTier, input_tokens: int, output_tokens: int) -> float:
input_cost = (input_tokens / 1000) * tier.cost_per_1k_input
output_cost = (output_tokens / 1000) * tier.cost_per_1k_output
return round(input_cost + output_cost, 6)
def route_request(
messages: list,
force_tier: str = None,
temperature: float = 0.3,
max_tokens: int = 1000
) -> RouterResult:
start = time.time()
if force_tier:
tier_name = force_tier
classification = {"complexity": "forced", "task_type": "forced"}
else:
classification = classify_request(messages)
tier_name = COMPLEXITY_TO_TIER.get(classification["complexity"], "standard")
tier = TIERS[tier_name]
response = client.chat.completions.create(
model=tier.model_id,
messages=messages,
temperature=temperature,
max_tokens=max_tokens
)
latency = int((time.time() - start) * 1000)
content = response.choices[0].message.content
input_tokens = response.usage.prompt_tokens
output_tokens = response.usage.completion_tokens
cost = estimate_cost(tier, input_tokens, output_tokens)
return RouterResult(
content=content,
model_used=tier.model_id,
tier=tier_name,
complexity=classification.get("complexity", "unknown"),
task_type=classification.get("task_type", "unknown"),
input_tokens=input_tokens,
output_tokens=output_tokens,
cost_usd=cost,
latency_ms=latency
)
Step 4: Add Routing Rules for Specific Content Patterns
Some requests should always go to a specific tier regardless of classification.
import re
HARD_ROUTE_RULES = [
# (pattern, tier_override, reason)
(r"HIPAA|PHI|medical record|patient data", "standard", "Medical/health content needs careful handling"),
(r"legal advice|contract review|litigation", "standard", "Legal content needs higher accuracy"),
(r"calculate|compute|solve for|proof|equation", "reasoning", "Math needs reasoning model"),
(r"refactor|debug|optimize.*code|stack trace", "standard", "Complex code needs standard tier"),
(r"translate to [a-z]+", "nano", "Translation is nano-appropriate"),
(r"yes or no|true or false|is it|does it", "nano", "Binary questions are nano-appropriate"),
]
def check_hard_routes(messages: list) -> str | None:
full_text = " ".join(m["content"] for m in messages if m.get("content")).lower()
for pattern, tier, reason in HARD_ROUTE_RULES:
if re.search(pattern, full_text, re.IGNORECASE):
return tier
return None
def smart_route(messages: list, **kwargs) -> RouterResult:
forced_tier = check_hard_routes(messages)
return route_request(messages, force_tier=forced_tier, **kwargs)
Step 5: Track Routing Analytics
Log every routing decision so you can measure savings and spot mis-classifications.
import sqlite3
from datetime import datetime
def init_router_db():
conn = sqlite3.connect("router_analytics.db")
conn.execute("""
CREATE TABLE IF NOT EXISTS routing_log (
id INTEGER PRIMARY KEY AUTOINCREMENT,
timestamp TEXT,
tier TEXT,
model_used TEXT,
complexity TEXT,
task_type TEXT,
input_tokens INTEGER,
output_tokens INTEGER,
cost_usd REAL,
latency_ms INTEGER
)
""")
conn.commit()
conn.close()
def log_route(result: RouterResult):
conn = sqlite3.connect("router_analytics.db")
conn.execute(
"""INSERT INTO routing_log
(timestamp, tier, model_used, complexity, task_type, input_tokens, output_tokens, cost_usd, latency_ms)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)""",
(datetime.now().isoformat(), result.tier, result.model_used, result.complexity,
result.task_type, result.input_tokens, result.output_tokens, result.cost_usd, result.latency_ms)
)
conn.commit()
conn.close()
def get_routing_summary(days: int = 7) -> dict:
conn = sqlite3.connect("router_analytics.db")
rows = conn.execute("""
SELECT tier, COUNT(*) as calls, SUM(cost_usd) as total_cost, AVG(latency_ms) as avg_latency
FROM routing_log
WHERE timestamp > datetime('now', ?)
GROUP BY tier
""", (f"-{days} days",)).fetchall()
conn.close()
summary = {}
for tier, calls, cost, latency in rows:
summary[tier] = {"calls": calls, "total_cost": round(cost, 4), "avg_latency_ms": round(latency)}
# Calculate what it would have cost if everything went to standard
total_input = conn.execute("SELECT SUM(input_tokens) FROM routing_log WHERE timestamp > datetime('now', ?)", (f"-{days} days",)).fetchone()[0] or 0
total_output = conn.execute("SELECT SUM(output_tokens) FROM routing_log WHERE timestamp > datetime('now', ?)", (f"-{days} days",)).fetchone()[0] or 0
return summary
init_router_db()
Step 6: Use the Router in Your Application
# Simple usage
result = smart_route([
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What does API stand for?"}
])
print(f"Answer: {result.content}")
print(f"Routed to: {result.tier} ({result.model_used})")
print(f"Cost: ${result.cost_usd:.6f}")
log_route(result)
# Force a specific tier for sensitive content
result = smart_route(
[{"role": "user", "content": "Review this contract clause for hidden risks..."}],
force_tier="standard"
)
What to Build Next
- Add quality scoring on a sample of nano-tier responses to verify the classifier is not sending complex requests to the wrong tier
- Build a weekly cost report that shows estimated savings vs routing everything to the standard tier so you can quantify the router's value
- Implement user-level tier overrides so premium users can request standard-tier processing for all their requests regardless of classification
Related Reading
- How to Implement Cost-Based AI Model Selection - cost-based selection is the financial logic layer beneath the router
- How to Build an AI Load Balancer Across Providers - combine routing with load balancing for maximum reliability
- How to Set Up LiteLLM as Your AI Gateway - LiteLLM can simplify the provider abstraction layer in your router
Want this system built for your business?
Get a free assessment. We will map every system your business needs and show you the ROI.
Get Your Free Assessment