Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions hud/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .base import BaseHub, BaseTool
from .bash import BashTool
from .edit import EditTool
from .memory import MemoryTool
from .playwright import PlaywrightTool
from .response import ResponseTool
from .submit import SubmitTool
Expand All @@ -30,6 +31,7 @@
"OpenAIComputerTool",
"PlaywrightTool",
"ResponseTool",
"MemoryTool",
"SubmitTool",
]

Expand Down
175 changes: 175 additions & 0 deletions hud/tools/memory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
"""Lightweight memory tool with optional Qdrant backend."""

from __future__ import annotations

from dataclasses import dataclass
import uuid
from typing import Any

from mcp.types import ContentBlock, TextContent

from hud.tools.base import BaseTool


def _tokenize(text: str) -> set[str]:
return {t.lower() for t in text.split() if t}


def _jaccard(a: set[str], b: set[str]) -> float:
if not a or not b:
return 0.0
inter = len(a & b)
union = len(a | b)
return inter / union if union else 0.0


@dataclass
class MemoryEntry:
text: str
metadata: dict[str, Any]
tokens: set[str]


class InMemoryStore:
"""Simple token-overlap store."""

def __init__(self) -> None:
self._entries: list[MemoryEntry] = []

def add(self, text: str, metadata: dict[str, Any] | None = None) -> None:
self._entries.append(
MemoryEntry(text=text, metadata=metadata or {}, tokens=_tokenize(text))
)

def query(self, query: str, top_k: int = 5) -> list[MemoryEntry]:
q_tokens = _tokenize(query)
scored = [(entry, _jaccard(q_tokens, entry.tokens)) for entry in self._entries]
scored.sort(key=lambda x: x[1], reverse=True)
return [entry for entry, score in scored[:top_k] if score > 0.0]


class MemoryTool(BaseTool):
"""Add and search short-term memory for a session.

If Qdrant is available and configured (QDRANT_URL), a remote collection is used.
Otherwise, an in-memory fallback is used.
"""

def __init__(
self,
collection: str = "hud_memory",
qdrant_url: str | None = None,
qdrant_api_key: str | None = None,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
self._backend = self._build_backend(collection, qdrant_url, qdrant_api_key)

def _build_backend(
self, collection: str, qdrant_url: str | None, qdrant_api_key: str | None
) -> Any:
if qdrant_url:
try:
from qdrant_client import QdrantClient
from qdrant_client.http.models import Distance, VectorParams
except Exception:
pass
else:
client = QdrantClient(url=qdrant_url, api_key=qdrant_api_key)
try:
client.get_collection(collection)
except Exception:
client.create_collection(
collection_name=collection,
vectors_config=VectorParams(size=384, distance=Distance.COSINE),
)
return QdrantBackend(client, collection)
return InMemoryStore()

@property
def parameters(self) -> dict[str, Any]: # type: ignore[override]
return {
"type": "object",
"properties": {
"action": {
"type": "string",
"enum": ["add", "search"],
"description": "add = store text, search = retrieve similar items",
},
"text": {"type": "string", "description": "content to store or query"},
"metadata": {
"type": "object",
"description": "optional metadata to store with the entry",
},
"top_k": {
"type": "integer",
"minimum": 1,
"maximum": 50,
"default": 5,
"description": "results to return when searching",
},
},
"required": ["action", "text"],
}

async def __call__(
self, action: str, text: str, metadata: dict[str, Any] | None = None, top_k: int = 5
) -> list[ContentBlock]:
if action == "add":
self._backend.add(text=text, metadata=metadata)
return [TextContent(text="stored", type="text")]
if action == "search":
entries = self._backend.query(query=text, top_k=top_k)
if not entries:
return [TextContent(text="no matches", type="text")]
lines = []
for idx, entry in enumerate(entries, 1):
meta = entry.metadata or {}
meta_str = f" | metadata={meta}" if meta else ""
lines.append(f"{idx}. {entry.text}{meta_str}")
return [TextContent(text="\n".join(lines), type="text")]
return [TextContent(text="unknown action", type="text")]


class QdrantBackend:
"""Minimal Qdrant wrapper with on-the-fly sentence-transformer embeddings."""

def __init__(self, client: Any, collection: str) -> None:
self.client = client
self.collection = collection
self._embedder = self._load_embedder()

def _load_embedder(self) -> Any:
try:
from sentence_transformers import SentenceTransformer
except Exception as e:
raise RuntimeError("sentence-transformers is required for Qdrant backend") from e
return SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")

def add(self, text: str, metadata: dict[str, Any] | None = None) -> None:
vec = self._embedder.encode(text).tolist()
payload = {"text": text, "metadata": metadata or {}}
self.client.upsert(
collection_name=self.collection,
points=[{"id": uuid.uuid4().hex, "vector": vec, "payload": payload}],
)

def query(self, query: str, top_k: int = 5) -> list[MemoryEntry]:
vec = self._embedder.encode(query).tolist()
res = self.client.search(
collection_name=self.collection,
query_vector=vec,
limit=top_k,
with_payload=True,
)
entries: list[MemoryEntry] = []
for point in res:
payload = point.payload or {}
entries.append(
MemoryEntry(
text=payload.get("text", ""),
metadata=payload.get("metadata", {}),
tokens=set(),
)
)
return entries
34 changes: 34 additions & 0 deletions hud/tools/tests/test_memory_tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from __future__ import annotations

import pytest
from mcp.types import TextContent

from hud.tools.memory import InMemoryStore, MemoryTool


def test_inmemory_store_add_and_query() -> None:
store = InMemoryStore()
store.add("apple orange", {"kind": "fruit"})
store.add("carrot celery", {"kind": "veg"})

results = store.query("apple", top_k=5)
assert len(results) == 1
assert results[0].metadata["kind"] == "fruit"


@pytest.mark.asyncio
async def test_memory_tool_add_and_search() -> None:
tool = MemoryTool()

out_add = await tool(action="add", text="alpha beta", metadata={"id": 1})
assert isinstance(out_add[0], TextContent)

out_search = await tool(action="search", text="alpha")
assert out_search[0].text.startswith("1.")


@pytest.mark.asyncio
async def test_memory_tool_unknown_action() -> None:
tool = MemoryTool()
res = await tool(action="noop", text="x")
assert res[0].text == "unknown action"