How to Optimize RAG for Large Document Collections
Scale RAG systems to handle thousands of documents without degrading quality.
Jay Banlasan
The AI Systems Guy
When you optimize rag for large document collections at scale, the challenge shifts from accuracy to speed and relevance. I build these for organizations with 10,000+ documents where naive RAG gets slow and retrieval quality degrades. The system uses pre-filtering, tiered indexing, and caching to keep query times under 2 seconds even at scale.
Small RAG systems forgive sloppy architecture. Large ones do not.
What You Need Before Starting
- A RAG system showing signs of degradation (slow queries, irrelevant results)
- Python 3.8+ with chromadb or Pinecone
- Document metadata (categories, dates, departments)
- Performance benchmarks from your current system
Step 1: Implement Category Pre-Filtering
Do not search the entire index. Narrow by category first:
def smart_query(question, collection, user_context=None):
category = classify_query_category(question)
where_filter = {}
if category != "general":
where_filter["category"] = category
if user_context and user_context.get("department"):
where_filter["department"] = {"$in": [user_context["department"], "general"]}
from sentence_transformers import SentenceTransformer
model = SentenceTransformer("all-MiniLM-L6-v2")
query_embedding = model.encode(question).tolist()
results = collection.query(
query_embeddings=[query_embedding],
n_results=5,
where=where_filter if where_filter else None
)
return results
Step 2: Build a Tiered Index
Split documents into tiers by importance and freshness:
import chromadb
chroma = chromadb.PersistentClient(path="./tiered_rag")
tiers = {
"current": chroma.get_or_create_collection("current"), # Last 90 days
"recent": chroma.get_or_create_collection("recent"), # Last year
"archive": chroma.get_or_create_collection("archive"), # Older
}
def tiered_search(question, top_k=5):
from sentence_transformers import SentenceTransformer
model = SentenceTransformer("all-MiniLM-L6-v2")
query_embedding = model.encode(question).tolist()
all_results = []
for tier_name, tier_collection in tiers.items():
tier_k = {"current": top_k, "recent": 3, "archive": 2}[tier_name]
results = tier_collection.query(query_embeddings=[query_embedding], n_results=tier_k)
for i in range(len(results["ids"][0])):
tier_boost = {"current": 1.2, "recent": 1.0, "archive": 0.8}[tier_name]
all_results.append({
"id": results["ids"][0][i],
"score": (1 - results["distances"][0][i]) * tier_boost,
"content": results["documents"][0][i],
"tier": tier_name
})
all_results.sort(key=lambda x: x["score"], reverse=True)
return all_results[:top_k]
Step 3: Add Query Caching
import hashlib
import json
cache = {}
def cached_query(question, collection, ttl_seconds=3600):
cache_key = hashlib.md5(question.encode()).hexdigest()
if cache_key in cache:
entry = cache[cache_key]
if (datetime.now() - entry["cached_at"]).seconds < ttl_seconds:
return entry["results"]
results = smart_query(question, collection)
cache[cache_key] = {"results": results, "cached_at": datetime.now()}
return results
Step 4: Optimize Chunk Sizes for Scale
def benchmark_chunk_sizes(documents, test_questions):
results = {}
for chunk_size in [500, 800, 1200, 2000]:
chunks = rechunk_documents(documents, chunk_size)
temp_collection = create_temp_collection(chunks)
import time
start = time.time()
accuracy = evaluate_accuracy(temp_collection, test_questions)
query_time = (time.time() - start) / len(test_questions)
results[chunk_size] = {
"total_chunks": len(chunks),
"accuracy": accuracy,
"avg_query_ms": round(query_time * 1000, 1)
}
return results
Step 5: Monitor Performance
import time
def monitored_query(question, collection):
start = time.time()
results = smart_query(question, collection)
query_time = time.time() - start
conn = sqlite3.connect("rag_perf.db")
conn.execute("""
INSERT INTO query_performance (query, results_count, query_time_ms, queried_at)
VALUES (?, ?, ?, datetime('now'))
""", (question, len(results.get("ids", [[]])[0]), round(query_time * 1000, 1)))
conn.commit()
if query_time > 2.0:
log_slow_query(question, query_time)
return results
def get_performance_report():
conn = sqlite3.connect("rag_perf.db")
return {
"avg_query_ms": conn.execute("SELECT AVG(query_time_ms) FROM query_performance WHERE queried_at > datetime('now', '-7 days')").fetchone()[0],
"p95_query_ms": conn.execute("SELECT query_time_ms FROM query_performance WHERE queried_at > datetime('now', '-7 days') ORDER BY query_time_ms DESC LIMIT 1 OFFSET (SELECT COUNT(*)*5/100 FROM query_performance WHERE queried_at > datetime('now', '-7 days'))").fetchone(),
"slow_queries": conn.execute("SELECT COUNT(*) FROM query_performance WHERE query_time_ms > 2000 AND queried_at > datetime('now', '-7 days')").fetchone()[0]
}
What to Build Next
Add automatic index optimization. Periodically analyze which chunks never get retrieved and move them to cold storage. Focus the hot index on the 20% of content that answers 80% of questions.
Related Reading
- The Scalability Test - testing whether your RAG system can handle growth
- Real-Time vs Batch Processing - when to batch-process RAG updates
- The Complexity Trap - keeping scaled RAG systems manageable
Want this system built for your business?
Get a free assessment. We will map every system your business needs and show you the ROI.
Get Your Free Assessment