Systems Library / AI Model Setup / How to Build AI Systems with Fallback Models
AI Model Setup advanced

How to Build AI Systems with Fallback Models

Configure backup models that activate when your primary AI is unavailable.

Jay Banlasan

Jay Banlasan

The AI Systems Guy

Every major AI provider has outages. OpenAI, Anthropic, Google — they all go down. If your product depends on a single provider, an outage means your product stops working. A fallback model system routes around the failure automatically. Your users never notice.

The pattern I use treats AI providers like a waterfall: try primary, if it fails try secondary, if that fails try tertiary. Each fallback is chosen based on capability similarity, not just availability. You want the cheapest model that can do the job at each tier.

What You Need Before Starting

Step 1: Install Dependencies

pip install openai anthropic

Step 2: Define Your Model Priority Chain

Build a priority list per task type. Not every task needs the same fallback chain.

from dataclasses import dataclass
from typing import Optional

@dataclass
class ModelConfig:
    provider: str        # "openai", "anthropic", "google"
    model_id: str
    priority: int        # Lower = try first
    max_tokens: int
    cost_per_1k_input: float
    cost_per_1k_output: float
    notes: str = ""

# Define your model pool
MODEL_POOL = {
    "gpt-4o": ModelConfig(
        provider="openai", model_id="gpt-4o",
        priority=1, max_tokens=4096,
        cost_per_1k_input=0.005, cost_per_1k_output=0.015,
        notes="Primary for complex reasoning"
    ),
    "claude-3-5-sonnet": ModelConfig(
        provider="anthropic", model_id="claude-3-5-sonnet-20241022",
        priority=2, max_tokens=4096,
        cost_per_1k_input=0.003, cost_per_1k_output=0.015,
        notes="Primary fallback, comparable quality"
    ),
    "gpt-4o-mini": ModelConfig(
        provider="openai", model_id="gpt-4o-mini",
        priority=3, max_tokens=4096,
        cost_per_1k_input=0.00015, cost_per_1k_output=0.0006,
        notes="Cost fallback for simple tasks"
    ),
    "claude-3-haiku": ModelConfig(
        provider="anthropic", model_id="claude-3-haiku-20240307",
        priority=4, max_tokens=4096,
        cost_per_1k_input=0.00025, cost_per_1k_output=0.00125,
        notes="Last resort, fast and cheap"
    ),
}

# Task-specific chains
FALLBACK_CHAINS = {
    "complex_reasoning":  ["gpt-4o", "claude-3-5-sonnet", "gpt-4o-mini"],
    "simple_generation":  ["gpt-4o-mini", "claude-3-haiku", "gpt-4o"],
    "document_analysis":  ["claude-3-5-sonnet", "gpt-4o", "gpt-4o-mini"],
    "code_generation":    ["gpt-4o", "claude-3-5-sonnet", "gpt-4o-mini"],
    "default":            ["gpt-4o", "claude-3-5-sonnet", "gpt-4o-mini", "claude-3-haiku"],
}

Step 3: Build Provider-Agnostic Callers

Abstract each provider into a common interface.

import openai
import anthropic

openai_client = openai.OpenAI(api_key="YOUR_OPENAI_KEY")
anthropic_client = anthropic.Anthropic(api_key="YOUR_ANTHROPIC_KEY")

PROVIDER_ERRORS = {
    "openai": (
        openai.RateLimitError,
        openai.APITimeoutError,
        openai.APIConnectionError,
        openai.InternalServerError,
        openai.APIStatusError,
    ),
    "anthropic": (
        anthropic.RateLimitError,
        anthropic.APITimeoutError,
        anthropic.APIConnectionError,
        anthropic.InternalServerError,
    )
}

def call_openai(model_id: str, messages: list, max_tokens: int = 1000, temperature: float = 0.3) -> str:
    response = openai_client.chat.completions.create(
        model=model_id,
        messages=messages,
        max_tokens=max_tokens,
        temperature=temperature
    )
    return response.choices[0].message.content

def call_anthropic(model_id: str, messages: list, max_tokens: int = 1000, temperature: float = 0.3) -> str:
    # Anthropic uses separate system message
    system = next((m["content"] for m in messages if m["role"] == "system"), None)
    user_messages = [m for m in messages if m["role"] != "system"]

    kwargs = {
        "model": model_id,
        "messages": user_messages,
        "max_tokens": max_tokens,
    }
    if system:
        kwargs["system"] = system

    response = anthropic_client.messages.create(**kwargs)
    return response.content[0].text

def call_model(config: ModelConfig, messages: list, max_tokens: int, temperature: float) -> str:
    if config.provider == "openai":
        return call_openai(config.model_id, messages, max_tokens, temperature)
    elif config.provider == "anthropic":
        return call_anthropic(config.model_id, messages, max_tokens, temperature)
    else:
        raise ValueError(f"Unknown provider: {config.provider}")

