Phase 3 memory engine

The retrieval phase is the core latency challenge of context assembly. Three independent data sources must be queried within 40ms total: 1. Directive — Redis lookup by agent_id (< 1ms warm)

Milestone 3.5.3 — Memory Retrieval Pipeline (Parallel, 40ms Deadline)

Status: Planned
Goal: 3.5 — Context assembly engine
Phase: 3 — Memory Engine and Operator Platform
Estimated effort: 3–4 days
ADR required: None (follows ADR-0038 retrieval design)


Why This Milestone Exists

The retrieval phase is the core latency challenge of context assembly. Three independent data sources must be queried within 40ms total:

  1. Directive — Redis lookup by agent_id (< 1ms warm)
  2. Hot memories — Redis sorted set ZREVRANGE (< 5ms)
  3. Cold semantic search — Embed the current query + pgvector ANN search (< 30ms)

Running these sequentially would take 36ms minimum — leaving only 4ms for scoring, packing, and formatting. They must run concurrently via asyncio.gather. Each has an independent timeout. Partial results from any source are used; no source blocks the others.


Deliverables

src/context/services/retrieval.py

Python
from __future__ import annotations
 
import asyncio
from dataclasses import dataclass, field
from uuid import UUID
 
import httpx
import redis.asyncio as aioredis
 
from ibex_proto.context.v1 import Message
from context.services.budget import TokenBudget
 
DIRECTIVE_TIMEOUT_SECONDS     = 0.010  # 10ms; Redis should always be < 2ms
HOT_CACHE_TIMEOUT_SECONDS     = 0.010  # 10ms
COLD_SEARCH_TIMEOUT_SECONDS   = 0.035  # 35ms; includes embed + vector search
HOT_CACHE_MEMORIES_LIMIT      = 20     # fetch top-20 from hot cache
COLD_SEARCH_MEMORIES_LIMIT    = 50     # fetch top-50 from cold search
 
@dataclass
class RetrievalResult:
    directive:       str                      = ""
    query_embedding: list[float]              = field(default_factory=list)
    memories:        list[dict]               = field(default_factory=list)  # raw memory records
    directive_from:  str                      = "none"   # "cache", "db", "none"
    hot_cache_count: int                      = 0
    cold_search_count: int                    = 0
    retrieval_ms:    float                    = 0.0
 
