From 48e0a186bd620b782f5e265d80c82c74e58e1630 Mon Sep 17 00:00:00 2001 From: danqi Date: Wed, 26 Feb 2025 19:27:56 +0800 Subject: [PATCH 1/2] repo-sync-2025-02-26T19:27:49+0800 --- .circleci/config.yml | 58 + .circleci/continue-config.yml | 116 ++ .circleci/release-config.yml | 72 + .condarc | 11 + .gitignore | 11 + README.md | 2 +- api_linter_config.json | 13 - build.sh | 11 + dev-requirements.txt | 2 + docs/locale/zh_CN/LC_MESSAGES/README.po | 1 + docs/locale/zh_CN/LC_MESSAGES/index.po | 1 + docs/locale/zh_CN/LC_MESSAGES/intro.po | 1 + docs/locale/zh_CN/LC_MESSAGES/spec.po | 1 + docs/requirements.txt | 1 + docs/spec.md | 35 +- requirements.txt | 2 + secretflow_spec/__init__.py | 97 ++ secretflow_spec/core/__init__.py | 13 + secretflow_spec/core/component.py | 149 ++ secretflow_spec/core/definition.py | 1281 +++++++++++++++++ secretflow_spec/core/discovery.py | 101 ++ secretflow_spec/core/dist_data/__init__.py | 13 + secretflow_spec/core/dist_data/base.py | 33 + secretflow_spec/core/dist_data/file.py | 97 ++ secretflow_spec/core/dist_data/report.py | 196 +++ secretflow_spec/core/dist_data/vtable.py | 553 +++++++ secretflow_spec/core/registry.py | 130 ++ secretflow_spec/core/storage/__init__.py | 38 + secretflow_spec/core/storage/base.py | 67 + secretflow_spec/core/storage/local.py | 103 ++ secretflow_spec/core/storage/s3.py | 149 ++ secretflow_spec/core/types.py | 79 + secretflow_spec/core/utils.py | 188 +++ secretflow_spec/core/version.py | 17 + secretflow_spec/protos/api_linter_config.json | 11 + .../protos/run_api_linter.sh | 6 +- secretflow_spec/protos/run_protoc.sh | 38 + .../secretflow_spec}/v1/component.proto | 69 +- .../protos/secretflow_spec}/v1/data.proto | 49 +- .../secretflow_spec}/v1/evaluation.proto | 35 +- .../protos/secretflow_spec}/v1/report.proto | 6 +- secretflow_spec/v1/__init__.py | 13 + secretflow_spec/v1/component_pb2.py | 47 + secretflow_spec/v1/component_pb2.pyi | 563 ++++++++ secretflow_spec/v1/data_pb2.py | 50 + secretflow_spec/v1/data_pb2.pyi | 407 ++++++ secretflow_spec/v1/evaluation_pb2.py | 31 + secretflow_spec/v1/evaluation_pb2.pyi | 124 ++ secretflow_spec/v1/report_pb2.py | 44 + secretflow_spec/v1/report_pb2.pyi | 254 ++++ secretflow_spec/version.py | 17 + setup.py | 116 ++ tests/__init__.py | 13 + tests/comps/__init__.py | 13 + tests/comps/my_comp.py | 20 + tests/spec/extend/__init__.py | 13 + tests/spec/extend/calculate_rules.proto | 59 + tests/spec/extend/calculate_rules_pb2.py | 28 + tests/spec/extend/calculate_rules_pb2.pyi | 124 ++ tests/test_definition.py | 325 +++++ tests/test_discovery.py | 25 + tests/test_dist_data.py | 182 +++ tests/test_registry.py | 57 + tests/test_storage.py | 49 + tests/test_utils.py | 31 + 65 files changed, 6370 insertions(+), 91 deletions(-) create mode 100644 .circleci/config.yml create mode 100644 .circleci/continue-config.yml create mode 100644 .circleci/release-config.yml create mode 100644 .condarc delete mode 100644 api_linter_config.json create mode 100755 build.sh create mode 100644 dev-requirements.txt create mode 100644 requirements.txt create mode 100644 secretflow_spec/__init__.py create mode 100644 secretflow_spec/core/__init__.py create mode 100644 secretflow_spec/core/component.py create mode 100644 secretflow_spec/core/definition.py create mode 100644 secretflow_spec/core/discovery.py create mode 100644 secretflow_spec/core/dist_data/__init__.py create mode 100644 secretflow_spec/core/dist_data/base.py create mode 100644 secretflow_spec/core/dist_data/file.py create mode 100644 secretflow_spec/core/dist_data/report.py create mode 100644 secretflow_spec/core/dist_data/vtable.py create mode 100644 secretflow_spec/core/registry.py create mode 100644 secretflow_spec/core/storage/__init__.py create mode 100644 secretflow_spec/core/storage/base.py create mode 100644 secretflow_spec/core/storage/local.py create mode 100644 secretflow_spec/core/storage/s3.py create mode 100644 secretflow_spec/core/types.py create mode 100644 secretflow_spec/core/utils.py create mode 100644 secretflow_spec/core/version.py create mode 100644 secretflow_spec/protos/api_linter_config.json rename run_api_linter.sh => secretflow_spec/protos/run_api_linter.sh (80%) create mode 100755 secretflow_spec/protos/run_protoc.sh rename {secretflow/spec => secretflow_spec/protos/secretflow_spec}/v1/component.proto (73%) rename {secretflow/spec => secretflow_spec/protos/secretflow_spec}/v1/data.proto (86%) rename {secretflow/spec => secretflow_spec/protos/secretflow_spec}/v1/evaluation.proto (76%) rename {secretflow/spec => secretflow_spec/protos/secretflow_spec}/v1/report.proto (94%) create mode 100644 secretflow_spec/v1/__init__.py create mode 100644 secretflow_spec/v1/component_pb2.py create mode 100644 secretflow_spec/v1/component_pb2.pyi create mode 100644 secretflow_spec/v1/data_pb2.py create mode 100644 secretflow_spec/v1/data_pb2.pyi create mode 100644 secretflow_spec/v1/evaluation_pb2.py create mode 100644 secretflow_spec/v1/evaluation_pb2.pyi create mode 100644 secretflow_spec/v1/report_pb2.py create mode 100644 secretflow_spec/v1/report_pb2.pyi create mode 100644 secretflow_spec/version.py create mode 100644 setup.py create mode 100644 tests/__init__.py create mode 100644 tests/comps/__init__.py create mode 100644 tests/comps/my_comp.py create mode 100644 tests/spec/extend/__init__.py create mode 100644 tests/spec/extend/calculate_rules.proto create mode 100644 tests/spec/extend/calculate_rules_pb2.py create mode 100644 tests/spec/extend/calculate_rules_pb2.pyi create mode 100644 tests/test_definition.py create mode 100644 tests/test_discovery.py create mode 100644 tests/test_dist_data.py create mode 100644 tests/test_registry.py create mode 100644 tests/test_storage.py create mode 100644 tests/test_utils.py diff --git a/.circleci/config.yml b/.circleci/config.yml new file mode 100644 index 0000000..94a4c16 --- /dev/null +++ b/.circleci/config.yml @@ -0,0 +1,58 @@ +# Copyright 2024 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +version: 2.1 + +setup: true + +orbs: + path-filtering: circleci/path-filtering@1.0.0 + continuation: circleci/continuation@1.0.0 + +parameters: + GHA_Actor: + type: string + default: "" + GHA_Action: + type: string + default: "" + GHA_Event: + type: string + default: "" + GHA_Meta: + type: string + default: "" + +workflows: + unittest-workflow: + when: + and: + - not: << pipeline.parameters.GHA_Action >> + - not: << pipeline.parameters.GHA_Meta >> + jobs: + - path-filtering/filter: + base-revision: main + config-path: .circleci/continue-config.yml + mapping: | + requirements.txt base true + dev-requirements.txt base true + .circleci/continue-config.yml base true + secretflow_spec/.* spec true + tests/.* test true + publish-workflow: + when: + equal: ["publish_pypi", << pipeline.parameters.GHA_Meta >>] + jobs: + - continuation/continue: + configuration_path: .circleci/release-config.yml diff --git a/.circleci/continue-config.yml b/.circleci/continue-config.yml new file mode 100644 index 0000000..5ba3622 --- /dev/null +++ b/.circleci/continue-config.yml @@ -0,0 +1,116 @@ +# Use the latest 2.1 version of CircleCI pipeline process engine. +# See: https://circleci.com/docs/2.0/configuration-reference +version: 2.1 +parameters: + base: + type: boolean + default: false + spec: + type: boolean + default: false + test: + type: boolean + default: false + +executors: + linux_executor: # declares a reusable executor + parameters: + resource_class: + type: string + docker: + - image: secretflow/ubuntu-base-ci:latest + resource_class: << parameters.resource_class >> + shell: /bin/bash --login -eo pipefail + +commands: + kill_countdown: + steps: + - run: + name: Cancel job after set time + background: true + command: | + sleep 2400 + echo "Canceling workflow as too much time has elapsed" + curl -X POST --header "Content-Type: application/json" "https://circleci.com/api/v2/workflow/${CIRCLE_WORKFLOW_ID}/cancel?circle-token=${BUILD_TIMER_TOKEN}" + pytest_wrapper: + parameters: + target_folder: + type: string + steps: + - restore_cache: + name: restore pip cache + key: pip-{{ arch }}-{{ .Branch }}-{{ checksum "requirements.txt" }} + - run: + name: Install test tools + command: | + conda init + pip install -r dev-requirements.txt + pip install --force-reinstall pytest + - run: + name: "Run tests" + command: | + conda init + pytest --suppress-no-test-exit-code -n auto --junitxml=results.xml -v -x --capture=no --cov=secretflow_spec/ --cov-report=xml:coverage.xml << parameters.target_folder >> + - store_test_results: + path: ./results.xml + run_test: + steps: + - kill_countdown + - when: + condition: + or: + - << pipeline.parameters.spec >> + - << pipeline.parameters.test >> + steps: + - pytest_wrapper: + target_folder: tests +jobs: + linux_build: + parameters: + resource_class: + type: string + executor: + name: linux_executor + resource_class: << parameters.resource_class >> + steps: + - checkout + - restore_cache: + name: restore pip cache + key: &pip-cache pip-{{ arch }}-{{ .Branch }}-{{ checksum "requirements.txt" }} + - run: + name: Install python deps + command: | + conda init + arch=$(uname -i) + mkdir -p artifacts + pip install -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cpu + - persist_to_workspace: + root: . + paths: + - artifacts + - save_cache: + key: *pip-cache + paths: + - /root/miniconda3/lib/python3.10/site-packages + run_test: + parameters: + resource_class: + type: string + executor: + name: linux_executor + resource_class: << parameters.resource_class >> + steps: + - checkout + - run_test + +# Invoke jobs via workflows +# See: https://circleci.com/docs/2.0/configuration-reference/#workflows +workflows: + build_and_test: + jobs: + - linux_build: + name: linux_build-<> + matrix: + parameters: + resource_class: ["2xlarge+"] + - run_test \ No newline at end of file diff --git a/.circleci/release-config.yml b/.circleci/release-config.yml new file mode 100644 index 0000000..5d85021 --- /dev/null +++ b/.circleci/release-config.yml @@ -0,0 +1,72 @@ +# Use the latest 2.1 version of CircleCI pipeline process engine. +# See: https://circleci.com/docs/2.0/configuration-reference +version: 2.1 +parameters: + GHA_Actor: + type: string + default: "" + GHA_Action: + type: string + default: "" + GHA_Event: + type: string + default: "" + GHA_Meta: + type: string + default: "" + +executors: + linux_x64_executor: # declares a reusable executor + docker: + - image: secretflow/release-ci:latest + resource_class: 2xlarge + shell: /bin/bash --login -eo pipefail + +commands: + build_and_upload: + parameters: + python_ver: + type: string + steps: + - checkout + - run: + name: "build package and publish" + command: | + conda create -n build python=<< parameters.python_ver >> -y + conda activate build + + python3 setup.py bdist_wheel + python3 setup.py clean + + ls dist/*.whl + python3 -m pip install twine + python3 -m twine upload -r pypi -u __token__ -p ${PYPI_TWINE_TOKEN} dist/*.whl + +# Define a job to be invoked later in a workflow. +# See: https://circleci.com/docs/2.0/configuration-reference/#jobs +jobs: + linux_publish: + parameters: + python_ver: + type: string + executor: + type: string + executor: <> + steps: + - checkout + - build_and_upload: + python_ver: <> + +# Invoke jobs via workflows +# See: https://circleci.com/docs/2.0/configuration-reference/#workflows +workflows: + publish: + jobs: + - linux_publish: + matrix: + parameters: + python_ver: ["3.10"] + executor: [ "linux_x64_executor"] + filters: + tags: + only: /.*/ \ No newline at end of file diff --git a/.condarc b/.condarc new file mode 100644 index 0000000..83ed619 --- /dev/null +++ b/.condarc @@ -0,0 +1,11 @@ +channels: + - defaults +show_channel_urls: true +default_channels: + - https://mirrors.ustc.edu.cn/anaconda/pkgs/free/ + - https://mirrors.ustc.edu.cn/anaconda/pkgs/main/ +custom_channels: + conda-forge: https://mirrors.ustc.edu.cn/anaconda/cloud/conda-forge/ + msys2: https://mirrors.ustc.edu.cn/anaconda/cloud/msys2/ + bioconda: https://mirrors.ustc.edu.cn/anaconda/cloud/bioconda/ + menpo: https://mirrors.ustc.edu.cn/anaconda/cloud/menpo/ \ No newline at end of file diff --git a/.gitignore b/.gitignore index 7192ea3..713854a 100644 --- a/.gitignore +++ b/.gitignore @@ -8,3 +8,14 @@ bazel-* # macos .DS_Store + +.pytest_cache/ +.coverage +coverage.xml +results.xml +dist + +# build +build/ +secretflow_spec.egg-info/ +secretflow_spec/protos/protoc-26.1/ \ No newline at end of file diff --git a/README.md b/README.md index e1f2096..cae97df 100644 --- a/README.md +++ b/README.md @@ -14,5 +14,5 @@ After you modified protos, please check with API Linter. ```bash go install github.com/googleapis/api-linter/cmd/api-linter@latest -sh run_api_linter.sh +sh secretflow_spec/protos/run_api_linter.sh ``` diff --git a/api_linter_config.json b/api_linter_config.json deleted file mode 100644 index 8ac8271..0000000 --- a/api_linter_config.json +++ /dev/null @@ -1,13 +0,0 @@ -[ - { - "included_paths": [ - "secretflow/**/*.proto" - ], - "disabled_rules": [ - "core::0192::only-leading-comments", - "core::0192::has-comments", - "core::0146::any", - "core::0123::resource-annotation" - ] - } -] \ No newline at end of file diff --git a/build.sh b/build.sh new file mode 100755 index 0000000..26eabca --- /dev/null +++ b/build.sh @@ -0,0 +1,11 @@ +#!/bin/bash +cd "$(dirname "$(readlink -f "$0")")" + +FILE="secretflow_spec/version.py" +CACHE=$(cat "$FILE") + +rm -f dist/*.whl +python3 setup.py bdist_wheel +rm -rf ./build ./secretflow_spec.egg-info + +echo "$CACHE" > "$FILE" \ No newline at end of file diff --git a/dev-requirements.txt b/dev-requirements.txt new file mode 100644 index 0000000..08ec1a8 --- /dev/null +++ b/dev-requirements.txt @@ -0,0 +1,2 @@ +pytest==7.3.1 +pytest-cov==4.0.0 \ No newline at end of file diff --git a/docs/locale/zh_CN/LC_MESSAGES/README.po b/docs/locale/zh_CN/LC_MESSAGES/README.po index e59ea61..ca9f8c7 100644 --- a/docs/locale/zh_CN/LC_MESSAGES/README.po +++ b/docs/locale/zh_CN/LC_MESSAGES/README.po @@ -43,3 +43,4 @@ msgstr "" #: ../../README.md:22 msgid "The generated html is at **_build/html/**" msgstr "" + diff --git a/docs/locale/zh_CN/LC_MESSAGES/index.po b/docs/locale/zh_CN/LC_MESSAGES/index.po index e11aaed..4b958ea 100644 --- a/docs/locale/zh_CN/LC_MESSAGES/index.po +++ b/docs/locale/zh_CN/LC_MESSAGES/index.po @@ -93,3 +93,4 @@ msgid "" "We officially launch the first version of Specification with SecretFlow " "1.0.0." msgstr "我们伴随 SecretFlow 1.0.0 发布了第一个版本。" + diff --git a/docs/locale/zh_CN/LC_MESSAGES/intro.po b/docs/locale/zh_CN/LC_MESSAGES/intro.po index 1eba8bf..45c1071 100644 --- a/docs/locale/zh_CN/LC_MESSAGES/intro.po +++ b/docs/locale/zh_CN/LC_MESSAGES/intro.po @@ -955,3 +955,4 @@ msgstr "" msgid "In SecretFlow, the type str for Report is *sf.report*." msgstr "" "在SecretFlow中,Report的类型为 *sf.report* 。" + diff --git a/docs/locale/zh_CN/LC_MESSAGES/spec.po b/docs/locale/zh_CN/LC_MESSAGES/spec.po index 8f8784f..d0a1793 100644 --- a/docs/locale/zh_CN/LC_MESSAGES/spec.po +++ b/docs/locale/zh_CN/LC_MESSAGES/spec.po @@ -1600,3 +1600,4 @@ msgstr "" #~ "The path of attributes. The attribute" #~ " path for a TableAttrDef is `(input" #~ msgstr "属性路径。 一个TableAttrDef的属性路径为 `(input" + diff --git a/docs/requirements.txt b/docs/requirements.txt index 673bb04..11295d6 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -3,3 +3,4 @@ sphinx==5.3.0 myst-parser==0.18.1 sphinx-intl==2.1.0 pydata-sphinx-theme + diff --git a/docs/spec.md b/docs/spec.md index a47b3e1..006102b 100644 --- a/docs/spec.md +++ b/docs/spec.md @@ -7,7 +7,7 @@ ### [DATA](#DATA) - + - Messages - [DistData](#distdata) - [DistData.DataRef](#distdatadataref) @@ -17,15 +17,15 @@ - [SystemInfo](#systeminfo) - [TableSchema](#tableschema) - [VerticalTable](#verticaltable) - - - + + + ### [COMPONENT](#COMPONENT) - + - Messages - [Attribute](#attribute) - [AttributeDef](#attributedef) @@ -35,31 +35,31 @@ - [ComponentDef](#componentdef) - [IoDef](#iodef) - [IoDef.TableAttrDef](#iodeftableattrdef) - - - + + + - Enums - [AttrType](#attrtype) - + ### [EVALUATION](#EVALUATION) - + - Messages - [NodeEvalParam](#nodeevalparam) - [NodeEvalResult](#nodeevalresult) - - - + + + ### [REPORT](#REPORT) - + - Messages - [Descriptions](#descriptions) - [Descriptions.Item](#descriptionsitem) @@ -70,9 +70,9 @@ - [Table](#table) - [Table.HeaderItem](#tableheaderitem) - [Table.Row](#tablerow) - - - + + + @@ -646,3 +646,4 @@ Displays rows of data. |

bool | | bool | boolean | boolean | |

string | A string must always contain UTF-8 encoded or 7-bit ASCII text. | string | String | str/unicode | |

