How to Build an AI Load Balancer Across Providers
Distribute AI requests across providers to avoid rate limits and outages.
Jay Banlasan
The AI Systems Guy
Rate limits kill batch jobs and spike latency during peak hours. A load balancer distributes requests across multiple AI providers or API keys so no single endpoint gets saturated. I use this pattern for high-volume pipelines that need to process thousands of requests without hitting throttling or queuing delays.
The approach is straightforward: maintain a pool of API endpoints, track current load and recent error rates on each, and route each request to the endpoint with the most available capacity. When one endpoint is degraded or rate-limited, traffic automatically shifts to the others.
What You Need Before Starting
- Python 3.9+
- API keys for at least 2 providers or 2 API keys for the same provider (different billing accounts)
aiohttpfor async requests if you need concurrent processing- Basic understanding of threading or asyncio
Step 1: Define Your Endpoint Pool
from dataclasses import dataclass, field
from threading import Lock
from datetime import datetime, timedelta
import time
@dataclass
class Endpoint:
name: str
provider: str
model_id: str
api_key: str
requests_per_minute: int # Your tier's rate limit
requests_per_day: int # Daily limit
weight: float = 1.0 # Higher weight = more traffic
# Runtime tracking (not set manually)
current_minute_requests: int = 0
current_day_requests: int = 0
last_error_at: float = 0.0
consecutive_errors: int = 0
total_requests: int = 0
total_errors: int = 0
last_minute_reset: float = field(default_factory=time.time)
last_day_reset: float = field(default_factory=time.time)
lock: Lock = field(default_factory=Lock)
@property
def available_capacity(self) -> float:
"""0.0 = no capacity, 1.0 = full capacity."""
minute_usage = self.current_minute_requests / self.requests_per_minute
day_usage = self.current_day_requests / self.requests_per_day
return max(0.0, min(1.0 - minute_usage, 1.0 - day_usage))
@property
def is_healthy(self) -> bool:
if self.consecutive_errors >= 3:
# Wait 60 seconds after 3 consecutive errors
if time.time() - self.last_error_at < 60:
return False
return True
def reset_minute_counter(self):
if time.time() - self.last_minute_reset >= 60:
self.current_minute_requests = 0
self.last_minute_reset = time.time()
def reset_day_counter(self):
if time.time() - self.last_day_reset >= 86400:
self.current_day_requests = 0
self.last_day_reset = time.time()
Step 2: Build the Endpoint Pool Manager
import random
class EndpointPool:
def __init__(self, endpoints: list[Endpoint]):
self.endpoints = endpoints
self.lock = Lock()
def get_best_endpoint(self) -> Endpoint | None:
with self.lock:
# Reset counters for all endpoints
for ep in self.endpoints:
ep.reset_minute_counter()
ep.reset_day_counter()
# Filter to healthy endpoints with capacity
candidates = [
ep for ep in self.endpoints
if ep.is_healthy and ep.available_capacity > 0
]
if not candidates:
return None
# Weighted selection based on available capacity and weight
scores = [ep.available_capacity * ep.weight for ep in candidates]
total_score = sum(scores)
if total_score == 0:
return None
# Weighted random selection
r = random.uniform(0, total_score)
cumulative = 0
for ep, score in zip(candidates, scores):
cumulative += score
if r <= cumulative:
return ep
return candidates[-1]
def record_success(self, endpoint: Endpoint):
with self.lock:
endpoint.current_minute_requests += 1
endpoint.current_day_requests += 1
endpoint.total_requests += 1
endpoint.consecutive_errors = 0
def record_error(self, endpoint: Endpoint, is_rate_limit: bool = False):
with self.lock:
endpoint.total_requests += 1
endpoint.total_errors += 1
endpoint.last_error_at = time.time()
endpoint.consecutive_errors += 1
if is_rate_limit:
# Temporarily inflate the request counter to back off
endpoint.current_minute_requests = endpoint.requests_per_minute
def get_stats(self) -> list[dict]:
return [
{
"name": ep.name,
"capacity": round(ep.available_capacity, 3),
"healthy": ep.is_healthy,
"total_requests": ep.total_requests,
"error_rate": round(ep.total_errors / max(1, ep.total_requests) * 100, 1),
"consecutive_errors": ep.consecutive_errors
}
for ep in self.endpoints
]
Step 3: Build the Load-Balanced Caller
import openai
import anthropic
def call_openai_endpoint(endpoint: Endpoint, messages: list, max_tokens: int, temperature: float) -> str:
client = openai.OpenAI(api_key=endpoint.api_key)
response = client.chat.completions.create(
model=endpoint.model_id,
messages=messages,
max_tokens=max_tokens,
temperature=temperature
)
return response.choices[0].message.content
def call_anthropic_endpoint(endpoint: Endpoint, messages: list, max_tokens: int, temperature: float) -> str:
client = anthropic.Anthropic(api_key=endpoint.api_key)
system = next((m["content"] for m in messages if m["role"] == "system"), None)
user_messages = [m for m in messages if m["role"] != "system"]
kwargs = {"model": endpoint.model_id, "messages": user_messages, "max_tokens": max_tokens}
if system:
kwargs["system"] = system
response = client.messages.create(**kwargs)
return response.content[0].text
PROVIDER_CALLERS = {
"openai": call_openai_endpoint,
"anthropic": call_anthropic_endpoint,
}
def load_balanced_complete(
pool: EndpointPool,
messages: list,
max_tokens: int = 800,
temperature: float = 0.3,
max_retries: int = 3
) -> dict:
last_error = None
for attempt in range(max_retries):
endpoint = pool.get_best_endpoint()
if endpoint is None:
raise RuntimeError("No available endpoints in pool")
caller = PROVIDER_CALLERS.get(endpoint.provider)
if not caller:
raise ValueError(f"Unknown provider: {endpoint.provider}")
try:
content = caller(endpoint, messages, max_tokens, temperature)
pool.record_success(endpoint)
return {
"content": content,
"endpoint": endpoint.name,
"attempt": attempt + 1
}
except openai.RateLimitError as e:
pool.record_error(endpoint, is_rate_limit=True)
last_error = e
time.sleep(1)
continue
except (openai.APIConnectionError, openai.InternalServerError,
anthropic.APIConnectionError, anthropic.InternalServerError) as e:
pool.record_error(endpoint, is_rate_limit=False)
last_error = e
time.sleep(2 ** attempt)
continue
except Exception as e:
pool.record_error(endpoint, is_rate_limit=False)
raise
raise RuntimeError(f"All {max_retries} attempts failed. Last error: {last_error}")
Step 4: Set Up a Real Pool
import os
pool = EndpointPool([
Endpoint(
name="openai-primary",
provider="openai",
model_id="gpt-4o-mini",
api_key=os.getenv("OPENAI_API_KEY_1"),
requests_per_minute=500,
requests_per_day=10000,
weight=2.0 # Prefer this endpoint
),
Endpoint(
name="openai-secondary",
provider="openai",
model_id="gpt-4o-mini",
api_key=os.getenv("OPENAI_API_KEY_2"),
requests_per_minute=500,
requests_per_day=10000,
weight=1.0
),
Endpoint(
name="anthropic-fallback",
provider="anthropic",
model_id="claude-3-haiku-20240307",
api_key=os.getenv("ANTHROPIC_API_KEY"),
requests_per_minute=300,
requests_per_day=5000,
weight=0.5 # Use less often, holds capacity for failover
),
])
# Use it
result = load_balanced_complete(
pool,
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Summarize this in one sentence: ..."}
]
)
print(f"Response via {result['endpoint']}: {result['content']}")
Step 5: Build a Concurrent Batch Processor
Process large batches in parallel across the pool.
from concurrent.futures import ThreadPoolExecutor, as_completed
def process_batch_concurrent(
pool: EndpointPool,
requests: list[list], # List of message arrays
max_workers: int = 10,
**kwargs
) -> list[dict]:
results = [None] * len(requests)
def process_one(idx_messages):
idx, messages = idx_messages
try:
return idx, load_balanced_complete(pool, messages, **kwargs)
except Exception as e:
return idx, {"error": str(e), "content": None}
with ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = {
executor.submit(process_one, (i, msgs)): i
for i, msgs in enumerate(requests)
}
completed = 0
for future in as_completed(futures):
idx, result = future.result()
results[idx] = result
completed += 1
if completed % 50 == 0:
stats = pool.get_stats()
print(f"Progress: {completed}/{len(requests)} | Pool stats: {stats}")
return results
What to Build Next
- Add a circuit breaker that removes an endpoint from the pool for a configurable backoff period when its error rate exceeds a threshold, rather than retrying it on every request
- Build a cost-aware load balancer extension that weights cheaper endpoints more heavily during budget-constrained periods
- Implement latency-weighted routing that shifts traffic away from slow endpoints based on rolling p95 latency measurements
Related Reading
- How to Build a Multi-Model AI Router - the router decides what to run; the load balancer decides where to run it
- How to Implement Cost-Based AI Model Selection - layer cost selection on top of load balancing for full control
- How to Set Up LiteLLM as Your AI Gateway - LiteLLM includes load balancing and can replace this custom implementation for many use cases
Want this system built for your business?
Get a free assessment. We will map every system your business needs and show you the ROI.
Get Your Free Assessment