diff --git a/examples/pg-memory.py b/examples/pg-memory.py new file mode 100644 index 00000000..6087fb06 --- /dev/null +++ b/examples/pg-memory.py @@ -0,0 +1,47 @@ +import controlflow as cf +from controlflow.memory.memory import Memory +from controlflow.memory.providers.postgres import PostgresMemory + +provider = PostgresMemory( + database_url="postgresql://postgres:postgres@localhost:5432/your_database", + # embedding_dimension=1536, + # embedding_fn=OpenAIEmbeddings(), + table_name="vector_db", +) +# Create a memory module for user preferences +user_preferences = cf.Memory( + key="user_preferences", + instructions="Store and retrieve user preferences.", + provider=provider, +) + +# Create an agent with access to the memory +agent = cf.Agent(memories=[user_preferences]) + + +# Create a flow to ask for the user's favorite color +@cf.flow +def remember_color(): + return cf.run( + "Ask the user for their favorite color and store it in memory", + agents=[agent], + interactive=True, + ) + + +# Create a flow to recall the user's favorite color +@cf.flow +def recall_color(): + return cf.run( + "What is the user's favorite color?", + agents=[agent], + ) + + +if __name__ == "__main__": + print("First flow:") + remember_color() + + print("\nSecond flow:") + result = recall_color() + print(result) diff --git a/src/controlflow/memory/memory.py b/src/controlflow/memory/memory.py index 2e521081..1e0dd840 100644 --- a/src/controlflow/memory/memory.py +++ b/src/controlflow/memory/memory.py @@ -166,4 +166,16 @@ def get_memory_provider(provider: str) -> MemoryProvider: return lance_providers.LanceMemory() + # --- Postgres --- + elif provider.startswith("postgres"): + try: + import sqlalchemy + except ImportError: + raise ImportError( + "To use Postgres as a memory provider, please install the `sqlalchemy` package." + ) + + import controlflow.memory.providers.postgres as postgres_providers + + return postgres_providers.PostgresMemory() raise ValueError(f'Memory provider "{provider}" could not be loaded from a string.') diff --git a/src/controlflow/memory/providers/postgres.py b/src/controlflow/memory/providers/postgres.py new file mode 100644 index 00000000..6887722a --- /dev/null +++ b/src/controlflow/memory/providers/postgres.py @@ -0,0 +1,197 @@ +import uuid +from typing import Callable, Dict, Optional + +import sqlalchemy +from pgvector.sqlalchemy import Vector +from pydantic import Field +from sqlalchemy import Column, String, select, text +from sqlalchemy.dialects.postgresql import ARRAY +from sqlalchemy.exc import ProgrammingError +from sqlalchemy.orm import Session, declarative_base, sessionmaker +from sqlalchemy_utils import create_database, database_exists + +import controlflow +from controlflow.memory.memory import MemoryProvider + +try: + # For embeddings, we can use langchain_openai or any other library: + from langchain_openai import OpenAIEmbeddings +except ImportError: + raise ImportError( + "To use an embedding function similar to LanceDB's default, " + "please install lancedb with: pip install lancedb" + ) + +# SQLAlchemy base class for declarative models +Base = declarative_base() + + +class SQLMemoryTable(Base): + """ + A simple declarative model that represents a memory record. + + We’ll dynamically set the __tablename__ at runtime. + """ + + __abstract__ = True + id = Column(String, primary_key=True) + text = Column(String) + # Use pgvector for storing embeddings in a Postgres Vector column + # vector = Column(Vector(dim=1536)) # Adjust dimension to match your embedding model + + +class PostgresMemory(MemoryProvider): + """ + A ControlFlow MemoryProvider that stores text + embeddings in PostgreSQL + using SQLAlchemy and pg_vector. Each Memory module gets its own table. + """ + + # Default database URL. You can point this to your actual Postgres instance. + # Requires the pgvector extension installed and the sqlalchemy-pgvector package. + database_url: str = Field( + default="postgresql://user:password@localhost:5432/your_database", + description="SQLAlchemy-compatible database URL to a Postgres instance with pgvector.", + ) + table_name: str = Field( + "memory_{key}", + description=""" + Name of the table to store this memory partition. "{key}" will be replaced + by the memory’s key attribute. + """, + ) + + embedding_dimension: int = Field( + default=1536, + description="Dimension of the embedding vectors. Match your model's output.", + ) + + embedding_fn: Callable = Field( + default_factory=lambda: OpenAIEmbeddings( + model="text-embedding-ada-002", + ), + description="A function that turns a string into a vector.", + ) + + # Internal: keep a cached Session maker + _SessionLocal: Optional[sessionmaker] = None + + # This dict will map "table_name" -> "model class" + _table_class_cache: Dict[str, Base] = {} + + def configure(self, memory_key: str) -> None: + """ + Configure a SQLAlchemy session and ensure the table for this + memory partition is created if it does not already exist. + """ + engine = sqlalchemy.create_engine(self.database_url) + + # 2) If DB doesn't exist, create it! + if not database_exists(engine.url): + create_database(engine.url) + + with engine.connect() as conn: + conn.execute(text("CREATE EXTENSION IF NOT EXISTS vector")) + conn.commit() + + self._SessionLocal = sessionmaker(bind=engine) + + # Dynamically create a specialized table model for this memory_key + table_name = self.table_name.format(key=memory_key) + + # 1) Check if table already in metadata + if table_name not in Base.metadata.tables: + # 2) Create the dynamic class + table + memory_model = type( + f"SQLMemoryTable_{memory_key}", + (SQLMemoryTable,), + { + "__tablename__": table_name, + "vector": Column(Vector(dim=self.embedding_dimension)), + }, + ) + + try: + Base.metadata.create_all(engine, tables=[memory_model.__table__]) + # Store it in the cache + self._table_class_cache[table_name] = memory_model + except ProgrammingError as e: + raise RuntimeError(f"Failed to create table {table_name}: {e}") + + def _get_session(self) -> Session: + if not self._SessionLocal: + raise RuntimeError( + "Session is not initialized. Make sure to call configure() first." + ) + return self._SessionLocal() + + def _get_table(self, memory_key: str) -> Base: + """ + Return a dynamically generated declarative model class + mapped to the memory_{key} table. Each memory partition + has a separate table. + """ + table_name = self.table_name.format(key=memory_key) + + # Return the cached class if already built + if table_name in self._table_class_cache: + return self._table_class_cache[table_name] + + # If for some reason it's not there, create it now (or raise error): + memory_model = type( + f"SQLMemoryTable_{memory_key}", + (SQLMemoryTable,), + { + "__tablename__": table_name, + "vector": Column(Vector(dim=self.embedding_dimension)), + }, + ) + self._table_class_cache[table_name] = memory_model + return memory_model + + def add(self, memory_key: str, content: str) -> str: + """ + Insert a new memory record into the Postgres table, + generating an embedding and storing it in a vector column. + Returns the memory’s ID (uuid). + """ + memory_id = str(uuid.uuid4()) + model_cls = self._get_table(memory_key) + + # Generate an embedding for the content + embedding = self.embedding_fn.embed_query(content) + + with self._get_session() as session: + record = model_cls(id=memory_id, text=content, vector=embedding) + session.add(record) + session.commit() + + return memory_id + + def delete(self, memory_key: str, memory_id: str) -> None: + """ + Delete a memory record by its UUID. + """ + model_cls = self._get_table(memory_key) + + with self._get_session() as session: + session.query(model_cls).filter(model_cls.id == memory_id).delete() + session.commit() + + def search(self, memory_key: str, query: str, n: int = 20) -> Dict[str, str]: + """ + Uses pgvector’s approximate nearest neighbor search with the `<->` operator to find + the top N matching records for the embedded query. Returns a dict of {id: text}. + """ + model_cls = self._get_table(memory_key) + # Generate embedding for the query + query_embedding = self.embedding_fn.embed_query(query) + embedding_col = model_cls.vector + + with self._get_session() as session: + results = session.execute( + select(model_cls.id, model_cls.text) + .order_by(embedding_col.l2_distance(query_embedding)) + .limit(n) + ).all() + + return {row.id: row.text for row in results}