Systems Library / AI Model Setup / How to Build a Multi-Model AI Router
AI Model Setup routing optimization

How to Build a Multi-Model AI Router

Route requests to the best AI model based on task type, cost, and quality needs.

Jay Banlasan

Jay Banlasan

The AI Systems Guy

Using one AI model for every task is like using a sledgehammer for every job in your toolbox. GPT-4o is excellent at reasoning but costs 30x more than GPT-4o-mini, which handles simple tasks just as well. A multi-model router sends each request to the cheapest model capable of handling it well. I use this in every production system I build. The cost savings are immediate and significant.

The router classifies each incoming request, matches it to a model tier, and routes accordingly. Simple tasks go to cheap models. Complex tasks go to capable models. Edge cases fall back to the highest-tier model. Your costs drop 40-70% with no quality loss on the tasks that matter.

What You Need Before Starting

Step 1: Define Your Model Tiers

from dataclasses import dataclass

@dataclass
class ModelTier:
    name: str
    model_id: str
    provider: str
    cost_per_1k_input: float
    cost_per_1k_output: float
    max_tokens: int
    strengths: list[str]
    weaknesses: list[str]

TIERS = {
    "nano": ModelTier(
        name="nano",
        model_id="gpt-4o-mini",
        provider="openai",
        cost_per_1k_input=0.00015,
        cost_per_1k_output=0.0006,
        max_tokens=4096,
        strengths=["simple Q&A", "classification", "short summaries", "formatting"],
        weaknesses=["complex reasoning", "long document analysis", "nuanced writing"]
    ),
    "standard": ModelTier(
        name="standard",
        model_id="gpt-4o",
        provider="openai",
        cost_per_1k_input=0.005,
        cost_per_1k_output=0.015,
        max_tokens=4096,
        strengths=["complex reasoning", "long context", "nuanced writing", "coding"],
        weaknesses=["cost", "speed vs nano"]
    ),
    "reasoning": ModelTier(
        name="reasoning",
        model_id="o1-mini",
        provider="openai",
        cost_per_1k_input=0.003,
        cost_per_1k_output=0.012,
        max_tokens=4096,
        strengths=["multi-step reasoning", "math", "strategic planning", "complex code"],
        weaknesses=["cost", "slow", "overkill for most tasks"]
    )
}

# Map complexity to tier
COMPLEXITY_TO_TIER = {
    "simple": "nano",
    "medium": "standard",
    "complex": "standard",
    "reasoning": "reasoning"
}

Step 2: Build the Request Classifier

Classify incoming requests by complexity and task type before routing.

import openai
import json

client = openai.OpenAI(api_key="YOUR_API_KEY")

CLASSIFIER_PROMPT = """Analyze this AI request and classify it.

COMPLEXITY:
- simple: Short factual answers, basic formatting, classification, yes/no, summarizing <500 words
- medium: Multi-part questions, writing 200-500 words, moderate analysis, basic code
- complex: Deep analysis, long-form writing, nuanced reasoning, complex code, legal/medical content
- reasoning: Multi-step math, strategic planning, debugging complex systems, requires chain of thought

TASK_TYPE:
- qa, summarization, extraction, classification, writing, coding, analysis, translation, other

Return JSON: {"complexity": "...", "task_type": "...", "reasoning": "one sentence"}"""

def classify_request(messages: list) -> dict:
    # Build a compact representation of the request
    request_text = ""
    for msg in messages[-3:]:  # Last 3 messages
        request_text += f"{msg['role'].upper()}: {msg['content'][:300]}\n"

    response = client.chat.completions.create(
        model="gpt-4o-mini",  # Always use cheapest model for classification
        messages=[
            {"role": "system", "content": CLASSIFIER_PROMPT},
            {"role": "user", "content": request_text}
        ],
        temperature=0,
        max_tokens=100,
        response_format={"type": "json_object"}
    )

    return json.loads(response.choices[0].message.content)

Step 3: Build the Router

Combine classification with model selection and execution.

import time
from dataclasses import dataclass

@dataclass
class RouterResult:
    content: str
    model_used: str
    tier: str
    complexity: str
    task_type: str
    input_tokens: int
    output_tokens: int
    cost_usd: float
    latency_ms: int

def estimate_cost(tier: ModelTier, input_tokens: int, output_tokens: int) -> float:
    input_cost = (input_tokens / 1000) * tier.cost_per_1k_input
    output_cost = (output_tokens / 1000) * tier.cost_per_1k_output
    return round(input_cost + output_cost, 6)

