diff --git a/.github/ISSUE_TEMPLATE/bug_report.yml b/.github/ISSUE_TEMPLATE/bug_report.yml index 8b99b71dd..ad0d83726 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.yml +++ b/.github/ISSUE_TEMPLATE/bug_report.yml @@ -15,7 +15,7 @@ body: Issue 标题请保持原有模板分类(例如:`[Bug]`), 长段描述之间可以增加`空行`或使用`序号`标记,保持排版清晰。 3. Please submit an issue in the corresponding module, lack of valid information / long time (>14 days) unanswered issues may be closed (will be reopened when updated). 请在对应的模块提交 issue, 缺乏有效信息 / 长时间 (> 14 天) 没有回复的 issue 可能会被 **关闭**(更新时会再开启)。 - + - type: dropdown attributes: label: Bug Type (问题类型) @@ -26,7 +26,7 @@ body: - data inconsistency (数据不一致) - exception / error (异常报错) - others (please comment below) - + - type: checkboxes attributes: label: Before submit @@ -45,7 +45,7 @@ body: - Data Size: xx vertices, xx edges > validations: required: true - + - type: textarea attributes: label: Expected & Actual behavior (期望与实际表现) @@ -53,8 +53,8 @@ body: we can refer [How to create a minimal reproducible Example](https://stackoverflow.com/help/minimal-reproducible-example), if possible, please provide screenshots or GIF. 可以参考 [如何提供最简的可复现用例](https://stackoverflow.com/help/minimal-reproducible-example),请提供清晰的截图,动图录屏更佳。 placeholder: | - type the main problem here - + type the main problem here + ```java // Detailed exception / error info (尽可能详细的日志 + 完整异常栈) diff --git a/.github/ISSUE_TEMPLATE/feature_request.yml b/.github/ISSUE_TEMPLATE/feature_request.yml index 62547e084..2b3d5e7a1 100644 --- a/.github/ISSUE_TEMPLATE/feature_request.yml +++ b/.github/ISSUE_TEMPLATE/feature_request.yml @@ -23,7 +23,7 @@ body: description: > Please describe the function you want in as much detail as possible. (请简要描述新功能 / 需求的使用场景或上下文, 最好能给个具体的例子说明) - placeholder: type the feature description here + placeholder: type the feature description here validations: required: true @@ -50,4 +50,3 @@ body: - type: markdown attributes: value: "Thanks for completing our form, and we will reply you as soon as possible." - diff --git a/.github/ISSUE_TEMPLATE/question_ask.yml b/.github/ISSUE_TEMPLATE/question_ask.yml index a59a4fb3d..5f910c909 100644 --- a/.github/ISSUE_TEMPLATE/question_ask.yml +++ b/.github/ISSUE_TEMPLATE/question_ask.yml @@ -27,7 +27,7 @@ body: - configs (配置项 / 文档相关) - exception / error (异常报错) - others (please comment below) - + - type: checkboxes attributes: label: Before submit @@ -54,8 +54,8 @@ body: For issues related to graph usage/configuration, please refer to [REST-API documentation](https://hugegraph.apache.org/docs/clients/restful-api/), and [Server configuration documentation](https://hugegraph.apache.org/docs/config/config-option/) (if possible, please provide screenshots or GIF). 图使用 / 配置相关问题,请优先参考 [REST-API 文档](https://hugegraph.apache.org/docs/clients/restful-api/), 以及 [Server 配置文档](https://hugegraph.apache.org/docs/config/config-option/) (请提供清晰的截图,动图录屏更佳) placeholder: | - type the main problem here - + type the main problem here + ```java // Exception / Error info (尽可能详细的日志 + 完整异常栈) diff --git a/.github/workflows/pylint.yml b/.github/workflows/ruff.yml similarity index 79% rename from .github/workflows/pylint.yml rename to .github/workflows/ruff.yml index 351628aa0..f158a62e8 100644 --- a/.github/workflows/pylint.yml +++ b/.github/workflows/ruff.yml @@ -1,12 +1,10 @@ -# TODO: replace by ruff & mypy soon -name: "Pylint" +name: "Ruff Code Quality" on: push: branches: - - 'main' - - 'master' - - 'release-*' + - "main" + - "release-*" pull_request: jobs: @@ -14,7 +12,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.10", "3.11"] + python-version: ["3.10", "3.11", "3.12"] steps: - uses: actions/checkout@v4 @@ -48,6 +46,10 @@ jobs: run: | uv run python -c "import dgl; print(dgl.__version__)" - - name: Analysing the code with pylint + - name: Check code formatting with Ruff run: | - uv run bash ./style/code_format_and_analysis.sh -p + uv run ruff format --check . + + - name: Lint code with Ruff + run: | + uv run ruff check . diff --git a/.gitignore b/.gitignore index 764121125..1822afb7f 100644 --- a/.gitignore +++ b/.gitignore @@ -185,8 +185,6 @@ venv.bak/ # Rope project settings .ropeproject -.pre-commit-config.yaml - # mkdocs documentation /site diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 000000000..bdd33dc77 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,39 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.6.0 + hooks: + - id: trailing-whitespace + - id: end-of-file-fixer + - id: check-yaml + - id: check-added-large-files + args: ["--maxkb=1000"] + - id: check-merge-conflict + - id: check-case-conflict + - id: check-docstring-first + - id: debug-statements + + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.13.1 + hooks: + - id: ruff-format + types_or: [python, pyi] + - id: ruff + types_or: [python, pyi] + args: [--fix] diff --git a/DISCLAIMER b/DISCLAIMER index be718eef3..be557e360 100644 --- a/DISCLAIMER +++ b/DISCLAIMER @@ -1,7 +1,7 @@ Apache HugeGraph (incubating) is an effort undergoing incubation at The Apache Software Foundation (ASF), sponsored by the Apache Incubator PMC. -Incubation is required of all newly accepted projects until a further review indicates that the infrastructure, communications, +Incubation is required of all newly accepted projects until a further review indicates that the infrastructure, communications, and decision making process have stabilized in a manner consistent with other successful ASF projects. -While incubation status is not necessarily a reflection of the completeness or stability of the code, +While incubation status is not necessarily a reflection of the completeness or stability of the code, it does indicate that the project has yet to be fully endorsed by the ASF. diff --git a/README.md b/README.md index a495968ec..ef0708251 100644 --- a/README.md +++ b/README.md @@ -178,8 +178,19 @@ uv add numpy # Add to base dependencies uv add --group dev pytest-mock # Add to dev group ``` -**Key Points:** +### Code Quality (ruff + pre-commit) + +- Ruff is used for linting and formatting: + - \`ruff format .\` + - \`ruff check .\` +- Enable Git hooks via pre-commit: + - \`pre-commit install\` + - \`pre-commit run --all-files\` +- Config: [.pre-commit-config.yaml](.pre-commit-config.yaml). CI enforces these checks. + **Key Points:** +- Config: [.pre-commit-config.yaml](.pre-commit-config.yaml). CI enforces these checks. +**Key Points:** - Use [GitHub Desktop](https://desktop.github.com/) for easier PR management - Check existing issues before reporting bugs diff --git a/docker/docker-compose-network.yml b/docker/docker-compose-network.yml index 04ab6e004..43c3aecc3 100644 --- a/docker/docker-compose-network.yml +++ b/docker/docker-compose-network.yml @@ -48,4 +48,4 @@ services: interval: 30s timeout: 10s retries: 3 - start_period: 60s + start_period: 60s diff --git a/hugegraph-llm/AGENTS.md b/hugegraph-llm/AGENTS.md index a15eb4328..4ca973fff 100644 --- a/hugegraph-llm/AGENTS.md +++ b/hugegraph-llm/AGENTS.md @@ -4,9 +4,9 @@ This file provides guidance to AI coding tools and developers when working with ## Project Overview -HugeGraph-LLM is a comprehensive toolkit that bridges graph databases and large language models, -part of the Apache HugeGraph AI ecosystem. It enables seamless integration between HugeGraph and LLMs for building -intelligent applications with three main capabilities: Knowledge Graph Construction, Graph-Enhanced RAG, +HugeGraph-LLM is a comprehensive toolkit that bridges graph databases and large language models, +part of the Apache HugeGraph AI ecosystem. It enables seamless integration between HugeGraph and LLMs for building +intelligent applications with three main capabilities: Knowledge Graph Construction, Graph-Enhanced RAG, and Text2Gremlin query generation. ## Tech Stack diff --git a/hugegraph-llm/CI_FIX_SUMMARY.md b/hugegraph-llm/CI_FIX_SUMMARY.md index 65a6ce8e2..ed186764a 100644 --- a/hugegraph-llm/CI_FIX_SUMMARY.md +++ b/hugegraph-llm/CI_FIX_SUMMARY.md @@ -33,7 +33,7 @@ export SKIP_EXTERNAL_SERVICES=true cd hugegraph-llm export PYTHONPATH="$(pwd)/src:$PYTHONPATH" - + # 跳过有问题的测试 python -m pytest src/tests/ -v --tb=short \ --ignore=src/tests/integration/ \ @@ -46,7 +46,7 @@ - uses: actions/checkout@v4 with: fetch-depth: 0 # 获取完整历史 - + - name: Sync latest changes run: | git pull origin main # 确保获取最新更改 diff --git a/hugegraph-llm/README.md b/hugegraph-llm/README.md index 9e2bbe05d..d5dd83627 100644 --- a/hugegraph-llm/README.md +++ b/hugegraph-llm/README.md @@ -22,6 +22,16 @@ For detailed source code doc, visit our [DeepWiki](https://deepwiki.com/apache/i > - **HugeGraph Server**: 1.3+ (recommended: 1.5+) > - **UV Package Manager**: 0.7+ +### Code Quality (ruff + pre-commit) + +- Ruff is used for linting and formatting: + - `ruff format .` + - `ruff check .` +- Enable Git hooks via pre-commit: + - `pre-commit install` (in the root dir) + - `pre-commit run --all-files` +- Config: [../.pre-commit-config.yaml](../.pre-commit-config.yaml) + ## 🚀 Quick Start Choose your preferred deployment method: @@ -177,7 +187,7 @@ The system supports both English and Chinese prompts. To switch languages: If you previously used high-level classes like `RAGPipeline` or `KgBuilder`, the project now exposes stable flows through the `Scheduler` API. Use `SchedulerSingleton.get_instance().schedule_flow(...)` to invoke workflows programmatically. Below are concise, working examples that match the new architecture. -1) RAG (graph-only) query example +1. RAG (graph-only) query example ```python from hugegraph_llm.flows.scheduler import SchedulerSingleton @@ -196,7 +206,7 @@ res = scheduler.schedule_flow( print(res.get("graph_only_answer")) ``` -2) RAG (vector-only) query example +2. RAG (vector-only) query example ```python from hugegraph_llm.flows.scheduler import SchedulerSingleton @@ -212,7 +222,7 @@ res = scheduler.schedule_flow( print(res.get("vector_only_answer")) ``` -3) Text -> Gremlin (text2gremlin) example +3. Text -> Gremlin (text2gremlin) example ```python from hugegraph_llm.flows.scheduler import SchedulerSingleton @@ -230,7 +240,7 @@ response = scheduler.schedule_flow( print(response.get("template_gremlin")) ``` -4) Build example index (used by text2gremlin examples) +4. Build example index (used by text2gremlin examples) ```python from hugegraph_llm.flows.scheduler import SchedulerSingleton diff --git a/hugegraph-llm/quick_start.md b/hugegraph-llm/quick_start.md index dab247c41..8a8bfe8a4 100644 --- a/hugegraph-llm/quick_start.md +++ b/hugegraph-llm/quick_start.md @@ -26,7 +26,7 @@ graph TD; A --> F[Text Segmentation] F --> G[LLM extracts graph based on schema \nand segmented text] G --> H[Store graph in Graph Database, \nautomatically vectorize vertices \nand store in Vector Database] - + I[Retrieve vertices from Graph Database] --> J[Vectorize vertices and store in Vector Database \nNote: Incremental update] ``` @@ -85,7 +85,7 @@ graph TD; F --> G[Match vertices precisely in Graph Database \nusing keywords; perform fuzzy matching in \nVector Database (graph vid)] G --> H[Generate Gremlin query using matched vertices and query with LLM] H --> I[Execute Gremlin query; if successful, finish; if failed, fallback to BFS] - + B --> J[Sort results] I --> J J --> K[Generate answer] @@ -162,7 +162,7 @@ The first part is straightforward, so the focus is on the second part. graph TD; A[Gremlin Pairs File] --> C[Vectorize query] C --> D[Store in Vector Database] - + F[Natural Language Query] --> G[Search for the most similar query \nin the Vector Database \n(If no Gremlin pairs exist in the Vector Database, \ndefault files will be automatically vectorized) \nand retrieve the corresponding Gremlin] G --> H[Add the matched pair to the prompt \nand use LLM to generate the Gremlin \ncorresponding to the Natural Language Query] ``` diff --git a/hugegraph-llm/src/hugegraph_llm/api/admin_api.py b/hugegraph-llm/src/hugegraph_llm/api/admin_api.py index 109da4a99..016427f1c 100644 --- a/hugegraph-llm/src/hugegraph_llm/api/admin_api.py +++ b/hugegraph-llm/src/hugegraph_llm/api/admin_api.py @@ -17,7 +17,7 @@ import os -from fastapi import status, APIRouter +from fastapi import APIRouter, status from fastapi.responses import StreamingResponse from hugegraph_llm.api.exceptions.rag_exceptions import generate_response diff --git a/hugegraph-llm/src/hugegraph_llm/api/exceptions/rag_exceptions.py b/hugegraph-llm/src/hugegraph_llm/api/exceptions/rag_exceptions.py index 18723e30b..fef71733f 100644 --- a/hugegraph-llm/src/hugegraph_llm/api/exceptions/rag_exceptions.py +++ b/hugegraph-llm/src/hugegraph_llm/api/exceptions/rag_exceptions.py @@ -16,14 +16,13 @@ # under the License. from fastapi import HTTPException + from hugegraph_llm.api.models.rag_response import RAGResponse class ExternalException(HTTPException): def __init__(self): - super().__init__( - status_code=400, detail="Connect failed with error code -1, please check the input." - ) + super().__init__(status_code=400, detail="Connect failed with error code -1, please check the input.") class ConnectionFailedException(HTTPException): diff --git a/hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py b/hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py index f46aea02c..433438eee 100644 --- a/hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py +++ b/hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py @@ -15,8 +15,9 @@ # specific language governing permissions and limitations # under the License. -from typing import Optional, Literal, List from enum import Enum +from typing import List, Literal, Optional + from fastapi import Query from pydantic import BaseModel, field_validator @@ -36,23 +37,13 @@ class RAGRequest(BaseModel): raw_answer: bool = Query(False, description="Use LLM to generate answer directly") vector_only: bool = Query(False, description="Use LLM to generate answer with vector") graph_only: bool = Query(True, description="Use LLM to generate answer with graph RAG only") - graph_vector_answer: bool = Query( - False, description="Use LLM to generate answer with vector & GraphRAG" - ) + graph_vector_answer: bool = Query(False, description="Use LLM to generate answer with vector & GraphRAG") graph_ratio: float = Query(0.5, description="The ratio of GraphRAG ans & vector ans") - rerank_method: Literal["bleu", "reranker"] = Query( - "bleu", description="Method to rerank the results." - ) - near_neighbor_first: bool = Query( - False, description="Prioritize near neighbors in the search results." - ) - custom_priority_info: str = Query( - "", description="Custom information to prioritize certain results." - ) + rerank_method: Literal["bleu", "reranker"] = Query("bleu", description="Method to rerank the results.") + near_neighbor_first: bool = Query(False, description="Prioritize near neighbors in the search results.") + custom_priority_info: str = Query("", description="Custom information to prioritize certain results.") # Graph Configs - max_graph_items: int = Query( - 30, description="Maximum number of items for GQL queries in graph." - ) + max_graph_items: int = Query(30, description="Maximum number of items for GQL queries in graph.") topk_return_results: int = Query(20, description="Number of sorted results to return finally.") vector_dis_threshold: float = Query( 0.9, @@ -64,14 +55,10 @@ class RAGRequest(BaseModel): description="TopK results returned for each keyword \ extracted from the query, by default only the most similar one is returned.", ) - client_config: Optional[GraphConfigRequest] = Query( - None, description="hugegraph server config." - ) + client_config: Optional[GraphConfigRequest] = Query(None, description="hugegraph server config.") # Keep prompt params in the end - answer_prompt: Optional[str] = Query( - prompt.answer_prompt, description="Prompt to guide the answer generation." - ) + answer_prompt: Optional[str] = Query(prompt.answer_prompt, description="Prompt to guide the answer generation.") keywords_extract_prompt: Optional[str] = Query( prompt.keywords_extract_prompt, description="Prompt for extracting keywords from query.", @@ -87,9 +74,7 @@ class RAGRequest(BaseModel): class GraphRAGRequest(BaseModel): query: str = Query(..., description="Query you want to ask") # Graph Configs - max_graph_items: int = Query( - 30, description="Maximum number of items for GQL queries in graph." - ) + max_graph_items: int = Query(30, description="Maximum number of items for GQL queries in graph.") topk_return_results: int = Query(20, description="Number of sorted results to return finally.") vector_dis_threshold: float = Query( 0.9, @@ -102,24 +87,16 @@ class GraphRAGRequest(BaseModel): from the query, by default only the most similar one is returned.", ) - client_config: Optional[GraphConfigRequest] = Query( - None, description="hugegraph server config." - ) + client_config: Optional[GraphConfigRequest] = Query(None, description="hugegraph server config.") get_vertex_only: bool = Query(False, description="return only keywords & vertex (early stop).") gremlin_tmpl_num: int = Query( 1, description="Number of Gremlin templates to use. If num <=0 means template is not provided", ) - rerank_method: Literal["bleu", "reranker"] = Query( - "bleu", description="Method to rerank the results." - ) - near_neighbor_first: bool = Query( - False, description="Prioritize near neighbors in the search results." - ) - custom_priority_info: str = Query( - "", description="Custom information to prioritize certain results." - ) + rerank_method: Literal["bleu", "reranker"] = Query("bleu", description="Method to rerank the results.") + near_neighbor_first: bool = Query(False, description="Prioritize near neighbors in the search results.") + custom_priority_info: str = Query("", description="Custom information to prioritize certain results.") gremlin_prompt: Optional[str] = Query( prompt.gremlin_generate_prompt, description="Prompt for the Text2Gremlin query.", @@ -163,16 +140,12 @@ class GremlinOutputType(str, Enum): class GremlinGenerateRequest(BaseModel): query: str - example_num: Optional[int] = Query( - 0, description="Number of Gremlin templates to use.(0 means no templates)" - ) + example_num: Optional[int] = Query(0, description="Number of Gremlin templates to use.(0 means no templates)") gremlin_prompt: Optional[str] = Query( prompt.gremlin_generate_prompt, description="Prompt for the Text2Gremlin query.", ) - client_config: Optional[GraphConfigRequest] = Query( - None, description="hugegraph server config." - ) + client_config: Optional[GraphConfigRequest] = Query(None, description="hugegraph server config.") output_types: Optional[List[GremlinOutputType]] = Query( default=[GremlinOutputType.TEMPLATE_GREMLIN], description=""" @@ -189,7 +162,5 @@ def validate_prompt_placeholders(cls, v): required_placeholders = ["{query}", "{schema}", "{example}", "{vertices}"] missing = [p for p in required_placeholders if p not in v] if missing: - raise ValueError( - f"Prompt template is missing required placeholders: {', '.join(missing)}" - ) + raise ValueError(f"Prompt template is missing required placeholders: {', '.join(missing)}") return v diff --git a/hugegraph-llm/src/hugegraph_llm/api/rag_api.py b/hugegraph-llm/src/hugegraph_llm/api/rag_api.py index ca29cb9ab..3838df2fc 100644 --- a/hugegraph-llm/src/hugegraph_llm/api/rag_api.py +++ b/hugegraph-llm/src/hugegraph_llm/api/rag_api.py @@ -17,20 +17,19 @@ import json -from fastapi import status, APIRouter, HTTPException +from fastapi import APIRouter, HTTPException, status from hugegraph_llm.api.exceptions.rag_exceptions import generate_response from hugegraph_llm.api.models.rag_requests import ( - RAGRequest, GraphConfigRequest, - LLMConfigRequest, - RerankerConfigRequest, GraphRAGRequest, GremlinGenerateRequest, + LLMConfigRequest, + RAGRequest, + RerankerConfigRequest, ) from hugegraph_llm.api.models.rag_response import RAGResponse -from hugegraph_llm.config import huge_settings -from hugegraph_llm.config import llm_settings, prompt +from hugegraph_llm.config import huge_settings, llm_settings, prompt from hugegraph_llm.utils.graph_index_utils import get_vertex_details from hugegraph_llm.utils.log import log @@ -74,8 +73,7 @@ def rag_answer_api(req: RAGRequest): # Keep prompt params in the end custom_related_information=req.custom_priority_info, answer_prompt=req.answer_prompt or prompt.answer_prompt, - keywords_extract_prompt=req.keywords_extract_prompt - or prompt.keywords_extract_prompt, + keywords_extract_prompt=req.keywords_extract_prompt or prompt.keywords_extract_prompt, gremlin_prompt=req.gremlin_prompt or prompt.gremlin_generate_prompt, ) # TODO: we need more info in the response for users to understand the query logic @@ -146,9 +144,7 @@ def graph_rag_recall_api(req: GraphRAGRequest): except TypeError as e: log.error("TypeError in graph_rag_recall_api: %s", e) - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, detail=str(e) - ) from e + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) from e except Exception as e: log.error("Unexpected error occurred: %s", e) raise HTTPException( @@ -159,9 +155,7 @@ def graph_rag_recall_api(req: GraphRAGRequest): @router.post("/config/graph", status_code=status.HTTP_201_CREATED) def graph_config_api(req: GraphConfigRequest): # Accept status code - res = apply_graph_conf( - req.url, req.name, req.user, req.pwd, req.gs, origin_call="http" - ) + res = apply_graph_conf(req.url, req.name, req.user, req.pwd, req.gs, origin_call="http") return generate_response(RAGResponse(status_code=res, message="Missing Value")) # TODO: restructure the implement of llm to three types, like "/config/chat_llm" @@ -178,9 +172,7 @@ def llm_config_api(req: LLMConfigRequest): origin_call="http", ) else: - res = apply_llm_conf( - req.host, req.port, req.language_model, None, origin_call="http" - ) + res = apply_llm_conf(req.host, req.port, req.language_model, None, origin_call="http") return generate_response(RAGResponse(status_code=res, message="Missing Value")) @router.post("/config/embedding", status_code=status.HTTP_201_CREATED) @@ -188,13 +180,9 @@ def embedding_config_api(req: LLMConfigRequest): llm_settings.embedding_type = req.llm_type if req.llm_type == "openai": - res = apply_embedding_conf( - req.api_key, req.api_base, req.language_model, origin_call="http" - ) + res = apply_embedding_conf(req.api_key, req.api_base, req.language_model, origin_call="http") else: - res = apply_embedding_conf( - req.host, req.port, req.language_model, origin_call="http" - ) + res = apply_embedding_conf(req.host, req.port, req.language_model, origin_call="http") return generate_response(RAGResponse(status_code=res, message="Missing Value")) @router.post("/config/rerank", status_code=status.HTTP_201_CREATED) @@ -202,13 +190,9 @@ def rerank_config_api(req: RerankerConfigRequest): llm_settings.reranker_type = req.reranker_type if req.reranker_type == "cohere": - res = apply_reranker_conf( - req.api_key, req.reranker_model, req.cohere_base_url, origin_call="http" - ) + res = apply_reranker_conf(req.api_key, req.reranker_model, req.cohere_base_url, origin_call="http") elif req.reranker_type == "siliconflow": - res = apply_reranker_conf( - req.api_key, req.reranker_model, None, origin_call="http" - ) + res = apply_reranker_conf(req.api_key, req.reranker_model, None, origin_call="http") else: res = status.HTTP_501_NOT_IMPLEMENTED return generate_response(RAGResponse(status_code=res, message="Missing Value")) diff --git a/hugegraph-llm/src/hugegraph_llm/config/admin_config.py b/hugegraph-llm/src/hugegraph_llm/config/admin_config.py index fabc75de4..5f6ed5761 100644 --- a/hugegraph-llm/src/hugegraph_llm/config/admin_config.py +++ b/hugegraph-llm/src/hugegraph_llm/config/admin_config.py @@ -16,6 +16,7 @@ # under the License. from typing import Optional + from .models import BaseConfig diff --git a/hugegraph-llm/src/hugegraph_llm/config/generate.py b/hugegraph-llm/src/hugegraph_llm/config/generate.py index 1bd7adea8..162822959 100644 --- a/hugegraph-llm/src/hugegraph_llm/config/generate.py +++ b/hugegraph-llm/src/hugegraph_llm/config/generate.py @@ -28,9 +28,7 @@ if __name__ == "__main__": parser = argparse.ArgumentParser(description="Generate hugegraph-llm config file") - parser.add_argument( - "-U", "--update", default=True, action="store_true", help="Update the config file" - ) + parser.add_argument("-U", "--update", default=True, action="store_true", help="Update the config file") args = parser.parse_args() if args.update: huge_settings.generate_env() diff --git a/hugegraph-llm/src/hugegraph_llm/config/index_config.py b/hugegraph-llm/src/hugegraph_llm/config/index_config.py index 63895e6a7..346f5f03c 100644 --- a/hugegraph-llm/src/hugegraph_llm/config/index_config.py +++ b/hugegraph-llm/src/hugegraph_llm/config/index_config.py @@ -26,9 +26,7 @@ class IndexConfig(BaseConfig): qdrant_host: Optional[str] = os.environ.get("QDRANT_HOST", None) qdrant_port: int = int(os.environ.get("QDRANT_PORT", "6333")) - qdrant_api_key: Optional[str] = ( - os.environ.get("QDRANT_API_KEY") if os.environ.get("QDRANT_API_KEY") else None - ) + qdrant_api_key: Optional[str] = os.environ.get("QDRANT_API_KEY") if os.environ.get("QDRANT_API_KEY") else None milvus_host: Optional[str] = os.environ.get("MILVUS_HOST", None) milvus_port: int = int(os.environ.get("MILVUS_PORT", "19530")) diff --git a/hugegraph-llm/src/hugegraph_llm/config/llm_config.py b/hugegraph-llm/src/hugegraph_llm/config/llm_config.py index eb094ef88..fd2c82303 100644 --- a/hugegraph-llm/src/hugegraph_llm/config/llm_config.py +++ b/hugegraph-llm/src/hugegraph_llm/config/llm_config.py @@ -17,7 +17,7 @@ import os -from typing import Optional, Literal +from typing import Literal, Optional from .models import BaseConfig @@ -36,33 +36,23 @@ class LLMConfig(BaseConfig): hybrid_llm_weights: Optional[float] = 0.5 # TODO: divide RAG part if necessary # 1. OpenAI settings - openai_chat_api_base: Optional[str] = os.environ.get( - "OPENAI_BASE_URL", "https://api.openai.com/v1" - ) + openai_chat_api_base: Optional[str] = os.environ.get("OPENAI_BASE_URL", "https://api.openai.com/v1") openai_chat_api_key: Optional[str] = os.environ.get("OPENAI_API_KEY") openai_chat_language_model: Optional[str] = "gpt-4.1-mini" - openai_extract_api_base: Optional[str] = os.environ.get( - "OPENAI_BASE_URL", "https://api.openai.com/v1" - ) + openai_extract_api_base: Optional[str] = os.environ.get("OPENAI_BASE_URL", "https://api.openai.com/v1") openai_extract_api_key: Optional[str] = os.environ.get("OPENAI_API_KEY") openai_extract_language_model: Optional[str] = "gpt-4.1-mini" - openai_text2gql_api_base: Optional[str] = os.environ.get( - "OPENAI_BASE_URL", "https://api.openai.com/v1" - ) + openai_text2gql_api_base: Optional[str] = os.environ.get("OPENAI_BASE_URL", "https://api.openai.com/v1") openai_text2gql_api_key: Optional[str] = os.environ.get("OPENAI_API_KEY") openai_text2gql_language_model: Optional[str] = "gpt-4.1-mini" - openai_embedding_api_base: Optional[str] = os.environ.get( - "OPENAI_EMBEDDING_BASE_URL", "https://api.openai.com/v1" - ) + openai_embedding_api_base: Optional[str] = os.environ.get("OPENAI_EMBEDDING_BASE_URL", "https://api.openai.com/v1") openai_embedding_api_key: Optional[str] = os.environ.get("OPENAI_EMBEDDING_API_KEY") openai_embedding_model: Optional[str] = "text-embedding-3-small" openai_chat_tokens: int = 8192 openai_extract_tokens: int = 256 openai_text2gql_tokens: int = 4096 # 2. Rerank settings - cohere_base_url: Optional[str] = os.environ.get( - "CO_API_URL", "https://api.cohere.com/v1/rerank" - ) + cohere_base_url: Optional[str] = os.environ.get("CO_API_URL", "https://api.cohere.com/v1/rerank") reranker_api_key: Optional[str] = None reranker_model: Optional[str] = None # 3. Ollama settings diff --git a/hugegraph-llm/src/hugegraph_llm/config/models/base_prompt_config.py b/hugegraph-llm/src/hugegraph_llm/config/models/base_prompt_config.py index 4b0c4dc76..57d8ba638 100644 --- a/hugegraph-llm/src/hugegraph_llm/config/models/base_prompt_config.py +++ b/hugegraph-llm/src/hugegraph_llm/config/models/base_prompt_config.py @@ -107,7 +107,6 @@ def ensure_yaml_file_exists(self): log.info("Prompt file '%s' doesn't exist, create it.", yaml_file_path) def save_to_yaml(self): - def to_literal(val): return LiteralStr(val) if isinstance(val, str) else val diff --git a/hugegraph-llm/src/hugegraph_llm/config/prompt_config.py b/hugegraph-llm/src/hugegraph_llm/config/prompt_config.py index cc79b3cef..d56c830e8 100644 --- a/hugegraph-llm/src/hugegraph_llm/config/prompt_config.py +++ b/hugegraph-llm/src/hugegraph_llm/config/prompt_config.py @@ -183,7 +183,7 @@ def __init__(self, llm_config_object): - Adjust keyword length: If keywords are relatively broad, you can appropriately increase individual keyword length based on context (e.g., "illegal behavior" can be extracted as a single keyword, or as "illegal", but should not be split into "illegal" and "behavior"). Output Format: - - Output only one line, prefixed with KEYWORDS:, followed by a comma-separated list of items. Each item should be in the format keyword:importance_score(round to two decimal places). If a keyword has been replaced by a synonym, use the synonym as the keyword in the output. + - Output only one line, prefixed with KEYWORDS:, followed by a comma-separated list of items. Each item should be in the format keyword:importance_score(round to two decimal places). If a keyword has been replaced by a synonym, use the synonym as the keyword in the output. - Format example: KEYWORDS:keyword1:score1,keyword2:score2,...,keywordN:scoreN diff --git a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/admin_block.py b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/admin_block.py index 8a0f7ced0..c26164fbf 100644 --- a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/admin_block.py +++ b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/admin_block.py @@ -108,9 +108,7 @@ def create_admin_block(): ) # Error message box, initially hidden - error_message = gr.Textbox( - label="", visible=False, interactive=False, elem_classes="error-message" - ) + error_message = gr.Textbox(label="", visible=False, interactive=False, elem_classes="error-message") # Button to submit password submit_button = gr.Button("Submit") diff --git a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/app.py b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/app.py index d584ccd88..34cd07110 100644 --- a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/app.py +++ b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/app.py @@ -19,23 +19,22 @@ import gradio as gr import uvicorn -from fastapi import FastAPI, Depends, APIRouter -from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials +from fastapi import APIRouter, Depends, FastAPI +from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from hugegraph_llm.api.admin_api import admin_http_api from hugegraph_llm.api.rag_api import rag_http_api from hugegraph_llm.config import admin_settings, huge_settings, prompt from hugegraph_llm.demo.rag_demo.admin_block import create_admin_block, log_stream from hugegraph_llm.demo.rag_demo.configs_block import ( - create_configs_block, - apply_llm_config, apply_embedding_config, - apply_reranker_config, apply_graph_config, + apply_llm_config, + apply_reranker_config, + create_configs_block, + get_header_with_language_indicator, ) -from hugegraph_llm.demo.rag_demo.configs_block import get_header_with_language_indicator -from hugegraph_llm.demo.rag_demo.other_block import create_other_block -from hugegraph_llm.demo.rag_demo.other_block import lifespan +from hugegraph_llm.demo.rag_demo.other_block import create_other_block, lifespan from hugegraph_llm.demo.rag_demo.rag_block import create_rag_block, rag_answer from hugegraph_llm.demo.rag_demo.text2gremlin_block import ( create_text2gremlin_block, @@ -94,20 +93,16 @@ def init_rag_ui() -> gr.Interface: textbox_array_graph_config = create_configs_block() with gr.Tab(label="1. Build RAG Index 💡"): - textbox_input_text, textbox_input_schema, textbox_info_extract_template = ( - create_vector_graph_block() - ) + textbox_input_text, textbox_input_schema, textbox_info_extract_template = create_vector_graph_block() with gr.Tab(label="2. (Graph)RAG & User Functions 📖"): ( textbox_inp, textbox_answer_prompt_input, textbox_keywords_extract_prompt_input, - textbox_custom_related_information + textbox_custom_related_information, ) = create_rag_block() with gr.Tab(label="3. Text2gremlin ⚙️"): - textbox_gremlin_inp, textbox_gremlin_schema, textbox_gremlin_prompt = ( - create_text2gremlin_block() - ) + textbox_gremlin_inp, textbox_gremlin_schema, textbox_gremlin_prompt = create_text2gremlin_block() with gr.Tab(label="4. Graph Tools 🚧"): create_other_block() with gr.Tab(label="5. Admin Tools 🛠"): diff --git a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/configs_block.py b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/configs_block.py index b472ea0ba..72c1eec06 100644 --- a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/configs_block.py +++ b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/configs_block.py @@ -64,9 +64,7 @@ def test_litellm_chat(api_key, api_base, model_name, max_tokens: int) -> int: return 200 -def test_api_connection( - url, method="GET", headers=None, params=None, body=None, auth=None, origin_call=None -) -> int: +def test_api_connection(url, method="GET", headers=None, params=None, body=None, auth=None, origin_call=None) -> int: # TODO: use fastapi.request / starlette instead? log.debug("Request URL: %s", url) try: @@ -133,9 +131,7 @@ def apply_vector_engine_backend( # pylint: disable=too-many-branches if engine == "Milvus": from pymilvus import connections, utility - connections.connect( - host=host, port=int(port or 19530), user=user or "", password=password or "" - ) + connections.connect(host=host, port=int(port or 19530), user=user or "", password=password or "") # Test if we can list collections _ = utility.list_collections() connections.disconnect("default") @@ -193,9 +189,7 @@ def apply_embedding_config(arg1, arg2, arg3, origin_call=None) -> int: test_url = llm_settings.openai_embedding_api_base + "/embeddings" headers = {"Authorization": f"Bearer {arg1}"} data = {"model": arg3, "input": "test"} - status_code = test_api_connection( - test_url, method="POST", headers=headers, body=data, origin_call=origin_call - ) + status_code = test_api_connection(test_url, method="POST", headers=headers, body=data, origin_call=origin_call) elif embedding_option == "ollama/local": llm_settings.ollama_embedding_host = arg1 llm_settings.ollama_embedding_port = int(arg2) @@ -290,26 +284,20 @@ def apply_llm_config( setattr(llm_settings, f"openai_{current_llm_config}_language_model", model_name) setattr(llm_settings, f"openai_{current_llm_config}_tokens", int(max_tokens)) - test_url = ( - getattr(llm_settings, f"openai_{current_llm_config}_api_base") + "/chat/completions" - ) + test_url = getattr(llm_settings, f"openai_{current_llm_config}_api_base") + "/chat/completions" data = { "model": model_name, "temperature": 0.01, "messages": [{"role": "user", "content": "test"}], } headers = {"Authorization": f"Bearer {api_key_or_host}"} - status_code = test_api_connection( - test_url, method="POST", headers=headers, body=data, origin_call=origin_call - ) + status_code = test_api_connection(test_url, method="POST", headers=headers, body=data, origin_call=origin_call) elif llm_option == "ollama/local": setattr(llm_settings, f"ollama_{current_llm_config}_host", api_key_or_host) setattr(llm_settings, f"ollama_{current_llm_config}_port", int(api_base_or_port)) setattr(llm_settings, f"ollama_{current_llm_config}_language_model", model_name) - status_code = test_api_connection( - f"http://{api_key_or_host}:{api_base_or_port}", origin_call=origin_call - ) + status_code = test_api_connection(f"http://{api_key_or_host}:{api_base_or_port}", origin_call=origin_call) elif llm_option == "litellm": setattr(llm_settings, f"litellm_{current_llm_config}_api_key", api_key_or_host) @@ -317,9 +305,7 @@ def apply_llm_config( setattr(llm_settings, f"litellm_{current_llm_config}_language_model", model_name) setattr(llm_settings, f"litellm_{current_llm_config}_tokens", int(max_tokens)) - status_code = test_litellm_chat( - api_key_or_host, api_base_or_port, model_name, int(max_tokens) - ) + status_code = test_litellm_chat(api_key_or_host, api_base_or_port, model_name, int(max_tokens)) gr.Info("Configured!") llm_settings.update_env() @@ -361,9 +347,7 @@ def create_configs_block() -> list: ), ] graph_config_button = gr.Button("Apply Configuration") - graph_config_button.click( - apply_graph_config, inputs=graph_config_input - ) # pylint: disable=no-member + graph_config_button.click(apply_graph_config, inputs=graph_config_input) # pylint: disable=no-member # TODO : use OOP to refactor the following code with gr.Accordion("2. Set up the LLM.", open=False): @@ -445,20 +429,14 @@ def chat_llm_settings(llm_type): llm_config_button = gr.Button("Apply configuration") llm_config_button.click(apply_llm_config_with_chat_op, inputs=llm_config_input) # Determine whether there are Settings in the.env file - env_path = os.path.join( - os.getcwd(), ".env" - ) # Load .env from the current working directory + env_path = os.path.join(os.getcwd(), ".env") # Load .env from the current working directory env_vars = dotenv_values(env_path) api_extract_key = env_vars.get("OPENAI_EXTRACT_API_KEY") api_text2sql_key = env_vars.get("OPENAI_TEXT2GQL_API_KEY") if not api_extract_key: - llm_config_button.click( - apply_llm_config_with_text2gql_op, inputs=llm_config_input - ) + llm_config_button.click(apply_llm_config_with_text2gql_op, inputs=llm_config_input) if not api_text2sql_key: - llm_config_button.click( - apply_llm_config_with_extract_op, inputs=llm_config_input - ) + llm_config_button.click(apply_llm_config_with_extract_op, inputs=llm_config_input) with gr.Tab(label="mini_tasks"): extract_llm_dropdown = gr.Dropdown( @@ -749,14 +727,10 @@ def vector_engine_settings(engine): gr.Textbox(value=index_settings.milvus_host, label="host"), gr.Textbox(value=str(index_settings.milvus_port), label="port"), gr.Textbox(value=index_settings.milvus_user, label="user"), - gr.Textbox( - value=index_settings.milvus_password, label="password", type="password" - ), + gr.Textbox(value=index_settings.milvus_password, label="password", type="password"), ] apply_backend_button = gr.Button("Apply Configuration") - apply_backend_button.click( - partial(apply_vector_engine_backend, "Milvus"), inputs=milvus_inputs - ) + apply_backend_button.click(partial(apply_vector_engine_backend, "Milvus"), inputs=milvus_inputs) elif engine == "Qdrant": with gr.Row(): qdrant_inputs = [ diff --git a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/other_block.py b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/other_block.py index 8b78328f3..beae33e0d 100644 --- a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/other_block.py +++ b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/other_block.py @@ -23,17 +23,15 @@ from apscheduler.triggers.cron import CronTrigger from fastapi import FastAPI -from hugegraph_llm.utils.hugegraph_utils import init_hg_test_data, run_gremlin_query, backup_data -from hugegraph_llm.utils.log import log from hugegraph_llm.demo.rag_demo.vector_graph_block import timely_update_vid_embedding +from hugegraph_llm.utils.hugegraph_utils import backup_data, init_hg_test_data, run_gremlin_query +from hugegraph_llm.utils.log import log def create_other_block(): gr.Markdown("""## Other Tools """) with gr.Row(): - inp = gr.Textbox( - value="g.V().limit(10)", label="Gremlin query", show_copy_button=True, lines=8 - ) + inp = gr.Textbox(value="g.V().limit(10)", label="Gremlin query", show_copy_button=True, lines=8) out = gr.Code(label="Output", language="json", elem_classes="code-container-show") btn = gr.Button("Run Gremlin query") btn.click(fn=run_gremlin_query, inputs=[inp], outputs=out) # pylint: disable=no-member @@ -41,9 +39,7 @@ def create_other_block(): gr.Markdown("---") with gr.Row(): inp = [] - out = gr.Textbox( - label="Backup Graph Manually (Auto backup at 1:00 AM everyday)", show_copy_button=True - ) + out = gr.Textbox(label="Backup Graph Manually (Auto backup at 1:00 AM everyday)", show_copy_button=True) btn = gr.Button("Backup Graph Data") btn.click(fn=backup_data, inputs=inp, outputs=out) # pylint: disable=no-member with gr.Accordion("Init HugeGraph test data (🚧)", open=False): @@ -58,9 +54,7 @@ def create_other_block(): async def lifespan(app: FastAPI): # pylint: disable=W0621 log.info("Starting background scheduler...") scheduler = AsyncIOScheduler() - scheduler.add_job( - backup_data, trigger=CronTrigger(hour=1, minute=0), id="daily_backup", replace_existing=True - ) + scheduler.add_job(backup_data, trigger=CronTrigger(hour=1, minute=0), id="daily_backup", replace_existing=True) scheduler.start() log.info("Starting vid embedding update task...") diff --git a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/rag_block.py b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/rag_block.py index 9bf04b570..d52d6005e 100644 --- a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/rag_block.py +++ b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/rag_block.py @@ -19,13 +19,14 @@ import os from typing import AsyncGenerator, Literal, Optional, Tuple -import pandas as pd + import gradio as gr +import pandas as pd from gradio.utils import NamedString +from hugegraph_llm.config import llm_settings, prompt, resource_path from hugegraph_llm.flows import FlowName from hugegraph_llm.flows.scheduler import SchedulerSingleton -from hugegraph_llm.config import resource_path, prompt, llm_settings from hugegraph_llm.utils.decorators import with_task_id from hugegraph_llm.utils.log import log @@ -104,9 +105,7 @@ def rag_answer( topk_per_keyword=topk_per_keyword, ) if res.get("switch_to_bleu"): - gr.Warning( - "Online reranker fails, automatically switches to local bleu rerank." - ) + gr.Warning("Online reranker fails, automatically switches to local bleu rerank.") return ( res.get("raw_answer", ""), res.get("vector_only_answer", ""), @@ -218,9 +217,7 @@ async def rag_answer_streaming( gremlin_prompt=gremlin_prompt, ): if res.get("switch_to_bleu"): - gr.Warning( - "Online reranker fails, automatically switches to local bleu rerank." - ) + gr.Warning("Online reranker fails, automatically switches to local bleu rerank.") yield ( res.get("raw_answer", ""), res.get("vector_only_answer", ""), @@ -289,19 +286,11 @@ def create_rag_block(): with gr.Column(scale=1): with gr.Row(): - raw_radio = gr.Radio( - choices=[True, False], value=False, label="Basic LLM Answer" - ) - vector_only_radio = gr.Radio( - choices=[True, False], value=False, label="Vector-only Answer" - ) + raw_radio = gr.Radio(choices=[True, False], value=False, label="Basic LLM Answer") + vector_only_radio = gr.Radio(choices=[True, False], value=False, label="Vector-only Answer") with gr.Row(): - graph_only_radio = gr.Radio( - choices=[True, False], value=True, label="Graph-only Answer" - ) - graph_vector_radio = gr.Radio( - choices=[True, False], value=False, label="Graph-Vector Answer" - ) + graph_only_radio = gr.Radio(choices=[True, False], value=True, label="Graph-only Answer") + graph_vector_radio = gr.Radio(choices=[True, False], value=False, label="Graph-Vector Answer") def toggle_slider(enable): return gr.update(interactive=enable) @@ -319,13 +308,9 @@ def toggle_slider(enable): label="Template Num (<0 means disable text2gql) ", precision=0, ) - graph_ratio = gr.Slider( - 0, 1, 0.6, label="Graph Ratio", step=0.1, interactive=False - ) + graph_ratio = gr.Slider(0, 1, 0.6, label="Graph Ratio", step=0.1, interactive=False) - graph_vector_radio.change( - toggle_slider, inputs=graph_vector_radio, outputs=graph_ratio - ) # pylint: disable=no-member + graph_vector_radio.change(toggle_slider, inputs=graph_vector_radio, outputs=graph_ratio) # pylint: disable=no-member near_neighbor_first = gr.Checkbox( value=False, label="Near neighbor first(Optional)", @@ -376,9 +361,7 @@ def toggle_slider(enable): # FIXME: "demo" might conflict with the graph name, it should be modified. answers_path = os.path.join(resource_path, "demo", "questions_answers.xlsx") questions_path = os.path.join(resource_path, "demo", "questions.xlsx") - questions_template_path = os.path.join( - resource_path, "demo", "questions_template.xlsx" - ) + questions_template_path = os.path.join(resource_path, "demo", "questions_template.xlsx") def read_file_to_excel(file: NamedString, line_count: Optional[int] = None): df = None @@ -454,22 +437,14 @@ def several_rag_answer( with gr.Row(): with gr.Column(): - questions_file = gr.File( - file_types=[".xlsx", ".csv"], label="Questions File (.xlsx & csv)" - ) + questions_file = gr.File(file_types=[".xlsx", ".csv"], label="Questions File (.xlsx & csv)") with gr.Column(): - test_template_file = os.path.join( - resource_path, "demo", "questions_template.xlsx" - ) + test_template_file = os.path.join(resource_path, "demo", "questions_template.xlsx") gr.File(value=test_template_file, label="Download Template File") - answer_max_line_count = gr.Number( - 1, label="Max Lines To Show", minimum=1, maximum=40 - ) + answer_max_line_count = gr.Number(1, label="Max Lines To Show", minimum=1, maximum=40) answers_btn = gr.Button("Generate Answer (Batch)", variant="primary") # TODO: Set individual progress bars for dataframe - qa_dataframe = gr.DataFrame( - label="Questions & Answers (Preview)", headers=tests_df_headers - ) + qa_dataframe = gr.DataFrame(label="Questions & Answers (Preview)", headers=tests_df_headers) answers_btn.click( several_rag_answer, inputs=[ @@ -487,12 +462,8 @@ def several_rag_answer( ], outputs=[qa_dataframe, gr.File(label="Download Answered File", min_width=40)], ) - questions_file.change( - read_file_to_excel, questions_file, [qa_dataframe, answer_max_line_count] - ) - answer_max_line_count.change( - change_showing_excel, answer_max_line_count, qa_dataframe - ) + questions_file.change(read_file_to_excel, questions_file, [qa_dataframe, answer_max_line_count]) + answer_max_line_count.change(change_showing_excel, answer_max_line_count, qa_dataframe) return ( inp, answer_prompt_input, diff --git a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/text2gremlin_block.py b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/text2gremlin_block.py index aa9c2f0c5..8c318cb6c 100644 --- a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/text2gremlin_block.py +++ b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/text2gremlin_block.py @@ -17,19 +17,19 @@ import json import os -from datetime import datetime from dataclasses import dataclass -from typing import Any, Tuple, Dict, Literal, Optional, List +from datetime import datetime +from typing import Any, Dict, List, Literal, Optional, Tuple import gradio as gr import pandas as pd -from hugegraph_llm.config import prompt, resource_path, huge_settings +from hugegraph_llm.config import huge_settings, prompt, resource_path from hugegraph_llm.flows import FlowName +from hugegraph_llm.flows.scheduler import SchedulerSingleton from hugegraph_llm.utils.embedding_utils import get_index_folder_name from hugegraph_llm.utils.hugegraph_utils import run_gremlin_query from hugegraph_llm.utils.log import log -from hugegraph_llm.flows.scheduler import SchedulerSingleton @dataclass @@ -82,9 +82,7 @@ def store_schema(schema, question, gremlin_prompt): def build_example_vector_index(temp_file) -> dict: - folder_name = get_index_folder_name( - huge_settings.graph_name, huge_settings.graph_space - ) + folder_name = get_index_folder_name(huge_settings.graph_name, huge_settings.graph_space) index_path = os.path.join(resource_path, folder_name, "gremlin_examples") if not os.path.exists(index_path): os.makedirs(index_path) @@ -96,9 +94,7 @@ def build_example_vector_index(temp_file) -> dict: timestamp = datetime.now().strftime("%Y%m%d%H%M%S") _, file_name = os.path.split(f"{name}_{timestamp}{ext}") log.info("Copying file to: %s", file_name) - target_file = os.path.join( - resource_path, folder_name, "gremlin_examples", file_name - ) + target_file = os.path.join(resource_path, folder_name, "gremlin_examples", file_name) try: import shutil @@ -117,9 +113,7 @@ def build_example_vector_index(temp_file) -> dict: log.critical("Unsupported file format. Please input a JSON or CSV file.") return {"error": "Unsupported file format. Please input a JSON or CSV file."} - return SchedulerSingleton.get_instance().schedule_flow( - FlowName.BUILD_EXAMPLES_INDEX, examples - ) + return SchedulerSingleton.get_instance().schedule_flow(FlowName.BUILD_EXAMPLES_INDEX, examples) def _process_schema(schema, generator, sm): @@ -188,22 +182,14 @@ def simple_schema(schema: Dict[str, Any]) -> Dict[str, Any]: if "vertexlabels" in schema: mini_schema["vertexlabels"] = [] for vertex in schema["vertexlabels"]: - new_vertex = { - key: vertex[key] - for key in ["id", "name", "properties"] - if key in vertex - } + new_vertex = {key: vertex[key] for key in ["id", "name", "properties"] if key in vertex} mini_schema["vertexlabels"].append(new_vertex) # Add necessary edgelabels items (4) if "edgelabels" in schema: mini_schema["edgelabels"] = [] for edge in schema["edgelabels"]: - new_edge = { - key: edge[key] - for key in ["name", "source_label", "target_label", "properties"] - if key in edge - } + new_edge = {key: edge[key] for key in ["name", "source_label", "target_label", "properties"] if key in edge} mini_schema["edgelabels"].append(new_edge) return mini_schema @@ -280,12 +266,8 @@ def create_text2gremlin_block() -> Tuple: language="javascript", elem_classes="code-container-show", ) - initialized_out = gr.Textbox( - label="Gremlin With Template", show_copy_button=True - ) - raw_out = gr.Textbox( - label="Gremlin Without Template", show_copy_button=True - ) + initialized_out = gr.Textbox(label="Gremlin With Template", show_copy_button=True) + raw_out = gr.Textbox(label="Gremlin Without Template", show_copy_button=True) tmpl_exec_out = gr.Code( label="Query With Template Output", language="json", @@ -298,9 +280,7 @@ def create_text2gremlin_block() -> Tuple: ) with gr.Column(scale=1): - example_num_slider = gr.Slider( - minimum=0, maximum=10, step=1, value=2, label="Number of refer examples" - ) + example_num_slider = gr.Slider(minimum=0, maximum=10, step=1, value=2, label="Number of refer examples") schema_box = gr.Textbox( value=prompt.text2gql_graph_schema, label="Schema", diff --git a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/vector_graph_block.py b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/vector_graph_block.py index 84d60df7e..32bb8ef04 100644 --- a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/vector_graph_block.py +++ b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/vector_graph_block.py @@ -23,36 +23,30 @@ import gradio as gr -from hugegraph_llm.config import huge_settings -from hugegraph_llm.config import prompt -from hugegraph_llm.config import resource_path +from hugegraph_llm.config import huge_settings, prompt, resource_path from hugegraph_llm.flows import FlowName from hugegraph_llm.flows.scheduler import SchedulerSingleton from hugegraph_llm.utils.graph_index_utils import ( - get_graph_index_info, - clean_all_graph_index, + build_schema, clean_all_graph_data, - update_vid_embedding, + clean_all_graph_index, extract_graph, + get_graph_index_info, import_graph_data, - build_schema, + update_vid_embedding, ) from hugegraph_llm.utils.hugegraph_utils import check_graph_db_connection from hugegraph_llm.utils.log import log from hugegraph_llm.utils.vector_index_utils import ( - clean_vector_index, build_vector_index, + clean_vector_index, get_vector_index_info, ) def store_prompt(doc, schema, example_prompt): # update env variables: doc, schema and example_prompt - if ( - prompt.doc_input_text != doc - or prompt.graph_schema != schema - or prompt.extract_graph_prompt != example_prompt - ): + if prompt.doc_input_text != doc or prompt.graph_schema != schema or prompt.extract_graph_prompt != example_prompt: prompt.doc_input_text = doc prompt.graph_schema = schema prompt.extract_graph_prompt = example_prompt @@ -64,16 +58,12 @@ def generate_prompt_for_ui(source_text, scenario, example_name): Handles the UI logic for generating a new prompt using the new workflow architecture. """ if not all([source_text, scenario, example_name]): - gr.Warning( - "Please provide original text, expected scenario, and select an example!" - ) + gr.Warning("Please provide original text, expected scenario, and select an example!") return gr.update() try: # using new architecture scheduler = SchedulerSingleton.get_instance() - result = scheduler.schedule_flow( - FlowName.PROMPT_GENERATE, source_text, scenario, example_name - ) + result = scheduler.schedule_flow(FlowName.PROMPT_GENERATE, source_text, scenario, example_name) gr.Info("Prompt generated successfully!") return result except Exception as e: @@ -84,9 +74,7 @@ def generate_prompt_for_ui(source_text, scenario, example_name): def load_example_names(): """Load all candidate examples""" try: - examples_path = os.path.join( - resource_path, "prompt_examples", "prompt_examples.json" - ) + examples_path = os.path.join(resource_path, "prompt_examples", "prompt_examples.json") with open(examples_path, "r", encoding="utf-8") as f: examples = json.load(f) return [example.get("name", "Unnamed example") for example in examples] @@ -100,29 +88,19 @@ def load_query_examples(): language = getattr( prompt, "language", - ( - getattr(prompt.llm_settings, "language", "EN") - if hasattr(prompt, "llm_settings") - else "EN" - ), + (getattr(prompt.llm_settings, "language", "EN") if hasattr(prompt, "llm_settings") else "EN"), ) if language.upper() == "CN": - examples_path = os.path.join( - resource_path, "prompt_examples", "query_examples_CN.json" - ) + examples_path = os.path.join(resource_path, "prompt_examples", "query_examples_CN.json") else: - examples_path = os.path.join( - resource_path, "prompt_examples", "query_examples.json" - ) + examples_path = os.path.join(resource_path, "prompt_examples", "query_examples.json") with open(examples_path, "r", encoding="utf-8") as f: examples = json.load(f) return json.dumps(examples, indent=2, ensure_ascii=False) except (FileNotFoundError, json.JSONDecodeError): try: - examples_path = os.path.join( - resource_path, "prompt_examples", "query_examples.json" - ) + examples_path = os.path.join(resource_path, "prompt_examples", "query_examples.json") with open(examples_path, "r", encoding="utf-8") as f: examples = json.load(f) return json.dumps(examples, indent=2, ensure_ascii=False) @@ -133,9 +111,7 @@ def load_query_examples(): def load_schema_fewshot_examples(): """Load few-shot examples from a JSON file""" try: - examples_path = os.path.join( - resource_path, "prompt_examples", "schema_examples.json" - ) + examples_path = os.path.join(resource_path, "prompt_examples", "schema_examples.json") with open(examples_path, "r", encoding="utf-8") as f: examples = json.load(f) return json.dumps(examples, indent=2, ensure_ascii=False) @@ -146,14 +122,10 @@ def load_schema_fewshot_examples(): def update_example_preview(example_name): """Update the display content based on the selected example name.""" try: - examples_path = os.path.join( - resource_path, "prompt_examples", "prompt_examples.json" - ) + examples_path = os.path.join(resource_path, "prompt_examples", "prompt_examples.json") with open(examples_path, "r", encoding="utf-8") as f: all_examples = json.load(f) - selected_example = next( - (ex for ex in all_examples if ex.get("name") == example_name), None - ) + selected_example = next((ex for ex in all_examples if ex.get("name") == example_name), None) if selected_example: return ( @@ -181,26 +153,18 @@ def _create_prompt_helper_block(demo, input_text, info_extract_template): few_shot_dropdown = gr.Dropdown( choices=example_names, label="Select a Few-shot example as a reference", - value=( - example_names[0] - if example_names and example_names[0] != "No available examples" - else None - ), + value=(example_names[0] if example_names and example_names[0] != "No available examples" else None), ) with gr.Accordion("View example details", open=False): example_desc_preview = gr.Markdown(label="Example description") - example_text_preview = gr.Textbox( - label="Example input text", lines=5, interactive=False - ) + example_text_preview = gr.Textbox(label="Example input text", lines=5, interactive=False) example_prompt_preview = gr.Code( label="Example Graph Extract Prompt", language="markdown", interactive=False, ) - generate_prompt_btn = gr.Button( - "🚀 Auto-generate Graph Extract Prompt", variant="primary" - ) + generate_prompt_btn = gr.Button("🚀 Auto-generate Graph Extract Prompt", variant="primary") # Bind the change event of the dropdown menu few_shot_dropdown.change( fn=update_example_preview, @@ -292,9 +256,7 @@ def create_vector_graph_block(): lines=15, max_lines=29, ) - out = gr.Code( - label="Output Info", language="json", elem_classes="code-container-edit" - ) + out = gr.Code(label="Output Info", language="json", elem_classes="code-container-edit") with gr.Row(): with gr.Accordion("Get RAG Info", open=False): @@ -303,12 +265,8 @@ def create_vector_graph_block(): graph_index_btn0 = gr.Button("Get Graph Index Info", size="sm") with gr.Accordion("Clear RAG Data", open=False): with gr.Column(): - vector_index_btn1 = gr.Button( - "Clear Chunks Vector Index", size="sm" - ) - graph_index_btn1 = gr.Button( - "Clear Graph Vid Vector Index", size="sm" - ) + vector_index_btn1 = gr.Button("Clear Chunks Vector Index", size="sm") + graph_index_btn1 = gr.Button("Clear Graph Vid Vector Index", size="sm") graph_data_btn0 = gr.Button("Clear Graph Data", size="sm") vector_import_bt = gr.Button("Import into Vector", variant="primary") @@ -348,9 +306,7 @@ def create_vector_graph_block(): store_prompt, inputs=[input_text, input_schema, info_extract_template], ) - vector_import_bt.click( - build_vector_index, inputs=[input_file, input_text], outputs=out - ).then( + vector_import_bt.click(build_vector_index, inputs=[input_file, input_text], outputs=out).then( store_prompt, inputs=[input_text, input_schema, info_extract_template], ) @@ -381,9 +337,9 @@ def create_vector_graph_block(): inputs=[input_text, input_schema, info_extract_template], ) - graph_loading_bt.click( - import_graph_data, inputs=[out, input_schema], outputs=[out] - ).then(update_vid_embedding).then( + graph_loading_bt.click(import_graph_data, inputs=[out, input_schema], outputs=[out]).then( + update_vid_embedding + ).then( store_prompt, inputs=[input_text, input_schema, info_extract_template], ) diff --git a/hugegraph-llm/src/hugegraph_llm/document/__init__.py b/hugegraph-llm/src/hugegraph_llm/document/__init__.py index 07e44c7f6..fb0428f4e 100644 --- a/hugegraph-llm/src/hugegraph_llm/document/__init__.py +++ b/hugegraph-llm/src/hugegraph_llm/document/__init__.py @@ -21,7 +21,7 @@ in the HugeGraph LLM system. """ -from typing import Dict, Any, Optional, Union +from typing import Any, Dict, Optional, Union class Metadata: @@ -59,7 +59,7 @@ def __init__(self, content: str, metadata: Optional[Union[Dict[str, Any], Metada Args: content: The text content of the document. metadata: Metadata associated with the document. Can be a dictionary or Metadata object. - + Raises: ValueError: If content is None or empty string. """ diff --git a/hugegraph-llm/src/hugegraph_llm/document/chunk_split.py b/hugegraph-llm/src/hugegraph_llm/document/chunk_split.py index ee173b284..b7ed15f93 100644 --- a/hugegraph-llm/src/hugegraph_llm/document/chunk_split.py +++ b/hugegraph-llm/src/hugegraph_llm/document/chunk_split.py @@ -16,7 +16,8 @@ # under the License. -from typing import Literal, Union, List +from typing import List, Literal, Union + from langchain_text_splitters import RecursiveCharacterTextSplitter @@ -33,13 +34,9 @@ def __init__( else: raise ValueError("Argument `language` must be zh or en!") if split_type == "paragraph": - self.text_splitter = RecursiveCharacterTextSplitter( - chunk_size=500, chunk_overlap=30, separators=separators - ) + self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=30, separators=separators) elif split_type == "sentence": - self.text_splitter = RecursiveCharacterTextSplitter( - chunk_size=50, chunk_overlap=0, separators=separators - ) + self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=50, chunk_overlap=0, separators=separators) else: raise ValueError("Arg `type` must be paragraph, sentence!") diff --git a/hugegraph-llm/src/hugegraph_llm/flows/build_example_index.py b/hugegraph-llm/src/hugegraph_llm/flows/build_example_index.py index 0b0cd9c26..c50e83b65 100644 --- a/hugegraph-llm/src/hugegraph_llm/flows/build_example_index.py +++ b/hugegraph-llm/src/hugegraph_llm/flows/build_example_index.py @@ -14,15 +14,15 @@ # limitations under the License. import json -from typing import List, Dict, Optional +from typing import Dict, List, Optional from pycgraph import GPipeline from hugegraph_llm.flows.common import BaseFlow -from hugegraph_llm.state.ai_state import WkFlowInput, WkFlowState from hugegraph_llm.nodes.index_node.build_gremlin_example_index import ( BuildGremlinExampleIndexNode, ) +from hugegraph_llm.state.ai_state import WkFlowInput, WkFlowState from hugegraph_llm.utils.log import log diff --git a/hugegraph-llm/src/hugegraph_llm/flows/build_schema.py b/hugegraph-llm/src/hugegraph_llm/flows/build_schema.py index 1bb413ba5..16e6775a5 100644 --- a/hugegraph-llm/src/hugegraph_llm/flows/build_schema.py +++ b/hugegraph-llm/src/hugegraph_llm/flows/build_schema.py @@ -18,8 +18,8 @@ from pycgraph import GPipeline from hugegraph_llm.flows.common import BaseFlow -from hugegraph_llm.state.ai_state import WkFlowInput, WkFlowState from hugegraph_llm.nodes.llm_node.schema_build import SchemaBuildNode +from hugegraph_llm.state.ai_state import WkFlowInput, WkFlowState from hugegraph_llm.utils.log import log @@ -42,9 +42,7 @@ def prepare( prepared_input.query_examples = query_examples prepared_input.few_shot_schema = few_shot_schema - def build_flow( - self, texts=None, query_examples=None, few_shot_schema=None, **kwargs - ): + def build_flow(self, texts=None, query_examples=None, few_shot_schema=None, **kwargs): pipeline = GPipeline() prepared_input = WkFlowInput() self.prepare( diff --git a/hugegraph-llm/src/hugegraph_llm/flows/build_vector_index.py b/hugegraph-llm/src/hugegraph_llm/flows/build_vector_index.py index ccfd6d02f..cfd23c627 100644 --- a/hugegraph-llm/src/hugegraph_llm/flows/build_vector_index.py +++ b/hugegraph-llm/src/hugegraph_llm/flows/build_vector_index.py @@ -20,8 +20,7 @@ from hugegraph_llm.flows.common import BaseFlow from hugegraph_llm.nodes.document_node.chunk_split import ChunkSplitNode from hugegraph_llm.nodes.index_node.build_vector_index import BuildVectorIndexNode -from hugegraph_llm.state.ai_state import WkFlowInput -from hugegraph_llm.state.ai_state import WkFlowState +from hugegraph_llm.state.ai_state import WkFlowInput, WkFlowState # pylint: disable=arguments-differ,keyword-arg-before-vararg diff --git a/hugegraph-llm/src/hugegraph_llm/flows/common.py b/hugegraph-llm/src/hugegraph_llm/flows/common.py index d1301119e..8705a3008 100644 --- a/hugegraph-llm/src/hugegraph_llm/flows/common.py +++ b/hugegraph-llm/src/hugegraph_llm/flows/common.py @@ -14,7 +14,7 @@ # limitations under the License. from abc import ABC, abstractmethod -from typing import Dict, Any, AsyncGenerator +from typing import Any, AsyncGenerator, Dict from hugegraph_llm.state.ai_state import WkFlowInput from hugegraph_llm.utils.log import log @@ -43,9 +43,7 @@ def post_deal(self, **kwargs): Post-processing interface. """ - async def post_deal_stream( - self, pipeline=None - ) -> AsyncGenerator[Dict[str, Any], None]: + async def post_deal_stream(self, pipeline=None) -> AsyncGenerator[Dict[str, Any], None]: """ Streaming post-processing interface. Subclasses can override this method as needed. diff --git a/hugegraph-llm/src/hugegraph_llm/flows/get_graph_index_info.py b/hugegraph-llm/src/hugegraph_llm/flows/get_graph_index_info.py index 1dc98323b..d6b13d59c 100644 --- a/hugegraph-llm/src/hugegraph_llm/flows/get_graph_index_info.py +++ b/hugegraph-llm/src/hugegraph_llm/flows/get_graph_index_info.py @@ -20,8 +20,8 @@ from hugegraph_llm.config import huge_settings, index_settings from hugegraph_llm.flows.common import BaseFlow from hugegraph_llm.models.embeddings.init_embedding import Embeddings -from hugegraph_llm.state.ai_state import WkFlowInput, WkFlowState from hugegraph_llm.nodes.hugegraph_node.fetch_graph_data import FetchGraphDataNode +from hugegraph_llm.state.ai_state import WkFlowInput, WkFlowState # pylint: disable=arguments-differ,keyword-arg-before-vararg @@ -44,7 +44,9 @@ def build_flow(self, **kwargs): def post_deal(self, pipeline=None): # Lazy import to avoid circular dependency - from hugegraph_llm.utils.vector_index_utils import get_vector_index_class # pylint: disable=import-outside-toplevel + from hugegraph_llm.utils.vector_index_utils import ( + get_vector_index_class, # pylint: disable=import-outside-toplevel + ) graph_summary_info = pipeline.getGParamWithNoEmpty("wkflow_state").to_json() diff --git a/hugegraph-llm/src/hugegraph_llm/flows/graph_extract.py b/hugegraph-llm/src/hugegraph_llm/flows/graph_extract.py index b2bfec664..0057f2b71 100644 --- a/hugegraph-llm/src/hugegraph_llm/flows/graph_extract.py +++ b/hugegraph-llm/src/hugegraph_llm/flows/graph_extract.py @@ -14,7 +14,9 @@ # limitations under the License. import json + from pycgraph import GPipeline + from hugegraph_llm.flows.common import BaseFlow from hugegraph_llm.nodes.document_node.chunk_split import ChunkSplitNode from hugegraph_llm.nodes.hugegraph_node.schema import SchemaNode @@ -46,15 +48,11 @@ def prepare( prepared_input.schema = schema prepared_input.extract_type = extract_type - def build_flow( - self, schema, texts, example_prompt, extract_type, language="zh", **kwargs - ): + def build_flow(self, schema, texts, example_prompt, extract_type, language="zh", **kwargs): pipeline = GPipeline() prepared_input = WkFlowInput() # prepare input data - self.prepare( - prepared_input, schema, texts, example_prompt, extract_type, language - ) + self.prepare(prepared_input, schema, texts, example_prompt, extract_type, language) pipeline.createGParam(prepared_input, "wkflow_input") pipeline.createGParam(WkFlowState(), "wkflow_state") @@ -64,9 +62,7 @@ def build_flow( graph_extract_node = ExtractNode() pipeline.registerGElement(schema_node, set(), "schema_node") pipeline.registerGElement(chunk_split_node, set(), "chunk_split") - pipeline.registerGElement( - graph_extract_node, {schema_node, chunk_split_node}, "graph_extract" - ) + pipeline.registerGElement(graph_extract_node, {schema_node, chunk_split_node}, "graph_extract") return pipeline diff --git a/hugegraph-llm/src/hugegraph_llm/flows/import_graph_data.py b/hugegraph-llm/src/hugegraph_llm/flows/import_graph_data.py index 1a8ae7483..ac0f2ab1a 100644 --- a/hugegraph-llm/src/hugegraph_llm/flows/import_graph_data.py +++ b/hugegraph-llm/src/hugegraph_llm/flows/import_graph_data.py @@ -17,6 +17,7 @@ import gradio as gr from pycgraph import GPipeline + from hugegraph_llm.flows.common import BaseFlow from hugegraph_llm.nodes.hugegraph_node.commit_to_hugegraph import Commit2GraphNode from hugegraph_llm.nodes.hugegraph_node.schema import SchemaNode diff --git a/hugegraph-llm/src/hugegraph_llm/flows/prompt_generate.py b/hugegraph-llm/src/hugegraph_llm/flows/prompt_generate.py index fe42a4420..e76968d68 100644 --- a/hugegraph-llm/src/hugegraph_llm/flows/prompt_generate.py +++ b/hugegraph-llm/src/hugegraph_llm/flows/prompt_generate.py @@ -17,8 +17,7 @@ from hugegraph_llm.flows.common import BaseFlow from hugegraph_llm.nodes.llm_node.prompt_generate import PromptGenerateNode -from hugegraph_llm.state.ai_state import WkFlowInput -from hugegraph_llm.state.ai_state import WkFlowState +from hugegraph_llm.state.ai_state import WkFlowInput, WkFlowState # pylint: disable=arguments-differ,keyword-arg-before-vararg @@ -26,9 +25,7 @@ class PromptGenerateFlow(BaseFlow): def __init__(self): pass - def prepare( - self, prepared_input: WkFlowInput, source_text, scenario, example_name, **kwargs - ): + def prepare(self, prepared_input: WkFlowInput, source_text, scenario, example_name, **kwargs): """ Prepare input data for PromptGenerate workflow """ @@ -59,6 +56,4 @@ def post_deal(self, pipeline=None, **kwargs): Process the execution result of PromptGenerate workflow """ res = pipeline.getGParamWithNoEmpty("wkflow_state").to_json() - return res.get( - "generated_extract_prompt", "Generation failed. Please check the logs." - ) + return res.get("generated_extract_prompt", "Generation failed. Please check the logs.") diff --git a/hugegraph-llm/src/hugegraph_llm/flows/rag_flow_graph_only.py b/hugegraph-llm/src/hugegraph_llm/flows/rag_flow_graph_only.py index 7985e3621..66e8fb0be 100644 --- a/hugegraph-llm/src/hugegraph_llm/flows/rag_flow_graph_only.py +++ b/hugegraph-llm/src/hugegraph_llm/flows/rag_flow_graph_only.py @@ -14,35 +14,31 @@ # limitations under the License. -from typing import Optional, Literal, cast +from typing import Literal, Optional, cast -from pycgraph import GPipeline, GRegion, GCondition +from pycgraph import GCondition, GPipeline, GRegion +from hugegraph_llm.config import huge_settings, prompt from hugegraph_llm.flows.common import BaseFlow -from hugegraph_llm.nodes.llm_node.keyword_extract_node import KeywordExtractNode -from hugegraph_llm.nodes.index_node.semantic_id_query_node import SemanticIdQueryNode -from hugegraph_llm.nodes.hugegraph_node.schema import SchemaNode -from hugegraph_llm.nodes.hugegraph_node.graph_query_node import GraphQueryNode from hugegraph_llm.nodes.common_node.merge_rerank_node import MergeRerankNode +from hugegraph_llm.nodes.hugegraph_node.graph_query_node import GraphQueryNode +from hugegraph_llm.nodes.hugegraph_node.schema import SchemaNode +from hugegraph_llm.nodes.index_node.semantic_id_query_node import SemanticIdQueryNode from hugegraph_llm.nodes.llm_node.answer_synthesize_node import AnswerSynthesizeNode +from hugegraph_llm.nodes.llm_node.keyword_extract_node import KeywordExtractNode from hugegraph_llm.state.ai_state import WkFlowInput, WkFlowState -from hugegraph_llm.config import huge_settings, prompt from hugegraph_llm.utils.log import log class GraphRecallCondition(GCondition): def choose(self): - prepared_input: WkFlowInput = cast( - WkFlowInput, self.getGParamWithNoEmpty("wkflow_input") - ) + prepared_input: WkFlowInput = cast(WkFlowInput, self.getGParamWithNoEmpty("wkflow_input")) return 0 if prepared_input.is_graph_rag_recall else 1 class VectorOnlyCondition(GCondition): def choose(self): - prepared_input: WkFlowInput = cast( - WkFlowInput, self.getGParamWithNoEmpty("wkflow_input") - ) + prepared_input: WkFlowInput = cast(WkFlowInput, self.getGParamWithNoEmpty("wkflow_input")) return 0 if prepared_input.is_vector_only else 1 @@ -86,25 +82,15 @@ def prepare( prepared_input.graph_vector_answer = graph_vector_answer prepared_input.gremlin_tmpl_num = gremlin_tmpl_num prepared_input.gremlin_prompt = gremlin_prompt or prompt.gremlin_generate_prompt - prepared_input.max_graph_items = ( - max_graph_items or huge_settings.max_graph_items - ) - prepared_input.topk_per_keyword = ( - topk_per_keyword or huge_settings.topk_per_keyword - ) - prepared_input.topk_return_results = ( - topk_return_results or huge_settings.topk_return_results - ) + prepared_input.max_graph_items = max_graph_items or huge_settings.max_graph_items + prepared_input.topk_per_keyword = topk_per_keyword or huge_settings.topk_per_keyword + prepared_input.topk_return_results = topk_return_results or huge_settings.topk_return_results prepared_input.rerank_method = rerank_method prepared_input.near_neighbor_first = near_neighbor_first - prepared_input.keywords_extract_prompt = ( - keywords_extract_prompt or prompt.keywords_extract_prompt - ) + prepared_input.keywords_extract_prompt = keywords_extract_prompt or prompt.keywords_extract_prompt prepared_input.answer_prompt = answer_prompt or prompt.answer_prompt prepared_input.custom_related_information = custom_related_information - prepared_input.vector_dis_threshold = ( - vector_dis_threshold or huge_settings.vector_dis_threshold - ) + prepared_input.vector_dis_threshold = vector_dis_threshold or huge_settings.vector_dis_threshold prepared_input.schema = huge_settings.graph_name prepared_input.is_graph_rag_recall = is_graph_rag_recall @@ -126,12 +112,8 @@ def build_flow(self, **kwargs): # Create nodes and register them with registerGElement only_keyword_extract_node = KeywordExtractNode("only_keyword") - only_semantic_id_query_node = SemanticIdQueryNode( - {only_keyword_extract_node}, "only_semantic" - ) - vector_region: GRegion = GRegion( - [only_keyword_extract_node, only_semantic_id_query_node] - ) + only_semantic_id_query_node = SemanticIdQueryNode({only_keyword_extract_node}, "only_semantic") + vector_region: GRegion = GRegion([only_keyword_extract_node, only_semantic_id_query_node]) only_schema_node = SchemaNode() schema_node = VectorOnlyCondition([GRegion(), only_schema_node]) @@ -150,9 +132,7 @@ def build_flow(self, **kwargs): {schema_node, vector_region}, "graph_condition", ) - pipeline.registerGElement( - answer_node, {graph_condition_region}, "answer_condition" - ) + pipeline.registerGElement(answer_node, {graph_condition_region}, "answer_condition") log.info("RAGGraphOnlyFlow pipeline built successfully") return pipeline diff --git a/hugegraph-llm/src/hugegraph_llm/flows/rag_flow_graph_vector.py b/hugegraph-llm/src/hugegraph_llm/flows/rag_flow_graph_vector.py index fbfa4a44b..1d738b2ac 100644 --- a/hugegraph-llm/src/hugegraph_llm/flows/rag_flow_graph_vector.py +++ b/hugegraph-llm/src/hugegraph_llm/flows/rag_flow_graph_vector.py @@ -14,20 +14,20 @@ # limitations under the License. -from typing import Optional, Literal +from typing import Literal, Optional from pycgraph import GPipeline +from hugegraph_llm.config import huge_settings, prompt from hugegraph_llm.flows.common import BaseFlow -from hugegraph_llm.nodes.index_node.vector_query_node import VectorQueryNode -from hugegraph_llm.nodes.llm_node.keyword_extract_node import KeywordExtractNode -from hugegraph_llm.nodes.index_node.semantic_id_query_node import SemanticIdQueryNode -from hugegraph_llm.nodes.hugegraph_node.schema import SchemaNode -from hugegraph_llm.nodes.hugegraph_node.graph_query_node import GraphQueryNode from hugegraph_llm.nodes.common_node.merge_rerank_node import MergeRerankNode +from hugegraph_llm.nodes.hugegraph_node.graph_query_node import GraphQueryNode +from hugegraph_llm.nodes.hugegraph_node.schema import SchemaNode +from hugegraph_llm.nodes.index_node.semantic_id_query_node import SemanticIdQueryNode +from hugegraph_llm.nodes.index_node.vector_query_node import VectorQueryNode from hugegraph_llm.nodes.llm_node.answer_synthesize_node import AnswerSynthesizeNode +from hugegraph_llm.nodes.llm_node.keyword_extract_node import KeywordExtractNode from hugegraph_llm.state.ai_state import WkFlowInput, WkFlowState -from hugegraph_llm.config import huge_settings, prompt from hugegraph_llm.utils.log import log @@ -71,23 +71,13 @@ def prepare( prepared_input.graph_ratio = graph_ratio prepared_input.gremlin_tmpl_num = gremlin_tmpl_num prepared_input.gremlin_prompt = gremlin_prompt or prompt.gremlin_generate_prompt - prepared_input.max_graph_items = ( - max_graph_items or huge_settings.max_graph_items - ) - prepared_input.topk_return_results = ( - topk_return_results or huge_settings.topk_return_results - ) - prepared_input.topk_per_keyword = ( - topk_per_keyword or huge_settings.topk_per_keyword - ) - prepared_input.vector_dis_threshold = ( - vector_dis_threshold or huge_settings.vector_dis_threshold - ) + prepared_input.max_graph_items = max_graph_items or huge_settings.max_graph_items + prepared_input.topk_return_results = topk_return_results or huge_settings.topk_return_results + prepared_input.topk_per_keyword = topk_per_keyword or huge_settings.topk_per_keyword + prepared_input.vector_dis_threshold = vector_dis_threshold or huge_settings.vector_dis_threshold prepared_input.rerank_method = rerank_method prepared_input.near_neighbor_first = near_neighbor_first - prepared_input.keywords_extract_prompt = ( - keywords_extract_prompt or prompt.keywords_extract_prompt - ) + prepared_input.keywords_extract_prompt = keywords_extract_prompt or prompt.keywords_extract_prompt prepared_input.answer_prompt = answer_prompt or prompt.answer_prompt prepared_input.custom_related_information = custom_related_information prepared_input.schema = huge_settings.graph_name @@ -118,19 +108,11 @@ def build_flow(self, **kwargs): # Register nodes and their dependencies pipeline.registerGElement(vector_query_node, set(), "vector") pipeline.registerGElement(keyword_extract_node, set(), "keyword") - pipeline.registerGElement( - semantic_id_query_node, {keyword_extract_node}, "semantic" - ) + pipeline.registerGElement(semantic_id_query_node, {keyword_extract_node}, "semantic") pipeline.registerGElement(schema_node, set(), "schema") - pipeline.registerGElement( - graph_query_node, {schema_node, semantic_id_query_node}, "graph" - ) - pipeline.registerGElement( - merge_rerank_node, {graph_query_node, vector_query_node}, "merge" - ) - pipeline.registerGElement( - answer_synthesize_node, {merge_rerank_node}, "graph_vector" - ) + pipeline.registerGElement(graph_query_node, {schema_node, semantic_id_query_node}, "graph") + pipeline.registerGElement(merge_rerank_node, {graph_query_node, vector_query_node}, "merge") + pipeline.registerGElement(answer_synthesize_node, {merge_rerank_node}, "graph_vector") log.info("RAGGraphVectorFlow pipeline built successfully") return pipeline diff --git a/hugegraph-llm/src/hugegraph_llm/flows/rag_flow_raw.py b/hugegraph-llm/src/hugegraph_llm/flows/rag_flow_raw.py index 0ae2d63aa..862359c1c 100644 --- a/hugegraph-llm/src/hugegraph_llm/flows/rag_flow_raw.py +++ b/hugegraph-llm/src/hugegraph_llm/flows/rag_flow_raw.py @@ -18,10 +18,10 @@ from pycgraph import GPipeline +from hugegraph_llm.config import huge_settings, prompt from hugegraph_llm.flows.common import BaseFlow from hugegraph_llm.nodes.llm_node.answer_synthesize_node import AnswerSynthesizeNode from hugegraph_llm.state.ai_state import WkFlowInput, WkFlowState -from hugegraph_llm.config import huge_settings, prompt from hugegraph_llm.utils.log import log diff --git a/hugegraph-llm/src/hugegraph_llm/flows/rag_flow_vector_only.py b/hugegraph-llm/src/hugegraph_llm/flows/rag_flow_vector_only.py index 98563abfb..585f5a6ab 100644 --- a/hugegraph-llm/src/hugegraph_llm/flows/rag_flow_vector_only.py +++ b/hugegraph-llm/src/hugegraph_llm/flows/rag_flow_vector_only.py @@ -14,16 +14,16 @@ # limitations under the License. -from typing import Optional, Literal +from typing import Literal, Optional from pycgraph import GPipeline +from hugegraph_llm.config import huge_settings, prompt from hugegraph_llm.flows.common import BaseFlow -from hugegraph_llm.nodes.index_node.vector_query_node import VectorQueryNode from hugegraph_llm.nodes.common_node.merge_rerank_node import MergeRerankNode +from hugegraph_llm.nodes.index_node.vector_query_node import VectorQueryNode from hugegraph_llm.nodes.llm_node.answer_synthesize_node import AnswerSynthesizeNode from hugegraph_llm.state.ai_state import WkFlowInput, WkFlowState -from hugegraph_llm.config import huge_settings, prompt from hugegraph_llm.utils.log import log @@ -59,12 +59,8 @@ def prepare( prepared_input.vector_only_answer = vector_only_answer prepared_input.graph_only_answer = graph_only_answer prepared_input.graph_vector_answer = graph_vector_answer - prepared_input.vector_dis_threshold = ( - vector_dis_threshold or huge_settings.vector_dis_threshold - ) - prepared_input.topk_return_results = ( - topk_return_results or huge_settings.topk_return_results - ) + prepared_input.vector_dis_threshold = vector_dis_threshold or huge_settings.vector_dis_threshold + prepared_input.topk_return_results = topk_return_results or huge_settings.topk_return_results prepared_input.rerank_method = rerank_method prepared_input.near_neighbor_first = near_neighbor_first prepared_input.custom_related_information = custom_related_information @@ -92,9 +88,7 @@ def build_flow(self, **kwargs): # Register nodes and dependencies, keep naming consistent with original pipeline.registerGElement(only_vector_query_node, set(), "only_vector") - pipeline.registerGElement( - merge_rerank_node, {only_vector_query_node}, "merge_two" - ) + pipeline.registerGElement(merge_rerank_node, {only_vector_query_node}, "merge_two") pipeline.registerGElement(answer_synthesize_node, {merge_rerank_node}, "vector") log.info("RAGVectorOnlyFlow pipeline built successfully") return pipeline diff --git a/hugegraph-llm/src/hugegraph_llm/flows/scheduler.py b/hugegraph-llm/src/hugegraph_llm/flows/scheduler.py index 2eef75bc8..a4ba5fa4c 100644 --- a/hugegraph-llm/src/hugegraph_llm/flows/scheduler.py +++ b/hugegraph-llm/src/hugegraph_llm/flows/scheduler.py @@ -14,25 +14,27 @@ # limitations under the License. import threading -from typing import Dict, Any +from typing import Any, Dict + from pycgraph import GPipeline, GPipelineManager + from hugegraph_llm.flows import FlowName +from hugegraph_llm.flows.build_example_index import BuildExampleIndexFlow +from hugegraph_llm.flows.build_schema import BuildSchemaFlow from hugegraph_llm.flows.build_vector_index import BuildVectorIndexFlow from hugegraph_llm.flows.common import BaseFlow -from hugegraph_llm.flows.build_example_index import BuildExampleIndexFlow +from hugegraph_llm.flows.get_graph_index_info import GetGraphIndexInfoFlow from hugegraph_llm.flows.graph_extract import GraphExtractFlow from hugegraph_llm.flows.import_graph_data import ImportGraphDataFlow -from hugegraph_llm.flows.update_vid_embeddings import UpdateVidEmbeddingsFlow -from hugegraph_llm.flows.get_graph_index_info import GetGraphIndexInfoFlow -from hugegraph_llm.flows.build_schema import BuildSchemaFlow from hugegraph_llm.flows.prompt_generate import PromptGenerateFlow -from hugegraph_llm.flows.rag_flow_raw import RAGRawFlow -from hugegraph_llm.flows.rag_flow_vector_only import RAGVectorOnlyFlow from hugegraph_llm.flows.rag_flow_graph_only import RAGGraphOnlyFlow from hugegraph_llm.flows.rag_flow_graph_vector import RAGGraphVectorFlow +from hugegraph_llm.flows.rag_flow_raw import RAGRawFlow +from hugegraph_llm.flows.rag_flow_vector_only import RAGVectorOnlyFlow +from hugegraph_llm.flows.text2gremlin import Text2GremlinFlow +from hugegraph_llm.flows.update_vid_embeddings import UpdateVidEmbeddingsFlow from hugegraph_llm.state.ai_state import WkFlowInput from hugegraph_llm.utils.log import log -from hugegraph_llm.flows.text2gremlin import Text2GremlinFlow class Scheduler: diff --git a/hugegraph-llm/src/hugegraph_llm/flows/text2gremlin.py b/hugegraph-llm/src/hugegraph_llm/flows/text2gremlin.py index 96ba08ba4..10818b128 100644 --- a/hugegraph-llm/src/hugegraph_llm/flows/text2gremlin.py +++ b/hugegraph-llm/src/hugegraph_llm/flows/text2gremlin.py @@ -18,13 +18,13 @@ from pycgraph import GPipeline from hugegraph_llm.flows.common import BaseFlow -from hugegraph_llm.state.ai_state import WkFlowInput, WkFlowState +from hugegraph_llm.nodes.hugegraph_node.gremlin_execute import GremlinExecuteNode from hugegraph_llm.nodes.hugegraph_node.schema import SchemaNode from hugegraph_llm.nodes.index_node.gremlin_example_index_query import ( GremlinExampleIndexQueryNode, ) from hugegraph_llm.nodes.llm_node.text2gremlin import Text2GremlinNode -from hugegraph_llm.nodes.hugegraph_node.gremlin_execute import GremlinExecuteNode +from hugegraph_llm.state.ai_state import WkFlowInput, WkFlowState # pylint: disable=arguments-differ,keyword-arg-before-vararg diff --git a/hugegraph-llm/src/hugegraph_llm/flows/update_vid_embeddings.py b/hugegraph-llm/src/hugegraph_llm/flows/update_vid_embeddings.py index 5af09bae1..693f70012 100644 --- a/hugegraph-llm/src/hugegraph_llm/flows/update_vid_embeddings.py +++ b/hugegraph-llm/src/hugegraph_llm/flows/update_vid_embeddings.py @@ -18,7 +18,7 @@ from hugegraph_llm.flows.common import BaseFlow from hugegraph_llm.nodes.hugegraph_node.fetch_graph_data import FetchGraphDataNode from hugegraph_llm.nodes.index_node.build_semantic_index import BuildSemanticIndexNode -from hugegraph_llm.state.ai_state import WkFlowState, WkFlowInput +from hugegraph_llm.state.ai_state import WkFlowInput, WkFlowState # pylint: disable=arguments-differ,keyword-arg-before-vararg diff --git a/hugegraph-llm/src/hugegraph_llm/indices/vector_index/base.py b/hugegraph-llm/src/hugegraph_llm/indices/vector_index/base.py index 2e1cfc267..feda7a24c 100644 --- a/hugegraph-llm/src/hugegraph_llm/indices/vector_index/base.py +++ b/hugegraph-llm/src/hugegraph_llm/indices/vector_index/base.py @@ -56,9 +56,7 @@ def remove(self, props: Union[Set[Any], List[Any]]) -> int: """ @abstractmethod - def search( - self, query_vector: List[float], top_k: int, dis_threshold: float = 0.9 - ) -> List[Any]: + def search(self, query_vector: List[float], top_k: int, dis_threshold: float = 0.9) -> List[Any]: """ Search for the top_k most similar vectors to the query vector. diff --git a/hugegraph-llm/src/hugegraph_llm/indices/vector_index/faiss_vector_store.py b/hugegraph-llm/src/hugegraph_llm/indices/vector_index/faiss_vector_store.py index a8f23155a..d0a016c53 100644 --- a/hugegraph-llm/src/hugegraph_llm/indices/vector_index/faiss_vector_store.py +++ b/hugegraph-llm/src/hugegraph_llm/indices/vector_index/faiss_vector_store.py @@ -66,9 +66,7 @@ def remove(self, props: Union[Set[Any], List[Any]]) -> int: self.properties = [p for i, p in enumerate(self.properties) if i not in indices] return remove_num - def search( - self, query_vector: List[float], top_k: int, dis_threshold: float = 0.9 - ) -> List[Any]: + def search(self, query_vector: List[float], top_k: int, dis_threshold: float = 0.9) -> List[Any]: if self.index.ntotal == 0: return [] diff --git a/hugegraph-llm/src/hugegraph_llm/indices/vector_index/milvus_vector_store.py b/hugegraph-llm/src/hugegraph_llm/indices/vector_index/milvus_vector_store.py index b83aa2836..94be957f3 100644 --- a/hugegraph-llm/src/hugegraph_llm/indices/vector_index/milvus_vector_store.py +++ b/hugegraph-llm/src/hugegraph_llm/indices/vector_index/milvus_vector_store.py @@ -74,9 +74,7 @@ def __init__( def _create_collection(self): """Create a new collection in Milvus.""" id_field = FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True) - vector_field = FieldSchema( - name="embedding", dtype=DataType.FLOAT_VECTOR, dim=self.embed_dim - ) + vector_field = FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=self.embed_dim) property_field = FieldSchema(name="property", dtype=DataType.VARCHAR, max_length=65535) original_id_field = FieldSchema(name="original_id", dtype=DataType.INT64) @@ -149,9 +147,7 @@ def remove(self, props: Union[Set[Any], List[Any]]) -> int: finally: self.collection.release() - def search( - self, query_vector: List[float], top_k: int, dis_threshold: float = 0.9 - ) -> List[Any]: + def search(self, query_vector: List[float], top_k: int, dis_threshold: float = 0.9) -> List[Any]: try: if self.collection.num_entities == 0: return [] diff --git a/hugegraph-llm/src/hugegraph_llm/indices/vector_index/qdrant_vector_store.py b/hugegraph-llm/src/hugegraph_llm/indices/vector_index/qdrant_vector_store.py index 52914d409..259814613 100644 --- a/hugegraph-llm/src/hugegraph_llm/indices/vector_index/qdrant_vector_store.py +++ b/hugegraph-llm/src/hugegraph_llm/indices/vector_index/qdrant_vector_store.py @@ -15,8 +15,8 @@ # specific language governing permissions and limitations # under the License. -from typing import Any, Dict, List, Set, Union import uuid +from typing import Any, Dict, List, Set, Union from qdrant_client import QdrantClient # pylint: disable=import-error from qdrant_client.http import models # pylint: disable=import-error @@ -56,9 +56,7 @@ def _create_collection(self): """Create a new collection in Qdrant.""" self.client.create_collection( collection_name=self.name, - vectors_config=models.VectorParams( - size=self.embed_dim, distance=models.Distance.COSINE - ), + vectors_config=models.VectorParams(size=self.embed_dim, distance=models.Distance.COSINE), ) log.info("Created Qdrant collection '%s'", self.name) @@ -119,9 +117,7 @@ def remove(self, props: Union[Set[Any], List[Any]]) -> int: return remove_num def search(self, query_vector: List[float], top_k: int = 5, dis_threshold: float = 0.9): - search_result = self.client.search( - collection_name=self.name, query_vector=query_vector, limit=top_k - ) + search_result = self.client.search(collection_name=self.name, query_vector=query_vector, limit=top_k) result_properties = [] diff --git a/hugegraph-llm/src/hugegraph_llm/models/__init__.py b/hugegraph-llm/src/hugegraph_llm/models/__init__.py index 514361eb6..9cdd48c31 100644 --- a/hugegraph-llm/src/hugegraph_llm/models/__init__.py +++ b/hugegraph-llm/src/hugegraph_llm/models/__init__.py @@ -26,8 +26,6 @@ # This enables import statements like: from hugegraph_llm.models import llms # Making subpackages accessible -from . import llms -from . import embeddings -from . import rerankers +from . import embeddings, llms, rerankers __all__ = ["llms", "embeddings", "rerankers"] diff --git a/hugegraph-llm/src/hugegraph_llm/models/embeddings/init_embedding.py b/hugegraph-llm/src/hugegraph_llm/models/embeddings/init_embedding.py index b9b8527aa..d7cb43985 100644 --- a/hugegraph-llm/src/hugegraph_llm/models/embeddings/init_embedding.py +++ b/hugegraph-llm/src/hugegraph_llm/models/embeddings/init_embedding.py @@ -16,8 +16,7 @@ # under the License. -from hugegraph_llm.config import llm_settings -from hugegraph_llm.config import LLMConfig +from hugegraph_llm.config import LLMConfig, llm_settings from hugegraph_llm.models.embeddings.litellm import LiteLLMEmbedding from hugegraph_llm.models.embeddings.ollama import OllamaEmbedding from hugegraph_llm.models.embeddings.openai import OpenAIEmbedding diff --git a/hugegraph-llm/src/hugegraph_llm/models/embeddings/litellm.py b/hugegraph-llm/src/hugegraph_llm/models/embeddings/litellm.py index 3f9619cd9..06a590afa 100644 --- a/hugegraph-llm/src/hugegraph_llm/models/embeddings/litellm.py +++ b/hugegraph-llm/src/hugegraph_llm/models/embeddings/litellm.py @@ -66,14 +66,14 @@ def get_text_embedding(self, text: str) -> List[float]: def get_texts_embeddings(self, texts: List[str], batch_size: int = 32) -> List[List[float]]: """Get embeddings for multiple texts with automatic batch splitting. - + Parameters ---------- texts : List[str] A list of text strings to be embedded. batch_size : int, optional Maximum number of texts to process in a single API call (default: 32). - + Returns ------- List[List[float]] @@ -82,7 +82,7 @@ def get_texts_embeddings(self, texts: List[str], batch_size: int = 32) -> List[L all_embeddings = [] try: for i in range(0, len(texts), batch_size): - batch = texts[i:i + batch_size] + batch = texts[i : i + batch_size] response = embedding( model=self.model, input=batch, @@ -113,14 +113,14 @@ async def async_get_text_embedding(self, text: str) -> List[float]: async def async_get_texts_embeddings(self, texts: List[str], batch_size: int = 32) -> List[List[float]]: """Get embeddings for multiple texts asynchronously with automatic batch splitting. - + Parameters ---------- texts : List[str] A list of text strings to be embedded. batch_size : int, optional Maximum number of texts to process in a single API call (default: 32). - + Returns ------- List[List[float]] @@ -129,7 +129,7 @@ async def async_get_texts_embeddings(self, texts: List[str], batch_size: int = 3 all_embeddings = [] try: for i in range(0, len(texts), batch_size): - batch = texts[i:i + batch_size] + batch = texts[i : i + batch_size] response = await aembedding( model=self.model, input=batch, diff --git a/hugegraph-llm/src/hugegraph_llm/models/embeddings/ollama.py b/hugegraph-llm/src/hugegraph_llm/models/embeddings/ollama.py index 755ba5044..195ea6f6f 100644 --- a/hugegraph-llm/src/hugegraph_llm/models/embeddings/ollama.py +++ b/hugegraph-llm/src/hugegraph_llm/models/embeddings/ollama.py @@ -19,6 +19,7 @@ from typing import List import ollama + from .base import BaseEmbedding @@ -73,7 +74,7 @@ def get_texts_embeddings(self, texts: List[str], batch_size: int = 32) -> List[L all_embeddings = [] for i in range(0, len(texts), batch_size): - batch = texts[i:i + batch_size] + batch = texts[i : i + batch_size] response = self.client.embed(model=self.model, input=batch)["embeddings"] all_embeddings.extend([list(inner_sequence) for inner_sequence in response]) return all_embeddings @@ -83,9 +84,7 @@ async def async_get_text_embedding(self, text: str) -> List[float]: response = await self.async_client.embeddings(model=self.model, prompt=text) return list(response["embedding"]) - async def async_get_texts_embeddings( - self, texts: List[str], batch_size: int = 32 - ) -> List[List[float]]: + async def async_get_texts_embeddings(self, texts: List[str], batch_size: int = 32) -> List[List[float]]: # Ollama python client may not provide batch async embeddings; fallback per item # batch_size parameter included for consistency with base class signature results: List[List[float]] = [] diff --git a/hugegraph-llm/src/hugegraph_llm/models/embeddings/openai.py b/hugegraph-llm/src/hugegraph_llm/models/embeddings/openai.py index 342149165..0d0058cdb 100644 --- a/hugegraph-llm/src/hugegraph_llm/models/embeddings/openai.py +++ b/hugegraph-llm/src/hugegraph_llm/models/embeddings/openai.py @@ -16,9 +16,10 @@ # under the License. -from typing import Optional, List +from typing import List, Optional + +from openai import AsyncOpenAI, OpenAI -from openai import OpenAI, AsyncOpenAI from hugegraph_llm.models.embeddings.base import BaseEmbedding @@ -67,7 +68,7 @@ def get_texts_embeddings(self, texts: List[str], batch_size: int = 32) -> List[L """ all_embeddings = [] for i in range(0, len(texts), batch_size): - batch = texts[i:i + batch_size] + batch = texts[i : i + batch_size] response = self.client.embeddings.create(input=batch, model=self.model) all_embeddings.extend([data.embedding for data in response.data]) return all_embeddings @@ -94,7 +95,7 @@ async def async_get_texts_embeddings(self, texts: List[str], batch_size: int = 3 """ all_embeddings = [] for i in range(0, len(texts), batch_size): - batch = texts[i:i + batch_size] + batch = texts[i : i + batch_size] response = await self.aclient.embeddings.create(input=batch, model=self.model) all_embeddings.extend([data.embedding for data in response.data]) return all_embeddings diff --git a/hugegraph-llm/src/hugegraph_llm/models/llms/__init__.py b/hugegraph-llm/src/hugegraph_llm/models/llms/__init__.py index 1b0694a07..37ba2560c 100644 --- a/hugegraph-llm/src/hugegraph_llm/models/llms/__init__.py +++ b/hugegraph-llm/src/hugegraph_llm/models/llms/__init__.py @@ -20,7 +20,7 @@ This package contains various LLM client implementations including: - OpenAI clients -- Qianfan clients +- Qianfan clients - Ollama clients - LiteLLM clients """ diff --git a/hugegraph-llm/src/hugegraph_llm/models/llms/base.py b/hugegraph-llm/src/hugegraph_llm/models/llms/base.py index 69c082690..7a5851aff 100644 --- a/hugegraph-llm/src/hugegraph_llm/models/llms/base.py +++ b/hugegraph-llm/src/hugegraph_llm/models/llms/base.py @@ -16,7 +16,7 @@ # under the License. from abc import ABC, abstractmethod -from typing import Any, AsyncGenerator, Generator, List, Optional, Callable, Dict +from typing import Any, AsyncGenerator, Callable, Dict, Generator, List, Optional class BaseLLM(ABC): diff --git a/hugegraph-llm/src/hugegraph_llm/models/llms/init_llm.py b/hugegraph-llm/src/hugegraph_llm/models/llms/init_llm.py index a13641db0..fafa1bdf4 100644 --- a/hugegraph-llm/src/hugegraph_llm/models/llms/init_llm.py +++ b/hugegraph-llm/src/hugegraph_llm/models/llms/init_llm.py @@ -15,11 +15,10 @@ # specific language governing permissions and limitations # under the License. -from hugegraph_llm.config import LLMConfig +from hugegraph_llm.config import LLMConfig, llm_settings +from hugegraph_llm.models.llms.litellm import LiteLLMClient from hugegraph_llm.models.llms.ollama import OllamaClient from hugegraph_llm.models.llms.openai import OpenAIClient -from hugegraph_llm.models.llms.litellm import LiteLLMClient -from hugegraph_llm.config import llm_settings def get_chat_llm(llm_configs: LLMConfig): @@ -173,8 +172,4 @@ def get_text2gql_llm(self): if __name__ == "__main__": client = LLMs().get_chat_llm() print(client.generate(prompt="What is the capital of China?")) - print( - client.generate( - messages=[{"role": "user", "content": "What is the capital of China?"}] - ) - ) + print(client.generate(messages=[{"role": "user", "content": "What is the capital of China?"}])) diff --git a/hugegraph-llm/src/hugegraph_llm/models/llms/litellm.py b/hugegraph-llm/src/hugegraph_llm/models/llms/litellm.py index 6f3c8129c..dcf479c2f 100644 --- a/hugegraph-llm/src/hugegraph_llm/models/llms/litellm.py +++ b/hugegraph-llm/src/hugegraph_llm/models/llms/litellm.py @@ -15,16 +15,16 @@ # specific language governing permissions and limitations # under the License. -from typing import Callable, List, Optional, Dict, Any, AsyncGenerator +from typing import Any, AsyncGenerator, Callable, Dict, List, Optional import tiktoken -from litellm import completion, acompletion -from litellm.exceptions import RateLimitError, BudgetExceededError, APIError +from litellm import acompletion, completion +from litellm.exceptions import APIError, BudgetExceededError, RateLimitError from tenacity import ( retry, + retry_if_exception_type, stop_after_attempt, wait_exponential, - retry_if_exception_type, ) from hugegraph_llm.models.llms.base import BaseLLM diff --git a/hugegraph-llm/src/hugegraph_llm/models/llms/ollama.py b/hugegraph-llm/src/hugegraph_llm/models/llms/ollama.py index 6d08ce8cd..515d8d69d 100644 --- a/hugegraph-llm/src/hugegraph_llm/models/llms/ollama.py +++ b/hugegraph-llm/src/hugegraph_llm/models/llms/ollama.py @@ -17,7 +17,7 @@ import json -from typing import Any, AsyncGenerator, Generator, List, Optional, Callable, Dict +from typing import Any, AsyncGenerator, Callable, Dict, Generator, List, Optional import ollama from retry import retry @@ -118,9 +118,7 @@ async def agenerate_streaming( messages = [{"role": "user", "content": prompt}] try: - async_generator = await self.async_client.chat( - model=self.model, messages=messages, stream=True - ) + async_generator = await self.async_client.chat(model=self.model, messages=messages, stream=True) async for chunk in async_generator: token = chunk.get("message", {}).get("content", "") if on_token_callback: diff --git a/hugegraph-llm/src/hugegraph_llm/models/llms/openai.py b/hugegraph-llm/src/hugegraph_llm/models/llms/openai.py index e1088c890..f7a6d3f9c 100644 --- a/hugegraph-llm/src/hugegraph_llm/models/llms/openai.py +++ b/hugegraph-llm/src/hugegraph_llm/models/llms/openai.py @@ -15,16 +15,16 @@ # specific language governing permissions and limitations # under the License. -from typing import Callable, List, Optional, Dict, Any, Generator, AsyncGenerator +from typing import Any, AsyncGenerator, Callable, Dict, Generator, List, Optional import openai import tiktoken -from openai import OpenAI, AsyncOpenAI, RateLimitError, APITimeoutError, APIConnectionError +from openai import APIConnectionError, APITimeoutError, AsyncOpenAI, OpenAI, RateLimitError from tenacity import ( retry, + retry_if_exception_type, stop_after_attempt, wait_exponential, - retry_if_exception_type, ) from hugegraph_llm.models.llms.base import BaseLLM diff --git a/hugegraph-llm/src/hugegraph_llm/models/rerankers/cohere.py b/hugegraph-llm/src/hugegraph_llm/models/rerankers/cohere.py index aef530a5b..953c58da3 100644 --- a/hugegraph-llm/src/hugegraph_llm/models/rerankers/cohere.py +++ b/hugegraph-llm/src/hugegraph_llm/models/rerankers/cohere.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -from typing import Optional, List +from typing import List, Optional import requests diff --git a/hugegraph-llm/src/hugegraph_llm/models/rerankers/init_reranker.py b/hugegraph-llm/src/hugegraph_llm/models/rerankers/init_reranker.py index 6136d61b4..aa9f0c061 100644 --- a/hugegraph-llm/src/hugegraph_llm/models/rerankers/init_reranker.py +++ b/hugegraph-llm/src/hugegraph_llm/models/rerankers/init_reranker.py @@ -32,7 +32,5 @@ def get_reranker(self): model=llm_settings.reranker_model, ) if self.reranker_type == "siliconflow": - return SiliconReranker( - api_key=llm_settings.reranker_api_key, model=llm_settings.reranker_model - ) + return SiliconReranker(api_key=llm_settings.reranker_api_key, model=llm_settings.reranker_model) raise Exception("Reranker type is not supported!") diff --git a/hugegraph-llm/src/hugegraph_llm/models/rerankers/siliconflow.py b/hugegraph-llm/src/hugegraph_llm/models/rerankers/siliconflow.py index 903debfa9..fa35ffc64 100644 --- a/hugegraph-llm/src/hugegraph_llm/models/rerankers/siliconflow.py +++ b/hugegraph-llm/src/hugegraph_llm/models/rerankers/siliconflow.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -from typing import Optional, List +from typing import List, Optional import requests diff --git a/hugegraph-llm/src/hugegraph_llm/nodes/base_node.py b/hugegraph-llm/src/hugegraph_llm/nodes/base_node.py index 35e926fea..7ed8af8a7 100644 --- a/hugegraph-llm/src/hugegraph_llm/nodes/base_node.py +++ b/hugegraph-llm/src/hugegraph_llm/nodes/base_node.py @@ -14,7 +14,9 @@ # limitations under the License. from typing import Dict, Optional -from pycgraph import GNode, CStatus + +from pycgraph import CStatus, GNode + from hugegraph_llm.nodes.util import init_context from hugegraph_llm.state.ai_state import WkFlowInput, WkFlowState from hugegraph_llm.utils.log import log diff --git a/hugegraph-llm/src/hugegraph_llm/nodes/common_node/merge_rerank_node.py b/hugegraph-llm/src/hugegraph_llm/nodes/common_node/merge_rerank_node.py index 74e192c64..8ce38b560 100644 --- a/hugegraph-llm/src/hugegraph_llm/nodes/common_node/merge_rerank_node.py +++ b/hugegraph-llm/src/hugegraph_llm/nodes/common_node/merge_rerank_node.py @@ -13,11 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Any +from typing import Any, Dict + +from hugegraph_llm.config import huge_settings, llm_settings +from hugegraph_llm.models.embeddings.init_embedding import get_embedding from hugegraph_llm.nodes.base_node import BaseNode from hugegraph_llm.operators.common_op.merge_dedup_rerank import MergeDedupRerank -from hugegraph_llm.models.embeddings.init_embedding import get_embedding -from hugegraph_llm.config import huge_settings, llm_settings from hugegraph_llm.utils.log import log @@ -39,9 +40,7 @@ def node_init(self): rerank_method = self.wk_input.rerank_method or "bleu" near_neighbor_first = self.wk_input.near_neighbor_first or False custom_related_information = self.wk_input.custom_related_information or "" - topk_return_results = ( - self.wk_input.topk_return_results or huge_settings.topk_return_results - ) + topk_return_results = self.wk_input.topk_return_results or huge_settings.topk_return_results self.operator = MergeDedupRerank( embedding=embedding, diff --git a/hugegraph-llm/src/hugegraph_llm/nodes/document_node/chunk_split.py b/hugegraph-llm/src/hugegraph_llm/nodes/document_node/chunk_split.py index 08ada8bc8..602c1cc00 100644 --- a/hugegraph-llm/src/hugegraph_llm/nodes/document_node/chunk_split.py +++ b/hugegraph-llm/src/hugegraph_llm/nodes/document_node/chunk_split.py @@ -14,6 +14,7 @@ # limitations under the License. from pycgraph import CStatus + from hugegraph_llm.nodes.base_node import BaseNode from hugegraph_llm.operators.document_op.chunk_split import ChunkSplit from hugegraph_llm.state.ai_state import WkFlowInput, WkFlowState @@ -25,11 +26,7 @@ class ChunkSplitNode(BaseNode): wk_input: WkFlowInput = None def node_init(self): - if ( - self.wk_input.texts is None - or self.wk_input.language is None - or self.wk_input.split_type is None - ): + if self.wk_input.texts is None or self.wk_input.language is None or self.wk_input.split_type is None: return CStatus(-1, "Error occurs when prepare for workflow input") texts = self.wk_input.texts language = self.wk_input.language diff --git a/hugegraph-llm/src/hugegraph_llm/nodes/hugegraph_node/graph_query_node.py b/hugegraph-llm/src/hugegraph_llm/nodes/hugegraph_node/graph_query_node.py index c9d62a9d5..1b5a1b368 100644 --- a/hugegraph-llm/src/hugegraph_llm/nodes/hugegraph_node/graph_query_node.py +++ b/hugegraph-llm/src/hugegraph_llm/nodes/hugegraph_node/graph_query_node.py @@ -14,13 +14,14 @@ # limitations under the License. import json -from typing import Dict, Any, Tuple, List, Set, Optional +from typing import Any, Dict, List, Optional, Set, Tuple + +from pyhugegraph.client import PyHugeClient -from hugegraph_llm.nodes.base_node import BaseNode from hugegraph_llm.config import huge_settings, prompt +from hugegraph_llm.nodes.base_node import BaseNode from hugegraph_llm.operators.operator_list import OperatorList from hugegraph_llm.utils.log import log -from pyhugegraph.client import PyHugeClient # TODO: remove 'as('subj)' step VERTEX_QUERY_TPL = "g.V({keywords}).limit(8).as('subj').toList()" @@ -103,13 +104,9 @@ def node_init(self): self._max_items = self.wk_input.max_graph_items or huge_settings.max_graph_items self._prop_to_match = self.wk_input.prop_to_match self._num_gremlin_generate_example = ( - self.wk_input.gremlin_tmpl_num - if self.wk_input.gremlin_tmpl_num is not None - else -1 - ) - self.gremlin_prompt = ( - self.wk_input.gremlin_prompt or prompt.gremlin_generate_prompt + self.wk_input.gremlin_tmpl_num if self.wk_input.gremlin_tmpl_num is not None else -1 ) + self.gremlin_prompt = self.wk_input.gremlin_prompt or prompt.gremlin_generate_prompt self._limit_property = huge_settings.limit_property.lower() == "true" self._max_v_prop_len = self.wk_input.max_v_prop_len or 2048 self._max_e_prop_len = self.wk_input.max_e_prop_len or 256 @@ -140,9 +137,7 @@ def _gremlin_generate_query(self, context: Dict[str, Any]) -> Dict[str, Any]: query_embedding = context.get("query_embedding") self.operator_list.clear() - self.operator_list.example_index_query( - num_examples=self._num_gremlin_generate_example - ) + self.operator_list.example_index_query(num_examples=self._num_gremlin_generate_example) gremlin_response = self.operator_list.gremlin_generate_synthesize( context["simple_schema"], vertices=vertices, @@ -158,23 +153,18 @@ def _gremlin_generate_query(self, context: Dict[str, Any]) -> Dict[str, Any]: result = self._client.gremlin().exec(gremlin=gremlin)["data"] if result == [None]: result = [] - context["graph_result"] = [ - json.dumps(item, ensure_ascii=False) for item in result - ] + context["graph_result"] = [json.dumps(item, ensure_ascii=False) for item in result] if context["graph_result"]: context["graph_result_flag"] = 1 context["graph_context_head"] = ( - f"The following are graph query result " - f"from gremlin query `{gremlin}`.\n" + f"The following are graph query result from gremlin query `{gremlin}`.\n" ) except Exception as e: # pylint: disable=broad-except,broad-exception-caught log.error(e) context["graph_result"] = [] return context - def _limit_property_query( - self, value: Optional[str], item_type: str - ) -> Optional[str]: + def _limit_property_query(self, value: Optional[str], item_type: str) -> Optional[str]: # NOTE: we skip the filter for list/set type (e.g., list of string, add it if needed) if not self._limit_property or not isinstance(value, str): return value @@ -193,19 +183,13 @@ def _process_vertex( use_id_to_match: bool, v_cache: Set[str], ) -> Tuple[str, int, int]: - matched_str = ( - item["id"] if use_id_to_match else item["props"][self._prop_to_match] - ) + matched_str = item["id"] if use_id_to_match else item["props"][self._prop_to_match] if matched_str in node_cache: flat_rel = flat_rel[:-prior_edge_str_len] return flat_rel, prior_edge_str_len, depth node_cache.add(matched_str) - props_str = ", ".join( - f"{k}: {self._limit_property_query(v, 'v')}" - for k, v in item["props"].items() - if v - ) + props_str = ", ".join(f"{k}: {self._limit_property_query(v, 'v')}" for k, v in item["props"].items() if v) # TODO: we may remove label id or replace with label name if matched_str in v_cache: @@ -228,16 +212,10 @@ def _process_edge( use_id_to_match: bool, e_cache: Set[Tuple[str, str, str]], ) -> Tuple[str, int]: - props_str = ", ".join( - f"{k}: {self._limit_property_query(v, 'e')}" - for k, v in item["props"].items() - if v - ) + props_str = ", ".join(f"{k}: {self._limit_property_query(v, 'e')}" for k, v in item["props"].items() if v) props_str = f"{{{props_str}}}" if props_str else "" prev_matched_str = ( - raw_flat_rel[i - 1]["id"] - if use_id_to_match - else (raw_flat_rel)[i - 1]["props"][self._prop_to_match] + raw_flat_rel[i - 1]["id"] if use_id_to_match else (raw_flat_rel)[i - 1]["props"][self._prop_to_match] ) edge_key = (item["inV"], item["label"], item["outV"]) @@ -247,11 +225,7 @@ def _process_edge( else: edge_label = item["label"] - edge_str = ( - f"--[{edge_label}]-->" - if item["outV"] == prev_matched_str - else f"<--[{edge_label}]--" - ) + edge_str = f"--[{edge_label}]-->" if item["outV"] == prev_matched_str else f"<--[{edge_label}]--" path_str += edge_str prior_edge_str_len = len(edge_str) return path_str, prior_edge_str_len @@ -293,17 +267,13 @@ def _process_path( return flat_rel, nodes_with_degree - def _update_vertex_degree_list( - self, vertex_degree_list: List[Set[str]], nodes_with_degree: List[str] - ) -> None: + def _update_vertex_degree_list(self, vertex_degree_list: List[Set[str]], nodes_with_degree: List[str]) -> None: for depth, node_str in enumerate(nodes_with_degree): if depth >= len(vertex_degree_list): vertex_degree_list.append(set()) vertex_degree_list[depth].add(node_str) - def _format_graph_query_result( - self, query_paths - ) -> Tuple[Set[str], List[Set[str]], Dict[str, List[str]]]: + def _format_graph_query_result(self, query_paths) -> Tuple[Set[str], List[Set[str]], Dict[str, List[str]]]: use_id_to_match = self._prop_to_match is None subgraph = set() subgraph_with_degree = {} @@ -313,9 +283,7 @@ def _format_graph_query_result( for path in query_paths: # 1. Process each path - path_str, vertex_with_degree = self._process_path( - path, use_id_to_match, v_cache, e_cache - ) + path_str, vertex_with_degree = self._process_path(path, use_id_to_match, v_cache, e_cache) subgraph.add(path_str) subgraph_with_degree[path_str] = vertex_with_degree # 2. Update vertex degree list @@ -333,17 +301,13 @@ def _get_graph_schema(self, refresh: bool = False) -> str: relationships = schema.getRelations() self._schema = ( - f"Vertex properties: {vertex_schema}\n" - f"Edge properties: {edge_schema}\n" - f"Relationships: {relationships}\n" + f"Vertex properties: {vertex_schema}\nEdge properties: {edge_schema}\nRelationships: {relationships}\n" ) log.debug("Link(Relation): %s", relationships) return self._schema @staticmethod - def _extract_label_names( - source: str, head: str = "name: ", tail: str = ", " - ) -> List[str]: + def _extract_label_names(source: str, head: str = "name: ", tail: str = ", ") -> List[str]: result = [] for s in source.split(head): end = s.find(tail) @@ -356,12 +320,8 @@ def _extract_labels_from_schema(self) -> Tuple[List[str], List[str]]: schema = self._get_graph_schema() vertex_props_str, edge_props_str = schema.split("\n")[:2] # TODO: rename to vertex (also need update in the schema) - vertex_props_str = ( - vertex_props_str[len("Vertex properties: ") :].strip("[").strip("]") - ) - edge_props_str = ( - edge_props_str[len("Edge properties: ") :].strip("[").strip("]") - ) + vertex_props_str = vertex_props_str[len("Vertex properties: ") :].strip("[").strip("]") + edge_props_str = edge_props_str[len("Edge properties: ") :].strip("[").strip("]") vertex_labels = self._extract_label_names(vertex_props_str) edge_labels = self._extract_label_names(edge_props_str) return vertex_labels, edge_labels @@ -439,13 +399,9 @@ def _subgraph_query(self, context: Dict[str, Any]) -> Dict[str, Any]: max_deep=self._max_deep, max_items=self._max_items, ) - log.warning( - "Unable to find vid, downgraded to property query, please confirm if it meets expectation." - ) + log.warning("Unable to find vid, downgraded to property query, please confirm if it meets expectation.") - paths: List[Any] = self._client.gremlin().exec(gremlin=gremlin_query)[ - "data" - ] + paths: List[Any] = self._client.gremlin().exec(gremlin=gremlin_query)["data"] ( graph_chain_knowledge, vertex_degree_list, @@ -455,9 +411,7 @@ def _subgraph_query(self, context: Dict[str, Any]) -> Dict[str, Any]: context["graph_result"] = list(graph_chain_knowledge) if context["graph_result"]: context["graph_result_flag"] = 0 - context["vertex_degree_list"] = [ - list(vertex_degree) for vertex_degree in vertex_degree_list - ] + context["vertex_degree_list"] = [list(vertex_degree) for vertex_degree in vertex_degree_list] context["knowledge_with_degree"] = knowledge_with_degree context["graph_context_head"] = ( f"The following are graph knowledge in {self._max_deep} depth, e.g:\n" @@ -491,9 +445,7 @@ def operator_schedule(self, data_json: Dict[str, Any]) -> Dict[str, Any]: data_json = self._subgraph_query(data_json) if data_json.get("graph_result"): - log.debug( - "Knowledge from Graph:\n%s", "\n".join(data_json["graph_result"]) - ) + log.debug("Knowledge from Graph:\n%s", "\n".join(data_json["graph_result"])) else: log.debug("No Knowledge Extracted from Graph") diff --git a/hugegraph-llm/src/hugegraph_llm/nodes/hugegraph_node/schema.py b/hugegraph-llm/src/hugegraph_llm/nodes/hugegraph_node/schema.py index 61e5f0488..e9f9c6082 100644 --- a/hugegraph-llm/src/hugegraph_llm/nodes/hugegraph_node/schema.py +++ b/hugegraph-llm/src/hugegraph_llm/nodes/hugegraph_node/schema.py @@ -16,6 +16,7 @@ import json from pycgraph import CStatus + from hugegraph_llm.nodes.base_node import BaseNode from hugegraph_llm.operators.common_op.check_schema import CheckSchema from hugegraph_llm.operators.hugegraph_op.schema_manager import SchemaManager diff --git a/hugegraph-llm/src/hugegraph_llm/nodes/index_node/build_gremlin_example_index.py b/hugegraph-llm/src/hugegraph_llm/nodes/index_node/build_gremlin_example_index.py index f43713ab7..697a88ca7 100644 --- a/hugegraph-llm/src/hugegraph_llm/nodes/index_node/build_gremlin_example_index.py +++ b/hugegraph-llm/src/hugegraph_llm/nodes/index_node/build_gremlin_example_index.py @@ -40,9 +40,7 @@ def node_init(self): vector_index = get_vector_index_class(index_settings.cur_vector_index) embedding = Embeddings().get_embedding() - self.build_gremlin_example_index_op = BuildGremlinExampleIndex( - embedding, examples, vector_index - ) + self.build_gremlin_example_index_op = BuildGremlinExampleIndex(embedding, examples, vector_index) return super().node_init() def operator_schedule(self, data_json): diff --git a/hugegraph-llm/src/hugegraph_llm/nodes/index_node/gremlin_example_index_query.py b/hugegraph-llm/src/hugegraph_llm/nodes/index_node/gremlin_example_index_query.py index 462f5ce21..fee1539ed 100644 --- a/hugegraph-llm/src/hugegraph_llm/nodes/index_node/gremlin_example_index_query.py +++ b/hugegraph-llm/src/hugegraph_llm/nodes/index_node/gremlin_example_index_query.py @@ -20,11 +20,11 @@ from pycgraph import CStatus from hugegraph_llm.config import index_settings +from hugegraph_llm.models.embeddings.init_embedding import Embeddings from hugegraph_llm.nodes.base_node import BaseNode from hugegraph_llm.operators.index_op.gremlin_example_index_query import ( GremlinExampleIndexQuery, ) -from hugegraph_llm.models.embeddings.init_embedding import Embeddings class GremlinExampleIndexQueryNode(BaseNode): diff --git a/hugegraph-llm/src/hugegraph_llm/nodes/index_node/semantic_id_query_node.py b/hugegraph-llm/src/hugegraph_llm/nodes/index_node/semantic_id_query_node.py index 1fe19d05b..fff4c6190 100644 --- a/hugegraph-llm/src/hugegraph_llm/nodes/index_node/semantic_id_query_node.py +++ b/hugegraph-llm/src/hugegraph_llm/nodes/index_node/semantic_id_query_node.py @@ -13,13 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Any +from typing import Any, Dict from pycgraph import CStatus + +from hugegraph_llm.config import huge_settings, index_settings +from hugegraph_llm.models.embeddings.init_embedding import Embeddings from hugegraph_llm.nodes.base_node import BaseNode from hugegraph_llm.operators.index_op.semantic_id_query import SemanticIdQuery -from hugegraph_llm.models.embeddings.init_embedding import Embeddings -from hugegraph_llm.config import huge_settings, index_settings from hugegraph_llm.utils.log import log @@ -45,21 +46,13 @@ def node_init(self): vector_index = get_vector_index_class(index_settings.cur_vector_index) embedding = Embeddings().get_embedding() - by = ( - self.wk_input.semantic_by - if self.wk_input.semantic_by is not None - else "keywords" - ) + by = self.wk_input.semantic_by if self.wk_input.semantic_by is not None else "keywords" topk_per_keyword = ( self.wk_input.topk_per_keyword if self.wk_input.topk_per_keyword is not None else huge_settings.topk_per_keyword ) - topk_per_query = ( - self.wk_input.topk_per_query - if self.wk_input.topk_per_query is not None - else 10 - ) + topk_per_query = self.wk_input.topk_per_query if self.wk_input.topk_per_query is not None else 10 vector_dis_threshold = ( self.wk_input.vector_dis_threshold if self.wk_input.vector_dis_threshold is not None @@ -98,8 +91,6 @@ def operator_schedule(self, data_json: Dict[str, Any]) -> Dict[str, Any]: semantic_result = self.semantic_id_query.run(data_json) match_vids = semantic_result.get("match_vids", []) - log.info( - "Semantic query completed, found %d matching vertex IDs", len(match_vids) - ) + log.info("Semantic query completed, found %d matching vertex IDs", len(match_vids)) return semantic_result diff --git a/hugegraph-llm/src/hugegraph_llm/nodes/index_node/vector_query_node.py b/hugegraph-llm/src/hugegraph_llm/nodes/index_node/vector_query_node.py index 59c7b6144..50d1f368d 100644 --- a/hugegraph-llm/src/hugegraph_llm/nodes/index_node/vector_query_node.py +++ b/hugegraph-llm/src/hugegraph_llm/nodes/index_node/vector_query_node.py @@ -13,11 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Any +from typing import Any, Dict + from hugegraph_llm.config import index_settings +from hugegraph_llm.models.embeddings.init_embedding import Embeddings from hugegraph_llm.nodes.base_node import BaseNode from hugegraph_llm.operators.index_op.vector_index_query import VectorIndexQuery -from hugegraph_llm.models.embeddings.init_embedding import Embeddings from hugegraph_llm.utils.log import log diff --git a/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/answer_synthesize_node.py b/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/answer_synthesize_node.py index 6997cd781..5ede8673b 100644 --- a/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/answer_synthesize_node.py +++ b/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/answer_synthesize_node.py @@ -13,7 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Any +from typing import Any, Dict + from hugegraph_llm.nodes.base_node import BaseNode from hugegraph_llm.operators.llm_op.answer_synthesize import AnswerSynthesize from hugegraph_llm.utils.log import log diff --git a/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/extract_info.py b/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/extract_info.py index 1d1abb003..9ac26970c 100644 --- a/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/extract_info.py +++ b/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/extract_info.py @@ -14,6 +14,7 @@ # limitations under the License. from pycgraph import CStatus + from hugegraph_llm.config import llm_settings from hugegraph_llm.models.llms.init_llm import get_chat_llm from hugegraph_llm.nodes.base_node import BaseNode diff --git a/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/keyword_extract_node.py b/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/keyword_extract_node.py index 60542ddc1..cb713c1dc 100644 --- a/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/keyword_extract_node.py +++ b/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/keyword_extract_node.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Any +from typing import Any, Dict from hugegraph_llm.nodes.base_node import BaseNode from hugegraph_llm.operators.llm_op.keyword_extract import KeywordExtract @@ -31,9 +31,7 @@ def node_init(self): """ Initialize the keyword extraction operator. """ - max_keywords = ( - self.wk_input.max_keywords if self.wk_input.max_keywords is not None else 5 - ) + max_keywords = self.wk_input.max_keywords if self.wk_input.max_keywords is not None else 5 extract_template = self.wk_input.keywords_extract_prompt self.operator = KeywordExtract( diff --git a/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/prompt_generate.py b/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/prompt_generate.py index b0c1dbe14..01708533a 100644 --- a/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/prompt_generate.py +++ b/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/prompt_generate.py @@ -14,6 +14,7 @@ # limitations under the License. from pycgraph import CStatus + from hugegraph_llm.config import llm_settings from hugegraph_llm.models.llms.init_llm import get_chat_llm from hugegraph_llm.nodes.base_node import BaseNode diff --git a/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/schema_build.py b/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/schema_build.py index 69d731eb3..01b2ca64d 100644 --- a/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/schema_build.py +++ b/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/schema_build.py @@ -16,11 +16,12 @@ import json from pycgraph import CStatus -from hugegraph_llm.nodes.base_node import BaseNode -from hugegraph_llm.state.ai_state import WkFlowInput, WkFlowState -from hugegraph_llm.models.llms.init_llm import get_chat_llm + from hugegraph_llm.config import llm_settings +from hugegraph_llm.models.llms.init_llm import get_chat_llm +from hugegraph_llm.nodes.base_node import BaseNode from hugegraph_llm.operators.llm_op.schema_build import SchemaBuilder +from hugegraph_llm.state.ai_state import WkFlowInput, WkFlowState from hugegraph_llm.utils.log import log @@ -61,16 +62,12 @@ def node_init(self): # few_shot_schema: already parsed dict or raw JSON string few_shot_schema = {} - fss_src = ( - self.wk_input.few_shot_schema if self.wk_input.few_shot_schema else None - ) + fss_src = self.wk_input.few_shot_schema if self.wk_input.few_shot_schema else None if fss_src: try: few_shot_schema = json.loads(fss_src) except json.JSONDecodeError as e: - return CStatus( - -1, f"Few Shot Schema is not in a valid JSON format: {e}" - ) + return CStatus(-1, f"Few Shot Schema is not in a valid JSON format: {e}") _context_payload = { "raw_texts": raw_texts, diff --git a/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/text2gremlin.py b/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/text2gremlin.py index 0904b9920..a4e621e64 100644 --- a/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/text2gremlin.py +++ b/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/text2gremlin.py @@ -18,11 +18,11 @@ import json from typing import Any, Dict, Optional - +from hugegraph_llm.config import llm_settings +from hugegraph_llm.config import prompt as prompt_cfg +from hugegraph_llm.models.llms.init_llm import get_text2gql_llm from hugegraph_llm.nodes.base_node import BaseNode from hugegraph_llm.operators.llm_op.gremlin_generate import GremlinGenerateSynthesize -from hugegraph_llm.models.llms.init_llm import get_text2gql_llm -from hugegraph_llm.config import llm_settings, prompt as prompt_cfg def _stable_schema_string(state_json: Dict[str, Any]) -> str: diff --git a/hugegraph-llm/src/hugegraph_llm/operators/common_op/check_schema.py b/hugegraph-llm/src/hugegraph_llm/operators/common_op/check_schema.py index 47b0f060f..63618aaec 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/common_op/check_schema.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/common_op/check_schema.py @@ -72,9 +72,7 @@ def _process_property_labels(self, schema: Dict[str, Any]) -> (list, set): property_label_set = {label["name"] for label in property_labels} return property_labels, property_label_set - def _process_vertex_labels( - self, schema: Dict[str, Any], property_labels: list, property_label_set: set - ) -> None: + def _process_vertex_labels(self, schema: Dict[str, Any], property_labels: list, property_label_set: set) -> None: for vertex_label in schema["vertexlabels"]: self._validate_vertex_label(vertex_label) properties = vertex_label["properties"] @@ -86,9 +84,7 @@ def _process_vertex_labels( vertex_label["nullable_keys"] = nullable_keys self._add_missing_properties(properties, property_labels, property_label_set) - def _process_edge_labels( - self, schema: Dict[str, Any], property_labels: list, property_label_set: set - ) -> None: + def _process_edge_labels(self, schema: Dict[str, Any], property_labels: list, property_label_set: set) -> None: for edge_label in schema["edgelabels"]: self._validate_edge_label(edge_label) properties = edge_label.get("properties", []) @@ -111,14 +107,8 @@ def _validate_vertex_label(self, vertex_label: Dict[str, Any]) -> None: def _validate_edge_label(self, edge_label: Dict[str, Any]) -> None: check_type(edge_label, dict, "EdgeLabel in input data is not a dictionary.") - if ( - "name" not in edge_label - or "source_label" not in edge_label - or "target_label" not in edge_label - ): - log_and_raise( - "EdgeLabel in input data does not contain 'name', 'source_label', 'target_label'." - ) + if "name" not in edge_label or "source_label" not in edge_label or "target_label" not in edge_label: + log_and_raise("EdgeLabel in input data does not contain 'name', 'source_label', 'target_label'.") check_type(edge_label["name"], str, "'name' in edge_label is not of correct type.") check_type( edge_label["source_label"], @@ -137,9 +127,7 @@ def _process_keys(self, label: Dict[str, Any], key_type: str, default_keys: list new_keys = [key for key in keys if key in label["properties"]] return new_keys - def _add_missing_properties( - self, properties: list, property_labels: list, property_label_set: set - ) -> None: + def _add_missing_properties(self, properties: list, property_labels: list, property_label_set: set) -> None: for prop in properties: if prop not in property_label_set: property_labels.append( diff --git a/hugegraph-llm/src/hugegraph_llm/operators/common_op/merge_dedup_rerank.py b/hugegraph-llm/src/hugegraph_llm/operators/common_op/merge_dedup_rerank.py index dc5b15e00..a61fbc1b7 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/common_op/merge_dedup_rerank.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/common_op/merge_dedup_rerank.py @@ -16,7 +16,7 @@ # under the License. -from typing import Literal, Dict, Any, List, Optional, Tuple +from typing import Any, Dict, List, Literal, Optional, Tuple import jieba import requests @@ -126,20 +126,15 @@ def _rerank_with_vertex_degree( reranker = Rerankers().get_reranker() try: vertex_rerank_res = [ - reranker.get_rerank_lists(query, vertex_degree) + [""] - for vertex_degree in vertex_degree_list + reranker.get_rerank_lists(query, vertex_degree) + [""] for vertex_degree in vertex_degree_list ] except requests.exceptions.RequestException as e: - log.warning( - "Online reranker fails, automatically switches to local bleu method: %s", e - ) + log.warning("Online reranker fails, automatically switches to local bleu method: %s", e) self.method = "bleu" self.switch_to_bleu = True if self.method == "bleu": - vertex_rerank_res = [ - _bleu_rerank(query, vertex_degree) + [""] for vertex_degree in vertex_degree_list - ] + vertex_rerank_res = [_bleu_rerank(query, vertex_degree) + [""] for vertex_degree in vertex_degree_list] depth = len(vertex_degree_list) for result in results: @@ -149,9 +144,7 @@ def _rerank_with_vertex_degree( knowledge_with_degree[result] += [""] * (depth - len(knowledge_with_degree[result])) def sort_key(res: str) -> Tuple[int, ...]: - return tuple( - vertex_rerank_res[i].index(knowledge_with_degree[res][i]) for i in range(depth) - ) + return tuple(vertex_rerank_res[i].index(knowledge_with_degree[res][i]) for i in range(depth)) sorted_results = sorted(results, key=sort_key) return sorted_results[:topn] diff --git a/hugegraph-llm/src/hugegraph_llm/operators/common_op/nltk_helper.py b/hugegraph-llm/src/hugegraph_llm/operators/common_op/nltk_helper.py index 30b2c6494..a1a660702 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/common_op/nltk_helper.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/common_op/nltk_helper.py @@ -18,8 +18,8 @@ import os import sys from pathlib import Path -from typing import List, Optional, Dict -from urllib.error import URLError, HTTPError +from typing import Dict, List, Optional +from urllib.error import HTTPError, URLError import nltk from nltk.corpus import stopwords @@ -83,7 +83,7 @@ def check_nltk_data(self): 'punkt': 'tokenizers/punkt', 'punkt_tab': 'tokenizers/punkt_tab', 'averaged_perceptron_tagger': 'taggers/averaged_perceptron_tagger', - "averaged_perceptron_tagger_eng": 'taggers/averaged_perceptron_tagger_eng' + "averaged_perceptron_tagger_eng": 'taggers/averaged_perceptron_tagger_eng', } for package, path in required_packages.items(): diff --git a/hugegraph-llm/src/hugegraph_llm/operators/document_op/chunk_split.py b/hugegraph-llm/src/hugegraph_llm/operators/document_op/chunk_split.py index c31e77af7..a22e4de88 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/document_op/chunk_split.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/document_op/chunk_split.py @@ -16,7 +16,7 @@ # under the License. -from typing import Literal, Dict, Any, Optional, Union, List +from typing import Any, Dict, List, Literal, Optional, Union from langchain_text_splitters import RecursiveCharacterTextSplitter @@ -56,9 +56,7 @@ def _get_text_splitter(self, split_type: str): chunk_size=500, chunk_overlap=30, separators=self.separators ).split_text if split_type == SPLIT_TYPE_SENTENCE: - return RecursiveCharacterTextSplitter( - chunk_size=50, chunk_overlap=0, separators=self.separators - ).split_text + return RecursiveCharacterTextSplitter(chunk_size=50, chunk_overlap=0, separators=self.separators).split_text raise ValueError("Type must be paragraph, sentence, html or markdown") def run(self, context: Optional[Dict[str, Any]]) -> Dict[str, Any]: diff --git a/hugegraph-llm/src/hugegraph_llm/operators/document_op/textrank_word_extract.py b/hugegraph-llm/src/hugegraph_llm/operators/document_op/textrank_word_extract.py index 1bd17c733..fdbb76668 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/document_op/textrank_word_extract.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/document_op/textrank_word_extract.py @@ -37,15 +37,16 @@ def __init__(self, keyword_num: int = 5, window_size: int = 3): self.pos_filter = { 'chinese': ('n', 'nr', 'ns', 'nt', 'nrt', 'nz', 'v', 'vd', 'vn', "eng", "j", "l"), - 'english': ('NN', 'NNS', 'NNP', 'NNPS', 'VB', 'VBG', 'VBN', 'VBZ') + 'english': ('NN', 'NNS', 'NNP', 'NNPS', 'VB', 'VBG', 'VBN', 'VBZ'), } - self.rules = [r"https?://\S+|www\.\S+", - r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}\b", - r"\b\w+(?:[-’\']\w+)+\b", - r"\b\d+[,.]\d+\b"] + self.rules = [ + r"https?://\S+|www\.\S+", + r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}\b", + r"\b\w+(?:[-’\']\w+)+\b", + r"\b\d+[,.]\d+\b", + ] def _word_mask(self, text): - placeholder_id_counter = 0 placeholder_map = {} @@ -64,11 +65,7 @@ def _create_placeholder(match_obj): @staticmethod def _get_valid_tokens(masked_text): - patterns_to_keep = [ - r'__shieldword_\d+__', - r'\b\w+\b', - r'[\u4e00-\u9fff]+' - ] + patterns_to_keep = [r'__shieldword_\d+__', r'\b\w+\b', r'[\u4e00-\u9fff]+'] combined_pattern = re.compile('|'.join(patterns_to_keep), re.IGNORECASE) tokens = combined_pattern.findall(masked_text) text_for_nltk = ' '.join(tokens) @@ -96,8 +93,7 @@ def _multi_preprocess(self, text): if re.compile('[\u4e00-\u9fff]').search(word): jieba_tokens = pseg.cut(word) for ch_word, ch_flag in jieba_tokens: - if len(ch_word) >= 1 and ch_flag in self.pos_filter['chinese'] \ - and ch_word not in ch_stop_words: + if len(ch_word) >= 1 and ch_flag in self.pos_filter['chinese'] and ch_word not in ch_stop_words: words.append(ch_word) elif len(word) >= 1 and flag in self.pos_filter['english'] and word.lower() not in en_stop_words: words.append(word) @@ -127,7 +123,7 @@ def _rank_nodes(self): pagerank_scores = self.graph.pagerank(directed=False, damping=0.85, weights='weight') if max(pagerank_scores) > 0: - pagerank_scores = [scores/max(pagerank_scores) for scores in pagerank_scores] + pagerank_scores = [scores / max(pagerank_scores) for scores in pagerank_scores] node_names = self.graph.vs['name'] return dict(zip(node_names, pagerank_scores)) diff --git a/hugegraph-llm/src/hugegraph_llm/operators/document_op/word_extract.py b/hugegraph-llm/src/hugegraph_llm/operators/document_op/word_extract.py index 4fee8d486..174752097 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/document_op/word_extract.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/document_op/word_extract.py @@ -17,7 +17,7 @@ import re -from typing import Dict, Any, Optional, List +from typing import Any, Dict, List, Optional import jieba @@ -74,8 +74,6 @@ def _filter_keywords( results.add(token) sub_tokens = re.findall(r"\w+", token) if len(sub_tokens) > 1: - results.update( - {w for w in sub_tokens if w not in NLTKHelper().stopwords(lang=self._language)} - ) + results.update({w for w in sub_tokens if w not in NLTKHelper().stopwords(lang=self._language)}) return list(results) diff --git a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/commit_to_hugegraph.py b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/commit_to_hugegraph.py index ba4392f7c..d464b80ef 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/commit_to_hugegraph.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/commit_to_hugegraph.py @@ -15,14 +15,15 @@ # specific language governing permissions and limitations # under the License. -from typing import Dict, Any +from typing import Any, Dict + +from pyhugegraph.client import PyHugeClient +from pyhugegraph.utils.exceptions import CreateError, NotFoundError from hugegraph_llm.config import huge_settings from hugegraph_llm.enums.property_cardinality import PropertyCardinality from hugegraph_llm.enums.property_data_type import PropertyDataType, default_value_map from hugegraph_llm.utils.log import log -from pyhugegraph.client import PyHugeClient -from pyhugegraph.utils.exceptions import NotFoundError, CreateError class Commit2Graph: @@ -41,17 +42,13 @@ def run(self, data: dict) -> Dict[str, Any]: vertices = data.get("vertices", []) edges = data.get("edges", []) if not vertices and not edges: - log.critical( - "(Loading) Both vertices and edges are empty. Please check the input data again." - ) + log.critical("(Loading) Both vertices and edges are empty. Please check the input data again.") raise ValueError("Both vertices and edges input are empty.") if not schema: # TODO: ensure the function works correctly (update the logic later) self.schema_free_mode(data.get("triples", [])) - log.warning( - "Using schema_free mode, could try schema_define mode for better effect!" - ) + log.warning("Using schema_free mode, could try schema_define mode for better effect!") else: self.init_schema_if_need(schema) self.load_into_graph(vertices, edges, schema) @@ -67,9 +64,7 @@ def _set_default_property(self, key, input_properties, property_label_map): # list or set default_value = [] input_properties[key] = default_value - log.warning( - "Property '%s' missing in vertex, set to '%s' for now", key, default_value - ) + log.warning("Property '%s' missing in vertex, set to '%s' for now", key, default_value) def _handle_graph_creation(self, func, *args, **kwargs): try: @@ -83,13 +78,9 @@ def _handle_graph_creation(self, func, *args, **kwargs): def load_into_graph(self, vertices, edges, schema): # pylint: disable=too-many-statements # pylint: disable=R0912 (too-many-branches) - vertex_label_map = { - v_label["name"]: v_label for v_label in schema["vertexlabels"] - } + vertex_label_map = {v_label["name"]: v_label for v_label in schema["vertexlabels"]} edge_label_map = {e_label["name"]: e_label for e_label in schema["edgelabels"]} - property_label_map = { - p_label["name"]: p_label for p_label in schema["propertykeys"] - } + property_label_map = {p_label["name"]: p_label for p_label in schema["propertykeys"]} for vertex in vertices: input_label = vertex["label"] @@ -105,9 +96,7 @@ def load_into_graph(self, vertices, edges, schema): # pylint: disable=too-many- vertex_label = vertex_label_map[input_label] primary_keys = vertex_label["primary_keys"] nullable_keys = vertex_label.get("nullable_keys", []) - non_null_keys = [ - key for key in vertex_label["properties"] if key not in nullable_keys - ] + non_null_keys = [key for key in vertex_label["properties"] if key not in nullable_keys] has_problem = False # 2. Handle primary-keys mode vertex @@ -139,9 +128,7 @@ def load_into_graph(self, vertices, edges, schema): # pylint: disable=too-many- # 3. Ensure all non-nullable props are set for key in non_null_keys: if key not in input_properties: - self._set_default_property( - key, input_properties, property_label_map - ) + self._set_default_property(key, input_properties, property_label_map) # 4. Check all data type value is right for key, value in input_properties.items(): @@ -159,9 +146,7 @@ def load_into_graph(self, vertices, edges, schema): # pylint: disable=too-many- continue # TODO: we could try batch add vertices first, setback to single-mode if failed - vid = self._handle_graph_creation( - self.client.graph().addVertex, input_label, input_properties - ).id + vid = self._handle_graph_creation(self.client.graph().addVertex, input_label, input_properties).id vertex["id"] = vid for edge in edges: @@ -178,9 +163,7 @@ def load_into_graph(self, vertices, edges, schema): # pylint: disable=too-many- continue # TODO: we could try batch add edges first, setback to single-mode if failed - self._handle_graph_creation( - self.client.graph().addEdge, label, start, end, properties - ) + self._handle_graph_creation(self.client.graph().addEdge, label, start, end, properties) def init_schema_if_need(self, schema: dict): properties = schema["propertykeys"] @@ -204,27 +187,19 @@ def init_schema_if_need(self, schema: dict): source_vertex_label = edge["source_label"] target_vertex_label = edge["target_label"] properties = edge["properties"] - self.schema.edgeLabel(edge_label).sourceLabel( - source_vertex_label - ).targetLabel(target_vertex_label).properties(*properties).nullableKeys( - *properties - ).ifNotExist().create() + self.schema.edgeLabel(edge_label).sourceLabel(source_vertex_label).targetLabel( + target_vertex_label + ).properties(*properties).nullableKeys(*properties).ifNotExist().create() def schema_free_mode(self, data): self.schema.propertyKey("name").asText().ifNotExist().create() - self.schema.vertexLabel("vertex").useCustomizeStringId().properties( + self.schema.vertexLabel("vertex").useCustomizeStringId().properties("name").ifNotExist().create() + self.schema.edgeLabel("edge").sourceLabel("vertex").targetLabel("vertex").properties( "name" ).ifNotExist().create() - self.schema.edgeLabel("edge").sourceLabel("vertex").targetLabel( - "vertex" - ).properties("name").ifNotExist().create() - self.schema.indexLabel("vertexByName").onV("vertex").by( - "name" - ).secondary().ifNotExist().create() - self.schema.indexLabel("edgeByName").onE("edge").by( - "name" - ).secondary().ifNotExist().create() + self.schema.indexLabel("vertexByName").onV("vertex").by("name").secondary().ifNotExist().create() + self.schema.indexLabel("edgeByName").onE("edge").by("name").secondary().ifNotExist().create() for item in data: s, p, o = (element.strip() for element in item) @@ -277,9 +252,7 @@ def _set_property_data_type(self, property_key, data_type): log.warning("UUID type is not supported, use text instead") property_key.asText() else: - log.error( - "Unknown data type %s for property_key %s", data_type, property_key - ) + log.error("Unknown data type %s for property_key %s", data_type, property_key) def _set_property_cardinality(self, property_key, cardinality): if cardinality == PropertyCardinality.SINGLE: @@ -289,13 +262,9 @@ def _set_property_cardinality(self, property_key, cardinality): elif cardinality == PropertyCardinality.SET: property_key.valueSet() else: - log.error( - "Unknown cardinality %s for property_key %s", cardinality, property_key - ) + log.error("Unknown cardinality %s for property_key %s", cardinality, property_key) - def _check_property_data_type( - self, data_type: str, cardinality: str, value - ) -> bool: + def _check_property_data_type(self, data_type: str, cardinality: str, value) -> bool: if cardinality in ( PropertyCardinality.LIST.value, PropertyCardinality.SET.value, @@ -325,9 +294,7 @@ def _check_single_data_type(self, data_type: str, value) -> bool: if data_type in (PropertyDataType.TEXT.value, PropertyDataType.UUID.value): return isinstance(value, str) # TODO: check ok below - if ( - data_type == PropertyDataType.DATE.value - ): # the format should be "yyyy-MM-dd" + if data_type == PropertyDataType.DATE.value: # the format should be "yyyy-MM-dd" import re return isinstance(value, str) and re.match(r"^\d{4}-\d{2}-\d{2}$", value) diff --git a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/fetch_graph_data.py b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/fetch_graph_data.py index 4c4c167c4..c3f427e93 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/fetch_graph_data.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/fetch_graph_data.py @@ -16,7 +16,7 @@ # under the License. -from typing import Optional, Dict, Any +from typing import Any, Dict, Optional from pyhugegraph.client import PyHugeClient diff --git a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/schema_manager.py b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/schema_manager.py index 2f0643a77..c265646fa 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/schema_manager.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/schema_manager.py @@ -14,11 +14,12 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Dict, Any, Optional +from typing import Any, Dict, Optional -from hugegraph_llm.config import huge_settings from pyhugegraph.client import PyHugeClient +from hugegraph_llm.config import huge_settings + class SchemaManager: def __init__(self, graph_name: str): @@ -39,9 +40,7 @@ def simple_schema(self, schema: Dict[str, Any]) -> Dict[str, Any]: if "vertexlabels" in schema: mini_schema["vertexlabels"] = [] for vertex in schema["vertexlabels"]: - new_vertex = { - key: vertex[key] for key in ["id", "name", "properties"] if key in vertex - } + new_vertex = {key: vertex[key] for key in ["id", "name", "properties"] if key in vertex} mini_schema["vertexlabels"].append(new_vertex) # Add necessary edgelabels items (4) @@ -49,9 +48,7 @@ def simple_schema(self, schema: Dict[str, Any]) -> Dict[str, Any]: mini_schema["edgelabels"] = [] for edge in schema["edgelabels"]: new_edge = { - key: edge[key] - for key in ["name", "source_label", "target_label", "properties"] - if key in edge + key: edge[key] for key in ["name", "source_label", "target_label", "properties"] if key in edge } mini_schema["edgelabels"].append(new_edge) diff --git a/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_semantic_index.py b/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_semantic_index.py index 5e9e8f449..1c75ea4b6 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_semantic_index.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_semantic_index.py @@ -29,9 +29,7 @@ class BuildSemanticIndex: def __init__(self, embedding: BaseEmbedding, vector_index: type[VectorStoreBase]): - self.vid_index = vector_index.from_name( - embedding.get_embedding_dim(), huge_settings.graph_name, "graph_vids" - ) + self.vid_index = vector_index.from_name(embedding.get_embedding_dim(), huge_settings.graph_name, "graph_vids") self.embedding = embedding self.sm = SchemaManager(huge_settings.graph_name) @@ -45,9 +43,7 @@ async def _get_embeddings_parallel(self, vids: list[str]) -> list[Any]: async def get_embeddings_with_semaphore(vid_list: list[str]) -> Any: async with sem: loop = asyncio.get_running_loop() - return await loop.run_in_executor( - None, self.embedding.get_texts_embeddings, vid_list - ) + return await loop.run_in_executor(None, self.embedding.get_texts_embeddings, vid_list) vid_batches = [vids[i : i + batch_size] for i in range(0, len(vids), batch_size)] tasks = [get_embeddings_with_semaphore(batch) for batch in vid_batches] @@ -62,9 +58,7 @@ async def get_embeddings_with_semaphore(vid_list: list[str]) -> Any: def run(self, context: Dict[str, Any]) -> Dict[str, Any]: vertexlabels = self.sm.schema.getSchema()["vertexlabels"] - all_pk_flag = bool(vertexlabels) and all( - data.get("id_strategy") == "PRIMARY_KEY" for data in vertexlabels - ) + all_pk_flag = bool(vertexlabels) and all(data.get("id_strategy") == "PRIMARY_KEY" for data in vertexlabels) past_vids = self.vid_index.get_all_properties() # TODO: We should build vid vector index separately, especially when the vertices may be very large diff --git a/hugegraph-llm/src/hugegraph_llm/operators/index_op/gremlin_example_index_query.py b/hugegraph-llm/src/hugegraph_llm/operators/index_op/gremlin_example_index_query.py index e3eea9f07..345db47a5 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/index_op/gremlin_example_index_query.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/index_op/gremlin_example_index_query.py @@ -40,14 +40,10 @@ def __init__( self.num_examples = num_examples if not vector_index.exist("gremlin_examples"): log.warning("No gremlin example index found, will generate one.") - self.vector_index = vector_index.from_name( - self.embedding.get_embedding_dim(), "gremlin_examples" - ) + self.vector_index = vector_index.from_name(self.embedding.get_embedding_dim(), "gremlin_examples") self._build_default_example_index() else: - self.vector_index = vector_index.from_name( - self.embedding.get_embedding_dim(), "gremlin_examples" - ) + self.vector_index = vector_index.from_name(self.embedding.get_embedding_dim(), "gremlin_examples") def _get_match_result(self, context: Dict[str, Any], query: str) -> List[Dict[str, Any]]: if self.num_examples <= 0: @@ -59,18 +55,14 @@ def _get_match_result(self, context: Dict[str, Any], query: str) -> List[Dict[st return self.vector_index.search(query_embedding, self.num_examples, dis_threshold=1.8) def _build_default_example_index(self): - properties = pd.read_csv(os.path.join(resource_path, "demo", "text2gremlin.csv")).to_dict( - orient="records" - ) + properties = pd.read_csv(os.path.join(resource_path, "demo", "text2gremlin.csv")).to_dict(orient="records") from concurrent.futures import ThreadPoolExecutor # TODO: reuse the logic in build_semantic_index.py (consider extract the batch-embedding method) with ThreadPoolExecutor() as executor: embeddings = list( tqdm( - executor.map( - self.embedding.get_text_embedding, [row["query"] for row in properties] - ), + executor.map(self.embedding.get_text_embedding, [row["query"] for row in properties]), total=len(properties), ) ) diff --git a/hugegraph-llm/src/hugegraph_llm/operators/index_op/vector_index_query.py b/hugegraph-llm/src/hugegraph_llm/operators/index_op/vector_index_query.py index 5ced65b41..979998a1a 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/index_op/vector_index_query.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/index_op/vector_index_query.py @@ -25,14 +25,10 @@ class VectorIndexQuery: - def __init__( - self, vector_index: type[VectorStoreBase], embedding: BaseEmbedding, topk: int = 3 - ): + def __init__(self, vector_index: type[VectorStoreBase], embedding: BaseEmbedding, topk: int = 3): self.embedding = embedding self.topk = topk - self.vector_index = vector_index.from_name( - embedding.get_embedding_dim(), huge_settings.graph_name, "chunks" - ) + self.vector_index = vector_index.from_name(embedding.get_embedding_dim(), huge_settings.graph_name, "chunks") def run(self, context: Dict[str, Any]) -> Dict[str, Any]: query = context.get("query") diff --git a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py index 9138f9e9b..45b327462 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py @@ -15,6 +15,12 @@ # specific language governing permissions and limitations # under the License. +""" +TODO: It is not clear whether there is any other dependence on the SCHEMA_EXAMPLE_PROMPT variable. +Because the SCHEMA_EXAMPLE_PROMPT variable will no longer change based on +prompt.extract_graph_prompt changes after the system loads, this does not seem to meet expectations. +""" + # pylint: disable=W0621 import asyncio @@ -25,11 +31,6 @@ from hugegraph_llm.models.llms.init_llm import LLMs from hugegraph_llm.utils.log import log -""" -TODO: It is not clear whether there is any other dependence on the SCHEMA_EXAMPLE_PROMPT variable. -Because the SCHEMA_EXAMPLE_PROMPT variable will no longer change based on -prompt.extract_graph_prompt changes after the system loads, this does not seem to meet expectations. -""" DEFAULT_ANSWER_TEMPLATE = prompt.answer_prompt @@ -62,13 +63,9 @@ def run(self, context: Dict[str, Any]) -> Dict[str, Any]: context_head_str, context_tail_str = self.init_llm(context) if self._context_body is not None: - context_str = ( - f"{context_head_str}\n" f"{self._context_body}\n" f"{context_tail_str}".strip("\n") - ) + context_str = f"{context_head_str}\n{self._context_body}\n{context_tail_str}".strip("\n") - final_prompt = self._prompt_template.format( - context_str=context_str, query_str=self._question - ) + final_prompt = self._prompt_template.format(context_str=context_str, query_str=self._question) response = self._llm.generate(prompt=final_prompt) return {"answer": response} @@ -104,9 +101,7 @@ def handle_vector_graph(self, context): vector_result_context = "No (vector)phrase related to the query." graph_result = context.get("graph_result") if graph_result: - graph_context_head = context.get( - "graph_context_head", "Knowledge from graphdb for the query:\n" - ) + graph_context_head = context.get("graph_context_head", "Knowledge from graphdb for the query:\n") graph_result_context = graph_context_head + "\n".join( f"{i + 1}. {res}" for i, res in enumerate(graph_result) ) @@ -119,13 +114,9 @@ async def run_streaming(self, context: Dict[str, Any]) -> AsyncGenerator[Dict[st context_head_str, context_tail_str = self.init_llm(context) if self._context_body is not None: - context_str = ( - f"{context_head_str}\n" f"{self._context_body}\n" f"{context_tail_str}".strip("\n") - ) + context_str = f"{context_head_str}\n{self._context_body}\n{context_tail_str}".strip("\n") - final_prompt = self._prompt_template.format( - context_str=context_str, query_str=self._question - ) + final_prompt = self._prompt_template.format(context_str=context_str, query_str=self._question) response = self._llm.generate(prompt=final_prompt) yield {"answer": response} return @@ -151,45 +142,23 @@ async def async_generate( final_prompt = self._question async_tasks["raw_task"] = asyncio.create_task(self._llm.agenerate(prompt=final_prompt)) if self._vector_only_answer: - context_str = ( - f"{context_head_str}\n" - f"{vector_result_context}\n" - f"{context_tail_str}".strip("\n") - ) + context_str = f"{context_head_str}\n{vector_result_context}\n{context_tail_str}".strip("\n") - final_prompt = self._prompt_template.format( - context_str=context_str, query_str=self._question - ) - async_tasks["vector_only_task"] = asyncio.create_task( - self._llm.agenerate(prompt=final_prompt) - ) + final_prompt = self._prompt_template.format(context_str=context_str, query_str=self._question) + async_tasks["vector_only_task"] = asyncio.create_task(self._llm.agenerate(prompt=final_prompt)) if self._graph_only_answer: - context_str = ( - f"{context_head_str}\n" - f"{graph_result_context}\n" - f"{context_tail_str}".strip("\n") - ) + context_str = f"{context_head_str}\n{graph_result_context}\n{context_tail_str}".strip("\n") - final_prompt = self._prompt_template.format( - context_str=context_str, query_str=self._question - ) - async_tasks["graph_only_task"] = asyncio.create_task( - self._llm.agenerate(prompt=final_prompt) - ) + final_prompt = self._prompt_template.format(context_str=context_str, query_str=self._question) + async_tasks["graph_only_task"] = asyncio.create_task(self._llm.agenerate(prompt=final_prompt)) if self._graph_vector_answer: context_body_str = f"{vector_result_context}\n{graph_result_context}" if context.get("graph_ratio", 0.5) < 0.5: context_body_str = f"{graph_result_context}\n{vector_result_context}" - context_str = ( - f"{context_head_str}\n" f"{context_body_str}\n" f"{context_tail_str}".strip("\n") - ) + context_str = f"{context_head_str}\n{context_body_str}\n{context_tail_str}".strip("\n") - final_prompt = self._prompt_template.format( - context_str=context_str, query_str=self._question - ) - async_tasks["graph_vector_task"] = asyncio.create_task( - self._llm.agenerate(prompt=final_prompt) - ) + final_prompt = self._prompt_template.format(context_str=context_str, query_str=self._question) + async_tasks["graph_vector_task"] = asyncio.create_task(self._llm.agenerate(prompt=final_prompt)) async_tasks_mapping = { "raw_task": "raw_answer", @@ -229,21 +198,13 @@ async def async_streaming_generate( if self._raw_answer: final_prompt = self._question async_generators.append( - self.__llm_generate_with_meta_info( - task_id=auto_id, target_key="raw_answer", prompt=final_prompt - ) + self.__llm_generate_with_meta_info(task_id=auto_id, target_key="raw_answer", prompt=final_prompt) ) auto_id += 1 if self._vector_only_answer: - context_str = ( - f"{context_head_str}\n" - f"{vector_result_context}\n" - f"{context_tail_str}".strip("\n") - ) + context_str = f"{context_head_str}\n{vector_result_context}\n{context_tail_str}".strip("\n") - final_prompt = self._prompt_template.format( - context_str=context_str, query_str=self._question - ) + final_prompt = self._prompt_template.format(context_str=context_str, query_str=self._question) async_generators.append( self.__llm_generate_with_meta_info( task_id=auto_id, target_key="vector_only_answer", prompt=final_prompt @@ -251,32 +212,20 @@ async def async_streaming_generate( ) auto_id += 1 if self._graph_only_answer: - context_str = ( - f"{context_head_str}\n" - f"{graph_result_context}\n" - f"{context_tail_str}".strip("\n") - ) + context_str = f"{context_head_str}\n{graph_result_context}\n{context_tail_str}".strip("\n") - final_prompt = self._prompt_template.format( - context_str=context_str, query_str=self._question - ) + final_prompt = self._prompt_template.format(context_str=context_str, query_str=self._question) async_generators.append( - self.__llm_generate_with_meta_info( - task_id=auto_id, target_key="graph_only_answer", prompt=final_prompt - ) + self.__llm_generate_with_meta_info(task_id=auto_id, target_key="graph_only_answer", prompt=final_prompt) ) auto_id += 1 if self._graph_vector_answer: context_body_str = f"{vector_result_context}\n{graph_result_context}" if context.get("graph_ratio", 0.5) < 0.5: context_body_str = f"{graph_result_context}\n{vector_result_context}" - context_str = ( - f"{context_head_str}\n" f"{context_body_str}\n" f"{context_tail_str}".strip("\n") - ) + context_str = f"{context_head_str}\n{context_body_str}\n{context_tail_str}".strip("\n") - final_prompt = self._prompt_template.format( - context_str=context_str, query_str=self._question - ) + final_prompt = self._prompt_template.format(context_str=context_str, query_str=self._question) async_generators.append( self.__llm_generate_with_meta_info( task_id=auto_id, target_key="graph_vector_answer", prompt=final_prompt diff --git a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/disambiguate_data.py b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/disambiguate_data.py index 2ac2eafff..754ded9e5 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/disambiguate_data.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/disambiguate_data.py @@ -16,7 +16,7 @@ # under the License. -from typing import Dict, List, Any +from typing import Any, Dict, List from hugegraph_llm.models.llms.base import BaseLLM from hugegraph_llm.operators.llm_op.info_extract import extract_triples_by_regex @@ -28,10 +28,10 @@ def generate_disambiguate_prompt(triples): {triples} If the second element of the triples expresses the same meaning but in different ways, unify them and keep the most concise expression. - + For example, if the input is: [("Alice", "friend", "Bob"), ("Simon", "is friends with", "Bob")] - + The output should be: [("Alice", "friend", "Bob"), ("Simon", "friend", "Bob")] """ @@ -51,10 +51,7 @@ def run(self, data: Dict) -> Dict[str, List[Any]]: llm_output = self.llm.generate(prompt=prompt) data["triples"] = [] extract_triples_by_regex(llm_output, data) - print( - f"LLM {self.__class__.__name__} input:{prompt} \n" - f" output: {llm_output} \n data: {data}" - ) + print(f"LLM {self.__class__.__name__} input:{prompt} \n output: {llm_output} \n data: {data}") data["call_count"] = data.get("call_count", 0) + 1 return data diff --git a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/gremlin_generate.py b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/gremlin_generate.py index 650834300..c3b4b9d0f 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/gremlin_generate.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/gremlin_generate.py @@ -18,7 +18,7 @@ import asyncio import json import re -from typing import Optional, List, Dict, Any, Union +from typing import Any, Dict, List, Optional, Union from hugegraph_llm.config import prompt from hugegraph_llm.models.llms.base import BaseLLM @@ -53,10 +53,7 @@ def _format_examples(self, examples: Optional[List[Dict[str, str]]]) -> Optional return None example_strings = [] for example in examples: - example_strings.append( - f"- query: {example['query']}\n" - f"- gremlin:\n```gremlin\n{example['gremlin']}\n```" - ) + example_strings.append(f"- query: {example['query']}\n- gremlin:\n```gremlin\n{example['gremlin']}\n```") return "\n\n".join(example_strings) def _format_vertices(self, vertices: Optional[List[str]]) -> Optional[str]: @@ -90,9 +87,7 @@ async def async_generate(self, context: Dict[str, Any]): vertices=self._format_vertices(vertices=self.vertices), properties=self._format_properties(properties=None), ) - async_tasks["initialized_answer"] = asyncio.create_task( - self.llm.agenerate(prompt=init_prompt) - ) + async_tasks["initialized_answer"] = asyncio.create_task(self.llm.agenerate(prompt=init_prompt)) raw_response = await async_tasks["raw_answer"] initialized_response = await async_tasks["initialized_answer"] diff --git a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/info_extract.py b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/info_extract.py index 8897e0fea..a786e52d4 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/info_extract.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/info_extract.py @@ -16,7 +16,7 @@ # under the License. import re -from typing import List, Any, Dict, Optional +from typing import Any, Dict, List, Optional from hugegraph_llm.document.chunk_split import ChunkSplitter from hugegraph_llm.models.llms.base import BaseLLM @@ -75,10 +75,7 @@ def generate_extract_triple_prompt(text, schema=None) -> str: if schema: return schema_real_prompt - log.warning( - "Recommend to provide a graph schema to improve the extraction accuracy. " - "Now using the default schema." - ) + log.warning("Recommend to provide a graph schema to improve the extraction accuracy. Now using the default schema.") return text_based_prompt @@ -107,9 +104,7 @@ def extract_triples_by_regex_with_schema(schema, text, graph): # TODO: use a more efficient way to compare the extract & input property p_lower = p.lower() for vertex in schema["vertices"]: - if vertex["vertex_label"] == label and any( - pp.lower() == p_lower for pp in vertex["properties"] - ): + if vertex["vertex_label"] == label and any(pp.lower() == p_lower for pp in vertex["properties"]): id = f"{label}-{s}" if id not in vertices_dict: vertices_dict[id] = { @@ -199,7 +194,5 @@ def valid(self, element_id: str, max_length: int = 256) -> bool: def _filter_long_id(self, graph) -> Dict[str, List[Any]]: graph["vertices"] = [vertex for vertex in graph["vertices"] if self.valid(vertex["id"])] - graph["edges"] = [ - edge for edge in graph["edges"] if self.valid(edge["start"]) and self.valid(edge["end"]) - ] + graph["edges"] = [edge for edge in graph["edges"] if self.valid(edge["start"]) and self.valid(edge["end"])] return graph diff --git a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/keyword_extract.py b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/keyword_extract.py index 48369b4ec..0d8b4a179 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/keyword_extract.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/keyword_extract.py @@ -19,7 +19,7 @@ import time from typing import Any, Dict, Optional -from hugegraph_llm.config import prompt, llm_settings +from hugegraph_llm.config import llm_settings, prompt from hugegraph_llm.models.llms.base import BaseLLM from hugegraph_llm.models.llms.init_llm import LLMs from hugegraph_llm.operators.document_op.textrank_word_extract import ( @@ -44,9 +44,7 @@ def __init__( self._max_keywords = max_keywords self._extract_template = extract_template or KEYWORDS_EXTRACT_TPL self._extract_method = llm_settings.keyword_extract_type.lower() - self._textrank_model = MultiLingualTextRank( - keyword_num=max_keywords, window_size=llm_settings.window_size - ) + self._textrank_model = MultiLingualTextRank(keyword_num=max_keywords, window_size=llm_settings.window_size) def run(self, context: Dict[str, Any]) -> Dict[str, Any]: if self._query is None: @@ -68,11 +66,7 @@ def run(self, context: Dict[str, Any]) -> Dict[str, Any]: max_keyword_num = self._max_keywords self._max_keywords = max(1, max_keyword_num) - method = ( - (context.get("extract_method", self._extract_method) or "LLM") - .strip() - .lower() - ) + method = (context.get("extract_method", self._extract_method) or "LLM").strip().lower() if method == "llm": # LLM method ranks = self._extract_with_llm() @@ -101,9 +95,7 @@ def _extract_with_llm(self) -> Dict[str, float]: response = self._llm.generate(prompt=prompt_run) end_time = time.perf_counter() log.debug("LLM Keyword extraction time: %.2f seconds", end_time - start_time) - keywords = self._extract_keywords_from_response( - response=response, lowercase=False, start_token="KEYWORDS:" - ) + keywords = self._extract_keywords_from_response(response=response, lowercase=False, start_token="KEYWORDS:") return keywords def _extract_with_textrank(self) -> Dict[str, float]: @@ -117,9 +109,7 @@ def _extract_with_textrank(self) -> Dict[str, float]: except MemoryError as e: log.critical("TextRank memory error (text too large?): %s", e) end_time = time.perf_counter() - log.debug( - "TextRank Keyword extraction time: %.2f seconds", end_time - start_time - ) + log.debug("TextRank Keyword extraction time: %.2f seconds", end_time - start_time) return ranks def _extract_with_hybrid(self) -> Dict[str, float]: @@ -180,9 +170,7 @@ def _extract_keywords_from_response( continue score_val = float(score_raw) if not 0.0 <= score_val <= 1.0: - log.warning( - "Score out of range for %s: %s", word_raw, score_val - ) + log.warning("Score out of range for %s: %s", word_raw, score_val) score_val = min(1.0, max(0.0, score_val)) word_out = word_raw.lower() if lowercase else word_raw results[word_out] = score_val diff --git a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/prompt_generate.py b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/prompt_generate.py index 058d1bce9..b5951e449 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/prompt_generate.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/prompt_generate.py @@ -18,9 +18,10 @@ import json import os -from typing import Dict, Any +from typing import Any, Dict -from hugegraph_llm.config import resource_path, prompt as prompt_tpl +from hugegraph_llm.config import prompt as prompt_tpl +from hugegraph_llm.config import resource_path from hugegraph_llm.models.llms.base import BaseLLM from hugegraph_llm.utils.log import log diff --git a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/property_graph_extract.py b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/property_graph_extract.py index acdd7a950..bf3630192 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/property_graph_extract.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/property_graph_extract.py @@ -19,7 +19,7 @@ import json import re -from typing import List, Any, Dict +from typing import Any, Dict, List from hugegraph_llm.config import prompt from hugegraph_llm.document.chunk_split import ChunkSplitter @@ -125,8 +125,7 @@ def _extract_and_filter_label(self, schema, text) -> List[Dict[str, Any]]: json_match = re.search(r"({.*})", text, re.DOTALL) if not json_match: log.critical( - "Invalid property graph! No JSON object found, " - "please check the output format example in prompt." + "Invalid property graph! No JSON object found, please check the output format example in prompt." ) return [] json_str = json_match.group(1).strip() @@ -135,11 +134,7 @@ def _extract_and_filter_label(self, schema, text) -> List[Dict[str, Any]]: try: property_graph = json.loads(json_str) # Expect property_graph to be a dict with keys "vertices" and "edges" - if not ( - isinstance(property_graph, dict) - and "vertices" in property_graph - and "edges" in property_graph - ): + if not (isinstance(property_graph, dict) and "vertices" in property_graph and "edges" in property_graph): log.critical("Invalid property graph format; expecting 'vertices' and 'edges'.") return items @@ -170,7 +165,5 @@ def process_items(item_list, valid_labels, item_type): process_items(property_graph["vertices"], vertex_label_set, "vertex") process_items(property_graph["edges"], edge_label_set, "edge") except json.JSONDecodeError: - log.critical( - "Invalid property graph JSON! Please check the extracted JSON data carefully" - ) + log.critical("Invalid property graph JSON! Please check the extracted JSON data carefully") return items diff --git a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/schema_build.py b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/schema_build.py index 928948413..5fa130e26 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/schema_build.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/schema_build.py @@ -17,7 +17,7 @@ import json import re -from typing import List, Dict, Any, Optional +from typing import Any, Dict, List, Optional from hugegraph_llm.models.llms.base import BaseLLM from hugegraph_llm.models.llms.init_llm import LLMs diff --git a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/unstructured_data_utils.py b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/unstructured_data_utils.py index 38eabb16e..6beeb0291 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/unstructured_data_utils.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/unstructured_data_utils.py @@ -20,10 +20,7 @@ import re REGEX = ( - r"Nodes:\s+(.*?)\s?\s?" - r"Relationships:\s?\s?" - r"NodesSchemas:\s+(.*?)\s?\s?" - r"RelationshipsSchemas:\s?\s?(.*)" + r"Nodes:\s+(.*?)\s?\s?" r"Relationships:\s?\s?" r"NodesSchemas:\s+(.*?)\s?\s?" r"RelationshipsSchemas:\s?\s?(.*)" ) INTERNAL_REGEX = r"\[(.*?)\]" JSON_REGEX = r"\{.*\}" diff --git a/hugegraph-llm/src/hugegraph_llm/operators/operator_list.py b/hugegraph-llm/src/hugegraph_llm/operators/operator_list.py index 6b6bf48e2..b7ff379e4 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/operator_list.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/operator_list.py @@ -14,37 +14,38 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Optional, List, Literal, Union +from typing import List, Literal, Optional, Union + +from pyhugegraph.client import PyHugeClient from hugegraph_llm.config import huge_settings from hugegraph_llm.models.embeddings.base import BaseEmbedding from hugegraph_llm.models.llms.base import BaseLLM from hugegraph_llm.operators.common_op.check_schema import CheckSchema +from hugegraph_llm.operators.common_op.merge_dedup_rerank import MergeDedupRerank from hugegraph_llm.operators.common_op.print_result import PrintResult +from hugegraph_llm.operators.document_op.chunk_split import ChunkSplit +from hugegraph_llm.operators.document_op.word_extract import WordExtract +from hugegraph_llm.operators.hugegraph_op.commit_to_hugegraph import Commit2Graph +from hugegraph_llm.operators.hugegraph_op.fetch_graph_data import FetchGraphData from hugegraph_llm.operators.hugegraph_op.schema_manager import SchemaManager from hugegraph_llm.operators.index_op.build_gremlin_example_index import ( BuildGremlinExampleIndex, ) +from hugegraph_llm.operators.index_op.build_semantic_index import BuildSemanticIndex +from hugegraph_llm.operators.index_op.build_vector_index import BuildVectorIndex from hugegraph_llm.operators.index_op.gremlin_example_index_query import ( GremlinExampleIndexQuery, ) -from hugegraph_llm.operators.llm_op.gremlin_generate import GremlinGenerateSynthesize -from hugegraph_llm.utils.decorators import log_time, log_operator_time, record_rpm -from hugegraph_llm.operators.hugegraph_op.fetch_graph_data import FetchGraphData -from hugegraph_llm.operators.document_op.chunk_split import ChunkSplit -from hugegraph_llm.operators.llm_op.info_extract import InfoExtract -from hugegraph_llm.operators.llm_op.property_graph_extract import PropertyGraphExtract -from hugegraph_llm.operators.llm_op.disambiguate_data import DisambiguateData -from hugegraph_llm.operators.hugegraph_op.commit_to_hugegraph import Commit2Graph -from hugegraph_llm.operators.index_op.build_semantic_index import BuildSemanticIndex -from hugegraph_llm.operators.index_op.build_vector_index import BuildVectorIndex -from hugegraph_llm.operators.document_op.word_extract import WordExtract -from hugegraph_llm.operators.llm_op.keyword_extract import KeywordExtract from hugegraph_llm.operators.index_op.semantic_id_query import SemanticIdQuery from hugegraph_llm.operators.index_op.vector_index_query import VectorIndexQuery -from hugegraph_llm.operators.common_op.merge_dedup_rerank import MergeDedupRerank from hugegraph_llm.operators.llm_op.answer_synthesize import AnswerSynthesize -from pyhugegraph.client import PyHugeClient +from hugegraph_llm.operators.llm_op.disambiguate_data import DisambiguateData +from hugegraph_llm.operators.llm_op.gremlin_generate import GremlinGenerateSynthesize +from hugegraph_llm.operators.llm_op.info_extract import InfoExtract +from hugegraph_llm.operators.llm_op.keyword_extract import KeywordExtract +from hugegraph_llm.operators.llm_op.property_graph_extract import PropertyGraphExtract +from hugegraph_llm.utils.decorators import log_operator_time, log_time, record_rpm class OperatorList: @@ -68,9 +69,7 @@ def example_index_build(self, examples): self.operators.append(BuildGremlinExampleIndex(self.embedding, examples)) return self - def import_schema( - self, from_hugegraph=None, from_extraction=None, from_user_defined=None - ): + def import_schema(self, from_hugegraph=None, from_extraction=None, from_user_defined=None): if from_hugegraph: self.operators.append(SchemaManager(from_hugegraph)) elif from_user_defined: @@ -91,9 +90,7 @@ def gremlin_generate_synthesize( gremlin_prompt: Optional[str] = None, vertices: Optional[List[str]] = None, ): - self.operators.append( - GremlinGenerateSynthesize(self.llm, schema, vertices, gremlin_prompt) - ) + self.operators.append(GremlinGenerateSynthesize(self.llm, schema, vertices, gremlin_prompt)) return self def print_result(self): @@ -125,9 +122,7 @@ def extract_info( elif extract_type == "property_graph": self.operators.append(PropertyGraphExtract(self.llm, example_prompt)) else: - raise ValueError( - f"invalid extract_type: {extract_type!r}, expected 'triples' or 'property_graph'" - ) + raise ValueError(f"invalid extract_type: {extract_type!r}, expected 'triples' or 'property_graph'") return self def disambiguate_word_sense(self): @@ -168,9 +163,7 @@ def extract_keywords( :param extract_template: Template for keyword extraction. :return: Self-instance for chaining. """ - self.operators.append( - KeywordExtract(text=text, extract_template=extract_template) - ) + self.operators.append(KeywordExtract(text=text, extract_template=extract_template)) return self def keywords_to_vid( diff --git a/hugegraph-llm/src/hugegraph_llm/state/ai_state.py b/hugegraph-llm/src/hugegraph_llm/state/ai_state.py index 7b39776fc..7dce7e857 100644 --- a/hugegraph-llm/src/hugegraph_llm/state/ai_state.py +++ b/hugegraph-llm/src/hugegraph_llm/state/ai_state.py @@ -13,8 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import AsyncGenerator, Union, List, Optional, Any, Dict -from pycgraph import GParam, CStatus +from typing import Any, AsyncGenerator, Dict, List, Optional, Union + +from pycgraph import CStatus, GParam from hugegraph_llm.utils.log import log @@ -259,11 +260,7 @@ def to_json(self): dict: A dictionary containing non-None instance members and their serialized values. """ # Only export instance attributes (excluding methods and class attributes) whose values are not None - return { - k: v - for k, v in self.__dict__.items() - if not k.startswith("_") and v is not None - } + return {k: v for k, v in self.__dict__.items() if not k.startswith("_") and v is not None} # Implement a method that assigns keys from data_json as WkFlowState member variables def assign_from_json(self, data_json: dict): @@ -274,6 +271,4 @@ def assign_from_json(self, data_json: dict): if hasattr(self, k): setattr(self, k, v) else: - log.warning( - "key %s should be a member of WkFlowState & type %s", k, type(v) - ) + log.warning("key %s should be a member of WkFlowState & type %s", k, type(v)) diff --git a/hugegraph-llm/src/hugegraph_llm/utils/decorators.py b/hugegraph-llm/src/hugegraph_llm/utils/decorators.py index 2914c4b28..8445661c8 100644 --- a/hugegraph-llm/src/hugegraph_llm/utils/decorators.py +++ b/hugegraph-llm/src/hugegraph_llm/utils/decorators.py @@ -18,7 +18,7 @@ import asyncio import time from functools import wraps -from typing import Optional, Any, Callable +from typing import Any, Callable, Optional from hugegraph_llm.utils.log import log diff --git a/hugegraph-llm/src/hugegraph_llm/utils/embedding_utils.py b/hugegraph-llm/src/hugegraph_llm/utils/embedding_utils.py index b2f485cea..45eb18626 100644 --- a/hugegraph-llm/src/hugegraph_llm/utils/embedding_utils.py +++ b/hugegraph-llm/src/hugegraph_llm/utils/embedding_utils.py @@ -24,9 +24,7 @@ from hugegraph_llm.models.embeddings.base import BaseEmbedding -async def _get_batch_with_progress( - embedding: BaseEmbedding, batch: list[str], pbar: tqdm -) -> list[Any]: +async def _get_batch_with_progress(embedding: BaseEmbedding, batch: list[str], pbar: tqdm) -> list[Any]: result = await embedding.async_get_texts_embeddings(batch) pbar.update(1) return result diff --git a/hugegraph-llm/src/hugegraph_llm/utils/graph_index_utils.py b/hugegraph-llm/src/hugegraph_llm/utils/graph_index_utils.py index e8080b631..423526ea4 100644 --- a/hugegraph-llm/src/hugegraph_llm/utils/graph_index_utils.py +++ b/hugegraph-llm/src/hugegraph_llm/utils/graph_index_utils.py @@ -17,17 +17,18 @@ import traceback -from typing import Dict, Any, Union, List +from typing import Any, Dict, List, Union import gradio as gr +from pyhugegraph.client import PyHugeClient + from hugegraph_llm.flows import FlowName from hugegraph_llm.flows.scheduler import SchedulerSingleton -from pyhugegraph.client import PyHugeClient +from ..config import huge_settings from .hugegraph_utils import clean_hg_data from .log import log from .vector_index_utils import read_documents -from ..config import huge_settings def get_graph_index_info(): @@ -41,8 +42,8 @@ def get_graph_index_info(): def clean_all_graph_index(): # Lazy import to avoid circular dependency - from .vector_index_utils import get_vector_index_class # pylint: disable=import-outside-toplevel from ..config import index_settings # pylint: disable=import-outside-toplevel + from .vector_index_utils import get_vector_index_class # pylint: disable=import-outside-toplevel vector_index = get_vector_index_class(index_settings.cur_vector_index) vector_index.clean(huge_settings.graph_name, "graph_vids") @@ -51,9 +52,7 @@ def clean_all_graph_index(): gr.Info("Clear graph index and text2gql index successfully!") -def get_vertex_details( - vertex_ids: List[str], context: Dict[str, Any] -) -> List[Dict[str, Any]]: +def get_vertex_details(vertex_ids: List[str], context: Dict[str, Any]) -> List[Dict[str, Any]]: if isinstance(context.get("graph_client"), PyHugeClient): client = context["graph_client"] else: @@ -85,9 +84,7 @@ def extract_graph(input_file, input_text, schema, example_prompt) -> str: return "ERROR: please input with correct schema/format." try: - return scheduler.schedule_flow( - FlowName.GRAPH_EXTRACT, schema, texts, example_prompt, "property_graph" - ) + return scheduler.schedule_flow(FlowName.GRAPH_EXTRACT, schema, texts, example_prompt, "property_graph") except Exception as e: # pylint: disable=broad-exception-caught log.error(e) raise gr.Error(str(e)) @@ -117,9 +114,7 @@ def import_graph_data(data: str, schema: str) -> Union[str, Dict[str, Any]]: def build_schema(input_text, query_example, few_shot): scheduler = SchedulerSingleton.get_instance() try: - return scheduler.schedule_flow( - FlowName.BUILD_SCHEMA, input_text, query_example, few_shot - ) + return scheduler.schedule_flow(FlowName.BUILD_SCHEMA, input_text, query_example, few_shot) except Exception as e: # pylint: disable=broad-exception-caught log.error("Schema generation failed: %s", e) raise gr.Error(f"Schema generation failed: {e}") diff --git a/hugegraph-llm/src/hugegraph_llm/utils/hugegraph_utils.py b/hugegraph-llm/src/hugegraph_llm/utils/hugegraph_utils.py index 147c0074c..a151110f0 100644 --- a/hugegraph-llm/src/hugegraph_llm/utils/hugegraph_utils.py +++ b/hugegraph-llm/src/hugegraph_llm/utils/hugegraph_utils.py @@ -19,12 +19,13 @@ import os import shutil from datetime import datetime + import requests +from pyhugegraph.client import PyHugeClient from requests.auth import HTTPBasicAuth from hugegraph_llm.config import huge_settings, resource_path from hugegraph_llm.utils.log import log -from pyhugegraph.client import PyHugeClient MAX_BACKUP_DIRS = 7 MAX_VERTICES = 100000 @@ -53,9 +54,7 @@ def init_hg_test_data(): schema = client.schema() schema.propertyKey("name").asText().ifNotExist().create() schema.propertyKey("birthDate").asText().ifNotExist().create() - schema.vertexLabel("Person").properties( - "name", "birthDate" - ).useCustomizeStringId().ifNotExist().create() + schema.vertexLabel("Person").properties("name", "birthDate").useCustomizeStringId().ifNotExist().create() schema.vertexLabel("Movie").properties("name").useCustomizeStringId().ifNotExist().create() schema.edgeLabel("ActedIn").sourceLabel("Person").targetLabel("Movie").ifNotExist().create() @@ -140,10 +139,7 @@ def write_backup_file(client, backup_subdir, filename, query, all_pk_flag): elif filename == "vertices.json": data_full = client.gremlin().exec(query)["data"][0]["vertices"] data = ( - [ - {key: value for key, value in vertex.items() if key != "id"} - for vertex in data_full - ] + [{key: value for key, value in vertex.items() if key != "id"} for vertex in data_full] if all_pk_flag else data_full ) @@ -152,9 +148,7 @@ def write_backup_file(client, backup_subdir, filename, query, all_pk_flag): data_full = query if isinstance(data_full, dict) and "schema" in data_full: groovy_filename = filename.replace(".json", ".groovy") - with open( - os.path.join(backup_subdir, groovy_filename), "w", encoding="utf-8" - ) as groovy_file: + with open(os.path.join(backup_subdir, groovy_filename), "w", encoding="utf-8") as groovy_file: groovy_file.write(str(data_full["schema"])) else: data = data_full @@ -164,9 +158,7 @@ def write_backup_file(client, backup_subdir, filename, query, all_pk_flag): def manage_backup_retention(): try: backup_dirs = [ - os.path.join(BACKUP_DIR, d) - for d in os.listdir(BACKUP_DIR) - if os.path.isdir(os.path.join(BACKUP_DIR, d)) + os.path.join(BACKUP_DIR, d) for d in os.listdir(BACKUP_DIR) if os.path.isdir(os.path.join(BACKUP_DIR, d)) ] backup_dirs.sort(key=os.path.getctime) if len(backup_dirs) > MAX_BACKUP_DIRS: diff --git a/hugegraph-llm/src/hugegraph_llm/utils/vector_index_utils.py b/hugegraph-llm/src/hugegraph_llm/utils/vector_index_utils.py index e96a5adec..bb9536d25 100644 --- a/hugegraph-llm/src/hugegraph_llm/utils/vector_index_utils.py +++ b/hugegraph-llm/src/hugegraph_llm/utils/vector_index_utils.py @@ -47,9 +47,7 @@ def read_documents(input_file, input_text): texts.append(text) elif full_path.endswith(".pdf"): # TODO: support PDF file - raise gr.Error( - "PDF will be supported later! Try to upload text/docx now" - ) + raise gr.Error("PDF will be supported later! Try to upload text/docx now") else: raise gr.Error("Please input txt or docx file.") else: diff --git a/hugegraph-llm/src/tests/config/test_config.py b/hugegraph-llm/src/tests/config/test_config.py index 7f480befa..63dc6cb94 100644 --- a/hugegraph-llm/src/tests/config/test_config.py +++ b/hugegraph-llm/src/tests/config/test_config.py @@ -22,6 +22,7 @@ class TestConfig(unittest.TestCase): def test_config(self): import nltk + from hugegraph_llm.config import resource_path nltk.data.path.append(resource_path) diff --git a/hugegraph-llm/src/tests/conftest.py b/hugegraph-llm/src/tests/conftest.py index 32e3c6bf2..ff38f3d54 100644 --- a/hugegraph-llm/src/tests/conftest.py +++ b/hugegraph-llm/src/tests/conftest.py @@ -15,9 +15,10 @@ # specific language governing permissions and limitations # under the License. +import logging import os import sys -import logging + import nltk # Get project root directory @@ -27,6 +28,8 @@ # Add src directory to Python path src_path = os.path.join(project_root, "src") sys.path.insert(0, src_path) + + # Download NLTK resources def download_nltk_resources(): try: @@ -34,6 +37,8 @@ def download_nltk_resources(): except LookupError: logging.info("Downloading NLTK stopwords resource...") nltk.download("stopwords", quiet=True) + + # Download NLTK resources before tests start download_nltk_resources() # Set environment variable to skip external service tests diff --git a/hugegraph-llm/src/tests/data/documents/sample.txt b/hugegraph-llm/src/tests/data/documents/sample.txt index 4e4726dae..e76fb0b67 100644 --- a/hugegraph-llm/src/tests/data/documents/sample.txt +++ b/hugegraph-llm/src/tests/data/documents/sample.txt @@ -3,4 +3,4 @@ Bob is 30 years old and is a data scientist at DataInc. Alice and Bob are colleagues and they collaborate on AI projects. They are working on a knowledge graph project that uses natural language processing. The project aims to extract structured information from unstructured text. -TechCorp and DataInc are partner companies in the technology sector. \ No newline at end of file +TechCorp and DataInc are partner companies in the technology sector. diff --git a/hugegraph-llm/src/tests/data/kg/schema.json b/hugegraph-llm/src/tests/data/kg/schema.json index 386b88b66..51e9f3ec9 100644 --- a/hugegraph-llm/src/tests/data/kg/schema.json +++ b/hugegraph-llm/src/tests/data/kg/schema.json @@ -39,4 +39,4 @@ "properties": [] } ] -} \ No newline at end of file +} diff --git a/hugegraph-llm/src/tests/data/prompts/test_prompts.yaml b/hugegraph-llm/src/tests/data/prompts/test_prompts.yaml index b55f7b258..e622d4419 100644 --- a/hugegraph-llm/src/tests/data/prompts/test_prompts.yaml +++ b/hugegraph-llm/src/tests/data/prompts/test_prompts.yaml @@ -23,10 +23,10 @@ rag_prompt: user: | Context: {context} - + Question: {query} - + Answer: kg_extraction_prompt: @@ -36,10 +36,10 @@ kg_extraction_prompt: user: | Text: {text} - + Schema: {schema} - + Extract entities and relationships from the text according to the schema: summarization_prompt: @@ -49,5 +49,5 @@ summarization_prompt: user: | Text: {text} - - Please provide a concise summary: \ No newline at end of file + + Please provide a concise summary: diff --git a/hugegraph-llm/src/tests/document/test_text_loader.py b/hugegraph-llm/src/tests/document/test_text_loader.py index e552d8950..f9ac6730b 100644 --- a/hugegraph-llm/src/tests/document/test_text_loader.py +++ b/hugegraph-llm/src/tests/document/test_text_loader.py @@ -39,9 +39,7 @@ def setUp(self): # pylint: disable=consider-using-with self.temp_dir = tempfile.TemporaryDirectory() self.temp_file_path = os.path.join(self.temp_dir.name, "test_file.txt") - self.test_content = ( - "This is a test file.\nIt has multiple lines.\nThis is for testing the TextLoader." - ) + self.test_content = "This is a test file.\nIt has multiple lines.\nThis is for testing the TextLoader." # Write test content to the file with open(self.temp_file_path, "w", encoding="utf-8") as f: diff --git a/hugegraph-llm/src/tests/integration/test_graph_rag_pipeline.py b/hugegraph-llm/src/tests/integration/test_graph_rag_pipeline.py index 35b6d0857..243963c3b 100644 --- a/hugegraph-llm/src/tests/integration/test_graph_rag_pipeline.py +++ b/hugegraph-llm/src/tests/integration/test_graph_rag_pipeline.py @@ -20,6 +20,7 @@ import tempfile import unittest from unittest.mock import MagicMock + from tests.utils.mock import MockEmbedding @@ -192,8 +193,7 @@ def setUp(self): self.mock_answer_synthesize = MagicMock() self.mock_answer_synthesize.return_value = { "answer": ( - "John Doe is a 30-year-old software engineer. " - "The Matrix is a science fiction movie released in 1999." + "John Doe is a 30-year-old software engineer. The Matrix is a science fiction movie released in 1999." ) } diff --git a/hugegraph-llm/src/tests/integration/test_kg_construction.py b/hugegraph-llm/src/tests/integration/test_kg_construction.py index 52f3667d8..daa7b191b 100644 --- a/hugegraph-llm/src/tests/integration/test_kg_construction.py +++ b/hugegraph-llm/src/tests/integration/test_kg_construction.py @@ -187,10 +187,10 @@ def test_kg_construction_end_to_end(self, *args): ] # Mock KG constructor methods - with patch.object( - self.kg_constructor, "extract_entities", return_value=mock_entities - ), patch.object(self.kg_constructor, "extract_relations", return_value=mock_relations): - + with ( + patch.object(self.kg_constructor, "extract_entities", return_value=mock_entities), + patch.object(self.kg_constructor, "extract_relations", return_value=mock_relations), + ): # Construct knowledge graph - use only one document to avoid duplicate relations from mocking kg = self.kg_constructor.construct_from_documents([self.test_docs[0]]) diff --git a/hugegraph-llm/src/tests/integration/test_rag_pipeline.py b/hugegraph-llm/src/tests/integration/test_rag_pipeline.py index 72b4663b6..bb5e43642 100644 --- a/hugegraph-llm/src/tests/integration/test_rag_pipeline.py +++ b/hugegraph-llm/src/tests/integration/test_rag_pipeline.py @@ -26,9 +26,9 @@ with_mock_openai_client, with_mock_openai_embedding, ) - from tests.utils.mock import VectorIndex + # 创建模拟类,替代缺失的模块 class Document: """模拟的Document类""" diff --git a/hugegraph-llm/src/tests/middleware/test_middleware.py b/hugegraph-llm/src/tests/middleware/test_middleware.py index 3691da309..90e71272a 100644 --- a/hugegraph-llm/src/tests/middleware/test_middleware.py +++ b/hugegraph-llm/src/tests/middleware/test_middleware.py @@ -19,6 +19,7 @@ from unittest.mock import AsyncMock, MagicMock, patch from fastapi import FastAPI + from hugegraph_llm.middleware.middleware import UseTimeMiddleware diff --git a/hugegraph-llm/src/tests/models/embeddings/test_ollama_embedding.py b/hugegraph-llm/src/tests/models/embeddings/test_ollama_embedding.py index 1d1fecc40..c919a2d65 100644 --- a/hugegraph-llm/src/tests/models/embeddings/test_ollama_embedding.py +++ b/hugegraph-llm/src/tests/models/embeddings/test_ollama_embedding.py @@ -27,15 +27,13 @@ class TestOllamaEmbedding(unittest.TestCase): def setUp(self): self.skip_external = os.getenv("SKIP_EXTERNAL_SERVICES", "false").lower() == "true" - @unittest.skipIf(os.getenv("SKIP_EXTERNAL_SERVICES", "false").lower() == "true", - "Skipping external service tests") + @unittest.skipIf(os.getenv("SKIP_EXTERNAL_SERVICES", "false").lower() == "true", "Skipping external service tests") def test_get_text_embedding(self): ollama_embedding = OllamaEmbedding(model_name="quentinz/bge-large-zh-v1.5") embedding = ollama_embedding.get_text_embedding("hello world") print(embedding) - @unittest.skipIf(os.getenv("SKIP_EXTERNAL_SERVICES", "false").lower() == "true", - "Skipping external service tests") + @unittest.skipIf(os.getenv("SKIP_EXTERNAL_SERVICES", "false").lower() == "true", "Skipping external service tests") def test_get_cosine_similarity(self): ollama_embedding = OllamaEmbedding(model_name="quentinz/bge-large-zh-v1.5") embedding1 = ollama_embedding.get_text_embedding("hello world") diff --git a/hugegraph-llm/src/tests/models/llms/test_ollama_client.py b/hugegraph-llm/src/tests/models/llms/test_ollama_client.py index 8f8cc48ef..bbc014fd2 100644 --- a/hugegraph-llm/src/tests/models/llms/test_ollama_client.py +++ b/hugegraph-llm/src/tests/models/llms/test_ollama_client.py @@ -25,21 +25,17 @@ class TestOllamaClient(unittest.TestCase): def setUp(self): self.skip_external = os.getenv("SKIP_EXTERNAL_SERVICES", "false").lower() == "true" - @unittest.skipIf(os.getenv("SKIP_EXTERNAL_SERVICES", "false").lower() == "true", - "Skipping external service tests") + @unittest.skipIf(os.getenv("SKIP_EXTERNAL_SERVICES", "false").lower() == "true", "Skipping external service tests") def test_generate(self): ollama_client = OllamaClient(model="llama3:8b-instruct-fp16") response = ollama_client.generate(prompt="What is the capital of France?") print(response) - @unittest.skipIf(os.getenv("SKIP_EXTERNAL_SERVICES", "false").lower() == "true", - "Skipping external service tests") + @unittest.skipIf(os.getenv("SKIP_EXTERNAL_SERVICES", "false").lower() == "true", "Skipping external service tests") def test_stream_generate(self): ollama_client = OllamaClient(model="llama3:8b-instruct-fp16") def on_token_callback(chunk): print(chunk, end="", flush=True) - ollama_client.generate_streaming( - prompt="What is the capital of France?", on_token_callback=on_token_callback - ) + ollama_client.generate_streaming(prompt="What is the capital of France?", on_token_callback=on_token_callback) diff --git a/hugegraph-llm/src/tests/models/llms/test_openai_client.py b/hugegraph-llm/src/tests/models/llms/test_openai_client.py index 18b55daa1..b9f8a113d 100644 --- a/hugegraph-llm/src/tests/models/llms/test_openai_client.py +++ b/hugegraph-llm/src/tests/models/llms/test_openai_client.py @@ -27,9 +27,7 @@ def setUp(self): """Set up test fixtures and common mock objects.""" # Create mock completion response self.mock_completion_response = MagicMock() - self.mock_completion_response.choices = [ - MagicMock(message=MagicMock(content="Paris")) - ] + self.mock_completion_response.choices = [MagicMock(message=MagicMock(content="Paris"))] self.mock_completion_response.usage = MagicMock() self.mock_completion_response.usage.model_dump_json.return_value = ( '{"prompt_tokens": 10, "completion_tokens": 5}' @@ -136,9 +134,11 @@ def on_token_callback(chunk): collected_tokens.append(chunk) # Collect all tokens from the generator - tokens = list(openai_client.generate_streaming( - prompt="What is the capital of France?", on_token_callback=on_token_callback - )) + tokens = list( + openai_client.generate_streaming( + prompt="What is the capital of France?", on_token_callback=on_token_callback + ) + ) # Verify the response self.assertEqual(tokens, ["Pa", "ris"]) @@ -203,6 +203,7 @@ def test_generate_authentication_error(self, mock_openai_class): """Test generate method with authentication error.""" # Setup mock client to raise OpenAI 的认证错误 from openai import AuthenticationError + mock_client = MagicMock() # Create a properly formatted AuthenticationError @@ -211,9 +212,7 @@ def test_generate_authentication_error(self, mock_openai_class): mock_response.headers = {} auth_error = AuthenticationError( - message="Invalid API key", - response=mock_response, - body={"error": {"message": "Invalid API key"}} + message="Invalid API key", response=mock_response, body={"error": {"message": "Invalid API key"}} ) mock_client.chat.completions.create.side_effect = auth_error mock_openai_class.return_value = mock_client diff --git a/hugegraph-llm/src/tests/operators/common_op/test_merge_dedup_rerank.py b/hugegraph-llm/src/tests/operators/common_op/test_merge_dedup_rerank.py index a9284a3ff..62f2c4934 100644 --- a/hugegraph-llm/src/tests/operators/common_op/test_merge_dedup_rerank.py +++ b/hugegraph-llm/src/tests/operators/common_op/test_merge_dedup_rerank.py @@ -38,8 +38,7 @@ def setUp(self): self.vector_results = [ "Artificial intelligence is a branch of computer science.", "AI is the simulation of human intelligence by machines.", - "Artificial intelligence involves creating systems that can " - "perform tasks requiring human intelligence.", + "Artificial intelligence involves creating systems that can perform tasks requiring human intelligence.", ] self.graph_results = [ "AI research includes reasoning, knowledge representation, " @@ -193,9 +192,7 @@ def test_rerank_with_vertex_degree(self, mock_rerankers_class, mock_llm_settings } # Call the method - reranked = merger._rerank_with_vertex_degree( - self.query, results, 2, vertex_degree_list, knowledge_with_degree - ) + reranked = merger._rerank_with_vertex_degree(self.query, results, 2, vertex_degree_list, knowledge_with_degree) # Verify that reranker was called for each vertex degree list self.assertEqual(mock_reranker.get_rerank_lists.call_count, 2) @@ -213,9 +210,7 @@ def test_rerank_with_vertex_degree_no_list(self): merger._dedup_and_rerank.return_value = ["result1", "result2"] # Call the method with empty vertex_degree_list - reranked = merger._rerank_with_vertex_degree( - self.query, ["result1", "result2"], 2, [], {} - ) + reranked = merger._rerank_with_vertex_degree(self.query, ["result1", "result2"], 2, [], {}) # Verify that _dedup_and_rerank was called merger._dedup_and_rerank.assert_called_once() diff --git a/hugegraph-llm/src/tests/operators/hugegraph_op/test_commit_to_hugegraph.py b/hugegraph-llm/src/tests/operators/hugegraph_op/test_commit_to_hugegraph.py index 7227a0535..634fdb961 100644 --- a/hugegraph-llm/src/tests/operators/hugegraph_op/test_commit_to_hugegraph.py +++ b/hugegraph-llm/src/tests/operators/hugegraph_op/test_commit_to_hugegraph.py @@ -17,12 +17,12 @@ # pylint: disable=protected-access,no-member import unittest - from unittest.mock import MagicMock, patch -from hugegraph_llm.operators.hugegraph_op.commit_to_hugegraph import Commit2Graph from pyhugegraph.utils.exceptions import CreateError, NotFoundError +from hugegraph_llm.operators.hugegraph_op.commit_to_hugegraph import Commit2Graph + class TestCommit2Graph(unittest.TestCase): def setUp(self): diff --git a/hugegraph-llm/src/tests/operators/hugegraph_op/test_fetch_graph_data.py b/hugegraph-llm/src/tests/operators/hugegraph_op/test_fetch_graph_data.py index 858158ac4..64c093eda 100644 --- a/hugegraph-llm/src/tests/operators/hugegraph_op/test_fetch_graph_data.py +++ b/hugegraph-llm/src/tests/operators/hugegraph_op/test_fetch_graph_data.py @@ -128,7 +128,7 @@ def test_run_with_partial_result(self): {"edge_num": 200}, {}, # Missing vertices {}, # Missing edges - {"note": "Only ≤10000 VIDs and ≤ 200 EIDs for brief overview ."} + {"note": "Only ≤10000 VIDs and ≤ 200 EIDs for brief overview ."}, ] } self.mock_gremlin.exec.return_value = partial_result diff --git a/hugegraph-llm/src/tests/operators/hugegraph_op/test_schema_manager.py b/hugegraph-llm/src/tests/operators/hugegraph_op/test_schema_manager.py index 787cd25c8..a20857aec 100644 --- a/hugegraph-llm/src/tests/operators/hugegraph_op/test_schema_manager.py +++ b/hugegraph-llm/src/tests/operators/hugegraph_op/test_schema_manager.py @@ -31,9 +31,7 @@ def setUp(self): # Create SchemaManager instance self.graph_name = "test_graph" - with patch( - "hugegraph_llm.operators.hugegraph_op.schema_manager.PyHugeClient" - ) as mock_client_class: + with patch("hugegraph_llm.operators.hugegraph_op.schema_manager.PyHugeClient") as mock_client_class: mock_client_class.return_value = self.mock_client self.schema_manager = SchemaManager(self.graph_name) @@ -129,9 +127,7 @@ def test_simple_schema_with_empty_schema(self): def test_simple_schema_with_partial_schema(self): """Test simple_schema method with a partial schema.""" - partial_schema = { - "vertexlabels": [{"id": 1, "name": "person", "properties": ["name", "age"]}] - } + partial_schema = {"vertexlabels": [{"id": 1, "name": "person", "properties": ["name", "age"]}]} simple_schema = self.schema_manager.simple_schema(partial_schema) self.assertIn("vertexlabels", simple_schema) self.assertNotIn("edgelabels", simple_schema) @@ -162,9 +158,7 @@ def test_run_with_empty_schema(self): self.schema_manager.run({}) # Verify the exception message - self.assertIn( - f"Can not get {self.graph_name}'s schema from HugeGraph!", str(cm.exception) - ) + self.assertIn(f"Can not get {self.graph_name}'s schema from HugeGraph!", str(cm.exception)) def test_run_with_existing_context(self): """Test run method with an existing context.""" diff --git a/hugegraph-llm/src/tests/operators/index_op/test_build_gremlin_example_index.py b/hugegraph-llm/src/tests/operators/index_op/test_build_gremlin_example_index.py index 773a83cb4..239e377ae 100644 --- a/hugegraph-llm/src/tests/operators/index_op/test_build_gremlin_example_index.py +++ b/hugegraph-llm/src/tests/operators/index_op/test_build_gremlin_example_index.py @@ -17,14 +17,13 @@ import unittest from unittest.mock import MagicMock, patch + from hugegraph_llm.indices.vector_index.base import VectorStoreBase from hugegraph_llm.models.embeddings.base import BaseEmbedding - from hugegraph_llm.operators.index_op.build_gremlin_example_index import BuildGremlinExampleIndex class TestBuildGremlinExampleIndex(unittest.TestCase): - def setUp(self): # Mock embedding model self.mock_embedding = MagicMock(spec=BaseEmbedding) @@ -44,9 +43,7 @@ def setUp(self): # Create instance self.index_builder = BuildGremlinExampleIndex( - embedding=self.mock_embedding, - examples=self.examples, - vector_index=self.mock_vector_store_class + embedding=self.mock_embedding, examples=self.examples, vector_index=self.mock_vector_store_class ) def test_init(self): @@ -91,9 +88,7 @@ def test_run_with_empty_examples(self, mock_get_embeddings_parallel, mock_asynci # Create instance with empty examples empty_index_builder = BuildGremlinExampleIndex( - embedding=self.mock_embedding, - examples=[], - vector_index=mock_vector_store_class + embedding=self.mock_embedding, examples=[], vector_index=mock_vector_store_class ) # Setup mocks - empty embeddings @@ -119,9 +114,7 @@ def test_run_single_example(self, mock_get_embeddings_parallel, mock_asyncio_run # Create instance with single example single_example = [{"query": "g.V().count()", "description": "Count all vertices"}] single_index_builder = BuildGremlinExampleIndex( - embedding=self.mock_embedding, - examples=single_example, - vector_index=mock_vector_store_class + embedding=self.mock_embedding, examples=single_example, vector_index=mock_vector_store_class ) # Setup mocks diff --git a/hugegraph-llm/src/tests/operators/index_op/test_build_semantic_index.py b/hugegraph-llm/src/tests/operators/index_op/test_build_semantic_index.py index d0e6a95fb..2befaa383 100644 --- a/hugegraph-llm/src/tests/operators/index_op/test_build_semantic_index.py +++ b/hugegraph-llm/src/tests/operators/index_op/test_build_semantic_index.py @@ -82,9 +82,7 @@ def test_init(self): self.assertEqual(builder.vid_index, self.mock_vector_store) # Verify from_name was called with correct parameters - self.mock_vector_store_class.from_name.assert_called_once_with( - 384, "test_graph", "graph_vids" - ) + self.mock_vector_store_class.from_name.assert_called_once_with(384, "test_graph", "graph_vids") def test_extract_names(self): # Create a builder diff --git a/hugegraph-llm/src/tests/operators/index_op/test_build_vector_index.py b/hugegraph-llm/src/tests/operators/index_op/test_build_vector_index.py index d2d4634d6..622f05d5c 100644 --- a/hugegraph-llm/src/tests/operators/index_op/test_build_vector_index.py +++ b/hugegraph-llm/src/tests/operators/index_op/test_build_vector_index.py @@ -20,8 +20,8 @@ import unittest from unittest.mock import MagicMock, patch -from hugegraph_llm.models.embeddings.base import BaseEmbedding from hugegraph_llm.indices.vector_index.base import VectorStoreBase +from hugegraph_llm.models.embeddings.base import BaseEmbedding from hugegraph_llm.operators.index_op.build_vector_index import BuildVectorIndex @@ -60,9 +60,7 @@ def test_init(self): self.assertEqual(builder.vector_index, self.mock_vector_store) # Check if from_name was called with correct parameters - self.mock_vector_store_class.from_name.assert_called_once_with( - 128, "test_graph", "chunks" - ) + self.mock_vector_store_class.from_name.assert_called_once_with(128, "test_graph", "chunks") def test_run_with_chunks(self): # Mock get_embeddings_parallel to return embeddings @@ -146,9 +144,7 @@ def test_logging(self, mock_log): builder.run(context) # Check if debug log was called - mock_log.debug.assert_called_once_with( - "Building vector index for %s chunks...", 1 - ) + mock_log.debug.assert_called_once_with("Building vector index for %s chunks...", 1) if __name__ == "__main__": diff --git a/hugegraph-llm/src/tests/operators/index_op/test_gremlin_example_index_query.py b/hugegraph-llm/src/tests/operators/index_op/test_gremlin_example_index_query.py index 3c8f0e860..4ec869f84 100644 --- a/hugegraph-llm/src/tests/operators/index_op/test_gremlin_example_index_query.py +++ b/hugegraph-llm/src/tests/operators/index_op/test_gremlin_example_index_query.py @@ -20,9 +20,10 @@ import shutil import tempfile import unittest -from unittest.mock import MagicMock, patch, Mock +from unittest.mock import MagicMock, Mock, patch import pandas as pd + from hugegraph_llm.operators.index_op.gremlin_example_index_query import GremlinExampleIndexQuery @@ -60,11 +61,7 @@ def test_init_with_existing_index(self): mock_vector_index_class.from_name.return_value = mock_index_instance # Create a GremlinExampleIndexQuery instance - query = GremlinExampleIndexQuery( - vector_index=mock_vector_index_class, - embedding=mock_embedding, - num_examples=2 - ) + query = GremlinExampleIndexQuery(vector_index=mock_vector_index_class, embedding=mock_embedding, num_examples=2) # Verify the instance was initialized correctly self.assertEqual(query.embedding, mock_embedding) @@ -73,9 +70,7 @@ def test_init_with_existing_index(self): # Verify that exist() and from_name() were called mock_vector_index_class.exist.assert_called_once_with("gremlin_examples") - mock_vector_index_class.from_name.assert_called_once_with( - self.embed_dim, "gremlin_examples" - ) + mock_vector_index_class.from_name.assert_called_once_with(self.embed_dim, "gremlin_examples") @patch("hugegraph_llm.operators.index_op.gremlin_example_index_query.resource_path", "/mock/path") @patch("pandas.read_csv") @@ -110,17 +105,11 @@ def test_init_without_existing_index(self, mock_join, mock_log, mock_tqdm, mock_ mock_tqdm.return_value = self.vectors # Create a GremlinExampleIndexQuery instance - query = GremlinExampleIndexQuery( - vector_index=mock_vector_index_class, - embedding=mock_embedding, - num_examples=1 - ) + GremlinExampleIndexQuery(vector_index=mock_vector_index_class, embedding=mock_embedding, num_examples=1) # Verify that the index was built mock_vector_index_class.exist.assert_called_once_with("gremlin_examples") - mock_vector_index_class.from_name.assert_called_once_with( - self.embed_dim, "gremlin_examples" - ) + mock_vector_index_class.from_name.assert_called_once_with(self.embed_dim, "gremlin_examples") mock_index_instance.add.assert_called_once_with(self.vectors, self.properties) mock_index_instance.save_index_by_name.assert_called_once_with("gremlin_examples") mock_log.warning.assert_called_once_with("No gremlin example index found, will generate one.") @@ -145,11 +134,7 @@ def test_run_with_query(self): context = {"query": "find all persons"} # Create a GremlinExampleIndexQuery instance - query = GremlinExampleIndexQuery( - vector_index=mock_vector_index_class, - embedding=mock_embedding, - num_examples=1 - ) + query = GremlinExampleIndexQuery(vector_index=mock_vector_index_class, embedding=mock_embedding, num_examples=1) # Run the query result_context = query.run(context) @@ -181,17 +166,10 @@ def test_run_with_query_embedding(self): mock_index_instance.search.return_value = [self.properties[0]] # Create a context with a pre-computed query embedding - context = { - "query": "find all persons", - "query_embedding": [1.0, 0.0, 0.0, 0.0] - } + context = {"query": "find all persons", "query_embedding": [1.0, 0.0, 0.0, 0.0]} # Create a GremlinExampleIndexQuery instance - query = GremlinExampleIndexQuery( - vector_index=mock_vector_index_class, - embedding=mock_embedding, - num_examples=1 - ) + query = GremlinExampleIndexQuery(vector_index=mock_vector_index_class, embedding=mock_embedding, num_examples=1) # Run the query result_context = query.run(context) @@ -227,11 +205,7 @@ def test_run_with_zero_examples(self): context = {"query": "find all persons"} # Create a GremlinExampleIndexQuery instance with num_examples=0 - query = GremlinExampleIndexQuery( - vector_index=mock_vector_index_class, - embedding=mock_embedding, - num_examples=0 - ) + query = GremlinExampleIndexQuery(vector_index=mock_vector_index_class, embedding=mock_embedding, num_examples=0) # Run the query result_context = query.run(context) @@ -261,11 +235,7 @@ def test_run_without_query(self): context = {} # Create a GremlinExampleIndexQuery instance - query = GremlinExampleIndexQuery( - vector_index=mock_vector_index_class, - embedding=mock_embedding, - num_examples=1 - ) + query = GremlinExampleIndexQuery(vector_index=mock_vector_index_class, embedding=mock_embedding, num_examples=1) # Run the query and expect a ValueError with self.assertRaises(ValueError) as cm: @@ -289,10 +259,7 @@ def test_init_with_default_embedding(self, mock_embeddings_class): mock_embeddings_class.return_value.get_embedding.return_value = mock_embedding_instance # Create instance without embedding parameter - query = GremlinExampleIndexQuery( - vector_index=mock_vector_index_class, - num_examples=1 - ) + query = GremlinExampleIndexQuery(vector_index=mock_vector_index_class, num_examples=1) # Verify default embedding was used self.assertEqual(query.embedding, mock_embedding_instance) @@ -318,9 +285,7 @@ def test_run_with_negative_examples(self): # Create a GremlinExampleIndexQuery instance with negative num_examples query = GremlinExampleIndexQuery( - vector_index=mock_vector_index_class, - embedding=mock_embedding, - num_examples=-1 + vector_index=mock_vector_index_class, embedding=mock_embedding, num_examples=-1 ) # Run the query @@ -350,11 +315,7 @@ def test_get_match_result_with_non_list_embedding(self): mock_index_instance.search.return_value = [self.properties[0]] # Create a GremlinExampleIndexQuery instance - query = GremlinExampleIndexQuery( - vector_index=mock_vector_index_class, - embedding=mock_embedding, - num_examples=1 - ) + query = GremlinExampleIndexQuery(vector_index=mock_vector_index_class, embedding=mock_embedding, num_examples=1) # Test with non-list query_embedding (should use embedding service) context = {"query": "find all persons", "query_embedding": "not_a_list"} diff --git a/hugegraph-llm/src/tests/operators/index_op/test_semantic_id_query.py b/hugegraph-llm/src/tests/operators/index_op/test_semantic_id_query.py index 26df22af6..622bed39b 100644 --- a/hugegraph-llm/src/tests/operators/index_op/test_semantic_id_query.py +++ b/hugegraph-llm/src/tests/operators/index_op/test_semantic_id_query.py @@ -75,14 +75,13 @@ def test_init(self, mock_settings, mock_resource_path): mock_settings.graph_name = "test_graph" mock_settings.topk_per_keyword = 5 mock_settings.vector_dis_threshold = 1.5 - mock_resource_path = "/mock/path" with patch("os.path.join", return_value=self.test_dir): query = SemanticIdQuery( self.embedding, self.mock_vector_store_class, # 传递 vector_index 参数 by="query", - topk_per_query=3 + topk_per_query=3, ) # Verify the instance was initialized correctly @@ -99,17 +98,11 @@ def test_run_by_query(self, mock_settings, mock_resource_path): mock_settings.graph_name = "test_graph" mock_settings.topk_per_keyword = 5 mock_settings.vector_dis_threshold = 1.5 - mock_resource_path = "/mock/path" context = {"query": "query1"} with patch("os.path.join", return_value=self.test_dir): - query = SemanticIdQuery( - self.embedding, - self.mock_vector_store_class, - by="query", - topk_per_query=2 - ) + query = SemanticIdQuery(self.embedding, self.mock_vector_store_class, by="query", topk_per_query=2) # Mock the search result query.vector_index.search.return_value = ["1:vid1", "2:vid2"] @@ -130,17 +123,11 @@ def test_run_by_keywords_with_exact_match(self, mock_settings, mock_resource_pat mock_settings.graph_name = "test_graph" mock_settings.topk_per_keyword = 2 mock_settings.vector_dis_threshold = 1.5 - mock_resource_path = "/mock/path" context = {"keywords": ["keyword1", "keyword2"]} with patch("os.path.join", return_value=self.test_dir): - query = SemanticIdQuery( - self.embedding, - self.mock_vector_store_class, - by="keywords", - topk_per_keyword=2 - ) + query = SemanticIdQuery(self.embedding, self.mock_vector_store_class, by="keywords", topk_per_keyword=2) result_context = query.run(context) @@ -156,16 +143,11 @@ def test_run_with_empty_keywords(self, mock_settings, mock_resource_path): mock_settings.graph_name = "test_graph" mock_settings.topk_per_keyword = 5 mock_settings.vector_dis_threshold = 1.5 - mock_resource_path = "/mock/path" context = {"keywords": []} with patch("os.path.join", return_value=self.test_dir): - query = SemanticIdQuery( - self.embedding, - self.mock_vector_store_class, - by="keywords" - ) + query = SemanticIdQuery(self.embedding, self.mock_vector_store_class, by="keywords") result_context = query.run(context) diff --git a/hugegraph-llm/src/tests/operators/index_op/test_vector_index_query.py b/hugegraph-llm/src/tests/operators/index_op/test_vector_index_query.py index de302e9aa..c9bd67b4c 100644 --- a/hugegraph-llm/src/tests/operators/index_op/test_vector_index_query.py +++ b/hugegraph-llm/src/tests/operators/index_op/test_vector_index_query.py @@ -43,11 +43,7 @@ def test_init(self, mock_settings): mock_settings.graph_name = "test_graph" # Create VectorIndexQuery instance - query = VectorIndexQuery( - vector_index=self.mock_vector_store_class, - embedding=self.mock_embedding, - topk=3 - ) + query = VectorIndexQuery(vector_index=self.mock_vector_store_class, embedding=self.mock_embedding, topk=3) # Verify initialization self.assertEqual(query.embedding, self.mock_embedding) @@ -55,9 +51,7 @@ def test_init(self, mock_settings): self.assertEqual(query.vector_index, self.mock_vector_index) # Verify vector store was initialized correctly - self.mock_vector_store_class.from_name.assert_called_once_with( - 4, "test_graph", "chunks" - ) + self.mock_vector_store_class.from_name.assert_called_once_with(4, "test_graph", "chunks") @patch("hugegraph_llm.operators.index_op.vector_index_query.huge_settings") def test_run_with_query(self, mock_settings): @@ -66,11 +60,7 @@ def test_run_with_query(self, mock_settings): mock_settings.graph_name = "test_graph" # Create VectorIndexQuery instance - query = VectorIndexQuery( - vector_index=self.mock_vector_store_class, - embedding=self.mock_embedding, - topk=2 - ) + query = VectorIndexQuery(vector_index=self.mock_vector_store_class, embedding=self.mock_embedding, topk=2) # Prepare context with query context = {"query": "test query"} @@ -86,9 +76,7 @@ def test_run_with_query(self, mock_settings): self.mock_embedding.get_texts_embeddings.assert_called_once_with(["test query"]) # Verify vector search was called correctly - self.mock_vector_index.search.assert_called_once_with( - [1.0, 0.0, 0.0, 0.0], 2, dis_threshold=2 - ) + self.mock_vector_index.search.assert_called_once_with([1.0, 0.0, 0.0, 0.0], 2, dis_threshold=2) @patch("hugegraph_llm.operators.index_op.vector_index_query.huge_settings") def test_run_with_none_query(self, mock_settings): @@ -97,11 +85,7 @@ def test_run_with_none_query(self, mock_settings): mock_settings.graph_name = "test_graph" # Create VectorIndexQuery instance - query = VectorIndexQuery( - vector_index=self.mock_vector_store_class, - embedding=self.mock_embedding, - topk=2 - ) + query = VectorIndexQuery(vector_index=self.mock_vector_store_class, embedding=self.mock_embedding, topk=2) # Prepare context without query or with None query context = {"query": None} @@ -123,11 +107,7 @@ def test_run_with_empty_context(self, mock_settings): mock_settings.graph_name = "test_graph" # Create VectorIndexQuery instance - query = VectorIndexQuery( - vector_index=self.mock_vector_store_class, - embedding=self.mock_embedding, - topk=2 - ) + query = VectorIndexQuery(vector_index=self.mock_vector_store_class, embedding=self.mock_embedding, topk=2) # Prepare empty context context = {} @@ -152,11 +132,7 @@ def test_run_with_different_topk(self, mock_settings): self.mock_vector_index.search.return_value = ["doc1", "doc2", "doc3", "doc4", "doc5"] # Create VectorIndexQuery instance with different topk - query = VectorIndexQuery( - vector_index=self.mock_vector_store_class, - embedding=self.mock_embedding, - topk=5 - ) + query = VectorIndexQuery(vector_index=self.mock_vector_store_class, embedding=self.mock_embedding, topk=5) # Prepare context context = {"query": "test query"} @@ -168,9 +144,7 @@ def test_run_with_different_topk(self, mock_settings): self.assertEqual(result_context["vector_result"], ["doc1", "doc2", "doc3", "doc4", "doc5"]) # Verify vector search was called with correct topk - self.mock_vector_index.search.assert_called_once_with( - [1.0, 0.0, 0.0, 0.0], 5, dis_threshold=2 - ) + self.mock_vector_index.search.assert_called_once_with([1.0, 0.0, 0.0, 0.0], 5, dis_threshold=2) @patch("hugegraph_llm.operators.index_op.vector_index_query.huge_settings") def test_run_with_different_embedding_result(self, mock_settings): @@ -182,11 +156,7 @@ def test_run_with_different_embedding_result(self, mock_settings): self.mock_embedding.get_texts_embeddings.return_value = [[0.0, 1.0, 0.0, 0.0]] # Create VectorIndexQuery instance - query = VectorIndexQuery( - vector_index=self.mock_vector_store_class, - embedding=self.mock_embedding, - topk=2 - ) + query = VectorIndexQuery(vector_index=self.mock_vector_store_class, embedding=self.mock_embedding, topk=2) # Prepare context context = {"query": "another query"} @@ -195,9 +165,7 @@ def test_run_with_different_embedding_result(self, mock_settings): _ = query.run(context) # Verify vector search was called with correct embedding - self.mock_vector_index.search.assert_called_once_with( - [0.0, 1.0, 0.0, 0.0], 2, dis_threshold=2 - ) + self.mock_vector_index.search.assert_called_once_with([0.0, 1.0, 0.0, 0.0], 2, dis_threshold=2) @patch("hugegraph_llm.operators.index_op.vector_index_query.huge_settings") def test_context_preservation(self, mock_settings): @@ -206,18 +174,10 @@ def test_context_preservation(self, mock_settings): mock_settings.graph_name = "test_graph" # Create VectorIndexQuery instance - query = VectorIndexQuery( - vector_index=self.mock_vector_store_class, - embedding=self.mock_embedding, - topk=2 - ) + query = VectorIndexQuery(vector_index=self.mock_vector_store_class, embedding=self.mock_embedding, topk=2) # Prepare context with existing data - context = { - "query": "test query", - "existing_key": "existing_value", - "another_key": 123 - } + context = {"query": "test query", "existing_key": "existing_value", "another_key": 123} # Run the query result_context = query.run(context) @@ -239,20 +199,14 @@ def test_init_with_custom_parameters(self, mock_settings): custom_embedding.get_embedding_dim.return_value = 256 # Create VectorIndexQuery instance with custom parameters - query = VectorIndexQuery( - vector_index=self.mock_vector_store_class, - embedding=custom_embedding, - topk=10 - ) + query = VectorIndexQuery(vector_index=self.mock_vector_store_class, embedding=custom_embedding, topk=10) # Verify initialization with custom parameters self.assertEqual(query.topk, 10) self.assertEqual(query.embedding, custom_embedding) # Verify vector store was initialized with custom parameters - self.mock_vector_store_class.from_name.assert_called_once_with( - 256, "custom_graph", "chunks" - ) + self.mock_vector_store_class.from_name.assert_called_once_with(256, "custom_graph", "chunks") if __name__ == "__main__": diff --git a/hugegraph-llm/src/tests/operators/llm_op/test_gremlin_generate.py b/hugegraph-llm/src/tests/operators/llm_op/test_gremlin_generate.py index 80d3b5dd5..557eade8a 100644 --- a/hugegraph-llm/src/tests/operators/llm_op/test_gremlin_generate.py +++ b/hugegraph-llm/src/tests/operators/llm_op/test_gremlin_generate.py @@ -52,8 +52,7 @@ def setUpClass(cls): ] cls.sample_gremlin_response = ( - "Here is the Gremlin query:\n```gremlin\n" - "g.V().has('person', 'name', 'Tom Hanks').out('acted_in')\n```" + "Here is the Gremlin query:\n```gremlin\ng.V().has('person', 'name', 'Tom Hanks').out('acted_in')\n```" ) cls.sample_gremlin_query = "g.V().has('person', 'name', 'Tom Hanks').out('acted_in')" @@ -75,10 +74,6 @@ def _create_mock_llm(self): mock_llm.generate.return_value = self.__class__.sample_gremlin_response return mock_llm - - - - def test_init_with_defaults(self): """Test initialization with default values.""" with patch("hugegraph_llm.operators.llm_op.gremlin_generate.LLMs") as mock_llms_class: @@ -179,9 +174,7 @@ def test_run_with_empty_query(self): def test_async_generate(self): """Test the run method with async functionality.""" # Create generator with schema and vertices - generator = GremlinGenerateSynthesize( - llm=self.mock_llm, schema=self.schema, vertices=self.vertices - ) + generator = GremlinGenerateSynthesize(llm=self.mock_llm, schema=self.schema, vertices=self.vertices) # Run the method result = generator.run({"query": self.query}) diff --git a/hugegraph-llm/src/tests/operators/llm_op/test_info_extract.py b/hugegraph-llm/src/tests/operators/llm_op/test_info_extract.py index 4053f929f..80359bd52 100644 --- a/hugegraph-llm/src/tests/operators/llm_op/test_info_extract.py +++ b/hugegraph-llm/src/tests/operators/llm_op/test_info_extract.py @@ -104,14 +104,7 @@ def test_extract_by_regex_with_schema(self): }, ] - expected_edges = [ - { - "start": "person-Alice", - "end": "person-Bob", - "type": "roommate", - "properties": {} - } - ] + expected_edges = [{"start": "person-Alice", "end": "person-Bob", "type": "roommate", "properties": {}}] # Sort vertices and edges for consistent comparison actual_vertices = sorted(graph["vertices"], key=lambda x: x["id"]) diff --git a/hugegraph-llm/src/tests/operators/llm_op/test_keyword_extract.py b/hugegraph-llm/src/tests/operators/llm_op/test_keyword_extract.py index 566e4ffe5..e56729546 100644 --- a/hugegraph-llm/src/tests/operators/llm_op/test_keyword_extract.py +++ b/hugegraph-llm/src/tests/operators/llm_op/test_keyword_extract.py @@ -34,14 +34,10 @@ def setUp(self): ) # Sample query - self.query = ( - "What are the latest advancements in artificial intelligence and machine learning?" - ) + self.query = "What are the latest advancements in artificial intelligence and machine learning?" # Create KeywordExtract instance (language is now set from llm_settings) - self.extractor = KeywordExtract( - text=self.query, llm=self.mock_llm, max_keywords=5 - ) + self.extractor = KeywordExtract(text=self.query, llm=self.mock_llm, max_keywords=5) def test_init_with_parameters(self): """Test initialization with provided parameters.""" @@ -146,7 +142,6 @@ def test_run_with_no_query_raises_assertion_error(self): extractor = KeywordExtract(llm=self.mock_llm) # Create context with no query - context = {} # Call the method and expect an assertion error with self.assertRaises(AssertionError) as cm: @@ -202,12 +197,9 @@ def test_run_with_existing_call_count(self): def test_extract_keywords_from_response_with_start_token(self): """Test _extract_keywords_from_response method with start token.""" response = ( - "Some text\nKEYWORDS: artificial intelligence:0.9, machine learning:0.8, " - "neural networks:0.7\nMore text" - ) - keywords = self.extractor._extract_keywords_from_response( - response, lowercase=False, start_token="KEYWORDS:" + "Some text\nKEYWORDS: artificial intelligence:0.9, machine learning:0.8, neural networks:0.7\nMore text" ) + keywords = self.extractor._extract_keywords_from_response(response, lowercase=False, start_token="KEYWORDS:") # Check for keywords - now returns dict with scores self.assertIn("artificial intelligence", keywords) @@ -227,9 +219,7 @@ def test_extract_keywords_from_response_without_start_token(self): def test_extract_keywords_from_response_with_lowercase(self): """Test _extract_keywords_from_response method with lowercase=True.""" response = "KEYWORDS: Artificial Intelligence:0.9, Machine Learning:0.8, Neural Networks:0.7" - keywords = self.extractor._extract_keywords_from_response( - response, lowercase=True, start_token="KEYWORDS:" - ) + keywords = self.extractor._extract_keywords_from_response(response, lowercase=True, start_token="KEYWORDS:") # Check for keywords in lowercase - now returns dict with scores self.assertIn("artificial intelligence", keywords) @@ -239,9 +229,7 @@ def test_extract_keywords_from_response_with_lowercase(self): def test_extract_keywords_from_response_with_multi_word_tokens(self): """Test _extract_keywords_from_response method with multi-word tokens.""" response = "KEYWORDS: artificial intelligence:0.9, machine learning:0.8" - keywords = self.extractor._extract_keywords_from_response( - response, start_token="KEYWORDS:" - ) + keywords = self.extractor._extract_keywords_from_response(response, start_token="KEYWORDS:") # Should include the keywords - returns dict with scores self.assertIn("artificial intelligence", keywords) diff --git a/hugegraph-llm/src/tests/operators/llm_op/test_property_graph_extract.py b/hugegraph-llm/src/tests/operators/llm_op/test_property_graph_extract.py index 24bdcf4fa..5a2dee09e 100644 --- a/hugegraph-llm/src/tests/operators/llm_op/test_property_graph_extract.py +++ b/hugegraph-llm/src/tests/operators/llm_op/test_property_graph_extract.py @@ -131,9 +131,7 @@ def test_generate_extract_property_graph_prompt(self): def test_split_text(self): """Test the split_text function.""" - with patch( - "hugegraph_llm.operators.llm_op.property_graph_extract.ChunkSplitter" - ) as mock_splitter_class: + with patch("hugegraph_llm.operators.llm_op.property_graph_extract.ChunkSplitter") as mock_splitter_class: mock_splitter = MagicMock() mock_splitter.split.return_value = ["chunk1", "chunk2"] mock_splitter_class.return_value = mock_splitter diff --git a/hugegraph-llm/src/tests/test_utils.py b/hugegraph-llm/src/tests/test_utils.py index edb1db983..f1833334c 100644 --- a/hugegraph-llm/src/tests/test_utils.py +++ b/hugegraph-llm/src/tests/test_utils.py @@ -19,8 +19,10 @@ from unittest.mock import MagicMock, patch from hugegraph_llm.document import Document + from .utils.mock import VectorIndex + # Check if external service tests should be skipped def should_skip_external(): return os.environ.get("SKIP_EXTERNAL_SERVICES") == "true" diff --git a/hugegraph-llm/src/tests/utils/mock.py b/hugegraph-llm/src/tests/utils/mock.py index 88b74a69d..7db602080 100644 --- a/hugegraph-llm/src/tests/utils/mock.py +++ b/hugegraph-llm/src/tests/utils/mock.py @@ -19,6 +19,7 @@ from hugegraph_llm.models.embeddings.base import BaseEmbedding + class MockEmbedding(BaseEmbedding): """Mock embedding class for testing""" diff --git a/hugegraph-ml/src/hugegraph_ml/data/hugegraph2dgl.py b/hugegraph-ml/src/hugegraph_ml/data/hugegraph2dgl.py index 3d33caa5f..be2b33d4a 100644 --- a/hugegraph-ml/src/hugegraph_ml/data/hugegraph2dgl.py +++ b/hugegraph-ml/src/hugegraph_ml/data/hugegraph2dgl.py @@ -19,15 +19,15 @@ # pylint: disable=C0304 import warnings -from typing import Optional, List import dgl +import networkx as nx import torch from pyhugegraph.api.gremlin import GremlinManager from pyhugegraph.client import PyHugeClient from hugegraph_ml.data.hugegraph_dataset import HugeGraphDataset -import networkx as nx + class HugeGraph2DGL: def __init__( @@ -36,11 +36,9 @@ def __init__( graph: str = "hugegraph", user: str = "", pwd: str = "", - graphspace: Optional[str] = None, + graphspace: str | None = None, ): - self._client: PyHugeClient = PyHugeClient( - url=url, graph=graph, user=user, pwd=pwd, graphspace=graphspace - ) + self._client: PyHugeClient = PyHugeClient(url=url, graph=graph, user=user, pwd=pwd, graphspace=graphspace) self._graph_germlin: GremlinManager = self._client.gremlin() def convert_graph( @@ -49,7 +47,7 @@ def convert_graph( edge_label: str, feat_key: str = "feat", label_key: str = "label", - mask_keys: Optional[List[str]] = None, + mask_keys: list[str] | None = None, ): if mask_keys is None: mask_keys = ["train_mask", "val_mask", "test_mask"] @@ -61,11 +59,11 @@ def convert_graph( def convert_hetero_graph( self, - vertex_labels: List[str], - edge_labels: List[str], + vertex_labels: list[str], + edge_labels: list[str], feat_key: str = "feat", label_key: str = "label", - mask_keys: Optional[List[str]] = None, + mask_keys: list[str] | None = None, ): if mask_keys is None: mask_keys = ["train_mask", "val_mask", "test_mask"] @@ -75,7 +73,7 @@ def convert_hetero_graph( for vertex_label in vertex_labels: vertices = self._graph_germlin.exec(f"g.V().hasLabel('{vertex_label}')")["data"] if len(vertices) == 0: - warnings.warn(f"Graph has no vertices of vertex_label: {vertex_label}", Warning) + warnings.warn(f"Graph has no vertices of vertex_label: {vertex_label}", Warning, stacklevel=2) else: vertex_ids = [v["id"] for v in vertices] id2idx = {vertex_id: idx for idx, vertex_id in enumerate(vertex_ids)} @@ -98,7 +96,7 @@ def convert_hetero_graph( for edge_label in edge_labels: edges = self._graph_germlin.exec(f"g.E().hasLabel('{edge_label}')")["data"] if len(edges) == 0: - warnings.warn(f"Graph has no edges of edge_label: {edge_label}", Warning) + warnings.warn(f"Graph has no edges of edge_label: {edge_label}", Warning, stacklevel=2) else: src_vertex_label = edges[0]["outVLabel"] src_idx = [vertex_label_id2idx[src_vertex_label][e["outV"]] for e in edges] @@ -113,7 +111,6 @@ def convert_hetero_graph( return hetero_graph - def convert_graph_dataset( self, graph_vertex_label: str, @@ -132,10 +129,8 @@ def convert_graph_dataset( label = graph_vertex["properties"][label_key] graph_labels.append(label) # get this graph's vertices and edges - vertices = self._graph_germlin.exec( - f"g.V().hasLabel('{vertex_label}').has('graph_id', {graph_id})")["data"] - edges = self._graph_germlin.exec( - f"g.E().hasLabel('{edge_label}').has('graph_id', {graph_id})")["data"] + vertices = self._graph_germlin.exec(f"g.V().hasLabel('{vertex_label}').has('graph_id', {graph_id})")["data"] + edges = self._graph_germlin.exec(f"g.E().hasLabel('{edge_label}').has('graph_id', {graph_id})")["data"] graph_dgl = self._convert_graph_from_v_e(vertices, edges, feat_key) graphs.append(graph_dgl) # record max num of node @@ -167,7 +162,7 @@ def convert_graph_with_edge_feat( node_feat_key: str = "feat", edge_feat_key: str = "edge_feat", label_key: str = "label", - mask_keys: Optional[List[str]] = None, + mask_keys: list[str] | None = None, ): if mask_keys is None: mask_keys = ["train_mask", "val_mask", "test_mask"] @@ -182,23 +177,19 @@ def convert_graph_with_edge_feat( def convert_graph_ogb(self, vertex_label: str, edge_label: str, split_label: str): vertices = self._graph_germlin.exec(f"g.V().hasLabel('{vertex_label}')")["data"] edges = self._graph_germlin.exec(f"g.E().hasLabel('{edge_label}')")["data"] - graph_dgl, vertex_id_to_idx = self._convert_graph_from_ogb( - vertices, edges, "feat", "year", "weight" - ) - edges_split = self._graph_germlin.exec(f"g.E().hasLabel('{split_label}')")[ - "data" - ] + graph_dgl, vertex_id_to_idx = self._convert_graph_from_ogb(vertices, edges, "feat", "year", "weight") + edges_split = self._graph_germlin.exec(f"g.E().hasLabel('{split_label}')")["data"] split_edge = self._convert_split_edge_from_ogb(edges_split, vertex_id_to_idx) return graph_dgl, split_edge def convert_hetero_graph_bgnn( self, - vertex_labels: List[str], - edge_labels: List[str], + vertex_labels: list[str], + edge_labels: list[str], feat_key: str = "feat", label_key: str = "class", cat_key: str = "cat_features", - mask_keys: Optional[List[str]] = None, + mask_keys: list[str] | None = None, ): if mask_keys is None: mask_keys = ["train_mask", "val_mask", "test_mask"] @@ -206,13 +197,9 @@ def convert_hetero_graph_bgnn( vertex_label_data = {} # for each vertex label for vertex_label in vertex_labels: - vertices = self._graph_germlin.exec(f"g.V().hasLabel('{vertex_label}')")[ - "data" - ] + vertices = self._graph_germlin.exec(f"g.V().hasLabel('{vertex_label}')")["data"] if len(vertices) == 0: - warnings.warn( - f"Graph has no vertices of vertex_label: {vertex_label}", Warning - ) + warnings.warn(f"Graph has no vertices of vertex_label: {vertex_label}", Warning, stacklevel=2) else: vertex_ids = [v["id"] for v in vertices] id2idx = {vertex_id: idx for idx, vertex_id in enumerate(vertex_ids)} @@ -250,18 +237,12 @@ def convert_hetero_graph_bgnn( for edge_label in edge_labels: edges = self._graph_germlin.exec(f"g.E().hasLabel('{edge_label}')")["data"] if len(edges) == 0: - warnings.warn( - f"Graph has no edges of edge_label: {edge_label}", Warning - ) + warnings.warn(f"Graph has no edges of edge_label: {edge_label}", Warning, stacklevel=2) else: src_vertex_label = edges[0]["outVLabel"] - src_idx = [ - vertex_label_id2idx[src_vertex_label][e["outV"]] for e in edges - ] + src_idx = [vertex_label_id2idx[src_vertex_label][e["outV"]] for e in edges] dst_vertex_label = edges[0]["inVLabel"] - dst_idx = [ - vertex_label_id2idx[dst_vertex_label][e["inV"]] for e in edges - ] + dst_idx = [vertex_label_id2idx[dst_vertex_label][e["inV"]] for e in edges] edge_data_dict[(src_vertex_label, edge_label, dst_vertex_label)] = ( src_idx, dst_idx, @@ -270,16 +251,14 @@ def convert_hetero_graph_bgnn( hetero_graph = dgl.heterograph(edge_data_dict) for vertex_label in vertex_labels: for prop in vertex_label_data[vertex_label]: - hetero_graph.nodes[vertex_label].data[prop] = vertex_label_data[ - vertex_label - ][prop] + hetero_graph.nodes[vertex_label].data[prop] = vertex_label_data[vertex_label][prop] return hetero_graph @staticmethod def _convert_graph_from_v_e(vertices, edges, feat_key=None, label_key=None, mask_keys=None): if len(vertices) == 0: - warnings.warn("This graph has no vertices", Warning) + warnings.warn("This graph has no vertices", Warning, stacklevel=2) return dgl.graph(()) vertex_ids = [v["id"] for v in vertices] vertex_id_to_idx = {vertex_id: idx for idx, vertex_id in enumerate(vertex_ids)} @@ -304,15 +283,13 @@ def _convert_graph_from_v_e(vertices, edges, feat_key=None, label_key=None, mask @staticmethod def _convert_graph_from_v_e_nx(vertices, edges): if len(vertices) == 0: - warnings.warn("This graph has no vertices", Warning) + warnings.warn("This graph has no vertices", Warning, stacklevel=2) return nx.Graph(()) vertex_ids = [v["id"] for v in vertices] vertex_id_to_idx = {vertex_id: idx for idx, vertex_id in enumerate(vertex_ids)} new_vertex_ids = [vertex_id_to_idx[id] for id in vertex_ids] edge_list = [(edge["outV"], edge["inV"]) for edge in edges] - new_edge_list = [ - (vertex_id_to_idx[src], vertex_id_to_idx[dst]) for src, dst in edge_list - ] + new_edge_list = [(vertex_id_to_idx[src], vertex_id_to_idx[dst]) for src, dst in edge_list] graph_nx = nx.Graph() graph_nx.add_nodes_from(new_vertex_ids) graph_nx.add_edges_from(new_edge_list) @@ -328,7 +305,7 @@ def _convert_graph_from_v_e_with_edge_feat( mask_keys=None, ): if len(vertices) == 0: - warnings.warn("This graph has no vertices", Warning) + warnings.warn("This graph has no vertices", Warning, stacklevel=2) return dgl.graph(()) vertex_ids = [v["id"] for v in vertices] vertex_id_to_idx = {vertex_id: idx for idx, vertex_id in enumerate(vertex_ids)} @@ -356,7 +333,7 @@ def _convert_graph_from_v_e_with_edge_feat( @staticmethod def _convert_graph_from_ogb(vertices, edges, feat_key, year_key, weight_key): if len(vertices) == 0: - warnings.warn("This graph has no vertices", Warning) + warnings.warn("This graph has no vertices", Warning, stacklevel=2) return dgl.graph(()) vertex_ids = [v["id"] for v in vertices] vertex_id_to_idx = {vertex_id: idx for idx, vertex_id in enumerate(vertex_ids)} @@ -364,10 +341,7 @@ def _convert_graph_from_ogb(vertices, edges, feat_key, year_key, weight_key): dst_idx = [vertex_id_to_idx[e["inV"]] for e in edges] graph_dgl = dgl.graph((src_idx, dst_idx)) if feat_key and feat_key in vertices[0]["properties"]: - node_feats = [ - v["properties"][feat_key] - for v in vertices[0 : graph_dgl.number_of_nodes()] - ] + node_feats = [v["properties"][feat_key] for v in vertices[0 : graph_dgl.number_of_nodes()]] graph_dgl.ndata["feat"] = torch.tensor(node_feats, dtype=torch.float32) if year_key and year_key in edges[0]["properties"]: year = [e["properties"][year_key] for e in edges] @@ -394,39 +368,29 @@ def _convert_split_edge_from_ogb(edges, vertex_id_to_idx): for edge in edges: if edge["properties"]["train_edge_mask"] == 1: - train_edge_list.append( - [vertex_id_to_idx[edge["outV"]], vertex_id_to_idx[edge["inV"]]] - ) + train_edge_list.append([vertex_id_to_idx[edge["outV"]], vertex_id_to_idx[edge["inV"]]]) if edge["properties"]["train_year_mask"] != -1: train_year_list.append(edge["properties"]["train_year_mask"]) if edge["properties"]["train_weight_mask"] != -1: train_weight_list.append(edge["properties"]["train_weight_mask"]) if edge["properties"]["valid_edge_mask"] == 1: - valid_edge_list.append( - [vertex_id_to_idx[edge["outV"]], vertex_id_to_idx[edge["inV"]]] - ) + valid_edge_list.append([vertex_id_to_idx[edge["outV"]], vertex_id_to_idx[edge["inV"]]]) if edge["properties"]["valid_year_mask"] != -1: valid_year_list.append(edge["properties"]["valid_year_mask"]) if edge["properties"]["valid_weight_mask"] != -1: valid_weight_list.append(edge["properties"]["valid_weight_mask"]) if edge["properties"]["valid_edge_neg_mask"] == 1: - valid_edge_neg_list.append( - [vertex_id_to_idx[edge["outV"]], vertex_id_to_idx[edge["inV"]]] - ) + valid_edge_neg_list.append([vertex_id_to_idx[edge["outV"]], vertex_id_to_idx[edge["inV"]]]) if edge["properties"]["test_edge_mask"] == 1: - test_edge_list.append( - [vertex_id_to_idx[edge["outV"]], vertex_id_to_idx[edge["inV"]]] - ) + test_edge_list.append([vertex_id_to_idx[edge["outV"]], vertex_id_to_idx[edge["inV"]]]) if edge["properties"]["test_year_mask"] != -1: test_year_list.append(edge["properties"]["test_year_mask"]) if edge["properties"]["test_weight_mask"] != -1: test_weight_list.append(edge["properties"]["test_weight_mask"]) if edge["properties"]["test_edge_neg_mask"] == 1: - test_edge_neg_list.append( - [vertex_id_to_idx[edge["outV"]], vertex_id_to_idx[edge["inV"]]] - ) + test_edge_neg_list.append([vertex_id_to_idx[edge["outV"]], vertex_id_to_idx[edge["inV"]]]) split_edge = { "train": { @@ -450,6 +414,7 @@ def _convert_split_edge_from_ogb(edges, vertex_id_to_idx): return split_edge + if __name__ == "__main__": hg2d = HugeGraph2DGL() hg2d.convert_graph(vertex_label="CORA_vertex", edge_label="CORA_edge") @@ -460,24 +425,20 @@ def _convert_split_edge_from_ogb(edges, vertex_id_to_idx): ) hg2d.convert_hetero_graph( vertex_labels=["ACM_paper_v", "ACM_author_v", "ACM_field_v"], - edge_labels=["ACM_ap_e", "ACM_fp_e", "ACM_pa_e", "ACM_pf_e"] + edge_labels=["ACM_ap_e", "ACM_fp_e", "ACM_pa_e", "ACM_pf_e"], ) hg2d.convert_graph_nx(vertex_label="CAVEMAN_vertex", edge_label="CAVEMAN_edge") - hg2d.convert_graph_with_edge_feat( - vertex_label="CORA_edge_feat_vertex", edge_label="CORA_edge_feat_edge" - ) + hg2d.convert_graph_with_edge_feat(vertex_label="CORA_edge_feat_vertex", edge_label="CORA_edge_feat_edge") hg2d.convert_graph_ogb( vertex_label="ogbl-collab_vertex", edge_label="ogbl-collab_edge", split_label="ogbl-collab_split_edge", ) - hg2d.convert_hetero_graph_bgnn( - vertex_labels=["AVAZU__N_v"], edge_labels=["AVAZU__E_e"] - ) + hg2d.convert_hetero_graph_bgnn(vertex_labels=["AVAZU__N_v"], edge_labels=["AVAZU__E_e"]) hg2d.convert_hetero_graph( vertex_labels=["AMAZONGATNE__N_v"], edge_labels=[ "AMAZONGATNE_1_e", "AMAZONGATNE_2_e", ], - ) \ No newline at end of file + ) diff --git a/hugegraph-ml/src/hugegraph_ml/data/hugegraph_dataset.py b/hugegraph-ml/src/hugegraph_ml/data/hugegraph_dataset.py index e9cb32496..b4d2ac4f0 100644 --- a/hugegraph-ml/src/hugegraph_ml/data/hugegraph_dataset.py +++ b/hugegraph-ml/src/hugegraph_ml/data/hugegraph_dataset.py @@ -17,6 +17,7 @@ from torch.utils.data import Dataset + class HugeGraphDataset(Dataset): def __init__(self, graphs, labels, info): self.graphs = graphs diff --git a/hugegraph-ml/src/hugegraph_ml/examples/appnp_example.py b/hugegraph-ml/src/hugegraph_ml/examples/appnp_example.py index 6754b7472..610493cf6 100644 --- a/hugegraph-ml/src/hugegraph_ml/examples/appnp_example.py +++ b/hugegraph-ml/src/hugegraph_ml/examples/appnp_example.py @@ -15,10 +15,11 @@ # specific language governing permissions and limitations # under the License. +import torch.nn.functional as F + from hugegraph_ml.data.hugegraph2dgl import HugeGraph2DGL from hugegraph_ml.models.appnp import APPNP from hugegraph_ml.tasks.node_classify import NodeClassify -import torch.nn.functional as F def appnp_example(n_epochs=200): diff --git a/hugegraph-ml/src/hugegraph_ml/examples/arma_example.py b/hugegraph-ml/src/hugegraph_ml/examples/arma_example.py index 0c75b5be1..0ee40bc01 100644 --- a/hugegraph-ml/src/hugegraph_ml/examples/arma_example.py +++ b/hugegraph-ml/src/hugegraph_ml/examples/arma_example.py @@ -15,10 +15,11 @@ # specific language governing permissions and limitations # under the License. +from torch import nn + from hugegraph_ml.data.hugegraph2dgl import HugeGraph2DGL from hugegraph_ml.models.arma import ARMA4NC from hugegraph_ml.tasks.node_classify import NodeClassify -from torch import nn def arma_example(n_epochs=200): diff --git a/hugegraph-ml/src/hugegraph_ml/examples/bgnn_example.py b/hugegraph-ml/src/hugegraph_ml/examples/bgnn_example.py index c395a1d4a..0cc56655c 100644 --- a/hugegraph-ml/src/hugegraph_ml/examples/bgnn_example.py +++ b/hugegraph-ml/src/hugegraph_ml/examples/bgnn_example.py @@ -17,26 +17,22 @@ # pylint: disable=C0103 +from hugegraph_ml.data.hugegraph2dgl import HugeGraph2DGL from hugegraph_ml.models.bgnn import ( - GNNModelDGL, BGNNPredictor, + GNNModelDGL, + convert_data, encode_cat_features, replace_na, - convert_data, ) -from hugegraph_ml.data.hugegraph2dgl import HugeGraph2DGL def bgnn_example(): hg2d = HugeGraph2DGL() - g = hg2d.convert_hetero_graph_bgnn( - vertex_labels=["AVAZU__N_v"], edge_labels=["AVAZU__E_e"] - ) + g = hg2d.convert_hetero_graph_bgnn(vertex_labels=["AVAZU__N_v"], edge_labels=["AVAZU__E_e"]) X, y, cat_features, train_mask, val_mask, test_mask = convert_data(g) encoded_X = X.copy() - encoded_X = encode_cat_features( - encoded_X, y, cat_features, train_mask, val_mask, test_mask - ) + encoded_X = encode_cat_features(encoded_X, y, cat_features, train_mask, val_mask, test_mask) encoded_X = replace_na(encoded_X, train_mask) gnn_model = GNNModelDGL(in_dim=y.shape[1], hidden_dim=128, out_dim=y.shape[1]) bgnn = BGNNPredictor( @@ -50,7 +46,7 @@ def bgnn_example(): gbdt_depth=6, gbdt_lr=0.1, ) - _ = bgnn.fit( + metrics = bgnn.fit( g, encoded_X, y, @@ -63,6 +59,7 @@ def bgnn_example(): patience=10, metric_name="loss", ) + print(metrics) if __name__ == "__main__": diff --git a/hugegraph-ml/src/hugegraph_ml/examples/bgrl_example.py b/hugegraph-ml/src/hugegraph_ml/examples/bgrl_example.py index f73ab0cc7..b23c826f6 100644 --- a/hugegraph-ml/src/hugegraph_ml/examples/bgrl_example.py +++ b/hugegraph-ml/src/hugegraph_ml/examples/bgrl_example.py @@ -25,7 +25,7 @@ def bgrl_example(n_epochs_embed=300, n_epochs_clf=400): hg2d = HugeGraph2DGL() graph = hg2d.convert_graph(vertex_label="CORA_vertex", edge_label="CORA_edge") - encoder = GCN([graph.ndata["feat"].size(1)] + [256, 128]) + encoder = GCN([graph.ndata["feat"].size(1), 256, 128]) predictor = MLP_Predictor( input_size=128, output_size=128, @@ -39,7 +39,7 @@ def bgrl_example(n_epochs_embed=300, n_epochs_clf=400): model = MLPClassifier( n_in_feat=embedded_graph.ndata["feat"].shape[1], n_out_feat=embedded_graph.ndata["label"].unique().shape[0], - n_hidden=128 + n_hidden=128, ) node_clf_task = NodeClassify(graph=embedded_graph, model=model) node_clf_task.train(lr=1e-3, n_epochs=n_epochs_clf, patience=30) diff --git a/hugegraph-ml/src/hugegraph_ml/examples/care_gnn_example.py b/hugegraph-ml/src/hugegraph_ml/examples/care_gnn_example.py index 723d37d0e..dd031760f 100644 --- a/hugegraph-ml/src/hugegraph_ml/examples/care_gnn_example.py +++ b/hugegraph-ml/src/hugegraph_ml/examples/care_gnn_example.py @@ -15,10 +15,12 @@ # specific language governing permissions and limitations # under the License. +import torch + from hugegraph_ml.data.hugegraph2dgl import HugeGraph2DGL from hugegraph_ml.models.care_gnn import CAREGNN from hugegraph_ml.tasks.fraud_detector_caregnn import DetectorCaregnn -import torch + def care_gnn_example(n_epochs=200): hg2d = HugeGraph2DGL() diff --git a/hugegraph-ml/src/hugegraph_ml/examples/correct_and_smooth_example.py b/hugegraph-ml/src/hugegraph_ml/examples/correct_and_smooth_example.py index 921fa9483..e407f5124 100644 --- a/hugegraph-ml/src/hugegraph_ml/examples/correct_and_smooth_example.py +++ b/hugegraph-ml/src/hugegraph_ml/examples/correct_and_smooth_example.py @@ -19,6 +19,7 @@ from hugegraph_ml.models.correct_and_smooth import MLP from hugegraph_ml.tasks.node_classify import NodeClassify + def cs_example(n_epochs=200): hg2d = HugeGraph2DGL() graph = hg2d.convert_graph(vertex_label="CORA_vertex", edge_label="CORA_edge") diff --git a/hugegraph-ml/src/hugegraph_ml/examples/deepergcn_example.py b/hugegraph-ml/src/hugegraph_ml/examples/deepergcn_example.py index 197826e28..1c7be6bf4 100644 --- a/hugegraph-ml/src/hugegraph_ml/examples/deepergcn_example.py +++ b/hugegraph-ml/src/hugegraph_ml/examples/deepergcn_example.py @@ -22,9 +22,7 @@ def deepergcn_example(n_epochs=1000): hg2d = HugeGraph2DGL() - graph = hg2d.convert_graph_with_edge_feat( - vertex_label="CORA_vertex", edge_label="CORA_edge" - ) + graph = hg2d.convert_graph_with_edge_feat(vertex_label="CORA_vertex", edge_label="CORA_edge") model = DeeperGCN( node_feat_dim=graph.ndata["feat"].shape[1], edge_feat_dim=graph.edata["feat"].shape[1], diff --git a/hugegraph-ml/src/hugegraph_ml/examples/diffpool_example.py b/hugegraph-ml/src/hugegraph_ml/examples/diffpool_example.py index 09ee632c3..0728f08da 100644 --- a/hugegraph-ml/src/hugegraph_ml/examples/diffpool_example.py +++ b/hugegraph-ml/src/hugegraph_ml/examples/diffpool_example.py @@ -29,7 +29,7 @@ def diffpool_example(n_epochs=1000): n_in_feats=dataset.info["n_feat_dim"], n_out_feats=dataset.info["n_classes"], max_n_nodes=dataset.info["max_n_nodes"], - pool_ratio=0.2 + pool_ratio=0.2, ) graph_clf_task = GraphClassify(dataset, model) graph_clf_task.train(lr=1e-3, n_epochs=n_epochs, patience=300, early_stopping_monitor="accuracy") diff --git a/hugegraph-ml/src/hugegraph_ml/examples/gin_example.py b/hugegraph-ml/src/hugegraph_ml/examples/gin_example.py index 75fac72d3..63e62d8db 100644 --- a/hugegraph-ml/src/hugegraph_ml/examples/gin_example.py +++ b/hugegraph-ml/src/hugegraph_ml/examples/gin_example.py @@ -25,11 +25,7 @@ def gin_example(n_epochs=1000): dataset = hg2d.convert_graph_dataset( graph_vertex_label="MUTAG_graph_vertex", vertex_label="MUTAG_vertex", edge_label="MUTAG_edge" ) - model = GIN( - n_in_feats=dataset.info["n_feat_dim"], - n_out_feats=dataset.info["n_classes"], - pooling="max" - ) + model = GIN(n_in_feats=dataset.info["n_feat_dim"], n_out_feats=dataset.info["n_classes"], pooling="max") graph_clf_task = GraphClassify(dataset, model) graph_clf_task.train(lr=1e-4, n_epochs=n_epochs) print(graph_clf_task.evaluate()) diff --git a/hugegraph-ml/src/hugegraph_ml/examples/grace_example.py b/hugegraph-ml/src/hugegraph_ml/examples/grace_example.py index 232cae4ad..c66d8e942 100644 --- a/hugegraph-ml/src/hugegraph_ml/examples/grace_example.py +++ b/hugegraph-ml/src/hugegraph_ml/examples/grace_example.py @@ -34,7 +34,7 @@ def grace_example(n_epochs_embed=300, n_epochs_clf=400): model = MLPClassifier( n_in_feat=embedded_graph.ndata["feat"].shape[1], n_out_feat=embedded_graph.ndata["label"].unique().shape[0], - n_hidden=128 + n_hidden=128, ) node_clf_task = NodeClassify(graph=embedded_graph, model=model) node_clf_task.train(lr=1e-3, n_epochs=n_epochs_clf, patience=30) diff --git a/hugegraph-ml/src/hugegraph_ml/examples/pgnn_example.py b/hugegraph-ml/src/hugegraph_ml/examples/pgnn_example.py index 7297de237..3b7a4e709 100644 --- a/hugegraph-ml/src/hugegraph_ml/examples/pgnn_example.py +++ b/hugegraph-ml/src/hugegraph_ml/examples/pgnn_example.py @@ -23,9 +23,7 @@ def pgnn_example(n_epochs=200): hg2d = HugeGraph2DGL() - graph = hg2d.convert_graph_nx( - vertex_label="CAVEMAN_vertex", edge_label="CAVEMAN_edge" - ) + graph = hg2d.convert_graph_nx(vertex_label="CAVEMAN_vertex", edge_label="CAVEMAN_edge") model = PGNN(input_dim=get_dataset(graph)["feature"].shape[1]) link_pre_task = LinkPredictionPGNN(graph, model) link_pre_task.train(lr=0.005, weight_decay=0.0005, n_epochs=n_epochs) diff --git a/hugegraph-ml/src/hugegraph_ml/examples/seal_example.py b/hugegraph-ml/src/hugegraph_ml/examples/seal_example.py index 3d6e7d3be..2e292a987 100644 --- a/hugegraph-ml/src/hugegraph_ml/examples/seal_example.py +++ b/hugegraph-ml/src/hugegraph_ml/examples/seal_example.py @@ -15,10 +15,11 @@ # specific language governing permissions and limitations # under the License. +import torch + from hugegraph_ml.data.hugegraph2dgl import HugeGraph2DGL from hugegraph_ml.models.seal import DGCNN, data_prepare from hugegraph_ml.tasks.link_prediction_seal import LinkPredictionSeal -import torch def seal_example(n_epochs=200): @@ -45,6 +46,10 @@ def seal_example(n_epochs=200): link_pre_task = LinkPredictionSeal(graph, split_edge, model) link_pre_task.train(lr=0.005, n_epochs=n_epochs) + # 在训练结束后,最后一个epoch的评估结果已经在train方法中计算并存储在summary_test中 + # 这里我们可以简单地打印一条消息,表示训练已完成 + print("Training completed. Evaluation metrics were calculated during training.") + if __name__ == "__main__": seal_example() diff --git a/hugegraph-ml/src/hugegraph_ml/models/agnn.py b/hugegraph-ml/src/hugegraph_ml/models/agnn.py index c83058f85..8be570d3f 100644 --- a/hugegraph-ml/src/hugegraph_ml/models/agnn.py +++ b/hugegraph-ml/src/hugegraph_ml/models/agnn.py @@ -21,15 +21,14 @@ References ---------- Paper: https://arxiv.org/abs/1803.03735 -Author's code: +Author's code: DGL code: https://github.com/dmlc/dgl/blob/master/python/dgl/nn/pytorch/conv/agnnconv.py """ - - +import torch.nn.functional as F from dgl.nn.pytorch.conv import AGNNConv from torch import nn -import torch.nn.functional as F + class AGNN(nn.Module): def __init__(self, num_layers, in_dim, hid_dim, out_dim, dropout): diff --git a/hugegraph-ml/src/hugegraph_ml/models/appnp.py b/hugegraph-ml/src/hugegraph_ml/models/appnp.py index f63fb29b9..24e045e5b 100644 --- a/hugegraph-ml/src/hugegraph_ml/models/appnp.py +++ b/hugegraph-ml/src/hugegraph_ml/models/appnp.py @@ -25,9 +25,8 @@ DGL code: https://github.com/dmlc/dgl/tree/master/examples/pytorch/appnp """ -from torch import nn - from dgl.nn.pytorch.conv import APPNPConv +from torch import nn class APPNP(nn.Module): @@ -42,7 +41,7 @@ def __init__( alpha, k, ): - super(APPNP, self).__init__() + super().__init__() self.layers = nn.ModuleList() # input layer self.layers.append(nn.Linear(in_feats, hiddens[0])) diff --git a/hugegraph-ml/src/hugegraph_ml/models/arma.py b/hugegraph-ml/src/hugegraph_ml/models/arma.py index 7fb21b5c6..77737a655 100644 --- a/hugegraph-ml/src/hugegraph_ml/models/arma.py +++ b/hugegraph-ml/src/hugegraph_ml/models/arma.py @@ -23,15 +23,16 @@ References ---------- Paper: https://arxiv.org/abs/1901.01343 -Author's code: +Author's code: DGL code: https://github.com/dmlc/dgl/tree/master/examples/pytorch/arma """ import math + import dgl.function as fn import torch -from torch import nn import torch.nn.functional as F +from torch import nn def glorot(tensor): @@ -56,7 +57,7 @@ def __init__( dropout=0.0, bias=True, ): - super(ARMAConv, self).__init__() + super().__init__() self.in_dim = in_dim self.out_dim = out_dim @@ -66,17 +67,11 @@ def __init__( self.dropout = nn.Dropout(p=dropout) # init weight - self.w_0 = nn.ModuleDict( - {str(k): nn.Linear(in_dim, out_dim, bias=False) for k in range(self.K)} - ) + self.w_0 = nn.ModuleDict({str(k): nn.Linear(in_dim, out_dim, bias=False) for k in range(self.K)}) # deeper weight - self.w = nn.ModuleDict( - {str(k): nn.Linear(out_dim, out_dim, bias=False) for k in range(self.K)} - ) + self.w = nn.ModuleDict({str(k): nn.Linear(out_dim, out_dim, bias=False) for k in range(self.K)}) # v - self.v = nn.ModuleDict( - {str(k): nn.Linear(in_dim, out_dim, bias=False) for k in range(self.K)} - ) + self.v = nn.ModuleDict({str(k): nn.Linear(in_dim, out_dim, bias=False) for k in range(self.K)}) # bias if bias: self.bias = nn.Parameter(torch.Tensor(self.K, self.T, 1, self.out_dim)) @@ -105,14 +100,11 @@ def forward(self, g, feats): for t in range(self.T): feats = feats * norm g.ndata["h"] = feats - g.update_all(fn.copy_u("h", "m"), fn.sum("m", "h")) # pylint: disable=E1101 + g.update_all(fn.copy_u("h", "m"), fn.sum("m", "h")) # pylint: disable=E1101 feats = g.ndata.pop("h") feats = feats * norm - if t == 0: - feats = self.w_0[str(k)](feats) - else: - feats = self.w[str(k)](feats) + feats = self.w_0[str(k)](feats) if t == 0 else self.w[str(k)](feats) feats += self.dropout(self.v[str(k)](init_feats)) feats += self.v[str(k)](self.dropout(init_feats)) @@ -138,7 +130,7 @@ def __init__( activation=None, dropout=0.0, ): - super(ARMA4NC, self).__init__() + super().__init__() self.conv1 = ARMAConv( in_dim=in_dim, diff --git a/hugegraph-ml/src/hugegraph_ml/models/bgnn.py b/hugegraph-ml/src/hugegraph_ml/models/bgnn.py index 51689ef27..b551fdc73 100644 --- a/hugegraph-ml/src/hugegraph_ml/models/bgnn.py +++ b/hugegraph-ml/src/hugegraph_ml/models/bgnn.py @@ -30,24 +30,31 @@ import itertools import time from collections import defaultdict as ddict + import dgl import numpy as np import pandas as pd import torch import torch.nn.functional as F from catboost import CatBoostClassifier, CatBoostRegressor, Pool, sum_models -from sklearn import preprocessing -from sklearn.metrics import r2_score -from tqdm import tqdm from category_encoders import CatBoostEncoder from dgl.nn.pytorch import ( AGNNConv as AGNNConvDGL, +) +from dgl.nn.pytorch import ( APPNPConv, + GraphConv, +) +from dgl.nn.pytorch import ( ChebConv as ChebConvDGL, +) +from dgl.nn.pytorch import ( GATConv as GATConvDGL, - GraphConv, ) -from torch.nn import Dropout, ELU, Linear, ReLU, Sequential +from sklearn import preprocessing +from sklearn.metrics import r2_score +from torch.nn import ELU, Dropout, Linear, ReLU, Sequential +from tqdm import tqdm class BGNNPredictor: @@ -168,9 +175,7 @@ def train_gbdt( if epoch == 0 and self.task == "classification": self.base_gbdt = epoch_gbdt_model else: - self.gbdt_model = self.append_gbdt_model( - epoch_gbdt_model, weights=[1, gbdt_alpha] - ) + self.gbdt_model = self.append_gbdt_model(epoch_gbdt_model, weights=[1, gbdt_alpha]) def update_node_features(self, node_features, X, original_X): # get predictions from gbdt model @@ -191,30 +196,19 @@ def update_node_features(self, node_features, X, original_X): axis=1, ) # replace old predictions with new predictions else: - predictions = np.append( - X, predictions, axis=1 - ) # append original features with new predictions + predictions = np.append(X, predictions, axis=1) # append original features with new predictions predictions = torch.from_numpy(predictions).to(self.device) node_features.data = predictions.float().data def update_gbdt_targets(self, node_features, node_features_before, train_mask): - return ( - (node_features - node_features_before) - .detach() - .cpu() - .numpy()[train_mask, -self.out_dim :] - ) + return (node_features - node_features_before).detach().cpu().numpy()[train_mask, -self.out_dim :] def init_node_features(self, X): - node_features = torch.empty( - X.shape[0], self.in_dim, requires_grad=True, device=self.device - ) + node_features = torch.empty(X.shape[0], self.in_dim, requires_grad=True, device=self.device) if self.append_gbdt_pred: - node_features.data[:, : -self.out_dim] = torch.from_numpy( - X.to_numpy(copy=True) - ) + node_features.data[:, : -self.out_dim] = torch.from_numpy(X.to_numpy(copy=True)) return node_features def init_optimizer(self, node_features, optimize_node_features, learning_rate): @@ -239,9 +233,7 @@ def train_model(self, model_in, target_labels, train_mask, optimizer): elif self.task == "classification": loss = F.cross_entropy(pred, y.long()) else: - raise NotImplementedError( - "Unknown task. Supported tasks: classification, regression." - ) + raise NotImplementedError("Unknown task. Supported tasks: classification, regression.") optimizer.zero_grad() loss.backward() @@ -255,18 +247,12 @@ def evaluate_model(self, logits, target_labels, mask): pred = logits[mask] if self.task == "regression": metrics["loss"] = torch.sqrt(F.mse_loss(pred, y).squeeze() + 1e-8) - metrics["rmsle"] = torch.sqrt( - F.mse_loss(torch.log(pred + 1), torch.log(y + 1)).squeeze() + 1e-8 - ) + metrics["rmsle"] = torch.sqrt(F.mse_loss(torch.log(pred + 1), torch.log(y + 1)).squeeze() + 1e-8) metrics["mae"] = F.l1_loss(pred, y) - metrics["r2"] = torch.Tensor( - [r2_score(y.cpu().numpy(), pred.cpu().numpy())] - ) + metrics["r2"] = torch.Tensor([r2_score(y.cpu().numpy(), pred.cpu().numpy())]) elif self.task == "classification": metrics["loss"] = F.cross_entropy(pred, y.long()) - metrics["accuracy"] = torch.Tensor( - [(y == pred.max(1)[1]).sum().item() / y.shape[0]] - ) + metrics["accuracy"] = torch.Tensor([(y == pred.max(1)[1]).sum().item() / y.shape[0]]) return metrics @@ -311,10 +297,8 @@ def update_early_stopping( metric_name, lower_better=False, ): - train_metric, val_metric, test_metric = metrics[metric_name][-1] - if (lower_better and val_metric < best_metric[1]) or ( - not lower_better and val_metric > best_metric[1] - ): + _train_metric, val_metric, _test_metric = metrics[metric_name][-1] + if (lower_better and val_metric < best_metric[1]) or (not lower_better and val_metric > best_metric[1]): best_metric = metrics[metric_name][-1] best_val_epoch = epoch epochs_since_last_best_metric = 0 @@ -393,10 +377,9 @@ def fit( """ # initialize for early stopping and metrics - if metric_name in ["r2", "accuracy"]: - best_metric = [np.cfloat("-inf")] * 3 # for train/val/test - else: - best_metric = [np.cfloat("inf")] * 3 # for train/val/test + best_metric = ( + [np.cfloat("-inf")] * 3 if metric_name in ["r2", "accuracy"] else [np.cfloat("inf")] * 3 + ) # for train/val/test best_val_epoch = 0 epochs_since_last_best_metric = 0 @@ -408,9 +391,7 @@ def fit( self.out_dim = y.shape[1] elif self.task == "classification": self.out_dim = len(set(y.iloc[test_mask, 0])) - self.in_dim = ( - self.out_dim + X.shape[1] if self.append_gbdt_pred else self.out_dim - ) + self.in_dim = self.out_dim + X.shape[1] if self.append_gbdt_pred else self.out_dim if original_X is None: original_X = X.copy() @@ -422,9 +403,7 @@ def fit( self.gbdt_model = None node_features = self.init_node_features(X) - optimizer = self.init_optimizer( - node_features, optimize_node_features=True, learning_rate=self.lr - ) + optimizer = self.init_optimizer(node_features, optimize_node_features=True, learning_rate=self.lr) y = torch.from_numpy(y.to_numpy(copy=True)).float().squeeze().to(self.device) graph = graph.to(self.device) @@ -456,9 +435,7 @@ def fit( metrics, self.backprop_per_epoch, ) - gbdt_y_train = self.update_gbdt_targets( - node_features, node_features_before, train_mask - ) + gbdt_y_train = self.update_gbdt_targets(node_features, node_features_before, train_mask) self.log_epoch( pbar, @@ -488,14 +465,8 @@ def fit( break if np.isclose(gbdt_y_train.sum(), 0.0): - print("Node embeddings do not change anymore. Stopping...") break - print( - "Best {} at iteration {}: {:.3f}/{:.3f}/{:.3f}".format( - metric_name, best_val_epoch, *best_metric - ) - ) return metrics def predict(self, graph, X, test_mask): @@ -522,7 +493,7 @@ def plot_interactive( metric_results = metrics[metric_name] xs = [list(range(len(metric_results)))] * len(metric_results[0]) - ys = list(zip(*metric_results)) + ys = list(zip(*metric_results, strict=False)) fig = go.Figure() for i in range(len(ys)): @@ -540,9 +511,9 @@ def plot_interactive( title_x=0.5, xaxis_title="Epoch", yaxis_title=metric_name, - font=dict( - size=40, - ), + font={ + "size": 40, + }, height=600, ) @@ -566,7 +537,7 @@ def __init__( use_mlp=False, join_with_mlp=False, ): - super(GNNModelDGL, self).__init__() + super().__init__() self.name = name self.use_mlp = use_mlp self.join_with_mlp = join_with_mlp @@ -599,14 +570,10 @@ def __init__( self.l2 = ChebConvDGL(hidden_dim, out_dim, k=3) self.drop = Dropout(p=dropout) elif name == "agnn": - self.lin1 = Sequential( - Dropout(p=dropout), Linear(in_dim, hidden_dim), ELU() - ) + self.lin1 = Sequential(Dropout(p=dropout), Linear(in_dim, hidden_dim), ELU()) self.l1 = AGNNConvDGL(learn_beta=False) self.l2 = AGNNConvDGL(learn_beta=True) - self.lin2 = Sequential( - Dropout(p=dropout), Linear(hidden_dim, out_dim), ELU() - ) + self.lin2 = Sequential(Dropout(p=dropout), Linear(hidden_dim, out_dim), ELU()) elif name == "appnp": self.lin1 = Sequential( Dropout(p=dropout), @@ -621,10 +588,7 @@ def forward(self, graph, features): h = features logits = None if self.use_mlp: - if self.join_with_mlp: - h = torch.cat((h, self.mlp(features)), 1) - else: - h = self.mlp(features) + h = torch.cat((h, self.mlp(features)), 1) if self.join_with_mlp else self.mlp(features) if self.name == "gat": h = self.l1(graph, h).flatten(1) logits = self.l2(graph, h).mean(1) @@ -648,6 +612,7 @@ def forward(self, graph, features): return logits + def normalize_features(X, train_mask, val_mask, test_mask): min_max_scaler = preprocessing.MinMaxScaler() A = X.to_numpy(copy=True) @@ -666,12 +631,8 @@ def encode_cat_features(X, y, cat_features, train_mask, val_mask, test_mask): enc = CatBoostEncoder() A = X.to_numpy(copy=True) b = y.to_numpy(copy=True) - A[np.ix_(train_mask, cat_features)] = enc.fit_transform( - A[np.ix_(train_mask, cat_features)], b[train_mask] - ) - A[np.ix_(val_mask + test_mask, cat_features)] = enc.transform( - A[np.ix_(val_mask + test_mask, cat_features)] - ) + A[np.ix_(train_mask, cat_features)] = enc.fit_transform(A[np.ix_(train_mask, cat_features)], b[train_mask]) + A[np.ix_(val_mask + test_mask, cat_features)] = enc.transform(A[np.ix_(val_mask + test_mask, cat_features)]) A = A.astype(float) return pd.DataFrame(A, columns=X.columns) diff --git a/hugegraph-ml/src/hugegraph_ml/models/bgrl.py b/hugegraph-ml/src/hugegraph_ml/models/bgrl.py index 0000e546c..c4e1e5542 100644 --- a/hugegraph-ml/src/hugegraph_ml/models/bgrl.py +++ b/hugegraph-ml/src/hugegraph_ml/models/bgrl.py @@ -28,16 +28,19 @@ """ import copy +import itertools + +import dgl +import numpy as np import torch +from dgl.nn.pytorch.conv import GraphConv +from dgl.transforms import Compose, DropEdge, FeatMask from torch import nn from torch.nn import BatchNorm1d from torch.nn.functional import cosine_similarity -import dgl -from dgl.nn.pytorch.conv import GraphConv -from dgl.transforms import Compose, DropEdge, FeatMask -import numpy as np -class MLP_Predictor(nn.Module): + +class MLPPredictor(nn.Module): r"""MLP used for predictor. The MLP has one hidden layer. Args: input_size (int): Size of input features. @@ -67,10 +70,10 @@ def reset_parameters(self): class GCN(nn.Module): def __init__(self, layer_sizes, batch_norm_mm=0.99): - super(GCN, self).__init__() + super().__init__() self.layers = nn.ModuleList() - for in_dim, out_dim in zip(layer_sizes[:-1], layer_sizes[1:]): + for in_dim, out_dim in itertools.pairwise(layer_sizes): self.layers.append(GraphConv(in_dim, out_dim)) self.layers.append(BatchNorm1d(out_dim, momentum=batch_norm_mm)) self.layers.append(nn.PReLU()) @@ -78,10 +81,7 @@ def __init__(self, layer_sizes, batch_norm_mm=0.99): def forward(self, g, feats): x = feats for layer in self.layers: - if isinstance(layer, GraphConv): - x = layer(g, x) - else: - x = layer(x) + x = layer(g, x) if isinstance(layer, GraphConv) else layer(x) return x def reset_parameters(self): @@ -89,6 +89,7 @@ def reset_parameters(self): if hasattr(layer, "reset_parameters"): layer.reset_parameters() + class BGRL(nn.Module): r"""BGRL architecture for Graph representation learning. Args: @@ -100,7 +101,7 @@ class BGRL(nn.Module): """ def __init__(self, encoder, predictor): - super(BGRL, self).__init__() + super().__init__() # online network self.online_encoder = encoder self.predictor = predictor @@ -117,9 +118,7 @@ def __init__(self, encoder, predictor): def trainable_parameters(self): r"""Returns the parameters that will be updated via an optimizer.""" - return list(self.online_encoder.parameters()) + list( - self.predictor.parameters() - ) + return list(self.online_encoder.parameters()) + list(self.predictor.parameters()) @torch.no_grad() def update_target_network(self, mm): @@ -127,18 +126,12 @@ def update_target_network(self, mm): Args: mm (float): Momentum used in moving average update. """ - for param_q, param_k in zip( - self.online_encoder.parameters(), self.target_encoder.parameters() - ): + for param_q, param_k in zip(self.online_encoder.parameters(), self.target_encoder.parameters(), strict=False): param_k.data.mul_(mm).add_(param_q.data, alpha=1.0 - mm) def forward(self, graph, feat): - transform_1 = get_graph_drop_transform( - drop_edge_p=0.3, feat_mask_p=0.3 - ) - transform_2 = get_graph_drop_transform( - drop_edge_p=0.2, feat_mask_p=0.4 - ) + transform_1 = get_graph_drop_transform(drop_edge_p=0.3, feat_mask_p=0.3) + transform_2 = get_graph_drop_transform(drop_edge_p=0.2, feat_mask_p=0.4) online_x = transform_1(graph) target_x = transform_2(graph) online_x, target_x = dgl.add_self_loop(online_x), dgl.add_self_loop(target_x) @@ -159,8 +152,8 @@ def forward(self, graph, feat): target_y2 = self.target_encoder(online_x, online_feats).detach() loss = ( 2 - - cosine_similarity(online_q1, target_y1.detach(), dim=-1).mean() # pylint: disable=E1102 - - cosine_similarity(online_q2, target_y2.detach(), dim=-1).mean() # pylint: disable=E1102 + - cosine_similarity(online_q1, target_y1.detach(), dim=-1).mean() # pylint: disable=E1102 + - cosine_similarity(online_q2, target_y2.detach(), dim=-1).mean() # pylint: disable=E1102 ) return loss @@ -183,6 +176,7 @@ def get_embedding(self, graph, feats): h = self.target_encoder(graph, feats) # Encode the node features with GCN return h.detach() # Detach from computation graph for evaluation + def compute_representations(net, dataset, device): r"""Pre-computes the representations for the entire data. Returns: @@ -211,6 +205,7 @@ def compute_representations(net, dataset, device): labels = torch.cat(labels, dim=0) return [reps, labels] + class CosineDecayScheduler: def __init__(self, max_val, warmup_steps, total_steps): self.max_val = max_val @@ -223,23 +218,15 @@ def get(self, step): elif self.warmup_steps <= step <= self.total_steps: return ( self.max_val - * ( - 1 - + np.cos( - (step - self.warmup_steps) - * np.pi - / (self.total_steps - self.warmup_steps) - ) - ) + * (1 + np.cos((step - self.warmup_steps) * np.pi / (self.total_steps - self.warmup_steps))) / 2 ) else: - raise ValueError( - f"Step ({step}) > total number of steps ({self.total_steps})." - ) + raise ValueError(f"Step ({step}) > total number of steps ({self.total_steps}).") + def get_graph_drop_transform(drop_edge_p, feat_mask_p): - transforms = list() + transforms = [] # make copy of graph transforms.append(copy.deepcopy) diff --git a/hugegraph-ml/src/hugegraph_ml/models/care_gnn.py b/hugegraph-ml/src/hugegraph_ml/models/care_gnn.py index 994513e14..ef9eaa621 100644 --- a/hugegraph-ml/src/hugegraph_ml/models/care_gnn.py +++ b/hugegraph-ml/src/hugegraph_ml/models/care_gnn.py @@ -45,7 +45,7 @@ def __init__( activation=None, step_size=0.02, ): - super(CAREConv, self).__init__() + super().__init__() self.activation = activation self.step_size = step_size @@ -87,9 +87,7 @@ def _top_p_sampling(self, g, p): num_neigh = th.ceil(g.in_degrees(node) * p).int().item() neigh_dist = dist[edges] if neigh_dist.shape[0] > num_neigh: - neigh_index = np.argpartition(neigh_dist.cpu().detach(), num_neigh)[ - :num_neigh - ] + neigh_index = np.argpartition(neigh_dist.cpu().detach(), num_neigh)[:num_neigh] else: neigh_index = np.arange(num_neigh) neigh_list.append(edges[neigh_index]) @@ -137,7 +135,7 @@ def __init__( activation=None, step_size=0.02, ): - super(CAREGNN, self).__init__() + super().__init__() self.in_dim = in_dim self.hid_dim = hid_dim self.num_classes = num_classes diff --git a/hugegraph-ml/src/hugegraph_ml/models/cluster_gcn.py b/hugegraph-ml/src/hugegraph_ml/models/cluster_gcn.py index 6bc078a8b..b96ed687a 100644 --- a/hugegraph-ml/src/hugegraph_ml/models/cluster_gcn.py +++ b/hugegraph-ml/src/hugegraph_ml/models/cluster_gcn.py @@ -26,10 +26,10 @@ DGL code: https://github.com/dmlc/dgl/tree/master/examples/pytorch/cluster_gcn """ -from torch import nn +import dgl.nn as dglnn import torch.nn.functional as F +from torch import nn -import dgl.nn as dglnn class SAGE(nn.Module): # pylint: disable=E1101 @@ -43,9 +43,9 @@ def __init__(self, in_feats, n_hidden, n_classes): def forward(self, sg, x): h = x - for l, layer in enumerate(self.layers): + for layer_idx, layer in enumerate(self.layers): h = layer(sg, h) - if l != len(self.layers) - 1: + if layer_idx != len(self.layers) - 1: h = F.relu(h) h = self.dropout(h) return h diff --git a/hugegraph-ml/src/hugegraph_ml/models/correct_and_smooth.py b/hugegraph-ml/src/hugegraph_ml/models/correct_and_smooth.py index 50481c60c..a01b11cb0 100644 --- a/hugegraph-ml/src/hugegraph_ml/models/correct_and_smooth.py +++ b/hugegraph-ml/src/hugegraph_ml/models/correct_and_smooth.py @@ -29,13 +29,13 @@ import dgl.function as fn import torch -from torch import nn import torch.nn.functional as F +from torch import nn class MLPLinear(nn.Module): def __init__(self, in_dim, out_dim): - super(MLPLinear, self).__init__() + super().__init__() self.linear = nn.Linear(in_dim, out_dim) self.reset_parameters() self.criterion = nn.CrossEntropyLoss() @@ -55,7 +55,7 @@ def inference(self, graph, feats): class MLP(nn.Module): def __init__(self, in_dim, hid_dim, out_dim, num_layers, dropout=0.0): - super(MLP, self).__init__() + super().__init__() assert num_layers >= 2 self.linears = nn.ModuleList() @@ -80,7 +80,7 @@ def reset_parameters(self): layer.reset_parameters() def forward(self, graph, x): - for linear, bn in zip(self.linears[:-1], self.bns): + for linear, bn in zip(self.linears[:-1], self.bns, strict=False): x = linear(x) x = F.relu(x, inplace=True) x = bn(x) @@ -119,7 +119,7 @@ class LabelPropagation(nn.Module): """ def __init__(self, num_layers, alpha, adj="DAD"): - super(LabelPropagation, self).__init__() + super().__init__() self.num_layers = num_layers self.alpha = alpha @@ -138,11 +138,7 @@ def forward(self, g, labels, mask=None, post_step=lambda y: y.clamp_(0.0, 1.0)): last = (1 - self.alpha) * y degs = g.in_degrees().float().clamp(min=1) - norm = ( - torch.pow(degs, -0.5 if self.adj == "DAD" else -1) - .to(labels.device) - .unsqueeze(1) - ) + norm = torch.pow(degs, -0.5 if self.adj == "DAD" else -1).to(labels.device).unsqueeze(1) for _ in range(self.num_layers): # Assume the graphs to be undirected @@ -202,17 +198,13 @@ def __init__( autoscale=True, scale=1.0, ): - super(CorrectAndSmooth, self).__init__() + super().__init__() self.autoscale = autoscale self.scale = scale - self.prop1 = LabelPropagation( - num_correction_layers, correction_alpha, correction_adj - ) - self.prop2 = LabelPropagation( - num_smoothing_layers, smoothing_alpha, smoothing_adj - ) + self.prop1 = LabelPropagation(num_correction_layers, correction_alpha, correction_adj) + self.prop2 = LabelPropagation(num_smoothing_layers, smoothing_alpha, smoothing_adj) def correct(self, g, y_soft, y_true, mask): with g.local_scope(): @@ -227,9 +219,7 @@ def correct(self, g, y_soft, y_true, mask): error[mask] = y_true - y_soft[mask] if self.autoscale: - smoothed_error = self.prop1( - g, error, post_step=lambda x: x.clamp_(-1.0, 1.0) - ) + smoothed_error = self.prop1(g, error, post_step=lambda x: x.clamp_(-1.0, 1.0)) sigma = error[mask].abs().sum() / numel scale = sigma / smoothed_error.abs().sum(dim=1, keepdim=True) scale[scale.isinf() | (scale > 1000)] = 1.0 diff --git a/hugegraph-ml/src/hugegraph_ml/models/dagnn.py b/hugegraph-ml/src/hugegraph_ml/models/dagnn.py index c455cc1f9..66a7f008a 100644 --- a/hugegraph-ml/src/hugegraph_ml/models/dagnn.py +++ b/hugegraph-ml/src/hugegraph_ml/models/dagnn.py @@ -30,13 +30,13 @@ import dgl.function as fn import torch from torch import nn -from torch.nn import functional as F, Parameter - +from torch.nn import Parameter +from torch.nn import functional as F class DAGNNConv(nn.Module): def __init__(self, in_dim, k): - super(DAGNNConv, self).__init__() + super().__init__() self.s = Parameter(torch.FloatTensor(in_dim, 1)) self.k = k @@ -58,7 +58,7 @@ def forward(self, graph, feats): for _ in range(self.k): feats = feats * norm graph.ndata["h"] = feats - graph.update_all(fn.copy_u("h", "m"), fn.sum("m", "h")) # pylint: disable=E1101 + graph.update_all(fn.copy_u("h", "m"), fn.sum("m", "h")) # pylint: disable=E1101 feats = graph.ndata["h"] feats = feats * norm results.append(feats) @@ -73,7 +73,7 @@ def forward(self, graph, feats): class MLPLayer(nn.Module): def __init__(self, in_dim, out_dim, bias=True, activation=None, dropout=0): - super(MLPLayer, self).__init__() + super().__init__() self.linear = nn.Linear(in_dim, out_dim, bias=bias) self.activation = activation @@ -108,7 +108,7 @@ def __init__( activation=F.relu, dropout=0, ): - super(DAGNN, self).__init__() + super().__init__() self.mlp = nn.ModuleList() self.mlp.append( MLPLayer( diff --git a/hugegraph-ml/src/hugegraph_ml/models/deepergcn.py b/hugegraph-ml/src/hugegraph_ml/models/deepergcn.py index 26b41fca5..05203c1f5 100644 --- a/hugegraph-ml/src/hugegraph_ml/models/deepergcn.py +++ b/hugegraph-ml/src/hugegraph_ml/models/deepergcn.py @@ -16,7 +16,6 @@ # under the License. - """ DeeperGCN @@ -27,15 +26,15 @@ DGL code: https://github.com/dmlc/dgl/tree/master/examples/pytorch/deepergcn """ +import dgl.function as fn import torch -from torch import nn import torch.nn.functional as F - -import dgl.function as fn from dgl.nn.functional import edge_softmax +from torch import nn # pylint: disable=E1101,E0401 + class DeeperGCN(nn.Module): r""" @@ -78,7 +77,7 @@ def __init__( aggr="softmax", mlp_layers=1, ): - super(DeeperGCN, self).__init__() + super().__init__() self.num_layers = num_layers self.dropout = dropout @@ -178,7 +177,7 @@ def __init__( mlp_layers=1, eps=1e-7, ): - super(GENConv, self).__init__() + super().__init__() self.aggr = aggregator self.eps = eps @@ -192,9 +191,7 @@ def __init__( self.msg_norm = MessageNorm(learn_msg_scale) if msg_norm else None self.beta = ( - nn.Parameter(torch.Tensor([beta]), requires_grad=True) - if learn_beta and self.aggr == "softmax" - else beta + nn.Parameter(torch.Tensor([beta]), requires_grad=True) if learn_beta and self.aggr == "softmax" else beta ) self.p = nn.Parameter(torch.Tensor([p]), requires_grad=True) if learn_p else p @@ -252,7 +249,7 @@ def __init__(self, channels, act="relu", dropout=0.0, bias=True): layers.append(nn.ReLU()) layers.append(nn.Dropout(dropout)) - super(MLP, self).__init__(*layers) + super().__init__(*layers) class MessageNorm(nn.Module): @@ -267,7 +264,7 @@ class MessageNorm(nn.Module): """ def __init__(self, learn_scale=False): - super(MessageNorm, self).__init__() + super().__init__() self.scale = nn.Parameter(torch.FloatTensor([1.0]), requires_grad=learn_scale) def forward(self, feats, msg, p=2): diff --git a/hugegraph-ml/src/hugegraph_ml/models/dgi.py b/hugegraph-ml/src/hugegraph_ml/models/dgi.py index 50caaaa37..981bdd95b 100644 --- a/hugegraph-ml/src/hugegraph_ml/models/dgi.py +++ b/hugegraph-ml/src/hugegraph_ml/models/dgi.py @@ -49,7 +49,7 @@ class DGI(nn.Module): """ def __init__(self, n_in_feats, n_hidden=512, n_layers=2, p_drop=0): - super(DGI, self).__init__() + super().__init__() self.encoder = GCNEncoder(n_in_feats, n_hidden, n_layers, p_drop) # Initialize the GCN-based encoder self.discriminator = Discriminator(n_hidden) # Initialize the discriminator for mutual information maximization self.loss = nn.BCEWithLogitsLoss() # Binary cross-entropy loss with logits for classification @@ -118,7 +118,7 @@ class GCNEncoder(nn.Module): """ def __init__(self, n_in_feats, n_hidden, n_layers, p_drop): - super(GCNEncoder, self).__init__() + super().__init__() assert n_layers >= 2, "The number of GNN layers must be at least 2." self.input_layer = GraphConv( n_in_feats, n_hidden, activation=nn.PReLU(n_hidden) @@ -170,7 +170,7 @@ class Discriminator(nn.Module): """ def __init__(self, n_hidden): - super(Discriminator, self).__init__() + super().__init__() self.weight = nn.Parameter(torch.Tensor(n_hidden, n_hidden)) # Define weights for bilinear transformation self.uniform_weight() # Initialize the weights uniformly diff --git a/hugegraph-ml/src/hugegraph_ml/models/diffpool.py b/hugegraph-ml/src/hugegraph_ml/models/diffpool.py index a09d78ac4..102e25061 100644 --- a/hugegraph-ml/src/hugegraph_ml/models/diffpool.py +++ b/hugegraph-ml/src/hugegraph_ml/models/diffpool.py @@ -48,7 +48,7 @@ def __init__( pool_ratio=0.1, concat=False, ): - super(DiffPool, self).__init__() + super().__init__() self.link_pred = True self.concat = concat self.n_pooling = n_pooling @@ -73,17 +73,10 @@ def __init__( self.gc_before_pool.append( SAGEConv(n_hidden, n_hidden, aggregator_type, feat_drop=dropout, activation=F.relu) ) - self.gc_before_pool.append( - SAGEConv(n_hidden, n_embedding, aggregator_type, feat_drop=dropout, activation=None) - ) + self.gc_before_pool.append(SAGEConv(n_hidden, n_embedding, aggregator_type, feat_drop=dropout, activation=None)) assign_dims = [self.assign_dim] - if self.concat: - # diffpool layer receive pool_embedding_dim node feature tensor - # and return pool_embedding_dim node embedding - pool_embedding_dim = n_hidden * (n_layers - 1) + n_embedding - else: - pool_embedding_dim = n_embedding + pool_embedding_dim = n_hidden * (n_layers - 1) + n_embedding if self.concat else n_embedding self.first_diffpool_layer = _DiffPoolBatchedGraphLayer( pool_embedding_dim, @@ -103,9 +96,7 @@ def __init__( self.assign_dim = int(self.assign_dim * pool_ratio) # each pooling module for _ in range(n_pooling - 1): - self.diffpool_layers.append( - _BatchedDiffPool(pool_embedding_dim, self.assign_dim, n_hidden, self.link_pred) - ) + self.diffpool_layers.append(_BatchedDiffPool(pool_embedding_dim, self.assign_dim, n_hidden, self.link_pred)) gc_after_per_pool = nn.ModuleList() for _ in range(n_layers - 1): gc_after_per_pool.append(_BatchedGraphSAGE(n_hidden, n_hidden)) @@ -167,10 +158,7 @@ def forward(self, g, feat): if self.num_aggs == 2: readout, _ = torch.max(h, dim=1) out_all.append(readout) - if self.concat or self.num_aggs > 1: - final_readout = torch.cat(out_all, dim=1) - else: - final_readout = readout + final_readout = torch.cat(out_all, dim=1) if self.concat or self.num_aggs > 1 else readout ypred = self.pred_layer(final_readout) return ypred @@ -228,7 +216,7 @@ def forward(self, x, adj): class _BatchedDiffPool(nn.Module): def __init__(self, n_feat, n_next, n_hid, link_pred=False, entropy=True): - super(_BatchedDiffPool, self).__init__() + super().__init__() self.link_pred = link_pred self.link_pred_layer = _LinkPredLoss() self.embed = _BatchedGraphSAGE(n_feat, n_hid) @@ -262,7 +250,7 @@ def __init__( aggregator_type, link_pred, ): - super(_DiffPoolBatchedGraphLayer, self).__init__() + super().__init__() self.embedding_dim = input_dim self.assign_dim = assign_dim self.hidden_dim = output_feat_dim @@ -338,8 +326,8 @@ def _batch2tensor(batch_adj, batch_feat, node_per_pool_graph): end = (i + 1) * node_per_pool_graph adj_list.append(batch_adj[start:end, start:end]) feat_list.append(batch_feat[start:end, :]) - adj_list = list(map(lambda x: torch.unsqueeze(x, 0), adj_list)) - feat_list = list(map(lambda x: torch.unsqueeze(x, 0), feat_list)) + adj_list = [torch.unsqueeze(x, 0) for x in adj_list] + feat_list = [torch.unsqueeze(x, 0) for x in feat_list] adj = torch.cat(adj_list, dim=0) feat = torch.cat(feat_list, dim=0) @@ -373,10 +361,7 @@ def _gcn_forward(g, h, gc_layers, cat=False): block_readout.append(h) h = gc_layers[-1](g, h) block_readout.append(h) - if cat: - block = torch.cat(block_readout, dim=1) # N x F, F = F1 + F2 + ... - else: - block = h + block = torch.cat(block_readout, dim=1) if cat else h return block @@ -385,8 +370,5 @@ def _gcn_forward_tensorized(h, adj, gc_layers, cat=False): for gc_layer in gc_layers: h = gc_layer(h, adj) block_readout.append(h) - if cat: - block = torch.cat(block_readout, dim=2) # N x F, F = F1 + F2 + ... - else: - block = h + block = torch.cat(block_readout, dim=2) if cat else h return block diff --git a/hugegraph-ml/src/hugegraph_ml/models/gatne.py b/hugegraph-ml/src/hugegraph_ml/models/gatne.py index 91bb582b2..3c10f7046 100644 --- a/hugegraph-ml/src/hugegraph_ml/models/gatne.py +++ b/hugegraph-ml/src/hugegraph_ml/models/gatne.py @@ -28,27 +28,26 @@ """ import math -import time import multiprocessing +import time from functools import partial, reduce +import dgl +import dgl.function as fn import numpy as np - import torch -from torch import nn import torch.nn.functional as F +from torch import nn from torch.nn.parameter import Parameter -import dgl -import dgl.function as fn -class NeighborSampler(object): +class NeighborSampler: def __init__(self, g, num_fanouts): self.g = g self.num_fanouts = num_fanouts def sample(self, pairs): - heads, tails, types = zip(*pairs) + heads, tails, types = zip(*pairs, strict=False) seeds, head_invmap = torch.unique(torch.LongTensor(heads), return_inverse=True) blocks = [] for fanout in reversed(self.num_fanouts): @@ -74,7 +73,7 @@ def __init__( edge_type_count, dim_a, ): - super(DGLGATNE, self).__init__() + super().__init__() self.num_nodes = num_nodes self.embedding_size = embedding_size self.embedding_u_size = embedding_u_size @@ -83,15 +82,9 @@ def __init__( self.dim_a = dim_a self.node_embeddings = Parameter(torch.FloatTensor(num_nodes, embedding_size)) - self.node_type_embeddings = Parameter( - torch.FloatTensor(num_nodes, edge_type_count, embedding_u_size) - ) - self.trans_weights = Parameter( - torch.FloatTensor(edge_type_count, embedding_u_size, embedding_size) - ) - self.trans_weights_s1 = Parameter( - torch.FloatTensor(edge_type_count, embedding_u_size, dim_a) - ) + self.node_type_embeddings = Parameter(torch.FloatTensor(num_nodes, edge_type_count, embedding_u_size)) + self.trans_weights = Parameter(torch.FloatTensor(edge_type_count, embedding_u_size, embedding_size)) + self.trans_weights_s1 = Parameter(torch.FloatTensor(edge_type_count, embedding_u_size, dim_a)) self.trans_weights_s2 = Parameter(torch.FloatTensor(edge_type_count, dim_a, 1)) self.reset_parameters() @@ -118,15 +111,13 @@ def forward(self, block): block.dstdata[edge_type] = self.node_type_embeddings[output_nodes, i] block.update_all( fn.copy_u(edge_type, "m"), - fn.sum("m", edge_type), # pylint: disable=E1101 + fn.sum("m", edge_type), # pylint: disable=E1101 etype=edge_type, ) node_type_embed.append(block.dstdata[edge_type]) node_type_embed = torch.stack(node_type_embed, 1) - tmp_node_type_embed = node_type_embed.unsqueeze(2).view( - -1, 1, self.embedding_u_size - ) + tmp_node_type_embed = node_type_embed.unsqueeze(2).view(-1, 1, self.embedding_u_size) trans_w = ( self.trans_weights.unsqueeze(0) .repeat(batch_size, 1, 1, 1) @@ -137,11 +128,7 @@ def forward(self, block): .repeat(batch_size, 1, 1, 1) .view(-1, self.embedding_u_size, self.dim_a) ) - trans_w_s2 = ( - self.trans_weights_s2.unsqueeze(0) - .repeat(batch_size, 1, 1, 1) - .view(-1, self.dim_a, 1) - ) + trans_w_s2 = self.trans_weights_s2.unsqueeze(0).repeat(batch_size, 1, 1, 1).view(-1, self.dim_a, 1) attention = ( F.softmax( @@ -157,14 +144,10 @@ def forward(self, block): .repeat(1, self.edge_type_count, 1) ) - node_type_embed = torch.matmul(attention, node_type_embed).view( - -1, 1, self.embedding_u_size - ) - node_embed = node_embed[output_nodes].unsqueeze(1).repeat( - 1, self.edge_type_count, 1 - ) + torch.matmul(node_type_embed, trans_w).view( - -1, self.edge_type_count, self.embedding_size - ) + node_type_embed = torch.matmul(attention, node_type_embed).view(-1, 1, self.embedding_u_size) + node_embed = node_embed[output_nodes].unsqueeze(1).repeat(1, self.edge_type_count, 1) + torch.matmul( + node_type_embed, trans_w + ).view(-1, self.edge_type_count, self.embedding_size) last_node_embed = F.normalize(node_embed, dim=2) return last_node_embed # [batch_size, edge_type_count, embedding_size] @@ -172,19 +155,14 @@ def forward(self, block): class NSLoss(nn.Module): def __init__(self, num_nodes, num_sampled, embedding_size): - super(NSLoss, self).__init__() + super().__init__() self.num_nodes = num_nodes self.num_sampled = num_sampled self.embedding_size = embedding_size self.weights = Parameter(torch.FloatTensor(num_nodes, embedding_size)) # [ (log(i+2) - log(i+1)) / log(num_nodes + 1)] self.sample_weights = F.normalize( - torch.Tensor( - [ - (math.log(k + 2) - math.log(k + 1)) / math.log(num_nodes + 1) - for k in range(num_nodes) - ] - ), + torch.Tensor([(math.log(k + 2) - math.log(k + 1)) / math.log(num_nodes + 1) for k in range(num_nodes)]), dim=0, ) @@ -195,16 +173,10 @@ def reset_parameters(self): def forward(self, input, embs, label): n = input.shape[0] - log_target = torch.log( - torch.sigmoid(torch.sum(torch.mul(embs, self.weights[label]), 1)) - ) - negs = torch.multinomial( - self.sample_weights, self.num_sampled * n, replacement=True - ).view(n, self.num_sampled) + log_target = torch.log(torch.sigmoid(torch.sum(torch.mul(embs, self.weights[label]), 1))) + negs = torch.multinomial(self.sample_weights, self.num_sampled * n, replacement=True).view(n, self.num_sampled) noise = torch.neg(self.weights[negs]) - sum_log_sampled = torch.sum( - torch.log(torch.sigmoid(torch.bmm(noise, embs.unsqueeze(2)))), 1 - ).squeeze() + sum_log_sampled = torch.sum(torch.log(torch.sigmoid(torch.bmm(noise, embs.unsqueeze(2)))), 1).squeeze() loss = log_target + sum_log_sampled return -loss.sum() / n @@ -226,8 +198,7 @@ def generate_pairs_parallel(walks, skip_window=None, layer_id=None): def generate_pairs(all_walks, window_size, num_workers): # for each node, choose the first neighbor and second neighbor of it to form pairs # Get all worker processes - start_time = time.time() - print(f"We are generating pairs with {num_workers} cores.") + time.time() # Start all worker processes pool = multiprocessing.Pool(processes=num_workers) @@ -236,10 +207,7 @@ def generate_pairs(all_walks, window_size, num_workers): for layer_id, walks in enumerate(all_walks): block_num = len(walks) // num_workers if block_num > 0: - walks_list = [ - walks[i * block_num : min((i + 1) * block_num, len(walks))] - for i in range(num_workers) - ] + walks_list = [walks[i * block_num : min((i + 1) * block_num, len(walks))] for i in range(num_workers)] else: walks_list = [walks] tmp_result = pool.map( @@ -253,8 +221,7 @@ def generate_pairs(all_walks, window_size, num_workers): pairs += reduce(lambda x, y: x + y, tmp_result) pool.close() - end_time = time.time() - print(f"Generate pairs end, use {end_time - start_time}s.") + time.time() return np.array([list(pair) for pair in set(pairs)]) diff --git a/hugegraph-ml/src/hugegraph_ml/models/gin_global_pool.py b/hugegraph-ml/src/hugegraph_ml/models/gin_global_pool.py index 8012e729a..2f8341afd 100644 --- a/hugegraph-ml/src/hugegraph_ml/models/gin_global_pool.py +++ b/hugegraph-ml/src/hugegraph_ml/models/gin_global_pool.py @@ -19,7 +19,7 @@ import torch.nn.functional as F from dgl.nn.pytorch.conv import GINConv -from dgl.nn.pytorch.glob import SumPooling, AvgPooling, MaxPooling, GlobalAttentionPooling, Set2Set +from dgl.nn.pytorch.glob import AvgPooling, GlobalAttentionPooling, MaxPooling, Set2Set, SumPooling from torch import nn @@ -32,13 +32,8 @@ def __init__(self, n_in_feats, n_out_feats, n_hidden=16, n_layers=5, p_drop=0.5, # five-layer GCN with two-layer MLP aggregator and sum-neighbor-pooling scheme assert n_layers >= 2, "The number of GIN layers must be at least 2." for layer in range(n_layers - 1): - if layer == 0: - mlp = _MLP(n_in_feats, n_hidden, n_hidden) - else: - mlp = _MLP(n_hidden, n_hidden, n_hidden) - self.gin_layers.append( - GINConv(mlp, learn_eps=False) - ) # set to True if learning epsilon + mlp = _MLP(n_in_feats, n_hidden, n_hidden) if layer == 0 else _MLP(n_hidden, n_hidden, n_hidden) + self.gin_layers.append(GINConv(mlp, learn_eps=False)) # set to True if learning epsilon self.batch_norms.append(nn.BatchNorm1d(n_hidden)) # linear functions for graph sum pooling of output of each layer self.linear_prediction = nn.ModuleList() diff --git a/hugegraph-ml/src/hugegraph_ml/models/grace.py b/hugegraph-ml/src/hugegraph_ml/models/grace.py index f80230e15..1e16cb12a 100644 --- a/hugegraph-ml/src/hugegraph_ml/models/grace.py +++ b/hugegraph-ml/src/hugegraph_ml/models/grace.py @@ -68,14 +68,14 @@ def __init__( n_hidden=128, n_out_feats=128, n_layers=2, - act_fn=nn.ReLU(), + act_fn=None, temp=0.4, edges_removing_rate_1=0.2, edges_removing_rate_2=0.4, feats_masking_rate_1=0.3, feats_masking_rate_2=0.4, ): - super(GRACE, self).__init__() + super().__init__() self.encoder = GCN(n_in_feats, n_hidden, act_fn, n_layers) # Initialize the GCN encoder # Initialize the MLP projector to map the encoded features to the contrastive space self.proj = MLP(n_hidden, n_out_feats) @@ -210,7 +210,7 @@ class GCN(nn.Module): """ def __init__(self, n_in_feats, n_out_feats, act_fn, n_layers=2): - super(GCN, self).__init__() + super().__init__() assert n_layers >= 2, "Number of layers should be at least 2." self.n_layers = n_layers # Set the number of layers self.n_hidden = n_out_feats * 2 # Set the hidden dimension as twice the output dimension @@ -255,7 +255,7 @@ class MLP(nn.Module): """ def __init__(self, n_in_feats, n_out_feats): - super(MLP, self).__init__() + super().__init__() self.fc1 = nn.Linear(n_in_feats, n_out_feats) # Define the first fully connected layer self.fc2 = nn.Linear(n_out_feats, n_out_feats) # Define the second fully connected layer diff --git a/hugegraph-ml/src/hugegraph_ml/models/grand.py b/hugegraph-ml/src/hugegraph_ml/models/grand.py index 24c1bf0a1..870de22d4 100644 --- a/hugegraph-ml/src/hugegraph_ml/models/grand.py +++ b/hugegraph-ml/src/hugegraph_ml/models/grand.py @@ -76,7 +76,7 @@ def __init__( temp=0.5, lam=1.0, ): - super(GRAND, self).__init__() + super().__init__() self.sample = sample # Number of augmentations self.order = order # Order of propagation steps @@ -249,7 +249,7 @@ class MLP(nn.Module): """ def __init__(self, n_in_feats, n_hidden, n_out_feats, p_input_drop, p_hidden_drop, bn): - super(MLP, self).__init__() + super().__init__() self.layer1 = nn.Linear(n_in_feats, n_hidden, bias=True) # First linear layer self.layer2 = nn.Linear(n_hidden, n_out_feats, bias=True) # Second linear layer self.input_drop = nn.Dropout(p_input_drop) # Dropout for input features diff --git a/hugegraph-ml/src/hugegraph_ml/models/jknet.py b/hugegraph-ml/src/hugegraph_ml/models/jknet.py index b778068f9..5c97de9c9 100644 --- a/hugegraph-ml/src/hugegraph_ml/models/jknet.py +++ b/hugegraph-ml/src/hugegraph_ml/models/jknet.py @@ -51,7 +51,7 @@ class JKNet(nn.Module): """ def __init__(self, n_in_feats, n_out_feats, n_hidden=32, n_layers=6, mode="cat", dropout=0.5): - super(JKNet, self).__init__() + super().__init__() self.mode = mode self.dropout = nn.Dropout(dropout) # Dropout layer to prevent overfitting diff --git a/hugegraph-ml/src/hugegraph_ml/models/mlp.py b/hugegraph-ml/src/hugegraph_ml/models/mlp.py index a659e6b1c..0a9db5bed 100644 --- a/hugegraph-ml/src/hugegraph_ml/models/mlp.py +++ b/hugegraph-ml/src/hugegraph_ml/models/mlp.py @@ -34,7 +34,7 @@ class MLPClassifier(nn.Module): """ def __init__(self, n_in_feat, n_out_feat, n_hidden=512): - super(MLPClassifier, self).__init__() + super().__init__() # Define the first fully connected layer for projecting input features to hidden features. self.fc1 = nn.Linear(n_in_feat, n_hidden) # Define the second fully connected layer to project hidden features to output classes. diff --git a/hugegraph-ml/src/hugegraph_ml/models/pgnn.py b/hugegraph-ml/src/hugegraph_ml/models/pgnn.py index 3a870e270..623f5d8d3 100644 --- a/hugegraph-ml/src/hugegraph_ml/models/pgnn.py +++ b/hugegraph-ml/src/hugegraph_ml/models/pgnn.py @@ -31,20 +31,19 @@ import random from multiprocessing import get_context -import torch -from torch import nn -import torch.nn.functional as F - +import dgl.function as fn import networkx as nx import numpy as np -from tqdm.auto import tqdm +import torch +import torch.nn.functional as F from sklearn.metrics import roc_auc_score +from torch import nn +from tqdm.auto import tqdm -import dgl.function as fn -class PGNN_layer(nn.Module): +class PGNNLayer(nn.Module): def __init__(self, input_dim, output_dim): - super(PGNN_layer, self).__init__() + super().__init__() self.input_dim = input_dim self.linear_hidden_u = nn.Linear(input_dim, output_dim) @@ -59,17 +58,15 @@ def forward(self, graph, feature, anchor_eid, dists_max): graph.srcdata.update({"u_feat": u_feat}) graph.dstdata.update({"v_feat": v_feat}) - graph.apply_edges(fn.u_mul_e("u_feat", "sp_dist", "u_message")) # pylint: disable=E1101 - graph.apply_edges(fn.v_add_e("v_feat", "u_message", "message")) # pylint: disable=E1101 + graph.apply_edges(fn.u_mul_e("u_feat", "sp_dist", "u_message")) # pylint: disable=E1101 + graph.apply_edges(fn.v_add_e("v_feat", "u_message", "message")) # pylint: disable=E1101 messages = torch.index_select( graph.edata["message"], 0, torch.LongTensor(anchor_eid).to(feature.device), ) - messages = messages.reshape( - dists_max.shape[0], dists_max.shape[1], messages.shape[-1] - ) + messages = messages.reshape(dists_max.shape[0], dists_max.shape[1], messages.shape[-1]) messages = self.act(messages) # n*m*d @@ -81,12 +78,12 @@ def forward(self, graph, feature, anchor_eid, dists_max): class PGNN(nn.Module): def __init__(self, input_dim, feature_dim=32, dropout=0.5): - super(PGNN, self).__init__() + super().__init__() self.dropout = nn.Dropout(dropout) self.linear_pre = nn.Linear(input_dim, feature_dim) - self.conv_first = PGNN_layer(feature_dim, feature_dim) - self.conv_out = PGNN_layer(feature_dim, feature_dim) + self.conv_first = PGNNLayer(feature_dim, feature_dim) + self.conv_out = PGNNLayer(feature_dim, feature_dim) def forward(self, data): x = data["graph"].ndata["feat"] @@ -256,9 +253,7 @@ def precompute_dist_data(edge_index, num_nodes, approximate=0): n = num_nodes dists_array = np.zeros((n, n)) - dists_dict = all_pairs_shortest_path( - graph, cutoff=approximate if approximate > 0 else None - ) + dists_dict = all_pairs_shortest_path(graph, cutoff=approximate if approximate > 0 else None) node_list = graph.nodes() for node_i in node_list: shortest_dist = dists_dict[node_i] @@ -281,9 +276,7 @@ def get_dataset(graph): approximate=-1, ) data["dists"] = torch.from_numpy(dists_removed).float() - data["edge_index"] = torch.from_numpy( - to_bidirected(data["positive_edges_train"]) - ).long() + data["edge_index"] = torch.from_numpy(to_bidirected(data["positive_edges_train"])).long() return data @@ -330,8 +323,8 @@ def get_a_graph(dists_max, dists_argmax): real_dst.extend(list(dists_argmax[i, :].numpy())) dst.extend(list(tmp_dists_argmax)) edge_weight.extend(dists_max[i, tmp_dists_argmax_idx].tolist()) - eid_dict = {(u, v): i for i, (u, v) in enumerate(list(zip(dst, src)))} - anchor_eid = [eid_dict.get((u, v)) for u, v in zip(real_dst, real_src)] + eid_dict = {(u, v): i for i, (u, v) in enumerate(list(zip(dst, src, strict=False)))} + anchor_eid = [eid_dict.get((u, v)) for u, v in zip(real_dst, real_src, strict=False)] g = (dst, src) return g, anchor_eid, edge_weight @@ -400,12 +393,8 @@ def get_loss(p, data, out, loss_func, device, get_auc=True): axis=-1, ) - nodes_first = torch.index_select( - out, 0, torch.from_numpy(edge_mask[0, :]).long().to(out.device) - ) - nodes_second = torch.index_select( - out, 0, torch.from_numpy(edge_mask[1, :]).long().to(out.device) - ) + nodes_first = torch.index_select(out, 0, torch.from_numpy(edge_mask[0, :]).long().to(out.device)) + nodes_second = torch.index_select(out, 0, torch.from_numpy(edge_mask[1, :]).long().to(out.device)) pred = torch.sum(nodes_first * nodes_second, dim=-1) diff --git a/hugegraph-ml/src/hugegraph_ml/models/seal.py b/hugegraph-ml/src/hugegraph_ml/models/seal.py index f743857bf..bc36045db 100644 --- a/hugegraph-ml/src/hugegraph_ml/models/seal.py +++ b/hugegraph-ml/src/hugegraph_ml/models/seal.py @@ -28,25 +28,23 @@ """ import argparse +import logging import os import os.path as osp -from copy import deepcopy -import logging import time +from copy import deepcopy +import dgl +import numpy as np import torch -from torch import nn import torch.nn.functional as F -from torch.utils.data import DataLoader, Dataset - -import dgl -from dgl import add_self_loop, NID +from dgl import NID, add_self_loop from dgl.dataloading.negative_sampler import Uniform from dgl.nn.pytorch import GraphConv, SAGEConv, SortPooling, SumPooling - -import numpy as np from ogb.linkproppred import DglLinkPropPredDataset, Evaluator from scipy.sparse.csgraph import shortest_path +from torch import nn +from torch.utils.data import DataLoader, Dataset from tqdm import tqdm @@ -85,13 +83,13 @@ def __init__( dropout=0.5, max_z=1000, ): - super(GCN, self).__init__() + super().__init__() self.num_layers = num_layers self.dropout = dropout self.pooling_type = pooling_type - self.use_attribute = False if node_attributes is None else True + self.use_attribute = node_attributes is not None self.use_embedding = use_embedding - self.use_edge_weight = False if edge_weights is None else True + self.use_edge_weight = edge_weights is not None self.z_embedding = nn.Embedding(max_z, hidden_units) if node_attributes is not None: @@ -114,21 +112,13 @@ def __init__( self.layers = nn.ModuleList() if gcn_type == "gcn": - self.layers.append( - GraphConv(initial_dim, hidden_units, allow_zero_in_degree=True) - ) + self.layers.append(GraphConv(initial_dim, hidden_units, allow_zero_in_degree=True)) for _ in range(num_layers - 1): - self.layers.append( - GraphConv(hidden_units, hidden_units, allow_zero_in_degree=True) - ) + self.layers.append(GraphConv(hidden_units, hidden_units, allow_zero_in_degree=True)) elif gcn_type == "sage": - self.layers.append( - SAGEConv(initial_dim, hidden_units, aggregator_type="gcn") - ) + self.layers.append(SAGEConv(initial_dim, hidden_units, aggregator_type="gcn")) for _ in range(num_layers - 1): - self.layers.append( - SAGEConv(hidden_units, hidden_units, aggregator_type="gcn") - ) + self.layers.append(SAGEConv(hidden_units, hidden_units, aggregator_type="gcn")) else: raise ValueError("Gcn type error.") @@ -162,10 +152,7 @@ def forward(self, g, z, node_id=None, edge_id=None): else: x = z_emb - if self.use_edge_weight: - edge_weight = self.edge_weights_lookup(edge_id) - else: - edge_weight = None + edge_weight = self.edge_weights_lookup(edge_id) if self.use_edge_weight else None if self.use_embedding: n_emb = self.node_embedding(node_id) @@ -219,12 +206,12 @@ def __init__( dropout=0.5, max_z=1000, ): - super(DGCNN, self).__init__() + super().__init__() self.num_layers = num_layers self.dropout = dropout - self.use_attribute = False if node_attributes is None else True + self.use_attribute = node_attributes is not None self.use_embedding = use_embedding - self.use_edge_weight = False if edge_weights is None else True + self.use_edge_weight = edge_weights is not None self.z_embedding = nn.Embedding(max_z, hidden_units) @@ -248,22 +235,14 @@ def __init__( self.layers = nn.ModuleList() if gcn_type == "gcn": - self.layers.append( - GraphConv(initial_dim, hidden_units, allow_zero_in_degree=True) - ) + self.layers.append(GraphConv(initial_dim, hidden_units, allow_zero_in_degree=True)) for _ in range(num_layers - 1): - self.layers.append( - GraphConv(hidden_units, hidden_units, allow_zero_in_degree=True) - ) + self.layers.append(GraphConv(hidden_units, hidden_units, allow_zero_in_degree=True)) self.layers.append(GraphConv(hidden_units, 1, allow_zero_in_degree=True)) elif gcn_type == "sage": - self.layers.append( - SAGEConv(initial_dim, hidden_units, aggregator_type="gcn") - ) + self.layers.append(SAGEConv(initial_dim, hidden_units, aggregator_type="gcn")) for _ in range(num_layers - 1): - self.layers.append( - SAGEConv(hidden_units, hidden_units, aggregator_type="gcn") - ) + self.layers.append(SAGEConv(hidden_units, hidden_units, aggregator_type="gcn")) self.layers.append(SAGEConv(hidden_units, 1, aggregator_type="gcn")) else: raise ValueError("Gcn type error.") @@ -274,9 +253,7 @@ def __init__( conv1d_kws = [total_latent_dim, 5] self.conv_1 = nn.Conv1d(1, conv1d_channels[0], conv1d_kws[0], conv1d_kws[0]) self.maxpool1d = nn.MaxPool1d(2, 2) - self.conv_2 = nn.Conv1d( - conv1d_channels[0], conv1d_channels[1], conv1d_kws[1], 1 - ) + self.conv_2 = nn.Conv1d(conv1d_channels[0], conv1d_channels[1], conv1d_kws[1], 1) dense_dim = int((k - 2) / 2 + 1) dense_dim = (dense_dim - conv1d_kws[1] + 1) * conv1d_channels[1] self.linear_1 = nn.Linear(dense_dim, 128) @@ -298,10 +275,7 @@ def forward(self, g, z, node_id=None, edge_id=None): x = torch.cat([z_emb, x], 1) else: x = z_emb - if self.use_edge_weight: - edge_weight = self.edge_weights_lookup(edge_id) - else: - edge_weight = None + edge_weight = self.edge_weights_lookup(edge_id) if self.use_edge_weight else None if self.use_embedding: n_emb = self.node_embedding(node_id) @@ -405,9 +379,7 @@ def drnl_node_labeling(subgraph, src, dst): dist2src = np.insert(dist2src, dst, 0, axis=0) dist2src = torch.from_numpy(dist2src) - dist2dst = shortest_path( - adj_wo_src, directed=False, unweighted=True, indices=dst - 1 - ) + dist2dst = shortest_path(adj_wo_src, directed=False, unweighted=True, indices=dst - 1) dist2dst = np.insert(dist2dst, src, 0, axis=0) dist2dst = torch.from_numpy(dist2dst) @@ -465,7 +437,7 @@ def __getitem__(self, index): return (self.graph_list[index], self.tensor[index]) -class PosNegEdgesGenerator(object): +class PosNegEdgesGenerator: """ Generate positive and negative samples Attributes: @@ -484,10 +456,7 @@ def __init__(self, g, split_edge, neg_samples=1, subsample_ratio=0.1, shuffle=Tr self.shuffle = shuffle def __call__(self, split_type): - if split_type == "train": - subsample_ratio = self.subsample_ratio - else: - subsample_ratio = 1 + subsample_ratio = self.subsample_ratio if split_type == "train" else 1 pos_edges = self.split_edge[split_type]["edge"] if split_type == "train": @@ -550,7 +519,7 @@ def __getitem__(self, index): return (subgraph, self.labels[index]) -class SEALSampler(object): +class SEALSampler: """ Sampler for SEAL in paper(no-block version) The strategy is to sample all the k-hop neighbors around the two target nodes. @@ -587,12 +556,8 @@ def sample_subgraph(self, target_nodes): subgraph = dgl.node_subgraph(self.graph, sample_nodes) # Each node should have unique node id in the new subgraph - u_id = int( - torch.nonzero(subgraph.ndata[NID] == int(target_nodes[0]), as_tuple=False) - ) - v_id = int( - torch.nonzero(subgraph.ndata[NID] == int(target_nodes[1]), as_tuple=False) - ) + u_id = int(torch.nonzero(subgraph.ndata[NID] == int(target_nodes[0]), as_tuple=False)) + v_id = int(torch.nonzero(subgraph.ndata[NID] == int(target_nodes[1]), as_tuple=False)) # remove link between target nodes in positive subgraphs. if subgraph.has_edges_between(u_id, v_id): @@ -608,7 +573,7 @@ def sample_subgraph(self, target_nodes): return subgraph def _collate(self, batch): - batch_graphs, batch_labels = map(list, zip(*batch)) + batch_graphs, batch_labels = map(list, zip(*batch, strict=False)) batch_graphs = dgl.batch(batch_graphs) batch_labels = torch.stack(batch_labels) @@ -637,7 +602,7 @@ def __call__(self, edges, labels): return subgraph_list, torch.cat(labels_list) -class SEALData(object): +class SEALData: """ 1. Generate positive and negative samples 2. Subgraph sampling @@ -682,26 +647,19 @@ def __init__( if use_coalesce: for k, v in g.edata.items(): g.edata[k] = v.float() # dgl.to_simple() requires data is float - self.g = dgl.to_simple( - g, copy_ndata=True, copy_edata=True, aggregator="sum" - ) + self.g = dgl.to_simple(g, copy_ndata=True, copy_edata=True, aggregator="sum") - self.ndata = {k: v for k, v in self.g.ndata.items()} - self.edata = {k: v for k, v in self.g.edata.items()} + self.ndata = dict(self.g.ndata.items()) + self.edata = dict(self.g.edata.items()) self.g.ndata.clear() self.g.edata.clear() self.print_fn("Save ndata and edata in class.") self.print_fn("Clear ndata and edata in graph.") - self.sampler = SEALSampler( - graph=self.g, hop=hop, num_workers=num_workers, print_fn=print_fn - ) + self.sampler = SEALSampler(graph=self.g, hop=hop, num_workers=num_workers, print_fn=print_fn) def __call__(self, split_type): - if split_type == "train": - subsample_ratio = self.subsample_ratio - else: - subsample_ratio = 1 + subsample_ratio = self.subsample_ratio if split_type == "train" else 1 path = osp.join( self.save_dir or "", @@ -741,7 +699,7 @@ def _transform_log_level(str_level): raise KeyError("Log level error") -class LightLogging(object): +class LightLogging: def __init__(self, log_path=None, log_name="lightlog", log_level="debug"): log_level = _transform_log_level(log_level) @@ -752,19 +710,10 @@ def __init__(self, log_path=None, log_name="lightlog", log_level="debug"): os.mkdir(log_path) if log_name.endswith("-") or log_name.endswith("_"): - log_name = ( - log_path - + log_name - + time.strftime("%Y-%m-%d-%H:%M", time.localtime(time.time())) - + ".log" - ) + log_name = log_path + log_name + time.strftime("%Y-%m-%d-%H:%M", time.localtime(time.time())) + ".log" else: log_name = ( - log_path - + log_name - + "_" - + time.strftime("%Y-%m-%d-%H-%M", time.localtime(time.time())) - + ".log" + log_path + log_name + "_" + time.strftime("%Y-%m-%d-%H-%M", time.localtime(time.time())) + ".log" ) logging.basicConfig( diff --git a/hugegraph-ml/src/hugegraph_ml/tasks/fraud_detector_caregnn.py b/hugegraph-ml/src/hugegraph_ml/tasks/fraud_detector_caregnn.py index 5258b9f5d..2741d7bd8 100644 --- a/hugegraph-ml/src/hugegraph_ml/tasks/fraud_detector_caregnn.py +++ b/hugegraph-ml/src/hugegraph_ml/tasks/fraud_detector_caregnn.py @@ -18,10 +18,11 @@ # pylint: disable=E0401,C0301 import torch -from torch import nn -from torch.nn.functional import softmax from dgl import DGLGraph from sklearn.metrics import recall_score, roc_auc_score +from torch import nn +from torch.nn.functional import softmax + class DetectorCaregnn: def __init__(self, graph: DGLGraph, model: nn.Module): @@ -36,60 +37,44 @@ def train( n_epochs: int = 200, gpu: int = -1, ): - - self._device = ( - f"cuda:{gpu}" if gpu != -1 and torch.cuda.is_available() else "cpu" - ) + self._device = f"cuda:{gpu}" if gpu != -1 and torch.cuda.is_available() else "cpu" self._model.to(self._device) self.graph = self.graph.to(self._device) labels = self.graph.ndata["label"].to(self._device) feat = self.graph.ndata["feature"].to(self._device) train_mask = self.graph.ndata["train_mask"] val_mask = self.graph.ndata["val_mask"] - train_idx = ( - torch.nonzero(train_mask, as_tuple=False).squeeze(1).to(self._device) - ) + train_idx = torch.nonzero(train_mask, as_tuple=False).squeeze(1).to(self._device) val_idx = torch.nonzero(val_mask, as_tuple=False).squeeze(1).to(self._device) - rl_idx = torch.nonzero( - train_mask.to(self._device) & labels.bool(), as_tuple=False - ).squeeze(1) + rl_idx = torch.nonzero(train_mask.to(self._device) & labels.bool(), as_tuple=False).squeeze(1) _, cnt = torch.unique(labels, return_counts=True) loss_fn = torch.nn.CrossEntropyLoss(weight=1 / cnt) - optimizer = torch.optim.Adam( - self._model.parameters(), lr=lr, weight_decay=weight_decay - ) - for epoch in range(n_epochs): + optimizer = torch.optim.Adam(self._model.parameters(), lr=lr, weight_decay=weight_decay) + for _epoch in range(n_epochs): self._model.train() logits_gnn, logits_sim = self._model(self.graph, feat) tr_loss = loss_fn(logits_gnn[train_idx], labels[train_idx]) + 2 * loss_fn( logits_sim[train_idx], labels[train_idx] ) - tr_recall = recall_score( + recall_score( labels[train_idx].cpu(), logits_gnn.data[train_idx].argmax(dim=1).cpu(), ) - tr_auc = roc_auc_score( + roc_auc_score( labels[train_idx].cpu(), softmax(logits_gnn, dim=1).data[train_idx][:, 1].cpu(), ) - val_loss = loss_fn(logits_gnn[val_idx], labels[val_idx]) + 2 * loss_fn( - logits_sim[val_idx], labels[val_idx] - ) - val_recall = recall_score( - labels[val_idx].cpu(), logits_gnn.data[val_idx].argmax(dim=1).cpu() - ) - val_auc = roc_auc_score( + loss_fn(logits_gnn[val_idx], labels[val_idx]) + 2 * loss_fn(logits_sim[val_idx], labels[val_idx]) + recall_score(labels[val_idx].cpu(), logits_gnn.data[val_idx].argmax(dim=1).cpu()) + roc_auc_score( labels[val_idx].cpu(), softmax(logits_gnn, dim=1).data[val_idx][:, 1].cpu(), ) optimizer.zero_grad() tr_loss.backward() optimizer.step() - print( - f"Epoch {epoch}, Train: Recall: {tr_recall:.4f} AUC: {tr_auc:.4f} Loss: {tr_loss.item():.4f} | Val: Recall: {val_recall:.4f} AUC: {val_auc:.4f} Loss: {val_loss.item():.4f}" - ) - self._model.RLModule(self.graph, epoch, rl_idx) + self._model.RLModule(self.graph, _epoch, rl_idx) def evaluate(self): labels = self.graph.ndata["label"].to(self._device) @@ -103,9 +88,7 @@ def evaluate(self): test_loss = loss_fn(logits_gnn[test_idx], labels[test_idx]) + 2 * loss_fn( logits_sim[test_idx], labels[test_idx] ) - test_recall = recall_score( - labels[test_idx].cpu(), logits_gnn[test_idx].argmax(dim=1).cpu() - ) + test_recall = recall_score(labels[test_idx].cpu(), logits_gnn[test_idx].argmax(dim=1).cpu()) test_auc = roc_auc_score( labels[test_idx].cpu(), softmax(logits_gnn, dim=1).data[test_idx][:, 1].cpu(), diff --git a/hugegraph-ml/src/hugegraph_ml/tasks/graph_classify.py b/hugegraph-ml/src/hugegraph_ml/tasks/graph_classify.py index 2810f7b21..edc316bc2 100644 --- a/hugegraph-ml/src/hugegraph_ml/tasks/graph_classify.py +++ b/hugegraph-ml/src/hugegraph_ml/tasks/graph_classify.py @@ -56,17 +56,16 @@ def _evaluate(self, dataloader): loss = total_loss / total return {"accuracy": accuracy, "loss": loss} - def train( self, batch_size: int = 20, lr: float = 1e-3, weight_decay: float = 0, n_epochs: int = 200, - patience: int = float('inf'), + patience: int = float("inf"), early_stopping_monitor: Literal["loss", "accuracy"] = "loss", clip: float = 2.0, - gpu: int = -1 + gpu: int = -1, ): self._device = f"cuda:{gpu}" if gpu != -1 and torch.cuda.is_available() else "cpu" self._early_stopping = EarlyStopping(patience=patience, monitor=early_stopping_monitor) diff --git a/hugegraph-ml/src/hugegraph_ml/tasks/hetero_sample_embed_gatne.py b/hugegraph-ml/src/hugegraph_ml/tasks/hetero_sample_embed_gatne.py index 91d00ac44..0399947ca 100644 --- a/hugegraph-ml/src/hugegraph_ml/tasks/hetero_sample_embed_gatne.py +++ b/hugegraph-ml/src/hugegraph_ml/tasks/hetero_sample_embed_gatne.py @@ -17,15 +17,17 @@ import random + import dgl import torch from torch import nn from tqdm.auto import tqdm + from hugegraph_ml.models.gatne import ( + NeighborSampler, + NSLoss, construct_typenodes_from_graph, generate_pairs, - NSLoss, - NeighborSampler, ) @@ -41,9 +43,7 @@ def train_and_embed( n_epochs: int = 200, gpu: int = -1, ): - self._device = ( - f"cuda:{gpu}" if gpu != -1 and torch.cuda.is_available() else "cpu" - ) + self._device = f"cuda:{gpu}" if gpu != -1 and torch.cuda.is_available() else "cpu" self._model = self._model.to(self._device) self.graph = self.graph.to(self._device) type_nodes = construct_typenodes_from_graph(self.graph) diff --git a/hugegraph-ml/src/hugegraph_ml/tasks/link_prediction_pgnn.py b/hugegraph-ml/src/hugegraph_ml/tasks/link_prediction_pgnn.py index 13e761e02..37c20fd2c 100644 --- a/hugegraph-ml/src/hugegraph_ml/tasks/link_prediction_pgnn.py +++ b/hugegraph-ml/src/hugegraph_ml/tasks/link_prediction_pgnn.py @@ -15,14 +15,15 @@ # specific language governing permissions and limitations # under the License. -import torch import dgl +import torch from torch import nn + from hugegraph_ml.models.pgnn import ( + eval_model, get_dataset, preselect_anchor, train_model, - eval_model, ) @@ -39,9 +40,7 @@ def train( n_epochs: int = 200, gpu: int = -1, ): - self._device = ( - f"cuda:{gpu}" if gpu != -1 and torch.cuda.is_available() else "cpu" - ) + self._device = f"cuda:{gpu}" if gpu != -1 and torch.cuda.is_available() else "cpu" self._model.to(self._device) data = get_dataset(self.graph) # pre-sample anchor nodes and compute shortest distance values for all epochs @@ -51,12 +50,9 @@ def train( dist_max_list, edge_weight_list, ) = preselect_anchor(data) - optimizer = torch.optim.Adam( - self._model.parameters(), lr=lr, weight_decay=weight_decay - ) + optimizer = torch.optim.Adam(self._model.parameters(), lr=lr, weight_decay=weight_decay) loss_func = nn.BCEWithLogitsLoss() best_auc_val = -1 - best_auc_test = -1 for epoch in range(n_epochs): if epoch == 200: for param_group in optimizer.param_groups: @@ -73,20 +69,9 @@ def train( train_model(data, self._model, loss_func, optimizer, self._device, g_data) - loss_train, auc_train, auc_val, auc_test = eval_model( - data, g_data, self._model, loss_func, self._device - ) + _loss_train, _auc_train, auc_val, _auc_test = eval_model(data, g_data, self._model, loss_func, self._device) if auc_val > best_auc_val: best_auc_val = auc_val - best_auc_test = auc_test if epoch % 100 == 0: - print( - epoch, - f"Loss {loss_train:.4f}", - f"Train AUC: {auc_train:.4f}", - f"Val AUC: {auc_val:.4f}", - f"Test AUC: {auc_test:.4f}", - f"Best Val AUC: {best_auc_val:.4f}", - f"Best Test AUC: {best_auc_test:.4f}", - ) + pass diff --git a/hugegraph-ml/src/hugegraph_ml/tasks/link_prediction_seal.py b/hugegraph-ml/src/hugegraph_ml/tasks/link_prediction_seal.py index 58f2115e1..0b8af2143 100644 --- a/hugegraph-ml/src/hugegraph_ml/tasks/link_prediction_seal.py +++ b/hugegraph-ml/src/hugegraph_ml/tasks/link_prediction_seal.py @@ -18,14 +18,17 @@ # pylint: disable=R1728 import time + +import numpy as np import torch -from torch.nn import BCEWithLogitsLoss -from dgl import DGLGraph, NID, EID +from dgl import EID, NID, DGLGraph from dgl.dataloading import GraphDataLoader +from torch.nn import BCEWithLogitsLoss from tqdm import tqdm -import numpy as np + from hugegraph_ml.models.seal import SEALData, evaluate_hits + class LinkPredictionSeal: def __init__(self, graph: DGLGraph, split_edge, model): self.graph = graph @@ -88,58 +91,37 @@ def train( gpu: int = -1, ): torch.manual_seed(2021) - self._device = ( - f"cuda:{gpu}" if gpu != -1 and torch.cuda.is_available() else "cpu" - ) + self._device = f"cuda:{gpu}" if gpu != -1 and torch.cuda.is_available() else "cpu" self._model.to(self._device) self.graph = self.graph.to(self._device) parameters = self._model.parameters() optimizer = torch.optim.Adam(parameters, lr=lr) loss_fn = BCEWithLogitsLoss() - print( - f"Total parameters: {sum([p.numel() for p in self._model.parameters()])}" - ) # train and evaluate loop summary_val = [] summary_test = [] for epoch in range(n_epochs): - start_time = time.time() - loss = self._train( + time.time() + self._train( dataloader=self.train_loader, loss_fn=loss_fn, optimizer=optimizer, num_graphs=32, total_graphs=self.train_graphs, ) - train_time = time.time() + time.time() if epoch % 5 == 0: val_pos_pred, val_neg_pred = self.evaluate(dataloader=self.val_loader) - test_pos_pred, test_neg_pred = self.evaluate( - dataloader=self.test_loader - ) + test_pos_pred, test_neg_pred = self.evaluate(dataloader=self.test_loader) - val_metric = evaluate_hits( - "ogbl-collab", val_pos_pred, val_neg_pred, 50 - ) - test_metric = evaluate_hits( - "ogbl-collab", test_pos_pred, test_neg_pred, 50 - ) - evaluate_time = time.time() - print( - f"Epoch-{epoch}, train loss: {loss:.4f}, hits@{50}: val-{val_metric:.4f}, \\" - f"test-{test_metric:.4f}, cost time: train-{train_time - start_time:.1f}s, \\" - f"total-{evaluate_time - start_time:.1f}s" - ) + val_metric = evaluate_hits("ogbl-collab", val_pos_pred, val_neg_pred, 50) + test_metric = evaluate_hits("ogbl-collab", test_pos_pred, test_neg_pred, 50) + time.time() summary_val.append(val_metric) summary_test.append(test_metric) summary_test = np.array(summary_test) - print("Experiment Results:") - print( - f"Best hits@{50}: {np.max(summary_test):.4f}, epoch: {np.argmax(summary_test)}" - ) - @torch.no_grad() def evaluate(self, dataloader): self._model.eval() diff --git a/hugegraph-ml/src/hugegraph_ml/tasks/node_classify_with_edge.py b/hugegraph-ml/src/hugegraph_ml/tasks/node_classify_with_edge.py index 57276ff67..0ab908a6a 100644 --- a/hugegraph-ml/src/hugegraph_ml/tasks/node_classify_with_edge.py +++ b/hugegraph-ml/src/hugegraph_ml/tasks/node_classify_with_edge.py @@ -39,15 +39,11 @@ def _check_graph(self): required_node_attrs = ["feat", "label", "train_mask", "val_mask", "test_mask"] for attr in required_node_attrs: if attr not in self.graph.ndata: - raise ValueError( - f"Graph is missing required node attribute '{attr}' in ndata." - ) + raise ValueError(f"Graph is missing required node attribute '{attr}' in ndata.") required_edge_attrs = ["feat"] for attr in required_edge_attrs: if attr not in self.graph.edata: - raise ValueError( - f"Graph is missing required edge attribute '{attr}' in edata." - ) + raise ValueError(f"Graph is missing required edge attribute '{attr}' in edata.") def _evaluate(self, edge_feats, node_feats, labels, mask): self._model.eval() @@ -69,12 +65,8 @@ def train( gpu: int = -1, ): # Set device for training - self._device = ( - f"cuda:{gpu}" if gpu != -1 and torch.cuda.is_available() else "cpu" - ) - self._early_stopping = EarlyStopping( - patience=patience, monitor=early_stopping_monitor - ) + self._device = f"cuda:{gpu}" if gpu != -1 and torch.cuda.is_available() else "cpu" + self._early_stopping = EarlyStopping(patience=patience, monitor=early_stopping_monitor) self._model.to(self._device) self.graph = self.graph.to(self._device) # Get node features, labels, masks and move to device @@ -83,9 +75,7 @@ def train( labels = self.graph.ndata["label"].to(self._device) train_mask = self.graph.ndata["train_mask"].to(self._device) val_mask = self.graph.ndata["val_mask"].to(self._device) - optimizer = torch.optim.Adam( - self._model.parameters(), lr=lr, weight_decay=weight_decay - ) + optimizer = torch.optim.Adam(self._model.parameters(), lr=lr, weight_decay=weight_decay) # Training model epochs = trange(n_epochs) for epoch in epochs: @@ -105,9 +95,7 @@ def train( f"epoch {epoch} | train loss {loss.item():.4f} | val loss {valid_metrics['loss']:.4f}" ) # early stopping - self._early_stopping( - valid_metrics[self._early_stopping.monitor], self._model - ) + self._early_stopping(valid_metrics[self._early_stopping.monitor], self._model) torch.cuda.empty_cache() if self._early_stopping.early_stop: break diff --git a/hugegraph-ml/src/hugegraph_ml/tasks/node_classify_with_sample.py b/hugegraph-ml/src/hugegraph_ml/tasks/node_classify_with_sample.py index 5a9afff11..a07992099 100644 --- a/hugegraph-ml/src/hugegraph_ml/tasks/node_classify_with_sample.py +++ b/hugegraph-ml/src/hugegraph_ml/tasks/node_classify_with_sample.py @@ -19,12 +19,12 @@ from typing import Literal +import dgl +import numpy as np import torch from dgl import DGLGraph from torch import nn from tqdm import trange -import dgl -import numpy as np from hugegraph_ml.utils.early_stopping import EarlyStopping @@ -60,9 +60,7 @@ def _check_graph(self): required_node_attrs = ["feat", "label", "train_mask", "val_mask", "test_mask"] for attr in required_node_attrs: if attr not in self.graph.ndata: - raise ValueError( - f"Graph is missing required node attribute '{attr}' in ndata." - ) + raise ValueError(f"Graph is missing required node attribute '{attr}' in ndata.") def train( self, @@ -73,18 +71,14 @@ def train( early_stopping_monitor: Literal["loss", "accuracy"] = "loss", ): # Set device for training - early_stopping = EarlyStopping( - patience=patience, monitor=early_stopping_monitor - ) + early_stopping = EarlyStopping(patience=patience, monitor=early_stopping_monitor) self._model.to(self._device) # Get node features, labels, masks and move to device feats = self.graph.ndata["feat"].to(self._device) labels = self.graph.ndata["label"].to(self._device) train_mask = self.graph.ndata["train_mask"].to(self._device) val_mask = self.graph.ndata["val_mask"].to(self._device) - optimizer = torch.optim.Adam( - self._model.parameters(), lr=lr, weight_decay=weight_decay - ) + optimizer = torch.optim.Adam(self._model.parameters(), lr=lr, weight_decay=weight_decay) # Training model loss_fn = nn.CrossEntropyLoss() epochs = trange(n_epochs) @@ -109,7 +103,8 @@ def train( ) # logs epochs.set_description( - f"epoch {epoch} | it {it} | train loss {train_loss.item():.4f} | val loss {valid_metrics['loss']:.4f}" + f"epoch {epoch} | it {it} | train loss {train_loss.item():.4f} " + f"| val loss {valid_metrics['loss']:.4f}" ) # early stopping early_stopping(valid_metrics[early_stopping.monitor], self._model) @@ -151,4 +146,3 @@ def evaluate(self): _, predicted = torch.max(test_logits, dim=1) accuracy = (predicted == test_labels[0]).sum().item() / len(test_labels[0]) return {"accuracy": accuracy, "total_loss": total_loss.item()} - \ No newline at end of file diff --git a/hugegraph-ml/src/hugegraph_ml/utils/dgl2hugegraph_utils.py b/hugegraph-ml/src/hugegraph_ml/utils/dgl2hugegraph_utils.py index 41e13fe6b..1bdce5cc8 100644 --- a/hugegraph-ml/src/hugegraph_ml/utils/dgl2hugegraph_utils.py +++ b/hugegraph-ml/src/hugegraph_ml/utils/dgl2hugegraph_utils.py @@ -19,31 +19,38 @@ # pylint: disable=too-many-statements # pylint: disable=C0302,C0103,W1514,R1735,R1734,C0206 -import os -from typing import Optional import json +import os + import dgl +import networkx as nx import numpy as np +import pandas as pd import scipy import torch -from dgl.data import CoraGraphDataset, CiteseerGraphDataset, PubmedGraphDataset, LegacyTUDataset, GINDataset, \ - get_download_dir +from dgl.data import ( + CiteseerGraphDataset, + CoraGraphDataset, + GINDataset, + LegacyTUDataset, + PubmedGraphDataset, + get_download_dir, +) from dgl.data.utils import _get_dgl_url, download, load_graphs -import networkx as nx from ogb.linkproppred import DglLinkPropPredDataset -import pandas as pd from pyhugegraph.api.graph import GraphManager from pyhugegraph.api.schema import SchemaManager from pyhugegraph.client import PyHugeClient MAX_BATCH_NUM = 500 + def clear_all_data( url: str = "http://127.0.0.1:8080", graph: str = "hugegraph", user: str = "", pwd: str = "", - graphspace: Optional[str] = None, + graphspace: str | None = None, ): client: PyHugeClient = PyHugeClient(url=url, graph=graph, user=user, pwd=pwd, graphspace=graphspace) client.graphs().clear_graph_all_data() @@ -55,7 +62,7 @@ def import_graph_from_dgl( graph: str = "hugegraph", user: str = "", pwd: str = "", - graphspace: Optional[str] = None, + graphspace: str | None = None, ): dataset_name = dataset_name.upper() if dataset_name == "CORA": @@ -91,8 +98,9 @@ def import_graph_from_dgl( vidxs = [] for idx in range(graph_dgl.number_of_nodes()): # extract props - properties = {p: int(props_value[p][idx]) if isinstance(props_value[p][idx], bool) else props_value[p][idx] - for p in props} + properties = { + p: int(props_value[p][idx]) if isinstance(props_value[p][idx], bool) else props_value[p][idx] for p in props + } vdata = [vertex_label, properties] vdatas.append(vdata) vidxs.append(idx) @@ -109,7 +117,7 @@ def import_graph_from_dgl( client_schema.edgeLabel(edge_label).sourceLabel(vertex_label).targetLabel(vertex_label).ifNotExist().create() edges_src, edges_dst = graph_dgl.edges() edatas = [] - for src, dst in zip(edges_src.numpy(), edges_dst.numpy()): + for src, dst in zip(edges_src.numpy(), edges_dst.numpy(), strict=False): edata = [edge_label, idx_to_vertex_id[src], idx_to_vertex_id[dst], vertex_label, vertex_label, {}] edatas.append(edata) if len(edatas) == MAX_BATCH_NUM: @@ -125,7 +133,7 @@ def import_graphs_from_dgl( graph: str = "hugegraph", user: str = "", pwd: str = "", - graphspace: Optional[str] = None, + graphspace: str | None = None, ): dataset_name = dataset_name.upper() # load dgl bultin dataset @@ -150,11 +158,12 @@ def import_graphs_from_dgl( client_schema.vertexLabel(graph_vertex_label).useAutomaticId().properties("label").ifNotExist().create() client_schema.vertexLabel(vertex_label).useAutomaticId().properties("feat", "graph_id").ifNotExist().create() client_schema.edgeLabel(edge_label).sourceLabel(vertex_label).targetLabel(vertex_label).properties( - "graph_id").ifNotExist().create() + "graph_id" + ).ifNotExist().create() client_schema.indexLabel("vertex_by_graph_id").onV(vertex_label).by("graph_id").secondary().ifNotExist().create() client_schema.indexLabel("edge_by_graph_id").onE(edge_label).by("graph_id").secondary().ifNotExist().create() # import to hugegraph - for (graph_dgl, label) in dataset_dgl: + for graph_dgl, label in dataset_dgl: graph_vertex = client_graph.addVertex(label=graph_vertex_label, properties={"label": int(label)}) # refine feat prop if "feat" in graph_dgl.ndata: @@ -182,14 +191,14 @@ def import_graphs_from_dgl( # add edges of graph i for barch srcs, dsts = graph_dgl.edges() edatas = [] - for src, dst in zip(srcs.numpy(), dsts.numpy()): + for src, dst in zip(srcs.numpy(), dsts.numpy(), strict=False): edata = [ edge_label, idx_to_vertex_id[src], idx_to_vertex_id[dst], vertex_label, vertex_label, - {"graph_id": graph_vertex.id} + {"graph_id": graph_vertex.id}, ] edatas.append(edata) if len(edatas) == MAX_BATCH_NUM: @@ -205,7 +214,7 @@ def import_hetero_graph_from_dgl( graph: str = "hugegraph", user: str = "", pwd: str = "", - graphspace: Optional[str] = None, + graphspace: str | None = None, ): dataset_name = dataset_name.upper() if dataset_name == "ACM": @@ -240,8 +249,10 @@ def import_hetero_graph_from_dgl( vdatas = [] idxs = [] for idx in range(hetero_graph.number_of_nodes(ntype=ntype)): - properties = {p: int(props_value[p][idx]) if isinstance(props_value[p][idx], bool) else props_value[p][idx] - for p in props} + properties = { + p: int(props_value[p][idx]) if isinstance(props_value[p][idx], bool) else props_value[p][idx] + for p in props + } vdata = [vertex_label, properties] vdatas.append(vdata) idxs.append(idx) @@ -264,14 +275,14 @@ def import_hetero_graph_from_dgl( ).ifNotExist().create() # add edges for batch of canonical_etype srcs, dsts = hetero_graph.edges(etype=canonical_etype) - for src, dst in zip(srcs.numpy(), dsts.numpy()): + for src, dst in zip(srcs.numpy(), dsts.numpy(), strict=False): edata = [ edge_label, ntype_idx_to_vertex_id[src_type][src], ntype_idx_to_vertex_id[dst_type][dst], ntype_to_vertex_label[src_type], ntype_to_vertex_label[dst_type], - {} + {}, ] edatas.append(edata) if len(edatas) == MAX_BATCH_NUM: @@ -280,13 +291,14 @@ def import_hetero_graph_from_dgl( if len(edatas) > 0: _add_batch_edges(client_graph, edatas) + def import_hetero_graph_from_dgl_no_feat( dataset_name, url: str = "http://127.0.0.1:8080", graph: str = "hugegraph", user: str = "", pwd: str = "", - graphspace: Optional[str] = None, + graphspace: str | None = None, ): # dataset download from: # https://s3.us-west-2.amazonaws.com/dgl-data/dataset/recsys/GATNE/amazon.zip @@ -295,9 +307,7 @@ def import_hetero_graph_from_dgl_no_feat( hetero_graph = load_training_data_gatne() else: raise ValueError("dataset not supported") - client: PyHugeClient = PyHugeClient( - url=url, graph=graph, user=user, pwd=pwd, graphspace=graphspace - ) + client: PyHugeClient = PyHugeClient(url=url, graph=graph, user=user, pwd=pwd, graphspace=graphspace) client_schema: SchemaManager = client.schema() client_graph: GraphManager = client.graph() @@ -331,12 +341,12 @@ def import_hetero_graph_from_dgl_no_feat( # create edge schema src_type, etype, dst_type = canonical_etype edge_label = f"{dataset_name}_{etype}_e" - client_schema.edgeLabel(edge_label).sourceLabel( - ntype_to_vertex_label[src_type] - ).targetLabel(ntype_to_vertex_label[dst_type]).ifNotExist().create() + client_schema.edgeLabel(edge_label).sourceLabel(ntype_to_vertex_label[src_type]).targetLabel( + ntype_to_vertex_label[dst_type] + ).ifNotExist().create() # add edges for batch of canonical_etype srcs, dsts = hetero_graph.edges(etype=canonical_etype) - for src, dst in zip(srcs.numpy(), dsts.numpy()): + for src, dst in zip(srcs.numpy(), dsts.numpy(), strict=False): edata = [ edge_label, ntype_idx_to_vertex_id[src_type][src], @@ -359,7 +369,7 @@ def import_graph_from_nx( graph: str = "hugegraph", user: str = "", pwd: str = "", - graphspace: Optional[str] = None, + graphspace: str | None = None, ): dataset_name = dataset_name.upper() if dataset_name == "CAVEMAN": @@ -367,9 +377,7 @@ def import_graph_from_nx( else: raise ValueError("dataset not supported") - client: PyHugeClient = PyHugeClient( - url=url, graph=graph, user=user, pwd=pwd, graphspace=graphspace - ) + client: PyHugeClient = PyHugeClient(url=url, graph=graph, user=user, pwd=pwd, graphspace=graphspace) client_schema: SchemaManager = client.schema() client_graph: GraphManager = client.graph() # create property schema @@ -394,9 +402,7 @@ def import_graph_from_nx( # add edges for batch edge_label = f"{dataset_name}_edge" - client_schema.edgeLabel(edge_label).sourceLabel(vertex_label).targetLabel( - vertex_label - ).ifNotExist().create() + client_schema.edgeLabel(edge_label).sourceLabel(vertex_label).targetLabel(vertex_label).ifNotExist().create() edatas = [] for edge in dataset.edges: edata = [ @@ -421,7 +427,7 @@ def import_graph_from_dgl_with_edge_feat( graph: str = "hugegraph", user: str = "", pwd: str = "", - graphspace: Optional[str] = None, + graphspace: str | None = None, ): dataset_name = dataset_name.upper() if dataset_name == "CORA": @@ -434,15 +440,11 @@ def import_graph_from_dgl_with_edge_feat( raise ValueError("dataset not supported") graph_dgl = dataset_dgl[0] - client: PyHugeClient = PyHugeClient( - url=url, graph=graph, user=user, pwd=pwd, graphspace=graphspace - ) + client: PyHugeClient = PyHugeClient(url=url, graph=graph, user=user, pwd=pwd, graphspace=graphspace) client_schema: SchemaManager = client.schema() client_graph: GraphManager = client.graph() # create property schema - client_schema.propertyKey( - "feat" - ).asDouble().valueList().ifNotExist().create() # node features + client_schema.propertyKey("feat").asDouble().valueList().ifNotExist().create() # node features client_schema.propertyKey("edge_feat").asDouble().valueList().ifNotExist().create() client_schema.propertyKey("label").asLong().ifNotExist().create() client_schema.propertyKey("train_mask").asInt().ifNotExist().create() @@ -455,9 +457,7 @@ def import_graph_from_dgl_with_edge_feat( node_props_value = {} for p in node_props: node_props_value[p] = graph_dgl.ndata[p].tolist() - client_schema.vertexLabel(vertex_label).useAutomaticId().properties( - *node_props - ).ifNotExist().create() + client_schema.vertexLabel(vertex_label).useAutomaticId().properties(*node_props).ifNotExist().create() # add vertices for batch (note MAX_BATCH_NUM) idx_to_vertex_id = {} vdatas = [] @@ -487,12 +487,12 @@ def import_graph_from_dgl_with_edge_feat( edge_label = f"{dataset_name}_edge_feat_edge" edge_all_props = ["edge_feat"] - client_schema.edgeLabel(edge_label).sourceLabel(vertex_label).targetLabel( - vertex_label - ).properties(*edge_all_props).ifNotExist().create() + client_schema.edgeLabel(edge_label).sourceLabel(vertex_label).targetLabel(vertex_label).properties( + *edge_all_props + ).ifNotExist().create() edges_src, edges_dst = graph_dgl.edges() edatas = [] - for src, dst in zip(edges_src.numpy(), edges_dst.numpy()): + for src, dst in zip(edges_src.numpy(), edges_dst.numpy(), strict=False): properties = {p: (torch.rand(8).tolist()) for p in edge_all_props} edata = [ edge_label, @@ -516,7 +516,7 @@ def import_graph_from_ogb( graph: str = "hugegraph", user: str = "", pwd: str = "", - graphspace: Optional[str] = None, + graphspace: str | None = None, ): if dataset_name == "ogbl-collab": dataset_dgl = DglLinkPropPredDataset(name=dataset_name) @@ -524,15 +524,11 @@ def import_graph_from_ogb( raise ValueError("dataset not supported") graph_dgl = dataset_dgl[0] - client: PyHugeClient = PyHugeClient( - url=url, graph=graph, user=user, pwd=pwd, graphspace=graphspace - ) + client: PyHugeClient = PyHugeClient(url=url, graph=graph, user=user, pwd=pwd, graphspace=graphspace) client_schema: SchemaManager = client.schema() client_graph: GraphManager = client.graph() # create property schema - client_schema.propertyKey( - "feat" - ).asDouble().valueList().ifNotExist().create() # node features + client_schema.propertyKey("feat").asDouble().valueList().ifNotExist().create() # node features client_schema.propertyKey("year").asDouble().valueList().ifNotExist().create() client_schema.propertyKey("weight").asDouble().valueList().ifNotExist().create() @@ -543,9 +539,7 @@ def import_graph_from_ogb( node_props_value = {} for p in node_props: node_props_value[p] = graph_dgl.ndata[p].tolist() - client_schema.vertexLabel(vertex_label).useAutomaticId().properties( - *node_props - ).ifNotExist().create() + client_schema.vertexLabel(vertex_label).useAutomaticId().properties(*node_props).ifNotExist().create() # add vertices for batch (note MAX_BATCH_NUM) idx_to_vertex_id = {} @@ -567,9 +561,7 @@ def import_graph_from_ogb( vdatas.append(vdata) vidxs.append(idx) if len(vdatas) == MAX_BATCH_NUM: - idx_to_vertex_id.update( - _add_batch_vertices(client_graph, vdatas, vidxs) - ) + idx_to_vertex_id.update(_add_batch_vertices(client_graph, vdatas, vidxs)) vdatas.clear() vidxs.clear() # add rest vertices @@ -582,12 +574,12 @@ def import_graph_from_ogb( edge_props_value = {} for p in edge_all_props: edge_props_value[p] = graph_dgl.edata[p].tolist() - client_schema.edgeLabel(edge_label).sourceLabel(vertex_label).targetLabel( - vertex_label - ).properties(*edge_all_props).ifNotExist().create() + client_schema.edgeLabel(edge_label).sourceLabel(vertex_label).targetLabel(vertex_label).properties( + *edge_all_props + ).ifNotExist().create() edges_src, edges_dst = graph_dgl.edges() edatas = [] - for src, dst in zip(edges_src.numpy(), edges_dst.numpy()): + for src, dst in zip(edges_src.numpy(), edges_dst.numpy(), strict=False): if src <= max_nodes and dst <= max_nodes: properties = { p: ( @@ -611,7 +603,6 @@ def import_graph_from_ogb( edatas.clear() if len(edatas) > 0: _add_batch_edges(client_graph, edatas) - print("begin edge split") import_split_edge_from_ogb( dataset_name=dataset_name, idx_to_vertex_id=idx_to_vertex_id, @@ -627,7 +618,7 @@ def import_split_edge_from_ogb( graph: str = "hugegraph", user: str = "", pwd: str = "", - graphspace: Optional[str] = None, + graphspace: str | None = None, ): if dataset_name == "ogbl-collab": dataset_dgl = DglLinkPropPredDataset(name=dataset_name) @@ -635,9 +626,7 @@ def import_split_edge_from_ogb( raise ValueError("dataset not supported") split_edges = dataset_dgl.get_edge_split() - client: PyHugeClient = PyHugeClient( - url=url, graph=graph, user=user, pwd=pwd, graphspace=graphspace - ) + client: PyHugeClient = PyHugeClient(url=url, graph=graph, user=user, pwd=pwd, graphspace=graphspace) client_schema: SchemaManager = client.schema() client_graph: GraphManager = client.graph() # create property schema @@ -675,9 +664,9 @@ def import_split_edge_from_ogb( # add edges for batch vertex_label = f"{dataset_name}_vertex" edge_label = f"{dataset_name}_split_edge" - client_schema.edgeLabel(edge_label).sourceLabel(vertex_label).targetLabel( - vertex_label - ).properties(*edge_all_props).ifNotExist().create() + client_schema.edgeLabel(edge_label).sourceLabel(vertex_label).targetLabel(vertex_label).properties( + *edge_all_props + ).ifNotExist().create() edges = {} edges["train_edge_mask"] = split_edges["train"]["edge"] edges["train_year_mask"] = split_edges["train"]["year"] @@ -763,7 +752,7 @@ def import_hetero_graph_from_dgl_bgnn( graph: str = "hugegraph", user: str = "", pwd: str = "", - graphspace: Optional[str] = None, + graphspace: str | None = None, ): # dataset download from : https://www.dropbox.com/s/verx1evkykzli88/datasets.zip # Extract zip folder in this directory @@ -772,9 +761,7 @@ def import_hetero_graph_from_dgl_bgnn( hetero_graph = read_input() else: raise ValueError("dataset not supported") - client: PyHugeClient = PyHugeClient( - url=url, graph=graph, user=user, pwd=pwd, graphspace=graphspace - ) + client: PyHugeClient = PyHugeClient(url=url, graph=graph, user=user, pwd=pwd, graphspace=graphspace) client_schema: SchemaManager = client.schema() client_graph: GraphManager = client.graph() @@ -801,9 +788,7 @@ def import_hetero_graph_from_dgl_bgnn( ] # check properties props = [p for p in all_props if p in hetero_graph.nodes[ntype].data] - client_schema.vertexLabel(vertex_label).useAutomaticId().properties( - *props - ).ifNotExist().create() + client_schema.vertexLabel(vertex_label).useAutomaticId().properties(*props).ifNotExist().create() props_value = {} for p in props: props_value[p] = hetero_graph.nodes[ntype].data[p].tolist() @@ -813,11 +798,7 @@ def import_hetero_graph_from_dgl_bgnn( idxs = [] for idx in range(hetero_graph.number_of_nodes(ntype=ntype)): properties = { - p: ( - int(props_value[p][idx]) - if isinstance(props_value[p][idx], bool) - else props_value[p][idx] - ) + p: (int(props_value[p][idx]) if isinstance(props_value[p][idx], bool) else props_value[p][idx]) for p in props } vdata = [vertex_label, properties] @@ -837,12 +818,12 @@ def import_hetero_graph_from_dgl_bgnn( # create edge schema src_type, etype, dst_type = canonical_etype edge_label = f"{dataset_name}_{etype}_e" - client_schema.edgeLabel(edge_label).sourceLabel( - ntype_to_vertex_label[src_type] - ).targetLabel(ntype_to_vertex_label[dst_type]).ifNotExist().create() + client_schema.edgeLabel(edge_label).sourceLabel(ntype_to_vertex_label[src_type]).targetLabel( + ntype_to_vertex_label[dst_type] + ).ifNotExist().create() # add edges for batch of canonical_etype srcs, dsts = hetero_graph.edges(etype=canonical_etype) - for src, dst in zip(srcs.numpy(), dsts.numpy()): + for src, dst in zip(srcs.numpy(), dsts.numpy(), strict=False): edata = [ edge_label, ntype_idx_to_vertex_id[src_type][src], @@ -914,6 +895,7 @@ def init_ogb_split_edge( if len(edatas) > 0: _add_batch_edges(client_graph, edatas) + def _add_batch_vertices(client_graph, vdatas, vidxs): vertices = client_graph.addVertices(vdatas) assert len(vertices) == len(vidxs) @@ -932,7 +914,6 @@ def load_acm_raw(): url = "dataset/ACM.mat" data_path = get_download_dir() + "/ACM.mat" if not os.path.exists(data_path): - print(f"File {data_path} not found, downloading...") download(_get_dgl_url(url), path=data_path) data = scipy.io.loadmat(data_path) @@ -967,16 +948,14 @@ def load_acm_raw(): pc_p, pc_c = p_vs_c.nonzero() labels = np.zeros(len(p_selected), dtype=np.int64) - for conf_id, label_id in zip(conf_ids, label_ids): + for conf_id, label_id in zip(conf_ids, label_ids, strict=False): labels[pc_p[pc_c == conf_id]] = label_id labels = torch.LongTensor(labels) float_mask = np.zeros(len(pc_p)) for conf_id in conf_ids: pc_c_mask = pc_c == conf_id - float_mask[pc_c_mask] = np.random.permutation( - np.linspace(0, 1, pc_c_mask.sum()) - ) + float_mask[pc_c_mask] = np.random.permutation(np.linspace(0, 1, pc_c_mask.sum())) train_idx = np.where(float_mask <= 0.2)[0] val_idx = np.where((float_mask > 0.2) & (float_mask <= 0.3))[0] test_idx = np.where(float_mask > 0.3)[0] @@ -994,6 +973,7 @@ def load_acm_raw(): return hgraph + def read_input(): # reference: https://github.com/dmlc/dgl/blob/master/examples/pytorch/bgnn/run.py # I added X, y, cat_features and masks into graph @@ -1050,18 +1030,17 @@ def load_training_data_gatne(): # reference: https://github.com/dmlc/dgl/blob/master/examples/pytorch/GATNE-T/src/utils.py # reference: https://github.com/dmlc/dgl/blob/master/examples/pytorch/GATNE-T/src/main.py f_name = "dataset/amazon/train.txt" - print("We are loading data from:", f_name) - edge_data_by_type = dict() - with open(f_name, "r") as f: + edge_data_by_type = {} + with open(f_name) as f: for line in f: words = line[:-1].split(" ") # line[-1] == '\n' if words[0] not in edge_data_by_type: - edge_data_by_type[words[0]] = list() + edge_data_by_type[words[0]] = [] x, y = words[1], words[2] edge_data_by_type[words[0]].append((x, y)) nodes, index2word = [], [] for edge_type in edge_data_by_type: - node1, node2 = zip(*edge_data_by_type[edge_type]) + node1, node2 = zip(*edge_data_by_type[edge_type], strict=False) index2word = index2word + list(node1) + list(node2) index2word = list(set(index2word)) vocab = {} @@ -1070,12 +1049,12 @@ def load_training_data_gatne(): vocab[word] = i i = i + 1 for edge_type in edge_data_by_type: - node1, node2 = zip(*edge_data_by_type[edge_type]) + node1, node2 = zip(*edge_data_by_type[edge_type], strict=False) tmp_nodes = list(set(list(node1) + list(node2))) tmp_nodes = [vocab[word] for word in tmp_nodes] nodes.append(tmp_nodes) node_type = "_N" # '_N' can be replaced by an arbitrary name - data_dict = dict() + data_dict = {} num_nodes_dict = {node_type: len(vocab)} for edge_type in edge_data_by_type: tmp_data = edge_data_by_type[edge_type] @@ -1088,6 +1067,7 @@ def load_training_data_gatne(): graph = dgl.heterograph(data_dict, num_nodes_dict) return graph + def _get_mask(size, indices): mask = torch.zeros(size) mask[indices] = 1 diff --git a/hugegraph-ml/src/tests/conftest.py b/hugegraph-ml/src/tests/conftest.py index a5f9839a4..303149014 100644 --- a/hugegraph-ml/src/tests/conftest.py +++ b/hugegraph-ml/src/tests/conftest.py @@ -18,13 +18,16 @@ import pytest -from hugegraph_ml.utils.dgl2hugegraph_utils import clear_all_data, import_graph_from_dgl, import_graphs_from_dgl, \ - import_hetero_graph_from_dgl +from hugegraph_ml.utils.dgl2hugegraph_utils import ( + clear_all_data, + import_graph_from_dgl, + import_graphs_from_dgl, + import_hetero_graph_from_dgl, +) @pytest.fixture(scope="session", autouse=True) def setup_and_teardown(): - print("Setup: Importing DGL dataset to HugeGraph") clear_all_data() import_graph_from_dgl("CORA") import_graphs_from_dgl("MUTAG") @@ -32,5 +35,4 @@ def setup_and_teardown(): yield - print("Teardown: Clearing HugeGraph data") clear_all_data() diff --git a/hugegraph-ml/src/tests/test_data/test_hugegraph2dgl.py b/hugegraph-ml/src/tests/test_data/test_hugegraph2dgl.py index 5527fba5f..f8ee4a381 100644 --- a/hugegraph-ml/src/tests/test_data/test_hugegraph2dgl.py +++ b/hugegraph-ml/src/tests/test_data/test_hugegraph2dgl.py @@ -20,6 +20,7 @@ import torch from dgl.data import CoraGraphDataset, GINDataset + from hugegraph_ml.data.hugegraph2dgl import HugeGraph2DGL from hugegraph_ml.utils.dgl2hugegraph_utils import load_acm_raw @@ -86,13 +87,9 @@ def test_convert_graph_dataset(self): edge_label="MUTAG_edge", ) - self.assertEqual( - len(dataset_dgl), len(self.mutag_dataset), "Number of graphs does not match." - ) + self.assertEqual(len(dataset_dgl), len(self.mutag_dataset), "Number of graphs does not match.") - self.assertEqual( - dataset_dgl.info["n_graphs"], len(self.mutag_dataset), "Number of graphs does not match." - ) + self.assertEqual(dataset_dgl.info["n_graphs"], len(self.mutag_dataset), "Number of graphs does not match.") self.assertEqual( dataset_dgl.info["n_classes"], self.mutag_dataset.gclasses, "Number of graph classes does not match." @@ -105,21 +102,19 @@ def test_convert_hetero_graph(self): hg2d = HugeGraph2DGL() hetero_graph = hg2d.convert_hetero_graph( vertex_labels=["ACM_paper_v", "ACM_author_v", "ACM_field_v"], - edge_labels=["ACM_ap_e", "ACM_fp_e", "ACM_pa_e", "ACM_pf_e"] + edge_labels=["ACM_ap_e", "ACM_fp_e", "ACM_pa_e", "ACM_pf_e"], ) for ntype in self.acm_data.ntypes: self.assertIn( - self.ntype_map[ntype], - hetero_graph.ntypes, - f"Node type {ntype} is missing in converted graph." + self.ntype_map[ntype], hetero_graph.ntypes, f"Node type {ntype} is missing in converted graph." ) acm_node_count = self.acm_data.num_nodes(ntype) hetero_node_count = hetero_graph.num_nodes(self.ntype_map[ntype]) self.assertEqual( acm_node_count, hetero_node_count, - f"Node count for type {ntype} does not match: {acm_node_count} != {hetero_node_count}" + f"Node count for type {ntype} does not match: {acm_node_count} != {hetero_node_count}", ) for c_etypes in self.acm_data.canonical_etypes: @@ -127,12 +122,12 @@ def test_convert_hetero_graph(self): self.assertIn( mapped_c_etypes, hetero_graph.canonical_etypes, - f"Edge type {mapped_c_etypes} is missing in converted graph." + f"Edge type {mapped_c_etypes} is missing in converted graph.", ) acm_edge_count = self.acm_data.num_edges(etype=c_etypes) hetero_edge_count = hetero_graph.num_edges(etype=mapped_c_etypes) self.assertEqual( acm_edge_count, hetero_edge_count, - f"Edge count for type {mapped_c_etypes} does not match: {acm_edge_count} != {hetero_edge_count}" + f"Edge count for type {mapped_c_etypes} does not match: {acm_edge_count} != {hetero_edge_count}", ) diff --git a/hugegraph-ml/src/tests/test_examples/test_examples.py b/hugegraph-ml/src/tests/test_examples/test_examples.py index 2712d9bc1..08dfd228f 100644 --- a/hugegraph-ml/src/tests/test_examples/test_examples.py +++ b/hugegraph-ml/src/tests/test_examples/test_examples.py @@ -17,12 +17,6 @@ import unittest -from hugegraph_ml.examples.dgi_example import dgi_example -from hugegraph_ml.examples.diffpool_example import diffpool_example -from hugegraph_ml.examples.gin_example import gin_example -from hugegraph_ml.examples.grace_example import grace_example -from hugegraph_ml.examples.grand_example import grand_example -from hugegraph_ml.examples.jknet_example import jknet_example from hugegraph_ml.examples.agnn_example import agnn_example from hugegraph_ml.examples.appnp_example import appnp_example from hugegraph_ml.examples.arma_example import arma_example @@ -32,9 +26,16 @@ from hugegraph_ml.examples.correct_and_smooth_example import cs_example from hugegraph_ml.examples.dagnn_example import dagnn_example from hugegraph_ml.examples.deepergcn_example import deepergcn_example +from hugegraph_ml.examples.dgi_example import dgi_example +from hugegraph_ml.examples.diffpool_example import diffpool_example +from hugegraph_ml.examples.gin_example import gin_example +from hugegraph_ml.examples.grace_example import grace_example +from hugegraph_ml.examples.grand_example import grand_example +from hugegraph_ml.examples.jknet_example import jknet_example from hugegraph_ml.examples.pgnn_example import pgnn_example from hugegraph_ml.examples.seal_example import seal_example + class TestHugegraph2DGL(unittest.TestCase): def setUp(self): self.test_n_epochs = 3 diff --git a/hugegraph-ml/src/tests/test_tasks/test_node_classify.py b/hugegraph-ml/src/tests/test_tasks/test_node_classify.py index d659d9d8d..898ff61da 100644 --- a/hugegraph-ml/src/tests/test_tasks/test_node_classify.py +++ b/hugegraph-ml/src/tests/test_tasks/test_node_classify.py @@ -34,18 +34,17 @@ def test_check_graph(self): graph=self.graph, model=JKNet( n_in_feats=self.graph.ndata["feat"].shape[1], - n_out_feats=self.graph.ndata["label"].unique().shape[0] + n_out_feats=self.graph.ndata["label"].unique().shape[0], ), ) except ValueError as e: - self.fail(f"_check_graph failed: {str(e)}") + self.fail(f"_check_graph failed: {e!s}") def test_train_and_evaluate(self): node_classify_task = NodeClassify( graph=self.graph, model=JKNet( - n_in_feats=self.graph.ndata["feat"].shape[1], - n_out_feats=self.graph.ndata["label"].unique().shape[0] + n_in_feats=self.graph.ndata["feat"].shape[1], n_out_feats=self.graph.ndata["label"].unique().shape[0] ), ) node_classify_task.train(n_epochs=10, patience=3) diff --git a/hugegraph-ml/src/tests/test_tasks/test_node_embed.py b/hugegraph-ml/src/tests/test_tasks/test_node_embed.py index e385fea1f..7520f8077 100644 --- a/hugegraph-ml/src/tests/test_tasks/test_node_embed.py +++ b/hugegraph-ml/src/tests/test_tasks/test_node_embed.py @@ -36,7 +36,7 @@ def test_check_graph(self): model=DGI(n_in_feats=self.graph.ndata["feat"].shape[1], n_hidden=self.embed_size), ) except ValueError as e: - self.fail(f"_check_graph failed: {str(e)}") + self.fail(f"_check_graph failed: {e!s}") def test_train_and_embed(self): node_embed_task = NodeEmbed( diff --git a/hugegraph-python-client/README.md b/hugegraph-python-client/README.md index 67150d60f..ebb39e24f 100644 --- a/hugegraph-python-client/README.md +++ b/hugegraph-python-client/README.md @@ -1,6 +1,6 @@ # hugegraph-python-client -The `hugegraph-python-client` is a Python client/SDK for HugeGraph Database. +The `hugegraph-python-client` is a Python client/SDK for HugeGraph Database. It is used to define graph structures, perform CRUD operations on graph data, manage schemas, and execute Gremlin queries. Both the `hugegraph-llm` and `hugegraph-ml` modules depend on this foundational library. @@ -42,9 +42,9 @@ You can use the `hugegraph-python-client` to define graph structures. Below is a from pyhugegraph.client import PyHugeClient # Initialize the client -# For HugeGraph API version ≥ v3: (Or enable graphspace function) +# For HugeGraph API version ≥ v3: (Or enable graphspace function) # - The 'graphspace' parameter becomes relevant if graphspaces are enabled.(default name is 'DEFAULT') -# - Otherwise, the graphspace parameter is optional and can be ignored. +# - Otherwise, the graphspace parameter is optional and can be ignored. client = PyHugeClient("127.0.0.1", "8080", user="admin", pwd="admin", graph="hugegraph", graphspace="DEFAULT") """" diff --git a/hugegraph-python-client/src/pyhugegraph/api/auth.py b/hugegraph-python-client/src/pyhugegraph/api/auth.py index d127c4f6d..7d7e74990 100644 --- a/hugegraph-python-client/src/pyhugegraph/api/auth.py +++ b/hugegraph-python-client/src/pyhugegraph/api/auth.py @@ -18,22 +18,18 @@ import json -from typing import Optional, Dict from pyhugegraph.api.common import HugeParamsBase from pyhugegraph.utils import huge_router as router class AuthManager(HugeParamsBase): - @router.http("GET", "auth/users") def list_users(self, limit=None): params = {"limit": limit} if limit is not None else {} return self._invoke_request(params=params) @router.http("POST", "auth/users") - def create_user( - self, user_name, user_password, user_phone=None, user_email=None - ) -> Optional[Dict]: + def create_user(self, user_name, user_password, user_phone=None, user_email=None) -> dict | None: return self._invoke_request( data=json.dumps( { @@ -46,7 +42,7 @@ def create_user( ) @router.http("DELETE", "auth/users/{user_id}") - def delete_user(self, user_id) -> Optional[Dict]: # pylint: disable=unused-argument + def delete_user(self, user_id) -> dict | None: # pylint: disable=unused-argument return self._invoke_request() @router.http("PUT", "auth/users/{user_id}") @@ -57,7 +53,7 @@ def modify_user( user_password=None, user_phone=None, user_email=None, - ) -> Optional[Dict]: + ) -> dict | None: return self._invoke_request( data=json.dumps( { @@ -70,21 +66,21 @@ def modify_user( ) @router.http("GET", "auth/users/{user_id}") - def get_user(self, user_id) -> Optional[Dict]: # pylint: disable=unused-argument + def get_user(self, user_id) -> dict | None: # pylint: disable=unused-argument return self._invoke_request() @router.http("GET", "auth/groups") - def list_groups(self, limit=None) -> Optional[Dict]: + def list_groups(self, limit=None) -> dict | None: params = {"limit": limit} if limit is not None else {} return self._invoke_request(params=params) @router.http("POST", "auth/groups") - def create_group(self, group_name, group_description=None) -> Optional[Dict]: + def create_group(self, group_name, group_description=None) -> dict | None: data = {"group_name": group_name, "group_description": group_description} return self._invoke_request(data=json.dumps(data)) @router.http("DELETE", "auth/groups/{group_id}") - def delete_group(self, group_id) -> Optional[Dict]: # pylint: disable=unused-argument + def delete_group(self, group_id) -> dict | None: # pylint: disable=unused-argument return self._invoke_request() @router.http("PUT", "auth/groups/{group_id}") @@ -93,16 +89,16 @@ def modify_group( group_id, # pylint: disable=unused-argument group_name=None, group_description=None, - ) -> Optional[Dict]: + ) -> dict | None: data = {"group_name": group_name, "group_description": group_description} return self._invoke_request(data=json.dumps(data)) @router.http("GET", "auth/groups/{group_id}") - def get_group(self, group_id) -> Optional[Dict]: # pylint: disable=unused-argument + def get_group(self, group_id) -> dict | None: # pylint: disable=unused-argument return self._invoke_request() @router.http("POST", "auth/accesses") - def grant_accesses(self, group_id, target_id, access_permission) -> Optional[Dict]: + def grant_accesses(self, group_id, target_id, access_permission) -> dict | None: return self._invoke_request( data=json.dumps( { @@ -114,29 +110,25 @@ def grant_accesses(self, group_id, target_id, access_permission) -> Optional[Dic ) @router.http("DELETE", "auth/accesses/{access_id}") - def revoke_accesses(self, access_id) -> Optional[Dict]: # pylint: disable=unused-argument + def revoke_accesses(self, access_id) -> dict | None: # pylint: disable=unused-argument return self._invoke_request() @router.http("PUT", "auth/accesses/{access_id}") - def modify_accesses( - self, access_id, access_description - ) -> Optional[Dict]: # pylint: disable=unused-argument + def modify_accesses(self, access_id, access_description) -> dict | None: # pylint: disable=unused-argument # The permission of access can\'t be updated data = {"access_description": access_description} return self._invoke_request(data=json.dumps(data)) @router.http("GET", "auth/accesses/{access_id}") - def get_accesses(self, access_id) -> Optional[Dict]: # pylint: disable=unused-argument + def get_accesses(self, access_id) -> dict | None: # pylint: disable=unused-argument return self._invoke_request() @router.http("GET", "auth/accesses") - def list_accesses(self) -> Optional[Dict]: + def list_accesses(self) -> dict | None: return self._invoke_request() @router.http("POST", "auth/targets") - def create_target( - self, target_name, target_graph, target_url, target_resources - ) -> Optional[Dict]: + def create_target(self, target_name, target_graph, target_url, target_resources) -> dict | None: return self._invoke_request( data=json.dumps( { @@ -160,7 +152,7 @@ def update_target( target_graph, target_url, target_resources, - ) -> Optional[Dict]: + ) -> dict | None: return self._invoke_request( data=json.dumps( { @@ -173,17 +165,15 @@ def update_target( ) @router.http("GET", "auth/targets/{target_id}") - def get_target( - self, target_id, response=None - ) -> Optional[Dict]: # pylint: disable=unused-argument + def get_target(self, target_id, response=None) -> dict | None: # pylint: disable=unused-argument return self._invoke_request() @router.http("GET", "auth/targets") - def list_targets(self) -> Optional[Dict]: + def list_targets(self) -> dict | None: return self._invoke_request() @router.http("POST", "auth/belongs") - def create_belong(self, user_id, group_id) -> Optional[Dict]: + def create_belong(self, user_id, group_id) -> dict | None: data = {"user": user_id, "group": group_id} return self._invoke_request(data=json.dumps(data)) @@ -192,16 +182,14 @@ def delete_belong(self, belong_id) -> None: # pylint: disable=unused-argument return self._invoke_request() @router.http("PUT", "auth/belongs/{belong_id}") - def update_belong( - self, belong_id, description - ) -> Optional[Dict]: # pylint: disable=unused-argument + def update_belong(self, belong_id, description) -> dict | None: # pylint: disable=unused-argument data = {"belong_description": description} return self._invoke_request(data=json.dumps(data)) @router.http("GET", "auth/belongs/{belong_id}") - def get_belong(self, belong_id) -> Optional[Dict]: # pylint: disable=unused-argument + def get_belong(self, belong_id) -> dict | None: # pylint: disable=unused-argument return self._invoke_request() @router.http("GET", "auth/belongs") - def list_belongs(self) -> Optional[Dict]: + def list_belongs(self) -> dict | None: return self._invoke_request() diff --git a/hugegraph-python-client/src/pyhugegraph/api/common.py b/hugegraph-python-client/src/pyhugegraph/api/common.py index 631a2779a..d72c02abd 100644 --- a/hugegraph-python-client/src/pyhugegraph/api/common.py +++ b/hugegraph-python-client/src/pyhugegraph/api/common.py @@ -18,10 +18,9 @@ import re -from abc import ABC -from pyhugegraph.utils.log import log -from pyhugegraph.utils.huge_router import RouterMixin from pyhugegraph.utils.huge_requests import HGraphSession +from pyhugegraph.utils.huge_router import RouterMixin +from pyhugegraph.utils.log import log # todo: rename -> HGraphMetaData or delete @@ -44,7 +43,7 @@ def get_keys(self): return self._dic.keys() -class HGraphContext(ABC): +class HGraphContext: def __init__(self, sess: HGraphSession) -> None: self._sess = sess self._cache = {} # todo: move parameter_holder to cache diff --git a/hugegraph-python-client/src/pyhugegraph/api/graph.py b/hugegraph-python-client/src/pyhugegraph/api/graph.py index 4555eeda4..4b6aab1c0 100644 --- a/hugegraph-python-client/src/pyhugegraph/api/graph.py +++ b/hugegraph-python-client/src/pyhugegraph/api/graph.py @@ -16,7 +16,6 @@ # under the License. import json -from typing import Optional, List from pyhugegraph.api.common import HugeParamsBase from pyhugegraph.structure.edge_data import EdgeData @@ -26,7 +25,6 @@ class GraphManager(HugeParamsBase): - @router.http("POST", "graph/vertices") def addVertex(self, label, properties, id=None): data = {} @@ -108,7 +106,7 @@ def removeVertexById(self, vertex_id): # pylint: disable=unused-argument return self._invoke_request() @router.http("POST", "graph/edges") - def addEdge(self, edge_label, out_id, in_id, properties) -> Optional[EdgeData]: + def addEdge(self, edge_label, out_id, in_id, properties) -> EdgeData | None: data = { "label": edge_label, "outV": out_id, @@ -120,7 +118,7 @@ def addEdge(self, edge_label, out_id, in_id, properties) -> Optional[EdgeData]: return None @router.http("POST", "graph/edges/batch") - def addEdges(self, input_data) -> Optional[List[EdgeData]]: + def addEdges(self, input_data) -> list[EdgeData] | None: data = [] for item in input_data: data.append( @@ -139,22 +137,26 @@ def addEdges(self, input_data) -> Optional[List[EdgeData]]: @router.http("PUT", "graph/edges/{edge_id}?action=append") def appendEdge( - self, edge_id, properties # pylint: disable=unused-argument - ) -> Optional[EdgeData]: + self, + edge_id, + properties, # pylint: disable=unused-argument + ) -> EdgeData | None: if response := self._invoke_request(data=json.dumps({"properties": properties})): return EdgeData(response) return None @router.http("PUT", "graph/edges/{edge_id}?action=eliminate") def eliminateEdge( - self, edge_id, properties # pylint: disable=unused-argument - ) -> Optional[EdgeData]: + self, + edge_id, + properties, # pylint: disable=unused-argument + ) -> EdgeData | None: if response := self._invoke_request(data=json.dumps({"properties": properties})): return EdgeData(response) return None @router.http("GET", "graph/edges/{edge_id}") - def getEdgeById(self, edge_id) -> Optional[EdgeData]: # pylint: disable=unused-argument + def getEdgeById(self, edge_id) -> EdgeData | None: # pylint: disable=unused-argument if response := self._invoke_request(): return EdgeData(response) return None @@ -194,7 +196,7 @@ def getEdgeByPage( def removeEdgeById(self, edge_id) -> dict: # pylint: disable=unused-argument return self._invoke_request() - def getVerticesById(self, vertex_ids) -> Optional[List[VertexData]]: + def getVerticesById(self, vertex_ids) -> list[VertexData] | None: if not vertex_ids: return [] path = "traversers/vertices?" @@ -205,7 +207,7 @@ def getVerticesById(self, vertex_ids) -> Optional[List[VertexData]]: return [VertexData(item) for item in response["vertices"]] return None - def getEdgesById(self, edge_ids) -> Optional[List[EdgeData]]: + def getEdgesById(self, edge_ids) -> list[EdgeData] | None: if not edge_ids: return [] path = "traversers/edges?" diff --git a/hugegraph-python-client/src/pyhugegraph/api/graphs.py b/hugegraph-python-client/src/pyhugegraph/api/graphs.py index dbcf477a4..cde2acfe1 100644 --- a/hugegraph-python-client/src/pyhugegraph/api/graphs.py +++ b/hugegraph-python-client/src/pyhugegraph/api/graphs.py @@ -23,7 +23,6 @@ class GraphsManager(HugeParamsBase): - @router.http("GET", "/graphs") def get_all_graphs(self) -> dict: return self._invoke_request(validator=ResponseValidation("text")) diff --git a/hugegraph-python-client/src/pyhugegraph/api/gremlin.py b/hugegraph-python-client/src/pyhugegraph/api/gremlin.py index 3261d60b3..3fa79368b 100644 --- a/hugegraph-python-client/src/pyhugegraph/api/gremlin.py +++ b/hugegraph-python-client/src/pyhugegraph/api/gremlin.py @@ -17,15 +17,14 @@ from pyhugegraph.api.common import HugeParamsBase -from pyhugegraph.utils.exceptions import NotFoundError from pyhugegraph.structure.gremlin_data import GremlinData from pyhugegraph.structure.response_data import ResponseData from pyhugegraph.utils import huge_router as router +from pyhugegraph.utils.exceptions import NotFoundError from pyhugegraph.utils.log import log class GremlinManager(HugeParamsBase): - @router.http("POST", "/gremlin") def exec(self, gremlin): gremlin_data = GremlinData(gremlin) diff --git a/hugegraph-python-client/src/pyhugegraph/api/metric.py b/hugegraph-python-client/src/pyhugegraph/api/metric.py index 27fd715ca..f32eb55c8 100644 --- a/hugegraph-python-client/src/pyhugegraph/api/metric.py +++ b/hugegraph-python-client/src/pyhugegraph/api/metric.py @@ -20,7 +20,6 @@ class MetricsManager(HugeParamsBase): - @router.http("GET", "/metrics/?type=json") def get_all_basic_metrics(self) -> dict: return self._invoke_request() diff --git a/hugegraph-python-client/src/pyhugegraph/api/rank.py b/hugegraph-python-client/src/pyhugegraph/api/rank.py index ba7bb0bb1..64f743b54 100644 --- a/hugegraph-python-client/src/pyhugegraph/api/rank.py +++ b/hugegraph-python-client/src/pyhugegraph/api/rank.py @@ -16,12 +16,12 @@ # under the License. -from pyhugegraph.utils import huge_router as router +from pyhugegraph.api.common import HugeParamsBase from pyhugegraph.structure.rank_data import ( - PersonalRankParameters, NeighborRankParameters, + PersonalRankParameters, ) -from pyhugegraph.api.common import HugeParamsBase +from pyhugegraph.utils import huge_router as router class RankManager(HugeParamsBase): diff --git a/hugegraph-python-client/src/pyhugegraph/api/rebuild.py b/hugegraph-python-client/src/pyhugegraph/api/rebuild.py index 74075f8bd..1429e1709 100644 --- a/hugegraph-python-client/src/pyhugegraph/api/rebuild.py +++ b/hugegraph-python-client/src/pyhugegraph/api/rebuild.py @@ -16,8 +16,8 @@ # under the License. -from pyhugegraph.utils import huge_router as router from pyhugegraph.api.common import HugeParamsBase +from pyhugegraph.utils import huge_router as router class RebuildManager(HugeParamsBase): diff --git a/hugegraph-python-client/src/pyhugegraph/api/schema.py b/hugegraph-python-client/src/pyhugegraph/api/schema.py index 8095887b0..09ba30715 100644 --- a/hugegraph-python-client/src/pyhugegraph/api/schema.py +++ b/hugegraph-python-client/src/pyhugegraph/api/schema.py @@ -16,7 +16,6 @@ # under the License. -from typing import Optional, Dict, List from pyhugegraph.api.common import HugeParamsBase from pyhugegraph.api.schema_manage.edge_label import EdgeLabel from pyhugegraph.api.schema_manage.index_label import IndexLabel @@ -64,53 +63,49 @@ def indexLabel(self, name): return index_label @router.http("GET", "schema?format={_format}") - def getSchema(self, _format: str = "json") -> Optional[Dict]: # pylint: disable=unused-argument + def getSchema(self, _format: str = "json") -> dict | None: # pylint: disable=unused-argument return self._invoke_request() @router.http("GET", "schema/propertykeys/{property_name}") - def getPropertyKey( - self, property_name - ) -> Optional[PropertyKeyData]: # pylint: disable=unused-argument + def getPropertyKey(self, property_name) -> PropertyKeyData | None: # pylint: disable=unused-argument if response := self._invoke_request(): return PropertyKeyData(response) return None @router.http("GET", "schema/propertykeys") - def getPropertyKeys(self) -> Optional[List[PropertyKeyData]]: + def getPropertyKeys(self) -> list[PropertyKeyData] | None: if response := self._invoke_request(): return [PropertyKeyData(item) for item in response["propertykeys"]] return None @router.http("GET", "schema/vertexlabels/{name}") - def getVertexLabel(self, name) -> Optional[VertexLabelData]: # pylint: disable=unused-argument + def getVertexLabel(self, name) -> VertexLabelData | None: # pylint: disable=unused-argument if response := self._invoke_request(): return VertexLabelData(response) log.error("VertexLabel not found: %s", str(response)) return None @router.http("GET", "schema/vertexlabels") - def getVertexLabels(self) -> Optional[List[VertexLabelData]]: + def getVertexLabels(self) -> list[VertexLabelData] | None: if response := self._invoke_request(): return [VertexLabelData(item) for item in response["vertexlabels"]] return None @router.http("GET", "schema/edgelabels/{label_name}") - def getEdgeLabel( - self, label_name: str - ) -> Optional[EdgeLabelData]: # pylint: disable=unused-argument + def getEdgeLabel(self, label_name: str) -> EdgeLabelData | None: # pylint: disable=unused-argument if response := self._invoke_request(): return EdgeLabelData(response) log.error("EdgeLabel not found: %s", str(response)) return None @router.http("GET", "schema/edgelabels") - def getEdgeLabels(self) -> Optional[List[EdgeLabelData]]: + def getEdgeLabels(self) -> list[EdgeLabelData] | None: if response := self._invoke_request(): return [EdgeLabelData(item) for item in response["edgelabels"]] return None @router.http("GET", "schema/edgelabels") - def getRelations(self) -> Optional[List[str]]: + def getRelations(self) -> list[str] | None: """ Retrieve all edge_label links/paths from the graph-sever. @@ -124,14 +119,14 @@ def getRelations(self) -> Optional[List[str]]: return None @router.http("GET", "schema/indexlabels/{name}") - def getIndexLabel(self, name) -> Optional[IndexLabelData]: # pylint: disable=unused-argument + def getIndexLabel(self, name) -> IndexLabelData | None: # pylint: disable=unused-argument if response := self._invoke_request(): return IndexLabelData(response) log.error("IndexLabel not found: %s", str(response)) return None @router.http("GET", "schema/indexlabels") - def getIndexLabels(self) -> Optional[List[IndexLabelData]]: + def getIndexLabels(self) -> list[IndexLabelData] | None: if response := self._invoke_request(): return [IndexLabelData(item) for item in response["indexlabels"]] return None diff --git a/hugegraph-python-client/src/pyhugegraph/api/schema_manage/edge_label.py b/hugegraph-python-client/src/pyhugegraph/api/schema_manage/edge_label.py index 93f218001..932a819ac 100644 --- a/hugegraph-python-client/src/pyhugegraph/api/schema_manage/edge_label.py +++ b/hugegraph-python-client/src/pyhugegraph/api/schema_manage/edge_label.py @@ -26,7 +26,6 @@ class EdgeLabel(HugeParamsBase): - @decorator_params def link(self, source_label, target_label) -> "EdgeLabel": self._parameter_holder.set("source_label", source_label) @@ -82,7 +81,7 @@ def nullableKeys(self, *args) -> "EdgeLabel": @decorator_params def ifNotExist(self) -> "EdgeLabel": - path = f'schema/edgelabels/{self._parameter_holder.get_value("name")}' + path = f"schema/edgelabels/{self._parameter_holder.get_value('name')}" if _ := self._sess.request(path, validator=ResponseValidation(strict=False)): self._parameter_holder.set("not_exist", False) return self @@ -117,17 +116,17 @@ def create(self): path = "schema/edgelabels" self.clean_parameter_holder() if response := self._sess.request(path, "POST", data=json.dumps(data)): - return f'create EdgeLabel success, Detail: "{str(response)}"' - log.error(f'create EdgeLabel failed, Detail: "{str(response)}"') + return f'create EdgeLabel success, Detail: "{response!s}"' + log.error(f'create EdgeLabel failed, Detail: "{response!s}"') return None @decorator_params def remove(self): - path = f'schema/edgelabels/{self._parameter_holder.get_value("name")}' + path = f"schema/edgelabels/{self._parameter_holder.get_value('name')}" self.clean_parameter_holder() if response := self._sess.request(path, "DELETE"): - return f'remove EdgeLabel success, Detail: "{str(response)}"' - log.error(f'remove EdgeLabel failed, Detail: "{str(response)}"') + return f'remove EdgeLabel success, Detail: "{response!s}"' + log.error(f'remove EdgeLabel failed, Detail: "{response!s}"') return None @decorator_params @@ -139,25 +138,23 @@ def append(self): if key in dic: data[key] = dic[key] - path = f'schema/edgelabels/{data["name"]}?action=append' + path = f"schema/edgelabels/{data['name']}?action=append" self.clean_parameter_holder() if response := self._sess.request(path, "PUT", data=json.dumps(data)): - return f'append EdgeLabel success, Detail: "{str(response)}"' - log.error(f'append EdgeLabel failed, Detail: "{str(response)}"') + return f'append EdgeLabel success, Detail: "{response!s}"' + log.error(f'append EdgeLabel failed, Detail: "{response!s}"') return None @decorator_params def eliminate(self): name = self._parameter_holder.get_value("name") user_data = ( - self._parameter_holder.get_value("user_data") - if self._parameter_holder.get_value("user_data") - else {} + self._parameter_holder.get_value("user_data") if self._parameter_holder.get_value("user_data") else {} ) path = f"schema/edgelabels/{name}?action=eliminate" data = {"name": name, "user_data": user_data} self.clean_parameter_holder() if response := self._sess.request(path, "PUT", data=json.dumps(data)): - return f'eliminate EdgeLabel success, Detail: "{str(response)}"' - log.error(f'eliminate EdgeLabel failed, Detail: "{str(response)}"') + return f'eliminate EdgeLabel success, Detail: "{response!s}"' + log.error(f'eliminate EdgeLabel failed, Detail: "{response!s}"') return None diff --git a/hugegraph-python-client/src/pyhugegraph/api/schema_manage/index_label.py b/hugegraph-python-client/src/pyhugegraph/api/schema_manage/index_label.py index acef8f968..94563a064 100644 --- a/hugegraph-python-client/src/pyhugegraph/api/schema_manage/index_label.py +++ b/hugegraph-python-client/src/pyhugegraph/api/schema_manage/index_label.py @@ -26,7 +26,6 @@ class IndexLabel(HugeParamsBase): - @decorator_params def onV(self, vertex_label) -> "IndexLabel": self._parameter_holder.set("base_value", vertex_label) @@ -75,7 +74,7 @@ def unique(self) -> "IndexLabel": @decorator_params def ifNotExist(self) -> "IndexLabel": - path = f'schema/indexlabels/{self._parameter_holder.get_value("name")}' + path = f"schema/indexlabels/{self._parameter_holder.get_value('name')}" if _ := self._sess.request(path, validator=ResponseValidation(strict=False)): self._parameter_holder.set("not_exist", False) return self @@ -93,8 +92,8 @@ def create(self): path = "schema/indexlabels" self.clean_parameter_holder() if response := self._sess.request(path, "POST", data=json.dumps(data)): - return f'create IndexLabel success, Detail: "{str(response)}"' - log.error(f'create IndexLabel failed, Detail: "{str(response)}"') + return f'create IndexLabel success, Detail: "{response!s}"' + log.error(f'create IndexLabel failed, Detail: "{response!s}"') return None @decorator_params @@ -103,6 +102,6 @@ def remove(self): path = f"schema/indexlabels/{name}" self.clean_parameter_holder() if response := self._sess.request(path, "DELETE"): - return f'remove IndexLabel success, Detail: "{str(response)}"' - log.error(f'remove IndexLabel failed, Detail: "{str(response)}"') + return f'remove IndexLabel success, Detail: "{response!s}"' + log.error(f'remove IndexLabel failed, Detail: "{response!s}"') return None diff --git a/hugegraph-python-client/src/pyhugegraph/api/schema_manage/property_key.py b/hugegraph-python-client/src/pyhugegraph/api/schema_manage/property_key.py index 045995ea6..eaefb59c1 100644 --- a/hugegraph-python-client/src/pyhugegraph/api/schema_manage/property_key.py +++ b/hugegraph-python-client/src/pyhugegraph/api/schema_manage/property_key.py @@ -126,7 +126,7 @@ def userdata(self, *args) -> "PropertyKey": return self def ifNotExist(self) -> "PropertyKey": - path = f'schema/propertykeys/{self._parameter_holder.get_value("name")}' + path = f"schema/propertykeys/{self._parameter_holder.get_value('name')}" if _ := self._sess.request(path, validator=ResponseValidation(strict=False)): self._parameter_holder.set("not_exist", False) return self @@ -142,7 +142,7 @@ def create(self): path = "schema/propertykeys" self.clean_parameter_holder() if response := self._sess.request(path, "POST", data=json.dumps(property_keys)): - return f"create PropertyKey success, Detail: {str(response)}" + return f"create PropertyKey success, Detail: {response!s}" log.error("create PropertyKey failed, Detail: %s", str(response)) return "" @@ -157,7 +157,7 @@ def append(self): path = f"schema/propertykeys/{property_name}/?action=append" self.clean_parameter_holder() if response := self._sess.request(path, "PUT", data=json.dumps(data)): - return f"append PropertyKey success, Detail: {str(response)}" + return f"append PropertyKey success, Detail: {response!s}" log.error("append PropertyKey failed, Detail: %s", str(response)) return "" @@ -172,16 +172,16 @@ def eliminate(self): path = f"schema/propertykeys/{property_name}/?action=eliminate" self.clean_parameter_holder() if response := self._sess.request(path, "PUT", data=json.dumps(data)): - return f"eliminate PropertyKey success, Detail: {str(response)}" + return f"eliminate PropertyKey success, Detail: {response!s}" log.error("eliminate PropertyKey failed, Detail: %s", str(response)) return "" @decorator_params def remove(self): dic = self._parameter_holder.get_dic() - path = f'schema/propertykeys/{dic["name"]}' + path = f"schema/propertykeys/{dic['name']}" self.clean_parameter_holder() if response := self._sess.request(path, "DELETE"): - return f"delete PropertyKey success, Detail: {str(response)}" + return f"delete PropertyKey success, Detail: {response!s}" log.error("delete PropertyKey failed, Detail: %s", str(response)) return "" diff --git a/hugegraph-python-client/src/pyhugegraph/api/schema_manage/vertex_label.py b/hugegraph-python-client/src/pyhugegraph/api/schema_manage/vertex_label.py index 597fb6f0e..fca8153d5 100644 --- a/hugegraph-python-client/src/pyhugegraph/api/schema_manage/vertex_label.py +++ b/hugegraph-python-client/src/pyhugegraph/api/schema_manage/vertex_label.py @@ -24,7 +24,6 @@ class VertexLabel(HugeParamsBase): - @decorator_params def useAutomaticId(self) -> "VertexLabel": self._parameter_holder.set("id_strategy", "AUTOMATIC") @@ -77,7 +76,7 @@ def userdata(self, *args) -> "VertexLabel": return self def ifNotExist(self) -> "VertexLabel": - path = f'schema/vertexlabels/{self._parameter_holder.get_value("name")}' + path = f"schema/vertexlabels/{self._parameter_holder.get_value('name')}" if _ := self._sess.request(path, validator=ResponseValidation(strict=False)): self._parameter_holder.set("not_exist", False) return self @@ -102,17 +101,17 @@ def create(self): path = "schema/vertexlabels" self.clean_parameter_holder() if response := self._sess.request(path, "POST", data=json.dumps(data)): - return f'create VertexLabel success, Detail: "{str(response)}"' + return f'create VertexLabel success, Detail: "{response!s}"' log.error("create VertexLabel failed, Detail: %s", str(response)) return "" @decorator_params def append(self) -> None: dic = self._parameter_holder.get_dic() - properties = dic["properties"] if "properties" in dic else [] - nullable_keys = dic["nullable_keys"] if "nullable_keys" in dic else [] - user_data = dic["user_data"] if "user_data" in dic else {} - path = f'schema/vertexlabels/{dic["name"]}?action=append' + properties = dic.get("properties", []) + nullable_keys = dic.get("nullable_keys", []) + user_data = dic.get("user_data", {}) + path = f"schema/vertexlabels/{dic['name']}?action=append" data = { "name": dic["name"], "properties": properties, @@ -121,7 +120,7 @@ def append(self) -> None: } self.clean_parameter_holder() if response := self._sess.request(path, "PUT", data=json.dumps(data)): - return f'append VertexLabel success, Detail: "{str(response)}"' + return f'append VertexLabel success, Detail: "{response!s}"' log.error("append VertexLabel failed, Detail: %s", str(response)) return "" @@ -131,7 +130,7 @@ def remove(self) -> None: path = f"schema/vertexlabels/{name}" self.clean_parameter_holder() if response := self._sess.request(path, "DELETE"): - return f'remove VertexLabel success, Detail: "{str(response)}"' + return f'remove VertexLabel success, Detail: "{response!s}"' log.error("remove VertexLabel failed, Detail: %s", str(response)) return "" @@ -141,12 +140,12 @@ def eliminate(self) -> None: path = f"schema/vertexlabels/{name}/?action=eliminate" dic = self._parameter_holder.get_dic() - user_data = dic["user_data"] if "user_data" in dic else {} + user_data = dic.get("user_data", {}) data = { "name": self._parameter_holder.get_value("name"), "user_data": user_data, } if response := self._sess.request(path, "PUT", data=json.dumps(data)): - return f'eliminate VertexLabel success, Detail: "{str(response)}"' + return f'eliminate VertexLabel success, Detail: "{response!s}"' log.error("eliminate VertexLabel failed, Detail: %s", str(response)) return "" diff --git a/hugegraph-python-client/src/pyhugegraph/api/services.py b/hugegraph-python-client/src/pyhugegraph/api/services.py index f353673db..258c0469a 100644 --- a/hugegraph-python-client/src/pyhugegraph/api/services.py +++ b/hugegraph-python-client/src/pyhugegraph/api/services.py @@ -16,9 +16,9 @@ # under the License. -from pyhugegraph.utils import huge_router as router from pyhugegraph.api.common import HugeParamsBase from pyhugegraph.structure.services_data import ServiceCreateParameters +from pyhugegraph.utils import huge_router as router class ServicesManager(HugeParamsBase): @@ -125,7 +125,6 @@ def delete_service(self, graphspace: str, service: str): # pylint: disable=unus None """ return self._sess.request( - f"/graphspaces/{graphspace}/services/{service}" - f"?confirm_message=I'm sure to delete the service", + f"/graphspaces/{graphspace}/services/{service}?confirm_message=I'm sure to delete the service", "DELETE", ) diff --git a/hugegraph-python-client/src/pyhugegraph/api/task.py b/hugegraph-python-client/src/pyhugegraph/api/task.py index 0668eb0c4..9468da746 100644 --- a/hugegraph-python-client/src/pyhugegraph/api/task.py +++ b/hugegraph-python-client/src/pyhugegraph/api/task.py @@ -20,7 +20,6 @@ class TaskManager(HugeParamsBase): - @router.http("GET", "tasks") def list_tasks(self, status=None, limit=None): params = {} diff --git a/hugegraph-python-client/src/pyhugegraph/api/traverser.py b/hugegraph-python-client/src/pyhugegraph/api/traverser.py index 2f226522b..f296313e1 100644 --- a/hugegraph-python-client/src/pyhugegraph/api/traverser.py +++ b/hugegraph-python-client/src/pyhugegraph/api/traverser.py @@ -21,7 +21,6 @@ class TraverserManager(HugeParamsBase): - @router.http("GET", 'traversers/kout?source="{source_id}"&max_depth={max_depth}') def k_out(self, source_id, max_depth): # pylint: disable=unused-argument return self._invoke_request() @@ -49,9 +48,7 @@ def shortest_path(self, source_id, target_id, max_depth): # pylint: disable=unu "GET", 'traversers/allshortestpaths?source="{source_id}"&target="{target_id}"&max_depth={max_depth}', ) - def all_shortest_paths( - self, source_id, target_id, max_depth - ): # pylint: disable=unused-argument + def all_shortest_paths(self, source_id, target_id, max_depth): # pylint: disable=unused-argument return self._invoke_request() @router.http( @@ -59,9 +56,7 @@ def all_shortest_paths( 'traversers/weightedshortestpath?source="{source_id}"&target="{target_id}"' "&weight={weight}&max_depth={max_depth}", ) - def weighted_shortest_path( - self, source_id, target_id, weight, max_depth - ): # pylint: disable=unused-argument + def weighted_shortest_path(self, source_id, target_id, weight, max_depth): # pylint: disable=unused-argument return self._invoke_request() @router.http( @@ -130,9 +125,7 @@ def advanced_paths( ) @router.http("POST", "traversers/customizedpaths") - def customized_paths( - self, sources, steps, sort_by="INCR", with_vertex=True, capacity=-1, limit=-1 - ): + def customized_paths(self, sources, steps, sort_by="INCR", with_vertex=True, capacity=-1, limit=-1): return self._invoke_request( data=json.dumps( { diff --git a/hugegraph-python-client/src/pyhugegraph/api/variable.py b/hugegraph-python-client/src/pyhugegraph/api/variable.py index c5fe84017..895a56200 100644 --- a/hugegraph-python-client/src/pyhugegraph/api/variable.py +++ b/hugegraph-python-client/src/pyhugegraph/api/variable.py @@ -22,7 +22,6 @@ class VariableManager(HugeParamsBase): - @router.http("PUT", "variables/{key}") def set(self, key, value): # pylint: disable=unused-argument return self._invoke_request(data=json.dumps({"data": value})) diff --git a/hugegraph-python-client/src/pyhugegraph/api/version.py b/hugegraph-python-client/src/pyhugegraph/api/version.py index 3635c7718..a75da506a 100644 --- a/hugegraph-python-client/src/pyhugegraph/api/version.py +++ b/hugegraph-python-client/src/pyhugegraph/api/version.py @@ -20,7 +20,6 @@ class VersionManager(HugeParamsBase): - @router.http("GET", "/versions") def version(self): return self._invoke_request() diff --git a/hugegraph-python-client/src/pyhugegraph/client.py b/hugegraph-python-client/src/pyhugegraph/client.py index c9f4d1027..c0d4120fe 100644 --- a/hugegraph-python-client/src/pyhugegraph/client.py +++ b/hugegraph-python-client/src/pyhugegraph/client.py @@ -15,7 +15,8 @@ # specific language governing permissions and limitations # under the License. -from typing import Any, Callable, Optional, TypeVar +from collections.abc import Callable +from typing import Any, TypeVar from pyhugegraph.api.auth import AuthManager from pyhugegraph.api.graph import GraphManager @@ -52,8 +53,8 @@ def __init__( graph: str, user: str, pwd: str, - graphspace: Optional[str] = None, - timeout: Optional[tuple[float, float]] = None, + graphspace: str | None = None, + timeout: tuple[float, float] | None = None, ): self.cfg = HGraphConfig(url, user, pwd, graph, graphspace, timeout or (0.5, 15.0)) diff --git a/hugegraph-python-client/src/pyhugegraph/example/hugegraph_example.py b/hugegraph-python-client/src/pyhugegraph/example/hugegraph_example.py index d5cc0eb9d..a152ffe06 100644 --- a/hugegraph-python-client/src/pyhugegraph/example/hugegraph_example.py +++ b/hugegraph-python-client/src/pyhugegraph/example/hugegraph_example.py @@ -18,9 +18,7 @@ from pyhugegraph.client import PyHugeClient if __name__ == "__main__": - client = PyHugeClient( - url="http://127.0.0.1:8080", user="admin", pwd="admin", graph="hugegraph", graphspace=None - ) + client = PyHugeClient(url="http://127.0.0.1:8080", user="admin", pwd="admin", graph="hugegraph", graphspace=None) """schema""" schema = client.schema() @@ -29,15 +27,9 @@ schema.vertexLabel("Person").properties("name", "birthDate").usePrimaryKeyId().primaryKeys( "name" ).ifNotExist().create() - schema.vertexLabel("Movie").properties("name").usePrimaryKeyId().primaryKeys( - "name" - ).ifNotExist().create() + schema.vertexLabel("Movie").properties("name").usePrimaryKeyId().primaryKeys("name").ifNotExist().create() schema.edgeLabel("ActedIn").sourceLabel("Person").targetLabel("Movie").ifNotExist().create() - print(schema.getVertexLabels()) - print(schema.getEdgeLabels()) - print(schema.getRelations()) - """graph""" g = client.graph() # add Vertex @@ -56,17 +48,11 @@ # update property # g.eliminateVertex("vertex_id", {"property_key": "property_value"}) - print(g.getVertexById(p1.id).label) # g.removeVertexById("12:Al Pacino") g.close() """gremlin""" g = client.gremlin() - print("gremlin.exec: ", g.exec("g.V().limit(10)")) """graphs""" g = client.graphs() - print("get_graph_info: ", g.get_graph_info()) - print("get_all_graphs: ", g.get_all_graphs()) - print("get_version: ", g.get_version()) - print("get_graph_config: ", g.get_graph_config()) diff --git a/hugegraph-python-client/src/pyhugegraph/example/hugegraph_test.py b/hugegraph-python-client/src/pyhugegraph/example/hugegraph_test.py index 2bfe6ea97..dc9c18620 100644 --- a/hugegraph-python-client/src/pyhugegraph/example/hugegraph_test.py +++ b/hugegraph-python-client/src/pyhugegraph/example/hugegraph_test.py @@ -31,17 +31,14 @@ def __init__( from pyhugegraph.client import PyHugeClient except ImportError: raise ValueError( - "Please install HugeGraph Python client first: " - "`pip3 install hugegraph-python-client`" + "Please install HugeGraph Python client first: `pip3 install hugegraph-python-client`" ) from ImportError self.username = username self.password = password self.url = url self.graph = graph - self.client = PyHugeClient( - url=url, user=username, pwd=password, graph=graph, graphspace=None - ) + self.client = PyHugeClient(url=url, user=username, pwd=password, graph=graph, graphspace=None) self.schema = "" def exec(self, query) -> str: diff --git a/hugegraph-python-client/src/pyhugegraph/structure/edge_data.py b/hugegraph-python-client/src/pyhugegraph/structure/edge_data.py index e0732c956..f4e30a113 100644 --- a/hugegraph-python-client/src/pyhugegraph/structure/edge_data.py +++ b/hugegraph-python-client/src/pyhugegraph/structure/edge_data.py @@ -19,13 +19,13 @@ class EdgeData: def __init__(self, dic): self.__id = dic["id"] - self.__label = dic["label"] if "label" in dic else None - self.__type = dic["type"] if "type" in dic else None - self.__outV = dic["outV"] if "outV" in dic else None - self.__outVLabel = dic["outVLabel"] if "outVLabel" in dic else None - self.__inV = dic["inV"] if "inV" in dic else None - self.__inVLabel = dic["inVLabel"] if "inVLabel" in dic else None - self.__properties = dic["properties"] if "properties" in dic else None + self.__label = dic.get("label", None) + self.__type = dic.get("type", None) + self.__outV = dic.get("outV", None) + self.__outVLabel = dic.get("outVLabel", None) + self.__inV = dic.get("inV", None) + self.__inVLabel = dic.get("inVLabel", None) + self.__properties = dic.get("properties", None) @property def id(self): diff --git a/hugegraph-python-client/src/pyhugegraph/structure/index_label_data.py b/hugegraph-python-client/src/pyhugegraph/structure/index_label_data.py index 65aedfc5e..684869220 100644 --- a/hugegraph-python-client/src/pyhugegraph/structure/index_label_data.py +++ b/hugegraph-python-client/src/pyhugegraph/structure/index_label_data.py @@ -18,12 +18,12 @@ class IndexLabelData: def __init__(self, dic): - self.__id = dic["id"] if "id" in dic else None - self.__base_type = dic["base_type"] if "base_type" in dic else None - self.__base_value = dic["base_value"] if "base_value" in dic else None - self.__name = dic["name"] if "name" in dic else None - self.__fields = dic["fields"] if "fields" in dic else None - self.__index_type = dic["index_type"] if "index_type" in dic else None + self.__id = dic.get("id", None) + self.__base_type = dic.get("base_type", None) + self.__base_value = dic.get("base_value", None) + self.__name = dic.get("name", None) + self.__fields = dic.get("fields", None) + self.__index_type = dic.get("index_type", None) @property def id(self): diff --git a/hugegraph-python-client/src/pyhugegraph/structure/property_key_data.py b/hugegraph-python-client/src/pyhugegraph/structure/property_key_data.py index ff50d9b2f..6fb7c36f0 100644 --- a/hugegraph-python-client/src/pyhugegraph/structure/property_key_data.py +++ b/hugegraph-python-client/src/pyhugegraph/structure/property_key_data.py @@ -62,7 +62,5 @@ def userdata(self): return self.__user_data def __repr__(self): - res = ( - f"name: {self.__name}, cardinality: {self.__cardinality}, data_type: {self.__data_type}" - ) + res = f"name: {self.__name}, cardinality: {self.__cardinality}, data_type: {self.__data_type}" return res diff --git a/hugegraph-python-client/src/pyhugegraph/structure/rank_data.py b/hugegraph-python-client/src/pyhugegraph/structure/rank_data.py index af622c9ff..d11ee3d06 100644 --- a/hugegraph-python-client/src/pyhugegraph/structure/rank_data.py +++ b/hugegraph-python-client/src/pyhugegraph/structure/rank_data.py @@ -16,8 +16,6 @@ # under the License. import json - -from typing import List, Union from dataclasses import asdict, dataclass, field @@ -28,7 +26,7 @@ class NeighborRankStep: """ direction: str = "BOTH" - labels: List[str] = field(default_factory=list) + labels: list[str] = field(default_factory=list) max_degree: int = 10000 top: int = 100 @@ -42,11 +40,11 @@ class NeighborRankParameters: BodyParams defines the body parameters for the rank API requests. """ - source: Union[str, int] + source: str | int label: str alpha: float = 0.85 capacity: int = 10000000 - steps: List[NeighborRankStep] = field(default_factory=list) + steps: list[NeighborRankStep] = field(default_factory=list) def dumps(self): return json.dumps(asdict(self)) @@ -82,7 +80,7 @@ class PersonalRankParameters: of a different category (the other end of a bipartite graph), and "BOTH_LABEL" to keep both. """ - source: Union[str, int] + source: str | int label: str alpha: float = 0.85 max_degree: int = 10000 diff --git a/hugegraph-python-client/src/pyhugegraph/structure/services_data.py b/hugegraph-python-client/src/pyhugegraph/structure/services_data.py index 07b7a23f7..cbf1a6c3c 100644 --- a/hugegraph-python-client/src/pyhugegraph/structure/services_data.py +++ b/hugegraph-python-client/src/pyhugegraph/structure/services_data.py @@ -16,8 +16,6 @@ # under the License. import json - -from typing import List, Optional from dataclasses import asdict, dataclass, field @@ -53,10 +51,10 @@ class ServiceCreateParameters: cpu_limit: int = 1 memory_limit: int = 4 storage_limit: int = 100 - route_type: Optional[str] = None - port: Optional[int] = None - urls: List[str] = field(default_factory=list) - deployment_type: Optional[str] = None + route_type: str | None = None + port: int | None = None + urls: list[str] = field(default_factory=list) + deployment_type: str | None = None def dumps(self): return json.dumps(asdict(self)) diff --git a/hugegraph-python-client/src/pyhugegraph/structure/vertex_data.py b/hugegraph-python-client/src/pyhugegraph/structure/vertex_data.py index bccee3762..adf29ee65 100644 --- a/hugegraph-python-client/src/pyhugegraph/structure/vertex_data.py +++ b/hugegraph-python-client/src/pyhugegraph/structure/vertex_data.py @@ -19,9 +19,9 @@ class VertexData: def __init__(self, dic): self.__id = dic["id"] - self.__label = dic["label"] if "label" in dic else None - self.__type = dic["type"] if "type" in dic else None - self.__properties = dic["properties"] if "properties" in dic else None + self.__label = dic.get("label", None) + self.__type = dic.get("type", None) + self.__properties = dic.get("properties", None) @property def id(self): diff --git a/hugegraph-python-client/src/pyhugegraph/structure/vertex_label_data.py b/hugegraph-python-client/src/pyhugegraph/structure/vertex_label_data.py index 39da4eebb..c675adf4a 100644 --- a/hugegraph-python-client/src/pyhugegraph/structure/vertex_label_data.py +++ b/hugegraph-python-client/src/pyhugegraph/structure/vertex_label_data.py @@ -65,8 +65,5 @@ def enableLabelIndex(self): return self.__enable_label_index def __repr__(self): - res = ( - f"name: {self.__name}, primary_keys: {self.__primary_keys}, " - f"properties: {self.__properties}" - ) + res = f"name: {self.__name}, primary_keys: {self.__primary_keys}, properties: {self.__properties}" return res diff --git a/hugegraph-python-client/src/pyhugegraph/utils/exceptions.py b/hugegraph-python-client/src/pyhugegraph/utils/exceptions.py index 78fee8c19..f535d8643 100644 --- a/hugegraph-python-client/src/pyhugegraph/utils/exceptions.py +++ b/hugegraph-python-client/src/pyhugegraph/utils/exceptions.py @@ -22,7 +22,7 @@ class NotAuthorizedError(Exception): """ -class InvalidParameter(Exception): +class InvalidParameterError(Exception): """ Parameter setting error """ @@ -58,7 +58,7 @@ class DataFormatError(Exception): """ -class ServiceUnavailableException(Exception): +class ServiceUnavailableError(Exception): """ The server is too busy to be available """ diff --git a/hugegraph-python-client/src/pyhugegraph/utils/huge_config.py b/hugegraph-python-client/src/pyhugegraph/utils/huge_config.py index 429c07c6b..69c8949fb 100644 --- a/hugegraph-python-client/src/pyhugegraph/utils/huge_config.py +++ b/hugegraph-python-client/src/pyhugegraph/utils/huge_config.py @@ -19,7 +19,6 @@ import sys import traceback from dataclasses import dataclass, field -from typing import List, Optional import requests @@ -32,10 +31,10 @@ class HGraphConfig: username: str password: str graph_name: str - graphspace: Optional[str] = None + graphspace: str | None = None timeout: tuple[float, float] = (0.5, 15.0) gs_supported: bool = field(default=False, init=False) - version: List[int] = field(default_factory=list) + version: list[int] = field(default_factory=list) def __post_init__(self): # Add URL prefix compatibility check @@ -69,6 +68,4 @@ def __post_init__(self): except Exception: # pylint: disable=broad-exception-caught exc_type, exc_value, tb = sys.exc_info() traceback.print_exception(exc_type, exc_value, tb) - log.warning( - "Failed to retrieve API version information from the server, reverting to default v1." - ) + log.warning("Failed to retrieve API version information from the server, reverting to default v1.") diff --git a/hugegraph-python-client/src/pyhugegraph/utils/huge_decorator.py b/hugegraph-python-client/src/pyhugegraph/utils/huge_decorator.py index a6dbe891e..7233cb8ce 100644 --- a/hugegraph-python-client/src/pyhugegraph/utils/huge_decorator.py +++ b/hugegraph-python-client/src/pyhugegraph/utils/huge_decorator.py @@ -17,6 +17,7 @@ from decorator import decorator + from pyhugegraph.utils.exceptions import NotAuthorizedError @@ -24,7 +25,6 @@ def decorator_params(func, *args, **kwargs): parameter_holder = args[0].get_parameter_holder() if parameter_holder is None or "name" not in parameter_holder.get_keys(): - print("Parameters required, please set necessary parameters.") raise Exception("Parameters required, please set necessary parameters.") return func(*args, **kwargs) @@ -41,5 +41,5 @@ def decorator_create(func, *args, **kwargs): def decorator_auth(func, *args, **kwargs): response = args[0] if response.status_code == 401: - raise NotAuthorizedError(f"NotAuthorized: {str(response.content)}") + raise NotAuthorizedError(f"NotAuthorized: {response.content!s}") return func(*args, **kwargs) diff --git a/hugegraph-python-client/src/pyhugegraph/utils/huge_requests.py b/hugegraph-python-client/src/pyhugegraph/utils/huge_requests.py index 4d99a0e45..4ced85dbe 100644 --- a/hugegraph-python-client/src/pyhugegraph/utils/huge_requests.py +++ b/hugegraph-python-client/src/pyhugegraph/utils/huge_requests.py @@ -16,7 +16,7 @@ # under the License. -from typing import Any, Optional +from typing import Any from urllib.parse import urljoin import requests @@ -36,7 +36,7 @@ def __init__( retries: int = 3, backoff_factor: int = 0.1, status_forcelist=(500, 502, 504), - session: Optional[requests.Session] = None, + session: requests.Session | None = None, ): """ Initialize the HGraphSession object. @@ -136,9 +136,11 @@ def request( self, path: str, method: str = "GET", - validator=ResponseValidation(), + validator=None, **kwargs: Any, ) -> dict: + if validator is None: + validator = ResponseValidation() url = self.resolve(path) response: requests.Response = getattr(self._session, method.lower())( url, diff --git a/hugegraph-python-client/src/pyhugegraph/utils/huge_router.py b/hugegraph-python-client/src/pyhugegraph/utils/huge_router.py index f4a38a418..48a9b3817 100644 --- a/hugegraph-python-client/src/pyhugegraph/utils/huge_router.py +++ b/hugegraph-python-client/src/pyhugegraph/utils/huge_router.py @@ -15,24 +15,24 @@ # specific language governing permissions and limitations # under the License. -import re -import inspect import functools +import inspect +import re import threading - +from collections.abc import Callable from dataclasses import dataclass -from typing import Any, Callable, Dict, Optional, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, ClassVar + from pyhugegraph.utils.log import log from pyhugegraph.utils.util import ResponseValidation - if TYPE_CHECKING: from pyhugegraph.api.common import HGraphContext class SingletonMeta(type): - _instances = {} - _lock = threading.Lock() + _instances: ClassVar[dict] = {} + _lock: ClassVar[threading.Lock] = threading.Lock() def __call__(cls, *args, **kwargs): """ @@ -50,12 +50,12 @@ def __call__(cls, *args, **kwargs): class Route: method: str path: str - request_func: Optional[Callable] = None + request_func: Callable | None = None class RouterRegistry(metaclass=SingletonMeta): def __init__(self): - self._routers: Dict[str, Route] = {} + self._routers: dict[str, Route] = {} def register(self, key: str, route: Route): self._routers[key] = route @@ -69,7 +69,6 @@ def __repr__(self) -> str: def register(method: str, path: str) -> Callable: - def decorator(func: Callable) -> Callable: RouterRegistry().register( func.__qualname__, @@ -144,10 +143,7 @@ def wrapper(self: "HGraphContext", *args: Any, **kwargs: Any) -> Any: class RouterMixin: - - def _invoke_request_registered( - self, placeholders: dict = None, validator=ResponseValidation(), **kwargs: Any - ): + def _invoke_request_registered(self, placeholders: dict | None = None, validator=None, **kwargs: Any): """ Make an HTTP request using the stored partial request function. Args: @@ -155,6 +151,8 @@ def _invoke_request_registered( Returns: Any: The response from the HTTP request. """ + if validator is None: + validator = ResponseValidation() frame = inspect.currentframe().f_back fname = frame.f_code.co_name route = RouterRegistry().routers.get(f"{self.__class__.__name__}.{fname}") @@ -171,7 +169,7 @@ def _invoke_request_registered( ) return route.request_func(formatted_path, validator=validator, **kwargs) - def _invoke_request(self, validator=ResponseValidation(), **kwargs: Any): + def _invoke_request(self, validator=None, **kwargs: Any): """ Make an HTTP request using the stored partial request function. @@ -181,9 +179,11 @@ def _invoke_request(self, validator=ResponseValidation(), **kwargs: Any): Returns: Any: The response from the HTTP request. """ + if validator is None: + validator = ResponseValidation() frame = inspect.currentframe().f_back fname = frame.f_code.co_name log.debug( # pylint: disable=logging-fstring-interpolation - f"Invoke request: {str(self.__class__.__name__)}.{fname}" + f"Invoke request: {self.__class__.__name__!s}.{fname}" ) return getattr(self, f"_{fname}_request")(validator=validator, **kwargs) diff --git a/hugegraph-python-client/src/pyhugegraph/utils/log.py b/hugegraph-python-client/src/pyhugegraph/utils/log.py index c6f6bd074..e381d25bf 100644 --- a/hugegraph-python-client/src/pyhugegraph/utils/log.py +++ b/hugegraph-python-client/src/pyhugegraph/utils/log.py @@ -13,17 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import atexit -import logging -import os -import sys -import time -from collections import Counter -from functools import lru_cache -from logging.handlers import RotatingFileHandler - -from rich.logging import RichHandler - """ HugeGraph Logger Util ====================== @@ -42,32 +31,44 @@ Example Usage: from pyhugegraph.utils.log import init_logger - + # Initialize logger with both console and file output log = init_logger( log_output="logs/myapp.log", log_level=logging.INFO, logger_name="myapp" ) - + # Use the log/logger log.info("Application started") log.debug("Processing data...") log.error("Error occurred: %s", error_msg) """ + +import atexit +import logging +import os +import sys +import time +from collections import Counter +from functools import cache, lru_cache +from logging.handlers import RotatingFileHandler + +from rich.logging import RichHandler + __all__ = [ - "init_logger", "fetch_log_level", - "log_first_n_times", - "log_every_n_times", + "init_logger", "log_every_n_secs", + "log_every_n_times", + "log_first_n_times", ] LOG_BUFFER_SIZE_ENV: str = "LOG_BUFFER_SIZE" DEFAULT_BUFFER_SIZE: int = 1024 * 1024 # 1MB -@lru_cache() # avoid creating multiple handlers when calling init_logger() +@lru_cache # avoid creating multiple handlers when calling init_logger() def init_logger( log_output=None, log_level=logging.INFO, @@ -134,13 +135,11 @@ def init_logger( # Cache the opened file object, so that different calls to `initialize_logger` # with the same file name can safely write to the same file. -@lru_cache(maxsize=None) +@cache def _cached_log_file(filename): """Cache the opened file object""" # Use 1K buffer if writing to cloud storage - with open( - filename, "a", buffering=_determine_buffer_size(filename), encoding="utf-8" - ) as file_io: + with open(filename, "a", buffering=_determine_buffer_size(filename), encoding="utf-8") as file_io: atexit.register(file_io.close) return file_io @@ -203,7 +202,7 @@ def log_first_n_times(level, message, n=1, *, logger_name=None, key="caller"): if "caller" in key: hash_key = hash_key + caller_key if "message" in key: - hash_key = hash_key + (message,) + hash_key = (*hash_key, message) LOG_COUNTER[hash_key] += 1 if LOG_COUNTER[hash_key] <= n: @@ -219,7 +218,7 @@ def log_every_n_times(level, message, n=1, *, logger_name=None): def log_every_n_secs(level, message, n=1, *, logger_name=None): caller_module, key = _identify_caller() - last_logged = LOG_TIMERS.get(key, None) + last_logged = LOG_TIMERS.get(key) current_time = time.time() if last_logged is None or current_time - last_logged >= n: logging.getLogger(logger_name or caller_module).log(level, message) diff --git a/hugegraph-python-client/src/pyhugegraph/utils/util.py b/hugegraph-python-client/src/pyhugegraph/utils/util.py index 56a135547..d8c833b49 100644 --- a/hugegraph-python-client/src/pyhugegraph/utils/util.py +++ b/hugegraph-python-client/src/pyhugegraph/utils/util.py @@ -24,7 +24,7 @@ from pyhugegraph.utils.exceptions import ( NotAuthorizedError, NotFoundError, - ServiceUnavailableException, + ServiceUnavailableError, ) from pyhugegraph.utils.log import log @@ -33,9 +33,8 @@ def create_exception(response_content): try: data = json.loads(response_content) if "ServiceUnavailableException" in data.get("exception", ""): - raise ServiceUnavailableException( - f'ServiceUnavailableException, "message": "{data["message"]}",' - f' "cause": "{data["cause"]}"' + raise ServiceUnavailableError( + f'ServiceUnavailableException, "message": "{data["message"]}", "cause": "{data["cause"]}"' ) except (json.JSONDecodeError, KeyError) as e: raise Exception(f"Error parsing response content: {response_content}") from e @@ -44,9 +43,7 @@ def create_exception(response_content): def check_if_authorized(response): if response.status_code == 401: - raise NotAuthorizedError( - f"Please check your username and password. {str(response.content)}" - ) + raise NotAuthorizedError(f"Please check your username and password. {response.content!s}") return True diff --git a/hugegraph-python-client/src/tests/api/test_auth.py b/hugegraph-python-client/src/tests/api/test_auth.py index bda9c8273..2105ef0b4 100644 --- a/hugegraph-python-client/src/tests/api/test_auth.py +++ b/hugegraph-python-client/src/tests/api/test_auth.py @@ -19,6 +19,7 @@ import unittest from pyhugegraph.utils.exceptions import NotFoundError + from ..client_utils import ClientUtils @@ -138,7 +139,7 @@ def test_target_operations(self): # Delete the target self.auth.delete_target(target["id"]) # Verify the target was deleted - with self.assertRaises(Exception): + with self.assertRaises(NotFoundError): self.auth.get_target(target["id"]) def test_belong_operations(self): @@ -170,7 +171,7 @@ def test_belong_operations(self): # Delete the belong self.auth.delete_belong(belong["id"]) # Verify the belong was deleted - with self.assertRaises(Exception): + with self.assertRaises(NotFoundError): self.auth.get_belong(belong["id"]) def test_access_operations(self): @@ -205,5 +206,5 @@ def test_access_operations(self): # Delete the permission self.auth.revoke_accesses(access["id"]) # Verify the permission was deleted - with self.assertRaises(Exception): + with self.assertRaises(NotFoundError): self.auth.get_accesses(access["id"]) diff --git a/hugegraph-python-client/src/tests/api/test_graph.py b/hugegraph-python-client/src/tests/api/test_graph.py index e77992b41..a66c3fb9f 100644 --- a/hugegraph-python-client/src/tests/api/test_graph.py +++ b/hugegraph-python-client/src/tests/api/test_graph.py @@ -18,6 +18,7 @@ import unittest from pyhugegraph.utils.exceptions import NotFoundError + from ..client_utils import ClientUtils diff --git a/hugegraph-python-client/src/tests/api/test_gremlin.py b/hugegraph-python-client/src/tests/api/test_gremlin.py index 43aeb8ba2..c212c0fe7 100644 --- a/hugegraph-python-client/src/tests/api/test_gremlin.py +++ b/hugegraph-python-client/src/tests/api/test_gremlin.py @@ -18,8 +18,8 @@ import unittest import pytest - from pyhugegraph.utils.exceptions import NotFoundError + from ..client_utils import ClientUtils diff --git a/hugegraph-python-client/src/tests/api/test_task.py b/hugegraph-python-client/src/tests/api/test_task.py index 3bd122967..599d8d1f6 100644 --- a/hugegraph-python-client/src/tests/api/test_task.py +++ b/hugegraph-python-client/src/tests/api/test_task.py @@ -18,6 +18,7 @@ import unittest from pyhugegraph.utils.exceptions import NotFoundError + from ..client_utils import ClientUtils diff --git a/hugegraph-python-client/src/tests/api/test_traverser.py b/hugegraph-python-client/src/tests/api/test_traverser.py index 330675f1d..123a78e43 100644 --- a/hugegraph-python-client/src/tests/api/test_traverser.py +++ b/hugegraph-python-client/src/tests/api/test_traverser.py @@ -55,9 +55,7 @@ def test_traverser_operations(self): self.assertEqual(k_out_result["vertices"], ["1:peter", "2:ripple"]) k_neighbor_result = self.traverser.k_neighbor(marko, 2) - self.assertEqual( - k_neighbor_result["vertices"], ["1:peter", "1:josh", "2:lop", "2:ripple", "1:vadas"] - ) + self.assertEqual(k_neighbor_result["vertices"], ["1:peter", "1:josh", "2:lop", "2:ripple", "1:vadas"]) same_neighbors_result = self.traverser.same_neighbors(marko, josh) self.assertEqual(same_neighbors_result["same_neighbors"], ["2:lop"]) @@ -69,16 +67,10 @@ def test_traverser_operations(self): self.assertEqual(shortest_path_result["path"], ["1:marko", "1:josh", "2:ripple"]) all_shortest_paths_result = self.traverser.all_shortest_paths(marko, ripple, 3) - self.assertEqual( - all_shortest_paths_result["paths"], [{"objects": ["1:marko", "1:josh", "2:ripple"]}] - ) + self.assertEqual(all_shortest_paths_result["paths"], [{"objects": ["1:marko", "1:josh", "2:ripple"]}]) - weighted_shortest_path_result = self.traverser.weighted_shortest_path( - marko, ripple, "weight", 3 - ) - self.assertEqual( - weighted_shortest_path_result["vertices"], ["1:marko", "1:josh", "2:ripple"] - ) + weighted_shortest_path_result = self.traverser.weighted_shortest_path(marko, ripple, "weight", 3) + self.assertEqual(weighted_shortest_path_result["vertices"], ["1:marko", "1:josh", "2:ripple"]) single_source_shortest_path_result = self.traverser.single_source_shortest_path(marko, 2) self.assertEqual( @@ -92,9 +84,7 @@ def test_traverser_operations(self): }, ) - multi_node_shortest_path_result = self.traverser.multi_node_shortest_path( - [marko, josh], max_depth=2 - ) + multi_node_shortest_path_result = self.traverser.multi_node_shortest_path([marko, josh], max_depth=2) self.assertEqual( multi_node_shortest_path_result["vertices"], [ @@ -131,9 +121,7 @@ def test_traverser_operations(self): } ], ) - self.assertEqual( - customized_paths_result["paths"], [{"objects": ["1:marko", "2:lop"], "weights": [8.0]}] - ) + self.assertEqual(customized_paths_result["paths"], [{"objects": ["1:marko", "2:lop"], "weights": [8.0]}]) sources = {"ids": [], "label": "person", "properties": {"name": "vadas"}} @@ -186,15 +174,11 @@ def test_traverser_operations(self): sources = {"ids": ["2:lop", "2:ripple"]} path_patterns = [{"steps": [{"direction": "IN", "labels": ["created"], "max_degree": -1}]}] - customized_crosspoints_result = self.traverser.customized_crosspoints( - sources, path_patterns - ) + customized_crosspoints_result = self.traverser.customized_crosspoints(sources, path_patterns) self.assertEqual(customized_crosspoints_result["crosspoints"], ["1:josh"]) rings_result = self.traverser.rings(marko, 3) - self.assertEqual( - rings_result["rings"], [{"objects": ["1:marko", "2:lop", "1:josh", "1:marko"]}] - ) + self.assertEqual(rings_result["rings"], [{"objects": ["1:marko", "2:lop", "1:josh", "1:marko"]}]) rays_result = self.traverser.rays(marko, 2) self.assertEqual( diff --git a/hugegraph-python-client/src/tests/api/test_variable.py b/hugegraph-python-client/src/tests/api/test_variable.py index 19af6a959..d75ac5142 100644 --- a/hugegraph-python-client/src/tests/api/test_variable.py +++ b/hugegraph-python-client/src/tests/api/test_variable.py @@ -18,8 +18,8 @@ import unittest import pytest - from pyhugegraph.utils.exceptions import NotFoundError + from ..client_utils import ClientUtils diff --git a/hugegraph-python-client/src/tests/client_utils.py b/hugegraph-python-client/src/tests/client_utils.py index f711072b8..1914cdb0c 100644 --- a/hugegraph-python-client/src/tests/client_utils.py +++ b/hugegraph-python-client/src/tests/client_utils.py @@ -59,23 +59,21 @@ def init_property_key(self): def init_vertex_label(self): schema = self.schema - schema.vertexLabel("person").properties("name", "age", "city").primaryKeys( - "name" - ).nullableKeys("city").ifNotExist().create() - schema.vertexLabel("software").properties("name", "lang", "price").primaryKeys( - "name" - ).nullableKeys("price").ifNotExist().create() + schema.vertexLabel("person").properties("name", "age", "city").primaryKeys("name").nullableKeys( + "city" + ).ifNotExist().create() + schema.vertexLabel("software").properties("name", "lang", "price").primaryKeys("name").nullableKeys( + "price" + ).ifNotExist().create() schema.vertexLabel("book").useCustomizeStringId().properties("name", "price").nullableKeys( "price" ).ifNotExist().create() def init_edge_label(self): schema = self.schema - schema.edgeLabel("knows").sourceLabel("person").targetLabel( - "person" - ).multiTimes().properties("date", "city").sortKeys("date").nullableKeys( - "city" - ).ifNotExist().create() + schema.edgeLabel("knows").sourceLabel("person").targetLabel("person").multiTimes().properties( + "date", "city" + ).sortKeys("date").nullableKeys("city").ifNotExist().create() schema.edgeLabel("created").sourceLabel("person").targetLabel("software").properties( "date", "city" ).nullableKeys("city").ifNotExist().create() @@ -84,16 +82,10 @@ def init_index_label(self): schema = self.schema schema.indexLabel("personByCity").onV("person").by("city").secondary().ifNotExist().create() schema.indexLabel("personByAge").onV("person").by("age").range().ifNotExist().create() - schema.indexLabel("softwareByPrice").onV("software").by( - "price" - ).range().ifNotExist().create() - schema.indexLabel("softwareByLang").onV("software").by( - "lang" - ).secondary().ifNotExist().create() + schema.indexLabel("softwareByPrice").onV("software").by("price").range().ifNotExist().create() + schema.indexLabel("softwareByLang").onV("software").by("lang").secondary().ifNotExist().create() schema.indexLabel("knowsByDate").onE("knows").by("date").secondary().ifNotExist().create() - schema.indexLabel("createdByDate").onE("created").by( - "date" - ).secondary().ifNotExist().create() + schema.indexLabel("createdByDate").onE("created").by("date").secondary().ifNotExist().create() def init_vertices(self): graph = self.graph @@ -125,7 +117,7 @@ def _get_vertex_id(self, label, properties): def _get_vertex(self, label, properties): lst = self.graph.getVertexByCondition(label=label, limit=1, properties=properties) - assert 1 == len(lst), "Can't find vertex." + assert len(lst) == 1, "Can't find vertex." return lst[0] def clear_graph_all_data(self): diff --git a/pyproject.toml b/pyproject.toml index 2dd4161d4..9b059384e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,6 +43,9 @@ dev = [ "pytest~=8.0.0", "pytest-cov~=5.0.0", "pylint~=3.0.0", + "ruff>=0.11.0", + "mypy>=1.16.1", + "pre-commit>=3.5.0", ] nk-llm = ["hugegraph-llm", "hugegraph-python-client", "nuitka"] @@ -67,7 +70,7 @@ build-backend = "hatchling.build" # Alternatively, configure globally: uv config --global index.url https://pypi.tuna.tsinghua.edu.cn/simple # To reset to default: uv config --global index.url https://pypi.org/simple -[tool.hatch.metadata] # Keep this if hatch is still used by submodules, otherwise remove +[tool.hatch.metadata] # Keep this if the hatch is still used by submodules, otherwise remove allow-direct-references = true [tool.hatch.build.targets.wheel] @@ -140,3 +143,43 @@ constraint-dependencies = [ # Other dependencies "python-dateutil~=2.9.0", ] + +# for code format +[tool.ruff] +line-length = 120 +target-version = "py310" + +[tool.ruff.lint] +# Select a broad set of rules for comprehensive checks. +# E: pycodestyle Errors, F: Pyflakes, W: pycodestyle Warnings, I: isort +# C: flake8-comprehensions, N: pep8-naming +# UP: pyupgrade, B: flake8-bugbear, SIM: flake8-simplify, T20: flake8-print +select = ["E", "F", "W", "I", "C", "N", "UP", "B", "SIM", "T20", "RUF"] + +# Ignore specific rules +ignore = [ + "PYI041", # redundant-numeric-union: keep clear 'int | float' for type hinting + "N812", # lowercase-imported-as-non-lowercase + "N806", # non-lowercase-variable-in-function + "N803", # invalid-argument-name + "N802", # invalid-function-name (API compatibility) + "C901", # complexity (non-critical for now) + "RUF001", # ambiguous-unicode-character-string + "RUF003", # ambiguous-unicode-character-comment +] +# No need to ignore E501 (line-too-long), `ruff format` will handle it automatically. + +[tool.ruff.lint.per-file-ignores] +"tests/**/*.py" = ["T20"] +"hugegraph-ml/src/hugegraph_ml/examples/**/*.py" = ["T20"] +"hugegraph-python-client/src/pyhugegraph/structure/*.py" = ["N802"] + +[tool.ruff.lint.isort] +known-first-party = ["hugegraph_llm", "hugegraph_python_client", "hugegraph_ml", "vermeer_python_client"] + +[tool.ruff.format] +quote-style = "double" +indent-style = "space" +skip-magic-trailing-comma = false +line-ending = "auto" +docstring-code-format = true diff --git a/rules/requirements.md b/rules/requirements.md index 10613019c..bd61058e3 100644 --- a/rules/requirements.md +++ b/rules/requirements.md @@ -86,4 +86,4 @@ `- **E2**: WHEN **重置令牌成功生成**, the **系统** shall **立即向该邮箱发送一封包含密码重置链接的邮件**。` `- **X1**: IF **用户提供的邮箱地址未在系统中注册**, THEN the **系统** shall **显示“如果该邮箱已注册,您将收到一封邮件”的通用提示,且不暴露该邮箱是否存在**。` `- **U1**: The **密码重置链接** shall **是唯一的,并在首次使用或24小时后立即失效**。` ---- \ No newline at end of file +--- diff --git a/vermeer-python-client/src/pyvermeer/api/base.py b/vermeer-python-client/src/pyvermeer/api/base.py index 0ab5fe090..84de2cb31 100644 --- a/vermeer-python-client/src/pyvermeer/api/base.py +++ b/vermeer-python-client/src/pyvermeer/api/base.py @@ -30,11 +30,7 @@ def session(self): """Return the client's session object""" return self._client.session - def _send_request(self, method: str, endpoint: str, params: dict = None): + def _send_request(self, method: str, endpoint: str, params: dict | None = None): """Unified request entry point""" self.log.debug(f"Sending {method} to {endpoint}") - return self._client.send_request( - method=method, - endpoint=endpoint, - params=params - ) + return self._client.send_request(method=method, endpoint=endpoint, params=params) diff --git a/vermeer-python-client/src/pyvermeer/api/graph.py b/vermeer-python-client/src/pyvermeer/api/graph.py index 1fabdcabe..ee9c3bcdf 100644 --- a/vermeer-python-client/src/pyvermeer/api/graph.py +++ b/vermeer-python-client/src/pyvermeer/api/graph.py @@ -15,7 +15,8 @@ # specific language governing permissions and limitations # under the License. -from pyvermeer.structure.graph_data import GraphsResponse, GraphResponse +from pyvermeer.structure.graph_data import GraphResponse, GraphsResponse + from .base import BaseModule @@ -24,10 +25,7 @@ class GraphModule(BaseModule): def get_graph(self, graph_name: str) -> GraphResponse: """Get task list""" - response = self._send_request( - "GET", - f"/graphs/{graph_name}" - ) + response = self._send_request("GET", f"/graphs/{graph_name}") return GraphResponse(response) def get_graphs(self) -> GraphsResponse: diff --git a/vermeer-python-client/src/pyvermeer/api/task.py b/vermeer-python-client/src/pyvermeer/api/task.py index 12e1186b6..3da32b687 100644 --- a/vermeer-python-client/src/pyvermeer/api/task.py +++ b/vermeer-python-client/src/pyvermeer/api/task.py @@ -16,8 +16,7 @@ # under the License. from pyvermeer.api.base import BaseModule - -from pyvermeer.structure.task_data import TasksResponse, TaskCreateRequest, TaskCreateResponse, TaskResponse +from pyvermeer.structure.task_data import TaskCreateRequest, TaskCreateResponse, TaskResponse, TasksResponse class TaskModule(BaseModule): @@ -25,25 +24,15 @@ class TaskModule(BaseModule): def get_tasks(self) -> TasksResponse: """Get task list""" - response = self._send_request( - "GET", - "/tasks" - ) + response = self._send_request("GET", "/tasks") return TasksResponse(response) def get_task(self, task_id: int) -> TaskResponse: """Get single task information""" - response = self._send_request( - "GET", - f"/task/{task_id}" - ) + response = self._send_request("GET", f"/task/{task_id}") return TaskResponse(response) def create_task(self, create_task: TaskCreateRequest) -> TaskCreateResponse: """Create new task""" - response = self._send_request( - method="POST", - endpoint="/tasks/create", - params=create_task.to_dict() - ) + response = self._send_request(method="POST", endpoint="/tasks/create", params=create_task.to_dict()) return TaskCreateResponse(response) diff --git a/vermeer-python-client/src/pyvermeer/client/client.py b/vermeer-python-client/src/pyvermeer/client/client.py index ba6a0947d..1946f7074 100644 --- a/vermeer-python-client/src/pyvermeer/client/client.py +++ b/vermeer-python-client/src/pyvermeer/client/client.py @@ -15,8 +15,6 @@ # specific language governing permissions and limitations # under the License. -from typing import Dict -from typing import Optional from pyvermeer.api.base import BaseModule from pyvermeer.api.graph import GraphModule @@ -30,12 +28,12 @@ class PyVermeerClient: """Vermeer API Client""" def __init__( - self, - ip: str, - port: int, - token: str, - timeout: Optional[tuple[float, float]] = None, - log_level: str = "INFO", + self, + ip: str, + port: int, + token: str, + timeout: tuple[float, float] | None = None, + log_level: str = "INFO", ): """Initialize the client, including configuration and session management :param ip: @@ -46,10 +44,7 @@ def __init__( """ self.cfg = VermeerConfig(ip, port, token, timeout) self.session = VermeerSession(self.cfg) - self._modules: Dict[str, BaseModule] = { - "graph": GraphModule(self), - "tasks": TaskModule(self) - } + self._modules: dict[str, BaseModule] = {"graph": GraphModule(self), "tasks": TaskModule(self)} log.setLevel(log_level) def __getattr__(self, name): @@ -58,6 +53,6 @@ def __getattr__(self, name): return self._modules[name] raise AttributeError(f"Module {name} not found") - def send_request(self, method: str, endpoint: str, params: dict = None): + def send_request(self, method: str, endpoint: str, params: dict | None = None): """Unified request method""" return self.session.request(method, endpoint, params) diff --git a/vermeer-python-client/src/pyvermeer/demo/task_demo.py b/vermeer-python-client/src/pyvermeer/demo/task_demo.py index bb0a00d85..9b23d82b6 100644 --- a/vermeer-python-client/src/pyvermeer/demo/task_demo.py +++ b/vermeer-python-client/src/pyvermeer/demo/task_demo.py @@ -27,27 +27,23 @@ def main(): token="", log_level="DEBUG", ) - task = client.tasks.get_tasks() + client.tasks.get_tasks() - print(task.to_dict()) - - create_response = client.tasks.create_task( + client.tasks.create_task( create_task=TaskCreateRequest( - task_type='load', - graph_name='DEFAULT-example', + task_type="load", + graph_name="DEFAULT-example", params={ - "load.hg_pd_peers": "[\"127.0.0.1:8686\"]", + "load.hg_pd_peers": '["127.0.0.1:8686"]', "load.hugegraph_name": "DEFAULT/example/g", "load.hugegraph_password": "xxx", "load.hugegraph_username": "xxx", "load.parallel": "10", - "load.type": "hugegraph" + "load.type": "hugegraph", }, ) ) - print(create_response.to_dict()) - if __name__ == "__main__": main() diff --git a/vermeer-python-client/src/pyvermeer/structure/base_data.py b/vermeer-python-client/src/pyvermeer/structure/base_data.py index 4d6078050..117ab390c 100644 --- a/vermeer-python-client/src/pyvermeer/structure/base_data.py +++ b/vermeer-python-client/src/pyvermeer/structure/base_data.py @@ -20,7 +20,7 @@ RESPONSE_NONE = -1 -class BaseResponse(object): +class BaseResponse: """ Base response class """ @@ -30,8 +30,8 @@ def __init__(self, dic: dict): init :param dic: """ - self.__errcode = dic.get('errcode', RESPONSE_NONE) - self.__message = dic.get('message', "") + self.__errcode = dic.get("errcode", RESPONSE_NONE) + self.__message = dic.get("message", "") @property def errcode(self) -> int: diff --git a/vermeer-python-client/src/pyvermeer/structure/graph_data.py b/vermeer-python-client/src/pyvermeer/structure/graph_data.py index 8f97ed1a1..d96626539 100644 --- a/vermeer-python-client/src/pyvermeer/structure/graph_data.py +++ b/vermeer-python-client/src/pyvermeer/structure/graph_data.py @@ -26,7 +26,7 @@ class BackendOpt: def __init__(self, dic: dict): """init""" - self.__vertex_data_backend = dic.get('vertex_data_backend', None) + self.__vertex_data_backend = dic.get("vertex_data_backend") @property def vertex_data_backend(self): @@ -35,9 +35,7 @@ def vertex_data_backend(self): def to_dict(self): """to dict""" - return { - 'vertex_data_backend': self.vertex_data_backend - } + return {"vertex_data_backend": self.vertex_data_backend} class GraphWorker: @@ -45,12 +43,12 @@ class GraphWorker: def __init__(self, dic: dict): """init""" - self.__name = dic.get('Name', '') - self.__vertex_count = dic.get('VertexCount', -1) - self.__vert_id_start = dic.get('VertIdStart', -1) - self.__edge_count = dic.get('EdgeCount', -1) - self.__is_self = dic.get('IsSelf', False) - self.__scatter_offset = dic.get('ScatterOffset', -1) + self.__name = dic.get("Name", "") + self.__vertex_count = dic.get("VertexCount", -1) + self.__vert_id_start = dic.get("VertIdStart", -1) + self.__edge_count = dic.get("EdgeCount", -1) + self.__is_self = dic.get("IsSelf", False) + self.__scatter_offset = dic.get("ScatterOffset", -1) @property def name(self) -> str: @@ -74,7 +72,7 @@ def edge_count(self) -> int: @property def is_self(self) -> bool: - """is self worker. Nonsense """ + """is self worker. Nonsense""" return self.__is_self @property @@ -85,12 +83,12 @@ def scatter_offset(self) -> int: def to_dict(self): """to dict""" return { - 'name': self.name, - 'vertex_count': self.vertex_count, - 'vert_id_start': self.vert_id_start, - 'edge_count': self.edge_count, - 'is_self': self.is_self, - 'scatter_offset': self.scatter_offset + "name": self.name, + "vertex_count": self.vertex_count, + "vert_id_start": self.vert_id_start, + "edge_count": self.edge_count, + "is_self": self.is_self, + "scatter_offset": self.scatter_offset, } @@ -99,21 +97,21 @@ class VermeerGraph: def __init__(self, dic: dict): """init""" - self.__name = dic.get('name', '') - self.__space_name = dic.get('space_name', '') - self.__status = dic.get('status', '') - self.__create_time = parse_vermeer_time(dic.get('create_time', '')) - self.__update_time = parse_vermeer_time(dic.get('update_time', '')) - self.__vertex_count = dic.get('vertex_count', 0) - self.__edge_count = dic.get('edge_count', 0) - self.__workers = [GraphWorker(w) for w in dic.get('workers', [])] - self.__worker_group = dic.get('worker_group', '') - self.__use_out_edges = dic.get('use_out_edges', False) - self.__use_property = dic.get('use_property', False) - self.__use_out_degree = dic.get('use_out_degree', False) - self.__use_undirected = dic.get('use_undirected', False) - self.__on_disk = dic.get('on_disk', False) - self.__backend_option = BackendOpt(dic.get('backend_option', {})) + self.__name = dic.get("name", "") + self.__space_name = dic.get("space_name", "") + self.__status = dic.get("status", "") + self.__create_time = parse_vermeer_time(dic.get("create_time", "")) + self.__update_time = parse_vermeer_time(dic.get("update_time", "")) + self.__vertex_count = dic.get("vertex_count", 0) + self.__edge_count = dic.get("edge_count", 0) + self.__workers = [GraphWorker(w) for w in dic.get("workers", [])] + self.__worker_group = dic.get("worker_group", "") + self.__use_out_edges = dic.get("use_out_edges", False) + self.__use_property = dic.get("use_property", False) + self.__use_out_degree = dic.get("use_out_degree", False) + self.__use_undirected = dic.get("use_undirected", False) + self.__on_disk = dic.get("on_disk", False) + self.__backend_option = BackendOpt(dic.get("backend_option", {})) @property def name(self) -> str: @@ -193,21 +191,21 @@ def backend_option(self) -> BackendOpt: def to_dict(self) -> dict: """to dict""" return { - 'name': self.__name, - 'space_name': self.__space_name, - 'status': self.__status, - 'create_time': self.__create_time.strftime("%Y-%m-%d %H:%M:%S") if self.__create_time else '', - 'update_time': self.__update_time.strftime("%Y-%m-%d %H:%M:%S") if self.__update_time else '', - 'vertex_count': self.__vertex_count, - 'edge_count': self.__edge_count, - 'workers': [w.to_dict() for w in self.__workers], - 'worker_group': self.__worker_group, - 'use_out_edges': self.__use_out_edges, - 'use_property': self.__use_property, - 'use_out_degree': self.__use_out_degree, - 'use_undirected': self.__use_undirected, - 'on_disk': self.__on_disk, - 'backend_option': self.__backend_option.to_dict(), + "name": self.__name, + "space_name": self.__space_name, + "status": self.__status, + "create_time": self.__create_time.strftime("%Y-%m-%d %H:%M:%S") if self.__create_time else "", + "update_time": self.__update_time.strftime("%Y-%m-%d %H:%M:%S") if self.__update_time else "", + "vertex_count": self.__vertex_count, + "edge_count": self.__edge_count, + "workers": [w.to_dict() for w in self.__workers], + "worker_group": self.__worker_group, + "use_out_edges": self.__use_out_edges, + "use_property": self.__use_property, + "use_out_degree": self.__use_out_degree, + "use_undirected": self.__use_undirected, + "on_disk": self.__on_disk, + "backend_option": self.__backend_option.to_dict(), } @@ -217,7 +215,7 @@ class GraphsResponse(BaseResponse): def __init__(self, dic: dict): """init""" super().__init__(dic) - self.__graphs = [VermeerGraph(g) for g in dic.get('graphs', [])] + self.__graphs = [VermeerGraph(g) for g in dic.get("graphs", [])] @property def graphs(self) -> list[VermeerGraph]: @@ -226,11 +224,7 @@ def graphs(self) -> list[VermeerGraph]: def to_dict(self) -> dict: """to dict""" - return { - 'errcode': self.errcode, - 'message': self.message, - 'graphs': [g.to_dict() for g in self.graphs] - } + return {"errcode": self.errcode, "message": self.message, "graphs": [g.to_dict() for g in self.graphs]} class GraphResponse(BaseResponse): @@ -239,7 +233,7 @@ class GraphResponse(BaseResponse): def __init__(self, dic: dict): """init""" super().__init__(dic) - self.__graph = VermeerGraph(dic.get('graph', {})) + self.__graph = VermeerGraph(dic.get("graph", {})) @property def graph(self) -> VermeerGraph: @@ -248,8 +242,4 @@ def graph(self) -> VermeerGraph: def to_dict(self) -> dict: """to dict""" - return { - 'errcode': self.errcode, - 'message': self.message, - 'graph': self.graph.to_dict() - } + return {"errcode": self.errcode, "message": self.message, "graph": self.graph.to_dict()} diff --git a/vermeer-python-client/src/pyvermeer/structure/master_data.py b/vermeer-python-client/src/pyvermeer/structure/master_data.py index de2fac774..0af10da79 100644 --- a/vermeer-python-client/src/pyvermeer/structure/master_data.py +++ b/vermeer-python-client/src/pyvermeer/structure/master_data.py @@ -26,11 +26,11 @@ class MasterInfo: def __init__(self, dic: dict): """Initialization function""" - self.__grpc_peer = dic.get('grpc_peer', '') - self.__ip_addr = dic.get('ip_addr', '') - self.__debug_mod = dic.get('debug_mod', False) - self.__version = dic.get('version', '') - self.__launch_time = parse_vermeer_time(dic.get('launch_time', '')) + self.__grpc_peer = dic.get("grpc_peer", "") + self.__ip_addr = dic.get("ip_addr", "") + self.__debug_mod = dic.get("debug_mod", False) + self.__version = dic.get("version", "") + self.__launch_time = parse_vermeer_time(dic.get("launch_time", "")) @property def grpc_peer(self) -> str: @@ -64,7 +64,7 @@ def to_dict(self): "ip_addr": self.__ip_addr, "debug_mod": self.__debug_mod, "version": self.__version, - "launch_time": self.__launch_time.strftime("%Y-%m-%d %H:%M:%S") if self.__launch_time else '' + "launch_time": self.__launch_time.strftime("%Y-%m-%d %H:%M:%S") if self.__launch_time else "", } @@ -74,7 +74,7 @@ class MasterResponse(BaseResponse): def __init__(self, dic: dict): """Initialization function""" super().__init__(dic) - self.__master_info = MasterInfo(dic['master_info']) + self.__master_info = MasterInfo(dic["master_info"]) @property def master_info(self) -> MasterInfo: diff --git a/vermeer-python-client/src/pyvermeer/structure/task_data.py b/vermeer-python-client/src/pyvermeer/structure/task_data.py index 6fa408a4b..4cf0aa227 100644 --- a/vermeer-python-client/src/pyvermeer/structure/task_data.py +++ b/vermeer-python-client/src/pyvermeer/structure/task_data.py @@ -26,8 +26,8 @@ class TaskWorker: def __init__(self, dic): """init""" - self.__name = dic.get('name', None) - self.__status = dic.get('status', None) + self.__name = dic.get("name", None) + self.__status = dic.get("status", None) @property def name(self) -> str: @@ -41,7 +41,7 @@ def status(self) -> str: def to_dict(self): """to dict""" - return {'name': self.name, 'status': self.status} + return {"name": self.name, "status": self.status} class TaskInfo: @@ -49,19 +49,19 @@ class TaskInfo: def __init__(self, dic): """init""" - self.__id = dic.get('id', 0) - self.__status = dic.get('status', '') - self.__state = dic.get('state', '') - self.__create_user = dic.get('create_user', '') - self.__create_type = dic.get('create_type', '') - self.__create_time = parse_vermeer_time(dic.get('create_time', '')) - self.__start_time = parse_vermeer_time(dic.get('start_time', '')) - self.__update_time = parse_vermeer_time(dic.get('update_time', '')) - self.__graph_name = dic.get('graph_name', '') - self.__space_name = dic.get('space_name', '') - self.__type = dic.get('type', '') - self.__params = dic.get('params', {}) - self.__workers = [TaskWorker(w) for w in dic.get('workers', [])] + self.__id = dic.get("id", 0) + self.__status = dic.get("status", "") + self.__state = dic.get("state", "") + self.__create_user = dic.get("create_user", "") + self.__create_type = dic.get("create_type", "") + self.__create_time = parse_vermeer_time(dic.get("create_time", "")) + self.__start_time = parse_vermeer_time(dic.get("start_time", "")) + self.__update_time = parse_vermeer_time(dic.get("update_time", "")) + self.__graph_name = dic.get("graph_name", "") + self.__space_name = dic.get("space_name", "") + self.__type = dic.get("type", "") + self.__params = dic.get("params", {}) + self.__workers = [TaskWorker(w) for w in dic.get("workers", [])] @property def id(self) -> int: @@ -126,19 +126,19 @@ def workers(self) -> list[TaskWorker]: def to_dict(self) -> dict: """to dict""" return { - 'id': self.__id, - 'status': self.__status, - 'state': self.__state, - 'create_user': self.__create_user, - 'create_type': self.__create_type, - 'create_time': self.__create_time.strftime("%Y-%m-%d %H:%M:%S") if self.__start_time else '', - 'start_time': self.__start_time.strftime("%Y-%m-%d %H:%M:%S") if self.__start_time else '', - 'update_time': self.__update_time.strftime("%Y-%m-%d %H:%M:%S") if self.__update_time else '', - 'graph_name': self.__graph_name, - 'space_name': self.__space_name, - 'type': self.__type, - 'params': self.__params, - 'workers': [w.to_dict() for w in self.__workers], + "id": self.__id, + "status": self.__status, + "state": self.__state, + "create_user": self.__create_user, + "create_type": self.__create_type, + "create_time": self.__create_time.strftime("%Y-%m-%d %H:%M:%S") if self.__start_time else "", + "start_time": self.__start_time.strftime("%Y-%m-%d %H:%M:%S") if self.__start_time else "", + "update_time": self.__update_time.strftime("%Y-%m-%d %H:%M:%S") if self.__update_time else "", + "graph_name": self.__graph_name, + "space_name": self.__space_name, + "type": self.__type, + "params": self.__params, + "workers": [w.to_dict() for w in self.__workers], } @@ -153,11 +153,7 @@ def __init__(self, task_type, graph_name, params): def to_dict(self) -> dict: """to dict""" - return { - 'task_type': self.task_type, - 'graph': self.graph_name, - 'params': self.params - } + return {"task_type": self.task_type, "graph": self.graph_name, "params": self.params} class TaskCreateResponse(BaseResponse): @@ -166,7 +162,7 @@ class TaskCreateResponse(BaseResponse): def __init__(self, dic): """init""" super().__init__(dic) - self.__task = TaskInfo(dic.get('task', {})) + self.__task = TaskInfo(dic.get("task", {})) @property def task(self) -> TaskInfo: @@ -188,7 +184,7 @@ class TasksResponse(BaseResponse): def __init__(self, dic): """init""" super().__init__(dic) - self.__tasks = [TaskInfo(t) for t in dic.get('tasks', [])] + self.__tasks = [TaskInfo(t) for t in dic.get("tasks", [])] @property def tasks(self) -> list[TaskInfo]: @@ -197,11 +193,7 @@ def tasks(self) -> list[TaskInfo]: def to_dict(self) -> dict: """to dict""" - return { - "errcode": self.errcode, - "message": self.message, - "tasks": [t.to_dict() for t in self.tasks] - } + return {"errcode": self.errcode, "message": self.message, "tasks": [t.to_dict() for t in self.tasks]} class TaskResponse(BaseResponse): @@ -210,7 +202,7 @@ class TaskResponse(BaseResponse): def __init__(self, dic): """init""" super().__init__(dic) - self.__task = TaskInfo(dic.get('task', {})) + self.__task = TaskInfo(dic.get("task", {})) @property def task(self) -> TaskInfo: diff --git a/vermeer-python-client/src/pyvermeer/structure/worker_data.py b/vermeer-python-client/src/pyvermeer/structure/worker_data.py index a35cfe9be..d8b19ff8e 100644 --- a/vermeer-python-client/src/pyvermeer/structure/worker_data.py +++ b/vermeer-python-client/src/pyvermeer/structure/worker_data.py @@ -26,15 +26,15 @@ class Worker: def __init__(self, dic): """init""" - self.__id = dic.get('id', 0) - self.__name = dic.get('name', '') - self.__grpc_addr = dic.get('grpc_addr', '') - self.__ip_addr = dic.get('ip_addr', '') - self.__state = dic.get('state', '') - self.__version = dic.get('version', '') - self.__group = dic.get('group', '') - self.__init_time = parse_vermeer_time(dic.get('init_time', '')) - self.__launch_time = parse_vermeer_time(dic.get('launch_time', '')) + self.__id = dic.get("id", 0) + self.__name = dic.get("name", "") + self.__grpc_addr = dic.get("grpc_addr", "") + self.__ip_addr = dic.get("ip_addr", "") + self.__state = dic.get("state", "") + self.__version = dic.get("version", "") + self.__group = dic.get("group", "") + self.__init_time = parse_vermeer_time(dic.get("init_time", "")) + self.__launch_time = parse_vermeer_time(dic.get("launch_time", "")) @property def id(self) -> int: @@ -102,7 +102,7 @@ class WorkersResponse(BaseResponse): def __init__(self, dic): """init""" super().__init__(dic) - self.__workers = [Worker(worker) for worker in dic['workers']] + self.__workers = [Worker(worker) for worker in dic["workers"]] @property def workers(self) -> list[Worker]: diff --git a/vermeer-python-client/src/pyvermeer/utils/exception.py b/vermeer-python-client/src/pyvermeer/utils/exception.py index ddb36d811..65aebf79a 100644 --- a/vermeer-python-client/src/pyvermeer/utils/exception.py +++ b/vermeer-python-client/src/pyvermeer/utils/exception.py @@ -20,25 +20,25 @@ class ConnectError(Exception): """Raised when there is an issue connecting to the server.""" def __init__(self, message): - super().__init__(f"Connection error: {str(message)}") + super().__init__(f"Connection error: {message!s}") class TimeOutError(Exception): """Raised when a request times out.""" def __init__(self, message): - super().__init__(f"Request timed out: {str(message)}") + super().__init__(f"Request timed out: {message!s}") class JsonDecodeError(Exception): """Raised when the response from the server cannot be decoded as JSON.""" def __init__(self, message): - super().__init__(f"Failed to decode JSON response: {str(message)}") + super().__init__(f"Failed to decode JSON response: {message!s}") class UnknownError(Exception): """Raised for any other unknown errors.""" def __init__(self, message): - super().__init__(f"Unknown API error: {str(message)}") + super().__init__(f"Unknown API error: {message!s}") diff --git a/vermeer-python-client/src/pyvermeer/utils/log.py b/vermeer-python-client/src/pyvermeer/utils/log.py index cc5199f31..856a80d8d 100644 --- a/vermeer-python-client/src/pyvermeer/utils/log.py +++ b/vermeer-python-client/src/pyvermeer/utils/log.py @@ -21,6 +21,7 @@ class VermeerLogger: """vermeer API log""" + _instance = None def __new__(cls, name: str = "VermeerClient"): @@ -38,8 +39,7 @@ def _initialize(self, name: str): if not self.logger.handlers: # Console output format console_format = logging.Formatter( - '[%(asctime)s] [%(levelname)s] %(name)s - %(message)s', - datefmt='%Y-%m-%d %H:%M:%S' + "[%(asctime)s] [%(levelname)s] %(name)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S" ) # Console handler diff --git a/vermeer-python-client/src/pyvermeer/utils/vermeer_config.py b/vermeer-python-client/src/pyvermeer/utils/vermeer_config.py index dfd4d5060..047f69558 100644 --- a/vermeer-python-client/src/pyvermeer/utils/vermeer_config.py +++ b/vermeer-python-client/src/pyvermeer/utils/vermeer_config.py @@ -18,6 +18,7 @@ class VermeerConfig: """The configuration of a Vermeer instance.""" + ip: str port: int token: str @@ -25,11 +26,7 @@ class VermeerConfig: username: str graph_space: str - def __init__(self, - ip: str, - port: int, - token: str, - timeout: tuple[float, float] = (0.5, 15.0)): + def __init__(self, ip: str, port: int, token: str, timeout: tuple[float, float] = (0.5, 15.0)): """Initialize the configuration for a Vermeer instance.""" self.ip = ip self.port = port diff --git a/vermeer-python-client/src/pyvermeer/utils/vermeer_datetime.py b/vermeer-python-client/src/pyvermeer/utils/vermeer_datetime.py index a76dfee77..41f3d0b0b 100644 --- a/vermeer-python-client/src/pyvermeer/utils/vermeer_datetime.py +++ b/vermeer-python-client/src/pyvermeer/utils/vermeer_datetime.py @@ -28,5 +28,5 @@ def parse_vermeer_time(vm_dt: str) -> datetime: return dt -if __name__ == '__main__': - print(parse_vermeer_time('2025-02-17T15:45:05.396311145+08:00').strftime("%Y%m%d")) +if __name__ == "__main__": + pass diff --git a/vermeer-python-client/src/pyvermeer/utils/vermeer_requests.py b/vermeer-python-client/src/pyvermeer/utils/vermeer_requests.py index 118484c4d..c81cb9c4e 100644 --- a/vermeer-python-client/src/pyvermeer/utils/vermeer_requests.py +++ b/vermeer-python-client/src/pyvermeer/utils/vermeer_requests.py @@ -16,14 +16,13 @@ # under the License. import json -from typing import Optional from urllib.parse import urljoin import requests from requests.adapters import HTTPAdapter from urllib3.util.retry import Retry -from pyvermeer.utils.exception import JsonDecodeError, ConnectError, TimeOutError, UnknownError +from pyvermeer.utils.exception import ConnectError, JsonDecodeError, TimeOutError, UnknownError from pyvermeer.utils.log import log from pyvermeer.utils.vermeer_config import VermeerConfig @@ -32,12 +31,12 @@ class VermeerSession: """vermeer session""" def __init__( - self, - cfg: VermeerConfig, - retries: int = 3, - backoff_factor: int = 0.1, - status_forcelist=(500, 502, 504), - session: Optional[requests.Session] = None, + self, + cfg: VermeerConfig, + retries: int = 3, + backoff_factor: int = 0.1, + status_forcelist=(500, 502, 504), + session: requests.Session | None = None, ): """ Initialize the Session. @@ -89,20 +88,13 @@ def close(self): """ self._session.close() - def request( - self, - method: str, - path: str, - params: dict = None - ) -> dict: + def request(self, method: str, path: str, params: dict | None = None) -> dict: """request""" try: log.debug(f"Request made to {path} with params {json.dumps(params)}") - response = self._session.request(method, - self.resolve(path), - headers=self._headers, - data=json.dumps(params), - timeout=self._timeout) + response = self._session.request( + method, self.resolve(path), headers=self._headers, data=json.dumps(params), timeout=self._timeout + ) log.debug(f"Response code:{response.status_code}, received: {response.text}") return response.json() except requests.ConnectionError as e: