diff --git a/.github/workflows/branching-database.yml b/.github/workflows/branching-database.yml index c9e05d5..de39bde 100644 --- a/.github/workflows/branching-database.yml +++ b/.github/workflows/branching-database.yml @@ -48,16 +48,17 @@ jobs: uses: neondatabase/create-branch-action@v6 with: project_id: ${{ vars.NEON_PROJECT_ID }} - branch_name: pr/${{ needs.setup.outputs.current_branch }} - parent_branch: ${{ needs.setup.outputs.base_ref_branch != '' && format('pr/{0}', needs.setup.outputs.base_ref_branch) || 'production' }} + branch_name: ${{ needs.setup.outputs.current_branch }} branch_type: "schema-only" + parent_branch: ${{ needs.setup.outputs.base_ref_branch || 'develop' }} api_key: ${{ secrets.NEON_API_KEY }} expires_at: ${{ env.EXPIRES_AT }} - name: Post Schema Diff Comment to PR uses: neondatabase/schema-diff-action@v1 with: project_id: ${{ vars.NEON_PROJECT_ID }} - compare_branch: pr/${{ needs.setup.outputs.current_branch }} + compare_branch: ${{ needs.setup.outputs.current_branch }} + base_branch: ${{ needs.setup.outputs.base_ref_branch || 'develop' }} api_key: ${{ secrets.NEON_API_KEY }} delete_db_branch: diff --git a/.github/workflows/copilot-setup-steps.yml b/.github/workflows/copilot-setup-steps.yml index 860794c..dbbe889 100644 --- a/.github/workflows/copilot-setup-steps.yml +++ b/.github/workflows/copilot-setup-steps.yml @@ -1,53 +1,53 @@ name: "Copilot Setup Steps" on: - workflow_dispatch: - push: - paths: - - .github/workflows/copilot-setup-steps.yml - pull_request: - paths: - - .github/workflows/copilot-setup-steps.yml + workflow_dispatch: + push: + paths: + - .github/workflows/copilot-setup-steps.yml + pull_request: + paths: + - .github/workflows/copilot-setup-steps.yml jobs: - copilot-setup-steps: - runs-on: ubuntu-latest - permissions: - contents: read - - steps: - - name: Checkout code - uses: actions/checkout@v5 - - - name: Get branch name - id: branch_name - uses: tj-actions/branch-names@v8 - - - name: Create DB Branch - id: create_db_branch - uses: neondatabase/create-branch-action@v6 - with: - project_id: ${{ vars.NEON_PROJECT_ID }} - branch_name: pr/${{ steps.branch_name.outputs.current_branch }} - parent_branch: ${{ steps.branch_name.outputs.base_ref_branch != '' && format('pr/{0}', steps.branch_name.outputs.base_ref_branch) || 'production' }} - branch_type: "schema-only" - api_key: ${{ secrets.NEON_API_KEY }} - - - name: Write database URL to .env - run: | - echo "DATABASE_URL=${{ steps.create_db_branch.outputs.db_url_with_pooler }}" > .env - echo "✅ Database URL written to .env file" - - - name: Set up Python - uses: actions/setup-python@v5 - with: - python-version: "3.12" - - - name: Set up PDM - uses: pdm-project/setup-pdm@v4 - with: - python-version: "3.12" - cache: true - - - name: Install dependencies - run: pdm install -G dev --frozen-lockfile + copilot-setup-steps: + runs-on: ubuntu-latest + permissions: + contents: read + + steps: + - name: Checkout code + uses: actions/checkout@v5 + + - name: Get branch name + id: branch_name + uses: tj-actions/branch-names@v8 + + - name: Create DB Branch + id: create_db_branch + uses: neondatabase/create-branch-action@v6 + with: + project_id: ${{ vars.NEON_PROJECT_ID }} + branch_name: ${{ steps.branch_name.outputs.current_branch }} + parent_branch: ${{ needs.setup.outputs.base_ref_branch || 'develop' }} + branch_type: "schema-only" + api_key: ${{ secrets.NEON_API_KEY }} + + - name: Write database URL to .env + run: | + echo "DATABASE_URL=${{ steps.create_db_branch.outputs.db_url_with_pooler }}" > .env + echo "✅ Database URL written to .env file" + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.12" + + - name: Set up PDM + uses: pdm-project/setup-pdm@v4 + with: + python-version: "3.12" + cache: true + + - name: Install dependencies + run: pdm install -G dev --frozen-lockfile diff --git a/.github/workflows/openapi-doc.yml b/.github/workflows/openapi-doc.yml index 27264cc..96316b1 100644 --- a/.github/workflows/openapi-doc.yml +++ b/.github/workflows/openapi-doc.yml @@ -2,6 +2,7 @@ name: Check & deploy API documentation permissions: contents: write pull-requests: write + on: push: branches: @@ -13,6 +14,7 @@ on: - main - staging - develop + jobs: generate-openapi: runs-on: ubuntu-latest @@ -41,44 +43,55 @@ jobs: SKIP_EXTENSIONS_SYNC: "1" OBSRV__LOGGING_BACKEND: "none" run: pdm run python scripts/generate-openapi.py - - - name: Check if schema changed - id: check_changes - run: | - if git diff --quiet docs/openapi.json; then - echo "changed=false" >> $GITHUB_OUTPUT - else - echo "changed=true" >> $GITHUB_OUTPUT - fi + - name: Upload OpenAPI artifact + uses: actions/upload-artifact@v4 + with: + name: openapi-schema + path: docs/openapi.json + retention-days: 1 - name: Commit and push changes - if: steps.check_changes.outputs.changed == 'true' - run: | - git config --local user.email "action@github.com" - git config --local user.name "GitHub Action" - git add docs/openapi.json - git commit -m "chore: update OpenAPI documentation" - git push + uses: stefanzweifel/git-auto-commit-action@v5 + with: + commit_message: "chore: update OpenAPI documentation" + file_pattern: docs/openapi.json + deploy-doc: if: ${{ github.event_name == 'push' }} name: Deploy API documentation on Bump.sh runs-on: ubuntu-latest + needs: generate-openapi steps: - name: Checkout uses: actions/checkout@v3 + - name: Download OpenAPI artifact + uses: actions/download-artifact@v4 + with: + name: openapi-schema + path: docs + - name: Deploy API documentation uses: bump-sh/github-action@v1 with: doc: core token: ${{secrets.BUMP_SH_TOKEN}} file: docs/openapi.json + api-diff: if: ${{ github.event_name == 'pull_request' }} name: Check API diff on Bump.sh runs-on: ubuntu-latest + needs: generate-openapi steps: - name: Checkout uses: actions/checkout@v3 + + - name: Download OpenAPI artifact + uses: actions/download-artifact@v4 + with: + name: openapi-schema + path: docs + - name: Comment pull request with API diff uses: bump-sh/github-action@v1 with: diff --git a/.gitignore b/.gitignore index c218f80..3da9440 100644 --- a/.gitignore +++ b/.gitignore @@ -217,7 +217,4 @@ scratch/ ipynb # extensions scripts -extensions/**/scripts/* - -# Database migrations -migrations/versions/ \ No newline at end of file +extensions/**/scripts/* \ No newline at end of file diff --git a/app/business/info_base/block.py b/app/business/info_base/block.py index 8bc803b..378da6b 100644 --- a/app/business/info_base/block.py +++ b/app/business/info_base/block.py @@ -20,7 +20,6 @@ one_chat, ) from app.schemas.info_base.block import ( - BlockEmbeddingModel, BlockID, BlockModel, ResolverType, @@ -101,52 +100,18 @@ def create( "Block created successfully", extra={"block_id": block.id, "resolver": block.resolver}, ) - - scheduler.add_job( - func=cls._upsert_embedding, - kwargs={"block_id": block.id}, - misfire_grace_time=None, + logger.debug( + "Embedding will be created asynchronously by interval job", + extra={"block_id": block.id}, ) return block @classmethod async def refresh_embeddings(cls): - """Rebuild all blocks' embeddings""" - with SessionLocal() as db_session: - blocks = db_session.exec( - sqlmodel.select(BlockModel).where( - BlockModel.resolver == "learn_english.lexical" - ) # FIXME - ).all() - tasks = tuple(cls._upsert_embedding(block, db_session) for block in blocks) - await asyncio.gather(*tasks) - db_session.commit() - - @classmethod - async def _upsert_embedding( - cls, block: BlockModel, db_session: Opt[sqlmodel.Session] = None - ) -> BlockEmbeddingModel: - """Upsert a block's embedding - - :param block: 块 - :param db_session: 可选的数据库会话,如果提供则使用该会话;不会提交。 - """ - from .resolver import ResolverManager - - resolver = ResolverManager.new_resolver(block) - embedding = BlockEmbeddingModel( - id=block.id, # type: ignore[arg-type] - embedding=Embedding("", "text-embedding-v3").embed(resolver.get_str_for_embedding()), - ) - if db_session: - db_session.merge(embedding) - return embedding - with SessionLocal() as db_session: - db_session.merge(embedding) - db_session.commit() - db_session.refresh(embedding) - return embedding + """Rebuild all blocks' embeddings - delegates to sink embedding service""" + from app.business.sink.embedding import EmbeddingManager + await EmbeddingManager.refresh_all_block_embeddings() @classmethod async def fetchsert(cls, block: BlockModel, db_session: sqlmodel.Session) -> BlockModel: @@ -154,6 +119,8 @@ async def fetchsert(cls, block: BlockModel, db_session: sqlmodel.Session) -> Blo Will NOT commit the session. """ + from app.business.sink.embedding import EmbeddingManager + resolver = ResolverManager.new_resolver(block) existing = resolver.get_existing(db_session) if existing is not None: @@ -170,8 +137,8 @@ async def fetchsert(cls, block: BlockModel, db_session: sqlmodel.Session) -> Blo db_session.add(block) db_session.flush() db_session.refresh(block) - # and embedding - await cls._upsert_embedding(block, db_session) + # and embedding - use sink service + await EmbeddingManager.upsert_block_embedding(block, db_session) return block @@ -205,39 +172,20 @@ def query_by_embedding( num: int = 10, max_distance: float = 0.3, ) -> tuple[BlockModel, ...]: - """根据余弦相似度查询块 + """Query blocks by cosine similarity - delegates to sink embedding service - :param block_id: 用已有块的embedding查询 - :param embedding: 用给定的embedding查询 - :param resolver: 限定解析器类型, None则不限定 + :param block_id: Use existing block's embedding for query + :param embedding: Use given embedding for query + :param resolver: Filter by resolver type, None means no filter """ - with SessionLocal() as db_session: - if block_id is not None: - base_embedding = db_session.exec( - sqlmodel.select(BlockEmbeddingModel.embedding).where( - BlockEmbeddingModel.id == block_id - ) - ) - else: - if embedding is not None: - base_embedding = embedding - else: - raise ValueError("one of block_id or embedding must be provided") - - similar_blocks = db_session.exec( - sqlmodel.select(BlockModel) - .select_from(BlockModel) - .join(BlockEmbeddingModel, BlockEmbeddingModel.id == BlockModel.id) # type: ignore - .where(BlockModel.resolver == resolver if resolver else True) - .where(BlockEmbeddingModel.embedding is not None) - .where(BlockEmbeddingModel.id != block_id) - .where( - BlockEmbeddingModel.embedding.cosine_distance(base_embedding) < max_distance # type: ignore - ) - .limit(num) - ).all() - - return tuple(similar_blocks) # type: ignore + from app.business.sink.embedding import EmbeddingManager + return EmbeddingManager.query_blocks_by_embedding( + block_id=block_id, + embedding=embedding, + resolver=resolver, + num=num, + max_distance=max_distance, + ) @classmethod async def iterate_from_block( @@ -445,11 +393,9 @@ def edit_block( db_session.refresh(block) logger.info("Block edited successfully", extra={"block_id": block.id}) - - scheduler.add_job( - func=cls._upsert_embedding, - kwargs={"block_id": block.id}, - misfire_grace_time=None, + logger.debug( + "Embedding will be updated asynchronously by interval job", + extra={"block_id": block.id}, ) return block diff --git a/app/business/info_base/storage/main.py b/app/business/info_base/storage/main.py index b68c5ab..5699192 100644 --- a/app/business/info_base/storage/main.py +++ b/app/business/info_base/storage/main.py @@ -145,22 +145,28 @@ def setup_builtin_storages(cls) -> None: Uses PostgreSQL upsert to ensure built-in storages exist with correct configuration. """ + from app.business.info_base.storage.http import ( + HTTPHtmlStorage, + HTTPImageStorage, + HTTPVideoStorage, + ) + builtin_storages = [ { "id": -1, - "type": "app.business.storage.http.HTTPImageStorage", + "type": ".".join((HTTPImageStorage.__module__, HTTPImageStorage.__qualname__)), "nickname": "http_image", "config": {}, }, { "id": -2, - "type": "app.business.storage.http.HTTPVideoStorage", + "type": ".".join((HTTPVideoStorage.__module__, HTTPVideoStorage.__qualname__)), "nickname": "http_video", "config": {}, }, { "id": -3, - "type": "app.business.storage.http.HTTPHtmlStorage", + "type": ".".join((HTTPHtmlStorage.__module__, HTTPHtmlStorage.__qualname__)), "nickname": "http_html", "config": {}, }, diff --git a/app/business/sink/__init__.py b/app/business/sink/__init__.py index b3eb465..92afd48 100644 --- a/app/business/sink/__init__.py +++ b/app/business/sink/__init__.py @@ -1,5 +1,7 @@ from .main import SinkManager +from .embedding import EmbeddingManager __all__ = [ "SinkManager", + "EmbeddingManager", ] diff --git a/app/business/sink/embedding.py b/app/business/sink/embedding.py new file mode 100644 index 0000000..9ca0b1b --- /dev/null +++ b/app/business/sink/embedding.py @@ -0,0 +1,259 @@ +"""Embedding Manager for RAG Sink + +This module manages embeddings for blocks and relations. +Embeddings are created/updated here as they are part of the RAG sink (output/usage of info-base). +""" + +__all__ = ["EmbeddingManager"] + +import asyncio +import sqlmodel +from typing import Optional as Opt +from app.engine import SessionLocal +from libs.obsrv.main import get_logger +from libs.ai import Embedding +from app.schemas.sink.embedding import BlockEmbeddingModel, RelationEmbeddingModel +from app.schemas.info_base.block import BlockModel, BlockID +from app.schemas.info_base.relation import RelationModel, RelationID +from app.schemas.info_base.main import Vector + +logger = get_logger() + + +class EmbeddingManager: + @classmethod + async def upsert_block_embedding( + cls, block_id: Opt[BlockID] = None, block: Opt[BlockModel] = None, db_session: Opt[sqlmodel.Session] = None + ) -> BlockEmbeddingModel: + """Upsert a block's embedding + + :param block_id: Block ID to create/update embedding for + :param block: Block model to create/update embedding for (alternative to block_id) + :param db_session: Optional database session, if provided uses that session; won't commit. + """ + from app.business.info_base.resolver import ResolverManager + + if block is None: + if block_id is None: + raise ValueError("Either block_id or block must be provided") + # Fetch block from database + with SessionLocal() as fetch_session: + block = fetch_session.exec( + sqlmodel.select(BlockModel).where(BlockModel.id == block_id) + ).one_or_none() + if block is None: + raise ValueError(f"Block with id {block_id} not found") + + resolver = ResolverManager.new_resolver(block) + embedding = BlockEmbeddingModel( + id=block.id, # type: ignore[arg-type] + embedding=Embedding("", "text-embedding-v3").embed(resolver.get_str_for_embedding()), + ) + if db_session: + db_session.merge(embedding) + return embedding + with SessionLocal() as db_session: + db_session.merge(embedding) + db_session.commit() + db_session.refresh(embedding) + return embedding + + @classmethod + async def upsert_relation_embedding( + cls, relation_id: Opt[RelationID] = None, relation: Opt[RelationModel] = None, db_session: Opt[sqlmodel.Session] = None + ) -> RelationEmbeddingModel: + """Upsert a relation's embedding + + :param relation_id: Relation ID to create/update embedding for + :param relation: Relation model to create/update embedding for (alternative to relation_id) + :param db_session: Optional database session, if provided uses that session; won't commit. + """ + if relation is None: + if relation_id is None: + raise ValueError("Either relation_id or relation must be provided") + # Fetch relation from database + with SessionLocal() as fetch_session: + relation = fetch_session.exec( + sqlmodel.select(RelationModel).where(RelationModel.id == relation_id) + ).one_or_none() + if relation is None: + raise ValueError(f"Relation with id {relation_id} not found") + + # For relations, we embed the content directly + embedding = RelationEmbeddingModel( + id=relation.id, # type: ignore[arg-type] + embedding=Embedding("", "text-embedding-v3").embed(relation.content), + ) + if db_session: + db_session.merge(embedding) + return embedding + with SessionLocal() as db_session: + db_session.merge(embedding) + db_session.commit() + db_session.refresh(embedding) + return embedding + + @classmethod + async def refresh_all_block_embeddings(cls): + """Rebuild all blocks' embeddings""" + with SessionLocal() as db_session: + blocks = db_session.exec( + sqlmodel.select(BlockModel).where( + BlockModel.resolver == "learn_english.lexical" + ) # FIXME + ).all() + tasks = tuple(cls.upsert_block_embedding(block=block, db_session=db_session) for block in blocks) + await asyncio.gather(*tasks) + db_session.commit() + + @classmethod + async def check_and_create_missing_embeddings(cls): + """Check for blocks/relations missing embeddings and create them + + This is called periodically by the scheduler to ensure all content has embeddings. + """ + logger.info("Checking for missing embeddings") + with SessionLocal() as db_session: + # Find blocks without embeddings + blocks_without_embeddings = db_session.exec( + sqlmodel.select(BlockModel) + .outerjoin(BlockEmbeddingModel, BlockModel.id == BlockEmbeddingModel.id) + .where(BlockEmbeddingModel.id.is_(None)) + .limit(10) # Process in batches to avoid long-running jobs + ).all() + + # Find relations without embeddings + relations_without_embeddings = db_session.exec( + sqlmodel.select(RelationModel) + .outerjoin(RelationEmbeddingModel, RelationModel.id == RelationEmbeddingModel.id) + .where(RelationEmbeddingModel.id.is_(None)) + .limit(10) # Process in batches + ).all() + + if blocks_without_embeddings: + logger.info( + f"Creating embeddings for {len(blocks_without_embeddings)} blocks" + ) + block_tasks = tuple( + cls.upsert_block_embedding(block=block, db_session=db_session) + for block in blocks_without_embeddings + ) + await asyncio.gather(*block_tasks) + + if relations_without_embeddings: + logger.info( + f"Creating embeddings for {len(relations_without_embeddings)} relations" + ) + relation_tasks = tuple( + cls.upsert_relation_embedding(relation=relation, db_session=db_session) + for relation in relations_without_embeddings + ) + await asyncio.gather(*relation_tasks) + + db_session.commit() + + if blocks_without_embeddings or relations_without_embeddings: + logger.info( + f"Created embeddings for {len(blocks_without_embeddings)} blocks " + f"and {len(relations_without_embeddings)} relations" + ) + + @classmethod + def query_blocks_by_embedding( + cls, + block_id: Opt[int] = None, + embedding: Opt[Vector] = None, + resolver: Opt[str] = None, + num: int = 10, + max_distance: float = 0.3, + ) -> tuple[BlockModel, ...]: + """Query blocks by cosine similarity + + :param block_id: Use embedding from existing block + :param embedding: Use given embedding + :param resolver: Filter by resolver type, None means no filter + :param num: Number of results to return + :param max_distance: Maximum cosine distance threshold + """ + with SessionLocal() as db_session: + if block_id is not None: + base_embedding = db_session.exec( + sqlmodel.select(BlockEmbeddingModel.embedding).where( + BlockEmbeddingModel.id == block_id + ) + ).one() + else: + if embedding is not None: + base_embedding = embedding + else: + raise ValueError("one of block_id or embedding must be provided") + + query = ( + sqlmodel.select(BlockModel) + .select_from(BlockModel) + .join(BlockEmbeddingModel, BlockEmbeddingModel.id == BlockModel.id) # type: ignore + .where(BlockEmbeddingModel.embedding is not None) + .where(BlockEmbeddingModel.id != block_id) + .where( + BlockEmbeddingModel.embedding.cosine_distance(base_embedding) < max_distance # type: ignore + ) + .order_by(BlockEmbeddingModel.embedding.cosine_distance(base_embedding)) # type: ignore + .limit(num) + ) + + # Apply resolver filter if specified + if resolver is not None: + query = query.where(BlockModel.resolver == resolver) + + similar_blocks = db_session.exec(query).all() + + return tuple(similar_blocks) # type: ignore + + @classmethod + def rerank_blocks( + cls, + query: str, + blocks: tuple[BlockModel, ...], + top_k: int = 5, + ) -> tuple[BlockModel, ...]: + """Rerank blocks using a more sophisticated method + + This uses cross-encoder or similar reranking approach to improve retrieval quality. + Currently implements a simple score-based reranking using query embedding similarity. + + :param query: The search query + :param blocks: Candidate blocks to rerank + :param top_k: Number of top results to return after reranking + """ + if not blocks: + return tuple() + + # Generate query embedding + query_embedding = Embedding("", "text-embedding-v3").embed(query) + + # Calculate scores for each block + with SessionLocal() as db_session: + block_scores: list[tuple[BlockModel, float]] = [] + + for block in blocks: + block_embedding = db_session.exec( + sqlmodel.select(BlockEmbeddingModel.embedding).where( + BlockEmbeddingModel.id == block.id + ) + ).one_or_none() + + if block_embedding: + # Calculate cosine distance (lower is better) + # We'll use SQLAlchemy's cosine_distance for consistency + distance = db_session.exec( + sqlmodel.select( + BlockEmbeddingModel.embedding.cosine_distance(query_embedding) # type: ignore + ).where(BlockEmbeddingModel.id == block.id) + ).one() + block_scores.append((block, distance)) + + # Sort by distance (ascending) and take top_k + block_scores.sort(key=lambda x: x[1]) + reranked_blocks = tuple(block for block, _ in block_scores[:top_k]) + + return reranked_blocks diff --git a/app/business/sink/main.py b/app/business/sink/main.py index 4c5651b..93de75d 100644 --- a/app/business/sink/main.py +++ b/app/business/sink/main.py @@ -6,7 +6,7 @@ import sqlmodel from app.business.info_base.block import BlockManager -from libs.ai import Chat, Message, MessageContent, Prompt +from libs.ai import Chat, Message, MessageContent, Prompt, Embedding from app.schemas.info_base.block import BlockID @@ -23,8 +23,23 @@ async def rag( query: str, context: Opt[str] = None, context_blocks: list[BlockID] = fastapi.Query([]), - retrieve_mode: RetrieveMode = "feature", + retrieve_mode: RetrieveMode = "embedding", + use_reranker: bool = True, + num_retrieve: int = 20, + num_rerank: int = 5, ) -> SinkV1RAGResBody: + """RAG (Retrieval Augmented Generation) endpoint + + :param query: User query + :param context: Additional context string + :param context_blocks: Additional context block IDs + :param retrieve_mode: Retrieval mode - "embedding", "reasoning", or "feature" + :param use_reranker: Whether to use reranker to improve retrieval results + :param num_retrieve: Number of blocks to retrieve initially + :param num_rerank: Number of blocks to keep after reranking + """ + from .embedding import EmbeddingManager + # retrieve from base if retrieve_mode == "reasoning": related_blocks = await BlockManager.query_by_reasoning(query=query) @@ -32,8 +47,30 @@ async def rag( for block in related_blocks: tmp.append(await block.get_context_as_text()) retrieve_result_prompt = MessageContent(content="\n".join(tmp)) + elif retrieve_mode == "embedding": + # Use embedding-based retrieval + query_embedding = Embedding("", "text-embedding-v3").embed(query) + related_blocks = EmbeddingManager.query_blocks_by_embedding( + embedding=query_embedding, + num=num_retrieve, + max_distance=0.5, # More lenient initial retrieval + ) + + # Apply reranker if enabled + if use_reranker and related_blocks: + related_blocks = EmbeddingManager.rerank_blocks( + query=query, + blocks=related_blocks, + top_k=num_rerank, + ) + + # Convert blocks to text for LLM + tmp = [] + for block in related_blocks: + tmp.append(await block.get_context_as_text()) + retrieve_result_prompt = MessageContent(content="\n".join(tmp)) else: - raise NotImplementedError + raise NotImplementedError(f"Retrieve mode '{retrieve_mode}' not implemented") # context + context_blocks -> context_text context_text = context or "" diff --git a/app/schemas/AGENTS.md b/app/schemas/AGENTS.md new file mode 100644 index 0000000..7f7adcd --- /dev/null +++ b/app/schemas/AGENTS.md @@ -0,0 +1 @@ +- Import your schema in `app/schemas/__init__.py` to make Alembic discovers your schema while generating migrations. \ No newline at end of file diff --git a/app/schemas/__init__.py b/app/schemas/__init__.py index ff8428f..3cc9f9b 100644 --- a/app/schemas/__init__.py +++ b/app/schemas/__init__.py @@ -7,6 +7,8 @@ "SourceModel", "SourceCollectJobModel", "ExtensionModel", + "RelationEmbeddingModel", + "BlockEmbeddingModel", ] import sqlalchemy.orm @@ -20,3 +22,4 @@ from .info_base.relation import RelationModel from .source import SourceModel, SourceCollectJobModel from .extension.main import ExtensionModel +from .sink import RelationEmbeddingModel, BlockEmbeddingModel diff --git a/app/schemas/info_base/__init__.py b/app/schemas/info_base/__init__.py index 53e59cd..960d352 100644 --- a/app/schemas/info_base/__init__.py +++ b/app/schemas/info_base/__init__.py @@ -1,4 +1,4 @@ -from .block import BlockModel, BlockID, ResolverType, BlockEmbeddingModel +from .block import BlockModel, BlockID, ResolverType from .relation import RelationModel, RelationID from .storage import StorageModel, StorageID, StorageTypesModel from .main import StarGraphForm, ArcForm, Vector @@ -7,7 +7,6 @@ "BlockModel", "BlockID", "ResolverType", - "BlockEmbeddingModel", "RelationModel", "RelationID", "StorageModel", diff --git a/app/schemas/info_base/block.py b/app/schemas/info_base/block.py index b5b5b8b..a238a24 100644 --- a/app/schemas/info_base/block.py +++ b/app/schemas/info_base/block.py @@ -57,24 +57,3 @@ async def get_context_as_text(self) -> str: resolver = ResolverManager.new_resolver(self) return await resolver.get_text() - - -class BlockEmbeddingModel(sqlmodel.SQLModel, table=True): - __tablename__ = "block_embeddings" # type: ignore - - id: int = sqlmodel.Field( - sa_column=sqlalchemy.Column( - sqlalchemy.Integer, - sqlalchemy.ForeignKey("blocks.id", ondelete="CASCADE", onupdate="CASCADE"), - primary_key=True, - ), - ) - embedding: "Vector" = sqlmodel.Field( - sa_column=sqlalchemy.Column(pgvector.sqlalchemy.VECTOR(1024), nullable=False) - ) - updated_at: datetime.datetime = sqlmodel.Field( - default_factory=datetime.datetime.now, - sa_column=sqlalchemy.Column( - sqlalchemy.TIMESTAMP(timezone=True), onupdate=datetime.datetime.now - ), - ) diff --git a/app/schemas/info_base/relation.py b/app/schemas/info_base/relation.py index 83accb4..8475494 100644 --- a/app/schemas/info_base/relation.py +++ b/app/schemas/info_base/relation.py @@ -3,7 +3,6 @@ from typing import Optional as Opt import datetime import sqlalchemy -import pgvector.sqlalchemy import sqlmodel if typing.TYPE_CHECKING: @@ -44,24 +43,3 @@ class RelationModel(sqlmodel.SQLModel, table=True): content: str = sqlmodel.Field( sa_column=sqlalchemy.Column(sqlalchemy.Text, nullable=False) ) - - -class RelationEmbeddingModel(sqlmodel.SQLModel, table=True): - __tablename__ = "relation_embeddings" # type: ignore - - id: int = sqlmodel.Field( - sa_column=sqlalchemy.Column( - sqlalchemy.Integer, - sqlalchemy.ForeignKey("relations.id", ondelete="CASCADE", onupdate="CASCADE"), - primary_key=True, - ), - ) - embedding: "Vector" = sqlmodel.Field( - sa_column=sqlalchemy.Column(pgvector.sqlalchemy.VECTOR(1024), nullable=False) - ) - updated_at: datetime.datetime = sqlmodel.Field( - default_factory=datetime.datetime.now, - sa_column=sqlalchemy.Column( - sqlalchemy.TIMESTAMP(timezone=True), onupdate=datetime.datetime.now - ), - ) diff --git a/app/schemas/sink/__init__.py b/app/schemas/sink/__init__.py new file mode 100644 index 0000000..6b1d851 --- /dev/null +++ b/app/schemas/sink/__init__.py @@ -0,0 +1,6 @@ +from .embedding import BlockEmbeddingModel, RelationEmbeddingModel + +__all__ = [ + "BlockEmbeddingModel", + "RelationEmbeddingModel", +] diff --git a/app/schemas/sink/embedding.py b/app/schemas/sink/embedding.py new file mode 100644 index 0000000..91ccdf6 --- /dev/null +++ b/app/schemas/sink/embedding.py @@ -0,0 +1,50 @@ +import datetime +import typing +import sqlalchemy +import pgvector.sqlalchemy +import sqlmodel + +if typing.TYPE_CHECKING: + from app.schemas.info_base.main import Vector + + +class BlockEmbeddingModel(sqlmodel.SQLModel, table=True): + __tablename__ = "block_embeddings" # type: ignore + + id: int = sqlmodel.Field( + sa_column=sqlalchemy.Column( + sqlalchemy.Integer, + sqlalchemy.ForeignKey("blocks.id", ondelete="CASCADE", onupdate="CASCADE"), + primary_key=True, + ), + ) + embedding: "Vector" = sqlmodel.Field( + sa_column=sqlalchemy.Column(pgvector.sqlalchemy.VECTOR(1024), nullable=False) + ) + updated_at: datetime.datetime = sqlmodel.Field( + default_factory=datetime.datetime.now, + sa_column=sqlalchemy.Column( + sqlalchemy.TIMESTAMP(timezone=True), onupdate=datetime.datetime.now + ), + ) + + +class RelationEmbeddingModel(sqlmodel.SQLModel, table=True): + __tablename__ = "relation_embeddings" # type: ignore + + id: int = sqlmodel.Field( + sa_column=sqlalchemy.Column( + sqlalchemy.Integer, + sqlalchemy.ForeignKey("relations.id", ondelete="CASCADE", onupdate="CASCADE"), + primary_key=True, + ), + ) + embedding: "Vector" = sqlmodel.Field( + sa_column=sqlalchemy.Column(pgvector.sqlalchemy.VECTOR(1024), nullable=False) + ) + updated_at: datetime.datetime = sqlmodel.Field( + default_factory=datetime.datetime.now, + sa_column=sqlalchemy.Column( + sqlalchemy.TIMESTAMP(timezone=True), onupdate=datetime.datetime.now + ), + ) diff --git a/docs/development.md b/docs/development.md index 36a802b..d3ce03e 100644 --- a/docs/development.md +++ b/docs/development.md @@ -1,4 +1,5 @@ - To develop the InKCre/core-py, following resources are required for you to test and debug so as to verify your changes: - PostgreSQL database - - A Github Action is configured to checkout a database branch (NeonDB) for each PR with branch name `pr/`. The checked out branch's parent branch is `pr/` and schema only. + - A Github Action is configured to checkout a database branch (NeonDB) for each PR with database branch name ``. The checked out branch's parent branch is `` and schema only (use `develop` as fallback) (the parent branch cannot has legacy web acesss roles) + - And the schema diff will be commented to the PR (compare to the PR target branch) - `copilot-setup-steps` also checked out a database branch for Github Copilot Agent and configure the DATABASE_URL in `.env`. \ No newline at end of file diff --git a/docs/openapi.json b/docs/openapi.json index 0729604..de9a577 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -353,7 +353,7 @@ "block" ], "summary": "Refresh Embeddings", - "description": "Rebuild all blocks' embeddings", + "description": "Rebuild all blocks' embeddings - delegates to sink embedding service", "operationId": "refresh_embeddings_blocks_embeddings_put", "responses": { "200": { @@ -764,6 +764,7 @@ "sink" ], "summary": "Rag", + "description": "RAG (Retrieval Augmented Generation) endpoint\n\n:param query: User query\n:param context: Additional context string\n:param context_blocks: Additional context block IDs\n:param retrieve_mode: Retrieval mode - \"embedding\", \"reasoning\", or \"feature\"\n:param use_reranker: Whether to use reranker to improve retrieval results\n:param num_retrieve: Number of blocks to retrieve initially\n:param num_rerank: Number of blocks to keep after reranking", "operationId": "rag_sink_rag_get", "parameters": [ { @@ -810,7 +811,37 @@ "required": false, "schema": { "$ref": "#/components/schemas/RetrieveMode", - "default": "feature" + "default": "embedding" + } + }, + { + "name": "use_reranker", + "in": "query", + "required": false, + "schema": { + "type": "boolean", + "default": true, + "title": "Use Reranker" + } + }, + { + "name": "num_retrieve", + "in": "query", + "required": false, + "schema": { + "type": "integer", + "default": 20, + "title": "Num Retrieve" + } + }, + { + "name": "num_rerank", + "in": "query", + "required": false, + "schema": { + "type": "integer", + "default": 5, + "title": "Num Rerank" } } ], diff --git a/migrations/grant.sql b/migrations/grant.sql index 1d39d43..9a31e50 100644 --- a/migrations/grant.sql +++ b/migrations/grant.sql @@ -1,4 +1,4 @@ -CREATE ROLE authenticated NOLOGIN; +CREATE ROLE authenticated NOLOGIN ; GRANT authenticated TO neondb_owner; GRANT USAGE, SELECT ON ALL SEQUENCES IN SCHEMA public TO authenticated; GRANT SELECT ON public.sources_types TO authenticated; diff --git a/migrations/versions/e5a01f9e69ef_init.py b/migrations/versions/e5a01f9e69ef_init.py new file mode 100644 index 0000000..8b7965e --- /dev/null +++ b/migrations/versions/e5a01f9e69ef_init.py @@ -0,0 +1,243 @@ +"""init + +Revision ID: e5a01f9e69ef +Revises: +Create Date: 2025-12-27 21:22:28.578426 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql +import app.schemas.source.main +import pgvector.sqlalchemy.vector + +# revision identifiers, used by Alembic. +revision: str = "e5a01f9e69ef" +down_revision: Union[str, Sequence[str], None] = None +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + # Install extensions + op.execute("CREATE EXTENSION IF NOT EXISTS vector;") + + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "extensions", + sa.Column("id", sa.TEXT, nullable=False), + sa.Column("version", sa.Text(), nullable=False), + sa.Column("disabled", sa.Boolean(), nullable=False), + sa.Column("nickname", sa.TEXT, nullable=True), + sa.Column( + "config", + postgresql.JSONB(astext_type=sa.Text()), + server_default=sa.text("'{}'::jsonb"), + nullable=True, + ), + sa.Column("config_schema", postgresql.JSONB(astext_type=sa.Text()), nullable=True), + sa.PrimaryKeyConstraint("id"), + ) + op.create_table( + "logs", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column( + "timestamp", + sa.TIMESTAMP(timezone=True), + server_default=sa.text("CURRENT_TIMESTAMP"), + nullable=True, + ), + sa.Column("severity_number", sa.SmallInteger(), nullable=False), + sa.Column("severity_text", sa.TEXT, nullable=False), + sa.Column("body", sa.TEXT, nullable=False), + sa.Column("trace_id", sa.TEXT, nullable=True), + sa.Column("span_id", sa.TEXT, nullable=True), + sa.Column( + "attributes", + postgresql.JSONB(astext_type=sa.Text()), + server_default=sa.text("'{}'::jsonb"), + nullable=False, + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_table( + "sources_types", + sa.Column("id", sa.Text(), nullable=False), + sa.Column("description", sa.Text(), nullable=True), + sa.Column( + "config_schema", + postgresql.JSONB(astext_type=sa.Text()), + server_default=sa.text("'{}'::jsonb"), + nullable=True, + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_table( + "storage_types", + sa.Column("id", sa.Text(), nullable=False), + sa.Column("description", sa.Text(), nullable=True), + sa.Column( + "config_schema", + postgresql.JSONB(astext_type=sa.Text()), + server_default=sa.text("'{}'::jsonb"), + nullable=True, + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_table( + "sources", + sa.Column("id", sa.Integer(), autoincrement=True, nullable=False), + sa.Column("type", sa.Text(), nullable=True), + sa.Column("nickname", sa.TEXT, nullable=True), + sa.Column( + "config", + postgresql.JSONB(astext_type=sa.Text()), + server_default=sa.text("'{}'::jsonb"), + nullable=True, + ), + sa.Column( + "collect_at", + app.schemas.source.main.CollectAtType(astext_type=sa.Text()), + nullable=True, + ), + sa.Column( + "state", + postgresql.JSONB(astext_type=sa.Text()), + server_default=sa.text("'{}'::jsonb"), + nullable=True, + ), + sa.ForeignKeyConstraint( + ["type"], ["sources_types.id"], onupdate="CASCADE", ondelete="CASCADE" + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_table( + "storages", + sa.Column("id", sa.Integer(), autoincrement=True, nullable=False), + sa.Column("type", sa.Text(), nullable=True), + sa.Column("nickname", sa.TEXT, nullable=True), + sa.Column( + "config", + postgresql.JSONB(astext_type=sa.Text()), + server_default=sa.text("'{}'::jsonb"), + nullable=True, + ), + sa.ForeignKeyConstraint( + ["type"], ["storage_types.id"], onupdate="CASCADE", ondelete="CASCADE" + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_table( + "blocks", + sa.Column("id", sa.Integer(), autoincrement=True, nullable=False), + sa.Column( + "created_at", + sa.TIMESTAMP(timezone=True), + server_default=sa.text("CURRENT_TIMESTAMP"), + nullable=True, + ), + sa.Column( + "updated_at", + sa.TIMESTAMP(timezone=True), + server_default=sa.text("CURRENT_TIMESTAMP"), + nullable=True, + ), + sa.Column("storage", sa.Integer(), nullable=True), + sa.Column("resolver", sa.Text(), nullable=False), + sa.Column("content", sa.Text(), nullable=False), + sa.ForeignKeyConstraint( + ["storage"], ["storages.id"], onupdate="CASCADE", ondelete="SET NULL" + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_table( + "sources_collect_jobs", + sa.Column("id", sa.Integer(), autoincrement=True, nullable=False), + sa.Column("source", sa.Integer(), nullable=True), + sa.Column("created_at", sa.TIMESTAMP(timezone=True), nullable=True), + sa.Column("started_at", sa.TIMESTAMP(timezone=True), nullable=True), + sa.Column("closed_at", sa.TIMESTAMP(timezone=True), nullable=True), + sa.Column( + "status", + sa.Enum("pending", "running", "finished", "failed", name="sourcecollectjobstatus"), + server_default="pending", + nullable=True, + ), + sa.Column("state", postgresql.JSONB(astext_type=sa.Text()), nullable=True), + sa.Column("config", postgresql.JSONB(astext_type=sa.Text()), nullable=True), + sa.ForeignKeyConstraint( + ["source"], ["sources.id"], onupdate="CASCADE", ondelete="CASCADE" + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_table( + "block_embeddings", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("embedding", pgvector.sqlalchemy.vector.VECTOR(dim=1024), nullable=False), + sa.Column("updated_at", sa.TIMESTAMP(timezone=True), nullable=True), + sa.ForeignKeyConstraint(["id"], ["blocks.id"], onupdate="CASCADE", ondelete="CASCADE"), + sa.PrimaryKeyConstraint("id"), + ) + op.create_table( + "relations", + sa.Column("id", sa.Integer(), autoincrement=True, nullable=False), + sa.Column( + "updated_at", + sa.TIMESTAMP(timezone=True), + server_default=sa.text("CURRENT_TIMESTAMP"), + nullable=True, + ), + sa.Column("from_", sa.Integer(), nullable=True), + sa.Column("to_", sa.Integer(), nullable=True), + sa.Column("content", sa.Text(), nullable=False), + sa.ForeignKeyConstraint( + ["from_"], ["blocks.id"], onupdate="CASCADE", ondelete="CASCADE" + ), + sa.ForeignKeyConstraint(["to_"], ["blocks.id"], onupdate="CASCADE", ondelete="CASCADE"), + sa.PrimaryKeyConstraint("id"), + ) + op.create_table( + "relation_embeddings", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("embedding", pgvector.sqlalchemy.vector.VECTOR(dim=1024), nullable=False), + sa.Column("updated_at", sa.TIMESTAMP(timezone=True), nullable=True), + sa.ForeignKeyConstraint( + ["id"], ["relations.id"], onupdate="CASCADE", ondelete="CASCADE" + ), + sa.PrimaryKeyConstraint("id"), + ) + # ### end Alembic commands ### + + # Create role and grant for PostgREST + op.execute("CREATE ROLE authenticated NOLOGIN NOINHERIT;") + op.execute("GRANT authenticated TO postgres;") + op.execute("GRANT SELECT ON public.sources_types TO authenticated;") + op.execute("GRANT SELECT, INSERT, UPDATE, DELETE ON public.sources TO authenticated;") + op.execute("GRANT SELECT, INSERT ON public.sources_collect_jobs TO authenticated;") + op.execute("GRANT SELECT ON public.extensions TO authenticated;") + op.execute("GRANT SELECT ON public.logs TO authenticated;") + op.execute( + "GRANT SELECT ON public.blocks, public.relations, public.storages, public.storage_types TO authenticated;" + ) + op.execute("GRANT USAGE, SELECT ON ALL SEQUENCES IN SCHEMA public TO authenticated;") + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("relation_embeddings") + op.drop_table("relations") + op.drop_table("block_embeddings") + op.drop_table("sources_collect_jobs") + op.drop_table("blocks") + op.drop_table("storages") + op.drop_table("sources") + op.drop_table("storage_types") + op.drop_table("sources_types") + op.drop_table("logs") + op.drop_table("extensions") + # ### end Alembic commands ### diff --git a/pyproject.toml b/pyproject.toml index 4f4e939..2cc16ea 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,14 +47,11 @@ package-mode = false distribution = false [tool.pdm.scripts] -alembic-gengrade = "alembic revision --autogenerate -m {args} && alembic upgrade head" -alembic-revision = "alembic revision -m {args}" -alembic-autogen = "alembic revision --autogenerate -m {args}" -alembic-upgrade = "alembic upgrade {args}" -alembic-downgrade = "alembic downgrade {args}" -alembic-current = "alembic current" -alembic-history = "alembic history" -alembic-stamp = "alembic stamp {args}" +"db:revision" = "alembic revision -m {args}" +"db:generate" = "alembic revision --autogenerate -m {args}" +"db:migrate" = "alembic upgrade head" +"db:downgrade" = "alembic downgrade {args}" + [build-system] requires = [] build-backend = "none" diff --git a/run.py b/run.py index 5d633fb..10db4f8 100644 --- a/run.py +++ b/run.py @@ -43,6 +43,7 @@ async def lifespan(app: fastapi.FastAPI): from app.business.source import SourceCollectJobManager from app.business.info_base.storage import StorageManager + from app.business.sink.embedding import EmbeddingManager logger.info("Application startup") @@ -59,6 +60,14 @@ async def lifespan(app: fastapi.FastAPI): id="sources.collect_jobs.check_pending", ) + # Add periodic job to check and create missing embeddings + scheduler.add_job( + EmbeddingManager.check_and_create_missing_embeddings, + "interval", + seconds=60, # Check every minute + id="sink.embeddings.check_missing", + ) + yield logger.info("Application shutdown") scheduler.shutdown(wait=True)