Step 4: Build the Fallback Orchestrator

Try each model in the chain until one succeeds.

import time
import logging
from datetime import datetime

logger = logging.getLogger(__name__)

class FallbackResult:
    def __init__(self, content: str, model_used: str, fallback_count: int, latency_ms: int):
        self.content = content
        self.model_used = model_used
        self.fallback_count = fallback_count
        self.latency_ms = latency_ms
        self.used_fallback = fallback_count > 0

def call_with_fallback(
    messages: list,
    task_type: str = "default",
    max_tokens: int = 1000,
    temperature: float = 0.3,
    timeout_seconds: int = 30
) -> FallbackResult:
    chain = FALLBACK_CHAINS.get(task_type, FALLBACK_CHAINS["default"])
    start_time = time.time()

    for attempt, model_name in enumerate(chain):
        config = MODEL_POOL[model_name]

        try:
            logger.info(f"Trying {model_name} (attempt {attempt + 1}/{len(chain)})")
            content = call_model(config, messages, max_tokens, temperature)
            latency = int((time.time() - start_time) * 1000)

            if attempt > 0:
                logger.warning(f"Used fallback: {model_name} after {attempt} failure(s)")

            return FallbackResult(
                content=content,
                model_used=model_name,
                fallback_count=attempt,
                latency_ms=latency
            )

        except PROVIDER_ERRORS.get(config.provider, (Exception,)) as e:
            logger.warning(f"{model_name} failed: {type(e).__name__}: {str(e)[:100]}")
            if attempt < len(chain) - 1:
                time.sleep(min(2 ** attempt, 10))  # Exponential backoff
            continue

        except Exception as e:
            logger.error(f"Unexpected error on {model_name}: {e}")
            if attempt < len(chain) - 1:
                continue
            raise

    raise RuntimeError(f"All models in fallback chain failed for task_type={task_type}")

Step 5: Add Health Tracking

Track which models are failing so you can route around degraded providers proactively.

import sqlite3
from collections import defaultdict

def init_health_db():
    conn = sqlite3.connect("model_health.db")
    conn.execute("""
        CREATE TABLE IF NOT EXISTS model_calls (
            id INTEGER PRIMARY KEY AUTOINCREMENT,
            timestamp TEXT,
            model_name TEXT,
            success INTEGER,
            latency_ms INTEGER,
            error_type TEXT
        )
    """)
    conn.commit()
    conn.close()

def log_model_call(model_name: str, success: bool, latency_ms: int, error_type: str = None):
    conn = sqlite3.connect("model_health.db")
    conn.execute(
        "INSERT INTO model_calls (timestamp, model_name, success, latency_ms, error_type) VALUES (?, ?, ?, ?, ?)",
        (datetime.now().isoformat(), model_name, 1 if success else 0, latency_ms, error_type)
    )
    conn.commit()
    conn.close()

def get_model_health(window_minutes: int = 15) -> dict:
    conn = sqlite3.connect("model_health.db")
    rows = conn.execute("""
        SELECT model_name, 
               COUNT(*) as total,
               SUM(success) as successes,
               AVG(latency_ms) as avg_latency
        FROM model_calls
        WHERE timestamp > datetime('now', ?)
        GROUP BY model_name
    """, (f"-{window_minutes} minutes",)).fetchall()
    conn.close()

    health = {}
    for row in rows:
        name, total, successes, avg_latency = row
        success_rate = successes / total if total > 0 else 1.0
        health[name] = {
            "success_rate": round(success_rate, 3),
            "avg_latency_ms": round(avg_latency or 0),
            "total_calls": total,
            "degraded": success_rate < 0.8
        }
    return health

init_health_db()

Step 6: Use Health Data to Reorder the Fallback Chain

Skip models that are currently degraded.

def call_with_smart_fallback(messages: list, task_type: str = "default", **kwargs) -> FallbackResult:
    health = get_model_health(window_minutes=10)
    chain = FALLBACK_CHAINS.get(task_type, FALLBACK_CHAINS["default"])

    # Move degraded models to the end
    healthy = [m for m in chain if not health.get(m, {}).get("degraded", False)]
    degraded = [m for m in chain if health.get(m, {}).get("degraded", False)]
    reordered_chain = healthy + degraded

    if reordered_chain != chain:
        logger.info(f"Reordered chain due to health: {reordered_chain}")

    # Temporarily override the chain
    original = FALLBACK_CHAINS.get(task_type)
    FALLBACK_CHAINS[task_type] = reordered_chain
    try:
        return call_with_fallback(messages, task_type, **kwargs)
    finally:
        FALLBACK_CHAINS[task_type] = original

What to Build Next

Related Reading

Want this system built for your business?

Get a free assessment. We will map every system your business needs and show you the ROI.

Get Your Free Assessment

Related Systems