Systems Library / AI Model Setup / How to Implement Cost-Based AI Model Selection
AI Model Setup routing optimization

How to Implement Cost-Based AI Model Selection

Automatically choose the cheapest AI model that meets quality thresholds.

Jay Banlasan

Jay Banlasan

The AI Systems Guy

Cost-based model selection is a specific discipline within routing: given a quality threshold, find the cheapest model that meets it. This is different from general routing, which focuses on capability matching. Cost-based selection starts with the cheapest option and escalates only when the output quality is insufficient.

I built a version of this for a client running 50,000 AI requests per month. After implementing cost-based selection with quality fallback, their monthly AI spend dropped from $380 to $110 with no measurable change in user satisfaction. The savings compound as volume grows.

What You Need Before Starting

Step 1: Define Cost and Quality Targets

from dataclasses import dataclass
from typing import Optional

@dataclass
class ModelOption:
    id: str
    provider: str
    cost_per_1k_input: float
    cost_per_1k_output: float
    expected_quality_score: float  # 1-10, your baseline measurement
    max_tokens: int = 4096

# Ordered cheapest to most expensive
MODELS_BY_COST = [
    ModelOption("gpt-4o-mini",                  "openai",    0.00015, 0.0006,  7.2),
    ModelOption("claude-3-haiku-20240307",       "anthropic", 0.00025, 0.00125, 7.5),
    ModelOption("gpt-4o",                        "openai",    0.005,   0.015,   9.1),
    ModelOption("claude-3-5-sonnet-20241022",    "anthropic", 0.003,   0.015,   9.0),
]

def estimate_request_cost(model: ModelOption, input_tokens: int, output_tokens: int) -> float:
    return (
        (input_tokens / 1000) * model.cost_per_1k_input +
        (output_tokens / 1000) * model.cost_per_1k_output
    )

Step 2: Build the Quality-Gated Selector

Try the cheapest model first. If quality is below threshold, escalate.

import openai
import anthropic

openai_client = openai.OpenAI(api_key="YOUR_OPENAI_KEY")
anthropic_client = anthropic.Anthropic(api_key="YOUR_ANTHROPIC_KEY")

def call_model_provider(model: ModelOption, messages: list, max_tokens: int = 800) -> str:
    if model.provider == "openai":
        response = openai_client.chat.completions.create(
            model=model.id,
            messages=messages,
            max_tokens=max_tokens,
            temperature=0.3
        )
        return response.choices[0].message.content

    elif model.provider == "anthropic":
        system = next((m["content"] for m in messages if m["role"] == "system"), None)
        user_messages = [m for m in messages if m["role"] != "system"]
        kwargs = {"model": model.id, "messages": user_messages, "max_tokens": max_tokens}
        if system:
            kwargs["system"] = system
        response = anthropic_client.messages.create(**kwargs)
        return response.content[0].text

    raise ValueError(f"Unknown provider: {model.provider}")

def quick_quality_check(output: str, task_description: str) -> float:
    """Score 1-10. Uses the cheapest model to keep meta-costs low."""
    response = openai_client.chat.completions.create(
        model="gpt-4o-mini",
        messages=[
            {
                "role": "system",
                "content": "You are a quality evaluator. Score AI outputs 1-10 based on accuracy, completeness, and usefulness for the stated task. Return ONLY the number."
            },
            {
                "role": "user",
                "content": f"Task: {task_description}\n\nOutput to evaluate:\n{output}\n\nScore (1-10):"
            }
        ],
        temperature=0,
        max_tokens=5
    )
    try:
        return float(response.choices[0].message.content.strip())
    except ValueError:
        return 5.0  # Default middle score if parse fails

Step 3: Build the Cost-Optimized Completion Function

from dataclasses import dataclass
import time

@dataclass
class OptimizedResult:
    content: str
    model_used: str
    attempts: int
    quality_score: float
    total_cost_usd: float
    latency_ms: int
    escalated: bool