def route_request(
    messages: list,
    force_tier: str = None,
    temperature: float = 0.3,
    max_tokens: int = 1000
) -> RouterResult:
    start = time.time()

    if force_tier:
        tier_name = force_tier
        classification = {"complexity": "forced", "task_type": "forced"}
    else:
        classification = classify_request(messages)
        tier_name = COMPLEXITY_TO_TIER.get(classification["complexity"], "standard")

    tier = TIERS[tier_name]

    response = client.chat.completions.create(
        model=tier.model_id,
        messages=messages,
        temperature=temperature,
        max_tokens=max_tokens
    )

    latency = int((time.time() - start) * 1000)
    content = response.choices[0].message.content
    input_tokens = response.usage.prompt_tokens
    output_tokens = response.usage.completion_tokens
    cost = estimate_cost(tier, input_tokens, output_tokens)

    return RouterResult(
        content=content,
        model_used=tier.model_id,
        tier=tier_name,
        complexity=classification.get("complexity", "unknown"),
        task_type=classification.get("task_type", "unknown"),
        input_tokens=input_tokens,
        output_tokens=output_tokens,
        cost_usd=cost,
        latency_ms=latency
    )

Step 4: Add Routing Rules for Specific Content Patterns

Some requests should always go to a specific tier regardless of classification.

import re

HARD_ROUTE_RULES = [
    # (pattern, tier_override, reason)
    (r"HIPAA|PHI|medical record|patient data", "standard", "Medical/health content needs careful handling"),
    (r"legal advice|contract review|litigation", "standard", "Legal content needs higher accuracy"),
    (r"calculate|compute|solve for|proof|equation", "reasoning", "Math needs reasoning model"),
    (r"refactor|debug|optimize.*code|stack trace", "standard", "Complex code needs standard tier"),
    (r"translate to [a-z]+", "nano", "Translation is nano-appropriate"),
    (r"yes or no|true or false|is it|does it", "nano", "Binary questions are nano-appropriate"),
]

def check_hard_routes(messages: list) -> str | None:
    full_text = " ".join(m["content"] for m in messages if m.get("content")).lower()

    for pattern, tier, reason in HARD_ROUTE_RULES:
        if re.search(pattern, full_text, re.IGNORECASE):
            return tier

    return None

def smart_route(messages: list, **kwargs) -> RouterResult:
    forced_tier = check_hard_routes(messages)
    return route_request(messages, force_tier=forced_tier, **kwargs)

Step 5: Track Routing Analytics

Log every routing decision so you can measure savings and spot mis-classifications.

import sqlite3
from datetime import datetime

def init_router_db():
    conn = sqlite3.connect("router_analytics.db")
    conn.execute("""
        CREATE TABLE IF NOT EXISTS routing_log (
            id INTEGER PRIMARY KEY AUTOINCREMENT,
            timestamp TEXT,
            tier TEXT,
            model_used TEXT,
            complexity TEXT,
            task_type TEXT,
            input_tokens INTEGER,
            output_tokens INTEGER,
            cost_usd REAL,
            latency_ms INTEGER
        )
    """)
    conn.commit()
    conn.close()

def log_route(result: RouterResult):
    conn = sqlite3.connect("router_analytics.db")
    conn.execute(
        """INSERT INTO routing_log 
        (timestamp, tier, model_used, complexity, task_type, input_tokens, output_tokens, cost_usd, latency_ms)
        VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)""",
        (datetime.now().isoformat(), result.tier, result.model_used, result.complexity,
         result.task_type, result.input_tokens, result.output_tokens, result.cost_usd, result.latency_ms)
    )
    conn.commit()
    conn.close()

def get_routing_summary(days: int = 7) -> dict:
    conn = sqlite3.connect("router_analytics.db")
    rows = conn.execute("""
        SELECT tier, COUNT(*) as calls, SUM(cost_usd) as total_cost, AVG(latency_ms) as avg_latency
        FROM routing_log
        WHERE timestamp > datetime('now', ?)
        GROUP BY tier
    """, (f"-{days} days",)).fetchall()
    conn.close()

    summary = {}
    for tier, calls, cost, latency in rows:
        summary[tier] = {"calls": calls, "total_cost": round(cost, 4), "avg_latency_ms": round(latency)}

    # Calculate what it would have cost if everything went to standard
    total_input = conn.execute("SELECT SUM(input_tokens) FROM routing_log WHERE timestamp > datetime('now', ?)", (f"-{days} days",)).fetchone()[0] or 0
    total_output = conn.execute("SELECT SUM(output_tokens) FROM routing_log WHERE timestamp > datetime('now', ?)", (f"-{days} days",)).fetchone()[0] or 0

    return summary

init_router_db()

Step 6: Use the Router in Your Application

# Simple usage
result = smart_route([
    {"role": "system", "content": "You are a helpful assistant."},
    {"role": "user", "content": "What does API stand for?"}
])
print(f"Answer: {result.content}")
print(f"Routed to: {result.tier} ({result.model_used})")
print(f"Cost: ${result.cost_usd:.6f}")
log_route(result)

# Force a specific tier for sensitive content
result = smart_route(
    [{"role": "user", "content": "Review this contract clause for hidden risks..."}],
    force_tier="standard"
)

What to Build Next

Related Reading

Want this system built for your business?

Get a free assessment. We will map every system your business needs and show you the ROI.

Get Your Free Assessment

Related Systems