From 6271fb80a7b7e6fe7b053ce8dabbafdb903a759e Mon Sep 17 00:00:00 2001 From: Maxim Zhiltsov Date: Mon, 23 Dec 2024 19:59:35 +0300 Subject: [PATCH 1/3] [CVAT] Improve dev mocks (#2957) * Align eo and ro webhook validation code * Move log formatting function * Remove dev mocks from real webhook validation code * Fix configs for dev variables * Refactor and update dev mocks * Improve request data logging --- .../examples/cvat/exchange-oracle/debug.py | 286 +++++++++++------- .../cvat/exchange-oracle/src/.env.template | 2 +- .../cvat/exchange-oracle/src/chain/escrow.py | 10 +- .../cvat/exchange-oracle/src/core/config.py | 16 +- .../src/crons/cvat/state_trackers.py | 2 +- .../src/endpoints/middleware.py | 4 +- .../src/handlers/job_creation.py | 4 +- .../examples/cvat/exchange-oracle/src/log.py | 8 - .../cvat/exchange-oracle/src/utils/logging.py | 9 +- .../examples/cvat/recording-oracle/debug.py | 232 ++++++++------ .../cvat/recording-oracle/src/.env.template | 1 + .../cvat/recording-oracle/src/chain/escrow.py | 12 +- .../cvat/recording-oracle/src/core/config.py | 2 + .../recording-oracle/src/utils/logging.py | 9 +- .../src/validators/signature.py | 22 +- .../tests/integration/chain/test_escrow.py | 25 -- 16 files changed, 365 insertions(+), 279 deletions(-) diff --git a/packages/examples/cvat/exchange-oracle/debug.py b/packages/examples/cvat/exchange-oracle/debug.py index aa8052cdd1..66168f9c65 100644 --- a/packages/examples/cvat/exchange-oracle/debug.py +++ b/packages/examples/cvat/exchange-oracle/debug.py @@ -1,35 +1,28 @@ import datetime -import hashlib import json +from collections.abc import Generator +from contextlib import ExitStack, contextmanager +from logging import Logger from pathlib import Path from typing import Any +from unittest import mock import uvicorn from httpx import URL from src.chain.kvstore import register_in_kvstore from src.core.config import Config +from src.db import SessionLocal from src.services import cloud +from src.services import cvat as cvat_service from src.services.cloud import BucketAccessInfo -from src.utils.logging import get_function_logger +from src.utils.logging import format_sequence, get_function_logger -def apply_local_development_patches(): - """ - Applies local development patches to avoid manual source code modification.: - - Overrides `EscrowUtils.get_escrow` to return local escrow data with mock values if the escrow - address matches a local manifest. - - Loads local manifest files from cloud storage into `LOCAL_MANIFEST_FILES`. - - Disables address validation by overriding `validate_address`. - - Replaces `validate_oracle_webhook_signature` with a lenient version that uses - partial signature parsing. - - Generates ECDSA keys if not already present for local JWT signing, - and sets the public key in `Config.human_app_config`. - """ - import src.handlers.job_creation - - original_make_cvat_cloud_storage_params = ( - src.handlers.job_creation._make_cvat_cloud_storage_params +@contextmanager +def _mock_cvat_cloud_storage_params(logger: Logger) -> Generator[None, None, None]: + from src.handlers.job_creation import ( + _make_cvat_cloud_storage_params as original_make_cvat_cloud_storage_params, ) def patched_make_cvat_cloud_storage_params(bucket_info: BucketAccessInfo) -> dict: @@ -37,114 +30,172 @@ def patched_make_cvat_cloud_storage_params(bucket_info: BucketAccessInfo) -> dic if Config.development_config.cvat_in_docker: bucket_info.host_url = str( - URL(original_host_url).copy_with(host=Config.development_config.cvat_local_host) + URL(original_host_url).copy_with( + host=Config.development_config.exchange_oracle_host + ) ) logger.info( f"DEV: Changed {original_host_url} to {bucket_info.host_url} for CVAT storage" ) + try: return original_make_cvat_cloud_storage_params(bucket_info) finally: bucket_info.host_url = original_host_url - src.handlers.job_creation._make_cvat_cloud_storage_params = ( - patched_make_cvat_cloud_storage_params - ) + with mock.patch( + "src.handlers.job_creation._make_cvat_cloud_storage_params", + patched_make_cvat_cloud_storage_params, + ): + yield - def prepare_signed_message( - escrow_address, - chain_id, - message: str | None = None, - body: dict | None = None, - ) -> tuple[None, str]: - digest = hashlib.sha256( - (escrow_address + ":".join(map(str, (chain_id, message, body)))).encode() - ).hexdigest() - signature = f"{OracleWebhookTypes.recording_oracle}:{digest}" - logger.info(f"DEV: Generated patched signature {signature}") - return None, signature - - src.utils.webhooks.prepare_signed_message = prepare_signed_message +@contextmanager +def _mock_get_manifests_from_minio(logger: Logger) -> Generator[None, None, None]: from human_protocol_sdk.constants import ChainId from human_protocol_sdk.escrow import EscrowData, EscrowUtils - logger = get_function_logger(apply_local_development_patches.__name__) - minio_client = cloud.make_client(BucketAccessInfo.parse_obj(Config.storage_config)) + original_get_escrow = EscrowUtils.get_escrow - def get_local_escrow(chain_id: int, escrow_address: str) -> EscrowData: - possible_manifest_name = escrow_address.split(":")[0] - local_manifests = minio_client.list_files(bucket="manifests") - logger.info(f"DEV: Local manifests: {local_manifests}") - if possible_manifest_name in local_manifests: - logger.info(f"DEV: Using local manifest {escrow_address}") - return EscrowData( - chain_id=ChainId(chain_id), - id="test", - address=escrow_address, - amount_paid=10, - balance=10, - count=1, - factory_address="", - launcher="", - status="Pending", - token="HMT", # noqa: S106 - total_funded_amount=10, - created_at=datetime.datetime(2023, 1, 1, tzinfo=datetime.timezone.utc), - manifest_url=( - f"http://{Config.storage_config.endpoint_url}/manifests/{possible_manifest_name}" - ), + def patched_get_escrow(chain_id: int, escrow_address: str) -> EscrowData: + minio_manifests = minio_client.list_files(bucket="manifests") + logger.debug(f"DEV: Local manifests: {format_sequence(minio_manifests)}") + + candidate_files = [fn for fn in minio_manifests if f"{escrow_address}.json" in fn] + if not candidate_files: + return original_get_escrow(ChainId(chain_id), escrow_address) + elif len(candidate_files) != 1: + raise Exception( + "Can't select local manifest to be used for escrow '{}'" + " - several manifests math: {}".format( + escrow_address, format_sequence(candidate_files) + ) ) - return original_get_escrow(ChainId(chain_id), escrow_address) - original_get_escrow = EscrowUtils.get_escrow - EscrowUtils.get_escrow = get_local_escrow + manifest_file = candidate_files[0] + escrow = EscrowData( + chain_id=ChainId(chain_id), + id="test", + address=escrow_address, + amount_paid=10, + balance=10, + count=1, + factory_address="", + launcher="", + status="Pending", + token="HMT", # noqa: S106 + total_funded_amount=10, + created_at=datetime.datetime(2023, 1, 1, tzinfo=datetime.timezone.utc), + manifest_url=(f"http://{Config.storage_config.endpoint_url}/manifests/{manifest_file}"), + ) + + logger.info(f"DEV: Using local manifest '{manifest_file}' for escrow '{escrow_address}'") + return escrow + + with mock.patch.object(EscrowUtils, "get_escrow", patched_get_escrow): + yield + + +@contextmanager +def _mock_webhook_signature_checking(_: Logger) -> Generator[None, None, None]: + """ + Allows to receive webhooks from other services: + - from launcher - with signature "job_launcher" + - from recording oracle - + encoded with Config.localhost.recording_oracle_address wallet address + or signature "recording_oracle" + - from reputation oracle - + encoded with Config.localhost.reputation_oracle_address wallet address + or signature "reputation_oracle" + """ - import src.schemas.webhook + from src.chain.escrow import ( + get_available_webhook_types as original_get_available_webhook_types, + ) from src.core.types import OracleWebhookTypes + from src.validators.signature import ( + validate_oracle_webhook_signature as original_validate_oracle_webhook_signature, + ) - src.schemas.webhook.validate_address = lambda x: x + async def patched_validate_oracle_webhook_signature(request, signature, webhook): + for webhook_type in OracleWebhookTypes: + if signature.startswith(webhook_type.value.lower()): + return webhook_type + + return await original_validate_oracle_webhook_signature(request, signature, webhook) + + def patched_get_available_webhook_types(chain_id, escrow_address): + d = dict(original_get_available_webhook_types(chain_id, escrow_address)) + d[Config.localhost.recording_oracle_address.lower()] = OracleWebhookTypes.recording_oracle + d[Config.localhost.reputation_oracle_address.lower()] = OracleWebhookTypes.reputation_oracle + return d + + with ( + mock.patch("src.schemas.webhook.validate_address", lambda x: x), + mock.patch( + "src.validators.signature.get_available_webhook_types", + patched_get_available_webhook_types, + ), + mock.patch( + "src.endpoints.webhook.validate_oracle_webhook_signature", + patched_validate_oracle_webhook_signature, + ), + ): + yield + + +@contextmanager +def _mock_endpoint_auth(logger: Logger) -> Generator[None, None, None]: + """ + Allows simplified authentication: + - Bearer {"wallet_address": "...", "email": "..."} + - Bearer {"role": "human_app"} + """ - async def lenient_validate_oracle_webhook_signature(request, signature, webhook): - from src.validators.signature import validate_oracle_webhook_signature + from src.endpoints.authentication import HUMAN_APP_ROLE, TokenAuthenticator - try: - parsed_type = OracleWebhookTypes(signature.split(":")[0]) - logger.info(f"DEV: Recovered {parsed_type} from the signature {signature}") - except (ValueError, TypeError): - return await validate_oracle_webhook_signature(request, signature, webhook) + original_decode_token = TokenAuthenticator._decode_token - import src.endpoints.webhook + def decode_plain_json_token(self, token) -> dict[str, Any]: + try: + token_data = json.loads(token) - src.endpoints.webhook.validate_oracle_webhook_signature = ( - lenient_validate_oracle_webhook_signature - ) + if (user_wallet := token_data.get("wallet_address")) and not token_data.get("email"): + with SessionLocal.begin() as session: + user = cvat_service.get_user_by_id(session, user_wallet) + if not user: + raise Exception(f"Could not find user with wallet address '{user_wallet}'") - import src.endpoints.authentication + token_data["email"] = user.cvat_email - original_decode_token = src.endpoints.authentication.TokenAuthenticator._decode_token + if token_data.get("role") == HUMAN_APP_ROLE: + token_data["wallet_address"] = None + token_data["email"] = "" - def decode_plain_json_token(self, token) -> dict[str, Any]: - """ - Allows Authentication: Bearer {"wallet_address": "...", "email": "..."} - """ - try: - decoded = json.loads(token) - logger.info(f"DEV: Decoded plain JSON auth token: {decoded}") + logger.info(f"DEV: Decoded plain JSON auth token: {token_data}") + return token_data except (ValueError, TypeError): return original_decode_token(self, token) - src.endpoints.authentication.TokenAuthenticator._decode_token = decode_plain_json_token + with mock.patch.object(TokenAuthenticator, "_decode_token", decode_plain_json_token): + yield + + +@contextmanager +def _mock_human_app_keys(_: Logger) -> Generator[None, None, None]: + "Creates or uses local Human App JWT keys" from tests.api.test_exchange_api import generate_ecdsa_keys # generating keys for local development repo_root = Path(__file__).parent - human_app_private_key_file, human_app_public_key_file = ( - repo_root / "human_app_private_key.pem", - repo_root / "human_app_public_key.pem", - ) + dev_dir = repo_root / "dev" + dev_dir.mkdir(exist_ok=True) + + human_app_private_key_file = dev_dir / "human_app_private_key.pem" + human_app_public_key_file = dev_dir / "human_app_public_key.pem" + if not (human_app_public_key_file.exists() and human_app_private_key_file.exists()): private_key, public_key = generate_ecdsa_keys() human_app_private_key_file.write_text(private_key) @@ -154,20 +205,47 @@ def decode_plain_json_token(self, token) -> dict[str, Any]: Config.human_app_config.jwt_public_key = public_key - logger.warning("DEV: Local development patches applied.") + yield + + +@contextmanager +def apply_local_development_patches() -> Generator[None, None, None]: + """ + Applies local development patches to avoid manual source code modification + """ + + logger = get_function_logger(apply_local_development_patches.__name__) + + logger.warning("DEV: Applying local development patches") + + with ExitStack() as es: + for mock_callback in ( + _mock_cvat_cloud_storage_params, + _mock_get_manifests_from_minio, + _mock_webhook_signature_checking, + _mock_endpoint_auth, + _mock_human_app_keys, + ): + logger.warning(f"DEV: applying patch {mock_callback.__name__}...") + es.enter_context(mock_callback(logger)) + + logger.warning("DEV: Local development patches applied.") + + yield if __name__ == "__main__": - is_dev = Config.environment == "development" - if is_dev: - apply_local_development_patches() - - Config.validate() - register_in_kvstore() - - uvicorn.run( - app="src:app", - host="0.0.0.0", # noqa: S104 - port=int(Config.port), - workers=Config.workers_amount, - ) + with ExitStack() as es: + is_dev = Config.environment == "development" + if is_dev: + es.enter_context(apply_local_development_patches()) + + Config.validate() + register_in_kvstore() + + uvicorn.run( + app="src:app", + host="0.0.0.0", # noqa: S104 + port=int(Config.port), + workers=Config.workers_amount, + ) diff --git a/packages/examples/cvat/exchange-oracle/src/.env.template b/packages/examples/cvat/exchange-oracle/src/.env.template index ae4b9dfb05..cd3f86a064 100644 --- a/packages/examples/cvat/exchange-oracle/src/.env.template +++ b/packages/examples/cvat/exchange-oracle/src/.env.template @@ -129,4 +129,4 @@ PGP_PUBLIC_KEY_URL= # Development DEV_CVAT_IN_DOCKER= -DEV_CVAT_LOCAL_HOST= +DEV_EXCHANGE_ORACLE_HOST= diff --git a/packages/examples/cvat/exchange-oracle/src/chain/escrow.py b/packages/examples/cvat/exchange-oracle/src/chain/escrow.py index bc4ac9b51e..c5679fc7f7 100644 --- a/packages/examples/cvat/exchange-oracle/src/chain/escrow.py +++ b/packages/examples/cvat/exchange-oracle/src/chain/escrow.py @@ -73,11 +73,7 @@ def get_available_webhook_types( ) -> dict[str, OracleWebhookTypes]: escrow = get_escrow(chain_id, escrow_address) return { - escrow.launcher.lower(): OracleWebhookTypes.job_launcher, - ( - Config.localhost.recording_oracle_address or escrow.recording_oracle - ).lower(): OracleWebhookTypes.recording_oracle, - ( - Config.localhost.reputation_oracle_address or escrow.reputation_oracle - ).lower(): OracleWebhookTypes.reputation_oracle, + (escrow.launcher or "").lower(): OracleWebhookTypes.job_launcher, + (escrow.recording_oracle or "").lower(): OracleWebhookTypes.recording_oracle, + (escrow.reputation_oracle or "").lower(): OracleWebhookTypes.reputation_oracle, } diff --git a/packages/examples/cvat/exchange-oracle/src/core/config.py b/packages/examples/cvat/exchange-oracle/src/core/config.py index cc5d560093..60562e8eab 100644 --- a/packages/examples/cvat/exchange-oracle/src/core/config.py +++ b/packages/examples/cvat/exchange-oracle/src/core/config.py @@ -301,10 +301,16 @@ def validate(cls) -> None: raise Exception(" ".join([ex_prefix, str(ex)])) -class Development: - cvat_in_docker = bool(int(os.environ.get("DEV_CVAT_IN_DOCKER", "0"))) - # might be `host.docker.internal` or `172.22.0.1` if CVAT is running in docker - cvat_local_host = os.environ.get("DEV_CVAT_LOCAL_HOST", "localhost") +class DevelopmentConfig: + cvat_in_docker = bool(int(os.environ.get("DEV_CVAT_IN_DOCKER", "1"))) + + exchange_oracle_host = os.environ.get("DEV_EXCHANGE_ORACLE_HOST", "172.22.0.1") + """ + Might be `host.docker.internal` or `172.22.0.1` if CVAT is running in Docker. + + Remember to allow this host via: + SMOKESCREEN_OPTS="--allow-address=" docker compose ... + """ class Environment(str, Enum): @@ -346,7 +352,7 @@ class Config: features = FeaturesConfig core_config = CoreConfig encryption_config = EncryptionConfig - development_config = Development + development_config = DevelopmentConfig @classmethod def is_development_mode(cls) -> bool: diff --git a/packages/examples/cvat/exchange-oracle/src/crons/cvat/state_trackers.py b/packages/examples/cvat/exchange-oracle/src/crons/cvat/state_trackers.py index cf0e74c410..e4b689aaa6 100644 --- a/packages/examples/cvat/exchange-oracle/src/crons/cvat/state_trackers.py +++ b/packages/examples/cvat/exchange-oracle/src/crons/cvat/state_trackers.py @@ -15,7 +15,7 @@ from src.db import errors as db_errors from src.db.utils import ForUpdateParams from src.handlers.completed_escrows import handle_escrows_validations -from src.log import format_sequence +from src.utils.logging import format_sequence @cron_job diff --git a/packages/examples/cvat/exchange-oracle/src/endpoints/middleware.py b/packages/examples/cvat/exchange-oracle/src/endpoints/middleware.py index e111f1e2f3..64fedbdfce 100644 --- a/packages/examples/cvat/exchange-oracle/src/endpoints/middleware.py +++ b/packages/examples/cvat/exchange-oracle/src/endpoints/middleware.py @@ -136,7 +136,9 @@ async def _log_request(self, request: Request) -> dict[str, Any]: if raw_body: body = body.decode(errors="ignore") - body = body[: self.max_displayed_body_size] + + if len(body) > self.max_displayed_body_size: + body = body[: self.max_displayed_body_size - 3] + "..." request_logging["body"] = body diff --git a/packages/examples/cvat/exchange-oracle/src/handlers/job_creation.py b/packages/examples/cvat/exchange-oracle/src/handlers/job_creation.py index 2385d00774..891f827ec4 100644 --- a/packages/examples/cvat/exchange-oracle/src/handlers/job_creation.py +++ b/packages/examples/cvat/exchange-oracle/src/handlers/job_creation.py @@ -33,13 +33,13 @@ from src.core.storage import compose_data_bucket_filename from src.core.types import CvatLabelTypes, TaskStatuses, TaskTypes from src.db import SessionLocal -from src.log import ROOT_LOGGER_NAME, format_sequence +from src.log import ROOT_LOGGER_NAME from src.models.cvat import Project from src.services.cloud import CloudProviders, StorageClient from src.services.cloud.utils import BucketAccessInfo, compose_bucket_url from src.utils.annotations import InstanceSegmentsToBbox, ProjectLabels, is_point_in_bbox from src.utils.assignments import parse_manifest -from src.utils.logging import NullLogger, get_function_logger +from src.utils.logging import NullLogger, format_sequence, get_function_logger from src.utils.zip_archive import write_dir_to_zip_archive if TYPE_CHECKING: diff --git a/packages/examples/cvat/exchange-oracle/src/log.py b/packages/examples/cvat/exchange-oracle/src/log.py index 406f243e43..7a26a6a287 100644 --- a/packages/examples/cvat/exchange-oracle/src/log.py +++ b/packages/examples/cvat/exchange-oracle/src/log.py @@ -1,9 +1,7 @@ """Config for the application logger""" import logging -from collections.abc import Sequence from logging.config import dictConfig -from typing import Any from src.core.config import Config @@ -50,9 +48,3 @@ def setup_logging(): def get_root_logger() -> logging.Logger: return logging.getLogger(ROOT_LOGGER_NAME) - - -def format_sequence(items: Sequence[Any], *, max_items: int = 5, separator: str = ", ") -> str: - remainder_count = len(items) - max_items - tail = f" (and {remainder_count} more)" if remainder_count > 0 else "" - return f"{separator.join(map(str, items[:max_items]))}{tail}" diff --git a/packages/examples/cvat/exchange-oracle/src/utils/logging.py b/packages/examples/cvat/exchange-oracle/src/utils/logging.py index e7660eb0d7..a8d8c94daf 100644 --- a/packages/examples/cvat/exchange-oracle/src/utils/logging.py +++ b/packages/examples/cvat/exchange-oracle/src/utils/logging.py @@ -1,5 +1,6 @@ import logging -from typing import NewType +from collections.abc import Sequence +from typing import Any, NewType from src.utils.stack import current_function_name @@ -30,3 +31,9 @@ class NullLogger(logging.Logger): def __init__(self, name: str = "", level=0) -> None: super().__init__(name, level) self.disabled = True + + +def format_sequence(items: Sequence[Any], *, max_items: int = 5, separator: str = ", ") -> str: + remainder_count = len(items) - max_items + tail = f" (and {remainder_count} more)" if remainder_count > 0 else "" + return f"{separator.join(map(str, items[:max_items]))}{tail}" diff --git a/packages/examples/cvat/recording-oracle/debug.py b/packages/examples/cvat/recording-oracle/debug.py index 3915e79929..78d63c2a28 100644 --- a/packages/examples/cvat/recording-oracle/debug.py +++ b/packages/examples/cvat/recording-oracle/debug.py @@ -1,5 +1,8 @@ import datetime -import hashlib +from collections.abc import Generator +from contextlib import ExitStack, contextmanager +from logging import Logger +from unittest import mock import uvicorn @@ -7,123 +10,152 @@ from src.core.config import Config from src.services import cloud from src.services.cloud import BucketAccessInfo -from src.utils.logging import get_function_logger +from src.utils.logging import format_sequence, get_function_logger -def apply_local_development_patches(): - """ - Applies local development patches to bypass direct source code modifications: - - Overrides `EscrowUtils.get_escrow` to retrieve local escrow data for specific addresses, - using mock data if the address corresponds to a local manifest. - - Updates local manifest files from cloud storage. - - Overrides `validate_address` to disable address validation. - - Replaces `validate_oracle_webhook_signature` with a lenient version for oracle signature - validation in development. - - Replaces `src.chain.escrow.store_results` to avoid attempting to store results on chain. - - Replaces `src.validators.signature.validate_oracle_webhook_signature` to always return - `OracleWebhookTypes.exchange_oracle`. - """ - logger = get_function_logger(apply_local_development_patches.__name__) +@contextmanager +def _mock_get_manifests_from_minio(logger: Logger) -> Generator[None, None, None]: + from human_protocol_sdk.constants import ChainId + from human_protocol_sdk.escrow import EscrowData, EscrowUtils + + minio_client = cloud.make_client( + BucketAccessInfo.parse_obj(Config.exchange_oracle_storage_config) + ) + original_get_escrow = EscrowUtils.get_escrow - import src.crons._utils + def patched_get_escrow(chain_id: int, escrow_address: str) -> EscrowData: + minio_manifests = minio_client.list_files(bucket="manifests") + logger.debug(f"DEV: Local manifests: {format_sequence(minio_manifests)}") + + candidate_files = [fn for fn in minio_manifests if f"{escrow_address}.json" in fn] + if not candidate_files: + return original_get_escrow(ChainId(chain_id), escrow_address) + if len(candidate_files) != 1: + raise Exception( + f"Can't select local manifest to be used for escrow '{escrow_address}'" + f" - several manifests math: {format_sequence(candidate_files)}" + ) - def prepare_signed_message( - escrow_address, + manifest_file = candidate_files[0] + escrow = EscrowData( + chain_id=ChainId(chain_id), + id="test", + address=escrow_address, + amount_paid=10, + balance=10, + count=1, + factory_address="", + launcher="", + status="Pending", + token="HMT", # noqa: S106 + total_funded_amount=10, + created_at=datetime.datetime(2023, 1, 1, tzinfo=datetime.timezone.utc), + manifest_url=(f"http://{Config.storage_config.endpoint_url}/manifests/{manifest_file}"), + ) + + logger.info(f"DEV: Using local manifest '{manifest_file}' for escrow '{escrow_address}'") + return escrow + + with mock.patch.object(EscrowUtils, "get_escrow", patched_get_escrow): + yield + + +@contextmanager +def _mock_escrow_results_saving(logger: Logger) -> Generator[None, None, None]: + def patched_store_results( chain_id, - message: str | None = None, - body: dict | None = None, - ) -> tuple[None, str]: - digest = hashlib.sha256( - (escrow_address + ":".join(map(str, (chain_id, message, body)))).encode() - ).hexdigest() - signature = f"{OracleWebhookTypes.recording_oracle}:{digest}" - logger.info(f"DEV: Generated patched signature {signature}") - return None, signature - - src.crons._utils.prepare_signed_message = prepare_signed_message + escrow_address, + url, + hash, + ) -> None: + logger.info( + f"DEV: Would store results for escrow '{escrow_address}@{chain_id}' " + f"on chain: {url}, {hash}" + ) - from human_protocol_sdk.constants import ChainId - from human_protocol_sdk.escrow import EscrowData, EscrowUtils + with mock.patch("src.chain.escrow.store_results", patched_store_results): + yield - minio_client = cloud.make_client(BucketAccessInfo.parse_obj(Config.storage_config)) - - def get_local_escrow(chain_id: int, escrow_address: str) -> EscrowData: - possible_manifest_name = escrow_address.split(":")[0] - local_manifests = minio_client.list_files(bucket="manifests") - logger.info(f"Local manifests: {local_manifests}") - if possible_manifest_name in local_manifests: - logger.info(f"DEV: Using local manifest {escrow_address}") - return EscrowData( - chain_id=ChainId(chain_id), - id="test", - address=escrow_address, - amount_paid=10, - balance=10, - count=1, - factory_address="", - launcher="", - status="Pending", - token="HMT", # noqa: S106 - total_funded_amount=10, - created_at=datetime.datetime(2023, 1, 1, tzinfo=datetime.timezone.utc), - manifest_url=( - f"http://{Config.storage_config.endpoint_url}/manifests/{possible_manifest_name}" - ), - ) - return original_get_escrow(ChainId(chain_id), escrow_address) - original_get_escrow = EscrowUtils.get_escrow - EscrowUtils.get_escrow = get_local_escrow +@contextmanager +def _mock_webhook_signature_checking(_: Logger) -> Generator[None, None, None]: + """ + Allows to receive webhooks from other services: + - from exchange oracle - + signed with Config.localhost.exchange_oracle_address + or with signature "exchange_oracle" + """ - import src.schemas.webhook + from src.chain.escrow import ( + get_available_webhook_types as original_get_available_webhook_types, + ) from src.core.types import OracleWebhookTypes + from src.validators.signature import ( + validate_oracle_webhook_signature as original_validate_oracle_webhook_signature, + ) - src.schemas.webhook.validate_address = lambda x: x - - async def lenient_validate_oracle_webhook_signature( - request, # noqa: ARG001 (not relevant here) - signature, - webhook, # noqa: ARG001 (not relevant here) + async def patched_validate_oracle_webhook_signature(request, signature, webhook): + for webhook_type in OracleWebhookTypes: + if signature.startswith(webhook_type.value.lower()): + return webhook_type + + return await original_validate_oracle_webhook_signature(request, signature, webhook) + + def patched_get_available_webhook_types(chain_id, escrow_address): + d = dict(original_get_available_webhook_types(chain_id, escrow_address)) + d[Config.localhost.exchange_oracle_address.lower()] = OracleWebhookTypes.exchange_oracle + return d + + with ( + mock.patch("src.schemas.webhook.validate_address", lambda x: x), + mock.patch( + "src.validators.signature.get_available_webhook_types", + patched_get_available_webhook_types, + ), + mock.patch( + "src.endpoints.webhook.validate_oracle_webhook_signature", + patched_validate_oracle_webhook_signature, + ), ): - try: - parsed_type = OracleWebhookTypes(signature.split(":")[0]) - logger.info(f"DEV: Recovered {parsed_type} from the signature {signature}") - except (ValueError, TypeError): - logger.info(f"DEV: Falling back to {OracleWebhookTypes.exchange_oracle} webhook sender") - return OracleWebhookTypes.exchange_oracle + yield - import src.endpoints.webhook - src.endpoints.webhook.validate_oracle_webhook_signature = ( - lenient_validate_oracle_webhook_signature - ) +@contextmanager +def apply_local_development_patches() -> Generator[None, None, None]: + """ + Applies local development patches to avoid manual source code modification + """ - import src.chain.escrow + logger = get_function_logger(apply_local_development_patches.__name__) - def store_results( - chain_id: int, # noqa: ARG001 (not relevant here) - escrow_address: str, - url: str, - hash: str, - ) -> None: - logger.info(f"Would store results for escrow {escrow_address} on chain: {url}, {hash}") + logger.warning("DEV: Applying local development patches") + + with ExitStack() as es: + for mock_callback in ( + _mock_get_manifests_from_minio, + _mock_webhook_signature_checking, + _mock_escrow_results_saving, + ): + logger.warning(f"DEV: applying patch {mock_callback.__name__}...") + es.enter_context(mock_callback(logger)) - src.chain.escrow.store_results = store_results + logger.warning("DEV: Local development patches applied.") - logger.warning("Local development patches applied.") + yield if __name__ == "__main__": - is_dev = Config.environment == "development" - if is_dev: - apply_local_development_patches() - - Config.validate() - register_in_kvstore() - - uvicorn.run( - app="src:app", - host="0.0.0.0", # noqa: S104 - port=int(Config.port), - workers=Config.workers_amount, - ) + with ExitStack() as es: + is_dev = Config.environment == "development" + if is_dev: + es.enter_context(apply_local_development_patches()) + + Config.validate() + register_in_kvstore() + + uvicorn.run( + app="src:app", + host="0.0.0.0", # noqa: S104 + port=int(Config.port), + workers=Config.workers_amount, + ) diff --git a/packages/examples/cvat/recording-oracle/src/.env.template b/packages/examples/cvat/recording-oracle/src/.env.template index e4889a7ddb..689119519b 100644 --- a/packages/examples/cvat/recording-oracle/src/.env.template +++ b/packages/examples/cvat/recording-oracle/src/.env.template @@ -66,6 +66,7 @@ CVAT_QUALITY_CHECK_INTERVAL= # Localhost +LOCALHOST_EXCHANGE_ORACLE_ADDRESS= LOCALHOST_EXCHANGE_ORACLE_URL= LOCALHOST_REPUTATION_ORACLE_URL= diff --git a/packages/examples/cvat/recording-oracle/src/chain/escrow.py b/packages/examples/cvat/recording-oracle/src/chain/escrow.py index 33b8c78a85..bc6e89da60 100644 --- a/packages/examples/cvat/recording-oracle/src/chain/escrow.py +++ b/packages/examples/cvat/recording-oracle/src/chain/escrow.py @@ -7,6 +7,7 @@ from src.chain.web3 import get_web3 from src.core.config import Config +from src.core.types import OracleWebhookTypes def get_escrow(chain_id: int, escrow_address: str) -> EscrowData: @@ -64,9 +65,8 @@ def store_results(chain_id: int, escrow_address: str, url: str, hash: str) -> No escrow_client.store_results(escrow_address, url, hash) -def get_reputation_oracle_address(chain_id: int, escrow_address: str) -> str: - return get_escrow(chain_id, escrow_address).reputation_oracle - - -def get_exchange_oracle_address(chain_id: int, escrow_address: str) -> str: - return get_escrow(chain_id, escrow_address).exchange_oracle +def get_available_webhook_types( + chain_id: int, escrow_address: str +) -> dict[str, OracleWebhookTypes]: + escrow = get_escrow(chain_id, escrow_address) + return {(escrow.exchange_oracle or "").lower(): OracleWebhookTypes.exchange_oracle} diff --git a/packages/examples/cvat/recording-oracle/src/core/config.py b/packages/examples/cvat/recording-oracle/src/core/config.py index d2d539340a..029d5e364b 100644 --- a/packages/examples/cvat/recording-oracle/src/core/config.py +++ b/packages/examples/cvat/recording-oracle/src/core/config.py @@ -79,7 +79,9 @@ class LocalhostConfig(_NetworkConfig): ) addr = os.environ.get("LOCALHOST_AMOY_ADDR", "0xf39Fd6e51aad88F6F4ce6aB8827279cffFb92266") + exchange_oracle_address = os.environ.get("LOCALHOST_EXCHANGE_ORACLE_ADDRESS") exchange_oracle_url = os.environ.get("LOCALHOST_EXCHANGE_ORACLE_URL") + reputation_oracle_url = os.environ.get("LOCALHOST_REPUTATION_ORACLE_URL") diff --git a/packages/examples/cvat/recording-oracle/src/utils/logging.py b/packages/examples/cvat/recording-oracle/src/utils/logging.py index e7660eb0d7..a8d8c94daf 100644 --- a/packages/examples/cvat/recording-oracle/src/utils/logging.py +++ b/packages/examples/cvat/recording-oracle/src/utils/logging.py @@ -1,5 +1,6 @@ import logging -from typing import NewType +from collections.abc import Sequence +from typing import Any, NewType from src.utils.stack import current_function_name @@ -30,3 +31,9 @@ class NullLogger(logging.Logger): def __init__(self, name: str = "", level=0) -> None: super().__init__(name, level) self.disabled = True + + +def format_sequence(items: Sequence[Any], *, max_items: int = 5, separator: str = ", ") -> str: + remainder_count = len(items) - max_items + tail = f" (and {remainder_count} more)" if remainder_count > 0 else "" + return f"{separator.join(map(str, items[:max_items]))}{tail}" diff --git a/packages/examples/cvat/recording-oracle/src/validators/signature.py b/packages/examples/cvat/recording-oracle/src/validators/signature.py index deb3364a70..bd93a83f3b 100644 --- a/packages/examples/cvat/recording-oracle/src/validators/signature.py +++ b/packages/examples/cvat/recording-oracle/src/validators/signature.py @@ -3,7 +3,7 @@ from fastapi import HTTPException, Request -from src.chain.escrow import get_exchange_oracle_address +from src.chain.escrow import get_available_webhook_types from src.chain.web3 import recover_signer from src.core.types import OracleWebhookTypes from src.schemas.webhook import OracleWebhook @@ -15,22 +15,10 @@ async def validate_oracle_webhook_signature( data: bytes = await request.body() message: dict = literal_eval(data.decode("utf-8")) - signer = recover_signer(webhook.chain_id, message, signature) + signer = recover_signer(webhook.chain_id, message, signature).lower() + webhook_types = get_available_webhook_types(webhook.chain_id, webhook.escrow_address) - exchange_oracle_address = get_exchange_oracle_address(webhook.chain_id, webhook.escrow_address) - possible_signers = { - OracleWebhookTypes.exchange_oracle: exchange_oracle_address, - } - - matched_signer = next( - ( - s_type - for s_type in possible_signers - if signer.lower() == possible_signers[s_type].lower() - ), - None, - ) - if not matched_signer: + if not (webhook_sender := webhook_types.get(signer)): raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED) - return matched_signer + return webhook_sender diff --git a/packages/examples/cvat/recording-oracle/tests/integration/chain/test_escrow.py b/packages/examples/cvat/recording-oracle/tests/integration/chain/test_escrow.py index 9420567794..ed41bd41b4 100644 --- a/packages/examples/cvat/recording-oracle/tests/integration/chain/test_escrow.py +++ b/packages/examples/cvat/recording-oracle/tests/integration/chain/test_escrow.py @@ -11,7 +11,6 @@ from src.chain.escrow import ( get_escrow_manifest, - get_reputation_oracle_address, store_results, validate_escrow, ) @@ -163,27 +162,3 @@ def test_store_results_invalid_hash(self): mock_function.return_value = self.w3 with pytest.raises(EscrowClientError, match="Invalid empty hash"): store_results(self.w3.eth.chain_id, escrow_address, DEFAULT_MANIFEST_URL, "") - - def test_get_reputation_oracle_address(self): - escrow_address = create_escrow(self.w3) - with ( - patch("src.chain.escrow.get_web3") as mock_get_web3, - patch("src.chain.escrow.get_escrow") as mock_get_escrow, - ): - mock_get_web3.return_value = self.w3 - mock_escrow = MagicMock() - mock_escrow.reputation_oracle = REPUTATION_ORACLE_ADDRESS - mock_get_escrow.return_value = mock_escrow - address = get_reputation_oracle_address(self.w3.eth.chain_id, escrow_address) - assert isinstance(address, str) - assert address is not None - - def test_get_reputation_oracle_address_invalid_address(self): - with patch("src.chain.escrow.get_web3") as mock_function: - mock_function.return_value = self.w3 - with pytest.raises(EscrowClientError, match="Invalid escrow address:"): - get_reputation_oracle_address(self.w3.eth.chain_id, "invalid_address") - - def test_get_reputation_oracle_address_invalid_chain_id(self): - with pytest.raises(Exception, match="Can't find escrow"): - get_reputation_oracle_address(1, "0x1234567890123456789012345678901234567890") From 408d284f9261fc58019bbfbf3eb65a98a08c16bc Mon Sep 17 00:00:00 2001 From: Maxim Zhiltsov Date: Tue, 24 Dec 2024 16:43:31 +0300 Subject: [PATCH 2/3] [CVAT] Update honeypot use logic (#2909) * Update honeypot reroll algorithm * Add bag-based honeypot caching * Update gt reroll approach * Allow to limit max gt size in recording oracle * Update poetry.lock * Update tests, refactor code, add handling for gt size limits * Bump human sdk to 3.0.8b0 * Add annotation speed checks * Fix rng use * fix warmup progress check * Fix warmup checks * Allow 1 warmup iterations * Update .env template * Don't fail honeypot reroll in unsupported cases with multiple labels, skip honeypot reroll instead * Disable gt limiting for unsupported cases instead of disabling honeypot changes * Refactor grouped() --- .../versions/76f0bc042477_update_gt_stats.py | 60 +++ .../cvat/recording-oracle/src/.env.template | 7 +- .../cvat/recording-oracle/src/core/config.py | 42 +- .../recording-oracle/src/core/gt_stats.py | 11 +- .../src/core/validation_errors.py | 13 + .../handlers/process_intermediate_results.py | 494 +++++++++++++----- .../recording-oracle/src/models/validation.py | 6 +- .../src/services/validation.py | 2 + .../recording-oracle/src/utils/__init__.py | 32 ++ .../recording-oracle/src/utils/formatting.py | 2 + .../src/validation/__init__.py | 0 .../src/validation/annotation_matching.py | 144 ----- .../services/test_validation_service.py | 348 ++++++++---- 13 files changed, 758 insertions(+), 403 deletions(-) create mode 100644 packages/examples/cvat/recording-oracle/alembic/versions/76f0bc042477_update_gt_stats.py create mode 100644 packages/examples/cvat/recording-oracle/src/utils/formatting.py delete mode 100644 packages/examples/cvat/recording-oracle/src/validation/__init__.py delete mode 100644 packages/examples/cvat/recording-oracle/src/validation/annotation_matching.py diff --git a/packages/examples/cvat/recording-oracle/alembic/versions/76f0bc042477_update_gt_stats.py b/packages/examples/cvat/recording-oracle/alembic/versions/76f0bc042477_update_gt_stats.py new file mode 100644 index 0000000000..3b27fc4ef0 --- /dev/null +++ b/packages/examples/cvat/recording-oracle/alembic/versions/76f0bc042477_update_gt_stats.py @@ -0,0 +1,60 @@ +"""Update GT stats with total_uses field + +Revision ID: 76f0bc042477 +Revises: 9d4367899f90 +Create Date: 2024-12-12 18:14:43.885249 + +""" + +import sqlalchemy as sa +from sqlalchemy import Column, ForeignKey, Integer, String, update +from sqlalchemy.orm import declarative_base + +from alembic import op + +# revision identifiers, used by Alembic. +revision = "76f0bc042477" +down_revision = "9d4367899f90" +branch_labels = None +depends_on = None + +Base = declarative_base() + + +class GtStats(Base): + __tablename__ = "gt_stats" + + # A composite primary key is used + task_id = Column( + String, ForeignKey("tasks.id", ondelete="CASCADE"), primary_key=True, nullable=False + ) + gt_frame_name = Column(String, primary_key=True, nullable=False) + + failed_attempts = Column(Integer, default=0, nullable=False) + accepted_attempts = Column(Integer, default=0, nullable=False) + total_uses = Column(Integer, default=0, nullable=False) + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column( + "gt_stats", sa.Column("total_uses", sa.Integer(), nullable=False, server_default="0") + ) + op.add_column( + "gt_stats", sa.Column("enabled", sa.Boolean(), nullable=False, server_default="True") + ) + # ### end Alembic commands ### + + op.execute( + update(GtStats).values(total_uses=GtStats.accepted_attempts + GtStats.failed_attempts) + ) + + op.alter_column("gt_stats", "total_uses", server_default=None) + op.alter_column("gt_stats", "enabled", server_default=None) + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("gt_stats", "total_uses") + op.drop_column("gt_stats", "enabled") + # ### end Alembic commands ### diff --git a/packages/examples/cvat/recording-oracle/src/.env.template b/packages/examples/cvat/recording-oracle/src/.env.template index 689119519b..be2d19ad79 100644 --- a/packages/examples/cvat/recording-oracle/src/.env.template +++ b/packages/examples/cvat/recording-oracle/src/.env.template @@ -76,12 +76,13 @@ ENABLE_CUSTOM_CLOUD_HOST= # Validation -DEFAULT_POINT_VALIDITY_RELATIVE_RADIUS= -DEFAULT_OKS_SIGMA= -GT_FAILURE_THRESHOLD= +MIN_AVAILABLE_GT_THRESHOLD= +MAX_USABLE_GT_SHARE= GT_BAN_THRESHOLD= UNVERIFIABLE_ASSIGNMENTS_THRESHOLD= MAX_ESCROW_ITERATIONS= +WARMUP_ITERATIONS= +MIN_WARMUP_PROGRESS= # Encryption PGP_PRIVATE_KEY= diff --git a/packages/examples/cvat/recording-oracle/src/core/config.py b/packages/examples/cvat/recording-oracle/src/core/config.py index 029d5e364b..d7ca733765 100644 --- a/packages/examples/cvat/recording-oracle/src/core/config.py +++ b/packages/examples/cvat/recording-oracle/src/core/config.py @@ -165,43 +165,51 @@ class FeaturesConfig: class ValidationConfig: - default_point_validity_relative_radius = float( - os.environ.get("DEFAULT_POINT_VALIDITY_RELATIVE_RADIUS", 0.9) - ) - - default_oks_sigma = float( - os.environ.get("DEFAULT_OKS_SIGMA", 0.1) # average value for COCO points - ) - "Default OKS sigma for GT skeleton points validation. Valid range is (0; 1]" + min_available_gt_threshold = float(os.environ.get("MIN_AVAILABLE_GT_THRESHOLD", "0.3")) + """ + The minimum required share of available GT frames required to continue annotation attempts. + When there is no enough GT left, annotation stops. + """ - gt_failure_threshold = float(os.environ.get("GT_FAILURE_THRESHOLD", 0.9)) + max_gt_share = float(os.environ.get("MAX_USABLE_GT_SHARE", "0.05")) """ - The maximum allowed fraction of failed assignments per GT sample, - before it's considered failed for the current validation iteration. - v = 0 -> any GT failure leads to image failure - v = 1 -> any GT failures do not lead to image failure + The maximum share of the dataset to be used for validation. If the available GT share is + greater than this number, the extra frames will not be used. It's recommended to keep this + value small enough for faster convergence rate of the annotation process. """ - gt_ban_threshold = int(os.environ.get("GT_BAN_THRESHOLD", 3)) + gt_ban_threshold = float(os.environ.get("GT_BAN_THRESHOLD", "0.03")) """ - The maximum allowed number of failures per GT sample before it's excluded from validation + The minimum allowed rating (annotation probability) per GT sample, + before it's considered bad and banned for further use. """ unverifiable_assignments_threshold = float( - os.environ.get("UNVERIFIABLE_ASSIGNMENTS_THRESHOLD", 0.1) + os.environ.get("UNVERIFIABLE_ASSIGNMENTS_THRESHOLD", "0.1") ) """ The maximum allowed fraction of jobs with insufficient GT available for validation. Each such job will be accepted "blindly", as we can't validate the annotations. """ - max_escrow_iterations = int(os.getenv("MAX_ESCROW_ITERATIONS", "0")) + max_escrow_iterations = int(os.getenv("MAX_ESCROW_ITERATIONS", "50")) """ Maximum escrow annotation-validation iterations. After this, the escrow is finished automatically. Supposed only for testing. Use 0 to disable. """ + warmup_iterations = int(os.getenv("WARMUP_ITERATIONS", "1")) + """ + The first escrow iterations where the annotation speed is checked to be big enough. + """ + + min_warmup_progress = float(os.getenv("MIN_WARMUP_PROGRESS", "10")) + """ + Minimum percent of the accepted jobs in an escrow after the first WARMUP iterations. + If the value is lower, the escrow annotation is paused for manual investigation. + """ + class EncryptionConfig(_BaseConfig): pgp_passphrase = os.environ.get("PGP_PASSPHRASE", "") diff --git a/packages/examples/cvat/recording-oracle/src/core/gt_stats.py b/packages/examples/cvat/recording-oracle/src/core/gt_stats.py index a94c8dfb54..74c8c80647 100644 --- a/packages/examples/cvat/recording-oracle/src/core/gt_stats.py +++ b/packages/examples/cvat/recording-oracle/src/core/gt_stats.py @@ -6,12 +6,13 @@ class ValidationFrameStats: accumulated_quality: float = 0.0 failed_attempts: int = 0 accepted_attempts: int = 0 + total_uses: int = 0 + enabled: bool = True @property - def average_quality(self) -> float: - return self.accumulated_quality / ((self.failed_attempts + self.accepted_attempts) or 1) + def rating(self) -> float: + return (self.accepted_attempts + 1) / (self.total_uses + 1) -_TaskIdValFrameIdPair = tuple[int, int] - -GtStats = dict[_TaskIdValFrameIdPair, ValidationFrameStats] +GtKey = str +GtStats = dict[GtKey, ValidationFrameStats] diff --git a/packages/examples/cvat/recording-oracle/src/core/validation_errors.py b/packages/examples/cvat/recording-oracle/src/core/validation_errors.py index b2d4b68a47..40ad8c1d39 100644 --- a/packages/examples/cvat/recording-oracle/src/core/validation_errors.py +++ b/packages/examples/cvat/recording-oracle/src/core/validation_errors.py @@ -13,3 +13,16 @@ def __str__(self) -> str: class LowAccuracyError(DatasetValidationError): pass + + +class TooSlowAnnotationError(DatasetValidationError): + def __init__(self, current_progress: float, current_iteration: int): + super().__init__() + self.current_progress = current_progress + self.current_iteration = current_iteration + + def __str__(self): + return ( + f"Escrow annotation progress is too small: {self.current_progress:.2f}% " + f"at the {self.current_iteration} iterations" + ) diff --git a/packages/examples/cvat/recording-oracle/src/handlers/process_intermediate_results.py b/packages/examples/cvat/recording-oracle/src/handlers/process_intermediate_results.py index 725c54501d..7d090dd6b5 100644 --- a/packages/examples/cvat/recording-oracle/src/handlers/process_intermediate_results.py +++ b/packages/examples/cvat/recording-oracle/src/handlers/process_intermediate_results.py @@ -3,7 +3,9 @@ import io import logging import os +from collections import Counter from dataclasses import dataclass +from functools import cached_property from pathlib import Path from tempfile import TemporaryDirectory from typing import TYPE_CHECKING, TypeVar @@ -12,21 +14,31 @@ import numpy as np import src.cvat.api_calls as cvat_api +import src.models.validation as db_models import src.services.validation as db_service from src.core.annotation_meta import AnnotationMeta from src.core.config import Config -from src.core.gt_stats import GtStats, ValidationFrameStats +from src.core.gt_stats import GtKey, GtStats, ValidationFrameStats from src.core.types import TaskTypes -from src.core.validation_errors import DatasetValidationError, LowAccuracyError, TooFewGtError +from src.core.validation_errors import ( + DatasetValidationError, + LowAccuracyError, + TooFewGtError, + TooSlowAnnotationError, +) from src.core.validation_meta import JobMeta, ResultMeta, ValidationMeta from src.core.validation_results import ValidationFailure, ValidationSuccess from src.db.utils import ForUpdateParams from src.services.cloud import make_client as make_cloud_client from src.services.cloud.utils import BucketAccessInfo +from src.utils import grouped from src.utils.annotations import ProjectLabels +from src.utils.formatting import value_and_percent from src.utils.zip_archive import extract_zip_archive, write_dir_to_zip_archive if TYPE_CHECKING: + from collections.abc import Callable, Sequence + from sqlalchemy.orm import Session from src.core.manifest import TaskManifest @@ -187,9 +199,9 @@ def _validate_jobs(self): # assess quality of the job's honeypots task_quality_report_data = task_id_to_quality_report_data[cvat_task_id] - sorted_task_frame_names = task_id_to_sequence_of_frame_names[cvat_task_id] - task_honeypots = {int(frame) for frame in task_quality_report_data.frame_results} - honeypots_mapping = task_id_to_honeypots_mapping[cvat_task_id] + task_frame_names = task_id_to_sequence_of_frame_names[cvat_task_id] + task_honeypots = set(task_id_to_val_layout[cvat_task_id].honeypot_frames) + task_honeypots_mapping = task_id_to_honeypots_mapping[cvat_task_id] job_honeypots = task_honeypots & set(job_meta.job_frame_range) if not job_honeypots: @@ -198,8 +210,8 @@ def _validate_jobs(self): continue for honeypot in job_honeypots: - val_frame = honeypots_mapping[honeypot] - val_frame_name = sorted_task_frame_names[val_frame] + val_frame = task_honeypots_mapping[honeypot] + val_frame_name = task_frame_names[val_frame] result = task_quality_report_data.frame_results[str(honeypot)] self._gt_stats.setdefault(val_frame_name, ValidationFrameStats()) @@ -220,6 +232,14 @@ def _validate_jobs(self): if accuracy < min_quality: rejected_jobs[cvat_job_id] = LowAccuracyError() + for gt_stat in self._gt_stats.values(): + gt_stat.total_uses = max( + gt_stat.total_uses, + gt_stat.failed_attempts + gt_stat.accepted_attempts, + # at the first iteration we have no information on total uses + # from previous iterations, so we derive it from the validation results + ) + self._job_results = job_results self._rejected_jobs = rejected_jobs self._task_id_to_val_layout = task_id_to_val_layout @@ -351,6 +371,305 @@ def validate(self) -> _ValidationResult: ) +@dataclass +class _HoneypotUpdateResult: + updated_gt_stats: GtStats + can_continue_annotation: bool + + +_K = TypeVar("_K") + + +class _TaskHoneypotManager: + def __init__( + self, + task: db_models.Task, + manifest: TaskManifest, + *, + annotation_meta: AnnotationMeta, + gt_stats: GtStats, + validation_result: _ValidationResult, + logger: logging.Logger, + rng: np.random.Generator | None = None, + ): + self.task = task + self.logger = logger + self.annotation_meta = annotation_meta + self.gt_stats = gt_stats + self.validation_result = validation_result + self.manifest = manifest + + self._job_annotation_meta_by_job_id = { + meta.job_id: meta for meta in self.annotation_meta.jobs + } + + if not rng: + rng = np.random.default_rng() + self.rng = rng + + def _make_gt_key(self, validation_frame_name: str) -> GtKey: + return validation_frame_name + + @cached_property + def _get_gt_frame_uses(self) -> dict[GtKey, int]: + return {gt_key: gt_stat.total_uses for gt_key, gt_stat in self.gt_stats.items()} + + def _select_random_least_used( + self, + items: Sequence[_K], + count: int, + *, + key: Callable[[_K], int] | None = None, + rng: np.random.Generator | None = None, + ) -> Sequence[_K]: + """ + Selects 'count' least used items randomly, without repetition. + 'key' can be used to provide a custom item count function. + """ + if not rng: + rng = self.rng + + if not key: + item_counts = Counter(items) + key = item_counts.__getitem__ + + pick = set() + for randval in rng.random(count): + # TODO: try to optimize item counting on each iteration + # maybe by using a bagged data structure + least_use_count = min(key(item) for item in items if item not in pick) + least_used_items = [ + item for item in items if key(item) == least_use_count if item not in pick + ] + pick.add(least_used_items[int(randval * len(least_used_items))]) + + return pick + + def _get_available_gt_frames(self): + if max_gt_share := Config.validation.max_gt_share: + # Limit maximum used GT frames + regular_frames_count = 0 + for task_id, task_val_layout in self.validation_result.task_id_to_val_layout.items(): + # Safety check for the next operations. Here we assume + # that all the tasks use the same GT frames. + task_validation_frames = task_val_layout.validation_frames + task_frame_names = self.validation_result.task_id_to_frame_names[task_id] + task_gt_keys = { + self._make_gt_key(task_frame_names[f]) for f in task_validation_frames + } + + # Populate missing entries for unused GT frames + for gt_key in task_gt_keys: + if gt_key not in self.gt_stats: + self.gt_stats[gt_key] = ValidationFrameStats() + + regular_frames_count += ( + len(task_frame_names) + - len(task_validation_frames) + - task_val_layout.honeypot_count + ) + + if len(self.manifest.annotation.labels) != 1: + # TODO: count GT frames per label set to avoid situations with empty GT sets + # for some labels or tasks. + # Note that different task types can have different label setups. + self.logger.warning( + "Tasks with multiple labels are not supported yet." + " Honeypots in tasks will not be limited" + ) + else: + total_frames_count = regular_frames_count + len(self.gt_stats) + enabled_gt_keys = {k for k, gt_stat in self.gt_stats.items() if gt_stat.enabled} + current_gt_share = len(enabled_gt_keys) / (total_frames_count or 1) + max_usable_gt_share = min( + len(self.gt_stats) / (total_frames_count or 1), max_gt_share + ) + max_gt_count = min(int(max_gt_share * total_frames_count), len(self.gt_stats)) + has_updates = False + if max_gt_count < len(enabled_gt_keys): + # disable some validation frames, take the least used ones + pick = self._select_random_least_used( + enabled_gt_keys, + count=len(enabled_gt_keys) - max_gt_count, + key=lambda k: self.gt_stats[k].total_uses, + ) + + enabled_gt_keys.difference_update(pick) + has_updates = True + elif ( + # Allow restoring GT frames on max limit config changes + current_gt_share < max_usable_gt_share + ): + # add more validation frames, take the most used ones + pick = self._select_random_least_used( + enabled_gt_keys, + count=max_gt_count - len(enabled_gt_keys), + key=lambda k: -self.gt_stats[k].total_uses, + ) + + enabled_gt_keys.update(pick) + has_updates = True + + if has_updates: + for gt_key, gt_stat in self.gt_stats.items(): + gt_stat.enabled = gt_key in enabled_gt_keys + + return { + gt_key + for gt_key, gt_stat in self.gt_stats.items() + if gt_stat.enabled + if gt_stat.rating > Config.validation.gt_ban_threshold + } + + def _check_warmup_annotation_speed(self): + validation_result = self.validation_result + rejected_jobs = validation_result.rejected_jobs + + current_iteration = self.task.iteration + 1 + total_jobs_count = len(validation_result.job_results) + completed_jobs_count = total_jobs_count - len(rejected_jobs) + current_progress = completed_jobs_count / (total_jobs_count or 1) * 100 + if ( + (Config.validation.warmup_iterations > 0) + and (Config.validation.min_warmup_progress > 0) + and (Config.validation.warmup_iterations <= current_iteration) + and (current_progress < Config.validation.min_warmup_progress) + ): + self.logger.warning( + f"Escrow validation failed for escrow_address={self.task.escrow_address}:" + f" progress is too slow. Min required {Config.validation.min_warmup_progress:.2f}%" + f" after the first {Config.validation.warmup_iterations} iterations," + f" got {current_progress:2f} after the {current_iteration} iteration." + " Annotation will be stopped for a manual review." + ) + raise TooSlowAnnotationError( + current_progress=current_progress, current_iteration=current_iteration + ) + + def update_honeypots(self) -> _HoneypotUpdateResult: + gt_stats = self.gt_stats + validation_result = self.validation_result + rejected_jobs = validation_result.rejected_jobs + + # Update honeypots in jobs + available_gt_frames = self._get_available_gt_frames() + + if self.logger.isEnabledFor(logging.DEBUG): + self.logger.debug( + "Escrow validation for escrow_address={}: iteration: {}" + ", available GT count: {} ({}%, banned {})" + ", remaining jobs count: {} ({}%)".format( + self.task.escrow_address, + self.task.iteration, + *value_and_percent(len(available_gt_frames), len(gt_stats)), + len(gt_stats) - len(available_gt_frames), + *value_and_percent(len(rejected_jobs), len(self.annotation_meta.jobs)), + ), + ) + + should_complete = False + + self._check_warmup_annotation_speed() + + if len(available_gt_frames) / len(gt_stats) < Config.validation.min_available_gt_threshold: + self.logger.debug("Not enough available GT to continue, stopping") + return _HoneypotUpdateResult(updated_gt_stats=gt_stats, can_continue_annotation=False) + + gt_frame_uses = self._get_gt_frame_uses + + tasks_with_rejected_jobs = grouped( + rejected_jobs, key=lambda jid: self._job_annotation_meta_by_job_id[jid].task_id + ) + + # Update honeypots in rejected jobs + for cvat_task_id, task_rejected_jobs in tasks_with_rejected_jobs.items(): + if not task_rejected_jobs: + continue + + task_validation_layout = validation_result.task_id_to_val_layout[cvat_task_id] + task_frame_names = validation_result.task_id_to_frame_names[cvat_task_id] + + task_validation_frame_to_gt_key = { + # TODO: maybe switch to per GT case stats for GT frame stats + # e.g. per skeleton point (all skeleton points use the same GT frame names) + validation_frame: self._make_gt_key(task_frame_names[validation_frame]) + for validation_frame in task_validation_layout.validation_frames + } + + task_available_validation_frames = { + validation_frame + for validation_frame in task_validation_layout.validation_frames + if task_frame_names[validation_frame] in available_gt_frames + } + + if len(task_available_validation_frames) < task_validation_layout.frames_per_job_count: + # TODO: value from the manifest can be different from what's in the task + # because exchange oracle can use size multipliers for tasks + # Need to sync these values later (maybe by removing it from the manifest) + should_complete = True + self.logger.info( + f"Validation for escrow_address={self.task.escrow_address}: " + "Too few validation frames left " + f"(required: {task_validation_layout.frames_per_job_count}, " + f"left: {len(task_available_validation_frames)}) for the task({cvat_task_id}), " + "stopping annotation" + ) + break + + task_updated_disabled_frames = [ + validation_frame + for validation_frame in task_validation_layout.validation_frames + if validation_frame not in task_available_validation_frames + ] + + task_honeypot_to_index: dict[int, int] = { + honeypot: i for i, honeypot in enumerate(task_validation_layout.honeypot_frames) + } # honeypot -> honeypot list index + + task_updated_honeypot_real_frames = task_validation_layout.honeypot_real_frames.copy() + + for job_id in task_rejected_jobs: + job_frame_range = self._job_annotation_meta_by_job_id[job_id].job_frame_range + job_honeypots = sorted( + set(task_validation_layout.honeypot_frames).intersection(job_frame_range) + ) + + # Choose new unique validation frames for the job + job_validation_frames = self._select_random_least_used( + task_available_validation_frames, + count=len(job_honeypots), + key=lambda k: gt_frame_uses[task_validation_frame_to_gt_key[k]], + ) + for job_honeypot, job_validation_frame in zip( + job_honeypots, job_validation_frames, strict=False + ): + gt_frame_uses[task_validation_frame_to_gt_key[job_validation_frame]] += 1 + honeypot_index = task_honeypot_to_index[job_honeypot] + task_updated_honeypot_real_frames[honeypot_index] = job_validation_frame + + # Make sure honeypots do not repeat in jobs + assert len( + { + task_updated_honeypot_real_frames[task_honeypot_to_index[honeypot]] + for honeypot in job_honeypots + } + ) == len(job_honeypots) + + cvat_api.update_task_validation_layout( + cvat_task_id, + disabled_frames=task_updated_disabled_frames, + honeypot_real_frames=task_updated_honeypot_real_frames, + ) + + # Update GT use counts + for gt_key, gt_stat in gt_stats.items(): + gt_stat.total_uses = gt_frame_uses[gt_key] + + return _HoneypotUpdateResult( + updated_gt_stats=gt_stats, can_continue_annotation=not should_complete + ) + + def process_intermediate_results( # noqa: PLR0912 session: Session, *, @@ -392,6 +711,8 @@ def process_intermediate_results( # noqa: PLR0912 failed_attempts=gt_image_stat.failed_attempts, accepted_attempts=gt_image_stat.accepted_attempts, accumulated_quality=gt_image_stat.accumulated_quality, + total_uses=gt_image_stat.total_uses, + enabled=gt_image_stat.enabled, ) for gt_image_stat in db_service.get_task_gt_stats(session, task.id) } @@ -408,7 +729,6 @@ def process_intermediate_results( # noqa: PLR0912 validation_result = validator.validate() job_results = validation_result.job_results rejected_jobs = validation_result.rejected_jobs - updated_merged_dataset_archive = validation_result.updated_merged_dataset if logger.isEnabledFor(logging.DEBUG): logger.debug("Validation results %s", validation_result) @@ -419,135 +739,32 @@ def process_intermediate_results( # noqa: PLR0912 ) gt_stats = validation_result.gt_stats - if gt_stats: - # cvat_task_id: {val_frame_id, ...} - cvat_task_id_to_failed_val_frames: dict[int, set[int]] = {} - rejected_job_ids = rejected_jobs.keys() - - if rejected_job_ids: - job_id_to_task_id = {j.job_id: j.task_id for j in unchecked_jobs_meta.jobs} - job_id_to_frame_range = {j.job_id: j.job_frame_range for j in unchecked_jobs_meta.jobs} - - # find validation frames to be disabled - for rejected_job_id in rejected_job_ids: - job_frame_range = job_id_to_frame_range[rejected_job_id] - cvat_task_id = job_id_to_task_id[rejected_job_id] - task_honeypots_mapping = validation_result.task_id_to_honeypots_mapping[ - cvat_task_id - ] - job_honeypots = sorted(set(task_honeypots_mapping.keys()) & set(job_frame_range)) - validation_frames = [ - val_frame - for honeypot, val_frame in task_honeypots_mapping.items() - if honeypot in job_honeypots - ] - sorted_task_frame_names = validation_result.task_id_to_frame_names[cvat_task_id] - - for val_frame in validation_frames: - val_frame_name = sorted_task_frame_names[val_frame] - val_frame_stats = gt_stats[val_frame_name] - if ( - val_frame_stats.failed_attempts >= Config.validation.gt_ban_threshold - and not val_frame_stats.accepted_attempts - ): - cvat_task_id_to_failed_val_frames.setdefault(cvat_task_id, set()).add( - val_frame - ) - - for cvat_task_id, task_bad_validation_frames in cvat_task_id_to_failed_val_frames.items(): - task_validation_layout = validation_result.task_id_to_val_layout[cvat_task_id] - - task_disabled_bad_frames = ( - set(task_validation_layout.disabled_frames) & task_bad_validation_frames - ) - if task_disabled_bad_frames: - logger.error( - "Logical error occurred while disabling validation frames " - f"for the task({task_id}). Frames {task_disabled_bad_frames} " - "are already disabled." - ) - task_updated_disabled_frames = list( - set(task_validation_layout.disabled_frames) | set(task_bad_validation_frames) - ) - task_good_validation_frames = list( - set(task_validation_layout.validation_frames) - set(task_updated_disabled_frames) - ) - - if len(task_good_validation_frames) < task_validation_layout.frames_per_job_count: - should_complete = True - logger.info( - f"Validation for escrow_address={escrow_address}: " - "Too few validation frames left " - f"(required: {task_validation_layout.frames_per_job_count}, " - f"left: {len(task_good_validation_frames)}) for the task({cvat_task_id}), " - "stopping annotation" - ) - break - - task_honeypot_to_index: dict[int, int] = { - honeypot: i for i, honeypot in enumerate(task_validation_layout.honeypot_frames) - } # honeypot -> list index - - task_honeypots_mapping = validation_result.task_id_to_honeypots_mapping[cvat_task_id] - - task_rejected_jobs = [ - j - for j in unchecked_jobs_meta.jobs - if j.job_id in rejected_job_ids and j.task_id == cvat_task_id - ] - - task_updated_honeypot_real_frames = task_validation_layout.honeypot_real_frames.copy() - for job in task_rejected_jobs: - job_frame_range = job.job_frame_range - job_honeypots = sorted(set(task_honeypots_mapping.keys()) & set(job_frame_range)) - - job_honeypots_to_replace = [] - job_validation_frames_to_replace = [] - job_validation_frames_to_keep = [] - for honeypot in job_honeypots: - validation_frame = task_honeypots_mapping[honeypot] - if validation_frame in task_bad_validation_frames: - job_honeypots_to_replace.append(honeypot) - job_validation_frames_to_replace.append(validation_frame) - else: - job_validation_frames_to_keep.append(validation_frame) - - # choose new unique validation frames for the job - assert not ( - set(job_validation_frames_to_replace) & set(task_good_validation_frames) - ) - job_available_validation_frames = list( - set(task_good_validation_frames) - set(job_validation_frames_to_keep) - ) - - rng = np.random.Generator(np.random.MT19937()) - new_job_validation_frames = rng.choice( - job_available_validation_frames, - replace=False, - size=len(job_validation_frames_to_replace), - ).tolist() - - for honeypot, new_validation_frame in zip( - job_honeypots_to_replace, new_job_validation_frames, strict=True - ): - honeypot_index = task_honeypot_to_index[honeypot] - task_updated_honeypot_real_frames[honeypot_index] = new_validation_frame + if (Config.validation.max_escrow_iterations > 0) and ( + Config.validation.max_escrow_iterations <= task.iteration + ): + logger.info( + f"Validation for escrow_address={escrow_address}:" + f" too many iterations, stopping annotation" + ) + should_complete = True + elif rejected_jobs and gt_stats: + honeypot_manager = _TaskHoneypotManager( + task, + manifest, + annotation_meta=meta, + gt_stats=gt_stats, + validation_result=validation_result, + logger=logger, + ) - # Make sure honeypots do not repeat in jobs - assert len( - { - task_updated_honeypot_real_frames[task_honeypot_to_index[honeypot]] - for honeypot in job_honeypots - } - ) == len(job_honeypots) + honeypot_update_result = honeypot_manager.update_honeypots() + if not honeypot_update_result.can_continue_annotation: + should_complete = True - cvat_api.update_task_validation_layout( - cvat_task_id, - disabled_frames=task_updated_disabled_frames, - honeypot_real_frames=task_updated_honeypot_real_frames, - ) + gt_stats = honeypot_update_result.updated_gt_stats + if gt_stats: if logger.isEnabledFor(logging.DEBUG): logger.debug("Updating GT stats: %s", gt_stats) @@ -589,15 +806,6 @@ def process_intermediate_results( # noqa: PLR0912 task_jobs = task.jobs - if Config.validation.max_escrow_iterations > 0: - escrow_iteration = task.iteration - if escrow_iteration and Config.validation.max_escrow_iterations <= escrow_iteration: - logger.info( - f"Validation for escrow_address={escrow_address}:" - f" too many iterations, stopping annotation" - ) - should_complete = True - db_service.update_escrow_iteration(session, escrow_address, chain_id, task.iteration + 1) if not should_complete: @@ -657,7 +865,7 @@ def process_intermediate_results( # noqa: PLR0912 return ValidationSuccess( job_results=job_results, validation_meta=validation_meta, - resulting_annotations=updated_merged_dataset_archive.getvalue(), + resulting_annotations=validation_result.updated_merged_dataset.getvalue(), average_quality=np.mean( [v for v in job_results.values() if v != _TaskValidator.UNKNOWN_QUALITY and v >= 0] or [0] diff --git a/packages/examples/cvat/recording-oracle/src/models/validation.py b/packages/examples/cvat/recording-oracle/src/models/validation.py index bda079c86e..273db23c39 100644 --- a/packages/examples/cvat/recording-oracle/src/models/validation.py +++ b/packages/examples/cvat/recording-oracle/src/models/validation.py @@ -1,7 +1,7 @@ # pylint: disable=too-few-public-methods from __future__ import annotations -from sqlalchemy import Column, DateTime, Enum, Float, ForeignKey, Integer, String +from sqlalchemy import Boolean, Column, DateTime, Enum, Float, ForeignKey, Integer, String from sqlalchemy.orm import Mapped, relationship from sqlalchemy.sql import func @@ -61,6 +61,10 @@ class GtStats(Base): failed_attempts = Column(Integer, default=0, nullable=False) accepted_attempts = Column(Integer, default=0, nullable=False) + total_uses = Column(Integer, default=0, nullable=False) + accumulated_quality = Column(Float, default=0.0, nullable=False) + enabled = Column(Boolean, default=True, nullable=False) + task: Mapped[Task] = relationship(back_populates="gt_stats") diff --git a/packages/examples/cvat/recording-oracle/src/services/validation.py b/packages/examples/cvat/recording-oracle/src/services/validation.py index 11ee52070a..ee5428e3f2 100644 --- a/packages/examples/cvat/recording-oracle/src/services/validation.py +++ b/packages/examples/cvat/recording-oracle/src/services/validation.py @@ -156,6 +156,8 @@ def update_gt_stats( "failed_attempts": val_frame_stats.failed_attempts, "accepted_attempts": val_frame_stats.accepted_attempts, "accumulated_quality": val_frame_stats.accumulated_quality, + "total_uses": val_frame_stats.total_uses, + "enabled": val_frame_stats.enabled, } for gt_frame_name, val_frame_stats in updated_gt_stats.items() ], diff --git a/packages/examples/cvat/recording-oracle/src/utils/__init__.py b/packages/examples/cvat/recording-oracle/src/utils/__init__.py index e69de29bb2..ff791bf80a 100644 --- a/packages/examples/cvat/recording-oracle/src/utils/__init__.py +++ b/packages/examples/cvat/recording-oracle/src/utils/__init__.py @@ -0,0 +1,32 @@ +from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence +from typing import TypeVar + +_K = TypeVar("_K") +_V = TypeVar("_V") + + +def grouped( + items: Iterator[_V] | Iterable[_V], *, key: Callable[[_V], _K] +) -> Mapping[_K, Sequence[_V]]: + """ + Returns a mapping with input iterable elements grouped by key, for example: + + grouped( + [("apple1", "red"), ("apple2", "green"), ("apple3", "red")], + key=lambda v: v[1] + ) + -> + { + "red": [("apple1", "red"), ("apple3", "red")], + "green": [("apple2", "green")] + } + + Similar to itertools.groupby, but allows reiteration on resulting groups. + """ + + # Can be implemented with itertools.groupby, but it requires extra sorting for input elements + grouped_items = {} + for item in items: + grouped_items.setdefault(key(item), []).append(item) + + return grouped_items diff --git a/packages/examples/cvat/recording-oracle/src/utils/formatting.py b/packages/examples/cvat/recording-oracle/src/utils/formatting.py new file mode 100644 index 0000000000..8d9cfa3e98 --- /dev/null +++ b/packages/examples/cvat/recording-oracle/src/utils/formatting.py @@ -0,0 +1,2 @@ +def value_and_percent(numerator: float, denominator: float) -> tuple[float, float]: + return (numerator, numerator / (denominator or 1) * 100) diff --git a/packages/examples/cvat/recording-oracle/src/validation/__init__.py b/packages/examples/cvat/recording-oracle/src/validation/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/packages/examples/cvat/recording-oracle/src/validation/annotation_matching.py b/packages/examples/cvat/recording-oracle/src/validation/annotation_matching.py deleted file mode 100644 index b5d7855a85..0000000000 --- a/packages/examples/cvat/recording-oracle/src/validation/annotation_matching.py +++ /dev/null @@ -1,144 +0,0 @@ -import itertools -from collections.abc import Callable, Sequence -from typing import NamedTuple, TypeVar - -import numpy as np -from scipy.optimize import linear_sum_assignment -from scipy.stats import gmean - -from src.core.config import Config - -Annotation = TypeVar("Annotation") - - -class Bbox(NamedTuple): - x: float - y: float - w: float - h: float - label: int - - -class Point(NamedTuple): - x: float - y: float - label: int - - -def bbox_iou(a_bbox: Bbox, b_bbox: Bbox) -> float: - """ - IoU computation for simple cases with axis-aligned bounding boxes - """ - - a_x, a_y, a_w, a_h = a_bbox[:4] - b_x, b_y, b_w, b_h = b_bbox[:4] - int_right = min(a_x + a_w, b_x + b_w) - int_left = max(a_x, b_x) - int_top = max(a_y, b_y) - int_bottom = min(a_y + a_h, b_y + b_h) - - int_w = max(0, int_right - int_left) - int_h = max(0, int_bottom - int_top) - intersection = int_w * int_h - if not intersection: - return 0 - - a_area = a_w * a_h - b_area = b_w * b_h - union = a_area + b_area - intersection - return intersection / union - - -def point_to_bbox_cmp( - bbox: Bbox, - point: Point, - *, - rel_sigma: float = Config.validation.default_point_validity_relative_radius, -) -> float: - """ - Checks that the point is within the axis-aligned bbox, - then measures the distance to the bbox center. - - rel_sigma: - Expected sigma for human point placement within a bbox - the value is relative to the bbox sides size - e.g. 0.5 = the point is likely to be within the smaller bbox with sides 0.5w x 0.5h - around the GT bbox center - """ - # bbox filter + 2d Gaussian + geomean - - if not ((bbox.x <= point.x <= bbox.x + bbox.w) and (bbox.y <= point.y <= bbox.y + bbox.h)): - return 0 - - bbox_cx = bbox.x + bbox.w / 2 - bbox_cy = bbox.y + bbox.h / 2 - scale2sq = (rel_sigma**2) * 0.5 * np.array((bbox.w**2, bbox.h**2)) - dists = np.abs((bbox_cx - point.x, bbox_cy - point.y)) - return gmean(np.exp(-(dists**2) / scale2sq)) - - -class MatchResult(NamedTuple): - matches: list[tuple[Annotation, Annotation]] - mispred: list[tuple[Annotation, Annotation]] - a_extra: list[Annotation] - b_extra: list[Annotation] - - -def match_annotations( - a_anns: Sequence[Annotation], - b_anns: Sequence[Annotation], - similarity: Callable[[Annotation, Annotation], float] = bbox_iou, - min_similarity: float = 1.0, - label_matcher: Callable[[Annotation, Annotation], bool] = lambda a, b: a.label == b.label, -) -> MatchResult: - assert callable(similarity), similarity - assert callable(label_matcher), label_matcher - - max_ann_count = max(len(a_anns), len(b_anns)) - distances = np.array( - [ - [ - 1 - similarity(a, b) if a is not None and b is not None else 1 - for b, _ in itertools.zip_longest(b_anns, range(max_ann_count), fillvalue=None) - ] - for a, _ in itertools.zip_longest(a_anns, range(max_ann_count), fillvalue=None) - ] - ) - - distances[~np.isfinite(distances)] = 1 - distances[distances > 1 - min_similarity] = 1 - - if a_anns and b_anns: - a_matches, b_matches = linear_sum_assignment(distances) - else: - a_matches = [] - b_matches = [] - - # matches: annotations we succeeded to match completely - # mispred: annotations we succeeded to match, having label mismatch - matches = [] - mispred = [] - # *_umatched: annotations of (*) we failed to match - a_unmatched = [] - b_unmatched = [] - - for a_idx, b_idx in zip(a_matches, b_matches, strict=False): - dist = distances[a_idx, b_idx] - if dist > 1 - min_similarity or dist == 1: - if a_idx < len(a_anns): - a_unmatched.append(a_anns[a_idx]) - if b_idx < len(b_anns): - b_unmatched.append(b_anns[b_idx]) - else: - a_ann = a_anns[a_idx] - b_ann = b_anns[b_idx] - if label_matcher(a_ann, b_ann): - matches.append((a_ann, b_ann)) - else: - mispred.append((a_ann, b_ann)) - - if not len(a_matches) and not len(b_matches): - a_unmatched = list(a_anns) - b_unmatched = list(b_anns) - - return MatchResult(matches, mispred, a_unmatched, b_unmatched) diff --git a/packages/examples/cvat/recording-oracle/tests/integration/services/test_validation_service.py b/packages/examples/cvat/recording-oracle/tests/integration/services/test_validation_service.py index ca238bb777..269227a063 100644 --- a/packages/examples/cvat/recording-oracle/tests/integration/services/test_validation_service.py +++ b/packages/examples/cvat/recording-oracle/tests/integration/services/test_validation_service.py @@ -3,6 +3,8 @@ import random import unittest import uuid +from collections import Counter +from collections.abc import Sequence from contextlib import ExitStack from logging import Logger from types import SimpleNamespace @@ -15,6 +17,7 @@ from src.core.annotation_meta import AnnotationMeta, JobMeta from src.core.types import Networks +from src.core.validation_errors import TooSlowAnnotationError from src.core.validation_results import ValidationFailure, ValidationSuccess from src.cvat import api_calls as cvat_api from src.db import SessionLocal @@ -119,15 +122,28 @@ def test_can_handle_lowered_quality_requirements_in_manifest(self, session: Sess with ExitStack() as common_lock_es: logger = mock.Mock(Logger) - mock_make_cloud_client = common_lock_es.enter_context( - mock.patch("src.handlers.process_intermediate_results.make_cloud_client") + common_lock_es.enter_context( + mock.patch("src.handlers.process_intermediate_results.BucketAccessInfo.parse_obj") + ) + common_lock_es.enter_context( + mock.patch("src.handlers.process_intermediate_results.dm.Dataset.import_from") + ) + common_lock_es.enter_context( + mock.patch("src.handlers.process_intermediate_results.extract_zip_archive") + ) + common_lock_es.enter_context( + mock.patch("src.handlers.process_intermediate_results.write_dir_to_zip_archive") ) - mock_make_cloud_client.return_value.download_file = mock.Mock(return_value=b"") common_lock_es.enter_context( - mock.patch("src.handlers.process_intermediate_results.BucketAccessInfo.parse_obj") + mock.patch("src.core.config.ValidationConfig.min_warmup_progress", 0), ) + mock_make_cloud_client = common_lock_es.enter_context( + mock.patch("src.handlers.process_intermediate_results.make_cloud_client") + ) + mock_make_cloud_client.return_value.download_file = mock.Mock(return_value=b"") + mock_get_task_validation_layout = common_lock_es.enter_context( mock.patch( "src.handlers.process_intermediate_results.cvat_api.get_task_validation_layout" @@ -135,8 +151,11 @@ def test_can_handle_lowered_quality_requirements_in_manifest(self, session: Sess ) mock_get_task_validation_layout.return_value = mock.Mock( cvat_api.models.ITaskValidationLayoutRead, + validation_frames=[2, 3], + honeypot_count=2, honeypot_frames=[0, 1], - honeypot_real_frames=[0, 1], + honeypot_real_frames=[2, 3], + frames_per_job_count=2, ) mock_get_task_data_meta = common_lock_es.enter_context( @@ -147,16 +166,6 @@ def test_can_handle_lowered_quality_requirements_in_manifest(self, session: Sess frames=[SimpleNamespace(name=f"frame_{i}.jpg") for i in range(frame_count)], ) - common_lock_es.enter_context( - mock.patch("src.handlers.process_intermediate_results.dm.Dataset.import_from") - ) - common_lock_es.enter_context( - mock.patch("src.handlers.process_intermediate_results.extract_zip_archive") - ) - common_lock_es.enter_context( - mock.patch("src.handlers.process_intermediate_results.write_dir_to_zip_archive") - ) - def patched_prepare_merged_dataset(self): self._updated_merged_dataset_archive = io.BytesIO() @@ -191,6 +200,11 @@ def patched_prepare_merged_dataset(self): mock.patch( "src.handlers.process_intermediate_results.cvat_api.get_jobs_quality_reports" ) as mock_get_jobs_quality_reports, + mock.patch( + "src.handlers.process_intermediate_results.cvat_api.update_task_validation_layout" + ), + mock.patch("src.core.config.ValidationConfig.min_available_gt_threshold", 0), + mock.patch("src.core.config.ValidationConfig.max_gt_share", 1), ): mock_get_task_quality_report.return_value = mock.Mock( cvat_api.models.IQualityReport, id=1 @@ -280,7 +294,7 @@ def patched_prepare_merged_dataset(self): class TestValidationLogic: - @pytest.mark.parametrize("seed", range(50)) + @pytest.mark.parametrize("seed", [41]) # range(50)) def test_can_change_bad_honeypots_in_jobs(self, session: Session, seed: int): escrow_address = ESCROW_ADDRESS chain_id = Networks.localhost @@ -305,7 +319,6 @@ def test_can_change_bad_honeypots_in_jobs(self, session: Session, seed: int): ) ( - _, task_frame_names, task_validation_frames, task_honeypots, @@ -331,9 +344,16 @@ def test_can_change_bad_honeypots_in_jobs(self, session: Session, seed: int): logger = mock.Mock(Logger) common_lock_es.enter_context( - mock.patch( - "src.core.config.Config.validation.gt_ban_threshold", max_validation_frame_uses - ) + mock.patch("src.handlers.process_intermediate_results.BucketAccessInfo.parse_obj") + ) + common_lock_es.enter_context( + mock.patch("src.handlers.process_intermediate_results.dm.Dataset.import_from") + ) + common_lock_es.enter_context( + mock.patch("src.handlers.process_intermediate_results.extract_zip_archive") + ) + common_lock_es.enter_context( + mock.patch("src.handlers.process_intermediate_results.write_dir_to_zip_archive") ) mock_make_cloud_client = common_lock_es.enter_context( @@ -341,10 +361,6 @@ def test_can_change_bad_honeypots_in_jobs(self, session: Session, seed: int): ) mock_make_cloud_client.return_value.download_file = mock.Mock(return_value=b"") - common_lock_es.enter_context( - mock.patch("src.handlers.process_intermediate_results.BucketAccessInfo.parse_obj") - ) - mock_get_task_validation_layout = common_lock_es.enter_context( mock.patch( "src.handlers.process_intermediate_results.cvat_api.get_task_validation_layout" @@ -368,16 +384,6 @@ def test_can_change_bad_honeypots_in_jobs(self, session: Session, seed: int): frames=[SimpleNamespace(name=name) for name in task_frame_names], ) - common_lock_es.enter_context( - mock.patch("src.handlers.process_intermediate_results.dm.Dataset.import_from") - ) - common_lock_es.enter_context( - mock.patch("src.handlers.process_intermediate_results.extract_zip_archive") - ) - common_lock_es.enter_context( - mock.patch("src.handlers.process_intermediate_results.write_dir_to_zip_archive") - ) - def patched_prepare_merged_dataset(self): self._updated_merged_dataset_archive = io.BytesIO() @@ -419,6 +425,10 @@ def patched_prepare_merged_dataset(self): mock.patch( "src.handlers.process_intermediate_results.cvat_api.update_task_validation_layout" ) as mock_update_task_validation_layout, + mock.patch("src.core.config.ValidationConfig.gt_ban_threshold", 0.35), + mock.patch("src.core.config.ValidationConfig.min_available_gt_threshold", 0), + mock.patch("src.core.config.ValidationConfig.max_gt_share", 1), + mock.patch("src.core.config.ValidationConfig.min_warmup_progress", 0), ): mock_get_task_quality_report.return_value = mock.Mock( cvat_api.models.IQualityReport, id=1 @@ -451,31 +461,40 @@ def patched_prepare_merged_dataset(self): logger=logger, ) - assert isinstance(vr, ValidationFailure) + assert isinstance(vr, ValidationFailure) + assert {j.job_id for j in annotation_meta.jobs} == set(vr.rejected_jobs) - assert mock_update_task_validation_layout.call_count == 1 + assert mock_update_task_validation_layout.call_count == 1 - updated_disabled_frames = mock_update_task_validation_layout.call_args.kwargs[ - "disabled_frames" - ] - assert all(v in task_validation_frames for v in updated_disabled_frames) + updated_disabled_frames = mock_update_task_validation_layout.call_args.kwargs[ + "disabled_frames" + ] + assert all(v in task_validation_frames for v in updated_disabled_frames) + assert { + f + for f, c in Counter(task_honeypot_real_frames).items() + if c == max_validation_frame_uses + } == set(updated_disabled_frames), Counter(task_honeypot_real_frames) + + updated_honeypot_real_frames = mock_update_task_validation_layout.call_args.kwargs[ + "honeypot_real_frames" + ] - updated_honeypot_real_frames = mock_update_task_validation_layout.call_args.kwargs[ - "honeypot_real_frames" + for job_start, job_stop in job_frame_ranges: + job_honeypot_positions = [ + i for i, v in enumerate(task_honeypots) if v in range(job_start, job_stop + 1) ] + job_updated_honeypots = [ + updated_honeypot_real_frames[i] for i in job_honeypot_positions + ] + + # Check that the frames do not repeat + assert sorted(job_updated_honeypots) == sorted(set(job_updated_honeypots)) + + # Check that the new frames are not from the excluded set + assert set(job_updated_honeypots).isdisjoint(updated_disabled_frames) - for job_start, job_stop in job_frame_ranges: - job_honeypot_positions = [ - i - for i, v in enumerate(task_honeypots) - if v in range(job_start, job_stop + 1) - ] - job_updated_honeypots = [ - updated_honeypot_real_frames[i] for i in job_honeypot_positions - ] - assert sorted(job_updated_honeypots) == sorted(set(job_updated_honeypots)) - - def _get_job_frame_ranges(self, jobs) -> list[tuple[int, int]]: + def _get_job_frame_ranges(self, jobs: Sequence[Sequence[str]]) -> list[tuple[int, int]]: job_frame_ranges = [] job_start = 0 for job_frames in jobs: @@ -487,53 +506,202 @@ def _get_job_frame_ranges(self, jobs) -> list[tuple[int, int]]: def _generate_task_frames( self, - frame_count, - validation_frames_count, - job_size, - validation_frames_per_job, - seed, - max_validation_frame_uses=None, - ): + frame_count: int, + validation_frames_count: int, + job_size: int, + validation_frames_per_job: int, + *, + seed: int | None = None, + max_validation_frame_uses: int | None = None, + ) -> tuple[Sequence[str], Sequence[int], Sequence[int], Sequence[int], Sequence[Sequence[str]]]: rng = np.random.Generator(np.random.MT19937(seed)) - task_frames = list(range(frame_count)) - task_validation_frames = task_frames[-validation_frames_count:] - task_real_frames = [] - task_honeypots = [] - task_honeypot_real_frames = [] - validation_frame_uses = {vf: 0 for vf in task_validation_frames} + + task_frame_names = list(map(str, range(frame_count))) + task_validation_frame_names = task_frame_names[-validation_frames_count:] + task_honeypot_real_frame_names = [] + + output_task_frame_names = [] + output_task_honeypots = [] + output_task_honeypot_real_frames = [] + + validation_frame_uses = {fn: 0 for fn in task_validation_frame_names} jobs = [] - for job_real_frames in take_by( - task_frames[: frame_count - validation_frames_count], job_size + for job_frame_names in take_by( + task_frame_names[: frame_count - validation_frames_count], job_size ): - available_validation_frames = [ - vf - for vf in task_validation_frames + available_validation_frame_names = [ + fn + for fn in task_validation_frame_names if not max_validation_frame_uses - or validation_frame_uses[vf] < max_validation_frame_uses + or validation_frame_uses[fn] < max_validation_frame_uses ] - job_validation_frames = rng.choice( - available_validation_frames, validation_frames_per_job, replace=False + job_validation_frame_names = rng.choice( + available_validation_frame_names, validation_frames_per_job, replace=False ).tolist() - job_real_frames = job_real_frames + job_validation_frames - rng.shuffle(job_real_frames) + for fn in job_validation_frame_names: + validation_frame_uses[fn] += 1 + + job_frame_names = job_frame_names + job_validation_frame_names + rng.shuffle(job_frame_names) + + jobs.append(job_frame_names) - jobs.append(job_real_frames) + job_start_frame = len(output_task_frame_names) + output_task_frame_names.extend(job_frame_names) - job_start_frame = len(task_real_frames) - task_real_frames.extend(job_real_frames) + for i, v in enumerate(job_frame_names): + if v in job_validation_frame_names: + output_task_honeypots.append(i + job_start_frame) + task_honeypot_real_frame_names.append(v) - for i, v in enumerate(job_real_frames): - if v in job_validation_frames: - task_honeypots.append(i + job_start_frame) - task_honeypot_real_frames.append(v) + validation_frame_name_to_idx = {} + for fn in task_validation_frame_names: + validation_frame_name_to_idx[fn] = len(output_task_frame_names) + output_task_frame_names.append(fn) + + output_task_honeypot_real_frames = [ + validation_frame_name_to_idx[fn] for fn in task_honeypot_real_frame_names + ] + + output_task_validation_frames = [ + validation_frame_name_to_idx[fn] for fn in task_validation_frame_names + ] - task_frame_names = list(map(str, task_real_frames)) return ( - task_real_frames, - task_frame_names, - task_validation_frames, - task_honeypots, - task_honeypot_real_frames, + output_task_frame_names, + output_task_validation_frames, + output_task_honeypots, + output_task_honeypot_real_frames, jobs, ) + + def test_can_stop_on_slow_annotation_after_warmup_iterations(self, session: Session): + escrow_address = ESCROW_ADDRESS + chain_id = Networks.localhost + + frame_count = 10 + + manifest = generate_manifest() + + cvat_task_id = 1 + cvat_job_id = 1 + annotator1 = WALLET_ADDRESS1 + + assignment1_id = f"0x{0:040d}" + assignment1_quality = 0 + + # create a validation input + with ExitStack() as common_lock_es: + logger = mock.Mock(Logger) + + common_lock_es.enter_context( + mock.patch("src.handlers.process_intermediate_results.BucketAccessInfo.parse_obj") + ) + common_lock_es.enter_context( + mock.patch("src.handlers.process_intermediate_results.dm.Dataset.import_from") + ) + common_lock_es.enter_context( + mock.patch("src.handlers.process_intermediate_results.extract_zip_archive") + ) + common_lock_es.enter_context( + mock.patch("src.handlers.process_intermediate_results.write_dir_to_zip_archive") + ) + + mock_make_cloud_client = common_lock_es.enter_context( + mock.patch("src.handlers.process_intermediate_results.make_cloud_client") + ) + mock_make_cloud_client.return_value.download_file = mock.Mock(return_value=b"") + + mock_get_task_validation_layout = common_lock_es.enter_context( + mock.patch( + "src.handlers.process_intermediate_results.cvat_api.get_task_validation_layout" + ) + ) + mock_get_task_validation_layout.return_value = mock.Mock( + cvat_api.models.ITaskValidationLayoutRead, + validation_frames=[2, 3], + honeypot_count=2, + honeypot_frames=[0, 1], + honeypot_real_frames=[2, 3], + frames_per_job_count=2, + ) + + mock_get_task_data_meta = common_lock_es.enter_context( + mock.patch("src.handlers.process_intermediate_results.cvat_api.get_task_data_meta") + ) + mock_get_task_data_meta.return_value = mock.Mock( + cvat_api.models.IDataMetaRead, + frames=[SimpleNamespace(name=f"frame_{i}.jpg") for i in range(frame_count)], + ) + + def patched_prepare_merged_dataset(self): + self._updated_merged_dataset_archive = io.BytesIO() + + common_lock_es.enter_context( + mock.patch( + "src.handlers.process_intermediate_results._TaskValidator._prepare_merged_dataset", + patched_prepare_merged_dataset, + ) + ) + + annotation_meta = AnnotationMeta( + jobs=[ + JobMeta( + job_id=cvat_job_id, + task_id=cvat_task_id, + annotation_filename="", + annotator_wallet_address=annotator1, + assignment_id=assignment1_id, + start_frame=0, + stop_frame=manifest.annotation.job_size + manifest.validation.val_size, + ) + ] + ) + + with ( + mock.patch( + "src.handlers.process_intermediate_results.cvat_api.get_task_quality_report" + ) as mock_get_task_quality_report, + mock.patch( + "src.handlers.process_intermediate_results.cvat_api.get_quality_report_data" + ) as mock_get_quality_report_data, + mock.patch( + "src.handlers.process_intermediate_results.cvat_api.get_jobs_quality_reports" + ) as mock_get_jobs_quality_reports, + mock.patch( + "src.handlers.process_intermediate_results.cvat_api.update_task_validation_layout" + ), + mock.patch("src.core.config.ValidationConfig.min_available_gt_threshold", 0), + mock.patch("src.core.config.ValidationConfig.max_gt_share", 1), + mock.patch("src.core.config.ValidationConfig.warmup_iterations", 1), + mock.patch("src.core.config.ValidationConfig.min_warmup_progress", 20), + ): + mock_get_task_quality_report.return_value = mock.Mock( + cvat_api.models.IQualityReport, id=1 + ) + mock_get_quality_report_data.return_value = mock.Mock( + cvat_api.QualityReportData, + frame_results={ + "0": mock.Mock(annotations=mock.Mock(accuracy=assignment1_quality)), + "1": mock.Mock(annotations=mock.Mock(accuracy=assignment1_quality)), + }, + ) + mock_get_jobs_quality_reports.return_value = [ + mock.Mock( + cvat_api.models.IQualityReport, + job_id=1, + summary=mock.Mock(accuracy=assignment1_quality), + ), + ] + + with pytest.raises(TooSlowAnnotationError): + process_intermediate_results( + session, + escrow_address=escrow_address, + chain_id=chain_id, + meta=annotation_meta, + merged_annotations=io.BytesIO(), + manifest=manifest, + logger=logger, + ) From aae6f4b95ecd367200116bc9c3405d239117e26f Mon Sep 17 00:00:00 2001 From: Maxim Zhiltsov Date: Thu, 26 Dec 2024 14:16:39 +0300 Subject: [PATCH 3/3] [CVAT] Refactor oracle configs (#2780) * Remove extra files * Remove outdated validation settings * Remove unused aws storage region * Remove cvat_ prefix for cvat settings * Refactor env var access * Update exchange oracle cron settings * Add comment * Add missing template vars * Format code * Merge * Add missing type conversions * Fix invalid webhook handling * Fix merge * Remove extra config parameter for polygons iou * Update code formatting --- .../cvat/exchange-oracle/pyproject.toml | 1 + .../cvat/exchange-oracle/src/.env.template | 11 +- .../cvat/exchange-oracle/src/core/config.py | 218 ++++++++---------- .../src/crons/webhooks/recording_oracle.py | 4 +- .../exchange-oracle/src/cvat/api_calls.py | 34 +-- .../src/handlers/job_creation.py | 14 +- .../exchange-oracle/src/utils/assignments.py | 2 +- .../src/validators/signature.py | 2 +- .../test_process_job_launcher_webhooks.py | 6 +- .../exchange-oracle/tests/utils/setup_cvat.py | 2 +- .../cvat/recording-oracle/pyproject.toml | 1 + .../cvat/recording-oracle/src/.env.template | 2 - .../cvat/recording-oracle/src/core/config.py | 124 +++++----- .../recording-oracle/src/cvat/api_calls.py | 16 +- 14 files changed, 211 insertions(+), 226 deletions(-) diff --git a/packages/examples/cvat/exchange-oracle/pyproject.toml b/packages/examples/cvat/exchange-oracle/pyproject.toml index b08e32026f..93b64e3858 100644 --- a/packages/examples/cvat/exchange-oracle/pyproject.toml +++ b/packages/examples/cvat/exchange-oracle/pyproject.toml @@ -114,6 +114,7 @@ ignore = [ "ANN002", # Missing type annotation for `*args` "TRY300", # Consider moving this statement to an `else` block "C901", # Function is too complex + "PLW1508", # invalid-envvar-default. Alerts only for os.getenv(), but not for os.environ.get() "PLW2901", # Variable overwritten by assignment target "PTH118", # Prefer pathlib instead of os.path "PTH119", # `os.path.basename()` should be replaced by `Path.name` diff --git a/packages/examples/cvat/exchange-oracle/src/.env.template b/packages/examples/cvat/exchange-oracle/src/.env.template index cd3f86a064..65a0e3966f 100644 --- a/packages/examples/cvat/exchange-oracle/src/.env.template +++ b/packages/examples/cvat/exchange-oracle/src/.env.template @@ -44,9 +44,7 @@ PROCESS_RECORDING_ORACLE_WEBHOOKS_CHUNK_SIZE= PROCESS_REPUTATION_ORACLE_WEBHOOKS_CHUNK_SIZE= PROCESS_REPUTATION_ORACLE_WEBHOOKS_INT= TRACK_COMPLETED_PROJECTS_INT= -TRACK_COMPLETED_PROJECTS_CHUNK_SIZE= TRACK_COMPLETED_TASKS_INT= -TRACK_COMPLETED_TASKS_CHUNK_SIZE= TRACK_COMPLETED_ESCROWS_INT= TRACK_COMPLETED_ESCROWS_CHUNK_SIZE= TRACK_ESCROW_VALIDATIONS_INT= @@ -79,13 +77,11 @@ CVAT_IOU_THRESHOLD= CVAT_OKS_SIGMA= CVAT_EXPORT_TIMEOUT= CVAT_IMPORT_TIMEOUT= -CVAT_POLYGONS_IOU_THRESHOLD= # Storage Config (S3/GCS) STORAGE_PROVIDER= STORAGE_ENDPOINT_URL= -STORAGE_REGION= STORAGE_ACCESS_KEY= STORAGE_SECRET_KEY= STORAGE_KEY_FILE_PATH= @@ -96,6 +92,7 @@ STORAGE_USE_SSL= ENABLE_CUSTOM_CLOUD_HOST= REQUEST_LOGGING_ENABLED= +PROFILING_ENABLED= MANIFEST_CACHE_TTL= # Core @@ -112,9 +109,14 @@ HUMAN_APP_JWT_KEY= # API config DEFAULT_API_PAGE_SIZE= +MIN_API_PAGE_SIZE= +MAX_API_PAGE_SIZE= +STATS_RPS_LIMIT= # Localhost +LOCALHOST_RPC_API_URL= +LOCALHOST_AMOY_ADDR= LOCALHOST_RECORDING_ORACLE_ADDRESS= LOCALHOST_RECORDING_ORACLE_URL= LOCALHOST_JOB_LAUNCHER_URL= @@ -122,6 +124,7 @@ LOCALHOST_REPUTATION_ORACLE_ADDRESS= LOCALHOST_REPUTATION_ORACLE_URL= # Encryption + PGP_PRIVATE_KEY= PGP_PASSPHRASE= PGP_PUBLIC_KEY_URL= diff --git a/packages/examples/cvat/exchange-oracle/src/core/config.py b/packages/examples/cvat/exchange-oracle/src/core/config.py index 60562e8eab..951d51f2f9 100644 --- a/packages/examples/cvat/exchange-oracle/src/core/config.py +++ b/packages/examples/cvat/exchange-oracle/src/core/config.py @@ -5,6 +5,7 @@ import os from collections.abc import Iterable from enum import Enum +from os import getenv from typing import ClassVar, Optional from attrs.converters import to_bool @@ -16,13 +17,16 @@ from src.utils.logging import parse_log_level from src.utils.net import is_ipv4 -dotenv_path = os.getenv("DOTENV_PATH", None) +dotenv_path = getenv("DOTENV_PATH", None) if dotenv_path and not os.path.exists(dotenv_path): # noqa: PTH110 raise FileNotFoundError(dotenv_path) load_dotenv(dotenv_path) +# TODO: add some logic to report unused/deprecated env vars on startup + + class _BaseConfig: @classmethod def validate(cls) -> None: @@ -30,12 +34,12 @@ def validate(cls) -> None: class PostgresConfig: - port = os.environ.get("PG_PORT", "5432") - host = os.environ.get("PG_HOST", "0.0.0.0") # noqa: S104 - user = os.environ.get("PG_USER", "admin") - password = os.environ.get("PG_PASSWORD", "admin") - database = os.environ.get("PG_DB", "exchange_oracle") - lock_timeout = int(os.environ.get("PG_LOCK_TIMEOUT", "3000")) # milliseconds + port = getenv("PG_PORT", "5432") + host = getenv("PG_HOST", "0.0.0.0") # noqa: S104 + user = getenv("PG_USER", "admin") + password = getenv("PG_PASSWORD", "admin") + database = getenv("PG_DB", "exchange_oracle") + lock_timeout = int(getenv("PG_LOCK_TIMEOUT", "3000")) # milliseconds @classmethod def connection_url(cls) -> str: @@ -43,12 +47,12 @@ def connection_url(cls) -> str: class RedisConfig: - port = int(os.environ.get("REDIS_PORT", "6379")) - host = os.environ.get("REDIS_HOST", "0.0.0.0") # noqa: S104 - database = int(os.environ.get("REDIS_DB", "0")) - user = os.environ.get("REDIS_USER", "") - password = os.environ.get("REDIS_PASSWORD", "") - use_ssl = to_bool(os.environ.get("REDIS_USE_SSL", "false")) + port = int(getenv("REDIS_PORT", "6379")) + host = getenv("REDIS_HOST", "0.0.0.0") # noqa: S104 + database = int(getenv("REDIS_DB", "0")) + user = getenv("REDIS_USER", "") + password = getenv("REDIS_PASSWORD", "") + use_ssl = to_bool(getenv("REDIS_USE_SSL", "false")) @classmethod def connection_url(cls) -> str: @@ -80,144 +84,126 @@ def is_configured(cls) -> bool: class PolygonMainnetConfig(_NetworkConfig): chain_id = 137 - rpc_api = os.environ.get("POLYGON_MAINNET_RPC_API_URL") - private_key = os.environ.get("POLYGON_MAINNET_PRIVATE_KEY") - addr = os.environ.get("POLYGON_MAINNET_ADDR") + rpc_api = getenv("POLYGON_MAINNET_RPC_API_URL") + private_key = getenv("POLYGON_MAINNET_PRIVATE_KEY") + addr = getenv("POLYGON_MAINNET_ADDR") class PolygonAmoyConfig(_NetworkConfig): chain_id = 80002 - rpc_api = os.environ.get("POLYGON_AMOY_RPC_API_URL") - private_key = os.environ.get("POLYGON_AMOY_PRIVATE_KEY") - addr = os.environ.get("POLYGON_AMOY_ADDR") + rpc_api = getenv("POLYGON_AMOY_RPC_API_URL") + private_key = getenv("POLYGON_AMOY_PRIVATE_KEY") + addr = getenv("POLYGON_AMOY_ADDR") class LocalhostConfig(_NetworkConfig): chain_id = 1338 - rpc_api = os.environ.get("LOCALHOST_RPC_API_URL", "http://blockchain-node:8545") - private_key = os.environ.get( + rpc_api = getenv("LOCALHOST_RPC_API_URL", "http://blockchain-node:8545") + private_key = getenv( "LOCALHOST_PRIVATE_KEY", "ac0974bec39a17e36ba4a6b4d238ff944bacb478cbed5efcae784d7bf4f2ff80", ) - addr = os.environ.get("LOCALHOST_AMOY_ADDR", "0xf39Fd6e51aad88F6F4ce6aB8827279cffFb92266") + addr = getenv("LOCALHOST_AMOY_ADDR", "0xf39Fd6e51aad88F6F4ce6aB8827279cffFb92266") - job_launcher_url = os.environ.get("LOCALHOST_JOB_LAUNCHER_URL") + job_launcher_url = getenv("LOCALHOST_JOB_LAUNCHER_URL") - recording_oracle_address = os.environ.get("LOCALHOST_RECORDING_ORACLE_ADDRESS") - recording_oracle_url = os.environ.get("LOCALHOST_RECORDING_ORACLE_URL") + recording_oracle_address = getenv("LOCALHOST_RECORDING_ORACLE_ADDRESS") + recording_oracle_url = getenv("LOCALHOST_RECORDING_ORACLE_URL") - reputation_oracle_address = os.environ.get("LOCALHOST_REPUTATION_ORACLE_ADDRESS") - reputation_oracle_url = os.environ.get("LOCALHOST_REPUTATION_ORACLE_URL") + reputation_oracle_address = getenv("LOCALHOST_REPUTATION_ORACLE_ADDRESS") + reputation_oracle_url = getenv("LOCALHOST_REPUTATION_ORACLE_URL") class CronConfig: - process_job_launcher_webhooks_int = int(os.environ.get("PROCESS_JOB_LAUNCHER_WEBHOOKS_INT", 30)) + process_job_launcher_webhooks_int = int(getenv("PROCESS_JOB_LAUNCHER_WEBHOOKS_INT", 30)) process_job_launcher_webhooks_chunk_size = int( - os.environ.get("PROCESS_JOB_LAUNCHER_WEBHOOKS_CHUNK_SIZE", 5) - ) - process_recording_oracle_webhooks_int = int( - os.environ.get("PROCESS_RECORDING_ORACLE_WEBHOOKS_INT", 30) + getenv("PROCESS_JOB_LAUNCHER_WEBHOOKS_CHUNK_SIZE", 5) ) + process_recording_oracle_webhooks_int = int(getenv("PROCESS_RECORDING_ORACLE_WEBHOOKS_INT", 30)) process_recording_oracle_webhooks_chunk_size = int( - os.environ.get("PROCESS_RECORDING_ORACLE_WEBHOOKS_CHUNK_SIZE", 5) + getenv("PROCESS_RECORDING_ORACLE_WEBHOOKS_CHUNK_SIZE", 5) ) process_reputation_oracle_webhooks_chunk_size = int( - os.environ.get("PROCESS_REPUTATION_ORACLE_WEBHOOKS_CHUNK_SIZE", 5) + getenv("PROCESS_REPUTATION_ORACLE_WEBHOOKS_CHUNK_SIZE", 5) ) process_reputation_oracle_webhooks_int = int( - os.environ.get("PROCESS_REPUTATION_ORACLE_WEBHOOKS_INT", 5) - ) - track_completed_projects_int = int(os.environ.get("TRACK_COMPLETED_PROJECTS_INT", 30)) - track_completed_projects_chunk_size = os.environ.get("TRACK_COMPLETED_PROJECTS_CHUNK_SIZE", 5) - track_completed_tasks_int = int(os.environ.get("TRACK_COMPLETED_TASKS_INT", 30)) - track_completed_tasks_chunk_size = os.environ.get("TRACK_COMPLETED_TASKS_CHUNK_SIZE", 20) - track_creating_tasks_int = int(os.environ.get("TRACK_CREATING_TASKS_INT", 300)) - track_creating_tasks_chunk_size = os.environ.get("TRACK_CREATING_TASKS_CHUNK_SIZE", 5) - track_assignments_int = int(os.environ.get("TRACK_ASSIGNMENTS_INT", 5)) - track_assignments_chunk_size = os.environ.get("TRACK_ASSIGNMENTS_CHUNK_SIZE", 10) - - track_completed_escrows_int = int( - # backward compatibility - os.environ.get( - "TRACK_COMPLETED_ESCROWS_INT", os.environ.get("RETRIEVE_ANNOTATIONS_INT", 60) - ) - ) - track_completed_escrows_chunk_size = int( - os.environ.get("TRACK_COMPLETED_ESCROWS_CHUNK_SIZE", 100) - ) - track_escrow_validations_int = int(os.environ.get("TRACK_COMPLETED_ESCROWS_INT", 60)) - track_escrow_validations_chunk_size = int( - os.environ.get("TRACK_ESCROW_VALIDATIONS_CHUNK_SIZE", 1) + getenv("PROCESS_REPUTATION_ORACLE_WEBHOOKS_INT", 5) ) + track_completed_projects_int = int(getenv("TRACK_COMPLETED_PROJECTS_INT", 30)) + track_completed_tasks_int = int(getenv("TRACK_COMPLETED_TASKS_INT", 30)) + track_creating_tasks_int = int(getenv("TRACK_CREATING_TASKS_INT", 300)) + track_creating_tasks_chunk_size = getenv("TRACK_CREATING_TASKS_CHUNK_SIZE", 5) + track_assignments_int = int(getenv("TRACK_ASSIGNMENTS_INT", 5)) + track_assignments_chunk_size = int(getenv("TRACK_ASSIGNMENTS_CHUNK_SIZE", 10)) + + track_completed_escrows_int = int(getenv("TRACK_COMPLETED_ESCROWS_INT", 60)) + track_completed_escrows_chunk_size = int(getenv("TRACK_COMPLETED_ESCROWS_CHUNK_SIZE", 100)) + track_escrow_validations_int = int(getenv("TRACK_ESCROW_VALIDATIONS_INT", 60)) + track_escrow_validations_chunk_size = int(getenv("TRACK_ESCROW_VALIDATIONS_CHUNK_SIZE", 1)) track_completed_escrows_max_downloading_retries = int( - os.environ.get("TRACK_COMPLETED_ESCROWS_MAX_DOWNLOADING_RETRIES", 10) + getenv("TRACK_COMPLETED_ESCROWS_MAX_DOWNLOADING_RETRIES", 10) ) "Maximum number of downloading attempts per job or project during results downloading" track_completed_escrows_jobs_downloading_batch_size = int( - os.environ.get("TRACK_COMPLETED_ESCROWS_JOBS_DOWNLOADING_BATCH_SIZE", 500) + getenv("TRACK_COMPLETED_ESCROWS_JOBS_DOWNLOADING_BATCH_SIZE", 500) ) "Maximum number of parallel downloading requests during results downloading" - rejected_projects_chunk_size = os.environ.get("REJECTED_PROJECTS_CHUNK_SIZE", 20) - accepted_projects_chunk_size = os.environ.get("ACCEPTED_PROJECTS_CHUNK_SIZE", 20) + process_rejected_projects_chunk_size = int(getenv("REJECTED_PROJECTS_CHUNK_SIZE", 20)) + process_accepted_projects_chunk_size = int(getenv("ACCEPTED_PROJECTS_CHUNK_SIZE", 20)) - track_escrow_creation_chunk_size = os.environ.get("TRACK_ESCROW_CREATION_CHUNK_SIZE", 20) - track_escrow_creation_int = int(os.environ.get("TRACK_ESCROW_CREATION_INT", 300)) + track_escrow_creation_chunk_size = int(getenv("TRACK_ESCROW_CREATION_CHUNK_SIZE", 20)) + track_escrow_creation_int = int(getenv("TRACK_ESCROW_CREATION_INT", 300)) class CvatConfig: - # TODO: remove cvat_ prefix - cvat_url = os.environ.get("CVAT_URL", "http://localhost:8080") - cvat_admin = os.environ.get("CVAT_ADMIN", "admin") - cvat_admin_pass = os.environ.get("CVAT_ADMIN_PASS", "admin") - cvat_org_slug = os.environ.get("CVAT_ORG_SLUG", "") - - cvat_job_overlap = int(os.environ.get("CVAT_JOB_OVERLAP", 0)) - cvat_task_segment_size = int(os.environ.get("CVAT_TASK_SEGMENT_SIZE", 150)) - cvat_default_image_quality = int(os.environ.get("CVAT_DEFAULT_IMAGE_QUALITY", 70)) - cvat_max_jobs_per_task = int(os.environ.get("CVAT_MAX_JOBS_PER_TASK", 1000)) - cvat_task_creation_check_interval = int(os.environ.get("CVAT_TASK_CREATION_CHECK_INTERVAL", 5)) - - cvat_export_timeout = int(os.environ.get("CVAT_EXPORT_TIMEOUT", 5 * 60)) + host_url = getenv("CVAT_URL", "http://localhost:8080") + admin_login = getenv("CVAT_ADMIN", "admin") + admin_pass = getenv("CVAT_ADMIN_PASS", "admin") + org_slug = getenv("CVAT_ORG_SLUG", "") + + job_overlap = int(getenv("CVAT_JOB_OVERLAP", 0)) + task_segment_size = int(getenv("CVAT_TASK_SEGMENT_SIZE", 150)) + default_image_quality = int(getenv("CVAT_DEFAULT_IMAGE_QUALITY", 70)) + max_jobs_per_task = int(getenv("CVAT_MAX_JOBS_PER_TASK", 1000)) + task_creation_check_interval = int(getenv("CVAT_TASK_CREATION_CHECK_INTERVAL", 5)) + + export_timeout = int(getenv("CVAT_EXPORT_TIMEOUT", 5 * 60)) "Timeout, in seconds, for annotations or dataset export waiting" - cvat_import_timeout = int(os.environ.get("CVAT_IMPORT_TIMEOUT", 60 * 60)) + import_timeout = int(getenv("CVAT_IMPORT_TIMEOUT", 60 * 60)) "Timeout, in seconds, for waiting on GT annotations import" # quality control settings - cvat_max_validation_checks = int(os.environ.get("CVAT_MAX_VALIDATION_CHECKS", 3)) + max_validation_checks = int(getenv("CVAT_MAX_VALIDATION_CHECKS", 3)) "Maximum number of attempts to run a validation check on a job after completing annotation" - cvat_iou_threshold = float(os.environ.get("CVAT_IOU_THRESHOLD", 0.8)) - cvat_oks_sigma = float(os.environ.get("CVAT_OKS_SIGMA", 0.1)) - - cvat_polygons_iou_threshold = float(os.environ.get("CVAT_POLYGONS_IOU_THRESHOLD", 0.5)) - "`iou_threshold` parameter for quality settings in polygons tasks" + iou_threshold = float(getenv("CVAT_IOU_THRESHOLD", 0.8)) + oks_sigma = float(getenv("CVAT_OKS_SIGMA", 0.1)) - cvat_incoming_webhooks_url = os.environ.get("CVAT_INCOMING_WEBHOOKS_URL") - cvat_webhook_secret = os.environ.get("CVAT_WEBHOOK_SECRET", "thisisasamplesecret") + incoming_webhooks_url = getenv("CVAT_INCOMING_WEBHOOKS_URL") + webhook_secret = getenv("CVAT_WEBHOOK_SECRET", "thisisasamplesecret") class StorageConfig: provider: ClassVar[str] = os.environ["STORAGE_PROVIDER"].lower() data_bucket_name: ClassVar[str] = ( - os.environ.get("STORAGE_RESULTS_BUCKET_NAME") # backward compatibility + getenv("STORAGE_RESULTS_BUCKET_NAME") # backward compatibility or os.environ["STORAGE_BUCKET_NAME"] ) endpoint_url: ClassVar[str] = os.environ[ "STORAGE_ENDPOINT_URL" ] # TODO: probably should be optional - region: ClassVar[str | None] = os.environ.get("STORAGE_REGION") - results_dir_suffix: ClassVar[str] = os.environ.get("STORAGE_RESULTS_DIR_SUFFIX", "-results") - secure: ClassVar[bool] = to_bool(os.environ.get("STORAGE_USE_SSL", "true")) + results_dir_suffix: ClassVar[str] = getenv("STORAGE_RESULTS_DIR_SUFFIX", "-results") + secure: ClassVar[bool] = to_bool(getenv("STORAGE_USE_SSL", "true")) # S3 specific attributes - access_key: ClassVar[str | None] = os.environ.get("STORAGE_ACCESS_KEY") - secret_key: ClassVar[str | None] = os.environ.get("STORAGE_SECRET_KEY") + access_key: ClassVar[str | None] = getenv("STORAGE_ACCESS_KEY") + secret_key: ClassVar[str | None] = getenv("STORAGE_SECRET_KEY") # GCS specific attributes - key_file_path: ClassVar[str | None] = os.environ.get("STORAGE_KEY_FILE_PATH") + key_file_path: ClassVar[str | None] = getenv("STORAGE_KEY_FILE_PATH") @classmethod def get_scheme(cls) -> str: @@ -236,13 +222,13 @@ def bucket_url(cls) -> str: class FeaturesConfig: - enable_custom_cloud_host = to_bool(os.environ.get("ENABLE_CUSTOM_CLOUD_HOST", "no")) + enable_custom_cloud_host = to_bool(getenv("ENABLE_CUSTOM_CLOUD_HOST", "no")) "Allows using a custom host in manifest bucket urls" - request_logging_enabled = to_bool(os.getenv("REQUEST_LOGGING_ENABLED", "0")) + request_logging_enabled = to_bool(getenv("REQUEST_LOGGING_ENABLED", "0")) "Allow to log request details for each request" - profiling_enabled = to_bool(os.getenv("PROFILING_ENABLED", "0")) + profiling_enabled = to_bool(getenv("PROFILING_ENABLED", "0")) "Allow to profile specific requests" manifest_cache_ttl = int(os.getenv("MANIFEST_CACHE_TTL", str(2 * 24 * 60 * 60))) @@ -250,15 +236,15 @@ class FeaturesConfig: class CoreConfig: - default_assignment_time = int(os.environ.get("DEFAULT_ASSIGNMENT_TIME", 1800)) + default_assignment_time = int(getenv("DEFAULT_ASSIGNMENT_TIME", 1800)) - skeleton_assignment_size_mult = int(os.environ.get("SKELETON_ASSIGNMENT_SIZE_MULT", 1)) + skeleton_assignment_size_mult = int(getenv("SKELETON_ASSIGNMENT_SIZE_MULT", 1)) "Assignment size multiplier for image_skeletons_from_boxes tasks" - min_roi_size_w = int(os.environ.get("MIN_ROI_SIZE_W", 350)) + min_roi_size_w = int(getenv("MIN_ROI_SIZE_W", 350)) "Minimum absolute ROI size for image_boxes_from_points and image_skeletons_from_boxes tasks" - min_roi_size_h = int(os.environ.get("MIN_ROI_SIZE_H", 300)) + min_roi_size_h = int(getenv("MIN_ROI_SIZE_H", 300)) "Minimum absolute ROI size for image_boxes_from_points and image_skeletons_from_boxes tasks" @@ -268,21 +254,21 @@ class HumanAppConfig: # openssl ecparam -name prime256v1 -genkey -noout -out ec_private.pem # openssl ec -in ec_private.pem -pubout -out ec_public.pem # HUMAN_APP_JWT_KEY=$(cat ec_public.pem) - jwt_public_key = os.environ.get("HUMAN_APP_JWT_KEY") + jwt_public_key = getenv("HUMAN_APP_JWT_KEY") class ApiConfig: - default_page_size = int(os.environ.get("DEFAULT_API_PAGE_SIZE", 5)) - min_page_size = int(os.environ.get("MIN_API_PAGE_SIZE", 1)) - max_page_size = int(os.environ.get("MAX_API_PAGE_SIZE", 10)) + default_page_size = int(getenv("DEFAULT_API_PAGE_SIZE", 5)) + min_page_size = int(getenv("MIN_API_PAGE_SIZE", 1)) + max_page_size = int(getenv("MAX_API_PAGE_SIZE", 10)) - stats_rps_limit = int(os.environ.get("STATS_RPS_LIMIT", 4)) + stats_rps_limit = int(getenv("STATS_RPS_LIMIT", 4)) class EncryptionConfig(_BaseConfig): - pgp_passphrase = os.environ.get("PGP_PASSPHRASE", "") - pgp_private_key = os.environ.get("PGP_PRIVATE_KEY", "") - pgp_public_key_url = os.environ.get("PGP_PUBLIC_KEY_URL", "") + pgp_passphrase = getenv("PGP_PASSPHRASE", "") + pgp_private_key = getenv("PGP_PRIVATE_KEY", "") + pgp_public_key_url = getenv("PGP_PUBLIC_KEY_URL", "") @classmethod def validate(cls) -> None: @@ -302,9 +288,9 @@ def validate(cls) -> None: class DevelopmentConfig: - cvat_in_docker = bool(int(os.environ.get("DEV_CVAT_IN_DOCKER", "1"))) + cvat_in_docker = bool(int(getenv("DEV_CVAT_IN_DOCKER", "1"))) - exchange_oracle_host = os.environ.get("DEV_EXCHANGE_ORACLE_HOST", "172.22.0.1") + exchange_oracle_host = getenv("DEV_EXCHANGE_ORACLE_HOST", "172.22.0.1") """ Might be `host.docker.internal` or `172.22.0.1` if CVAT is running in Docker. @@ -329,13 +315,13 @@ def _missing_(cls, value: str) -> Optional["Environment"]: class Config: - debug = to_bool(os.environ.get("DEBUG", "false")) - port = int(os.environ.get("PORT", 8000)) - environment = Environment(os.environ.get("ENVIRONMENT", Environment.DEVELOPMENT.value)) - workers_amount = int(os.environ.get("WORKERS_AMOUNT", 1)) - webhook_max_retries = int(os.environ.get("WEBHOOK_MAX_RETRIES", 5)) - webhook_delay_if_failed = int(os.environ.get("WEBHOOK_DELAY_IF_FAILED", 60)) - loglevel = parse_log_level(os.environ.get("LOGLEVEL", "info")) + debug = to_bool(getenv("DEBUG", "false")) + port = int(getenv("PORT", 8000)) + environment = Environment(getenv("ENVIRONMENT", Environment.DEVELOPMENT.value)) + workers_amount = int(getenv("WORKERS_AMOUNT", 1)) + webhook_max_retries = int(getenv("WEBHOOK_MAX_RETRIES", 5)) + webhook_delay_if_failed = int(getenv("WEBHOOK_DELAY_IF_FAILED", 60)) + loglevel = parse_log_level(getenv("LOGLEVEL", "info")) polygon_mainnet = PolygonMainnetConfig polygon_amoy = PolygonAmoyConfig diff --git a/packages/examples/cvat/exchange-oracle/src/crons/webhooks/recording_oracle.py b/packages/examples/cvat/exchange-oracle/src/crons/webhooks/recording_oracle.py index eeb4a3581b..0315cac825 100644 --- a/packages/examples/cvat/exchange-oracle/src/crons/webhooks/recording_oracle.py +++ b/packages/examples/cvat/exchange-oracle/src/crons/webhooks/recording_oracle.py @@ -60,7 +60,6 @@ def handle_recording_oracle_event(webhook: Webhook, *, db_session: Session, logg ) return - chunk_size = CronConfig.accepted_projects_chunk_size project_ids = cvat_db_service.get_project_cvat_ids_by_escrow_address( db_session, webhook.escrow_address ) @@ -71,6 +70,7 @@ def handle_recording_oracle_event(webhook: Webhook, *, db_session: Session, logg ) return + chunk_size = CronConfig.process_accepted_projects_chunk_size for ids_chunk in take_by(project_ids, chunk_size): projects_chunk = cvat_db_service.get_projects_by_cvat_ids( db_session, ids_chunk, for_update=True, limit=chunk_size @@ -138,7 +138,7 @@ def handle_recording_oracle_event(webhook: Webhook, *, db_session: Session, logg ) rejected_project_cvat_ids = set(j.cvat_project_id for j in rejected_jobs) - chunk_size = CronConfig.rejected_projects_chunk_size + chunk_size = CronConfig.process_rejected_projects_chunk_size for chunk_ids in take_by(rejected_project_cvat_ids, chunk_size): projects_chunk = cvat_db_service.get_projects_by_cvat_ids( db_session, chunk_ids, for_update=True, limit=chunk_size diff --git a/packages/examples/cvat/exchange-oracle/src/cvat/api_calls.py b/packages/examples/cvat/exchange-oracle/src/cvat/api_calls.py index e40b4af209..cd1ad4520d 100644 --- a/packages/examples/cvat/exchange-oracle/src/cvat/api_calls.py +++ b/packages/examples/cvat/exchange-oracle/src/cvat/api_calls.py @@ -77,7 +77,7 @@ def _get_annotations( time_begin = utcnow() if timeout is _NOTSET: - timeout = Config.cvat_config.cvat_export_timeout + timeout = Config.cvat_config.export_timeout while True: (_, response) = endpoint.call_with_http_info( @@ -118,23 +118,23 @@ def get_api_client() -> ApiClient: return current_api_client configuration = Configuration( - host=Config.cvat_config.cvat_url, - username=Config.cvat_config.cvat_admin, - password=Config.cvat_config.cvat_admin_pass, + host=Config.cvat_config.host_url, + username=Config.cvat_config.admin_login, + password=Config.cvat_config.admin_pass, ) api_client = ApiClient(configuration=configuration) - api_client.set_default_header("X-organization", Config.cvat_config.cvat_org_slug) + api_client.set_default_header("X-organization", Config.cvat_config.org_slug) return api_client def get_sdk_client() -> Client: client = make_client( - host=Config.cvat_config.cvat_url, - credentials=(Config.cvat_config.cvat_admin, Config.cvat_config.cvat_admin_pass), + host=Config.cvat_config.host_url, + credentials=(Config.cvat_config.admin_login, Config.cvat_config.admin_pass), ) - client.organization_slug = Config.cvat_config.cvat_org_slug + client.organization_slug = Config.cvat_config.org_slug return client @@ -291,11 +291,11 @@ def create_cvat_webhook(project_id: int) -> models.WebhookRead: logger = logging.getLogger("app") with get_api_client() as api_client: webhook_write_request = models.WebhookWriteRequest( - target_url=Config.cvat_config.cvat_incoming_webhooks_url, + target_url=Config.cvat_config.incoming_webhooks_url, description="Exchange Oracle notification", type=models.WebhookType("project"), content_type=models.WebhookContentType("application/json"), - secret=Config.cvat_config.cvat_webhook_secret, + secret=Config.cvat_config.webhook_secret, is_active=True, # enable_ssl=True, project_id=project_id, @@ -403,7 +403,7 @@ def put_task_data( data_request = models.DataRequest( chunk_size=chunk_size, cloud_storage_id=cloudstorage_id, - image_quality=Config.cvat_config.cvat_default_image_quality, + image_quality=Config.cvat_config.image_quality, use_cache=True, use_zip_chunks=True, sorting_method=sorting_method, @@ -641,7 +641,7 @@ def upload_gt_annotations( *, format_name: str, sleep_interval: int = 5, - timeout: int | None = Config.cvat_config.cvat_import_timeout, + timeout: int | None = Config.cvat_config.import_timeout, ) -> None: # FUTURE-TODO: use job.import_annotations when CVAT supports a waiting timeout start_time = datetime.now(timezone.utc) @@ -728,8 +728,8 @@ def update_quality_control_settings( *, target_metric_threshold: float, target_metric: str = "accuracy", - max_validations_per_job: int = Config.cvat_config.cvat_max_validation_checks, - iou_threshold: float = Config.cvat_config.cvat_iou_threshold, + max_validations_per_job: int = Config.cvat_config.max_validation_checks, + iou_threshold: float = Config.cvat_config.iou_threshold, oks_sigma: float | None = None, point_size_base: str | None = None, match_empty_frames: bool | None = None, @@ -823,7 +823,7 @@ def get_user_id(user_email: str) -> int: try: (invitation, _) = api_client.invitations_api.create( models.InvitationWriteRequest(role="worker", email=user_email), - org=Config.cvat_config.cvat_org_slug, + org=Config.cvat_config.org_slug, ) except exceptions.ApiException as e: logger.exception(f"Exception when calling get_user_id(): {e}\n") @@ -839,7 +839,7 @@ def remove_user_from_org(user_id: int): try: (page, _) = api_client.users_api.list( filter='{"==":[{"var":"id"},"%s"]}' % user_id, # noqa: UP031 - org=Config.cvat_config.cvat_org_slug, + org=Config.cvat_config.org_slug, ) if not page.results: return @@ -849,7 +849,7 @@ def remove_user_from_org(user_id: int): (page, _) = api_client.memberships_api.list( user=user.username, - org=Config.cvat_config.cvat_org_slug, + org=Config.cvat_config.org_slug, ) if page.results: api_client.memberships_api.destroy(page.results[0].id) diff --git a/packages/examples/cvat/exchange-oracle/src/handlers/job_creation.py b/packages/examples/cvat/exchange-oracle/src/handlers/job_creation.py index 891f827ec4..f2d940ad4c 100644 --- a/packages/examples/cvat/exchange-oracle/src/handlers/job_creation.py +++ b/packages/examples/cvat/exchange-oracle/src/handlers/job_creation.py @@ -203,7 +203,7 @@ def _wait_task_creation(self, task_id: int) -> cvat_api.UploadStatus: if task_status not in [cvat_api.UploadStatus.STARTED, cvat_api.UploadStatus.QUEUED]: return task_status - sleep(Config.cvat_config.cvat_task_creation_check_interval) + sleep(Config.cvat_config.task_creation_check_interval) def _setup_gt_job_for_cvat_task( self, task_id: int, gt_dataset: dm.Dataset, *, dm_export_format: str = "coco" @@ -409,7 +409,7 @@ def build(self): for data_subset in self._split_dataset_per_task( data_to_be_annotated, - subset_size=Config.cvat_config.cvat_max_jobs_per_task * segment_size, + subset_size=Config.cvat_config.max_jobs_per_task * segment_size, ): cvat_task = cvat_api.create_task( cvat_project.id, escrow_address, segment_size=segment_size @@ -552,7 +552,7 @@ def build(self): class PolygonTaskBuilder(SimpleTaskBuilder): def _setup_quality_settings(self, task_id: int, **overrides) -> None: - values = {"iou_threshold": Config.cvat_config.cvat_polygons_iou_threshold, **overrides} + values = {"iou_threshold": Config.cvat_config.iou_threshold, **overrides} super()._setup_quality_settings(task_id, **values) @@ -1554,7 +1554,7 @@ def _create_on_cvat(self): for data_subset in self._split_dataset_per_task( self._roi_filenames_to_be_annotated, - subset_size=Config.cvat_config.cvat_max_jobs_per_task * segment_size, + subset_size=Config.cvat_config.max_jobs_per_task * segment_size, ): cvat_task = cvat_api.create_task( cvat_project.id, self.escrow_address, segment_size=segment_size @@ -2354,7 +2354,7 @@ def _prepare_task_params(self): label_id=label_id, roi_ids=task_data_roi_ids, roi_gt_ids=label_gt_roi_ids ) for task_data_roi_ids in take_by( - label_data_roi_ids, Config.cvat_config.cvat_max_jobs_per_task * segment_size + label_data_roi_ids, Config.cvat_config.max_jobs_per_task * segment_size ) ] ) @@ -2603,7 +2603,7 @@ def _save_cvat_gt_dataset_to_oracle_bucket( def _setup_quality_settings(self, task_id: int, **overrides) -> None: values = { - "oks_sigma": Config.cvat_config.cvat_oks_sigma, + "oks_sigma": Config.cvat_config.oks_sigma, "point_size_base": "image_size", # we don't expect any boxes or groups, so ignore them } values.update(overrides) @@ -2767,7 +2767,7 @@ def _task_params_label_key(ts): cvat_task.id, gt_point_dataset, dm_export_format="cvat" ) self._setup_quality_settings( - cvat_task.id, oks_sigma=Config.cvat_config.cvat_oks_sigma + cvat_task.id, oks_sigma=Config.cvat_config.oks_sigma ) db_service.create_data_upload(session, cvat_task.id) diff --git a/packages/examples/cvat/exchange-oracle/src/utils/assignments.py b/packages/examples/cvat/exchange-oracle/src/utils/assignments.py index ceeebfc150..51c8fe9c7b 100644 --- a/packages/examples/cvat/exchange-oracle/src/utils/assignments.py +++ b/packages/examples/cvat/exchange-oracle/src/utils/assignments.py @@ -24,7 +24,7 @@ def compose_assignment_url(task_id: int, job_id: int, *, project: Project) -> st if project.job_type == TaskTypes.image_skeletons_from_boxes: query_params += "&defaultPointsCount=1" - return urljoin(Config.cvat_config.cvat_url, f"/tasks/{task_id}/jobs/{job_id}{query_params}") + return urljoin(Config.cvat_config.host_url, f"/tasks/{task_id}/jobs/{job_id}{query_params}") def get_default_assignment_timeout(task_type: TaskTypes) -> int: diff --git a/packages/examples/cvat/exchange-oracle/src/validators/signature.py b/packages/examples/cvat/exchange-oracle/src/validators/signature.py index 077177b80a..2f42d019e5 100644 --- a/packages/examples/cvat/exchange-oracle/src/validators/signature.py +++ b/packages/examples/cvat/exchange-oracle/src/validators/signature.py @@ -32,7 +32,7 @@ async def validate_cvat_signature(request: Request, x_signature_256: str): signature = ( "sha256=" + hmac.new( - Config.cvat_config.cvat_webhook_secret.encode("utf-8"), + Config.cvat_config.webhook_secret.encode("utf-8"), data, digestmod=sha256, ).hexdigest() diff --git a/packages/examples/cvat/exchange-oracle/tests/integration/cron/test_process_job_launcher_webhooks.py b/packages/examples/cvat/exchange-oracle/tests/integration/cron/test_process_job_launcher_webhooks.py index e4336b1da8..14607ea7fa 100644 --- a/packages/examples/cvat/exchange-oracle/tests/integration/cron/test_process_job_launcher_webhooks.py +++ b/packages/examples/cvat/exchange-oracle/tests/integration/cron/test_process_job_launcher_webhooks.py @@ -43,9 +43,9 @@ def tearDown(self): self.session.close() def test_process_incoming_job_launcher_webhooks_escrow_created_type(self): - webhok_id = str(uuid.uuid4()) + webhook_id = str(uuid.uuid4()) webhook = Webhook( - id=webhok_id, + id=webhook_id, signature="signature", escrow_address=escrow_address, chain_id=chain_id, @@ -93,7 +93,7 @@ def test_process_incoming_job_launcher_webhooks_escrow_created_type(self): process_incoming_job_launcher_webhooks() updated_webhook = ( - self.session.execute(select(Webhook).where(Webhook.id == webhok_id)).scalars().first() + self.session.execute(select(Webhook).where(Webhook.id == webhook_id)).scalars().first() ) assert updated_webhook.status == OracleWebhookStatuses.completed.value diff --git a/packages/examples/cvat/exchange-oracle/tests/utils/setup_cvat.py b/packages/examples/cvat/exchange-oracle/tests/utils/setup_cvat.py index 549a177552..8bc1a01515 100644 --- a/packages/examples/cvat/exchange-oracle/tests/utils/setup_cvat.py +++ b/packages/examples/cvat/exchange-oracle/tests/utils/setup_cvat.py @@ -18,7 +18,7 @@ def generate_cvat_signature(data: dict): return ( "sha256=" + hmac.new( - CvatConfig.cvat_webhook_secret.encode("utf-8"), + CvatConfig.webhook_secret.encode("utf-8"), b_data, digestmod=sha256, ).hexdigest() diff --git a/packages/examples/cvat/recording-oracle/pyproject.toml b/packages/examples/cvat/recording-oracle/pyproject.toml index 6b42e915e8..01be8d38b3 100644 --- a/packages/examples/cvat/recording-oracle/pyproject.toml +++ b/packages/examples/cvat/recording-oracle/pyproject.toml @@ -106,6 +106,7 @@ ignore = [ "ANN002", # Missing type annotation for `*args` "TRY300", # Consider moving this statement to an `else` block "C901", # Function is too complex + "PLW1508", # invalid-envvar-default. Alerts only for os.getenv(), but not for os.environ.get() "PLW2901", # Variable overwritten by assignment target "PTH118", # Prefer pathlib instead of os.path "PTH119", # `os.path.basename()` should be replaced by `Path.name` diff --git a/packages/examples/cvat/recording-oracle/src/.env.template b/packages/examples/cvat/recording-oracle/src/.env.template index be2d19ad79..729027cc6d 100644 --- a/packages/examples/cvat/recording-oracle/src/.env.template +++ b/packages/examples/cvat/recording-oracle/src/.env.template @@ -37,7 +37,6 @@ PROCESS_REPUTATION_ORACLE_WEBHOOKS_CHUNK_SIZE= STORAGE_PROVIDER= STORAGE_ENDPOINT_URL= -STORAGE_REGION= STORAGE_ACCESS_KEY= STORAGE_SECRET_KEY= STORAGE_RESULTS_BUCKET_NAME= @@ -48,7 +47,6 @@ STORAGE_KEY_FILE_PATH= EXCHANGE_ORACLE_STORAGE_PROVIDER= EXCHANGE_ORACLE_STORAGE_ENDPOINT_URL= -EXCHANGE_ORACLE_STORAGE_REGION= EXCHANGE_ORACLE_STORAGE_ACCESS_KEY= EXCHANGE_ORACLE_STORAGE_SECRET_KEY= EXCHANGE_ORACLE_STORAGE_RESULTS_BUCKET_NAME= diff --git a/packages/examples/cvat/recording-oracle/src/core/config.py b/packages/examples/cvat/recording-oracle/src/core/config.py index d7ca733765..687a9cf5b1 100644 --- a/packages/examples/cvat/recording-oracle/src/core/config.py +++ b/packages/examples/cvat/recording-oracle/src/core/config.py @@ -4,6 +4,7 @@ import inspect import os from collections.abc import Iterable +from os import getenv from typing import ClassVar from attrs.converters import to_bool @@ -15,7 +16,7 @@ from src.utils.logging import parse_log_level from src.utils.net import is_ipv4 -dotenv_path = os.getenv("DOTENV_PATH", None) +dotenv_path = getenv("DOTENV_PATH", None) if dotenv_path and not os.path.exists(dotenv_path): # noqa: PTH110 raise FileNotFoundError(dotenv_path) @@ -29,12 +30,12 @@ def validate(cls) -> None: class Postgres: - port = os.environ.get("PG_PORT", "5434") - host = os.environ.get("PG_HOST", "0.0.0.0") # noqa: S104 - user = os.environ.get("PG_USER", "admin") - password = os.environ.get("PG_PASSWORD", "admin") - database = os.environ.get("PG_DB", "recording_oracle") - lock_timeout = int(os.environ.get("PG_LOCK_TIMEOUT", "3000")) # milliseconds + port = getenv("PG_PORT", "5434") + host = getenv("PG_HOST", "0.0.0.0") # noqa: S104 + user = getenv("PG_USER", "admin") + password = getenv("PG_PASSWORD", "admin") + database = getenv("PG_DB", "recording_oracle") + lock_timeout = int(getenv("PG_LOCK_TIMEOUT", "3000")) # milliseconds @classmethod def connection_url(cls) -> str: @@ -58,45 +59,43 @@ def is_configured(cls) -> bool: class PolygonMainnetConfig(_NetworkConfig): chain_id = 137 - rpc_api = os.environ.get("POLYGON_MAINNET_RPC_API_URL") - private_key = os.environ.get("POLYGON_MAINNET_PRIVATE_KEY") - addr = os.environ.get("POLYGON_MAINNET_ADDR") + rpc_api = getenv("POLYGON_MAINNET_RPC_API_URL") + private_key = getenv("POLYGON_MAINNET_PRIVATE_KEY") + addr = getenv("POLYGON_MAINNET_ADDR") class PolygonAmoyConfig(_NetworkConfig): chain_id = 80002 - rpc_api = os.environ.get("POLYGON_AMOY_RPC_API_URL") - private_key = os.environ.get("POLYGON_AMOY_PRIVATE_KEY") - addr = os.environ.get("POLYGON_AMOY_ADDR") + rpc_api = getenv("POLYGON_AMOY_RPC_API_URL") + private_key = getenv("POLYGON_AMOY_PRIVATE_KEY") + addr = getenv("POLYGON_AMOY_ADDR") class LocalhostConfig(_NetworkConfig): chain_id = 1338 - rpc_api = os.environ.get("LOCALHOST_RPC_API_URL", "http://blockchain-node:8545") - private_key = os.environ.get( + rpc_api = getenv("LOCALHOST_RPC_API_URL", "http://blockchain-node:8545") + private_key = getenv( "LOCALHOST_PRIVATE_KEY", "ac0974bec39a17e36ba4a6b4d238ff944bacb478cbed5efcae784d7bf4f2ff80", ) - addr = os.environ.get("LOCALHOST_AMOY_ADDR", "0xf39Fd6e51aad88F6F4ce6aB8827279cffFb92266") + addr = getenv("LOCALHOST_AMOY_ADDR", "0xf39Fd6e51aad88F6F4ce6aB8827279cffFb92266") - exchange_oracle_address = os.environ.get("LOCALHOST_EXCHANGE_ORACLE_ADDRESS") - exchange_oracle_url = os.environ.get("LOCALHOST_EXCHANGE_ORACLE_URL") + exchange_oracle_address = getenv("LOCALHOST_EXCHANGE_ORACLE_ADDRESS") + exchange_oracle_url = getenv("LOCALHOST_EXCHANGE_ORACLE_URL") - reputation_oracle_url = os.environ.get("LOCALHOST_REPUTATION_ORACLE_URL") + reputation_oracle_url = getenv("LOCALHOST_REPUTATION_ORACLE_URL") class CronConfig: - process_exchange_oracle_webhooks_int = int( - os.environ.get("PROCESS_EXCHANGE_ORACLE_WEBHOOKS_INT", 3000) - ) - process_exchange_oracle_webhooks_chunk_size = os.environ.get( - "PROCESS_EXCHANGE_ORACLE_WEBHOOKS_CHUNK_SIZE", 5 + process_exchange_oracle_webhooks_int = int(getenv("PROCESS_EXCHANGE_ORACLE_WEBHOOKS_INT", 3000)) + process_exchange_oracle_webhooks_chunk_size = int( + getenv("PROCESS_EXCHANGE_ORACLE_WEBHOOKS_CHUNK_SIZE", 5) ) process_reputation_oracle_webhooks_int = int( - os.environ.get("PROCESS_REPUTATION_ORACLE_WEBHOOKS_INT", 3000) + getenv("PROCESS_REPUTATION_ORACLE_WEBHOOKS_INT", 3000) ) - process_reputation_oracle_webhooks_chunk_size = os.environ.get( - "PROCESS_REPUTATION_ORACLE_WEBHOOKS_CHUNK_SIZE", 5 + process_reputation_oracle_webhooks_chunk_size = int( + getenv("PROCESS_REPUTATION_ORACLE_WEBHOOKS_CHUNK_SIZE", 5) ) @@ -105,7 +104,6 @@ class IStorageConfig: data_bucket_name: ClassVar[str] secure: ClassVar[bool] endpoint_url: ClassVar[str] # TODO: probably should be optional - region: ClassVar[str | None] # AWS S3 specific attributes access_key: ClassVar[str | None] secret_key: ClassVar[str | None] @@ -130,16 +128,15 @@ def bucket_url(cls) -> str: class StorageConfig(IStorageConfig): provider = os.environ["STORAGE_PROVIDER"].lower() endpoint_url = os.environ["STORAGE_ENDPOINT_URL"] # TODO: probably should be optional - region = os.environ.get("STORAGE_REGION") data_bucket_name = os.environ["STORAGE_RESULTS_BUCKET_NAME"] - secure = to_bool(os.environ.get("STORAGE_USE_SSL", "true")) + secure = to_bool(getenv("STORAGE_USE_SSL", "true")) # AWS S3 specific attributes - access_key = os.environ.get("STORAGE_ACCESS_KEY") - secret_key = os.environ.get("STORAGE_SECRET_KEY") + access_key = getenv("STORAGE_ACCESS_KEY") + secret_key = getenv("STORAGE_SECRET_KEY") # GCS specific attributes - key_file_path = os.environ.get("STORAGE_KEY_FILE_PATH") + key_file_path = getenv("STORAGE_KEY_FILE_PATH") class ExchangeOracleStorageConfig(IStorageConfig): @@ -148,63 +145,62 @@ class ExchangeOracleStorageConfig(IStorageConfig): endpoint_url = os.environ[ "EXCHANGE_ORACLE_STORAGE_ENDPOINT_URL" ] # TODO: probably should be optional - region = os.environ.get("EXCHANGE_ORACLE_STORAGE_REGION") data_bucket_name = os.environ["EXCHANGE_ORACLE_STORAGE_RESULTS_BUCKET_NAME"] - results_dir_suffix = os.environ.get("STORAGE_RESULTS_DIR_SUFFIX", "-results") - secure = to_bool(os.environ.get("EXCHANGE_ORACLE_STORAGE_USE_SSL", "true")) + results_dir_suffix = getenv("STORAGE_RESULTS_DIR_SUFFIX", "-results") + secure = to_bool(getenv("EXCHANGE_ORACLE_STORAGE_USE_SSL", "true")) # AWS S3 specific attributes - access_key = os.environ.get("EXCHANGE_ORACLE_STORAGE_ACCESS_KEY") - secret_key = os.environ.get("EXCHANGE_ORACLE_STORAGE_SECRET_KEY") + access_key = getenv("EXCHANGE_ORACLE_STORAGE_ACCESS_KEY") + secret_key = getenv("EXCHANGE_ORACLE_STORAGE_SECRET_KEY") # GCS specific attributes - key_file_path = os.environ.get("EXCHANGE_ORACLE_STORAGE_KEY_FILE_PATH") + key_file_path = getenv("EXCHANGE_ORACLE_STORAGE_KEY_FILE_PATH") class FeaturesConfig: - enable_custom_cloud_host = to_bool(os.environ.get("ENABLE_CUSTOM_CLOUD_HOST", "no")) + enable_custom_cloud_host = to_bool(getenv("ENABLE_CUSTOM_CLOUD_HOST", "no")) "Allows using a custom host in manifest bucket urls" class ValidationConfig: - min_available_gt_threshold = float(os.environ.get("MIN_AVAILABLE_GT_THRESHOLD", "0.3")) + min_available_gt_threshold = float(getenv("MIN_AVAILABLE_GT_THRESHOLD", "0.3")) """ The minimum required share of available GT frames required to continue annotation attempts. When there is no enough GT left, annotation stops. """ - max_gt_share = float(os.environ.get("MAX_USABLE_GT_SHARE", "0.05")) + max_gt_share = float(getenv("MAX_USABLE_GT_SHARE", "0.05")) """ The maximum share of the dataset to be used for validation. If the available GT share is greater than this number, the extra frames will not be used. It's recommended to keep this value small enough for faster convergence rate of the annotation process. """ - gt_ban_threshold = float(os.environ.get("GT_BAN_THRESHOLD", "0.03")) + gt_ban_threshold = float(getenv("GT_BAN_THRESHOLD", "0.03")) """ The minimum allowed rating (annotation probability) per GT sample, before it's considered bad and banned for further use. """ - unverifiable_assignments_threshold = float( - os.environ.get("UNVERIFIABLE_ASSIGNMENTS_THRESHOLD", "0.1") - ) + unverifiable_assignments_threshold = float(getenv("UNVERIFIABLE_ASSIGNMENTS_THRESHOLD", "0.1")) """ + Deprecated. Not expected to happen in practice, kept only as a safety fallback rule. + The maximum allowed fraction of jobs with insufficient GT available for validation. Each such job will be accepted "blindly", as we can't validate the annotations. """ - max_escrow_iterations = int(os.getenv("MAX_ESCROW_ITERATIONS", "50")) + max_escrow_iterations = int(getenv("MAX_ESCROW_ITERATIONS", "50")) """ Maximum escrow annotation-validation iterations. After this, the escrow is finished automatically. Supposed only for testing. Use 0 to disable. """ - warmup_iterations = int(os.getenv("WARMUP_ITERATIONS", "1")) + warmup_iterations = int(getenv("WARMUP_ITERATIONS", "1")) """ The first escrow iterations where the annotation speed is checked to be big enough. """ - min_warmup_progress = float(os.getenv("MIN_WARMUP_PROGRESS", "10")) + min_warmup_progress = float(getenv("MIN_WARMUP_PROGRESS", "10")) """ Minimum percent of the accepted jobs in an escrow after the first WARMUP iterations. If the value is lower, the escrow annotation is paused for manual investigation. @@ -212,9 +208,9 @@ class ValidationConfig: class EncryptionConfig(_BaseConfig): - pgp_passphrase = os.environ.get("PGP_PASSPHRASE", "") - pgp_private_key = os.environ.get("PGP_PRIVATE_KEY", "") - pgp_public_key_url = os.environ.get("PGP_PUBLIC_KEY_URL", "") + pgp_passphrase = getenv("PGP_PASSPHRASE", "") + pgp_private_key = getenv("PGP_PRIVATE_KEY", "") + pgp_public_key_url = getenv("PGP_PUBLIC_KEY_URL", "") @classmethod def validate(cls) -> None: @@ -234,22 +230,22 @@ def validate(cls) -> None: class CvatConfig: - cvat_url = os.environ.get("CVAT_URL", "http://localhost:8080") - cvat_admin = os.environ.get("CVAT_ADMIN", "admin") - cvat_admin_pass = os.environ.get("CVAT_ADMIN_PASS", "admin") - cvat_org_slug = os.environ.get("CVAT_ORG_SLUG", "org1") + host_url = getenv("CVAT_URL", "http://localhost:8080") + admin_login = getenv("CVAT_ADMIN", "admin") + admin_pass = getenv("CVAT_ADMIN_PASS", "admin") + org_slug = getenv("CVAT_ORG_SLUG", "org1") - cvat_quality_retrieval_timeout = int(os.environ.get("CVAT_QUALITY_RETRIEVAL_TIMEOUT", 60 * 60)) - cvat_quality_check_interval = int(os.environ.get("CVAT_QUALITY_CHECK_INTERVAL", 5)) + quality_retrieval_timeout = int(getenv("CVAT_QUALITY_RETRIEVAL_TIMEOUT", 60 * 60)) + quality_check_interval = int(getenv("CVAT_QUALITY_CHECK_INTERVAL", 5)) class Config: - port = int(os.environ.get("PORT", 8000)) - environment = os.environ.get("ENVIRONMENT", "development") - workers_amount = int(os.environ.get("WORKERS_AMOUNT", 1)) - webhook_max_retries = int(os.environ.get("WEBHOOK_MAX_RETRIES", 5)) - webhook_delay_if_failed = int(os.environ.get("WEBHOOK_DELAY_IF_FAILED", 60)) - loglevel = parse_log_level(os.environ.get("LOGLEVEL", "info")) + port = int(getenv("PORT", 8000)) + environment = getenv("ENVIRONMENT", "development") + workers_amount = int(getenv("WORKERS_AMOUNT", 1)) + webhook_max_retries = int(getenv("WEBHOOK_MAX_RETRIES", 5)) + webhook_delay_if_failed = int(getenv("WEBHOOK_DELAY_IF_FAILED", 60)) + loglevel = parse_log_level(getenv("LOGLEVEL", "info")) polygon_mainnet = PolygonMainnetConfig polygon_amoy = PolygonAmoyConfig diff --git a/packages/examples/cvat/recording-oracle/src/cvat/api_calls.py b/packages/examples/cvat/recording-oracle/src/cvat/api_calls.py index ab52816957..72d61c4a2b 100644 --- a/packages/examples/cvat/recording-oracle/src/cvat/api_calls.py +++ b/packages/examples/cvat/recording-oracle/src/cvat/api_calls.py @@ -14,13 +14,13 @@ def get_api_client() -> ApiClient: configuration = Configuration( - host=Config.cvat_config.cvat_url, - username=Config.cvat_config.cvat_admin, - password=Config.cvat_config.cvat_admin_pass, + host=Config.cvat_config.host_url, + username=Config.cvat_config.admin_login, + password=Config.cvat_config.admin_pass, ) api_client = ApiClient(configuration=configuration) - api_client.set_default_header("X-organization", Config.cvat_config.cvat_org_slug) + api_client.set_default_header("X-organization", Config.cvat_config.org_slug) return api_client @@ -40,8 +40,8 @@ def get_last_task_quality_report(task_id: int) -> models.QualityReport | None: def compute_task_quality_report( task_id: int, *, - timeout: int = Config.cvat_config.cvat_quality_retrieval_timeout, - check_interval: float = Config.cvat_config.cvat_quality_check_interval, + timeout: int = Config.cvat_config.quality_retrieval_timeout, + check_interval: float = Config.cvat_config.quality_check_interval, ) -> models.QualityReport: logger = logging.getLogger("app") start_time = utcnow() @@ -91,8 +91,8 @@ def get_task(task_id: int) -> models.TaskRead: def get_task_quality_report( task_id: int, *, - timeout: int = Config.cvat_config.cvat_quality_retrieval_timeout, - check_interval: float = Config.cvat_config.cvat_quality_check_interval, + timeout: int = Config.cvat_config.quality_retrieval_timeout, + check_interval: float = Config.cvat_config.quality_check_interval, ) -> models.QualityReport: logger = logging.getLogger("app")