bytes | May contain any arbitrary sequence of bytes. | string | ByteString | str | + diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..8e84ced --- /dev/null +++ b/requirements.txt @@ -0,0 +1,2 @@ +protobuf>=4,<5 +s3fs==2024.2.0 \ No newline at end of file diff --git a/secretflow_spec/__init__.py b/secretflow_spec/__init__.py new file mode 100644 index 0000000..fccc29c --- /dev/null +++ b/secretflow_spec/__init__.py @@ -0,0 +1,97 @@ +# Copyright 2024 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from secretflow_spec.core.component import ( + BuiltinType, + Component, + Input, + Output, + UnionGroup, + UnionSelection, +) +from secretflow_spec.core.definition import Definition, Field, Interval +from secretflow_spec.core.discovery import load_component_modules +from secretflow_spec.core.dist_data.file import ObjectFile +from secretflow_spec.core.dist_data.report import Reporter +from secretflow_spec.core.dist_data.vtable import ( + VTable, + VTableField, + VTableFieldKind, + VTableFieldType, + VTableFormat, + VTableParty, + VTableSchema, +) +from secretflow_spec.core.registry import Registry, register +from secretflow_spec.core.storage import ( + LocalStorage, + S3Storage, + Storage, + StorageType, + make_storage, +) +from secretflow_spec.core.types import StrEnum, Version +from secretflow_spec.core.utils import build_node_eval_param, to_attribute, to_type +from secretflow_spec.core.version import ( + SPEC_VERSION, + SPEC_VERSION_MAJOR, + SPEC_VERSION_MINOR, +) + +__all__ = [ + "SPEC_VERSION", + "SPEC_VERSION_MAJOR", + "SPEC_VERSION_MINOR", + # component + "BuiltinType", + "Component", + "Input", + "Output", + "UnionGroup", + "UnionSelection", + # definition + "Definition", + "Field", + "Interval", + # registry + "Registry", + "register", + # discovery + "load_component_modules", + # dist_data.file + "ObjectFile", + # dist_data.report + "Reporter", + # dist_data.vtable + "VTable", + "VTableField", + "VTableFieldKind", + "VTableFieldType", + "VTableFormat", + "VTableParty", + "VTableSchema", + # storage + "make_storage", + "StorageType", + "Storage", + "S3Storage", + "LocalStorage", + # types + "StrEnum", + "Version", + # utils + "to_type", + "to_attribute", + "build_node_eval_param", +] diff --git a/secretflow_spec/core/__init__.py b/secretflow_spec/core/__init__.py new file mode 100644 index 0000000..086637a --- /dev/null +++ b/secretflow_spec/core/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/secretflow_spec/core/component.py b/secretflow_spec/core/component.py new file mode 100644 index 0000000..36cc1e9 --- /dev/null +++ b/secretflow_spec/core/component.py @@ -0,0 +1,149 @@ +# Copyright 2024 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from abc import ABC, abstractmethod +from dataclasses import dataclass, field, is_dataclass +from typing import TypeVar + +from secretflow_spec.v1.data_pb2 import DistData + +try: + from typing import dataclass_transform +except ImportError: + # Define a no-op decorator if dataclass_transform is not available + def dataclass_transform(*args, **kwargs): + def wrapper(cls): + return cls + + return wrapper + + +class BuiltinType: ... + + +Input = DistData + + +@dataclass +class Output(BuiltinType): + uri: str + data: DistData = None + + +@dataclass +class UnionSelection(BuiltinType): + name: str + desc: str + minor_min: int = 0 + minor_max: int = -1 + + +class UnionGroup(BuiltinType): + """ + A group of union attrs. + """ + + def __init__(self) -> None: + self._selected: str = "" + + def is_selected(self, v: str) -> bool: + return self._selected == v + + def set_selected(self, v: str): + self._selected = v + + def get_selected(self) -> str: + return self._selected + + +@dataclass_transform() +@dataclass +class Component(ABC): + """ + Component is the base class of all components. + + It uses the metadata within the field function to describe the fields in the component. + And it leverages the dataclass decorator to automatically generate methods for these described fields, + Then we can use __dataclass_fields__ to reflect the Component and convert to ComponentDef + + NOTE: + you do not need to manually add the dataclass decorator to your component, + as it has already been automatically applied by the __init_subclass__ method. + + More detail for dataclasses: + https://peps.python.org/pep-0557/ + + When you create a new component, you should define its domain and version explicitly. + The name of the component can be inferred from the class name it is defined within, + and the description can be taken from the docstring. + However, for more control over these values, you can set them manually. + + Examples: + >>> @register(domain="test", version="0.0.1") + >>> class DemoComponent(Component): + >>> value : int = Field.attr(desc="value", default=1) + >>> input0: Input = Field.input(desc="input table", types=[DistDataType.INDIVIDUAL_TABLE]) + >>> # implement abstractmethod + >>> def evaluate(self, ctx: Context): + >>> print(self.value) + >>> # create instance directly + >>> comp = DemoComponent(value=1) + >>> assert comp.value == 1 + >>> # use Definition to reflect the component and create instance + >>> definition = Definition(DemoComponent) + >>> args = {"_minor":0, "value":1, "input0": DistData()} + >>> comp : DemoComponent = definition.make_component(args) + >>> asset comp.value == 1 + """ + + _minor: int = field(default=-1, kw_only=True) + + @property + def minor(self) -> int: + return self._minor + + def is_supported(self, v: int) -> bool: + return self._minor > v + + def __init_subclass__(cls, **kwargs): + def _is_custom_class(cls): + if not isinstance(cls, type) or cls.__module__ == "builtins": + return False + if issubclass(cls, (DistData, BuiltinType)): + return False + return True + + def _check_field_dataclass(cls): + for f in cls.__dataclass_fields__.values(): + if f.name.startswith("_"): + continue + if not _is_custom_class(f.type): + continue + if not is_dataclass(f.type): + f.type = dataclass(f.type) + _check_field_dataclass(f.type) + + dataclass(cls, kw_only=True) + _check_field_dataclass(cls) + + super().__init_subclass__(**kwargs) + + @abstractmethod + def evaluate(self, *args, **kwargs) -> None: + raise NotImplementedError(f"{type(self)} evaluate is not implemented.") + + +# TComponent is used to implement the mixin design pattern. +TComponent = TypeVar("TComponent", bound=Component) diff --git a/secretflow_spec/core/definition.py b/secretflow_spec/core/definition.py new file mode 100644 index 0000000..aa4516c --- /dev/null +++ b/secretflow_spec/core/definition.py @@ -0,0 +1,1281 @@ +# Copyright 2024 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from dataclasses import MISSING +from dataclasses import Field as DField +from dataclasses import dataclass, field, is_dataclass +from enum import Enum, auto +from typing import Any, Type, get_args, get_origin + +from google.protobuf import json_format +from google.protobuf.message import Message + +from secretflow_spec.core.component import ( + Component, + Input, + Output, + UnionGroup, + UnionSelection, +) +from secretflow_spec.core.dist_data.base import DistDataType +from secretflow_spec.core.utils import clean_text +from secretflow_spec.v1.component_pb2 import ( + Attribute, + AttributeDef, + AttrType, + ComponentDef, + IoDef, +) +from secretflow_spec.v1.data_pb2 import DistData +from secretflow_spec.v1.evaluation_pb2 import NodeEvalParam + +LABEL_LEN_MAX = 64 + + +class Interval: + def __init__( + self, + lower: float | int = None, + upper: float | int = None, + lower_closed: bool = False, + upper_closed: bool = False, + ): + if lower is not None and upper is not None: + assert upper >= lower + + self.lower = lower + self.upper = upper + self.lower_closed = lower_closed + self.upper_closed = upper_closed + + @staticmethod + def open(lower: float | int | None, upper: float | int | None) -> "Interval": + """return (lower, upper)""" + return Interval( + lower=lower, upper=upper, lower_closed=False, upper_closed=False + ) + + @staticmethod + def closed(lower: float | int | None, upper: float | int | None) -> "Interval": + """return [lower, upper]""" + return Interval(lower=lower, upper=upper, lower_closed=True, upper_closed=True) + + @staticmethod + def open_closed(lower: float | int | None, upper: float | int | None) -> "Interval": + """return (lower, upper]""" + return Interval(lower=lower, upper=upper, lower_closed=False, upper_closed=True) + + @staticmethod + def closed_open(lower: float | int | None, upper: float | int | None) -> "Interval": + """return [lower, upper)""" + return Interval(lower=lower, upper=upper, lower_closed=True, upper_closed=False) + + def astype(self, typ: type): + assert typ in [float, int] + if self.lower is not None: + self.lower = typ(self.lower) + if self.upper is not None: + self.upper = typ(self.upper) + + def enforce_closed(self): + if self.lower != None: + if isinstance(self.lower, float) and not self.lower.is_integer(): + raise ValueError(f"Lower bound must be an integer, {self.lower}") + self.lower = int(self.lower) + if not self.lower_closed: + self.lower += 1 + self.lower_closed = True + + if self.upper != None: + if isinstance(self.upper, float) and not self.upper.is_integer(): + raise ValueError(f"Upper bound must be an integer, {self.upper}") + self.upper = int(self.upper) + if not self.upper_closed: + self.upper -= 1 + self.upper_closed = True + + def check(self, v: float | int) -> tuple[bool, str]: + if self.upper is not None: + if self.upper_closed: + if v > self.upper: + return ( + False, + f"should be less than or equal {self.upper}, but got {v}", + ) + else: + if v >= self.upper: + return ( + False, + f"should be less than {self.upper}, but got {v}", + ) + if self.lower is not None: + if self.lower_closed: + if v < self.lower: + return ( + False, + f"should be greater than or equal {self.lower}, but got {v}", + ) + else: + if v <= self.lower: + return ( + False, + f"should be greater than {self.lower}, but got {v}", + ) + return True, "" + + +class FieldKind(Enum): + BasicAttr = auto() + PartyAttr = auto() + CustomAttr = auto() + StructAttr = auto() + UnionAttr = auto() + SelectionAttr = auto() + TableColumnAttr = auto() + Input = auto() + Output = auto() + + +def is_deprecated_field(minor_max: int) -> bool: + return minor_max != -1 + + +def is_deprecated_minor(minor_max: int, minor: int) -> bool: + return minor_max != -1 and minor > minor_max + + +@dataclass +class _Metadata: + prefixes: list = None + fullname: str = "" + name: str = "" + type: Type = None + kind: FieldKind = None + desc: str = None + is_optional: bool = False + choices: list = None + bound_limit: Interval = None + list_limit: Interval = None + default: Any = None + selections: dict[str, UnionSelection] = None # only used in union_group + input_name: str = None # only used in table_column_attr + is_checkpoint: bool = False # if true it will be save when dump checkpoint + types: list[str] = None # only used in input/output + minor_min: int = 0 # it's supported only if minor >= minor_min + minor_max: int = -1 # it's deprecated if minor > minor_max and minor_max != -1 + + +class Field: + @staticmethod + def _field( + kind: FieldKind, + minor_min: int, + minor_max: int, + desc: str, + md: _Metadata | None = None, + default: Any = None, + init=True, + ): + assert minor_max is not None + if minor_max != -1 and minor_min > minor_max: + raise ValueError(f"invalid minor version, {minor_min}, {minor_max}") + if md is None: + md = _Metadata() + md.kind = kind + md.desc = clean_text(desc) + md.minor_min = minor_min + md.minor_max = minor_max + + if isinstance(default, list): + default = MISSING + default_factory = lambda: default + else: + default_factory = MISSING + return field( + default=default, + default_factory=default_factory, + init=init, + kw_only=True, + metadata={"md": md}, + ) + + @staticmethod + def attr( + desc: str = "", + is_optional: bool | None = None, + default: Any | None = None, + choices: list | None = None, + bound_limit: Interval | None = None, + list_limit: Interval | None = None, + is_checkpoint: bool = False, + minor_min: int = 0, + minor_max: int = -1, + ): + if is_optional is None: + is_optional = default != MISSING and default is not None + + md = _Metadata( + is_optional=is_optional, + choices=choices, + bound_limit=bound_limit, + list_limit=list_limit, + is_checkpoint=is_checkpoint, + default=default if default != MISSING else None, + ) + return Field._field( + FieldKind.BasicAttr, minor_min, minor_max, desc, md, default + ) + + @staticmethod + def party_attr( + desc: str = "", + list_limit: Interval | None = None, + minor_min: int = 0, + minor_max: int = -1, + ): + md = _Metadata(list_limit=list_limit) + return Field._field(FieldKind.PartyAttr, minor_min, minor_max, desc, md) + + @staticmethod + def struct_attr(desc: str = "", minor_min: int = 0, minor_max: int = -1): + return Field._field(FieldKind.StructAttr, minor_min, minor_max, desc) + + @staticmethod + def union_attr( + desc: str = "", + default: str = "", + selections: list[UnionSelection] | None = None, # only used when type is str + minor_min: int = 0, + minor_max: int = -1, + ): + if selections: + selections = {s.name: s for s in selections} + md = _Metadata(default=default, selections=selections) + return Field._field(FieldKind.UnionAttr, minor_min, minor_max, desc, md) + + @staticmethod + def selection_attr(desc: str = "", minor_min: int = 0, minor_max: int = -1): + return Field._field(FieldKind.SelectionAttr, minor_min, minor_max, desc) + + @staticmethod + def custom_attr(desc: str = "", minor_min: int = 0, minor_max: int = -1): + return Field._field(FieldKind.CustomAttr, minor_min, minor_max, desc) + + @staticmethod + def table_column_attr( + input_name: str, + desc: str = "", + limit: Interval | None = None, + is_checkpoint: bool = False, + minor_min: int = 0, + minor_max: int = -1, + ): + if input_name == "": + raise ValueError("input_name cannot be empty") + md = _Metadata( + input_name=input_name, + list_limit=limit, + is_checkpoint=is_checkpoint, + ) + return Field._field(FieldKind.TableColumnAttr, minor_min, minor_max, desc, md) + + @staticmethod + def input( + desc: str = "", + types: list[str] = [], + is_checkpoint: bool = False, + list_limit: Interval = None, + minor_min: int = 0, + minor_max: int = -1, + ): + """ + the last input can be variable and the type must be list[Input] + """ + if not types: + raise ValueError("input types is none") + types = [str(s) for s in types] + md = _Metadata(types=types, is_checkpoint=is_checkpoint, list_limit=list_limit) + return Field._field(FieldKind.Input, minor_min, minor_max, desc, md) + + @staticmethod + def output( + desc: str = "", + types: list[str] = [], + minor_min: int = 0, + minor_max: int = -1, + ): + if not types: + raise ValueError("output types is none") + types = [str(s) for s in types] + md = _Metadata(types=types) + return Field._field(FieldKind.Output, minor_min, minor_max, desc, md) + + +class Creator: + def __init__(self, check_exist: bool) -> None: + self._check_exist = check_exist + + def make(self, cls: Type, kwargs: dict, minor: int): + args = {} + for name, field in cls.__dataclass_fields__.items(): + if name == MINOR_NAME: + continue + args[name] = self._make_field(field, kwargs, minor) + if len(kwargs) > 0: + unused = {k: self._check_unused_type(cls, k, minor) for k in kwargs.keys()} + raise ValueError(f"unused fields {unused}") + + args[MINOR_NAME] = minor + ins = cls(**args) + setattr(ins, MINOR_NAME, minor) + return ins + + def _check_unused_type(self, cls: Type, key: str, minor: int) -> str: + UNKNOWN = "unknown" + DEPRECATED = "deprecated" + + tokens = key.split("/") + cur_cls = cls + for token in tokens: + if not is_dataclass(cur_cls): + return UNKNOWN + if token not in cur_cls.__dataclass_fields__: + return UNKNOWN + + field = cur_cls.__dataclass_fields__[token] + md: _Metadata = field.metadata["md"] + if is_deprecated_minor(md.minor_max, minor): + return DEPRECATED + cur_cls = md.type + return UNKNOWN + + def _make_field(self, field: DField, kwargs: dict, minor: int): + md: _Metadata = field.metadata["md"] + if is_deprecated_minor(md.minor_max, minor): + return None + + if md.kind == FieldKind.StructAttr: + return self._make_struct(md, kwargs, minor) + elif md.kind == FieldKind.UnionAttr: + return self._make_union(md, kwargs, minor) + + if minor < md.minor_min: + return md.default + + if md.fullname not in kwargs: + if self._check_exist and not md.is_optional: + raise ValueError(f"{md.fullname} is required") + else: + return md.default + + value = kwargs.pop(md.fullname, md.default) + + if md.kind == FieldKind.Input: + if not isinstance(value, (DistData, list)): + raise ValueError(f"type of {md.name} should be DistData") + + return ( + value + if isinstance(value, list) or value.type != DistDataType.NULL + else None + ) + elif md.kind == FieldKind.Output: + if not isinstance(value, (Output, str)): + raise ValueError( + f"type of {md.name} should be str or Output, but got {type(value)}" + ) + return value if isinstance(value, Output) else Output(uri=value, data=None) + elif md.kind == FieldKind.TableColumnAttr: + return self._make_str_or_list(md, value) + elif md.kind == FieldKind.PartyAttr: + return self._make_str_or_list(md, value) + elif md.kind == FieldKind.CustomAttr: + pb_inst = md.type() + return json_format.Parse(value, pb_inst) + elif md.kind == FieldKind.BasicAttr: + return self._make_basic(md, value) + else: + raise ValueError(f"invalid field kind, {md.fullname}, {md.kind}") + + def _make_struct(self, md: _Metadata, kwargs: dict, minor: int): + cls = md.type + args = {} + for name, field in cls.__dataclass_fields__.items(): + args[name] = self._make_field(field, kwargs, minor) + + return cls(**args) + + def _make_union(self, md: _Metadata, kwargs: dict, minor: int): + union_type = md.type + if minor < md.minor_min: + selected_key = md.default + else: + selected_key = kwargs.pop(md.fullname, md.default) + + if not isinstance(selected_key, str): + raise ValueError( + f"{md.fullname} should be a str, but got {type(selected_key)}" + ) + if union_type == str: + if selected_key not in md.selections: + raise ValueError(f"{selected_key} not in {md.selections.keys()}") + selection = md.selections[selected_key] + if is_deprecated_minor(selection.minor_max, minor): + raise ValueError(f"{selected_key} is deprecated") + return selected_key + + choices = union_type.__dataclass_fields__.keys() + if selected_key not in choices: + raise ValueError(f"{selected_key} should be one of {choices}") + + selected_field = md.type.__dataclass_fields__[selected_key] + selected_md: _Metadata = selected_field.metadata["md"] + if is_deprecated_minor(selected_md.minor_max, minor): + raise ValueError(f"{selected_key} is deprecated") + + args = {} + if selected_md.kind != FieldKind.SelectionAttr: + value = self._make_field(selected_field, kwargs, minor) + args = {selected_key: value} + res: UnionGroup = md.type(**args) + res.set_selected(selected_key) + return res + + def _make_basic(self, md: _Metadata, value): + is_list = isinstance(value, list) + if is_list and md.list_limit: + is_valid, err_str = md.list_limit.check(len(value)) + if not is_valid: + raise ValueError(f"length of {md.fullname} is valid, {err_str}") + + check_list = value if is_list else [value] + if md.bound_limit is not None: + for v in check_list: + is_valid, err_str = md.bound_limit.check(v) + if not is_valid: + raise ValueError(f"value of {md.fullname} is valid, {err_str}") + if md.choices is not None: + for v in check_list: + if v not in md.choices: + raise ValueError( + f"value {v} must be in {md.choices}, name is {md.fullname}" + ) + return value + + def _make_str_or_list(self, md: _Metadata, value): + if value is None: + raise ValueError(f"{md.name} can not be none") + is_list = get_origin(md.type) is list + if not is_list: + if isinstance(value, list): + if len(value) != 1: + raise ValueError(f"{md.name} can only have one element") + value = value[0] + assert isinstance( + value, str + ), f"{md.name} must be str, but got {type(value)}" + return value + else: + assert isinstance( + value, list + ), f"{md.name} must be list[str], but got {type(value)}" + if md.list_limit is not None: + is_valid, err_str = md.list_limit.check(len(value)) + if not is_valid: + raise ValueError(f"length of {md.name} is invalid, {err_str}") + + return value + + +MINOR_NAME = "_minor" +RESERVED = ["input", "output"] + + +class Reflector: + def __init__(self, cls, name: str, minor: int): + self._cls = cls + self._name = name + self._minor = minor + self._inputs: list[IoDef] = [] + self._outputs: list[IoDef] = [] + self._attrs: list[AttributeDef] = [] + self._attr_types: dict[str, AttrType] = {} + + def get_inputs(self) -> list[IoDef]: + return self._inputs + + def get_outputs(self) -> list[IoDef]: + return self._outputs + + def get_attrs(self) -> list[AttributeDef]: + return self._attrs + + def get_attr_types(self) -> dict[str, AttrType]: + return self._attr_types + + def reflect(self): + """ + Reflect dataclass to ComponentDef. + """ + self._force_dataclass(self._cls) + + attrs: list[_Metadata] = [] + for field in self._cls.__dataclass_fields__.values(): + if field.name == MINOR_NAME: + continue + md = self._build_metadata(field, []) + if md.kind == FieldKind.Input: + is_list, prim_type = self._check_list(md.type) + if prim_type != Input: + raise ValueError("input type must be Input") + if is_list and DistDataType.NULL in md.types: + raise ValueError("input type cannot be null if is variable") + io_def = self._reflect_io(md, is_list) + self._inputs.append(io_def) + elif md.kind == FieldKind.Output: + if md.type != Output: + raise ValueError("output type must be Output") + io_def = self._reflect_io(md) + self._outputs.append(io_def) + else: + attrs.append(md) + + for md in attrs: + self._reflect_attr(md) + + # check input variable + for idx, io in enumerate(self._inputs): + if is_deprecated_field(io.minor_max): + continue + if io.is_variable and idx != len(self._inputs) - 1: + raise ValueError(f"variable input must be the last one") + + def _reflect_io(self, md: _Metadata, is_list: bool = False): + assert ( + DistDataType.OUTDATED_VERTICAL_TABLE not in md.types + ), f"sf.table.vertical_table is deprecated, please use sf.table.vertical in {md.fullname}" + variable_min, variable_max = 0, -1 + if is_list and md.list_limit: + l = md.list_limit + l.enforce_closed() + variable_min = l.lower if l.lower else 0 + variable_max = l.upper if l.upper else -1 + + is_optional = DistDataType.NULL in md.types + return IoDef( + name=md.name, + desc=md.desc, + types=md.types, + is_optional=is_optional, + is_variable=is_list, + variable_min=variable_min, + variable_max=variable_max, + minor_min=md.minor_min, + minor_max=md.minor_max, + ) + + def _reflect_party_attr(self, md: _Metadata): + is_list, org_type = self._check_list(md.type) + if org_type != str: + raise ValueError(f"the type of party attr should be str or list[str]") + list_min_length_inclusive, list_max_length_inclusive = self._build_list_limit( + is_list, md.list_limit + ) + if list_min_length_inclusive <= 0: + md.is_optional = True + atomic = AttributeDef.AtomicAttrDesc( + list_min_length_inclusive=list_min_length_inclusive, + list_max_length_inclusive=list_max_length_inclusive, + ) + self._append_attr(AttrType.AT_PARTY, md, atomic=atomic) + + def _reflect_table_column_attr(self, md: _Metadata): + is_list, prim_type = self._check_list(md.type) + if prim_type != str: + raise ValueError( + f"input_table_attr's type must be str or list[str], but got {md.type}]" + ) + + input_name = md.input_name + io_def = next((io for io in self._inputs if io.name == input_name), None) + if io_def is None: + raise ValueError(f"cannot find input io, {input_name}") + + if not input_name: + raise ValueError(f"input_name cannot be empty in field<{md.fullname}>") + + for t in io_def.types: + if t not in [ + str(DistDataType.VERTICAL_TABLE), + str(DistDataType.INDIVIDUAL_TABLE), + ]: + raise ValueError(f"{input_name} is not defined correctly in input.") + + col_min_cnt_inclusive, col_max_cnt_inclusive = self._build_list_limit( + is_list, md.list_limit + ) + if col_min_cnt_inclusive <= 0: + md.is_optional = True + if md.prefixes: + atomic = AttributeDef.AtomicAttrDesc( + list_min_length_inclusive=col_min_cnt_inclusive, + list_max_length_inclusive=col_max_cnt_inclusive, + ) + self._append_attr( + AttrType.AT_COL_PARAMS, + md, + atomic=atomic, + col_params_binded_table=md.input_name, + ) + else: + if col_max_cnt_inclusive < 0: + col_max_cnt_inclusive = 0 + preifx = md.input_name + "_" + if md.name.startswith(preifx): + name = md.name[len(preifx) :] + else: + name = md.name + tbl_attr = IoDef.TableAttrDef( + name=name, + desc=md.desc, + col_min_cnt_inclusive=col_min_cnt_inclusive, + col_max_cnt_inclusive=col_max_cnt_inclusive, + ) + io_def.attrs.append(tbl_attr) + self._attr_types[md.fullname] = AttrType.AT_STRINGS + + def _reflect_attr(self, md: _Metadata): + if md.kind == FieldKind.StructAttr: + self._reflect_struct_attr(md) + elif md.kind == FieldKind.UnionAttr: + self._reflect_union_attr(md) + elif md.kind == FieldKind.BasicAttr: + self._reflect_basic_attr(md) + elif md.kind == FieldKind.CustomAttr: + self._reflect_custom_attr(md) + elif md.kind == FieldKind.TableColumnAttr: + self._reflect_table_column_attr(md) + elif md.kind == FieldKind.PartyAttr: + self._reflect_party_attr(md) + else: + raise ValueError(f"{md.kind} not supported, metadata={md}.") + + def _reflect_struct_attr(self, md: _Metadata): + self._force_dataclass(md.type) + + self._append_attr(AttrType.AT_STRUCT_GROUP, md) + + prefixes = md.prefixes + [md.name] + for field in md.type.__dataclass_fields__.values(): + sub_md = self._build_metadata(field, prefixes, md) + self._reflect_attr(sub_md) + + def _reflect_union_attr(self, md: _Metadata): + sub_mds = [] + prefixes = md.prefixes + [md.name] + + if md.type == str: + if not md.selections: + raise ValueError(f"no selections in {md.name}") + prefix = "/".join(prefixes) + for s in md.selections.values(): + fullname = f"{prefix}/{s.name}" + sub_md: _Metadata = _Metadata( + kind=FieldKind.SelectionAttr, + type=str, + prefixes=prefixes, + fullname=fullname, + name=s.name, + desc=clean_text(s.desc), + minor_min=s.minor_min, + minor_max=s.minor_max, + ) + sub_mds.append(sub_md) + else: + if md.selections: + raise ValueError( + f"cannot assign selections when type is not str, {md.name}" + ) + if not issubclass(md.type, UnionGroup): + raise ValueError( + f"type<{md.type}> of {md.name} must be subclass of UnionGroup." + ) + + self._force_dataclass(md.type) + + for field in md.type.__dataclass_fields__.values(): + sub_md: _Metadata = self._build_metadata(field, prefixes, parent=md) + sub_mds.append(sub_md) + + md.choices = [] + for sub_md in sub_mds: + if not is_deprecated_field(sub_md.minor_max): + md.choices.append(sub_md.name) + + if len(md.choices) == 0: + raise ValueError(f"union {md.name} must have at least one choice.") + + if md.default == "": + md.default = md.choices[0] + elif md.default not in md.choices: + raise ValueError( + f"{md.default} not in {md.choices}, union name is {md.name}" + ) + + union_desc = AttributeDef.UnionAttrGroupDesc(default_selection=md.default) + self._append_attr(AttrType.AT_UNION_GROUP, md, union=union_desc) + + for sub_md in sub_mds: + if sub_md.kind == FieldKind.SelectionAttr: + self._append_attr(AttrType.ATTR_TYPE_UNSPECIFIED, sub_md) + else: + self._reflect_attr(sub_md) + + def _reflect_custom_attr(self, md: _Metadata): + pb_cls = md.type + assert issubclass(pb_cls, Message), f"support protobuf class only, got {pb_cls}" + extend_path = "secretflow.spec.extend." + module = pb_cls.__module__ + if module.startswith(extend_path): + module = module[len(extend_path) :] + pb_cls_name = f"{module}.{pb_cls.__qualname__}" + self._append_attr(AttrType.AT_CUSTOM_PROTOBUF, md, pb_cls=pb_cls_name) + + def _reflect_basic_attr(self, md: _Metadata): + is_list, prim_type = self._check_list(md.type) + attr_type = self._to_attr_type(prim_type, is_list) + if attr_type == AttrType.ATTR_TYPE_UNSPECIFIED: + raise ValueError(f"invalid primative type {prim_type}, name is {md.name}.") + + if is_list: + list_min_length_inclusive, list_max_length_inclusive = ( + self._build_list_limit(True, md.list_limit) + ) + else: + list_min_length_inclusive, list_max_length_inclusive = None, None + + # check bound + lower_bound_enabled = False + lower_bound_inclusive = False + lower_bound = None + upper_bound_enabled = False + upper_bound_inclusive = False + upper_bound = None + + if md.bound_limit is not None: + if prim_type not in [int, float]: + raise ValueError( + f"bound limit is not supported for {prim_type}, name is {md.name}." + ) + md.bound_limit.astype(prim_type) + if md.choices is not None: + for v in md.choices: + is_valid, err_str = md.bound_limit.check(v) + if not is_valid: + raise ValueError( + f"choices of {md.fullname} is valid, {err_str}" + ) + if md.bound_limit.lower is not None: + lower_bound_enabled = True + lower_bound_inclusive = md.bound_limit.lower_closed + lower_bound = self._to_attr(prim_type(md.bound_limit.lower)) + if md.bound_limit.upper is not None: + upper_bound_enabled = True + upper_bound_inclusive = md.bound_limit.upper_closed + upper_bound = self._to_attr(prim_type(md.bound_limit.upper)) + + default_value = None + allowed_values = None + if md.is_optional and md.default is None: + raise ValueError(f"no default value for optional field, {md.name}") + if md.default is not None: + if is_list and not isinstance(md.default, list): + raise ValueError("Default value for list must be a list") + + # make sure the default type is correct + if not isinstance(md.default, list): + md.default = md.type(md.default) + else: + for idx, v in enumerate(md.default): + md.default[idx] = prim_type(md.default[idx]) + if md.choices is not None: + values = md.default if is_list else [md.default] + for v in values: + if v not in md.choices: + raise ValueError( + f"Default value for {v} must be one of {md.choices}" + ) + default_value = self._to_attr(md.default, prim_type) + + if md.choices is not None: + allowed_values = self._to_attr(md.choices, prim_type) + + atomic = AttributeDef.AtomicAttrDesc( + default_value=default_value, + allowed_values=allowed_values, + is_optional=md.is_optional, + list_min_length_inclusive=list_min_length_inclusive, + list_max_length_inclusive=list_max_length_inclusive, + lower_bound_enabled=lower_bound_enabled, + lower_bound_inclusive=lower_bound_inclusive, + lower_bound=lower_bound, + upper_bound_enabled=upper_bound_enabled, + upper_bound_inclusive=upper_bound_inclusive, + upper_bound=upper_bound, + ) + self._append_attr(attr_type, md, atomic=atomic) + + def _append_attr( + self, + typ: str, + md: _Metadata, + atomic=None, + union=None, + pb_cls=None, + col_params_binded_table=None, + ): + attr = AttributeDef( + type=typ, + name=md.name, + desc=md.desc, + prefixes=md.prefixes, + atomic=atomic, + union=union, + custom_protobuf_cls=pb_cls, + col_params_binded_table=col_params_binded_table, + minor_min=md.minor_min, + minor_max=md.minor_max, + ) + self._attrs.append(attr) + if typ not in [AttrType.ATTR_TYPE_UNSPECIFIED, AttrType.AT_STRUCT_GROUP]: + self._attr_types[md.fullname] = typ + + @staticmethod + def _check_list(field_type) -> tuple[bool, type]: + origin = get_origin(field_type) + if origin is list: + args = get_args(field_type) + if not args: + raise ValueError("list must have type.") + return (True, args[0]) + else: + return (False, field_type) + + def _build_metadata( + self, field: DField, prefixes: list[str], parent: _Metadata = None + ) -> _Metadata: + if field.name in RESERVED: + raise ValueError(f"{field.name} is a reserved word.") + + if "md" not in field.metadata: + raise ValueError(f"md not exist in {field.name}, {field.metadata}") + md: _Metadata = field.metadata["md"] + md.name = field.name + md.type = field.type + md.prefixes = prefixes + md.fullname = Reflector._to_fullname(prefixes, field.name) + + assert ( + self._minor >= md.minor_min and self._minor >= md.minor_max + ), f"{self._minor} shoule be greater than {md.minor_min} and {md.minor_max}" + + if parent != None: + # inherit parent‘s minor_min version if it is zero + if md.minor_min == 0: + md.minor_min = parent.minor_min + elif md.minor_min < parent.minor_min: + raise ValueError( + f"minor version of {md.name} must be greater than or equal to {parent.minor_min}" + ) + return md + + @staticmethod + def _build_list_limit(is_list: bool, limit: Interval | None) -> tuple[int, int]: + if not is_list and limit is None: + # limit must be 1 if target type is not list + return (1, 1) + if limit is None: + return (0, -1) + + limit.enforce_closed() + list_min_length_inclusive = 0 + list_max_length_inclusive = -1 + if limit.lower != None: + assert limit.lower >= 0, f"list min size should be 1" + list_min_length_inclusive = int(limit.lower) + if limit.upper != None: + list_max_length_inclusive = int(limit.upper) + return (list_min_length_inclusive, list_max_length_inclusive) + + @staticmethod + def _to_attr_type(prim_type, is_list) -> str: + if prim_type is float: + return AttrType.AT_FLOATS if is_list else AttrType.AT_FLOAT + elif prim_type is int: + return AttrType.AT_INTS if is_list else AttrType.AT_INT + elif prim_type is str: + return AttrType.AT_STRINGS if is_list else AttrType.AT_STRING + elif prim_type is bool: + return AttrType.AT_BOOLS if is_list else AttrType.AT_BOOL + else: + return AttrType.ATTR_TYPE_UNSPECIFIED + + @staticmethod + def _to_attr(v: Any, prim_type: type | None = None) -> Attribute: + is_list = isinstance(v, list) + if prim_type == None: + if is_list: + raise ValueError(f"unknown list primitive type for {v}") + prim_type = type(v) + + if prim_type == bool: + return Attribute(bs=v) if is_list else Attribute(b=v) + elif prim_type == int: + return Attribute(i64s=v) if is_list else Attribute(i64=v) + elif prim_type == float: + return Attribute(fs=v) if is_list else Attribute(f=v) + elif prim_type == str: + return Attribute(ss=v) if is_list else Attribute(s=v) + else: + raise ValueError(f"unsupported primitive type {prim_type}") + + @staticmethod + def _to_fullname(prefixes: list, name: str) -> str: + if prefixes is not None and len(prefixes) > 0: + return "/".join(prefixes) + "/" + name + else: + return name + + @staticmethod + def _force_dataclass(cls): + if "__dataclass_params__" not in cls.__dict__: + dataclass(cls) + + +class Definition: + def __init__( + self, + cls: type[Component], + domain: str, + version: str, + name: str = "", + desc: str = None, + labels: dict[str, str | bool | int | float] = None, + ): + if not issubclass(cls, Component): + raise ValueError(f"{cls} must be subclass of Component") + + if name == "": + name = re.sub(r"(? LABEL_LEN_MAX or len(v) > LABEL_LEN_MAX: + raise ValueError( + f"length of {k} or {v} must be less than {LABEL_LEN_MAX} in {name}:{version}" + ) + labels[k] = v + + root_package = cls.__module__.split(".")[0] + + self.name = name + self.domain = domain + self.version = version + self.desc = clean_text(desc, no_line_breaks=False) + self.labels = labels + self.root_package = root_package + + self._minor = self.parse_minor(version) + self._comp_cls = cls + self._comp_id = self.build_id(domain, name, version) + + self._comp_def: ComponentDef = None + self._inputs_map: dict[str, IoDef] = None + self._attr_types: dict[str, AttrType] = None + self.reflect() + + def __str__(self) -> str: + return json_format.MessageToJson(self._comp_def, indent=0) + + @staticmethod + def build_id(domain: str, name: str, version: str) -> str: + return f"{domain}/{name}:{version}" + + @staticmethod + def parse_id(comp_id: str) -> tuple[str, str, str]: + pattern = r"(?P[^/]+)/(?P[^:]+):(?P.+)" + match = re.match(pattern, comp_id) + + if match: + return match.group("domain"), match.group("name"), match.group("version") + else: + raise ValueError(f"comp_id<{comp_id}> format is incorrect") + + @staticmethod + def parse_minor(version: str) -> int: + tokens = version.split(".") + if len(tokens) != 3: + raise ValueError(f"version must be in format of x.y.z, but got {version}") + minor = int(tokens[1]) + assert minor >= 0, f"invalid minor<{minor}>" + return minor + + @property + def component_id(self) -> str: + return self._comp_id + + @property + def component_cls(self) -> type[Component]: + return self._comp_cls + + @property + def component_def(self) -> ComponentDef: + if self._comp_def is None: + self.reflect() + return self._comp_def + + @staticmethod + def _get_io(io_defs: list[IoDef], minor: int) -> list[IoDef]: + result = [] + for io in io_defs: + if minor < io.minor_min: + continue + if is_deprecated_minor(io.minor_max, minor): + continue + result.append(io) + + return result + + def get_input_defs(self, minor: int) -> list[IoDef]: + return self._get_io(self.component_def.inputs, minor) + + def get_output_defs(self, minor: int) -> list[IoDef]: + return self._get_io(self.component_def.outputs, minor) + + def reflect(self): + r = Reflector(self._comp_cls, self.name, self._minor) + r.reflect() + self._comp_def = ComponentDef( + name=self.name, + desc=self.desc, + domain=self.domain, + version=self.version, + labels=self.labels, + inputs=r.get_inputs(), + outputs=r.get_outputs(), + attrs=r.get_attrs(), + ) + self._inputs_map = {io.name: io for io in r.get_inputs()} + self._attr_types = r.get_attr_types() + + def make_checkpoint_params(self, param: NodeEvalParam | dict) -> dict: + kwargs, minor = self._to_kwargs(param) + + args = {} + cls = self._comp_cls + for name, field in cls.__dataclass_fields__.items(): + if name == MINOR_NAME: + continue + md: _Metadata = field.metadata["md"] + if md.kind not in [ + FieldKind.BasicAttr, + FieldKind.TableColumnAttr, + FieldKind.Input, + ]: + continue + if is_deprecated_minor(md.minor_max, minor) or not md.is_checkpoint: + continue + + value = kwargs[md.fullname] if md.fullname in kwargs else md.default + args[md.fullname] = value + return args + + def make_component( + self, param: NodeEvalParam | dict, check_exist: bool = True + ) -> type[Component]: + kwargs, minor = self._to_kwargs(param) + + creator = Creator(check_exist=check_exist) + ins = creator.make(self._comp_cls, kwargs, minor) + return ins + + def _to_kwargs(self, param: NodeEvalParam | dict) -> tuple[dict, int]: + if isinstance(param, NodeEvalParam): + kwargs = self.parse_param(param) + elif isinstance(param, dict): + kwargs = {} + for k, v in param.items(): + self._fix_vertical_table(v) + + k = self._trim_input_prefix(k) + kwargs[k] = v + else: + raise ValueError(f"unsupported param type {type(param)}") + + if MINOR_NAME not in kwargs: + raise KeyError(f"kwargs must contain {MINOR_NAME}") + minor = int(kwargs.pop(MINOR_NAME)) + + return kwargs, minor + + def parse_param( + self, + param: NodeEvalParam, + input_params: list[DistData] = None, + output_params: list[str] | list[DistData] = None, + ) -> dict: + _, _, version = self.parse_id(param.comp_id) + minor = self.parse_minor(version) + + attrs = self._parse_attrs(param) + + assert all( + isinstance(item, DistData) for item in param.inputs + ), f"type of inputs must be DistData" + assert all( + isinstance(item, str) for item in param.output_uris + ), f"type of output_uris must be str" + + # parse input + if input_params is None: + input_params = param.inputs + input_defs = self.get_input_defs(minor) + input_size = len(input_params) + if len(input_defs) > 0 and input_defs[-1].is_variable: + in_var = input_defs[-1] + expected_min = len(input_defs) - 1 + in_var.variable_min + expected_max = len(input_defs) - 1 + in_var.variable_max + if input_size < expected_min: + raise ValueError( + f"input size<{input_size}> should be not less than {expected_min}" + ) + if in_var.variable_max > -1 and input_size > expected_max: + raise ValueError( + f"input size<{input_size}> should be not greater than {expected_max}" + ) + elif len(input_defs) != input_size: + raise ValueError( + f"input size<{input_size}> mismatch, expect {input_defs} but got {input_params}" + ) + + inputs = {} + + for idx, io_def in enumerate(input_defs): + if io_def.is_variable: + assert idx == len(input_defs) - 1 + sub_params = input_params[idx:] + else: + sub_params = [input_params[idx]] + + # check type + for in_param in sub_params: + self._fix_vertical_table(in_param) + if in_param.type not in io_def.types: + raise ValueError( + f"input type<{in_param.type}> mismatch, expect {io_def.types}" + ) + + if io_def.is_variable: + inputs[io_def.name] = sub_params + else: + inputs[io_def.name] = sub_params[0] + + # parse output + if output_params is None: + output_params = param.output_uris + + output_defs = self.get_output_defs(minor) + assert len(output_defs) == len( + output_params + ), f"input size<{len(output_params)}> mismatch, expect {output_defs} but got {output_params}" + + outputs = {} + for idx, io_def in enumerate(output_defs): + output_data = output_params[idx] + if isinstance(output_data, str): + output = Output(output_data, None) + elif isinstance(output_data, DistData): + output = Output("", output_data) + else: + raise ValueError( + f"unsupport output type<{type(output_data)}>, name={io_def.name}" + ) + outputs[io_def.name] = output + + return {**attrs, **inputs, **outputs, MINOR_NAME: minor} + + def _parse_attrs(self, param: NodeEvalParam) -> dict: + attrs = {} + for path, attr in zip(list(param.attr_paths), list(param.attrs)): + path = self._trim_input_prefix(path) + if path not in self._attr_types: + raise KeyError(f"unknown attr key {path}") + at = self._attr_types[path] + attrs[path] = self._from_attr(attr, at) + return attrs + + def _trim_input_prefix(self, p: str) -> str: + if p.startswith("input/"): + tokens = p.split("/", maxsplit=3) + if len(tokens) != 3: + raise ValueError(f"invalid input, {p}") + assert ( + tokens[1] in self._inputs_map + ), f"unknown input table name<{p}> in {self.component_id}" + key = "_".join(tokens[1:]) + if key in self._attr_types: + return key + return tokens[2] + return p + + @staticmethod + def _from_attr(value: Attribute, at: AttrType) -> Any: + if at == AttrType.ATTR_TYPE_UNSPECIFIED: + raise ValueError("Type of Attribute is undefined.") + elif at == AttrType.AT_FLOAT: + return value.f + elif at == AttrType.AT_INT: + return value.i64 + elif at == AttrType.AT_STRING: + return value.s + elif at == AttrType.AT_BOOL: + return value.b + elif at == AttrType.AT_FLOATS: + return list(value.fs) + elif at == AttrType.AT_INTS: + return list(value.i64s) + elif at == AttrType.AT_BOOLS: + return list(value.bs) + elif at == AttrType.AT_CUSTOM_PROTOBUF: + return value.s + elif at == AttrType.AT_UNION_GROUP: + return value.s + elif at in [AttrType.AT_STRINGS, AttrType.AT_PARTY, AttrType.AT_COL_PARAMS]: + return list(value.ss) + elif at == AttrType.AT_STRUCT_GROUP: + raise ValueError(f"AT_STRUCT_GROUP should be ignore") + else: + raise ValueError(f"unsupported type: {at}.") + + @staticmethod + def _fix_vertical_table(dd: DistData): + if not isinstance(dd, DistData): + return + if dd.type == DistDataType.OUTDATED_VERTICAL_TABLE: + dd.type = DistDataType.VERTICAL_TABLE diff --git a/secretflow_spec/core/discovery.py b/secretflow_spec/core/discovery.py new file mode 100644 index 0000000..bb0e959 --- /dev/null +++ b/secretflow_spec/core/discovery.py @@ -0,0 +1,101 @@ +# Copyright 2024 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import ast +import glob +import importlib +import importlib.metadata +import logging +import os +import sys + + +def _check_module_usage(file_path: str) -> bool: + module_name = "secretflow_spec" + with open(file_path, "r", encoding="utf-8") as file: + file_content = file.read() + + has_parse_import = False + tree = ast.parse(file_content) + for node in ast.walk(tree): + if isinstance(node, ast.Import): + has_parse_import = True + for alias in node.names: + if alias.name.startswith(module_name): + return True + elif isinstance(node, ast.ImportFrom): + has_parse_import = True + # Do not compare module names because of indirect import + if any(item.name == "Component" for item in node.names): + return True + elif has_parse_import: + # early stop, just parse import of file header. + return False + + return False + + +def load_component_modules( + root_path: str, + module_prefix: str = "", + ignore_dirs: list[str] = [], + ignore_keys: list[str] = [], + ignore_root_files: bool = True, + verbose: bool = False, +): + if root_path not in sys.path: + sys.path.append(root_path) + + def is_ignore_file(file): + for key in ignore_keys: + if key in file: + return True + return False + + if ignore_root_files: + root_dirs = [ + f + for f in os.listdir(root_path) + if os.path.isdir(os.path.join(root_path, f)) + ] + root_dirs = [x for x in root_dirs if x not in ignore_dirs] + else: + root_dirs = [root_path] + + for dir_name in root_dirs: + if dir_name.startswith("__"): # ignore __pycache__ + continue + pattern = os.path.join(root_path, dir_name, "**/*.py") + for pyfile in glob.glob(pattern, recursive=True): + if pyfile.endswith("__init__.py") or is_ignore_file(pyfile): + continue + if not _check_module_usage(pyfile): + continue + + module_name = ( + os.path.relpath(pyfile, root_path) + .removesuffix(".py") + .replace(os.path.sep, ".") + ) + if module_prefix: + module_name = f"{module_prefix}.{module_name}" + + try: + importlib.import_module(module_name) + if verbose: + logging.warning(f"import component {module_name} from {pyfile}") + except Exception as e: + raise ValueError( + f"import component fail, file={pyfile}, module={module_name}, err={e}" + ) diff --git a/secretflow_spec/core/dist_data/__init__.py b/secretflow_spec/core/dist_data/__init__.py new file mode 100644 index 0000000..086637a --- /dev/null +++ b/secretflow_spec/core/dist_data/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/secretflow_spec/core/dist_data/base.py b/secretflow_spec/core/dist_data/base.py new file mode 100644 index 0000000..0e58e49 --- /dev/null +++ b/secretflow_spec/core/dist_data/base.py @@ -0,0 +1,33 @@ +# Copyright 2024 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import enum + +from secretflow_spec.core.types import StrEnum + + +@enum.unique +class DistDataType(StrEnum): + """ + builtin distdata type + """ + + # tables + OUTDATED_VERTICAL_TABLE = "sf.table.vertical_table" # deprecated + VERTICAL_TABLE = "sf.table.vertical" + INDIVIDUAL_TABLE = "sf.table.individual" + # report + REPORT = "sf.report" + # if input of component is optional, then the corresponding type can be NULL + NULL = "sf.null" diff --git a/secretflow_spec/core/dist_data/file.py b/secretflow_spec/core/dist_data/file.py new file mode 100644 index 0000000..f9d7d45 --- /dev/null +++ b/secretflow_spec/core/dist_data/file.py @@ -0,0 +1,97 @@ +# Copyright 2024 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import json +from typing import Any + +from secretflow_spec.core.types import Version +from secretflow_spec.core.version import SPEC_VERSION +from secretflow_spec.v1.data_pb2 import DistData, ObjectFileInfo, SystemInfo + + +class ObjectFile: + def __init__( + self, + name: str, + type: str, + data_refs: list[DistData.DataRef], + version: Version, + public_info: Any, + attributes: dict[str, str] = None, + system_info: SystemInfo = None, + ): + self.name = name + self.type = type + self.version = version + self.public_info = public_info + self.attributes = attributes + self.data_refs = data_refs + self.system_info = system_info + + @staticmethod + def from_distdata(dd: DistData) -> "ObjectFile": + meta = ObjectFileInfo() + dd.meta.Unpack(meta) + attributes = dict(meta.attributes) + if not ("version" in attributes and "public_info" in attributes): + raise ValueError(f"invalid FileInfo format {attributes}") + version = Version.from_str(attributes.pop("version")) + public_info = json.loads(attributes.pop("public_info")) + + return ObjectFile( + name=dd.name, + type=dd.type, + data_refs=list(dd.data_refs), + version=version, + public_info=public_info, + attributes=attributes, + system_info=dd.system_info, + ) + + def to_distdata(self) -> "DistData": + if self.name == "": + raise ValueError(f"dist_data file name is empty") + if self.type == "": + raise ValueError(f"dist_data type is empty") + + attributes = { + "version": str(self.version), + "public_info": json.dumps(self.public_info), + } + if self.attributes: + attributes.update(self.attributes) + meta = ObjectFileInfo(attributes=attributes) + + dd = DistData( + version=SPEC_VERSION, + name=self.name, + type=self.type, + data_refs=self.data_refs, + system_info=self.system_info, + ) + dd.meta.Pack(meta) + return dd + + def check(self, file_type: str = None, max_version: Version = None): + if file_type and file_type != self.type: + raise ValueError(f"type mismatch, expect {file_type} but got {self.type}") + + if max_version and not ( + max_version.major == self.version.major + and max_version.minor >= self.version.minor + ): + raise ValueError( + f"max_version mismatch, expect {max_version} but got {self.version}" + ) diff --git a/secretflow_spec/core/dist_data/report.py b/secretflow_spec/core/dist_data/report.py new file mode 100644 index 0000000..09451a6 --- /dev/null +++ b/secretflow_spec/core/dist_data/report.py @@ -0,0 +1,196 @@ +# Copyright 2024 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from secretflow_spec.core.utils import to_attribute, to_type +from secretflow_spec.core.version import SPEC_VERSION +from secretflow_spec.v1.data_pb2 import DistData, SystemInfo +from secretflow_spec.v1.report_pb2 import Descriptions, Div, Report, Tab, Table + +from .base import DistDataType + + +class Reporter: + def __init__( + self, + name: str = "", + desc: str = "", + tabs: list[Tab] = None, + system_info: SystemInfo = None, + type: str = DistDataType.REPORT, + ) -> None: + self._name = name + self._desc = desc + self._tabs = tabs if tabs else [] + self._system_info = system_info + self._type = str(type) + + @staticmethod + def from_distdata(dd: DistData) -> "Reporter": + if dd.meta: + report = Report() + dd.meta.Unpack(report) + return Reporter( + report.name, report.desc, report.tabs, dd.system_info, type=dd.type + ) + + return Report(dd.name) + + def to_distdata(self) -> DistData: + dd = DistData( + version=SPEC_VERSION, + name=self._name, + type=self._type, + system_info=self._system_info, + ) + meta = self.report() + if meta: + dd.meta.Pack(meta) + + return dd + + def report(self) -> Report: + if self._tabs: + return Report(name=self._name, desc=self._desc, tabs=self._tabs) + return None + + def add_tab( + self, + obj: list[Div] | Div | Table | Descriptions | dict, + name: str = None, + desc: str = None, + ): + divs: list[Div] = [] + if isinstance(obj, list): + assert all( + isinstance(item, Div) for item in obj + ), f"all item should be instance of Div, {obj}" + divs = obj + elif isinstance(obj, Div): + divs.append(obj) + else: + child = self.build_div_child(obj) + divs.append(Div(children=[child])) + + self._tabs.append(Tab(name=name, desc=desc, divs=divs)) + + @staticmethod + def build_table( + obj: dict, + name: str = None, + desc: str = None, + columns: dict[str, Table.HeaderItem | str] = None, + index: list[str] = None, + prefix: str = "", + ) -> Table: + """ + name: table name + desc: table description + columns: columns header info, if type of dict value is str, it represents column description + index: row index + prefix: row index name prefix + """ + pb_headers, pb_rows = [], [] + df = _to_dict(obj) + for col_name in df.keys(): + dtype = _to_type_str(df[col_name][0]) + if columns and col_name in columns: + v = columns[col_name] + header = ( + v + if isinstance(v, Table.HeaderItem) + else Table.HeaderItem(name=col_name, desc=v, type=dtype) + ) + else: + header = Table.HeaderItem(name=col_name, desc="", type=dtype) + pb_headers.append(header) + + row_size = len(next(iter(df.values()))) + for idx in range(row_size): + items = [] + for k in df.keys(): + value = df[k][idx] + items.append(to_attribute(value)) + idx_name = index[idx] if index and idx < len(index) else str(idx) + row_name = f"{prefix}{idx_name}" + pb_rows.append(Table.Row(name=row_name, items=items)) + return Table(name=name, desc=desc, headers=pb_headers, rows=pb_rows) + + @staticmethod + def build_descriptions( + values: dict[str, int | float | bool | str], name: str = None, desc: str = None + ) -> Descriptions: + items = [ + Descriptions.Item(name=k, type=_to_type_str(v), value=to_attribute(v)) + for k, v in values.items() + ] + return Descriptions(name=name, desc=desc, items=items) + + @staticmethod + def build_div_child(obj: Table | Descriptions | Div | dict) -> Div.Child: + if isinstance(obj, Table): + return Div.Child(type="table", table=obj) + elif isinstance(obj, Descriptions): + return Div.Child(type="descriptions", descriptions=obj) + elif isinstance(obj, Div): + return Div.Child(type="div", div=obj) + else: + obj = _to_dict(obj) + if _is_table_dict(obj): + table = Reporter.build_table(obj) + return Div.Child(type="table", table=table) + else: + descriptions = Reporter.build_descriptions(obj) + return Div.Child(type="descriptions", descriptions=descriptions) + + @staticmethod + def build_div( + obj: Table | Descriptions | Div | dict, name: str = None, desc: str = None + ) -> Div: + child = Reporter.build_div_child(obj) + return Div(name=name, desc=desc, children=[child]) + + +def _to_type_str(dt) -> str: + dt = to_type(dt) + return dt.__name__ + + +def _is_table_dict(value: dict) -> bool: + if not value: + return False + + # {"A": [1,2,3], "B": [0.1,0.2,0.3]} + sizes = set() + for x in value.values(): + if not isinstance(x, list): + return False + sizes.add(len(x)) + + if len(sizes) != 1 or list(sizes)[0] < 1: + return False + + return True + + +def _to_dict(obj) -> dict: + if isinstance(obj, dict): + return obj + + # only support pd.DataFrame + if hasattr(obj, "to_dict"): + method = getattr(obj, "to_dict") + if callable(method): + return method(orient="list") + + raise ValueError(f"unsupport type, {type(obj)}") diff --git a/secretflow_spec/core/dist_data/vtable.py b/secretflow_spec/core/dist_data/vtable.py new file mode 100644 index 0000000..972ea0d --- /dev/null +++ b/secretflow_spec/core/dist_data/vtable.py @@ -0,0 +1,553 @@ +# Copyright 2024 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import enum +from dataclasses import dataclass +from enum import IntFlag + +from secretflow_spec.core.types import StrEnum +from secretflow_spec.core.version import SPEC_VERSION +from secretflow_spec.v1.data_pb2 import ( + DistData, + IndividualTable, + SystemInfo, + TableSchema, + VerticalTable, +) + +from .base import DistDataType + + +@enum.unique +class VTableFormat(StrEnum): + CSV = "csv" + ORC = "orc" + + +_same_type = { + "int": "int64", + "float": "float64", + "int64": "int", + "float64": "float", +} + + +@enum.unique +class VTableFieldType(StrEnum): + STR = "str" + BOOL = "bool" + INT = "int" + FLOAT = "float" + INT8 = "int8" + INT16 = "int16" + INT32 = "int32" + INT64 = "int64" + UINT8 = "uint8" + UINT16 = "uint16" + UINT32 = "uint32" + UINT64 = "uint64" + FLOAT16 = "float16" + FLOAT32 = "float32" + FLOAT64 = "float64" + + def is_string(self) -> bool: + return self.value == "str" + + def is_bool(self) -> bool: + return self.value == "bool" + + def is_integer(self) -> bool: + v = str(self.value) + return v.startswith("int") or v.startswith("uint") + + def is_float(self) -> bool: + v = str(self.value) + return v.startswith("float") + + @staticmethod + def is_same_type(t1: str, t2: str) -> bool: + return t1 == t2 or (t1 in _same_type and _same_type[t1] == t2) + + +class VTableFieldKind(IntFlag): + UNKNOWN = 0 + FEATURE = 1 << 0 + LABEL = 1 << 1 + ID = 1 << 2 + + FEATURE_LABEL = FEATURE | LABEL + ALL = FEATURE | LABEL | ID + + @staticmethod + def from_str(str_value: str) -> "VTableFieldKind": + if str_value == "UNKNOWN": + return VTableFieldKind.UNKNOWN + + value = VTableFieldKind.UNKNOWN + fields = str_value.split("|") + for key in fields: + value |= VTableFieldKind[key].value + + return VTableFieldKind(value) + + def __str__(self): + if self.value == VTableFieldKind.UNKNOWN: + return "UNKNOWN" + + members = [VTableFieldKind.FEATURE, VTableFieldKind.LABEL, VTableFieldKind.ID] + fields = [m.name for m in members if self.value & m] + + return "|".join(fields) + + +@dataclass +class VTableField: + name: str + type: VTableFieldType + kind: VTableFieldKind + + def __post_init__(self): + self.type = VTableFieldType(self.type) + + +class VTableSchema: + def __init__(self, fields: list[VTableField] | dict[str, VTableField]) -> None: + if isinstance(fields, list): + fields = {f.name: f for f in fields} + self.fields: dict[str, VTableField] = fields + + def __getitem__(self, key: int | str) -> VTableField: + return self.get_field(key) + + def __eq__(self, value: object) -> bool: + if isinstance(value, VTableSchema): + return self.fields == value.fields + + return False + + def __contains__(self, keys: list[str] | str) -> bool: + if isinstance(keys, list): + return all(item in self.fields for item in keys) + return keys in self.fields + + @property + def names(self) -> list[str]: + return [f.name for f in self.fields.values()] + + @property + def kinds(self) -> dict[str, VTableFieldKind]: + return {f.name: f.kind for f in self.fields.values()} + + @property + def types(self) -> dict[str, str]: + return {f.name: f.type for f in self.fields.values()} + + def get_field(self, key: int | str) -> VTableField: + if isinstance(key, int): + keys = self.fields.keys() + key = next(iter(keys)) if key == 0 else list(keys)[key] + + return self.fields[key] + + def select(self, columns: list[str]) -> "VTableSchema": + fields = {n: self.fields[n] for n in columns} + return VTableSchema(fields) + + @staticmethod + def from_dict( + features: dict[str, str] = None, + labels: dict[str, str] = None, + ids: dict[str, str] = None, + ) -> "VTableSchema": + kinds = [VTableFieldKind.FEATURE, VTableFieldKind.LABEL, VTableFieldKind.ID] + values = [features, labels, ids] + fields = [] + for kind, value in zip(kinds, values): + if not value: + continue + fields.extend([VTableField(name, typ, kind) for name, typ in value.items()]) + + return VTableSchema(fields) + + @staticmethod + def from_pb_str(pb_str: str) -> "VTableSchema": + pb = TableSchema() + pb.ParseFromString(pb_str) + return VTableSchema.from_pb(pb) + + @staticmethod + def from_pb(schema: TableSchema) -> "VTableSchema": + fields: list[VTableField] = [] + kind_list = [VTableFieldKind.ID, VTableFieldKind.FEATURE, VTableFieldKind.LABEL] + name_list = [schema.ids, schema.features, schema.labels] + type_list = [schema.id_types, schema.feature_types, schema.label_types] + for kind, names, types in zip(kind_list, name_list, type_list): + res = [VTableField(n, t, kind) for n, t in zip(names, types)] + fields.extend(res) + return VTableSchema(fields) + + def to_pb(self) -> TableSchema: + features, feature_types = [], [] + labels, label_types = [], [] + ids, id_types = [], [] + + for f in self.fields.values(): + if f.kind == VTableFieldKind.FEATURE: + feature_types.append(str(f.type)) + features.append(f.name) + elif f.kind == VTableFieldKind.LABEL: + label_types.append(str(f.type)) + labels.append(f.name) + elif f.kind == VTableFieldKind.ID: + id_types.append(str(f.type)) + ids.append(f.name) + else: + raise ValueError(f"invalid vtable field kind: {f}") + + return TableSchema( + features=features, + feature_types=feature_types, + labels=labels, + label_types=label_types, + ids=ids, + id_types=id_types, + ) + + +@dataclass +class VTableParty: + party: str = "" + uri: str = "" + format: str = "" + null_strs: list = None + schema: VTableSchema = None + + @property + def columns(self) -> list[str]: + return self.schema.names + + @property + def kinds(self) -> dict[str, VTableFieldKind]: + return self.schema.kinds + + @property + def types(self) -> dict[str, str]: + return self.schema.types + + @staticmethod + def from_dict( + party: str = "", + format: str = "", + uri: str = "", + null_strs: list = None, + features: dict[str, str] = None, + labels: dict[str, str] = None, + ids: dict[str, str] = None, + ) -> "VTableParty": + return VTableParty( + party=party, + uri=uri, + format=format, + null_strs=null_strs, + schema=VTableSchema.from_dict(features, labels, ids), + ) + + @staticmethod + def from_pb(dr: DistData.DataRef, pb_schema: TableSchema) -> "VTableParty": + return VTableParty( + party=dr.party, + uri=dr.uri, + format=dr.format, + null_strs=list(dr.null_strs), + schema=VTableSchema.from_pb(pb_schema), + ) + + def to_pb(self) -> tuple[DistData.DataRef, TableSchema]: + pb_dr = DistData.DataRef( + party=self.party, + uri=self.uri, + format=self.format, + null_strs=self.null_strs, + ) + pb_schema = self.schema.to_pb() + return pb_dr, pb_schema + + +class VTable: + def __init__( + self, + name: str, + parties: dict[str, VTableParty] | list[VTableParty], + line_count: int = -1, + system_info: SystemInfo = None, + ): + type = ( + DistDataType.INDIVIDUAL_TABLE + if len(parties) == 1 + else DistDataType.VERTICAL_TABLE + ) + + if isinstance(parties, list): + parties = {dr.party: dr for dr in parties} + + self.name = name + self.type = str(type) + self.parties = parties + self.line_count = line_count + self.system_info = system_info + + @property + def is_individual(self) -> bool: + return self.type == DistDataType.INDIVIDUAL_TABLE + + @property + def columns(self) -> list[str]: + ret = [] + for p in self.schemas.values(): + ret.extend(p.names) + return ret + + @property + def schemas(self) -> dict[str, VTableSchema]: + return {name: p.schema for name, p in self.parties.items()} + + @property + def flatten_schema(self) -> VTableSchema: + if len(self.parties) == 1: + return next(iter(self.parties.values())).schema + else: + fields = [] + for s in self.parties.values(): + fields.extend(s.schema.fields.values()) + return VTableSchema(fields) + + def get_party(self, key: str | int) -> VTableParty: + if isinstance(key, int): + keys = self.parties.keys() + key = next(iter(keys)) if key == 0 else list(keys)[key] + + return self.parties[key] + + def get_schema(self, key: str | int) -> VTableSchema: + return self.get_party(key).schema + + @staticmethod + def from_output_uri( + output_uri: str, + schemas: dict[str, VTableSchema], + line_count: int = -1, + name: str = None, + format: str = VTableFormat.ORC, + null_strs: list[str] = None, + system_info: SystemInfo = None, + ) -> "VTable": + assert len(schemas) > 0, f"empty schema, uri={output_uri}" + parties = { + name: VTableParty( + party=name, + uri=output_uri, + format=str(format), + null_strs=null_strs, + schema=s, + ) + for name, s in schemas.items() + } + return VTable( + name=name if name else output_uri, + parties=parties, + line_count=line_count, + system_info=system_info, + ) + + @staticmethod + def from_distdata(dd: DistData, columns: list[str] = None) -> "VTable": + dd_type = dd.type.lower() + if dd_type not in [ + DistDataType.VERTICAL_TABLE, + DistDataType.INDIVIDUAL_TABLE, + ]: + raise ValueError(f"Unsupported DistData type {dd_type}") + # parse meta + is_individual = dd.type == DistDataType.INDIVIDUAL_TABLE + meta = IndividualTable() if is_individual else VerticalTable() + dd.meta.Unpack(meta) + pb_schemas = [meta.schema] if is_individual else meta.schemas + if len(pb_schemas) == 0: + raise ValueError(f"empty schema") + if len(dd.data_refs) != len(pb_schemas): + raise ValueError( + f"schemas<{len(pb_schemas)}> and data_refs<{len(dd.data_refs)}> mismatch" + ) + + parties = { + dr.party: VTableParty.from_pb(dr, ps) + for dr, ps in zip(dd.data_refs, pb_schemas) + } + + vtbl = VTable( + name=dd.name, + parties=parties, + system_info=dd.system_info, + line_count=meta.line_count, + ) + if columns: + vtbl = vtbl.select(columns) + return vtbl + + def to_distdata(self) -> DistData: + pb_schemas = [] + pb_data_refs = [] + for p in self.parties.values(): + pb_dr, pb_schema = p.to_pb() + pb_data_refs.append(pb_dr) + pb_schemas.append(pb_schema) + + if len(pb_schemas) == 1: + meta = IndividualTable(schema=pb_schemas[0], line_count=self.line_count) + else: + meta = VerticalTable(schemas=pb_schemas, line_count=self.line_count) + dd = DistData( + version=SPEC_VERSION, + name=self.name, + type=self.type, + system_info=self.system_info, + data_refs=pb_data_refs, + ) + if meta: + dd.meta.Pack(meta) + return dd + + def _copy(self, schemas: dict[str, VTableSchema]) -> "VTable": + parties = {} + for key, schema in schemas.items(): + p = self.parties[key] + parties[key] = VTableParty( + party=p.party, + uri=p.uri, + format=p.format, + null_strs=p.null_strs, + schema=schema, + ) + + return VTable( + name=self.name, + parties=parties, + line_count=self.line_count, + system_info=self.system_info, + ) + + def sort_partitions(self, orders: list[str]) -> "VTable": + if set(orders) != set(self.parties.keys()): + raise ValueError(f"parties mismatch, {orders}<>{self.parties.keys()}") + + parties = {} + for key in orders: + parties[key] = self.parties[key] + + return VTable(self.name, parties, self.line_count, self.system_info) + + def drop(self, columns: list[str]) -> "VTable": + """ + drop some columns, return new VTable + """ + if not columns: + raise ValueError(f"empty exclude columns set") + + excludes_set = set(columns) + schemas = {} + for party, p in self.parties.items(): + if len(excludes_set) == 0: + schemas[party] = p.schema + break + fields = {} + for f in p.schema.fields.values(): + if f.name in excludes_set: + excludes_set.remove(f.name) + continue + fields[f.name] = f + if len(fields) == 0: + continue + schemas[party] = VTableSchema(fields) + + if len(excludes_set) > 0: + raise ValueError(f"unknowns columns, {excludes_set}") + + return self._copy(schemas) + + def select(self, columns: list[str]) -> "VTable": + """ + select and sort by column names, return new VTable + """ + if not columns: + raise ValueError(f"columns cannot be empty") + + seen = set() + duplicates = set(x for x in columns if x in seen or seen.add(x)) + if duplicates: + raise f"has duplicate items<{duplicates}> in {columns}" + + columns_map = {name: idx for idx, name in enumerate(columns)} + + schemas = {} + for party, p in self.parties.items(): + fields = { + name: field + for name, field in p.schema.fields.items() + if field.name in columns_map + } + if len(fields) == 0: + continue + + # sort by keys + fields = {n: fields[n] for n in columns_map.keys() if n in fields} + + for n in fields.keys(): + del columns_map[n] + + schemas[party] = VTableSchema(fields) + if len(columns_map) == 0: + continue + + if len(columns_map) > 0: + raise ValueError(f"unknowns columns, {columns_map.keys()}") + + return self._copy(schemas) + + def select_by_kinds(self, kinds: VTableFieldKind) -> "VTable": + if kinds == VTableFieldKind.ALL: + return self + + schemas = {} + for party, p in self.parties.items(): + schema = p.schema + fields = { + name: field + for name, field in schema.fields.items() + if field.kind & kinds + } + if len(fields) > 0: + schemas[party] = VTableSchema(fields) + + return self._copy(schemas) + + def check_kinds(self, kinds: VTableFieldKind): + assert kinds != 0 and kinds != VTableFieldKind.ALL + mismatch = {} + for p in self.parties.values(): + for f in p.schema.fields.values(): + if not (kinds & f.kind): + mismatch[f.name] = str(f.kind) + + if len(mismatch) > 0: + raise ValueError(f"kind of {mismatch} mismatch, expected {kinds}") diff --git a/secretflow_spec/core/registry.py b/secretflow_spec/core/registry.py new file mode 100644 index 0000000..839239b --- /dev/null +++ b/secretflow_spec/core/registry.py @@ -0,0 +1,130 @@ +# Copyright 2024 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from collections import defaultdict +from typing import Iterable + +from secretflow_spec.core.component import Component +from secretflow_spec.core.definition import Definition +from secretflow_spec.core.version import SPEC_VERSION +from secretflow_spec.v1.component_pb2 import CompListDef + +_reg_defs_by_key: dict[str, Definition] = {} +_reg_defs_by_cls: dict[str, Definition] = {} +_reg_defs_by_pkg: dict[str, list[Definition]] = defaultdict(list) + + +def _parse_major(version: str) -> str: + tokens = version.split(".") + if len(tokens) != 3: + raise ValueError(f"version must be in format of x.y.z, but got {version}") + return tokens[0] + + +def _gen_reg_key(domain: str, name: str, version: str) -> str: + return f"{domain}/{name}:{_parse_major(version)}" + + +def _gen_class_id(cls: Component | type[Component]) -> str: + if isinstance(cls, Component): + cls = type(cls) + return f"{cls.__module__}:{cls.__qualname__}" + + +class Registry: + @staticmethod + def register(d: Definition): + key = _gen_reg_key(d.domain, d.name, d.version) + if key in _reg_defs_by_key: + raise ValueError(f"{key} is already registered") + class_id = _gen_class_id(d.component_cls) + _reg_defs_by_key[key] = d + _reg_defs_by_cls[class_id] = d + _reg_defs_by_pkg[d.root_package].append(d) + + @staticmethod + def unregister(domain: str, name: str, version: str) -> bool: + key = _gen_reg_key(domain, name, version) + if key not in _reg_defs_by_key: + return False + d = _reg_defs_by_key.pop(key) + class_id = _gen_class_id(d.component_cls) + del _reg_defs_by_cls[class_id] + _reg_defs_by_pkg[d.root_package].remove(d) + return True + + @staticmethod + def get_definition(domain: str, name: str, version: str) -> Definition: + key = _gen_reg_key(domain, name, version) + return _reg_defs_by_key.get(key) + + @staticmethod + def get_definitions(root_pkg: str = None) -> Iterable[Definition]: + if root_pkg and root_pkg != "*": + return _reg_defs_by_pkg.get(root_pkg, None) + + return _reg_defs_by_key.values() + + @staticmethod + def get_definition_keys() -> Iterable[str]: + return _reg_defs_by_key.keys() + + @staticmethod + def get_definition_by_key(key: str) -> Definition: + return _reg_defs_by_key.get(key) + + @staticmethod + def get_definition_by_id(id: str) -> Definition: + prefix, version = id.split(":") + key = f"{prefix}:{_parse_major(version)}" + comp_def = _reg_defs_by_key.get(key) + + return comp_def + + @staticmethod + def get_definition_by_class(cls: Component | type[Component]) -> Definition: + class_id = _gen_class_id(cls) + return _reg_defs_by_cls.get(class_id) + + @staticmethod + def build_comp_list_def( + name: str, + desc: str, + components: Iterable[Definition], + version: str = SPEC_VERSION, + ) -> CompListDef: + comps = [d.component_def for d in components] + comps = sorted(comps, key=lambda k: (k.domain, k.name, k.version)) + return CompListDef(name=name, desc=desc, version=version, comps=comps) + + +def register( + domain: str, + version: str, + name: str = "", + desc: str = None, + labels: dict[str, str | bool | int | float] = None, +): + if domain == "" or version == "": + raise ValueError( + f"domain<{domain}> and version<{version}> cannot be empty in register" + ) + + def wrap(cls): + d = Definition(cls, domain, version, name, desc, labels=labels) + Registry.register(d) + return cls + + return wrap diff --git a/secretflow_spec/core/storage/__init__.py b/secretflow_spec/core/storage/__init__.py new file mode 100644 index 0000000..b09c172 --- /dev/null +++ b/secretflow_spec/core/storage/__init__.py @@ -0,0 +1,38 @@ +# Copyright 2024 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from secretflow_spec.v1.data_pb2 import StorageConfig + +from .base import Storage, StorageType +from .local import LocalStorage +from .s3 import S3Storage + + +def make_storage(config: StorageConfig) -> Storage: + if config.type == StorageType.LOCAL_FS: + return LocalStorage(config) + elif config.type == StorageType.S3: + return S3Storage(config) + else: + raise ValueError(f"unsupported storage type{config.type}") + + +__all__ = [ + "make_storage", + "StorageType", + "Storage", + "S3Storage", + "LocalStorage", +] diff --git a/secretflow_spec/core/storage/base.py b/secretflow_spec/core/storage/base.py new file mode 100644 index 0000000..5aad63f --- /dev/null +++ b/secretflow_spec/core/storage/base.py @@ -0,0 +1,67 @@ +# Copyright 2024 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from io import BufferedIOBase + +from secretflow_spec.core.types import StrEnum +from secretflow_spec.v1.data_pb2 import StorageConfig + + +class StorageType(StrEnum): + LOCAL_FS = "local_fs" + S3 = "s3" + + +class Storage(ABC): + def __init__(self, config: StorageConfig) -> None: + self.config = config + + @abstractmethod + def get_type(self) -> StorageType: + pass + + @abstractmethod + def get_size(self, path: str) -> int: + pass + + @abstractmethod + def get_full_path(self, path: str) -> str: + pass + + @abstractmethod + def get_reader(self, path: str) -> BufferedIOBase: + pass + + @abstractmethod + def get_writer(self, path: str) -> BufferedIOBase: + pass + + @abstractmethod + def remove(self, path: str) -> None: + pass + + @abstractmethod + def exists(self, path: str) -> bool: + pass + + @abstractmethod + def download_file(self, remote_path: str, local_path: str) -> None: + """blocked download whole file into local_path, overwrite if local_path exist""" + pass + + @abstractmethod + def upload_file(self, local_path: str, remote_path: str) -> None: + """blocked upload_file whole file into remote_path, overwrite if remote_path exist""" + pass diff --git a/secretflow_spec/core/storage/local.py b/secretflow_spec/core/storage/local.py new file mode 100644 index 0000000..75f9f29 --- /dev/null +++ b/secretflow_spec/core/storage/local.py @@ -0,0 +1,103 @@ +# Copyright 2024 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import shutil +from io import BufferedIOBase +from pathlib import Path + +from secretflow_spec.v1.data_pb2 import StorageConfig + +from .base import Storage, StorageType + + +class LocalStorage(Storage): + def __init__(self, config: StorageConfig) -> None: + super().__init__(config) + assert config.type == "local_fs" + self._local_wd = config.local_fs.wd + + def get_full_path(self, remote_fn) -> str: + full_path = os.path.join(self._local_wd, remote_fn) + full_path = os.path.normpath(full_path) + full_path = os.path.abspath(full_path) + return full_path + + def get_type(self) -> StorageType: + return StorageType.LOCAL_FS + + def get_size(self, path: str) -> int: + full_path = self.get_full_path(path) + return os.path.getsize(full_path) + + def get_reader(self, path: str) -> BufferedIOBase: + return self.open(path, "rb") + + def get_writer(self, path: str) -> BufferedIOBase: + return self.open(path, "wb") + + def open(self, path: str, mode: str) -> BufferedIOBase: + full_path = self.get_full_path(path) + if "w" in mode: + Path(full_path).parent.mkdir(parents=True, exist_ok=True) + try: + return open(full_path, mode) + except FileNotFoundError: + raise FileNotFoundError(f"{full_path} not found") + except IsADirectoryError: + raise IsADirectoryError(f"{full_path} is a directory") + except Exception as e: + raise e + + def remove(self, path: str) -> None: + full_path = self.get_full_path(path) + if not os.path.exists(full_path): + raise ValueError(f"{full_path} not exist") + return os.remove(full_path) + + def exists(self, path: str) -> bool: + full_path = self.get_full_path(path) + return os.path.exists(full_path) + + def mkdir(self, path: str) -> bool: + Path(path).mkdir(parents=True, exist_ok=True) + + def download_file(self, remote_path: str, local_path: str) -> None: + full_remote_path = self.get_full_path(remote_path) + if not os.path.exists(full_remote_path): + raise ValueError(f"file not exist {full_remote_path}") + if not os.path.isfile(full_remote_path): + raise ValueError(f"{full_remote_path} is not a file") + if os.path.exists(local_path): + if not os.path.isfile(local_path): + raise ValueError(f"{local_path} is not a file") + if os.path.samefile(full_remote_path, local_path): + return + Path(local_path).parent.mkdir(parents=True, exist_ok=True) + shutil.copyfile(full_remote_path, local_path) + + def upload_file(self, local_path: str, remote_path: str) -> None: + if not os.path.exists(local_path): + raise ValueError(f"{local_path} not exist.") + if not os.path.isfile(local_path): + raise ValueError(f"{local_path} is not a file") + full_remote_path = self.get_full_path(remote_path) + + if os.path.exists(full_remote_path): + if not os.path.isfile(full_remote_path): + raise ValueError(f"{full_remote_path} is not a file") + if os.path.samefile(full_remote_path, local_path): + return + Path(full_remote_path).parent.mkdir(parents=True, exist_ok=True) + shutil.copyfile(local_path, full_remote_path) diff --git a/secretflow_spec/core/storage/s3.py b/secretflow_spec/core/storage/s3.py new file mode 100644 index 0000000..54ee1ab --- /dev/null +++ b/secretflow_spec/core/storage/s3.py @@ -0,0 +1,149 @@ +# Copyright 2024 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +from io import BufferedIOBase + +import s3fs +from botocore import exceptions as s3_exceptions + +from secretflow_spec.v1.data_pb2 import StorageConfig + +from .base import Storage, StorageType + + +class S3Storage(Storage): + """ + s3 storage, please refer to https://s3fs.readthedocs.io/en/latest/ + """ + + def __init__(self, config: StorageConfig) -> None: + super().__init__(config) + assert config.type == "s3" + s3_config: StorageConfig.S3Config = config.s3 + + if s3_config.version == "": + s3_config.version = "s3v4" + if s3_config.version not in ["s3v4", "s3v2"]: + raise ValueError(f"Not support s3 version {s3_config.version}") + + if not s3_config.endpoint.startswith(("https://", "http://")): + raise ValueError( + f"Please specify the scheme(http or https) of endpoint<{s3_config.endpoint}>" + ) + + self._prefix = s3_config.prefix + self._bucket = s3_config.bucket + self._s3_client = s3fs.S3FileSystem( + anon=False, + key=s3_config.access_key_id, + secret=s3_config.access_key_secret, + client_kwargs={"endpoint_url": s3_config.endpoint}, + config_kwargs={ + "signature_version": s3_config.version, + "s3": { + "addressing_style": "virtual" if s3_config.virtual_host else "path" + }, + }, + ) + + try: + self._s3_client.ls(self._bucket, detail=False) + except s3_exceptions.UnknownSignatureVersionError as e: + logging.exception( + f"config.version {s3_config.version} not support by server" + ) + raise + except Exception as e: + self._log_s3_error(e) + raise + + def _log_s3_error(self, e: Exception, file_name: str = None) -> None: + if isinstance(e, FileNotFoundError): + if file_name: + logging.exception( + f"The file {file_name} in bucket {self._bucket} does not exist" + ) + else: + logging.exception(f"The specified bucket {self._bucket} does not exist") + elif isinstance(e, PermissionError): + logging.exception("Access denied, Check your key and signing method") + else: + logging.exception("Unknown error") + + def get_full_path(self, path: str) -> str: + return f"s3://{os.path.join(self._bucket, self._prefix, path)}" + + def get_type(self) -> StorageType: + return StorageType.S3 + + def get_size(self, path: str) -> int: + full_path = self.get_full_path(path) + try: + info = self._s3_client.stat(full_path) + return info["size"] + except Exception as e: + self._log_s3_error(e) + raise + + def get_reader(self, path: str) -> BufferedIOBase: + return self.open(path, "rb") + + def get_writer(self, path: str) -> BufferedIOBase: + return self.open(path, "wb") + + def open(self, path: str, mode: str) -> BufferedIOBase: + full_path = self.get_full_path(path) + try: + return self._s3_client.open(full_path, mode) + except Exception as e: + self._log_s3_error(e) + raise + + def remove(self, path: str) -> None: + full_path = self.get_full_path(path) + try: + self._s3_client.rm(full_path) + except Exception as e: + self._log_s3_error(e, full_path) + raise + + def exists(self, path: str) -> bool: + full_path = self.get_full_path(path) + return self._s3_client.exists(full_path) + + def mkdir(self, path: str): + full_path = self.get_full_path(path) + try: + self._s3_client.mkdir(full_path) + except Exception as e: + self._log_s3_error(e, full_path) + raise + + def download_file(self, remote_path: str, local_path: str) -> None: + full_remote_path = self.get_full_path(remote_path) + try: + self._s3_client.download(full_remote_path, local_path) + except Exception as e: + self._log_s3_error(e, full_remote_path) + raise + + def upload_file(self, local_path: str, remote_path: str) -> None: + full_remote_fn = self.get_full_path(remote_path) + try: + self._s3_client.upload(local_path, full_remote_fn) + except Exception as e: + self._log_s3_error(e) + raise diff --git a/secretflow_spec/core/types.py b/secretflow_spec/core/types.py new file mode 100644 index 0000000..4d9ff88 --- /dev/null +++ b/secretflow_spec/core/types.py @@ -0,0 +1,79 @@ +# Copyright 2024 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import enum +import sys +from dataclasses import dataclass + +if sys.version_info >= (3, 11): + from enum import StrEnum +else: + + class StrEnum(str, enum.Enum): + """ + Enum where members are also (and must be) strings + """ + + def __new__(cls, *values): + "values must already be of type `str`" + if len(values) > 3: + raise TypeError("too many arguments for str(): %r" % (values,)) + if len(values) == 1: + # it must be a string + if not isinstance(values[0], str): + raise TypeError("%r is not a string" % (values[0],)) + if len(values) >= 2: + # check that encoding argument is a string + if not isinstance(values[1], str): + raise TypeError("encoding must be a string, not %r" % (values[1],)) + if len(values) == 3: + # check that errors argument is a string + if not isinstance(values[2], str): + raise TypeError("errors must be a string, not %r" % (values[2])) + value = str(*values) + member = str.__new__(cls, value) + member._value_ = value + return member + + @staticmethod + def _generate_next_value_(name, start, count, last_values): + """ + Return the lower-cased version of the member name. + """ + return name.lower() + + def __repr__(self): + return self.value + + def __str__(self): + return self.value + + +@dataclass +class Version: + major: int + minor: int + + def __str__(self): + return f"{self.major}.{self.minor}" + + def __repr__(self): + return f"{self.major}.{self.minor}" + + @staticmethod + def from_str(v: str) -> "Version": + tokens = v.strip().split(".") + if len(tokens) != 2: + raise ValueError(f"version must be in format of x.y, but got {v}") + return Version(int(tokens[0]), int(tokens[1])) diff --git a/secretflow_spec/core/utils.py b/secretflow_spec/core/utils.py new file mode 100644 index 0000000..953eff8 --- /dev/null +++ b/secretflow_spec/core/utils.py @@ -0,0 +1,188 @@ +# Copyright 2024 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from typing import Any + +from secretflow_spec.core.dist_data.vtable import VTable +from secretflow_spec.core.version import SPEC_VERSION +from secretflow_spec.v1.component_pb2 import Attribute +from secretflow_spec.v1.data_pb2 import DistData +from secretflow_spec.v1.evaluation_pb2 import NodeEvalParam + +LINEBREAK_REGEX = re.compile(r"((\r\n)|[\n\v])+") +TWO_LINEBREAK_REGEX = re.compile(r"((\r\n)|[\n\v])+((\r\n)|[\n\v])+") +MULTI_WHITESPACE_TO_ONE_REGEX = re.compile(r"\s+") +NONBREAKING_SPACE_REGEX = re.compile(r"(?!\n)\s+") + + +def normalize_whitespace( + text: str, no_line_breaks=False, strip_lines=True, keep_two_line_breaks=False +): + """ + Given ``text`` str, replace one or more spacings with a single space, and one + or more line breaks with a single newline. Also strip leading/trailing whitespace. + """ + if strip_lines: + text = "\n".join([x.strip() for x in text.splitlines()]) + + if no_line_breaks: + text = MULTI_WHITESPACE_TO_ONE_REGEX.sub(" ", text) + else: + if keep_two_line_breaks: + text = NONBREAKING_SPACE_REGEX.sub( + " ", TWO_LINEBREAK_REGEX.sub(r"\n\n", text) + ) + else: + text = NONBREAKING_SPACE_REGEX.sub(" ", LINEBREAK_REGEX.sub(r"\n", text)) + + return text.strip() + + +DOUBLE_QUOTE_REGEX = re.compile("|".join("«»“”„‟‹›❝❞❮❯〝〞〟"")) +SINGLE_QUOTE_REGEX = re.compile("|".join("`´‘‘’’‛❛❜")) + + +def fix_strange_quotes(text): + """ + Replace strange quotes, i.e., 〞with a single quote ' or a double quote " if it fits better. + """ + text = SINGLE_QUOTE_REGEX.sub("'", text) + text = DOUBLE_QUOTE_REGEX.sub('"', text) + return text + + +def clean_text(text: str, no_line_breaks: bool = True) -> str: + text = text.strip() + text = normalize_whitespace(text, no_line_breaks) + text = fix_strange_quotes(text) + return text + + +_type_mapping: dict[str, type] = { + "float": float, + "bool": bool, + "int": int, + "str": str, + # float + "float16": float, + "float32": float, + "float64": float, + # int + "int8": int, + "int16": int, + "int32": int, + "int64": int, + "uint": int, + "uint8": int, + "uint16": int, + "uint32": int, + "uint64": int, + # numpy specific type + "float_": float, + "bool_": bool, + "int_": int, + "str_": str, + "object_": str, + # others + "double": float, + "halffloat": float, +} + + +def to_type(dt) -> type: + if not isinstance(dt, type): + dt = type(dt) + + if dt.__name__ in _type_mapping: + return _type_mapping[dt.__name__] + else: + raise ValueError(f"unsupported primitive type {dt}") + + +def to_attribute(v) -> Attribute: + if isinstance(v, Attribute): + return v + + is_list = isinstance(v, list) + if is_list: + assert len(v) > 0, f"Type cannot be inferred from an empty list" + prim_type = type(v[0]) + else: + prim_type = type(v) + if prim_type not in [bool, int, float, str]: + if prim_type.__name__ not in _type_mapping: + raise ValueError(f"unsupported type {prim_type},{v}") + if hasattr(v, "as_py"): + method = getattr(v, "as_py") + assert callable(method) + v = method() + else: + prim_type = _type_mapping[prim_type.__name__] + v = prim_type(v) + + if prim_type == bool: + return Attribute(bs=v) if is_list else Attribute(b=v) + elif prim_type == int: + return Attribute(i64s=v) if is_list else Attribute(i64=v) + elif prim_type == float: + return Attribute(fs=v) if is_list else Attribute(f=v) + elif prim_type == str: + return Attribute(ss=v) if is_list else Attribute(s=v) + else: + raise ValueError(f"unsupported primitive type {prim_type}") + + +def build_node_eval_param( + domain: str, + name: str, + version: str, + attrs: dict[str, Any] = None, + inputs: list[DistData | VTable] = None, + output_uris: list[str] = None, + checkpoint_uri: str = None, +) -> NodeEvalParam: + """ + Used for constructing NodeEvalParam in unit tests. + """ + + attr_paths, attr_values = None, None + if attrs: + attr_paths, attr_values = [], [] + for k, v in attrs.items(): + attr_paths.append(k) + attr_values.append(to_attribute(v)) + + def _to_distdata(x) -> DistData: + if isinstance(x, DistData): + return x + elif isinstance(x, VTable): + return x.to_distdata() + else: + raise ValueError(f"invalid DistData type, {type(x)}") + + if inputs: + inputs = [_to_distdata(dd) for dd in inputs] + + comp_id = f"{domain}/{name}:{version}" + param = NodeEvalParam( + version=SPEC_VERSION, + comp_id=comp_id, + attr_paths=attr_paths, + attrs=attr_values, + inputs=inputs, + output_uris=output_uris, + checkpoint_uri=checkpoint_uri, + ) + return param diff --git a/secretflow_spec/core/version.py b/secretflow_spec/core/version.py new file mode 100644 index 0000000..322509b --- /dev/null +++ b/secretflow_spec/core/version.py @@ -0,0 +1,17 @@ +# Copyright 2024 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +SPEC_VERSION_MAJOR = 1 +SPEC_VERSION_MINOR = 0 +SPEC_VERSION = f"{SPEC_VERSION_MAJOR}.{SPEC_VERSION_MINOR}" diff --git a/secretflow_spec/protos/api_linter_config.json b/secretflow_spec/protos/api_linter_config.json new file mode 100644 index 0000000..0bdc597 --- /dev/null +++ b/secretflow_spec/protos/api_linter_config.json @@ -0,0 +1,11 @@ +[ + { + "included_paths": ["secretflow_spec/v1/*.proto"], + "disabled_rules": [ + "core::0192::only-leading-comments", + "core::0192::has-comments", + "core::0146::any", + "core::0123::resource-annotation" + ] + } +] diff --git a/run_api_linter.sh b/secretflow_spec/protos/run_api_linter.sh similarity index 80% rename from run_api_linter.sh rename to secretflow_spec/protos/run_api_linter.sh index 1161f1c..33125eb 100755 --- a/run_api_linter.sh +++ b/secretflow_spec/protos/run_api_linter.sh @@ -1,12 +1,14 @@ #!/bin/bash -PROTO_FOLDER="secretflow/" +cd "$(dirname "$(readlink -f "$0")")" + +PROTO_FOLDER="secretflow_spec/" AIP_LINTER_PATH=$(which api-linter) failures=0 for file in $(find $PROTO_FOLDER -name '*.proto'); do - lint_output=$("$AIP_LINTER_PATH" "$file" --config api_linter_config.json) + lint_output=$("$AIP_LINTER_PATH" "$file" --config api_linter_config.json -I=.) if echo "$lint_output" | grep -iq "problems: \[\]"; then echo "[Scucess] $file." diff --git a/secretflow_spec/protos/run_protoc.sh b/secretflow_spec/protos/run_protoc.sh new file mode 100755 index 0000000..fa4eedd --- /dev/null +++ b/secretflow_spec/protos/run_protoc.sh @@ -0,0 +1,38 @@ +#!/bin/bash +cd "$(dirname "$(readlink -f "$0")")" + +# set -x +set -e + +# check to install protoc-26.1 +PROTOC_DIR="protoc-26.1" +PROTOC_BIN="${PROTOC_DIR}/bin/protoc" +PROTOC_ZIP_NAME="protoc-26.1-linux-x86_64.zip" +DOWNLOAD_URL="https://github.com/protocolbuffers/protobuf/releases/download/v26.1/${PROTOC_ZIP_NAME}" + +if [ ! -d "$PROTOC_DIR" ]; then + echo "start download protoc" + wget "$DOWNLOAD_URL" + unzip "$PROTOC_ZIP_NAME" -d "./$PROTOC_DIR" + rm -f $PROTOC_ZIP_NAME +fi + +# check to install mypy-protobuf +mypy_installed_version=$(pip show mypy-protobuf 2>/dev/null | grep Version | awk '{print $2}') +mypy_required_version="3.6.0" + +if [ "$mypy_installed_version" == "$mypy_required_version" ]; then + echo "mypy-protobuf<$mypy_required_version> has installed." +else + if [ -z "$mypy_installed_version" ]; then + echo "mypy-protobuf not found" + else + echo "mypy-protobuf version mismatch, current is $mypy_installed_version, but required is $mypy_required_version" + pip uninstall mypy-protobuf -y + fi + echo "start to install mypy-protobuf==$mypy_required_version" + pip install mypy-protobuf==$mypy_required_version +fi + +# build pb2.py +$PROTOC_BIN --proto_path="$PROTOC_DIR/include" --proto_path=. --python_out=../.. --mypy_out=../../ secretflow_spec/v1/*.proto \ No newline at end of file diff --git a/secretflow/spec/v1/component.proto b/secretflow_spec/protos/secretflow_spec/v1/component.proto similarity index 73% rename from secretflow/spec/v1/component.proto rename to secretflow_spec/protos/secretflow_spec/v1/component.proto index e3d9096..68096ab 100644 --- a/secretflow/spec/v1/component.proto +++ b/secretflow_spec/protos/secretflow_spec/v1/component.proto @@ -15,10 +15,10 @@ syntax = "proto3"; -package secretflow.spec.v1; +package secretflow_spec.v1; option java_outer_classname = "ComponentProto"; -option java_package = "com.secretflow.spec.v1"; +option java_package = "com.secretflow_spec.v1"; option java_multiple_files = true; // The value of an attribute @@ -44,6 +44,15 @@ message Attribute { } // Describe an attribute. +// There are three kinds of attribute. +// - Atomic Attributes: a solid field for users to fill-in. +// - Struct Attributes: a group of closely related attributes(including atomic, +// union and struct attributes). +// - Union Attributes: a group of mutually exlusive attributes(including union, +// group and dummy atomic attributes). Users should select only one children to +// fill-in. An atmoic attribute with ATTR_TYPE_UNSPECIFIED AttrType is regarded +// as dummy, which represents a selection of union without further +// configurations. message AttributeDef { // Indicates the ancestors of a node, // e.g. `[name_a, name_b, name_c]` means the path prefixes of current @@ -109,6 +118,15 @@ message AttributeDef { // Extras for custom protobuf attribute string custom_protobuf_cls = 7; + + // Extras for COL_PARAMS + string col_params_binded_table = 8; + + // The attribute can appear in NodeEvalParam only if current minor is in [minor_min, minor_max] + // if current minor < minor_min, it's not supported, + // if minor_max != -1 and current minor > minor_max, it's a deprecated attribute + int32 minor_min = 9; + int32 minor_max = 10; } // Define an input/output for component. @@ -169,36 +187,48 @@ message IoDef { // The attribute path for a TableAttrDef is `{input\|output}/{IoDef // name}/{TableAttrDef name}`. repeated TableAttrDef attrs = 4; + + bool is_optional = 5; + + // if the input io is variable, it must be the last one and size must be in [variable_min, variable_max] + bool is_variable = 6; + int32 variable_min = 7; + int32 variable_max = 8; + + // The io input/output can appear in NodeEvalParam only if current minor is in [minor_min, minor_max] + // if current minor < minor_min, it's not supported, + // if minor_max != -1 and current minor > minor_max, it's a deprecated io input/output + int32 minor_min = 9; + int32 minor_max = 10; } // The definition of a comp. message ComponentDef { // Namespace of the comp. string domain = 1; - // Should be unique among all comps of the same domain. string name = 2; - - string desc = 3; - // Version of the comp. - string version = 4; - - repeated AttributeDef attrs = 5; - - repeated IoDef inputs = 6; - - repeated IoDef outputs = 7; + string version = 3; + // Description of the comp. + string desc = 4; + // Static label infomations of the comp. + // e.g., {"sf.use.mpc":"true", "sf.multi.party.computation":"true"} + map labels = 5; + + repeated AttributeDef attrs = 10; + repeated IoDef inputs = 11; + repeated IoDef outputs = 12; } // A list of components message CompListDef { - string name = 1; - - string desc = 2; - - string version = 3; - + // The version of spec, format is {major}.{minor} + // the different major version are not compatible + // the different minor version should be forward compatible, In other words, you can only add fields and cannot modify or delete fields. + string version = 1; + string name = 2; + string desc = 3; repeated ComponentDef comps = 4; } @@ -228,4 +258,5 @@ enum AttrType { AT_UNION_GROUP = 10; AT_CUSTOM_PROTOBUF = 11; AT_PARTY = 12; // A specialized AT_STRINGS. + AT_COL_PARAMS = 13; // A specialized AT_STRINGS. } diff --git a/secretflow/spec/v1/data.proto b/secretflow_spec/protos/secretflow_spec/v1/data.proto similarity index 86% rename from secretflow/spec/v1/data.proto rename to secretflow_spec/protos/secretflow_spec/v1/data.proto index 0fd8e41..5254ef0 100644 --- a/secretflow/spec/v1/data.proto +++ b/secretflow_spec/protos/secretflow_spec/v1/data.proto @@ -14,10 +14,10 @@ syntax = "proto3"; -package secretflow.spec.v1; +package secretflow_spec.v1; option java_outer_classname = "DataProto"; -option java_package = "com.secretflow.spec.v1"; +option java_package = "com.secretflow_spec.v1"; option java_multiple_files = true; import "google/protobuf/any.proto"; @@ -79,20 +79,6 @@ message StorageConfig { // - sf.table.vertical_table represent a secretflow vertical table // - sf.table.individual_table represent a secretflow individual table message DistData { - // The name of this distributed data. - string name = 1; - - // Type. - string type = 2; - - // Describe the system information that used to generate this distributed - // data. - SystemInfo system_info = 3; - - // Public information, known to all parties. - // i.e. VerticalTable. - google.protobuf.Any meta = 4; - // A reference to a data that is stored in the remote path. message DataRef { // The path information relative to StorageConfig of the party. @@ -101,12 +87,33 @@ message DistData { // The owner party. string party = 2; - // The storage format, i.e. csv. + // The storage format, support: + // - csv represent a comma-separated value format file + // - orc represent a apache orc format file string format = 3; + + // A list of strings that represent NULL value. + // Only take effect when format is csv + repeated string null_strs = 4; } + // The version of spec + string version = 1; + // The name of this distributed data. + string name = 2; + // Type. + string type = 3; + + // Describe the system information that used to generate this distributed + // data. + SystemInfo system_info = 4; + + // Public information, known to all parties. + // i.e. VerticalTable. + google.protobuf.Any meta = 5; + // Remote data references. - repeated DataRef data_refs = 5; + repeated DataRef data_refs = 6; } // VerticalTable describes a virtual vertical partitioning table from multiple @@ -173,3 +180,9 @@ message TableSchema { // Len(labels) should match len(label_types). repeated string label_types = 6; } + +// ObjectFileInfo describes metadata for unstructured data file, such as Model +message ObjectFileInfo { + // Any public attributes + map attributes = 1; +} \ No newline at end of file diff --git a/secretflow/spec/v1/evaluation.proto b/secretflow_spec/protos/secretflow_spec/v1/evaluation.proto similarity index 76% rename from secretflow/spec/v1/evaluation.proto rename to secretflow_spec/protos/secretflow_spec/v1/evaluation.proto index f69241a..0268d8d 100644 --- a/secretflow/spec/v1/evaluation.proto +++ b/secretflow_spec/protos/secretflow_spec/v1/evaluation.proto @@ -14,14 +14,14 @@ syntax = "proto3"; -package secretflow.spec.v1; +package secretflow_spec.v1; option java_outer_classname = "EvaluationProto"; -option java_package = "com.secretflow.spec.v1"; +option java_package = "com.secretflow_spec.v1"; option java_multiple_files = true; -import "secretflow/spec/v1/component.proto"; -import "secretflow/spec/v1/data.proto"; +import "secretflow_spec/v1/component.proto"; +import "secretflow_spec/v1/data.proto"; // Evaluate a node. // - CompListDef + StorageConfig + NodeEvalParam + other extra configs -> @@ -29,33 +29,30 @@ import "secretflow/spec/v1/data.proto"; // // NodeEvalParam contains all the information to evaluate a component. message NodeEvalParam { - // Domain of the component. - string domain = 1; - - // Name of the component. - string name = 2; - - // Version of the component. - string version = 3; + // The version of spec + string version = 1; + + // The unique component id, the format is {domain}/{name}:{version} which is defined in ComponentDef + string comp_id = 2; // The path of attributes. // The attribute path for a TableAttrDef is // `(input\|output)/(IoDef name)/(TableAttrDef name)(/(column name)(/(extra // attributes))?)?`. - repeated string attr_paths = 4; + repeated string attr_paths = 3; // The value of the attribute. // Must match attr_paths. - repeated Attribute attrs = 5; + repeated Attribute attrs = 4; // The input data, the order of inputs must match inputs in ComponentDef. // NOTE: Names of DistData doesn't need to match those of inputs in // ComponentDef definition. - repeated DistData inputs = 6; + repeated DistData inputs = 5; // The output data uris, the order of output_uris must match outputs in // ComponentDef. - repeated string output_uris = 7; + repeated string output_uris = 6; // If not empty: // 1. Component will try to save checkpoint during training if the component @@ -64,11 +61,13 @@ message NodeEvalParam { // previous training. If the checkpoint does not exist or cannot be loaded, // training will be starting from scratch. // - string checkpoint_uri = 8; + string checkpoint_uri = 7; } // NodeEvalResult contains outputs of a component evaluation. message NodeEvalResult { + // The version of spec + string version = 1; // Output data. - repeated DistData outputs = 1; + repeated DistData outputs = 2; } diff --git a/secretflow/spec/v1/report.proto b/secretflow_spec/protos/secretflow_spec/v1/report.proto similarity index 94% rename from secretflow/spec/v1/report.proto rename to secretflow_spec/protos/secretflow_spec/v1/report.proto index ec937b7..2f39365 100644 --- a/secretflow/spec/v1/report.proto +++ b/secretflow_spec/protos/secretflow_spec/v1/report.proto @@ -14,12 +14,12 @@ syntax = "proto3"; -package secretflow.spec.v1; +package secretflow_spec.v1; -import "secretflow/spec/v1/component.proto"; +import "secretflow_spec/v1/component.proto"; option java_outer_classname = "ReportProto"; -option java_package = "com.secretflow.spec.v1"; +option java_package = "com.secretflow_spec.v1"; option java_multiple_files = true; // Displays multiple read-only fields in groups. diff --git a/secretflow_spec/v1/__init__.py b/secretflow_spec/v1/__init__.py new file mode 100644 index 0000000..086637a --- /dev/null +++ b/secretflow_spec/v1/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/secretflow_spec/v1/component_pb2.py b/secretflow_spec/v1/component_pb2.py new file mode 100644 index 0000000..edb690d --- /dev/null +++ b/secretflow_spec/v1/component_pb2.py @@ -0,0 +1,47 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: secretflow_spec/v1/component.proto +# Protobuf Python Version: 5.26.1 +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\"secretflow_spec/v1/component.proto\x12\x12secretflow_spec.v1\"z\n\tAttribute\x12\t\n\x01\x66\x18\x01 \x01(\x02\x12\x0b\n\x03i64\x18\x02 \x01(\x03\x12\t\n\x01s\x18\x03 \x01(\t\x12\t\n\x01\x62\x18\x04 \x01(\x08\x12\n\n\x02\x66s\x18\x05 \x03(\x02\x12\x0c\n\x04i64s\x18\x06 \x03(\x03\x12\n\n\x02ss\x18\x07 \x03(\t\x12\n\n\x02\x62s\x18\x08 \x03(\x08\x12\r\n\x05is_na\x18\t \x01(\x08\"\xbd\x06\n\x0c\x41ttributeDef\x12\x10\n\x08prefixes\x18\x01 \x03(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\x0c\n\x04\x64\x65sc\x18\x03 \x01(\t\x12*\n\x04type\x18\x04 \x01(\x0e\x32\x1c.secretflow_spec.v1.AttrType\x12?\n\x06\x61tomic\x18\x05 \x01(\x0b\x32/.secretflow_spec.v1.AttributeDef.AtomicAttrDesc\x12\x42\n\x05union\x18\x06 \x01(\x0b\x32\x33.secretflow_spec.v1.AttributeDef.UnionAttrGroupDesc\x12\x1b\n\x13\x63ustom_protobuf_cls\x18\x07 \x01(\t\x12\x1f\n\x17\x63ol_params_binded_table\x18\x08 \x01(\t\x12\x11\n\tminor_min\x18\t \x01(\x05\x12\x11\n\tminor_max\x18\n \x01(\x05\x1a\xb8\x03\n\x0e\x41tomicAttrDesc\x12!\n\x19list_min_length_inclusive\x18\x01 \x01(\x03\x12!\n\x19list_max_length_inclusive\x18\x02 \x01(\x03\x12\x13\n\x0bis_optional\x18\x03 \x01(\x08\x12\x34\n\rdefault_value\x18\x04 \x01(\x0b\x32\x1d.secretflow_spec.v1.Attribute\x12\x35\n\x0e\x61llowed_values\x18\x05 \x01(\x0b\x32\x1d.secretflow_spec.v1.Attribute\x12\x1b\n\x13lower_bound_enabled\x18\x06 \x01(\x08\x12\x32\n\x0blower_bound\x18\x07 \x01(\x0b\x32\x1d.secretflow_spec.v1.Attribute\x12\x1d\n\x15lower_bound_inclusive\x18\x08 \x01(\x08\x12\x1b\n\x13upper_bound_enabled\x18\t \x01(\x08\x12\x32\n\x0bupper_bound\x18\n \x01(\x0b\x32\x1d.secretflow_spec.v1.Attribute\x12\x1d\n\x15upper_bound_inclusive\x18\x0b \x01(\x08\x1a/\n\x12UnionAttrGroupDesc\x12\x19\n\x11\x64\x65\x66\x61ult_selection\x18\x01 \x01(\t\"\x96\x03\n\x05IoDef\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0c\n\x04\x64\x65sc\x18\x02 \x01(\t\x12\r\n\x05types\x18\x03 \x03(\t\x12\x35\n\x05\x61ttrs\x18\x04 \x03(\x0b\x32&.secretflow_spec.v1.IoDef.TableAttrDef\x12\x13\n\x0bis_optional\x18\x05 \x01(\x08\x12\x13\n\x0bis_variable\x18\x06 \x01(\x08\x12\x14\n\x0cvariable_min\x18\x07 \x01(\x05\x12\x14\n\x0cvariable_max\x18\x08 \x01(\x05\x12\x11\n\tminor_min\x18\t \x01(\x05\x12\x11\n\tminor_max\x18\n \x01(\x05\x1a\xae\x01\n\x0cTableAttrDef\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0c\n\x04\x64\x65sc\x18\x02 \x01(\t\x12\r\n\x05types\x18\x03 \x03(\t\x12\x1d\n\x15\x63ol_min_cnt_inclusive\x18\x04 \x01(\x03\x12\x1d\n\x15\x63ol_max_cnt_inclusive\x18\x05 \x01(\x03\x12\x35\n\x0b\x65xtra_attrs\x18\x06 \x03(\x0b\x32 .secretflow_spec.v1.AttributeDef\"\xc0\x02\n\x0c\x43omponentDef\x12\x0e\n\x06\x64omain\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\x0f\n\x07version\x18\x03 \x01(\t\x12\x0c\n\x04\x64\x65sc\x18\x04 \x01(\t\x12<\n\x06labels\x18\x05 \x03(\x0b\x32,.secretflow_spec.v1.ComponentDef.LabelsEntry\x12/\n\x05\x61ttrs\x18\n \x03(\x0b\x32 .secretflow_spec.v1.AttributeDef\x12)\n\x06inputs\x18\x0b \x03(\x0b\x32\x19.secretflow_spec.v1.IoDef\x12*\n\x07outputs\x18\x0c \x03(\x0b\x32\x19.secretflow_spec.v1.IoDef\x1a-\n\x0bLabelsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"k\n\x0b\x43ompListDef\x12\x0f\n\x07version\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\x0c\n\x04\x64\x65sc\x18\x03 \x01(\t\x12/\n\x05\x63omps\x18\x04 \x03(\x0b\x32 .secretflow_spec.v1.ComponentDef*\xf7\x01\n\x08\x41ttrType\x12\x19\n\x15\x41TTR_TYPE_UNSPECIFIED\x10\x00\x12\x0c\n\x08\x41T_FLOAT\x10\x01\x12\n\n\x06\x41T_INT\x10\x02\x12\r\n\tAT_STRING\x10\x03\x12\x0b\n\x07\x41T_BOOL\x10\x04\x12\r\n\tAT_FLOATS\x10\x05\x12\x0b\n\x07\x41T_INTS\x10\x06\x12\x0e\n\nAT_STRINGS\x10\x07\x12\x0c\n\x08\x41T_BOOLS\x10\x08\x12\x13\n\x0f\x41T_STRUCT_GROUP\x10\t\x12\x12\n\x0e\x41T_UNION_GROUP\x10\n\x12\x16\n\x12\x41T_CUSTOM_PROTOBUF\x10\x0b\x12\x0c\n\x08\x41T_PARTY\x10\x0c\x12\x11\n\rAT_COL_PARAMS\x10\rB*\n\x16\x63om.secretflow_spec.v1B\x0e\x43omponentProtoP\x01\x62\x06proto3') + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'secretflow_spec.v1.component_pb2', _globals) +if not _descriptor._USE_C_DESCRIPTORS: + _globals['DESCRIPTOR']._loaded_options = None + _globals['DESCRIPTOR']._serialized_options = b'\n\026com.secretflow_spec.v1B\016ComponentProtoP\001' + _globals['_COMPONENTDEF_LABELSENTRY']._loaded_options = None + _globals['_COMPONENTDEF_LABELSENTRY']._serialized_options = b'8\001' + _globals['_ATTRTYPE']._serialized_start=1856 + _globals['_ATTRTYPE']._serialized_end=2103 + _globals['_ATTRIBUTE']._serialized_start=58 + _globals['_ATTRIBUTE']._serialized_end=180 + _globals['_ATTRIBUTEDEF']._serialized_start=183 + _globals['_ATTRIBUTEDEF']._serialized_end=1012 + _globals['_ATTRIBUTEDEF_ATOMICATTRDESC']._serialized_start=523 + _globals['_ATTRIBUTEDEF_ATOMICATTRDESC']._serialized_end=963 + _globals['_ATTRIBUTEDEF_UNIONATTRGROUPDESC']._serialized_start=965 + _globals['_ATTRIBUTEDEF_UNIONATTRGROUPDESC']._serialized_end=1012 + _globals['_IODEF']._serialized_start=1015 + _globals['_IODEF']._serialized_end=1421 + _globals['_IODEF_TABLEATTRDEF']._serialized_start=1247 + _globals['_IODEF_TABLEATTRDEF']._serialized_end=1421 + _globals['_COMPONENTDEF']._serialized_start=1424 + _globals['_COMPONENTDEF']._serialized_end=1744 + _globals['_COMPONENTDEF_LABELSENTRY']._serialized_start=1699 + _globals['_COMPONENTDEF_LABELSENTRY']._serialized_end=1744 + _globals['_COMPLISTDEF']._serialized_start=1746 + _globals['_COMPLISTDEF']._serialized_end=1853 +# @@protoc_insertion_point(module_scope) diff --git a/secretflow_spec/v1/component_pb2.pyi b/secretflow_spec/v1/component_pb2.pyi new file mode 100644 index 0000000..f263a6d --- /dev/null +++ b/secretflow_spec/v1/component_pb2.pyi @@ -0,0 +1,563 @@ +""" +@generated by mypy-protobuf. Do not edit manually! +isort:skip_file +Copyright 2023 Ant Group Co., Ltd. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import builtins +import collections.abc +import google.protobuf.descriptor +import google.protobuf.internal.containers +import google.protobuf.internal.enum_type_wrapper +import google.protobuf.message +import sys +import typing + +if sys.version_info >= (3, 10): + import typing as typing_extensions +else: + import typing_extensions + +DESCRIPTOR: google.protobuf.descriptor.FileDescriptor + +class _AttrType: + ValueType = typing.NewType("ValueType", builtins.int) + V: typing_extensions.TypeAlias = ValueType + +class _AttrTypeEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[_AttrType.ValueType], builtins.type): + DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor + ATTR_TYPE_UNSPECIFIED: _AttrType.ValueType # 0 + """NOTE: ATTR_TYPE_UNSPECIFIED could be used as a child of a union struct + with no further attribute(s). + """ + AT_FLOAT: _AttrType.ValueType # 1 + """Scalar types + + FLOAT + """ + AT_INT: _AttrType.ValueType # 2 + """INT""" + AT_STRING: _AttrType.ValueType # 3 + """STRING""" + AT_BOOL: _AttrType.ValueType # 4 + """BOOL""" + AT_FLOATS: _AttrType.ValueType # 5 + """List types + + FLOATS + """ + AT_INTS: _AttrType.ValueType # 6 + """INTS""" + AT_STRINGS: _AttrType.ValueType # 7 + """STRINGS""" + AT_BOOLS: _AttrType.ValueType # 8 + """BOOLS""" + AT_STRUCT_GROUP: _AttrType.ValueType # 9 + """Special types.""" + AT_UNION_GROUP: _AttrType.ValueType # 10 + AT_CUSTOM_PROTOBUF: _AttrType.ValueType # 11 + AT_PARTY: _AttrType.ValueType # 12 + """A specialized AT_STRINGS.""" + AT_COL_PARAMS: _AttrType.ValueType # 13 + """A specialized AT_STRINGS.""" + +class AttrType(_AttrType, metaclass=_AttrTypeEnumTypeWrapper): + """Supported attribute types.""" + +ATTR_TYPE_UNSPECIFIED: AttrType.ValueType # 0 +"""NOTE: ATTR_TYPE_UNSPECIFIED could be used as a child of a union struct +with no further attribute(s). +""" +AT_FLOAT: AttrType.ValueType # 1 +"""Scalar types + +FLOAT +""" +AT_INT: AttrType.ValueType # 2 +"""INT""" +AT_STRING: AttrType.ValueType # 3 +"""STRING""" +AT_BOOL: AttrType.ValueType # 4 +"""BOOL""" +AT_FLOATS: AttrType.ValueType # 5 +"""List types + +FLOATS +""" +AT_INTS: AttrType.ValueType # 6 +"""INTS""" +AT_STRINGS: AttrType.ValueType # 7 +"""STRINGS""" +AT_BOOLS: AttrType.ValueType # 8 +"""BOOLS""" +AT_STRUCT_GROUP: AttrType.ValueType # 9 +"""Special types.""" +AT_UNION_GROUP: AttrType.ValueType # 10 +AT_CUSTOM_PROTOBUF: AttrType.ValueType # 11 +AT_PARTY: AttrType.ValueType # 12 +"""A specialized AT_STRINGS.""" +AT_COL_PARAMS: AttrType.ValueType # 13 +"""A specialized AT_STRINGS.""" +global___AttrType = AttrType + +@typing.final +class Attribute(google.protobuf.message.Message): + """The value of an attribute""" + + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + F_FIELD_NUMBER: builtins.int + I64_FIELD_NUMBER: builtins.int + S_FIELD_NUMBER: builtins.int + B_FIELD_NUMBER: builtins.int + FS_FIELD_NUMBER: builtins.int + I64S_FIELD_NUMBER: builtins.int + SS_FIELD_NUMBER: builtins.int + BS_FIELD_NUMBER: builtins.int + IS_NA_FIELD_NUMBER: builtins.int + f: builtins.float + """FLOAT""" + i64: builtins.int + """INT + NOTE(junfeng): "is" is preserved by Python. Replaced with "i64". + """ + s: builtins.str + """STRING""" + b: builtins.bool + """BOOL""" + is_na: builtins.bool + """Indicates the value is missing explicitly.""" + @property + def fs(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.float]: + """lists + + FLOATS + """ + + @property + def i64s(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]: + """INTS""" + + @property + def ss(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]: + """STRINGS""" + + @property + def bs(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.bool]: + """BOOLS""" + + def __init__( + self, + *, + f: builtins.float = ..., + i64: builtins.int = ..., + s: builtins.str = ..., + b: builtins.bool = ..., + fs: collections.abc.Iterable[builtins.float] | None = ..., + i64s: collections.abc.Iterable[builtins.int] | None = ..., + ss: collections.abc.Iterable[builtins.str] | None = ..., + bs: collections.abc.Iterable[builtins.bool] | None = ..., + is_na: builtins.bool = ..., + ) -> None: ... + def ClearField(self, field_name: typing.Literal["b", b"b", "bs", b"bs", "f", b"f", "fs", b"fs", "i64", b"i64", "i64s", b"i64s", "is_na", b"is_na", "s", b"s", "ss", b"ss"]) -> None: ... + +global___Attribute = Attribute + +@typing.final +class AttributeDef(google.protobuf.message.Message): + """Describe an attribute. + There are three kinds of attribute. + - Atomic Attributes: a solid field for users to fill-in. + - Struct Attributes: a group of closely related attributes(including atomic, + union and struct attributes). + - Union Attributes: a group of mutually exlusive attributes(including union, + group and dummy atomic attributes). Users should select only one children to + fill-in. An atmoic attribute with ATTR_TYPE_UNSPECIFIED AttrType is regarded + as dummy, which represents a selection of union without further + configurations. + """ + + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + @typing.final + class AtomicAttrDesc(google.protobuf.message.Message): + """Extras for an atomic attribute. + Including: `AT_FLOAT | AT_INT | AT_STRING | AT_BOOL | AT_FLOATS | AT_INTS | + AT_STRINGS | AT_BOOLS`. + """ + + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + LIST_MIN_LENGTH_INCLUSIVE_FIELD_NUMBER: builtins.int + LIST_MAX_LENGTH_INCLUSIVE_FIELD_NUMBER: builtins.int + IS_OPTIONAL_FIELD_NUMBER: builtins.int + DEFAULT_VALUE_FIELD_NUMBER: builtins.int + ALLOWED_VALUES_FIELD_NUMBER: builtins.int + LOWER_BOUND_ENABLED_FIELD_NUMBER: builtins.int + LOWER_BOUND_FIELD_NUMBER: builtins.int + LOWER_BOUND_INCLUSIVE_FIELD_NUMBER: builtins.int + UPPER_BOUND_ENABLED_FIELD_NUMBER: builtins.int + UPPER_BOUND_FIELD_NUMBER: builtins.int + UPPER_BOUND_INCLUSIVE_FIELD_NUMBER: builtins.int + list_min_length_inclusive: builtins.int + """Only valid when type is `AT_FLOATS \\| AT_INTS \\| AT_STRINGS \\| AT_BOOLS`.""" + list_max_length_inclusive: builtins.int + """Only valid when type is `AT_FLOATS \\| AT_INTS \\| AT_STRINGS \\| AT_BOOLS`.""" + is_optional: builtins.bool + """If True, when Atmoic Attr is not provided or is_na, default_value would + be used. Else, Atmoic Attr must be provided. + """ + lower_bound_enabled: builtins.bool + """Only valid when type is `AT_FLOAT \\| AT_INT \\| AT_FLOATS \\| AT_INTS `. + If the attribute is a list, lower_bound is applied to each element. + """ + lower_bound_inclusive: builtins.bool + upper_bound_enabled: builtins.bool + """Only valid when type is `AT_FLOAT \\| AT_INT \\| AT_FLOATS \\| AT_INTS `. + If the attribute is a list, upper_bound is applied to each element. + """ + upper_bound_inclusive: builtins.bool + @property + def default_value(self) -> global___Attribute: + """A reasonable default for this attribute if the user does not supply a + value. + """ + + @property + def allowed_values(self) -> global___Attribute: + """Only valid when type is `AT_FLOAT \\| AT_INT \\| AT_STRING \\| AT_FLOATS \\| + AT_INTS \\| AT_STRINGS`. + Please use list fields of AtomicParameter, i.e. `ss`, `i64s`, `fs`. + If the attribute is a list, allowed_values is applied to each element. + """ + + @property + def lower_bound(self) -> global___Attribute: ... + @property + def upper_bound(self) -> global___Attribute: ... + def __init__( + self, + *, + list_min_length_inclusive: builtins.int = ..., + list_max_length_inclusive: builtins.int = ..., + is_optional: builtins.bool = ..., + default_value: global___Attribute | None = ..., + allowed_values: global___Attribute | None = ..., + lower_bound_enabled: builtins.bool = ..., + lower_bound: global___Attribute | None = ..., + lower_bound_inclusive: builtins.bool = ..., + upper_bound_enabled: builtins.bool = ..., + upper_bound: global___Attribute | None = ..., + upper_bound_inclusive: builtins.bool = ..., + ) -> None: ... + def HasField(self, field_name: typing.Literal["allowed_values", b"allowed_values", "default_value", b"default_value", "lower_bound", b"lower_bound", "upper_bound", b"upper_bound"]) -> builtins.bool: ... + def ClearField(self, field_name: typing.Literal["allowed_values", b"allowed_values", "default_value", b"default_value", "is_optional", b"is_optional", "list_max_length_inclusive", b"list_max_length_inclusive", "list_min_length_inclusive", b"list_min_length_inclusive", "lower_bound", b"lower_bound", "lower_bound_enabled", b"lower_bound_enabled", "lower_bound_inclusive", b"lower_bound_inclusive", "upper_bound", b"upper_bound", "upper_bound_enabled", b"upper_bound_enabled", "upper_bound_inclusive", b"upper_bound_inclusive"]) -> None: ... + + @typing.final + class UnionAttrGroupDesc(google.protobuf.message.Message): + """Extras for a union attribute group.""" + + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + DEFAULT_SELECTION_FIELD_NUMBER: builtins.int + default_selection: builtins.str + """The default selected child.""" + def __init__( + self, + *, + default_selection: builtins.str = ..., + ) -> None: ... + def ClearField(self, field_name: typing.Literal["default_selection", b"default_selection"]) -> None: ... + + PREFIXES_FIELD_NUMBER: builtins.int + NAME_FIELD_NUMBER: builtins.int + DESC_FIELD_NUMBER: builtins.int + TYPE_FIELD_NUMBER: builtins.int + ATOMIC_FIELD_NUMBER: builtins.int + UNION_FIELD_NUMBER: builtins.int + CUSTOM_PROTOBUF_CLS_FIELD_NUMBER: builtins.int + COL_PARAMS_BINDED_TABLE_FIELD_NUMBER: builtins.int + MINOR_MIN_FIELD_NUMBER: builtins.int + MINOR_MAX_FIELD_NUMBER: builtins.int + name: builtins.str + """Must be unique in the same level just like Linux file systems. + Only `^[a-zA-Z0-9_.-]*$` is allowed. + `input` and `output` are reserved. + """ + desc: builtins.str + type: global___AttrType.ValueType + custom_protobuf_cls: builtins.str + """Extras for custom protobuf attribute""" + col_params_binded_table: builtins.str + """Extras for COL_PARAMS""" + minor_min: builtins.int + """The attribute can appear in NodeEvalParam only if current minor is in [minor_min, minor_max] + if current minor < minor_min, it's not supported, + if minor_max != -1 and current minor > minor_max, it's a deprecated attribute + """ + minor_max: builtins.int + @property + def prefixes(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]: + """Indicates the ancestors of a node, + e.g. `[name_a, name_b, name_c]` means the path prefixes of current + Attribute is `name_a/name_b/name_c/`. + Only `^[a-zA-Z0-9_.-]*$` is allowed. + `input` and `output` are reserved. + """ + + @property + def atomic(self) -> global___AttributeDef.AtomicAttrDesc: ... + @property + def union(self) -> global___AttributeDef.UnionAttrGroupDesc: ... + def __init__( + self, + *, + prefixes: collections.abc.Iterable[builtins.str] | None = ..., + name: builtins.str = ..., + desc: builtins.str = ..., + type: global___AttrType.ValueType = ..., + atomic: global___AttributeDef.AtomicAttrDesc | None = ..., + union: global___AttributeDef.UnionAttrGroupDesc | None = ..., + custom_protobuf_cls: builtins.str = ..., + col_params_binded_table: builtins.str = ..., + minor_min: builtins.int = ..., + minor_max: builtins.int = ..., + ) -> None: ... + def HasField(self, field_name: typing.Literal["atomic", b"atomic", "union", b"union"]) -> builtins.bool: ... + def ClearField(self, field_name: typing.Literal["atomic", b"atomic", "col_params_binded_table", b"col_params_binded_table", "custom_protobuf_cls", b"custom_protobuf_cls", "desc", b"desc", "minor_max", b"minor_max", "minor_min", b"minor_min", "name", b"name", "prefixes", b"prefixes", "type", b"type", "union", b"union"]) -> None: ... + +global___AttributeDef = AttributeDef + +@typing.final +class IoDef(google.protobuf.message.Message): + """Define an input/output for component.""" + + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + @typing.final + class TableAttrDef(google.protobuf.message.Message): + """An extra attribute for a table. + + If provided in a IoDef, e.g. + ```json + { + "name": "feature", + "types": [ + "int", + "float" + ], + "col_min_cnt_inclusive": 1, + "col_max_cnt": 3, + "attrs": [ + { + "name": "bucket_size", + "type": "AT_INT" + } + ] + } + ``` + means after a user provide a table as IO, they should also specify + cols as "feature": + - col_min_cnt_inclusive is 1: At least 1 col to be selected. + - col_max_cnt_inclusive is 3: At most 3 cols to be selected. + And afterwards, user have to fill an int attribute called bucket_size for + each selected cols. + """ + + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + NAME_FIELD_NUMBER: builtins.int + DESC_FIELD_NUMBER: builtins.int + TYPES_FIELD_NUMBER: builtins.int + COL_MIN_CNT_INCLUSIVE_FIELD_NUMBER: builtins.int + COL_MAX_CNT_INCLUSIVE_FIELD_NUMBER: builtins.int + EXTRA_ATTRS_FIELD_NUMBER: builtins.int + name: builtins.str + """Must be unique among all attributes for the table.""" + desc: builtins.str + col_min_cnt_inclusive: builtins.int + """inclusive""" + col_max_cnt_inclusive: builtins.int + @property + def types(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]: + """Accepted col data types. + Please check comments of TableSchema in data.proto. + """ + + @property + def extra_attrs(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___AttributeDef]: + """extra attribute for specified col.""" + + def __init__( + self, + *, + name: builtins.str = ..., + desc: builtins.str = ..., + types: collections.abc.Iterable[builtins.str] | None = ..., + col_min_cnt_inclusive: builtins.int = ..., + col_max_cnt_inclusive: builtins.int = ..., + extra_attrs: collections.abc.Iterable[global___AttributeDef] | None = ..., + ) -> None: ... + def ClearField(self, field_name: typing.Literal["col_max_cnt_inclusive", b"col_max_cnt_inclusive", "col_min_cnt_inclusive", b"col_min_cnt_inclusive", "desc", b"desc", "extra_attrs", b"extra_attrs", "name", b"name", "types", b"types"]) -> None: ... + + NAME_FIELD_NUMBER: builtins.int + DESC_FIELD_NUMBER: builtins.int + TYPES_FIELD_NUMBER: builtins.int + ATTRS_FIELD_NUMBER: builtins.int + IS_OPTIONAL_FIELD_NUMBER: builtins.int + IS_VARIABLE_FIELD_NUMBER: builtins.int + VARIABLE_MIN_FIELD_NUMBER: builtins.int + VARIABLE_MAX_FIELD_NUMBER: builtins.int + MINOR_MIN_FIELD_NUMBER: builtins.int + MINOR_MAX_FIELD_NUMBER: builtins.int + name: builtins.str + """should be unique among all IOs of the component.""" + desc: builtins.str + is_optional: builtins.bool + is_variable: builtins.bool + """if the input io is variable, it must be the last one and size must be in [variable_min, variable_max]""" + variable_min: builtins.int + variable_max: builtins.int + minor_min: builtins.int + """The io input/output can appear in NodeEvalParam only if current minor is in [minor_min, minor_max] + if current minor < minor_min, it's not supported, + if minor_max != -1 and current minor > minor_max, it's a deprecated io input/output + """ + minor_max: builtins.int + @property + def types(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]: + """Must be one of DistData.type in data.proto""" + + @property + def attrs(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___IoDef.TableAttrDef]: + """Only valid for tables. + The attribute path for a TableAttrDef is `{input\\|output}/{IoDef + name}/{TableAttrDef name}`. + """ + + def __init__( + self, + *, + name: builtins.str = ..., + desc: builtins.str = ..., + types: collections.abc.Iterable[builtins.str] | None = ..., + attrs: collections.abc.Iterable[global___IoDef.TableAttrDef] | None = ..., + is_optional: builtins.bool = ..., + is_variable: builtins.bool = ..., + variable_min: builtins.int = ..., + variable_max: builtins.int = ..., + minor_min: builtins.int = ..., + minor_max: builtins.int = ..., + ) -> None: ... + def ClearField(self, field_name: typing.Literal["attrs", b"attrs", "desc", b"desc", "is_optional", b"is_optional", "is_variable", b"is_variable", "minor_max", b"minor_max", "minor_min", b"minor_min", "name", b"name", "types", b"types", "variable_max", b"variable_max", "variable_min", b"variable_min"]) -> None: ... + +global___IoDef = IoDef + +@typing.final +class ComponentDef(google.protobuf.message.Message): + """The definition of a comp.""" + + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + @typing.final + class LabelsEntry(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + KEY_FIELD_NUMBER: builtins.int + VALUE_FIELD_NUMBER: builtins.int + key: builtins.str + value: builtins.str + def __init__( + self, + *, + key: builtins.str = ..., + value: builtins.str = ..., + ) -> None: ... + def ClearField(self, field_name: typing.Literal["key", b"key", "value", b"value"]) -> None: ... + + DOMAIN_FIELD_NUMBER: builtins.int + NAME_FIELD_NUMBER: builtins.int + VERSION_FIELD_NUMBER: builtins.int + DESC_FIELD_NUMBER: builtins.int + LABELS_FIELD_NUMBER: builtins.int + ATTRS_FIELD_NUMBER: builtins.int + INPUTS_FIELD_NUMBER: builtins.int + OUTPUTS_FIELD_NUMBER: builtins.int + domain: builtins.str + """Namespace of the comp.""" + name: builtins.str + """Should be unique among all comps of the same domain.""" + version: builtins.str + """Version of the comp.""" + desc: builtins.str + """Description of the comp.""" + @property + def labels(self) -> google.protobuf.internal.containers.ScalarMap[builtins.str, builtins.str]: + """Static label infomations of the comp. + e.g., {"sf.use.mpc":"true", "sf.multi.party.computation":"true"} + """ + + @property + def attrs(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___AttributeDef]: ... + @property + def inputs(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___IoDef]: ... + @property + def outputs(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___IoDef]: ... + def __init__( + self, + *, + domain: builtins.str = ..., + name: builtins.str = ..., + version: builtins.str = ..., + desc: builtins.str = ..., + labels: collections.abc.Mapping[builtins.str, builtins.str] | None = ..., + attrs: collections.abc.Iterable[global___AttributeDef] | None = ..., + inputs: collections.abc.Iterable[global___IoDef] | None = ..., + outputs: collections.abc.Iterable[global___IoDef] | None = ..., + ) -> None: ... + def ClearField(self, field_name: typing.Literal["attrs", b"attrs", "desc", b"desc", "domain", b"domain", "inputs", b"inputs", "labels", b"labels", "name", b"name", "outputs", b"outputs", "version", b"version"]) -> None: ... + +global___ComponentDef = ComponentDef + +@typing.final +class CompListDef(google.protobuf.message.Message): + """A list of components""" + + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + VERSION_FIELD_NUMBER: builtins.int + NAME_FIELD_NUMBER: builtins.int + DESC_FIELD_NUMBER: builtins.int + COMPS_FIELD_NUMBER: builtins.int + version: builtins.str + """The version of spec, format is {major}.{minor} + the different major version are not compatible + the different minor version should be forward compatible, In other words, you can only add fields and cannot modify or delete fields. + """ + name: builtins.str + desc: builtins.str + @property + def comps(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___ComponentDef]: ... + def __init__( + self, + *, + version: builtins.str = ..., + name: builtins.str = ..., + desc: builtins.str = ..., + comps: collections.abc.Iterable[global___ComponentDef] | None = ..., + ) -> None: ... + def ClearField(self, field_name: typing.Literal["comps", b"comps", "desc", b"desc", "name", b"name", "version", b"version"]) -> None: ... + +global___CompListDef = CompListDef diff --git a/secretflow_spec/v1/data_pb2.py b/secretflow_spec/v1/data_pb2.py new file mode 100644 index 0000000..0f2a2ec --- /dev/null +++ b/secretflow_spec/v1/data_pb2.py @@ -0,0 +1,50 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: secretflow_spec/v1/data.proto +# Protobuf Python Version: 5.26.1 +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +from google.protobuf import any_pb2 as google_dot_protobuf_dot_any__pb2 + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1dsecretflow_spec/v1/data.proto\x12\x12secretflow_spec.v1\x1a\x19google/protobuf/any.proto\"A\n\nSystemInfo\x12\x0b\n\x03\x61pp\x18\x01 \x01(\t\x12&\n\x08\x61pp_meta\x18\x02 \x01(\x0b\x32\x14.google.protobuf.Any\"\xcd\x02\n\rStorageConfig\x12\x0c\n\x04type\x18\x01 \x01(\t\x12\x41\n\x08local_fs\x18\x02 \x01(\x0b\x32/.secretflow_spec.v1.StorageConfig.LocalFSConfig\x12\x36\n\x02s3\x18\x03 \x01(\x0b\x32*.secretflow_spec.v1.StorageConfig.S3Config\x1a\x1b\n\rLocalFSConfig\x12\n\n\x02wd\x18\x01 \x01(\t\x1a\x95\x01\n\x08S3Config\x12\x10\n\x08\x65ndpoint\x18\x01 \x01(\t\x12\x0e\n\x06\x62ucket\x18\x02 \x01(\t\x12\x0e\n\x06prefix\x18\x03 \x01(\t\x12\x15\n\raccess_key_id\x18\x04 \x01(\t\x12\x19\n\x11\x61\x63\x63\x65ss_key_secret\x18\x05 \x01(\t\x12\x14\n\x0cvirtual_host\x18\x06 \x01(\x08\x12\x0f\n\x07version\x18\x07 \x01(\t\"\x93\x02\n\x08\x44istData\x12\x0f\n\x07version\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\x0c\n\x04type\x18\x03 \x01(\t\x12\x33\n\x0bsystem_info\x18\x04 \x01(\x0b\x32\x1e.secretflow_spec.v1.SystemInfo\x12\"\n\x04meta\x18\x05 \x01(\x0b\x32\x14.google.protobuf.Any\x12\x37\n\tdata_refs\x18\x06 \x03(\x0b\x32$.secretflow_spec.v1.DistData.DataRef\x1aH\n\x07\x44\x61taRef\x12\x0b\n\x03uri\x18\x01 \x01(\t\x12\r\n\x05party\x18\x02 \x01(\t\x12\x0e\n\x06\x66ormat\x18\x03 \x01(\t\x12\x11\n\tnull_strs\x18\x04 \x03(\t\"U\n\rVerticalTable\x12\x30\n\x07schemas\x18\x01 \x03(\x0b\x32\x1f.secretflow_spec.v1.TableSchema\x12\x12\n\nline_count\x18\x02 \x01(\x03\"V\n\x0fIndividualTable\x12/\n\x06schema\x18\x01 \x01(\x0b\x32\x1f.secretflow_spec.v1.TableSchema\x12\x12\n\nline_count\x18\x02 \x01(\x03\"z\n\x0bTableSchema\x12\x0b\n\x03ids\x18\x01 \x03(\t\x12\x10\n\x08\x66\x65\x61tures\x18\x02 \x03(\t\x12\x0e\n\x06labels\x18\x03 \x03(\t\x12\x10\n\x08id_types\x18\x04 \x03(\t\x12\x15\n\rfeature_types\x18\x05 \x03(\t\x12\x13\n\x0blabel_types\x18\x06 \x03(\t\"\x8b\x01\n\x0eObjectFileInfo\x12\x46\n\nattributes\x18\x01 \x03(\x0b\x32\x32.secretflow_spec.v1.ObjectFileInfo.AttributesEntry\x1a\x31\n\x0f\x41ttributesEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\x42%\n\x16\x63om.secretflow_spec.v1B\tDataProtoP\x01\x62\x06proto3') + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'secretflow_spec.v1.data_pb2', _globals) +if not _descriptor._USE_C_DESCRIPTORS: + _globals['DESCRIPTOR']._loaded_options = None + _globals['DESCRIPTOR']._serialized_options = b'\n\026com.secretflow_spec.v1B\tDataProtoP\001' + _globals['_OBJECTFILEINFO_ATTRIBUTESENTRY']._loaded_options = None + _globals['_OBJECTFILEINFO_ATTRIBUTESENTRY']._serialized_options = b'8\001' + _globals['_SYSTEMINFO']._serialized_start=80 + _globals['_SYSTEMINFO']._serialized_end=145 + _globals['_STORAGECONFIG']._serialized_start=148 + _globals['_STORAGECONFIG']._serialized_end=481 + _globals['_STORAGECONFIG_LOCALFSCONFIG']._serialized_start=302 + _globals['_STORAGECONFIG_LOCALFSCONFIG']._serialized_end=329 + _globals['_STORAGECONFIG_S3CONFIG']._serialized_start=332 + _globals['_STORAGECONFIG_S3CONFIG']._serialized_end=481 + _globals['_DISTDATA']._serialized_start=484 + _globals['_DISTDATA']._serialized_end=759 + _globals['_DISTDATA_DATAREF']._serialized_start=687 + _globals['_DISTDATA_DATAREF']._serialized_end=759 + _globals['_VERTICALTABLE']._serialized_start=761 + _globals['_VERTICALTABLE']._serialized_end=846 + _globals['_INDIVIDUALTABLE']._serialized_start=848 + _globals['_INDIVIDUALTABLE']._serialized_end=934 + _globals['_TABLESCHEMA']._serialized_start=936 + _globals['_TABLESCHEMA']._serialized_end=1058 + _globals['_OBJECTFILEINFO']._serialized_start=1061 + _globals['_OBJECTFILEINFO']._serialized_end=1200 + _globals['_OBJECTFILEINFO_ATTRIBUTESENTRY']._serialized_start=1151 + _globals['_OBJECTFILEINFO_ATTRIBUTESENTRY']._serialized_end=1200 +# @@protoc_insertion_point(module_scope) diff --git a/secretflow_spec/v1/data_pb2.pyi b/secretflow_spec/v1/data_pb2.pyi new file mode 100644 index 0000000..9498d65 --- /dev/null +++ b/secretflow_spec/v1/data_pb2.pyi @@ -0,0 +1,407 @@ +""" +@generated by mypy-protobuf. Do not edit manually! +isort:skip_file +Copyright 2023 Ant Group Co., Ltd. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import builtins +import collections.abc +import google.protobuf.any_pb2 +import google.protobuf.descriptor +import google.protobuf.internal.containers +import google.protobuf.message +import typing + +DESCRIPTOR: google.protobuf.descriptor.FileDescriptor + +@typing.final +class SystemInfo(google.protobuf.message.Message): + """Describe the application related to data.""" + + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + APP_FIELD_NUMBER: builtins.int + APP_META_FIELD_NUMBER: builtins.int + app: builtins.str + """The application name. + Supported: `secretflow` + """ + @property + def app_meta(self) -> google.protobuf.any_pb2.Any: + """Meta for application.""" + + def __init__( + self, + *, + app: builtins.str = ..., + app_meta: google.protobuf.any_pb2.Any | None = ..., + ) -> None: ... + def HasField(self, field_name: typing.Literal["app_meta", b"app_meta"]) -> builtins.bool: ... + def ClearField(self, field_name: typing.Literal["app", b"app", "app_meta", b"app_meta"]) -> None: ... + +global___SystemInfo = SystemInfo + +@typing.final +class StorageConfig(google.protobuf.message.Message): + """A StorageConfig specifies the root for all data for one party. + - At this moment, only local_fs / S3 compatible object storage is supported + - We would support databases like mysql etc. in future. + """ + + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + @typing.final + class LocalFSConfig(google.protobuf.message.Message): + """For local_fs.""" + + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + WD_FIELD_NUMBER: builtins.int + wd: builtins.str + """Working directory.""" + def __init__( + self, + *, + wd: builtins.str = ..., + ) -> None: ... + def ClearField(self, field_name: typing.Literal["wd", b"wd"]) -> None: ... + + @typing.final + class S3Config(google.protobuf.message.Message): + """For S3 compatible object storage""" + + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + ENDPOINT_FIELD_NUMBER: builtins.int + BUCKET_FIELD_NUMBER: builtins.int + PREFIX_FIELD_NUMBER: builtins.int + ACCESS_KEY_ID_FIELD_NUMBER: builtins.int + ACCESS_KEY_SECRET_FIELD_NUMBER: builtins.int + VIRTUAL_HOST_FIELD_NUMBER: builtins.int + VERSION_FIELD_NUMBER: builtins.int + endpoint: builtins.str + """endpoint https://play.min.io or http://127.0.0.1:9000 with scheme""" + bucket: builtins.str + """the bucket name of the oss datasource""" + prefix: builtins.str + """the prefix of the oss datasource. e.g. data/traindata/""" + access_key_id: builtins.str + """access key""" + access_key_secret: builtins.str + """access secret""" + virtual_host: builtins.bool + """virtual_host is the same as AliyunOSS/AWS S3's virtualhost , default true""" + version: builtins.str + """optional enum[s3v2,s3v4]""" + def __init__( + self, + *, + endpoint: builtins.str = ..., + bucket: builtins.str = ..., + prefix: builtins.str = ..., + access_key_id: builtins.str = ..., + access_key_secret: builtins.str = ..., + virtual_host: builtins.bool = ..., + version: builtins.str = ..., + ) -> None: ... + def ClearField(self, field_name: typing.Literal["access_key_id", b"access_key_id", "access_key_secret", b"access_key_secret", "bucket", b"bucket", "endpoint", b"endpoint", "prefix", b"prefix", "version", b"version", "virtual_host", b"virtual_host"]) -> None: ... + + TYPE_FIELD_NUMBER: builtins.int + LOCAL_FS_FIELD_NUMBER: builtins.int + S3_FIELD_NUMBER: builtins.int + type: builtins.str + """enum[local_fs, s3]""" + @property + def local_fs(self) -> global___StorageConfig.LocalFSConfig: + """local_fs config.""" + + @property + def s3(self) -> global___StorageConfig.S3Config: + """s3 config""" + + def __init__( + self, + *, + type: builtins.str = ..., + local_fs: global___StorageConfig.LocalFSConfig | None = ..., + s3: global___StorageConfig.S3Config | None = ..., + ) -> None: ... + def HasField(self, field_name: typing.Literal["local_fs", b"local_fs", "s3", b"s3"]) -> builtins.bool: ... + def ClearField(self, field_name: typing.Literal["local_fs", b"local_fs", "s3", b"s3", "type", b"type"]) -> None: ... + +global___StorageConfig = StorageConfig + +@typing.final +class DistData(google.protobuf.message.Message): + """A public record for a general distributed data. + + The type of this distributed data, should be meaningful to components. + + The concrete data format (include public and private parts) is defined by + other protos. + + Suggested internal types, i.e. + - sf.table.vertical_table represent a secretflow vertical table + - sf.table.individual_table represent a secretflow individual table + """ + + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + @typing.final + class DataRef(google.protobuf.message.Message): + """A reference to a data that is stored in the remote path.""" + + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + URI_FIELD_NUMBER: builtins.int + PARTY_FIELD_NUMBER: builtins.int + FORMAT_FIELD_NUMBER: builtins.int + NULL_STRS_FIELD_NUMBER: builtins.int + uri: builtins.str + """The path information relative to StorageConfig of the party.""" + party: builtins.str + """The owner party.""" + format: builtins.str + """The storage format, support: + - csv represent a comma-separated value format file + - orc represent a apache orc format file + """ + @property + def null_strs(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]: + """A list of strings that represent NULL value. + Only take effect when format is csv + """ + + def __init__( + self, + *, + uri: builtins.str = ..., + party: builtins.str = ..., + format: builtins.str = ..., + null_strs: collections.abc.Iterable[builtins.str] | None = ..., + ) -> None: ... + def ClearField(self, field_name: typing.Literal["format", b"format", "null_strs", b"null_strs", "party", b"party", "uri", b"uri"]) -> None: ... + + VERSION_FIELD_NUMBER: builtins.int + NAME_FIELD_NUMBER: builtins.int + TYPE_FIELD_NUMBER: builtins.int + SYSTEM_INFO_FIELD_NUMBER: builtins.int + META_FIELD_NUMBER: builtins.int + DATA_REFS_FIELD_NUMBER: builtins.int + version: builtins.str + """The version of spec""" + name: builtins.str + """The name of this distributed data.""" + type: builtins.str + """Type.""" + @property + def system_info(self) -> global___SystemInfo: + """Describe the system information that used to generate this distributed + data. + """ + + @property + def meta(self) -> google.protobuf.any_pb2.Any: + """Public information, known to all parties. + i.e. VerticalTable. + """ + + @property + def data_refs(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___DistData.DataRef]: + """Remote data references.""" + + def __init__( + self, + *, + version: builtins.str = ..., + name: builtins.str = ..., + type: builtins.str = ..., + system_info: global___SystemInfo | None = ..., + meta: google.protobuf.any_pb2.Any | None = ..., + data_refs: collections.abc.Iterable[global___DistData.DataRef] | None = ..., + ) -> None: ... + def HasField(self, field_name: typing.Literal["meta", b"meta", "system_info", b"system_info"]) -> builtins.bool: ... + def ClearField(self, field_name: typing.Literal["data_refs", b"data_refs", "meta", b"meta", "name", b"name", "system_info", b"system_info", "type", b"type", "version", b"version"]) -> None: ... + +global___DistData = DistData + +@typing.final +class VerticalTable(google.protobuf.message.Message): + """VerticalTable describes a virtual vertical partitioning table from multiple + parties. + """ + + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + SCHEMAS_FIELD_NUMBER: builtins.int + LINE_COUNT_FIELD_NUMBER: builtins.int + line_count: builtins.int + """If -1, the number is unknown.""" + @property + def schemas(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___TableSchema]: + """The vertical partitioned slices' schema. + Must match data_refs in the parent DistData message. + """ + + def __init__( + self, + *, + schemas: collections.abc.Iterable[global___TableSchema] | None = ..., + line_count: builtins.int = ..., + ) -> None: ... + def ClearField(self, field_name: typing.Literal["line_count", b"line_count", "schemas", b"schemas"]) -> None: ... + +global___VerticalTable = VerticalTable + +@typing.final +class IndividualTable(google.protobuf.message.Message): + """IndividualTable describes a table owned by a single party.""" + + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + SCHEMA_FIELD_NUMBER: builtins.int + LINE_COUNT_FIELD_NUMBER: builtins.int + line_count: builtins.int + """If -1, the number is unknown.""" + @property + def schema(self) -> global___TableSchema: + """Schema.""" + + def __init__( + self, + *, + schema: global___TableSchema | None = ..., + line_count: builtins.int = ..., + ) -> None: ... + def HasField(self, field_name: typing.Literal["schema", b"schema"]) -> builtins.bool: ... + def ClearField(self, field_name: typing.Literal["line_count", b"line_count", "schema", b"schema"]) -> None: ... + +global___IndividualTable = IndividualTable + +@typing.final +class TableSchema(google.protobuf.message.Message): + """The schema of a table. + - A col must be one of `id | feature | label`. By default, it should be a + feature. + - All names must match the regexp `[A-Za-z0-9.][A-Za-z0-9_>./]*`. + - All data type must be one of + * int8 + * int16 + * int32 + * int64 + * uint8 + * uint16 + * uint32 + * uint64 + * float16 + * float32 + * float64 + * bool + * int + * float + * str + """ + + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + IDS_FIELD_NUMBER: builtins.int + FEATURES_FIELD_NUMBER: builtins.int + LABELS_FIELD_NUMBER: builtins.int + ID_TYPES_FIELD_NUMBER: builtins.int + FEATURE_TYPES_FIELD_NUMBER: builtins.int + LABEL_TYPES_FIELD_NUMBER: builtins.int + @property + def ids(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]: + """Id column name(s). + Optional, can be empty. + """ + + @property + def features(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]: + """Feature column name(s).""" + + @property + def labels(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]: + """Label column name(s). + Optional, can be empty. + """ + + @property + def id_types(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]: + """Id column data type(s). + Len(id) should match len(id_types). + """ + + @property + def feature_types(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]: + """Feature column data type(s). + Len(features) should match len(feature_types). + """ + + @property + def label_types(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]: + """Label column data type(s). + Len(labels) should match len(label_types). + """ + + def __init__( + self, + *, + ids: collections.abc.Iterable[builtins.str] | None = ..., + features: collections.abc.Iterable[builtins.str] | None = ..., + labels: collections.abc.Iterable[builtins.str] | None = ..., + id_types: collections.abc.Iterable[builtins.str] | None = ..., + feature_types: collections.abc.Iterable[builtins.str] | None = ..., + label_types: collections.abc.Iterable[builtins.str] | None = ..., + ) -> None: ... + def ClearField(self, field_name: typing.Literal["feature_types", b"feature_types", "features", b"features", "id_types", b"id_types", "ids", b"ids", "label_types", b"label_types", "labels", b"labels"]) -> None: ... + +global___TableSchema = TableSchema + +@typing.final +class ObjectFileInfo(google.protobuf.message.Message): + """ObjectFileInfo describes metadata for unstructured data file, such as Model""" + + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + @typing.final + class AttributesEntry(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + KEY_FIELD_NUMBER: builtins.int + VALUE_FIELD_NUMBER: builtins.int + key: builtins.str + value: builtins.str + def __init__( + self, + *, + key: builtins.str = ..., + value: builtins.str = ..., + ) -> None: ... + def ClearField(self, field_name: typing.Literal["key", b"key", "value", b"value"]) -> None: ... + + ATTRIBUTES_FIELD_NUMBER: builtins.int + @property + def attributes(self) -> google.protobuf.internal.containers.ScalarMap[builtins.str, builtins.str]: + """Any public attributes""" + + def __init__( + self, + *, + attributes: collections.abc.Mapping[builtins.str, builtins.str] | None = ..., + ) -> None: ... + def ClearField(self, field_name: typing.Literal["attributes", b"attributes"]) -> None: ... + +global___ObjectFileInfo = ObjectFileInfo diff --git a/secretflow_spec/v1/evaluation_pb2.py b/secretflow_spec/v1/evaluation_pb2.py new file mode 100644 index 0000000..a42d949 --- /dev/null +++ b/secretflow_spec/v1/evaluation_pb2.py @@ -0,0 +1,31 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: secretflow_spec/v1/evaluation.proto +# Protobuf Python Version: 5.26.1 +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +from secretflow_spec.v1 import component_pb2 as secretflow__spec_dot_v1_dot_component__pb2 +from secretflow_spec.v1 import data_pb2 as secretflow__spec_dot_v1_dot_data__pb2 + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n#secretflow_spec/v1/evaluation.proto\x12\x12secretflow_spec.v1\x1a\"secretflow_spec/v1/component.proto\x1a\x1dsecretflow_spec/v1/data.proto\"\xce\x01\n\rNodeEvalParam\x12\x0f\n\x07version\x18\x01 \x01(\t\x12\x0f\n\x07\x63omp_id\x18\x02 \x01(\t\x12\x12\n\nattr_paths\x18\x03 \x03(\t\x12,\n\x05\x61ttrs\x18\x04 \x03(\x0b\x32\x1d.secretflow_spec.v1.Attribute\x12,\n\x06inputs\x18\x05 \x03(\x0b\x32\x1c.secretflow_spec.v1.DistData\x12\x13\n\x0boutput_uris\x18\x06 \x03(\t\x12\x16\n\x0e\x63heckpoint_uri\x18\x07 \x01(\t\"P\n\x0eNodeEvalResult\x12\x0f\n\x07version\x18\x01 \x01(\t\x12-\n\x07outputs\x18\x02 \x03(\x0b\x32\x1c.secretflow_spec.v1.DistDataB+\n\x16\x63om.secretflow_spec.v1B\x0f\x45valuationProtoP\x01\x62\x06proto3') + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'secretflow_spec.v1.evaluation_pb2', _globals) +if not _descriptor._USE_C_DESCRIPTORS: + _globals['DESCRIPTOR']._loaded_options = None + _globals['DESCRIPTOR']._serialized_options = b'\n\026com.secretflow_spec.v1B\017EvaluationProtoP\001' + _globals['_NODEEVALPARAM']._serialized_start=127 + _globals['_NODEEVALPARAM']._serialized_end=333 + _globals['_NODEEVALRESULT']._serialized_start=335 + _globals['_NODEEVALRESULT']._serialized_end=415 +# @@protoc_insertion_point(module_scope) diff --git a/secretflow_spec/v1/evaluation_pb2.pyi b/secretflow_spec/v1/evaluation_pb2.pyi new file mode 100644 index 0000000..d722714 --- /dev/null +++ b/secretflow_spec/v1/evaluation_pb2.pyi @@ -0,0 +1,124 @@ +""" +@generated by mypy-protobuf. Do not edit manually! +isort:skip_file +Copyright 2023 Ant Group Co., Ltd. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import builtins +import collections.abc +import google.protobuf.descriptor +import google.protobuf.internal.containers +import google.protobuf.message +import secretflow_spec.v1.component_pb2 +import secretflow_spec.v1.data_pb2 +import typing + +DESCRIPTOR: google.protobuf.descriptor.FileDescriptor + +@typing.final +class NodeEvalParam(google.protobuf.message.Message): + """Evaluate a node. + - CompListDef + StorageConfig + NodeEvalParam + other extra configs -> + NodeEvalResult + + NodeEvalParam contains all the information to evaluate a component. + """ + + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + VERSION_FIELD_NUMBER: builtins.int + COMP_ID_FIELD_NUMBER: builtins.int + ATTR_PATHS_FIELD_NUMBER: builtins.int + ATTRS_FIELD_NUMBER: builtins.int + INPUTS_FIELD_NUMBER: builtins.int + OUTPUT_URIS_FIELD_NUMBER: builtins.int + CHECKPOINT_URI_FIELD_NUMBER: builtins.int + version: builtins.str + """The version of spec""" + comp_id: builtins.str + """The unique component id, the format is {domain}/{name}:{version} which is defined in ComponentDef""" + checkpoint_uri: builtins.str + """If not empty: + 1. Component will try to save checkpoint during training if the component + supports it. + 2. Component will try to reload checkpoint when starting to continue the + previous training. If the checkpoint does not exist or cannot be loaded, + training will be starting from scratch. + """ + @property + def attr_paths(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]: + """The path of attributes. + The attribute path for a TableAttrDef is + `(input\\|output)/(IoDef name)/(TableAttrDef name)(/(column name)(/(extra + attributes))?)?`. + """ + + @property + def attrs(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[secretflow_spec.v1.component_pb2.Attribute]: + """The value of the attribute. + Must match attr_paths. + """ + + @property + def inputs(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[secretflow_spec.v1.data_pb2.DistData]: + """The input data, the order of inputs must match inputs in ComponentDef. + NOTE: Names of DistData doesn't need to match those of inputs in + ComponentDef definition. + """ + + @property + def output_uris(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]: + """The output data uris, the order of output_uris must match outputs in + ComponentDef. + """ + + def __init__( + self, + *, + version: builtins.str = ..., + comp_id: builtins.str = ..., + attr_paths: collections.abc.Iterable[builtins.str] | None = ..., + attrs: collections.abc.Iterable[secretflow_spec.v1.component_pb2.Attribute] | None = ..., + inputs: collections.abc.Iterable[secretflow_spec.v1.data_pb2.DistData] | None = ..., + output_uris: collections.abc.Iterable[builtins.str] | None = ..., + checkpoint_uri: builtins.str = ..., + ) -> None: ... + def ClearField(self, field_name: typing.Literal["attr_paths", b"attr_paths", "attrs", b"attrs", "checkpoint_uri", b"checkpoint_uri", "comp_id", b"comp_id", "inputs", b"inputs", "output_uris", b"output_uris", "version", b"version"]) -> None: ... + +global___NodeEvalParam = NodeEvalParam + +@typing.final +class NodeEvalResult(google.protobuf.message.Message): + """NodeEvalResult contains outputs of a component evaluation.""" + + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + VERSION_FIELD_NUMBER: builtins.int + OUTPUTS_FIELD_NUMBER: builtins.int + version: builtins.str + """The version of spec""" + @property + def outputs(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[secretflow_spec.v1.data_pb2.DistData]: + """Output data.""" + + def __init__( + self, + *, + version: builtins.str = ..., + outputs: collections.abc.Iterable[secretflow_spec.v1.data_pb2.DistData] | None = ..., + ) -> None: ... + def ClearField(self, field_name: typing.Literal["outputs", b"outputs", "version", b"version"]) -> None: ... + +global___NodeEvalResult = NodeEvalResult diff --git a/secretflow_spec/v1/report_pb2.py b/secretflow_spec/v1/report_pb2.py new file mode 100644 index 0000000..d11dc38 --- /dev/null +++ b/secretflow_spec/v1/report_pb2.py @@ -0,0 +1,44 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: secretflow_spec/v1/report.proto +# Protobuf Python Version: 5.26.1 +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +from secretflow_spec.v1 import component_pb2 as secretflow__spec_dot_v1_dot_component__pb2 + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1fsecretflow_spec/v1/report.proto\x12\x12secretflow_spec.v1\x1a\"secretflow_spec/v1/component.proto\"\xc0\x01\n\x0c\x44\x65scriptions\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0c\n\x04\x64\x65sc\x18\x02 \x01(\t\x12\x34\n\x05items\x18\x03 \x03(\x0b\x32%.secretflow_spec.v1.Descriptions.Item\x1a^\n\x04Item\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0c\n\x04\x64\x65sc\x18\x02 \x01(\t\x12\x0c\n\x04type\x18\x03 \x01(\t\x12,\n\x05value\x18\x04 \x01(\x0b\x32\x1d.secretflow_spec.v1.Attribute\"\x90\x02\n\x05Table\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0c\n\x04\x64\x65sc\x18\x02 \x01(\t\x12\x35\n\x07headers\x18\x03 \x03(\x0b\x32$.secretflow_spec.v1.Table.HeaderItem\x12+\n\x04rows\x18\x04 \x03(\x0b\x32\x1d.secretflow_spec.v1.Table.Row\x1a\x36\n\nHeaderItem\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0c\n\x04\x64\x65sc\x18\x02 \x01(\t\x12\x0c\n\x04type\x18\x03 \x01(\t\x1aO\n\x03Row\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0c\n\x04\x64\x65sc\x18\x02 \x01(\t\x12,\n\x05items\x18\x03 \x03(\x0b\x32\x1d.secretflow_spec.v1.Attribute\"\xf2\x01\n\x03\x44iv\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0c\n\x04\x64\x65sc\x18\x02 \x01(\t\x12/\n\x08\x63hildren\x18\x03 \x03(\x0b\x32\x1d.secretflow_spec.v1.Div.Child\x1a\x9d\x01\n\x05\x43hild\x12\x0c\n\x04type\x18\x01 \x01(\t\x12\x36\n\x0c\x64\x65scriptions\x18\x02 \x01(\x0b\x32 .secretflow_spec.v1.Descriptions\x12(\n\x05table\x18\x03 \x01(\x0b\x32\x19.secretflow_spec.v1.Table\x12$\n\x03\x64iv\x18\x04 \x01(\x0b\x32\x17.secretflow_spec.v1.Div\"H\n\x03Tab\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0c\n\x04\x64\x65sc\x18\x02 \x01(\t\x12%\n\x04\x64ivs\x18\x03 \x03(\x0b\x32\x17.secretflow_spec.v1.Div\"q\n\x06Report\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0c\n\x04\x64\x65sc\x18\x02 \x01(\t\x12%\n\x04tabs\x18\x03 \x03(\x0b\x32\x17.secretflow_spec.v1.Tab\x12\x10\n\x08\x65rr_code\x18\x04 \x01(\x05\x12\x12\n\nerr_detail\x18\x05 \x01(\tB\'\n\x16\x63om.secretflow_spec.v1B\x0bReportProtoP\x01\x62\x06proto3') + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'secretflow_spec.v1.report_pb2', _globals) +if not _descriptor._USE_C_DESCRIPTORS: + _globals['DESCRIPTOR']._loaded_options = None + _globals['DESCRIPTOR']._serialized_options = b'\n\026com.secretflow_spec.v1B\013ReportProtoP\001' + _globals['_DESCRIPTIONS']._serialized_start=92 + _globals['_DESCRIPTIONS']._serialized_end=284 + _globals['_DESCRIPTIONS_ITEM']._serialized_start=190 + _globals['_DESCRIPTIONS_ITEM']._serialized_end=284 + _globals['_TABLE']._serialized_start=287 + _globals['_TABLE']._serialized_end=559 + _globals['_TABLE_HEADERITEM']._serialized_start=424 + _globals['_TABLE_HEADERITEM']._serialized_end=478 + _globals['_TABLE_ROW']._serialized_start=480 + _globals['_TABLE_ROW']._serialized_end=559 + _globals['_DIV']._serialized_start=562 + _globals['_DIV']._serialized_end=804 + _globals['_DIV_CHILD']._serialized_start=647 + _globals['_DIV_CHILD']._serialized_end=804 + _globals['_TAB']._serialized_start=806 + _globals['_TAB']._serialized_end=878 + _globals['_REPORT']._serialized_start=880 + _globals['_REPORT']._serialized_end=993 +# @@protoc_insertion_point(module_scope) diff --git a/secretflow_spec/v1/report_pb2.pyi b/secretflow_spec/v1/report_pb2.pyi new file mode 100644 index 0000000..0b2380a --- /dev/null +++ b/secretflow_spec/v1/report_pb2.pyi @@ -0,0 +1,254 @@ +""" +@generated by mypy-protobuf. Do not edit manually! +isort:skip_file +Copyright 2023 Ant Group Co., Ltd. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import builtins +import collections.abc +import google.protobuf.descriptor +import google.protobuf.internal.containers +import google.protobuf.message +import secretflow_spec.v1.component_pb2 +import typing + +DESCRIPTOR: google.protobuf.descriptor.FileDescriptor + +@typing.final +class Descriptions(google.protobuf.message.Message): + """Displays multiple read-only fields in groups.""" + + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + @typing.final + class Item(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + NAME_FIELD_NUMBER: builtins.int + DESC_FIELD_NUMBER: builtins.int + TYPE_FIELD_NUMBER: builtins.int + VALUE_FIELD_NUMBER: builtins.int + name: builtins.str + """Name of the field.""" + desc: builtins.str + type: builtins.str + """Must be one of bool/int/float/str""" + @property + def value(self) -> secretflow_spec.v1.component_pb2.Attribute: ... + def __init__( + self, + *, + name: builtins.str = ..., + desc: builtins.str = ..., + type: builtins.str = ..., + value: secretflow_spec.v1.component_pb2.Attribute | None = ..., + ) -> None: ... + def HasField(self, field_name: typing.Literal["value", b"value"]) -> builtins.bool: ... + def ClearField(self, field_name: typing.Literal["desc", b"desc", "name", b"name", "type", b"type", "value", b"value"]) -> None: ... + + NAME_FIELD_NUMBER: builtins.int + DESC_FIELD_NUMBER: builtins.int + ITEMS_FIELD_NUMBER: builtins.int + name: builtins.str + """Name of the Descriptions.""" + desc: builtins.str + @property + def items(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___Descriptions.Item]: ... + def __init__( + self, + *, + name: builtins.str = ..., + desc: builtins.str = ..., + items: collections.abc.Iterable[global___Descriptions.Item] | None = ..., + ) -> None: ... + def ClearField(self, field_name: typing.Literal["desc", b"desc", "items", b"items", "name", b"name"]) -> None: ... + +global___Descriptions = Descriptions + +@typing.final +class Table(google.protobuf.message.Message): + """Displays rows of data.""" + + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + @typing.final + class HeaderItem(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + NAME_FIELD_NUMBER: builtins.int + DESC_FIELD_NUMBER: builtins.int + TYPE_FIELD_NUMBER: builtins.int + name: builtins.str + desc: builtins.str + type: builtins.str + """Must be one of bool/int/float/str""" + def __init__( + self, + *, + name: builtins.str = ..., + desc: builtins.str = ..., + type: builtins.str = ..., + ) -> None: ... + def ClearField(self, field_name: typing.Literal["desc", b"desc", "name", b"name", "type", b"type"]) -> None: ... + + @typing.final + class Row(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + NAME_FIELD_NUMBER: builtins.int + DESC_FIELD_NUMBER: builtins.int + ITEMS_FIELD_NUMBER: builtins.int + name: builtins.str + desc: builtins.str + @property + def items(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[secretflow_spec.v1.component_pb2.Attribute]: ... + def __init__( + self, + *, + name: builtins.str = ..., + desc: builtins.str = ..., + items: collections.abc.Iterable[secretflow_spec.v1.component_pb2.Attribute] | None = ..., + ) -> None: ... + def ClearField(self, field_name: typing.Literal["desc", b"desc", "items", b"items", "name", b"name"]) -> None: ... + + NAME_FIELD_NUMBER: builtins.int + DESC_FIELD_NUMBER: builtins.int + HEADERS_FIELD_NUMBER: builtins.int + ROWS_FIELD_NUMBER: builtins.int + name: builtins.str + """Name of the Table.""" + desc: builtins.str + @property + def headers(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___Table.HeaderItem]: ... + @property + def rows(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___Table.Row]: ... + def __init__( + self, + *, + name: builtins.str = ..., + desc: builtins.str = ..., + headers: collections.abc.Iterable[global___Table.HeaderItem] | None = ..., + rows: collections.abc.Iterable[global___Table.Row] | None = ..., + ) -> None: ... + def ClearField(self, field_name: typing.Literal["desc", b"desc", "headers", b"headers", "name", b"name", "rows", b"rows"]) -> None: ... + +global___Table = Table + +@typing.final +class Div(google.protobuf.message.Message): + """A division or a section of a page.""" + + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + @typing.final + class Child(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + TYPE_FIELD_NUMBER: builtins.int + DESCRIPTIONS_FIELD_NUMBER: builtins.int + TABLE_FIELD_NUMBER: builtins.int + DIV_FIELD_NUMBER: builtins.int + type: builtins.str + """Supported: descriptions, table, div.""" + @property + def descriptions(self) -> global___Descriptions: ... + @property + def table(self) -> global___Table: ... + @property + def div(self) -> global___Div: ... + def __init__( + self, + *, + type: builtins.str = ..., + descriptions: global___Descriptions | None = ..., + table: global___Table | None = ..., + div: global___Div | None = ..., + ) -> None: ... + def HasField(self, field_name: typing.Literal["descriptions", b"descriptions", "div", b"div", "table", b"table"]) -> builtins.bool: ... + def ClearField(self, field_name: typing.Literal["descriptions", b"descriptions", "div", b"div", "table", b"table", "type", b"type"]) -> None: ... + + NAME_FIELD_NUMBER: builtins.int + DESC_FIELD_NUMBER: builtins.int + CHILDREN_FIELD_NUMBER: builtins.int + name: builtins.str + """Name of the Div.""" + desc: builtins.str + @property + def children(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___Div.Child]: ... + def __init__( + self, + *, + name: builtins.str = ..., + desc: builtins.str = ..., + children: collections.abc.Iterable[global___Div.Child] | None = ..., + ) -> None: ... + def ClearField(self, field_name: typing.Literal["children", b"children", "desc", b"desc", "name", b"name"]) -> None: ... + +global___Div = Div + +@typing.final +class Tab(google.protobuf.message.Message): + """A page of a report.""" + + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + NAME_FIELD_NUMBER: builtins.int + DESC_FIELD_NUMBER: builtins.int + DIVS_FIELD_NUMBER: builtins.int + name: builtins.str + """Name of the Tab.""" + desc: builtins.str + @property + def divs(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___Div]: ... + def __init__( + self, + *, + name: builtins.str = ..., + desc: builtins.str = ..., + divs: collections.abc.Iterable[global___Div] | None = ..., + ) -> None: ... + def ClearField(self, field_name: typing.Literal["desc", b"desc", "divs", b"divs", "name", b"name"]) -> None: ... + +global___Tab = Tab + +@typing.final +class Report(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + NAME_FIELD_NUMBER: builtins.int + DESC_FIELD_NUMBER: builtins.int + TABS_FIELD_NUMBER: builtins.int + ERR_CODE_FIELD_NUMBER: builtins.int + ERR_DETAIL_FIELD_NUMBER: builtins.int + name: builtins.str + """Name of the Report.""" + desc: builtins.str + err_code: builtins.int + err_detail: builtins.str + """Structed error detail (JSON encoded message).""" + @property + def tabs(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___Tab]: ... + def __init__( + self, + *, + name: builtins.str = ..., + desc: builtins.str = ..., + tabs: collections.abc.Iterable[global___Tab] | None = ..., + err_code: builtins.int = ..., + err_detail: builtins.str = ..., + ) -> None: ... + def ClearField(self, field_name: typing.Literal["desc", b"desc", "err_code", b"err_code", "err_detail", b"err_detail", "name", b"name", "tabs", b"tabs"]) -> None: ... + +global___Report = Report diff --git a/secretflow_spec/version.py b/secretflow_spec/version.py new file mode 100644 index 0000000..01d0374 --- /dev/null +++ b/secretflow_spec/version.py @@ -0,0 +1,17 @@ +# Copyright 2024 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +__version__ = "1.1.0b0" +__commit_id__ = "$$COMMIT_ID$$" +__build_time__ = "$$BUILD_TIME$$" diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..0d96b43 --- /dev/null +++ b/setup.py @@ -0,0 +1,116 @@ +# Copyright 2024 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import re +import shutil +import subprocess +import time +from datetime import date + +import setuptools +from setuptools import find_packages, setup + +this_directory = os.path.abspath(os.path.dirname(__file__)) + + +def long_description(): + with open(os.path.join(this_directory, "README.md"), encoding="utf-8") as f: + return f.read() + + +def get_commit_id() -> str: + commit_id = ( + subprocess.check_output(["git", "rev-parse", "HEAD"]).decode("ascii").strip() + ) + dirty = subprocess.check_output(["git", "diff", "--stat"]).decode("ascii").strip() + + if dirty: + commit_id = f"{commit_id}-dirty" + + return commit_id + + +def complete_version_file(*filepath): + today = date.today() + dstr = today.strftime("%Y%m%d") + with open(os.path.join(".", *filepath), "r") as fp: + content = fp.read() + + content = content.replace("$$DATE$$", dstr) + content = content.replace("$$BUILD_TIME$$", time.strftime("%b %d %Y, %X")) + try: + content = content.replace("$$COMMIT_ID$$", get_commit_id()) + except: + pass + + with open(os.path.join(".", *filepath), "w+") as fp: + fp.write(content) + + +def find_version(*filepath): + complete_version_file(*filepath) + # Extract version information from filepath + with open(os.path.join(".", *filepath)) as fp: + version_match = re.search( + r"^__version__ = ['\"]([^'\"]*)['\"]", fp.read(), re.M + ) + if version_match: + return version_match.group(1) + print("Unable to find version string.") + exit(-1) + + +def read_requirements(): + with open("requirements.txt") as req_file: + return req_file.read().splitlines() + + +# [ref](https://github.com/perwin/pyimfit/blob/master/setup.py) +# Modified cleanup command to remove dist subdirectory +# Based on: https://stackoverflow.com/questions/1710839/custom-distutils-commands +class CleanCommand(setuptools.Command): + description = "custom clean command that forcefully removes dist directories" + user_options = [] + + def initialize_options(self): + pass + + def finalize_options(self): + pass + + def run(self): + directories_to_clean = ["./build"] + + for dir in directories_to_clean: + if os.path.exists(dir): + shutil.rmtree(dir) + + +if __name__ == "__main__": + setup( + name="secretflow_spec", + version=find_version("secretflow_spec", "version.py"), + license="Apache 2.0", + description="Secretflow spec", + long_description=long_description(), + long_description_content_type="text/markdown", + author="SCI Center", + author_email="secretflow-contact@service.alipay.com", + url="https://github.com/secretflow/spec", + packages=find_packages(exclude=["secretflow_spec.tests"]), + install_requires=read_requirements(), + extras_require={"dev": ["pylint"]}, + cmdclass=dict(clean=CleanCommand, cleanall=CleanCommand), + ) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..0e6db10 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/comps/__init__.py b/tests/comps/__init__.py new file mode 100644 index 0000000..0e6db10 --- /dev/null +++ b/tests/comps/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/comps/my_comp.py b/tests/comps/my_comp.py new file mode 100644 index 0000000..710dd48 --- /dev/null +++ b/tests/comps/my_comp.py @@ -0,0 +1,20 @@ +# Copyright 2024 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from secretflow_spec import Component, Field, register + + +@register(domain="test", version="0.0.1", desc="xx") +class MyComponent(Component): + fa: int = Field.attr("field int") diff --git a/tests/spec/extend/__init__.py b/tests/spec/extend/__init__.py new file mode 100644 index 0000000..086637a --- /dev/null +++ b/tests/spec/extend/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/spec/extend/calculate_rules.proto b/tests/spec/extend/calculate_rules.proto new file mode 100644 index 0000000..adc98c7 --- /dev/null +++ b/tests/spec/extend/calculate_rules.proto @@ -0,0 +1,59 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; + +package tests.spec.extend.extend; + +option java_package = "org.tests.spec.extend"; + +// input columns set by Component::io::col_params +message CalculateOpRules { + enum OpType { + // inval type + INVAL = 0; + // len(operands) == 0 + STANDARDIZE = 1; + // len(operands) == 0 + NORMALIZATION = 2; + // len(operands) == 2, [min, max] + RANGE_LIMIT = 3; + // len(operands) == 3, [(+ -), unary_op(+ - * /), value] + // if operandsp[0] == "+", column unary_op value + // if operandsp[0] == "-", value unary_op column + UNARY = 4; + // len(operands) == 0 + RECIPROCAL = 5; + // len(operands) == 0 + ROUND = 6; + // len(operands) == 1, [bias] + LOG_ROUND = 7; + // len(operands) == 0 + SQRT = 8; + // len(operands) == 2, [log_base, bias] + LOG = 9; + // len(operands) == 0 + EXP = 10; + // len(operands) == 0 + LENGTH = 11; + // len(operands) == 2, [start_pos, length] + SUBSTR = 12; + } + + OpType op = 1; + + repeated string operands = 2; + + string new_col_name = 3; +} \ No newline at end of file diff --git a/tests/spec/extend/calculate_rules_pb2.py b/tests/spec/extend/calculate_rules_pb2.py new file mode 100644 index 0000000..9afcc16 --- /dev/null +++ b/tests/spec/extend/calculate_rules_pb2.py @@ -0,0 +1,28 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: tests/spec/extend/calculate_rules.proto +"""Generated protocol buffer code.""" +from google.protobuf.internal import builder as _builder +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\'tests/spec/extend/calculate_rules.proto\x12\x18tests.spec.extend.extend\"\xad\x02\n\x10\x43\x61lculateOpRules\x12=\n\x02op\x18\x01 \x01(\x0e\x32\x31.tests.spec.extend.extend.CalculateOpRules.OpType\x12\x10\n\x08operands\x18\x02 \x03(\t\x12\x14\n\x0cnew_col_name\x18\x03 \x01(\t\"\xb1\x01\n\x06OpType\x12\t\n\x05INVAL\x10\x00\x12\x0f\n\x0bSTANDARDIZE\x10\x01\x12\x11\n\rNORMALIZATION\x10\x02\x12\x0f\n\x0bRANGE_LIMIT\x10\x03\x12\t\n\x05UNARY\x10\x04\x12\x0e\n\nRECIPROCAL\x10\x05\x12\t\n\x05ROUND\x10\x06\x12\r\n\tLOG_ROUND\x10\x07\x12\x08\n\x04SQRT\x10\x08\x12\x07\n\x03LOG\x10\t\x12\x07\n\x03\x45XP\x10\n\x12\n\n\x06LENGTH\x10\x0b\x12\n\n\x06SUBSTR\x10\x0c\x42\x17\n\x15org.tests.spec.extendb\x06proto3') + +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'tests.spec.extend.calculate_rules_pb2', globals()) +if _descriptor._USE_C_DESCRIPTORS == False: + + DESCRIPTOR._options = None + DESCRIPTOR._serialized_options = b'\n\025org.tests.spec.extend' + _CALCULATEOPRULES._serialized_start=70 + _CALCULATEOPRULES._serialized_end=371 + _CALCULATEOPRULES_OPTYPE._serialized_start=194 + _CALCULATEOPRULES_OPTYPE._serialized_end=371 +# @@protoc_insertion_point(module_scope) diff --git a/tests/spec/extend/calculate_rules_pb2.pyi b/tests/spec/extend/calculate_rules_pb2.pyi new file mode 100644 index 0000000..5f436c1 --- /dev/null +++ b/tests/spec/extend/calculate_rules_pb2.pyi @@ -0,0 +1,124 @@ +""" +@generated by mypy-protobuf. Do not edit manually! +isort:skip_file +Copyright 2023 Ant Group Co., Ltd. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import builtins +import collections.abc +import google.protobuf.descriptor +import google.protobuf.internal.containers +import google.protobuf.internal.enum_type_wrapper +import google.protobuf.message +import sys +import typing + +if sys.version_info >= (3, 10): + import typing as typing_extensions +else: + import typing_extensions + +DESCRIPTOR: google.protobuf.descriptor.FileDescriptor + +@typing.final +class CalculateOpRules(google.protobuf.message.Message): + """input columns set by Component::io::col_params""" + + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + class _OpType: + ValueType = typing.NewType("ValueType", builtins.int) + V: typing_extensions.TypeAlias = ValueType + + class _OpTypeEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[CalculateOpRules._OpType.ValueType], builtins.type): + DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor + INVAL: CalculateOpRules._OpType.ValueType # 0 + """inval type""" + STANDARDIZE: CalculateOpRules._OpType.ValueType # 1 + """len(operands) == 0""" + NORMALIZATION: CalculateOpRules._OpType.ValueType # 2 + """len(operands) == 0""" + RANGE_LIMIT: CalculateOpRules._OpType.ValueType # 3 + """len(operands) == 2, [min, max]""" + UNARY: CalculateOpRules._OpType.ValueType # 4 + """len(operands) == 3, [(+ -), unary_op(+ - * /), value] + if operandsp[0] == "+", column unary_op value + if operandsp[0] == "-", value unary_op column + """ + RECIPROCAL: CalculateOpRules._OpType.ValueType # 5 + """len(operands) == 0""" + ROUND: CalculateOpRules._OpType.ValueType # 6 + """len(operands) == 0""" + LOG_ROUND: CalculateOpRules._OpType.ValueType # 7 + """len(operands) == 1, [bias]""" + SQRT: CalculateOpRules._OpType.ValueType # 8 + """len(operands) == 0""" + LOG: CalculateOpRules._OpType.ValueType # 9 + """len(operands) == 2, [log_base, bias]""" + EXP: CalculateOpRules._OpType.ValueType # 10 + """len(operands) == 0""" + LENGTH: CalculateOpRules._OpType.ValueType # 11 + """len(operands) == 0""" + SUBSTR: CalculateOpRules._OpType.ValueType # 12 + """len(operands) == 2, [start_pos, length]""" + + class OpType(_OpType, metaclass=_OpTypeEnumTypeWrapper): ... + INVAL: CalculateOpRules.OpType.ValueType # 0 + """inval type""" + STANDARDIZE: CalculateOpRules.OpType.ValueType # 1 + """len(operands) == 0""" + NORMALIZATION: CalculateOpRules.OpType.ValueType # 2 + """len(operands) == 0""" + RANGE_LIMIT: CalculateOpRules.OpType.ValueType # 3 + """len(operands) == 2, [min, max]""" + UNARY: CalculateOpRules.OpType.ValueType # 4 + """len(operands) == 3, [(+ -), unary_op(+ - * /), value] + if operandsp[0] == "+", column unary_op value + if operandsp[0] == "-", value unary_op column + """ + RECIPROCAL: CalculateOpRules.OpType.ValueType # 5 + """len(operands) == 0""" + ROUND: CalculateOpRules.OpType.ValueType # 6 + """len(operands) == 0""" + LOG_ROUND: CalculateOpRules.OpType.ValueType # 7 + """len(operands) == 1, [bias]""" + SQRT: CalculateOpRules.OpType.ValueType # 8 + """len(operands) == 0""" + LOG: CalculateOpRules.OpType.ValueType # 9 + """len(operands) == 2, [log_base, bias]""" + EXP: CalculateOpRules.OpType.ValueType # 10 + """len(operands) == 0""" + LENGTH: CalculateOpRules.OpType.ValueType # 11 + """len(operands) == 0""" + SUBSTR: CalculateOpRules.OpType.ValueType # 12 + """len(operands) == 2, [start_pos, length]""" + + OP_FIELD_NUMBER: builtins.int + OPERANDS_FIELD_NUMBER: builtins.int + NEW_COL_NAME_FIELD_NUMBER: builtins.int + op: global___CalculateOpRules.OpType.ValueType + new_col_name: builtins.str + @property + def operands(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]: ... + def __init__( + self, + *, + op: global___CalculateOpRules.OpType.ValueType = ..., + operands: collections.abc.Iterable[builtins.str] | None = ..., + new_col_name: builtins.str = ..., + ) -> None: ... + def ClearField(self, field_name: typing.Literal["new_col_name", b"new_col_name", "op", b"op", "operands", b"operands"]) -> None: ... + +global___CalculateOpRules = CalculateOpRules diff --git a/tests/test_definition.py b/tests/test_definition.py new file mode 100644 index 0000000..5109d7a --- /dev/null +++ b/tests/test_definition.py @@ -0,0 +1,325 @@ +# Copyright 2024 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import logging +from dataclasses import dataclass + +import pytest + +from secretflow_spec import ( + Component, + Definition, + Field, + Input, + Interval, + Output, + UnionGroup, +) +from secretflow_spec.core.dist_data.base import DistDataType +from secretflow_spec.v1.evaluation_pb2 import NodeEvalParam + + +@dataclass +class UnionOption1: + state: int = Field.attr(desc="state", default=1) + + +@dataclass +class UnionOption2: + field1: int = Field.attr(desc="field1") + + +@dataclass +class UnionField(UnionGroup): + option1: UnionOption1 = Field.struct_attr(desc="option1") + option2: UnionOption2 = Field.struct_attr(desc="option2") + option3: str = Field.selection_attr(desc="option3, selection attr") + option4: int = Field.attr(desc="option4, int attr") + + +@dataclass +class StructField: + field1: int = Field.attr(default=1) + field2: float = Field.attr(default=1.0, bound_limit=Interval.closed(0, 100)) + + +class StructNoDataclass: + field1: int = Field.attr(default=1) + + +class DemoComponent(Component): + """ + demo component + """ + + class StructInner: + field1: int = Field.attr(default=1) + + int_fd: int = Field.attr(default=1, choices=[0, 1]) + float_fd: float = Field.attr(default=1.0, bound_limit=Interval.closed(0, 2)) + bool_fd: bool = Field.attr(default=True) + ints_fd: list[int] = Field.attr(default=[1, 2, 3], list_limit=Interval.closed(1, 3)) + floats_fd: list[float] = Field.attr( + default=[1.0, 2.0, 3.0], list_limit=Interval.closed(1, 3) + ) + bools_fd: list[bool] = Field.attr(default=[True, False]) + strs_fd: list[str] = Field.attr(default=["a", "b", "c"]) + struct_fd: StructField = Field.struct_attr(desc="struct field") + struct_no_dc: StructNoDataclass = Field.struct_attr() + struct_inner: StructInner = Field.struct_attr() + union_fd: UnionField = Field.union_attr(desc="union field") + column: str = Field.table_column_attr("input_fd", desc="input table column") + drop_first: bool = Field.attr(desc="drop_first", minor_max=0) + drop: str = Field.attr(choices=["first", "mod", "no_drop"], minor_min=1) + input_fd: Input = Field.input(desc="input", types=[DistDataType.INDIVIDUAL_TABLE]) + output_fd: Output = Field.output( + desc="output", types=[DistDataType.INDIVIDUAL_TABLE] + ) + + def __post__init__(self) -> None: + # upgrade param and verify + if self.is_supported(0): + self.drop = "first" if self.drop_first else "no_drop" + + def evaluate(self): + print(self.int_fd, self.float_fd) + + +def test_definition(): + comp1 = DemoComponent( + int_fd=1, + struct_fd=StructField(field1=2), + struct_no_dc=StructNoDataclass(field1=3), + struct_inner=DemoComponent.StructInner(field1=4), + ) + assert ( + comp1.int_fd == 1 + and comp1.struct_fd.field1 == 2 + and comp1.struct_no_dc.field1 == 3 + ) + + args = { + "int_fd": 1, + "float_fd": 2.0, + "bool_fd": True, + "struct_fd/field1": 11, + "struct_fd/field2": 12, + "struct_no_dc/field1": 13, + "struct_inner/field1": 14, + "union_fd": "option4", + "union_fd/option4": 21, + "drop_first": True, + "input/input_fd/column": "col1", + "input_fd": Input(type=str(DistDataType.INDIVIDUAL_TABLE)), + "output_fd": "test uri", + } + args["_minor"] = 0 + comp_def = Definition(DemoComponent, "test", "0.1.0") + comp: DemoComponent = comp_def.make_component(args) + assert ( + comp._minor == 0 + and comp.int_fd == 1 + and comp.float_fd == 2.0 + and comp.bool_fd == True + and comp.struct_fd.field1 == 11 + and comp.struct_fd.field2 == 12 + and comp.struct_no_dc.field1 == 13 + and comp.struct_inner.field1 == 14 + and comp.union_fd.is_selected("option4") + and comp.union_fd.option4 == 21 + and comp.column == "col1" + ) + + +def test_custom(): + from tests.spec.extend.calculate_rules_pb2 import CalculateOpRules + + class CustomComponent(Component): + rule: CalculateOpRules = Field.custom_attr(desc="xx") + + def evaluate(self): ... + + d = Definition(CustomComponent, "test", "0.0.1") + protobuf_cls = d.component_def.attrs[0].custom_protobuf_cls + assert protobuf_cls == "tests.spec.extend.calculate_rules_pb2.CalculateOpRules" + + +def test_union(): + class UnionField(UnionGroup): + f1: int = Field.attr() + f2: int = Field.attr() + + class UnionComponent(Component): + uf: UnionField = Field.union_attr() + + def evaluate(self): ... + + args = {"uf": "f1", "uf/f1": 1} + args["_minor"] = 0 + comp_def = Definition(UnionComponent, "test", "0.1.0") + comp: UnionComponent = comp_def.make_component(args) + assert comp.uf.is_selected("f1") and comp.uf.f1 == 1 + + # uf default use f1 + args = {"uf/f1": 1} + args["_minor"] = 0 + comp_def.make_component(args) + + # no uf/f1 + args = {"uf": "f1"} + args["_minor"] = 0 + with pytest.raises(Exception) as exc_info: + comp_def.make_component(args) + logging.info(f"Caught expected Exception: {exc_info}") + + # unused uf/f2 + args = {"uf": "f1", "uf/f1": 1, "uf/f2": 2} + args["_minor"] = 0 + with pytest.raises(Exception) as exc_info: + comp_def.make_component(args) + logging.info(f"Caught expected Exception: {exc_info}") + + +def test_version(): + class VersionComponent(Component): + v0: int = Field.attr() + v1: int = Field.attr(minor_min=1, minor_max=1) + v2: int = Field.attr(minor_min=2) + + input1: Input = Field.input(types=[DistDataType.NULL], minor_min=0, minor_max=0) + input2: Input = Field.input(types=[DistDataType.NULL], minor_min=1) + + output1: Output = Field.output( + types=[DistDataType.NULL], minor_min=0, minor_max=0 + ) + output2: Output = Field.output(types=[DistDataType.NULL], minor_min=1) + + def evaluate(self) -> None: ... + + definition = Definition(VersionComponent, "test", "0.2.0") + comp_def = definition.component_def + assert len(comp_def.attrs) == 3 and comp_def.attrs[2].name == "v2" + assert len(comp_def.inputs) == 2 and len(comp_def.outputs) == 2 + + input = Input(type=str(DistDataType.NULL)) + args = {"v0": 0, "input1": input, "output1": Output(uri="o1")} + args["_minor"] = 0 + comp: VersionComponent = definition.make_component(args) + assert comp.v0 == 0 and not comp.v1 and not comp.v2 + + args = {"v0": 0, "v1": 1, "input2": input, "output2": Output(uri="o1")} + args["_minor"] = 1 + comp: VersionComponent = definition.make_component(args) + assert comp.v0 == 0 and comp.v1 == 1 and not comp.v2 + + args = {"v0": 0, "v2": 2, "input2": input, "output2": Output(uri="o2")} + args["_minor"] = 2 + comp: VersionComponent = definition.make_component(args) + assert comp.v0 == 0 and comp.v2 == 2 and comp.output2.uri == "o2" + + bad_cases = [ + {"_minor": 0, "input1": input, "output1": Output(uri="o1")}, # no v0 + { + "_minor": 2, + "v0": 0, + "v1": 1, # deprecated field + "v2": 2, + "v3": 3, # unknown field + "input2": input, + "output2": Output(uri="o2"), + }, + ] + + for case in bad_cases: + with pytest.raises(Exception) as e: + definition.make_component(case) + logging.info(f"Caught expected Exception: {e}") + + +def test_inherit(): + class Parent(Component): + f1: int = Field.attr() + + class Child(Parent): + f2: int = Field.attr() + + def evaluate(self) -> None: + print(self.f1, self.f2) + + child = Child(f1=1, f2=2) + assert child.f1 == 1 and child.f2 == 2 + + comp_def = Definition(Child, "test", "0.0.1") + args = {"_minor": 0, "f1": 1, "f2": 2} + comp: Child = comp_def.make_component(args) + assert comp.f1 == 1 and comp.f2 == 2 + + +def test_default(): + class TestDefaultComponent(Component): + f1: int = Field.attr(default=1) + f2: int = Field.attr() + + def evaluate(self) -> None: ... + + dc = TestDefaultComponent(f2=2) + assert dc.f1 == 1 and dc.f2 == 2 + + comp_def = Definition(TestDefaultComponent, "test", "0.0.1") + args = {"_minor": 0, "f2": 2} + comp: TestDefaultComponent = comp_def.make_component(args) + assert comp.f1 == 1 and comp.f2 == 2 + + +def test_outdated_vertical_table(): + class TestOutdatedVerticalTable(Component): + input1: Input = Field.input("", types=[DistDataType.OUTDATED_VERTICAL_TABLE]) + + def evaluate(self) -> None: ... + + with pytest.raises(Exception) as e: + Definition(TestOutdatedVerticalTable, "test", "0.0.1") + logging.info(f"Caught expected Exception: {e}") + + class TestOutdatedVerticalTable1(Component): + input1: Input = Field.input("", types=[DistDataType.VERTICAL_TABLE]) + + def evaluate(self) -> None: ... + + d1 = Definition(TestOutdatedVerticalTable1, "test", "0.0.1") + args = {"_minor": 0, "input1": Input(type=DistDataType.OUTDATED_VERTICAL_TABLE)} + comp: TestOutdatedVerticalTable1 = d1.make_component(args) + assert comp.input1.type == DistDataType.VERTICAL_TABLE + + +def test_variable_input(): + class TestVariableInput(Component): + input1: list[Input] = Field.input( + "desc", + types=[DistDataType.VERTICAL_TABLE], + list_limit=Interval.closed(1, 2), + ) + + def evaluate(self) -> None: ... + + d1 = Definition(TestVariableInput, "test", "0.0.1", "test_variable_input") + in0 = d1.component_def.inputs[0] + assert in0.is_variable and in0.variable_min == 1 and in0.variable_max == 2 + param = NodeEvalParam( + comp_id="test/test_variable_input:0.0.1", + inputs=[Input(type=DistDataType.VERTICAL_TABLE)], + ) + comp: TestVariableInput = d1.make_component(param) + assert len(comp.input1) == 1 diff --git a/tests/test_discovery.py b/tests/test_discovery.py new file mode 100644 index 0000000..5082424 --- /dev/null +++ b/tests/test_discovery.py @@ -0,0 +1,25 @@ +# Copyright 2024 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os + +from secretflow_spec import Registry, load_component_modules + + +def test_load_component_modules(): + root_dir = os.path.dirname(__file__) + load_component_modules(root_dir) + d = Registry.get_definition_by_id("test/my_component:0.0.1") + assert d diff --git a/tests/test_dist_data.py b/tests/test_dist_data.py new file mode 100644 index 0000000..f08090f --- /dev/null +++ b/tests/test_dist_data.py @@ -0,0 +1,182 @@ +# Copyright 2024 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from secretflow_spec import Reporter, VTable, VTableFieldKind, VTableFieldType +from secretflow_spec.core.dist_data.base import DistDataType +from secretflow_spec.core.dist_data.file import ObjectFile +from secretflow_spec.core.types import Version +from secretflow_spec.core.version import SPEC_VERSION +from secretflow_spec.v1.data_pb2 import ( + DistData, + IndividualTable, + TableSchema, + VerticalTable, +) +from secretflow_spec.v1.report_pb2 import Report + + +def test_vtable_field(): + assert VTableFieldType.BOOL.is_bool() + assert VTableFieldType.STR.is_string() + assert VTableFieldType.INT.is_integer() + assert VTableFieldType.INT32.is_integer() + assert VTableFieldType.UINT32.is_integer() + assert VTableFieldType.FLOAT.is_float() + assert VTableFieldType.FLOAT32.is_float() + assert VTableFieldType.is_same_type("int", "int64") + + assert VTableFieldKind.from_str("LABEL") == VTableFieldKind.LABEL + assert VTableFieldKind.from_str("LABEL|FEATURE") == VTableFieldKind.FEATURE_LABEL + + assert str(VTableFieldKind.LABEL) == "LABEL" + assert str(VTableFieldKind.FEATURE_LABEL) == "FEATURE|LABEL" + + +def test_vtable_indivitual(): + dd = DistData( + name="input_ds", + type=DistDataType.INDIVIDUAL_TABLE, + data_refs=[ + DistData.DataRef(uri="xx", party="alice", format="csv"), + ], + ) + meta = IndividualTable( + schema=TableSchema( + id_types=["str"], + ids=["id"], + label_types=["float"], + labels=["pred"], + feature_types=["int", "int", "int"], + features=["f1", "f2", "f3"], + ) + ) + dd.meta.Pack(meta) + + t = VTable.from_distdata(dd, ["id", "f2", "pred"]) + assert t.columns == ["id", "f2", "pred"] + + t = VTable.from_distdata(dd) + assert len(t.schemas) == 1 + + assert t.select(["f3", "f1"]).columns == ["f3", "f1"] + assert t.select_by_kinds(VTableFieldKind.LABEL).columns == ["pred"] + assert t.drop(["f2"]).columns == ["id", "f1", "f3", "pred"] + + pb = t.to_distdata() + assert pb.type == DistDataType.INDIVIDUAL_TABLE and len(pb.data_refs) == 1 + + +def test_vtable_vertical(): + dd = DistData( + name="input_ds", + type=str(DistDataType.VERTICAL_TABLE), + data_refs=[ + DistData.DataRef(uri="aa", party="alice", format="csv"), + DistData.DataRef(uri="bb", party="bob", format="csv"), + ], + ) + + meta = VerticalTable( + schemas=[ + TableSchema( + id_types=["str", "str"], + ids=["a1", "a2"], + feature_types=["float", "int"], + features=["a3", "a4"], + label_types=["int"], + labels=["a5"], + ), + TableSchema( + id_types=["str", "str"], + ids=["b1", "b2"], + feature_types=["float", "int"], + features=["b3", "b4"], + label_types=["int"], + labels=["b5"], + ), + ] + ) + dd.meta.Pack(meta) + + t = VTable.from_distdata(dd, columns=["a1", "a5", "a3"]) + assert t.columns == ["a1", "a5", "a3"] + + t = VTable.from_distdata(dd, columns=["a1", "a5", "a3", "b2", "b4"]) + assert t.columns == ["a1", "a5", "a3", "b2", "b4"] + + t = VTable.from_distdata(dd) + assert set(t.columns) == set( + [f"a{i+1}" for i in range(5)] + [f"b{i+1}" for i in range(5)] + ) + t1 = t.select(["a2", "a1"]) + assert t1.columns == ["a2", "a1"] + t2 = t.select(["a3", "a1", "b2", "b5"]) + assert t2.columns == ["a3", "a1", "b2", "b5"] + t3 = t.drop(["a2", "a3", "b2", "b5"]) + assert set(t3.columns) == set(["a1", "a4", "a5", "b1", "b3", "b4"]) + t4 = t.drop(["a1"]) + columns = t.columns + columns.remove("a1") + assert t4.columns == columns + t5 = t.select_by_kinds(VTableFieldKind.FEATURE) + assert t5.columns == ["a3", "a4", "b3", "b4"] + t6 = t.select_by_kinds(VTableFieldKind.FEATURE_LABEL) + assert t6.columns == ["a3", "a4", "a5", "b3", "b4", "b5"] + + orders = ["bob", "alice"] + assert list(t.sort_partitions(orders).schemas.keys()) == orders + + assert ["a1", "a2"] in t.schemas["alice"] + assert "a1" in t.schemas["alice"] + + pb = t.to_distdata() + assert pb.type == DistDataType.VERTICAL_TABLE and len(pb.data_refs) == 2 + assert pb.version == SPEC_VERSION + + +def test_report(): + r = Reporter(name="test_name", desc="test_desc") + # add descriptions + r.add_tab({"a": "a", "b": 1, "c": "0.1"}) + # add table + r.add_tab({"a": [1, 2], "b": ["b1", "b2"], "c": [0.1, 0.2]}) + + dd = r.to_distdata() + assert len(dd.data_refs) == 0 + assert dd.version == SPEC_VERSION + meta = Report() + assert dd.meta.Unpack(meta) + assert meta.name == "test_name" and meta.desc == "test_desc" and len(meta.tabs) == 2 + assert meta.tabs[0].divs[0].children[0].type == "descriptions" + assert meta.tabs[1].divs[0].children[0].type == "table" + + +def test_file(): + file_version = Version(1, 0) + public_info = {"a": 1, "b": 2} + df = ObjectFile( + name="xx", + type="sf.model.sgb", + data_refs=[ + DistData.DataRef(party="alice", uri="aa", format="pickle"), + DistData.DataRef(party="bob", uri="aa", format="pickle"), + ], + version=file_version, + public_info=public_info, + ) + dd = df.to_distdata() + assert dd.version == SPEC_VERSION + nfile = ObjectFile.from_distdata(dd) + assert nfile.version == file_version, nfile.public_info == public_info diff --git a/tests/test_registry.py b/tests/test_registry.py new file mode 100644 index 0000000..fcc5dd1 --- /dev/null +++ b/tests/test_registry.py @@ -0,0 +1,57 @@ +# Copyright 2024 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from secretflow_spec import Component, Registry, register +from secretflow_spec.core.definition import Definition +from secretflow_spec.core.version import SPEC_VERSION + + +@register( + domain="test", + version="1.0.0", + name="test_comp", + labels={ + "sf.use.mpc": True, + "sf.multi.party.computation": "true", + }, +) +class DemoCompnent(Component): + def evaluate(self): + print("eval") + + +def test_registry(): + definitions = Registry.get_definitions() + keys = list(Registry.get_definition_keys()) + assert len(definitions) > 0 and len(definitions) == len(keys) + first = Registry.get_definition_by_key(keys[0]) + assert first is not None + assert Registry.get_definition(first.domain, first.name, first.version) + assert Registry.get_definition_by_class(first.component_cls) + assert Registry.get_definition_by_id(first.component_id) + + test_comp_def = Registry.get_definition_by_id("test/test_comp:1.0.0") + assert test_comp_def + labels = test_comp_def.component_def.labels + assert ( + labels["sf.use.mpc"] == "true" + and labels["sf.multi.party.computation"] == "true" + ) + + comp_list = Registry.build_comp_list_def("test", "test", definitions) + comp_ids = [ + Definition.build_id(c.domain, c.name, c.version) for c in comp_list.comps + ] + assert sorted(comp_ids) == comp_ids + assert comp_list.version == SPEC_VERSION diff --git a/tests/test_storage.py b/tests/test_storage.py new file mode 100644 index 0000000..45eaeca --- /dev/null +++ b/tests/test_storage.py @@ -0,0 +1,49 @@ +# Copyright 2024 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os + +import pytest + +from secretflow_spec.core.storage import make_storage +from secretflow_spec.v1.data_pb2 import StorageConfig + + +def test_local(): + root_dir = os.path.dirname(__file__) + s = make_storage( + StorageConfig( + type="local_fs", local_fs=StorageConfig.LocalFSConfig(wd=root_dir) + ) + ) + file = "test_storage.py" + p = s.get_full_path(file) + assert p == __file__ + size = s.get_size(file) + assert size > 0 + with s.get_reader(file) as r: + data = r.read() + assert len(data) > 0 + + assert s.exists(file) + + not_exists_file = "not_exists_file" + assert s.exists(not_exists_file) == False + with pytest.raises(FileNotFoundError) as e: + s.get_reader(not_exists_file) + with pytest.raises(IsADirectoryError) as e: + s.get_writer(root_dir) + with pytest.raises(IsADirectoryError) as e: + s.get_reader(root_dir) diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..73690b6 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,31 @@ +# Copyright 2024 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from secretflow_spec import Definition, build_node_eval_param + + +def test_build_node_eval_param(): + param = build_node_eval_param( + domain="test_domain", + name="test_name", + version="1.0.0", + attrs={"a": 1, "b": "s"}, + ) + + domain, name, version = Definition.parse_id(param.comp_id) + assert domain == "test_domain" + assert name == "test_name" + assert version == "1.0.0" + assert Definition.parse_minor(version) == 0 From a312a6c821a64d2f63ee7a447c6e53194894a1bf Mon Sep 17 00:00:00 2001 From: danqi Date: Wed, 26 Feb 2025 19:53:28 +0800 Subject: [PATCH 2/2] repo-sync-2025-02-26T19:53:15+0800 --- .github/workflows/publish_pypi.yml | 15 +++++++++++++++ 1 file changed, 15 insertions(+) create mode 100644 .github/workflows/publish_pypi.yml diff --git a/.github/workflows/publish_pypi.yml b/.github/workflows/publish_pypi.yml new file mode 100644 index 0000000..16415fa --- /dev/null +++ b/.github/workflows/publish_pypi.yml @@ -0,0 +1,15 @@ +name: "Publish PyPI Package" +on: + workflow_dispatch: + +jobs: + trigger-circleci: + runs-on: ubuntu-latest + steps: + - name: spec-deploy + id: spec-deploy + uses: CircleCI-Public/trigger-circleci-pipeline-action@v1.0.5 + with: + GHA_Meta: "publish_pypi" + env: + CCI_TOKEN: ${{ secrets.CCI_TOKEN }}