Systems Library / AI Model Setup / How to Implement Semantic Caching for AI Queries
AI Model Setup routing optimization

How to Implement Semantic Caching for AI Queries

Cache similar AI queries to avoid redundant API calls and reduce costs by 30%.

Jay Banlasan

Jay Banlasan

The AI Systems Guy

A support chatbot I was running was getting the same questions answered over and over. "What are your payment terms?" "How do I cancel?" "What's your refund policy?" Different wording, same semantic intent, same answer. Every unique phrasing hit the API. Semantic caching ai api cost reduction cut that chatbot's API spend by 34% in the first week without changing a single answer. Similar queries now return cached results instantly, and the API only fires when a genuinely new question comes in.

Regular string-match caching is too brittle for natural language. "What's your refund policy?" and "Can I get a refund?" are completely different strings but semantically near-identical. Semantic caching uses embeddings to measure meaning distance and returns cached answers when two queries are close enough.

What You Need Before Starting

Step 1: Build the Embedding Function

Embeddings convert text into a vector. Vectors close together in space = semantically similar content.

import anthropic
import numpy as np
import json

_client = anthropic.Anthropic()

def embed(text: str) -> list[float]:
    """Get text embedding via Anthropic Voyage API or OpenAI."""
    # Using OpenAI embeddings (widely available, cheap)
    import openai
    import os
    oa = openai.OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
    response = oa.embeddings.create(
        model="text-embedding-3-small",
        input=text[:8000]  # truncate to model limit
    )
    return response.data[0].embedding

def cosine_similarity(vec_a: list[float], vec_b: list[float]) -> float:
    a = np.array(vec_a)
    b = np.array(vec_b)
    return float(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)))

text-embedding-3-small costs $0.00002 per 1K tokens. That's roughly 100x cheaper than running the same query through a chat model. The math works: spending $0.0001 on an embedding to potentially save $0.01 on a chat call is a 100x return.

Step 2: Build the SQLite Cache Store

Store embeddings alongside cached responses.

import sqlite3
import hashlib
from datetime import datetime

CACHE_DB = "semantic_cache.db"

def init_cache():
    conn = sqlite3.connect(CACHE_DB)
    conn.executescript("""
        CREATE TABLE IF NOT EXISTS cache_entries (
            id          INTEGER PRIMARY KEY AUTOINCREMENT,
            query_hash  TEXT NOT NULL,
            query_text  TEXT NOT NULL,
            embedding   TEXT NOT NULL,
            response    TEXT NOT NULL,
            model       TEXT NOT NULL,
            created_at  TEXT NOT NULL,
            hit_count   INTEGER DEFAULT 0,
            last_hit    TEXT
        );
        CREATE INDEX IF NOT EXISTS idx_hash ON cache_entries(query_hash);
    """)
    conn.commit()
    conn.close()

def save_to_cache(query: str, embedding: list[float],
                   response: str, model: str):
    query_hash = hashlib.md5(query.lower().strip().encode()).hexdigest()
    conn = sqlite3.connect(CACHE_DB)
    conn.execute("""
        INSERT INTO cache_entries 
        (query_hash, query_text, embedding, response, model, created_at)
        VALUES (?,?,?,?,?,?)
    """, (query_hash, query, json.dumps(embedding),
          response, model, datetime.utcnow().isoformat()))
    conn.commit()
    conn.close()

def increment_hit(entry_id: int):
    conn = sqlite3.connect(CACHE_DB)
    conn.execute("""
        UPDATE cache_entries 
        SET hit_count = hit_count + 1, last_hit = ?
        WHERE id = ?
    """, (datetime.utcnow().isoformat(), entry_id))
    conn.commit()
    conn.close()

Step 3: Build the Cache Lookup

The lookup embeds the incoming query and checks similarity against all cached entries.

def find_cached_response(
    query: str,
    similarity_threshold: float = 0.92,
    max_age_days: int = 30
) -> dict | None:
    query_embedding = embed(query)
    
    conn = sqlite3.connect(CACHE_DB)
    conn.row_factory = sqlite3.Row
    rows = conn.execute("""
        SELECT id, query_text, embedding, response, model, created_at
        FROM cache_entries
        WHERE created_at >= datetime('now', ?)
        ORDER BY created_at DESC
        LIMIT 500
    """, (f'-{max_age_days} days',)).fetchall()
    conn.close()
    
    best_match = None
    best_score = 0.0
    
    for row in rows:
        cached_embedding = json.loads(row["embedding"])
        similarity = cosine_similarity(query_embedding, cached_embedding)
        
        if similarity > best_score:
            best_score = similarity
            best_match = dict(row)
            best_match["similarity"] = similarity
    
    if best_match and best_score >= similarity_threshold:
        increment_hit(best_match["id"])
        return best_match
    
    return None

