The retrieval phase is the core latency challenge of context assembly. Three independent data sources must be queried within 40ms total: 1. Directive — Redis lookup by agent_id (< 1ms warm)
Milestone 3.5.3 — Memory Retrieval Pipeline (Parallel, 40ms Deadline)
Status: Planned
Goal: 3.5 — Context assembly engine
Phase: 3 — Memory Engine and Operator Platform
Estimated effort: 3–4 days
ADR required: None (follows ADR-0038 retrieval design)
Why This Milestone Exists
The retrieval phase is the core latency challenge of context assembly. Three independent data sources must be queried within 40ms total:
- Directive — Redis lookup by agent_id (< 1ms warm)
- Hot memories — Redis sorted set ZREVRANGE (< 5ms)
- Cold semantic search — Embed the current query + pgvector ANN search (< 30ms)
Running these sequentially would take 36ms minimum — leaving only 4ms for scoring, packing, and formatting. They must run concurrently via asyncio.gather. Each has an independent timeout. Partial results from any source are used; no source blocks the others.
Deliverables
src/context/services/retrieval.py
from __future__ import annotations
import asyncio
from dataclasses import dataclass, field
from uuid import UUID
import httpx
import redis.asyncio as aioredis
from ibex_proto.context.v1 import Message
from context.services.budget import TokenBudget
DIRECTIVE_TIMEOUT_SECONDS = 0.010 # 10ms; Redis should always be < 2ms
HOT_CACHE_TIMEOUT_SECONDS = 0.010 # 10ms
COLD_SEARCH_TIMEOUT_SECONDS = 0.035 # 35ms; includes embed + vector search
HOT_CACHE_MEMORIES_LIMIT = 20 # fetch top-20 from hot cache
COLD_SEARCH_MEMORIES_LIMIT = 50 # fetch top-50 from cold search
@dataclass
class RetrievalResult:
directive: str = ""
query_embedding: list[float] = field(default_factory=list)
memories: list[dict] = field(default_factory=list) # raw memory records
directive_from: str = "none" # "cache", "db", "none"
hot_cache_count: int = 0
cold_search_count: int = 0
retrieval_ms: float = 0.0
class ContextRetriever:
"""
Orchestrates parallel retrieval of directive, hot memories, and cold memories.
All three retrievals run concurrently; each has an independent timeout.
The retrieval as a whole is bounded by asyncio.wait_for in the caller (3.5.1).
"""
def __init__(
self,
redis: aioredis.Redis,
embedder: httpx.AsyncClient,
memory_service: httpx.AsyncClient,
embedder_url: str,
memory_service_url: str,
) -> None:
self._redis = redis
self._embedder = embedder
self._memory = memory_service
self._embed_url = embedder_url
self._memory_url = memory_service_url
async def retrieve(
self,
org_id: UUID,
agent_id: UUID,
session_id: UUID | None,
messages: list[Message],
budget: TokenBudget,
) -> RetrievalResult:
"""
Run all three retrieval operations concurrently.
Never raises; uses partial results on any individual failure.
"""
# Extract the last user message for embedding (the semantic query)
query_text = self._extract_query(messages)
# Launch all three concurrently — DO NOT await in sequence
directive_task = asyncio.create_task(self._get_directive(agent_id))
hot_cache_task = asyncio.create_task(self._get_hot_memories(org_id, agent_id))
embedding_task = asyncio.create_task(self._embed_query(query_text))
# Wait for embedding before launching cold search (cold search needs the vector)
directive, hot_mems, query_embedding = await asyncio.gather(
directive_task, hot_cache_task, embedding_task,
return_exceptions=True,
)
# Resolve exceptions to defaults (fail open on any retrieval)
directive = directive if isinstance(directive, str) else ""
hot_mems = hot_mems if isinstance(hot_mems, list) else []
query_embedding = query_embedding if isinstance(query_embedding, list) else []
# Now launch cold search if we have an embedding
cold_mems: list[dict] = []
if query_embedding and not budget.is_constrained:
try:
cold_mems = await asyncio.wait_for(
self._cold_search(org_id, agent_id, query_embedding),
timeout=COLD_SEARCH_TIMEOUT_SECONDS,
)
except (asyncio.TimeoutError, Exception):
cold_mems = []
# Merge hot and cold results, deduplicate by memory_id
all_memories = self._merge_dedup(hot_mems, cold_mems)
return RetrievalResult(
directive=directive,
query_embedding=query_embedding,
memories=all_memories,
hot_cache_count=len(hot_mems),
cold_search_count=len(cold_mems),
)
async def _get_directive(self, agent_id: UUID) -> str:
"""Fetch active directive content from Redis. Returns '' on miss."""
key = f"directive:{agent_id}"
data = await asyncio.wait_for(
self._redis.get(key), timeout=DIRECTIVE_TIMEOUT_SECONDS
)
return data.decode() if data else ""
async def _get_hot_memories(self, org_id: UUID, agent_id: UUID) -> list[dict]:
"""Fetch top-20 memory IDs from Redis sorted set, then bulk-fetch from memory service."""
key = f"{org_id}:hot_memories:{agent_id}"
ids = await asyncio.wait_for(
self._redis.zrevrange(key, 0, HOT_CACHE_MEMORIES_LIMIT - 1),
timeout=HOT_CACHE_TIMEOUT_SECONDS,
)
if not ids:
return []
resp = await self._memory.post(
f"{self._memory_url}/internal/memories/bulk",
json={"ids": [i.decode() for i in ids]},
timeout=5.0,
)
resp.raise_for_status()
return resp.json()["memories"]
async def _embed_query(self, text: str) -> list[float]:
"""Embed the query text via the embedding service (cache-aware)."""
if not text.strip():
return []
resp = await asyncio.wait_for(
self._embedder.post(
f"{self._embed_url}/v1/embed",
json={"text": text[:2000]}, # truncate very long queries for embedding
timeout=5.0,
),
timeout=COLD_SEARCH_TIMEOUT_SECONDS,
)
resp.raise_for_status()
return resp.json()["embedding"]
async def _cold_search(
self, org_id: UUID, agent_id: UUID, embedding: list[float]
) -> list[dict]:
"""Semantic search via the memory service's vector search endpoint."""
resp = await self._memory.post(
f"{self._memory_url}/internal/memories/search",
json={
"org_id": str(org_id),
"agent_id": str(agent_id),
"embedding": embedding,
"limit": COLD_SEARCH_MEMORIES_LIMIT,
},
timeout=5.0,
)
resp.raise_for_status()
return resp.json()["memories"]
@staticmethod
def _extract_query(messages: list[Message]) -> str:
"""Extract the most recent user message as the semantic query."""
for msg in reversed(messages):
if msg.role == "user" and msg.content:
return msg.content
return ""
@staticmethod
def _merge_dedup(hot: list[dict], cold: list[dict]) -> list[dict]:
"""Merge hot and cold memory lists, deduplicating by memory ID."""
seen: set[str] = set()
result: list[dict] = []
for mem in hot + cold:
mid = mem.get("id", "")
if mid and mid not in seen:
seen.add(mid)
result.append(mem)
return resultAcceptance Criteria
- Directive, hot cache, and embedding retrievals run concurrently (not sequentially)
- Individual timeout for each retrieval source; one failure does not block others
- Results merged and deduplicated by memory ID
-
RetrievalResult.memoriesnever contains duplicates - Empty query text (no user message) returns empty
query_embeddingwithout error - Integration test: inject 30 memories; verify cold search returns them within 35ms against testcontainers Postgres
Last updated on