class ContextRetriever:
    """
    Orchestrates parallel retrieval of directive, hot memories, and cold memories.
    All three retrievals run concurrently; each has an independent timeout.
    The retrieval as a whole is bounded by asyncio.wait_for in the caller (3.5.1).
    """
 
    def __init__(
        self,
        redis: aioredis.Redis,
        embedder: httpx.AsyncClient,
        memory_service: httpx.AsyncClient,
        embedder_url: str,
        memory_service_url: str,
    ) -> None:
        self._redis         = redis
        self._embedder      = embedder
        self._memory        = memory_service
        self._embed_url     = embedder_url
        self._memory_url    = memory_service_url
 
    async def retrieve(
        self,
        org_id: UUID,
        agent_id: UUID,
        session_id: UUID | None,
        messages: list[Message],
        budget: TokenBudget,
    ) -> RetrievalResult:
        """
        Run all three retrieval operations concurrently.
        Never raises; uses partial results on any individual failure.
        """
        # Extract the last user message for embedding (the semantic query)
        query_text = self._extract_query(messages)
 
        # Launch all three concurrently — DO NOT await in sequence
        directive_task    = asyncio.create_task(self._get_directive(agent_id))
        hot_cache_task    = asyncio.create_task(self._get_hot_memories(org_id, agent_id))
        embedding_task    = asyncio.create_task(self._embed_query(query_text))
 
        # Wait for embedding before launching cold search (cold search needs the vector)
        directive, hot_mems, query_embedding = await asyncio.gather(
            directive_task, hot_cache_task, embedding_task,
            return_exceptions=True,
        )
 
        # Resolve exceptions to defaults (fail open on any retrieval)
        directive     = directive     if isinstance(directive, str)      else ""
        hot_mems      = hot_mems      if isinstance(hot_mems, list)      else []
        query_embedding = query_embedding if isinstance(query_embedding, list) else []
 
        # Now launch cold search if we have an embedding
        cold_mems: list[dict] = []
        if query_embedding and not budget.is_constrained:
            try:
                cold_mems = await asyncio.wait_for(
                    self._cold_search(org_id, agent_id, query_embedding),
                    timeout=COLD_SEARCH_TIMEOUT_SECONDS,
                )
            except (asyncio.TimeoutError, Exception):
                cold_mems = []
 
        # Merge hot and cold results, deduplicate by memory_id
        all_memories = self._merge_dedup(hot_mems, cold_mems)
 
        return RetrievalResult(
            directive=directive,
            query_embedding=query_embedding,
            memories=all_memories,
            hot_cache_count=len(hot_mems),
            cold_search_count=len(cold_mems),
        )
 
    async def _get_directive(self, agent_id: UUID) -> str:
        """Fetch active directive content from Redis. Returns '' on miss."""
        key = f"directive:{agent_id}"
        data = await asyncio.wait_for(
            self._redis.get(key), timeout=DIRECTIVE_TIMEOUT_SECONDS
        )
        return data.decode() if data else ""
 
    async def _get_hot_memories(self, org_id: UUID, agent_id: UUID) -> list[dict]:
        """Fetch top-20 memory IDs from Redis sorted set, then bulk-fetch from memory service."""
        key     = f"{org_id}:hot_memories:{agent_id}"
        ids     = await asyncio.wait_for(
            self._redis.zrevrange(key, 0, HOT_CACHE_MEMORIES_LIMIT - 1),
            timeout=HOT_CACHE_TIMEOUT_SECONDS,
        )
        if not ids:
            return []
        resp = await self._memory.post(
            f"{self._memory_url}/internal/memories/bulk",
            json={"ids": [i.decode() for i in ids]},
            timeout=5.0,
        )
        resp.raise_for_status()
        return resp.json()["memories"]
 
    async def _embed_query(self, text: str) -> list[float]:
        """Embed the query text via the embedding service (cache-aware)."""
        if not text.strip():
            return []
        resp = await asyncio.wait_for(
            self._embedder.post(
                f"{self._embed_url}/v1/embed",
                json={"text": text[:2000]},  # truncate very long queries for embedding
                timeout=5.0,
            ),
            timeout=COLD_SEARCH_TIMEOUT_SECONDS,
        )
        resp.raise_for_status()
        return resp.json()["embedding"]
 
    async def _cold_search(
        self, org_id: UUID, agent_id: UUID, embedding: list[float]
    ) -> list[dict]:
        """Semantic search via the memory service's vector search endpoint."""
        resp = await self._memory.post(
            f"{self._memory_url}/internal/memories/search",
            json={
                "org_id":    str(org_id),
                "agent_id":  str(agent_id),
                "embedding": embedding,
                "limit":     COLD_SEARCH_MEMORIES_LIMIT,
            },
            timeout=5.0,
        )
        resp.raise_for_status()
        return resp.json()["memories"]
 
    @staticmethod
    def _extract_query(messages: list[Message]) -> str:
        """Extract the most recent user message as the semantic query."""
        for msg in reversed(messages):
            if msg.role == "user" and msg.content:
                return msg.content
        return ""
 
    @staticmethod
    def _merge_dedup(hot: list[dict], cold: list[dict]) -> list[dict]:
        """Merge hot and cold memory lists, deduplicating by memory ID."""
        seen: set[str] = set()
        result: list[dict] = []
        for mem in hot + cold:
            mid = mem.get("id", "")
            if mid and mid not in seen:
                seen.add(mid)
                result.append(mem)
        return result

Acceptance Criteria

  • Directive, hot cache, and embedding retrievals run concurrently (not sequentially)
  • Individual timeout for each retrieval source; one failure does not block others
  • Results merged and deduplicated by memory ID
  • RetrievalResult.memories never contains duplicates
  • Empty query text (no user message) returns empty query_embedding without error
  • Integration test: inject 30 memories; verify cold search returns them within 35ms against testcontainers Postgres

Edit on GitHub

Last updated on

On this page

0%