The threshold of 0.92 is a good starting point. Lower it to 0.88 for higher recall (more cache hits, slightly more risk of wrong answers). Raise it to 0.95 for high-precision use cases like legal or medical content.

Step 4: Build the Cached AI Call Wrapper

The main interface. Checks cache first, falls back to live API, then stores the new result.

def cached_ai_call(
    query: str,
    system_prompt: str = "",
    model: str = "claude-haiku-3",
    similarity_threshold: float = 0.92,
    cache_ttl_days: int = 30
) -> dict:
    # Step 1: Check cache
    cached = find_cached_response(query, similarity_threshold, cache_ttl_days)
    
    if cached:
        return {
            "response":   cached["response"],
            "source":     "cache",
            "similarity": round(cached["similarity"], 4),
            "cached_query": cached["query_text"],
            "model":      cached["model"]
        }
    
    # Step 2: Cache miss — call the API
    messages = [{"role": "user", "content": query}]
    kwargs = {"model": model, "max_tokens": 1024, "messages": messages}
    if system_prompt:
        kwargs["system"] = system_prompt
    
    response = _client.messages.create(**kwargs)
    response_text = response.content[0].text
    
    # Step 3: Embed and store for future hits
    query_embedding = embed(query)
    save_to_cache(query, query_embedding, response_text, model)
    
    return {
        "response": response_text,
        "source":   "api",
        "similarity": None,
        "model":    model
    }

Step 5: Calibrate Your Threshold with Real Data

Before shipping to production, run a calibration test on real queries from your use case.

def calibrate_threshold(query_pairs: list[tuple[str, str, bool]]) -> dict:
    """
    query_pairs: list of (query_a, query_b, should_match)
    Returns recommended threshold and accuracy at different cutoffs.
    """
    scored = []
    for qa, qb, should_match in query_pairs:
        emb_a = embed(qa)
        emb_b = embed(qb)
        sim = cosine_similarity(emb_a, emb_b)
        scored.append({"qa": qa, "qb": qb, "sim": sim, "expected": should_match})
    
    results = {}
    for threshold in [0.85, 0.88, 0.90, 0.92, 0.95]:
        tp = sum(1 for s in scored if s["sim"] >= threshold and s["expected"])
        fp = sum(1 for s in scored if s["sim"] >= threshold and not s["expected"])
        fn = sum(1 for s in scored if s["sim"] < threshold and s["expected"])
        precision = tp / (tp + fp) if (tp + fp) > 0 else 0
        recall    = tp / (tp + fn) if (tp + fn) > 0 else 0
        results[threshold] = {"precision": precision, "recall": recall}
    
    return results

# Example calibration data for a support bot
test_pairs = [
    ("What's your refund policy?", "Can I get a refund?", True),
    ("How do I cancel?", "I want to cancel my subscription", True),
    ("What are your prices?", "How much does it cost?", True),
    ("What's your refund policy?", "How do I contact support?", False),
    ("Cancel my account", "Reset my password", False),
]
print(calibrate_threshold(test_pairs))

Step 6: Add Cache Analytics

Track hit rates so you know the system is working.

def cache_stats() -> dict:
    conn = sqlite3.connect(CACHE_DB)
    row = conn.execute("""
        SELECT COUNT(*) as entries,
               SUM(hit_count) as total_hits,
               AVG(hit_count) as avg_hits,
               MAX(hit_count) as max_hits
        FROM cache_entries
    """).fetchone()
    conn.close()
    return {
        "cache_entries": row[0],
        "total_cache_hits": row[1] or 0,
        "avg_hits_per_entry": round(row[2] or 0, 2),
        "most_hit_count": row[3] or 0
    }

# Cache hit rate = total_hits / (total_hits + api_calls)
# Log api_calls separately in your usage tracking

What to Build Next

Related Reading

Want this system built for your business?

Get a free assessment. We will map every system your business needs and show you the ROI.

Get Your Free Assessment

Related Systems