def optimized_complete(
    messages: list,
    quality_threshold: float = 7.5,
    task_description: str = "Respond helpfully and accurately",
    max_cost_per_request: float = 0.05,
    check_quality: bool = True
) -> OptimizedResult:
    start = time.time()
    total_cost = 0.0
    attempts = 0

    # Estimate input tokens (rough approximation)
    input_text = " ".join(m["content"] for m in messages)
    estimated_input_tokens = len(input_text.split()) * 1.3  # Rough token estimate

    for model in MODELS_BY_COST:
        # Skip if expected cost exceeds budget
        estimated_cost = estimate_request_cost(model, estimated_input_tokens, 800)
        if estimated_cost > max_cost_per_request:
            continue

        # Skip models whose expected quality is below threshold by more than 1 point
        # (No point trying if we know it will fail)
        if check_quality and model.expected_quality_score < quality_threshold - 1.0:
            continue

        attempts += 1
        output = call_model_provider(model, messages)
        total_cost += estimate_request_cost(model, estimated_input_tokens, len(output.split()) * 1.3)

        # Skip quality check for last resort model
        if not check_quality or model == MODELS_BY_COST[-1]:
            latency = int((time.time() - start) * 1000)
            return OptimizedResult(
                content=output,
                model_used=model.id,
                attempts=attempts,
                quality_score=model.expected_quality_score,
                total_cost_usd=round(total_cost, 6),
                latency_ms=latency,
                escalated=attempts > 1
            )

        # Check quality
        quality = quick_quality_check(output, task_description)
        total_cost += 0.0001  # Approximate cost of quality check

        if quality >= quality_threshold:
            latency = int((time.time() - start) * 1000)
            return OptimizedResult(
                content=output,
                model_used=model.id,
                attempts=attempts,
                quality_score=quality,
                total_cost_usd=round(total_cost, 6),
                latency_ms=latency,
                escalated=attempts > 1
            )

        print(f"{model.id} scored {quality:.1f}/{quality_threshold}. Escalating...")

    # Should not reach here, but just in case
    raise RuntimeError("No models available within cost and quality constraints")

Step 4: Add a Quality Calibration Mode

Before deploying, calibrate the quality scores for your specific tasks.

import statistics

def calibrate_models(
    test_inputs: list[str],
    task_description: str,
    system_prompt: str = "You are a helpful assistant."
) -> dict:
    results = {model.id: [] for model in MODELS_BY_COST}

    for user_input in test_inputs:
        messages = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_input}
        ]

        for model in MODELS_BY_COST:
            output = call_model_provider(model, messages)
            score = quick_quality_check(output, task_description)
            results[model.id].append(score)
            print(f"{model.id}: {score:.1f}")

    calibration = {}
    for model_id, scores in results.items():
        calibration[model_id] = {
            "mean": round(statistics.mean(scores), 2),
            "min": round(min(scores), 2),
            "max": round(max(scores), 2),
            "stdev": round(statistics.stdev(scores) if len(scores) > 1 else 0, 2)
        }

    return calibration

# Run once per task type to get real quality data
# calibration = calibrate_models(TEST_INPUTS, "Answer customer support questions accurately")
# print(calibration)

Step 5: Set Budget Caps by Time Period

Prevent runaway costs with daily and monthly spend limits.

import sqlite3
from datetime import datetime

def init_cost_db():
    conn = sqlite3.connect("ai_costs.db")
    conn.execute("""
        CREATE TABLE IF NOT EXISTS cost_log (
            id INTEGER PRIMARY KEY AUTOINCREMENT,
            timestamp TEXT,
            model_id TEXT,
            cost_usd REAL,
            task_type TEXT
        )
    """)
    conn.commit()
    conn.close()

def log_cost(model_id: str, cost: float, task_type: str = "general"):
    conn = sqlite3.connect("ai_costs.db")
    conn.execute(
        "INSERT INTO cost_log (timestamp, model_id, cost_usd, task_type) VALUES (?, ?, ?, ?)",
        (datetime.now().isoformat(), model_id, cost, task_type)
    )
    conn.commit()
    conn.close()

def get_period_spend(period: str = "day") -> float:
    conn = sqlite3.connect("ai_costs.db")
    window = "-1 day" if period == "day" else "-30 days"
    total = conn.execute(
        "SELECT SUM(cost_usd) FROM cost_log WHERE timestamp > datetime('now', ?)",
        (window,)
    ).fetchone()[0]
    conn.close()
    return round(total or 0.0, 4)

DAILY_BUDGET = 5.00
MONTHLY_BUDGET = 100.00

def budget_checked_complete(messages: list, **kwargs) -> Optional[OptimizedResult]:
    daily_spend = get_period_spend("day")
    monthly_spend = get_period_spend("month")

    if daily_spend >= DAILY_BUDGET:
        print(f"Daily budget exhausted: ${daily_spend:.4f} / ${DAILY_BUDGET}")
        return None

    if monthly_spend >= MONTHLY_BUDGET:
        print(f"Monthly budget exhausted: ${monthly_spend:.2f} / ${MONTHLY_BUDGET}")
        return None

    result = optimized_complete(messages, **kwargs)
    log_cost(result.model_used, result.total_cost_usd)
    return result

init_cost_db()

What to Build Next

Related Reading

Want this system built for your business?

Get a free assessment. We will map every system your business needs and show you the ROI.

Get Your Free Assessment

Related Systems