diff --git a/packages/examples/cvat/exchange-oracle/debug.py b/packages/examples/cvat/exchange-oracle/debug.py index aa8052cdd1..66168f9c65 100644 --- a/packages/examples/cvat/exchange-oracle/debug.py +++ b/packages/examples/cvat/exchange-oracle/debug.py @@ -1,35 +1,28 @@ import datetime -import hashlib import json +from collections.abc import Generator +from contextlib import ExitStack, contextmanager +from logging import Logger from pathlib import Path from typing import Any +from unittest import mock import uvicorn from httpx import URL from src.chain.kvstore import register_in_kvstore from src.core.config import Config +from src.db import SessionLocal from src.services import cloud +from src.services import cvat as cvat_service from src.services.cloud import BucketAccessInfo -from src.utils.logging import get_function_logger +from src.utils.logging import format_sequence, get_function_logger -def apply_local_development_patches(): - """ - Applies local development patches to avoid manual source code modification.: - - Overrides `EscrowUtils.get_escrow` to return local escrow data with mock values if the escrow - address matches a local manifest. - - Loads local manifest files from cloud storage into `LOCAL_MANIFEST_FILES`. - - Disables address validation by overriding `validate_address`. - - Replaces `validate_oracle_webhook_signature` with a lenient version that uses - partial signature parsing. - - Generates ECDSA keys if not already present for local JWT signing, - and sets the public key in `Config.human_app_config`. - """ - import src.handlers.job_creation - - original_make_cvat_cloud_storage_params = ( - src.handlers.job_creation._make_cvat_cloud_storage_params +@contextmanager +def _mock_cvat_cloud_storage_params(logger: Logger) -> Generator[None, None, None]: + from src.handlers.job_creation import ( + _make_cvat_cloud_storage_params as original_make_cvat_cloud_storage_params, ) def patched_make_cvat_cloud_storage_params(bucket_info: BucketAccessInfo) -> dict: @@ -37,114 +30,172 @@ def patched_make_cvat_cloud_storage_params(bucket_info: BucketAccessInfo) -> dic if Config.development_config.cvat_in_docker: bucket_info.host_url = str( - URL(original_host_url).copy_with(host=Config.development_config.cvat_local_host) + URL(original_host_url).copy_with( + host=Config.development_config.exchange_oracle_host + ) ) logger.info( f"DEV: Changed {original_host_url} to {bucket_info.host_url} for CVAT storage" ) + try: return original_make_cvat_cloud_storage_params(bucket_info) finally: bucket_info.host_url = original_host_url - src.handlers.job_creation._make_cvat_cloud_storage_params = ( - patched_make_cvat_cloud_storage_params - ) + with mock.patch( + "src.handlers.job_creation._make_cvat_cloud_storage_params", + patched_make_cvat_cloud_storage_params, + ): + yield - def prepare_signed_message( - escrow_address, - chain_id, - message: str | None = None, - body: dict | None = None, - ) -> tuple[None, str]: - digest = hashlib.sha256( - (escrow_address + ":".join(map(str, (chain_id, message, body)))).encode() - ).hexdigest() - signature = f"{OracleWebhookTypes.recording_oracle}:{digest}" - logger.info(f"DEV: Generated patched signature {signature}") - return None, signature - - src.utils.webhooks.prepare_signed_message = prepare_signed_message +@contextmanager +def _mock_get_manifests_from_minio(logger: Logger) -> Generator[None, None, None]: from human_protocol_sdk.constants import ChainId from human_protocol_sdk.escrow import EscrowData, EscrowUtils - logger = get_function_logger(apply_local_development_patches.__name__) - minio_client = cloud.make_client(BucketAccessInfo.parse_obj(Config.storage_config)) + original_get_escrow = EscrowUtils.get_escrow - def get_local_escrow(chain_id: int, escrow_address: str) -> EscrowData: - possible_manifest_name = escrow_address.split(":")[0] - local_manifests = minio_client.list_files(bucket="manifests") - logger.info(f"DEV: Local manifests: {local_manifests}") - if possible_manifest_name in local_manifests: - logger.info(f"DEV: Using local manifest {escrow_address}") - return EscrowData( - chain_id=ChainId(chain_id), - id="test", - address=escrow_address, - amount_paid=10, - balance=10, - count=1, - factory_address="", - launcher="", - status="Pending", - token="HMT", # noqa: S106 - total_funded_amount=10, - created_at=datetime.datetime(2023, 1, 1, tzinfo=datetime.timezone.utc), - manifest_url=( - f"http://{Config.storage_config.endpoint_url}/manifests/{possible_manifest_name}" - ), + def patched_get_escrow(chain_id: int, escrow_address: str) -> EscrowData: + minio_manifests = minio_client.list_files(bucket="manifests") + logger.debug(f"DEV: Local manifests: {format_sequence(minio_manifests)}") + + candidate_files = [fn for fn in minio_manifests if f"{escrow_address}.json" in fn] + if not candidate_files: + return original_get_escrow(ChainId(chain_id), escrow_address) + elif len(candidate_files) != 1: + raise Exception( + "Can't select local manifest to be used for escrow '{}'" + " - several manifests math: {}".format( + escrow_address, format_sequence(candidate_files) + ) ) - return original_get_escrow(ChainId(chain_id), escrow_address) - original_get_escrow = EscrowUtils.get_escrow - EscrowUtils.get_escrow = get_local_escrow + manifest_file = candidate_files[0] + escrow = EscrowData( + chain_id=ChainId(chain_id), + id="test", + address=escrow_address, + amount_paid=10, + balance=10, + count=1, + factory_address="", + launcher="", + status="Pending", + token="HMT", # noqa: S106 + total_funded_amount=10, + created_at=datetime.datetime(2023, 1, 1, tzinfo=datetime.timezone.utc), + manifest_url=(f"http://{Config.storage_config.endpoint_url}/manifests/{manifest_file}"), + ) + + logger.info(f"DEV: Using local manifest '{manifest_file}' for escrow '{escrow_address}'") + return escrow + + with mock.patch.object(EscrowUtils, "get_escrow", patched_get_escrow): + yield + + +@contextmanager +def _mock_webhook_signature_checking(_: Logger) -> Generator[None, None, None]: + """ + Allows to receive webhooks from other services: + - from launcher - with signature "job_launcher" + - from recording oracle - + encoded with Config.localhost.recording_oracle_address wallet address + or signature "recording_oracle" + - from reputation oracle - + encoded with Config.localhost.reputation_oracle_address wallet address + or signature "reputation_oracle" + """ - import src.schemas.webhook + from src.chain.escrow import ( + get_available_webhook_types as original_get_available_webhook_types, + ) from src.core.types import OracleWebhookTypes + from src.validators.signature import ( + validate_oracle_webhook_signature as original_validate_oracle_webhook_signature, + ) - src.schemas.webhook.validate_address = lambda x: x + async def patched_validate_oracle_webhook_signature(request, signature, webhook): + for webhook_type in OracleWebhookTypes: + if signature.startswith(webhook_type.value.lower()): + return webhook_type + + return await original_validate_oracle_webhook_signature(request, signature, webhook) + + def patched_get_available_webhook_types(chain_id, escrow_address): + d = dict(original_get_available_webhook_types(chain_id, escrow_address)) + d[Config.localhost.recording_oracle_address.lower()] = OracleWebhookTypes.recording_oracle + d[Config.localhost.reputation_oracle_address.lower()] = OracleWebhookTypes.reputation_oracle + return d + + with ( + mock.patch("src.schemas.webhook.validate_address", lambda x: x), + mock.patch( + "src.validators.signature.get_available_webhook_types", + patched_get_available_webhook_types, + ), + mock.patch( + "src.endpoints.webhook.validate_oracle_webhook_signature", + patched_validate_oracle_webhook_signature, + ), + ): + yield + + +@contextmanager +def _mock_endpoint_auth(logger: Logger) -> Generator[None, None, None]: + """ + Allows simplified authentication: + - Bearer {"wallet_address": "...", "email": "..."} + - Bearer {"role": "human_app"} + """ - async def lenient_validate_oracle_webhook_signature(request, signature, webhook): - from src.validators.signature import validate_oracle_webhook_signature + from src.endpoints.authentication import HUMAN_APP_ROLE, TokenAuthenticator - try: - parsed_type = OracleWebhookTypes(signature.split(":")[0]) - logger.info(f"DEV: Recovered {parsed_type} from the signature {signature}") - except (ValueError, TypeError): - return await validate_oracle_webhook_signature(request, signature, webhook) + original_decode_token = TokenAuthenticator._decode_token - import src.endpoints.webhook + def decode_plain_json_token(self, token) -> dict[str, Any]: + try: + token_data = json.loads(token) - src.endpoints.webhook.validate_oracle_webhook_signature = ( - lenient_validate_oracle_webhook_signature - ) + if (user_wallet := token_data.get("wallet_address")) and not token_data.get("email"): + with SessionLocal.begin() as session: + user = cvat_service.get_user_by_id(session, user_wallet) + if not user: + raise Exception(f"Could not find user with wallet address '{user_wallet}'") - import src.endpoints.authentication + token_data["email"] = user.cvat_email - original_decode_token = src.endpoints.authentication.TokenAuthenticator._decode_token + if token_data.get("role") == HUMAN_APP_ROLE: + token_data["wallet_address"] = None + token_data["email"] = "" - def decode_plain_json_token(self, token) -> dict[str, Any]: - """ - Allows Authentication: Bearer {"wallet_address": "...", "email": "..."} - """ - try: - decoded = json.loads(token) - logger.info(f"DEV: Decoded plain JSON auth token: {decoded}") + logger.info(f"DEV: Decoded plain JSON auth token: {token_data}") + return token_data except (ValueError, TypeError): return original_decode_token(self, token) - src.endpoints.authentication.TokenAuthenticator._decode_token = decode_plain_json_token + with mock.patch.object(TokenAuthenticator, "_decode_token", decode_plain_json_token): + yield + + +@contextmanager +def _mock_human_app_keys(_: Logger) -> Generator[None, None, None]: + "Creates or uses local Human App JWT keys" from tests.api.test_exchange_api import generate_ecdsa_keys # generating keys for local development repo_root = Path(__file__).parent - human_app_private_key_file, human_app_public_key_file = ( - repo_root / "human_app_private_key.pem", - repo_root / "human_app_public_key.pem", - ) + dev_dir = repo_root / "dev" + dev_dir.mkdir(exist_ok=True) + + human_app_private_key_file = dev_dir / "human_app_private_key.pem" + human_app_public_key_file = dev_dir / "human_app_public_key.pem" + if not (human_app_public_key_file.exists() and human_app_private_key_file.exists()): private_key, public_key = generate_ecdsa_keys() human_app_private_key_file.write_text(private_key) @@ -154,20 +205,47 @@ def decode_plain_json_token(self, token) -> dict[str, Any]: Config.human_app_config.jwt_public_key = public_key - logger.warning("DEV: Local development patches applied.") + yield + + +@contextmanager +def apply_local_development_patches() -> Generator[None, None, None]: + """ + Applies local development patches to avoid manual source code modification + """ + + logger = get_function_logger(apply_local_development_patches.__name__) + + logger.warning("DEV: Applying local development patches") + + with ExitStack() as es: + for mock_callback in ( + _mock_cvat_cloud_storage_params, + _mock_get_manifests_from_minio, + _mock_webhook_signature_checking, + _mock_endpoint_auth, + _mock_human_app_keys, + ): + logger.warning(f"DEV: applying patch {mock_callback.__name__}...") + es.enter_context(mock_callback(logger)) + + logger.warning("DEV: Local development patches applied.") + + yield if __name__ == "__main__": - is_dev = Config.environment == "development" - if is_dev: - apply_local_development_patches() - - Config.validate() - register_in_kvstore() - - uvicorn.run( - app="src:app", - host="0.0.0.0", # noqa: S104 - port=int(Config.port), - workers=Config.workers_amount, - ) + with ExitStack() as es: + is_dev = Config.environment == "development" + if is_dev: + es.enter_context(apply_local_development_patches()) + + Config.validate() + register_in_kvstore() + + uvicorn.run( + app="src:app", + host="0.0.0.0", # noqa: S104 + port=int(Config.port), + workers=Config.workers_amount, + ) diff --git a/packages/examples/cvat/exchange-oracle/pyproject.toml b/packages/examples/cvat/exchange-oracle/pyproject.toml index b08e32026f..93b64e3858 100644 --- a/packages/examples/cvat/exchange-oracle/pyproject.toml +++ b/packages/examples/cvat/exchange-oracle/pyproject.toml @@ -114,6 +114,7 @@ ignore = [ "ANN002", # Missing type annotation for `*args` "TRY300", # Consider moving this statement to an `else` block "C901", # Function is too complex + "PLW1508", # invalid-envvar-default. Alerts only for os.getenv(), but not for os.environ.get() "PLW2901", # Variable overwritten by assignment target "PTH118", # Prefer pathlib instead of os.path "PTH119", # `os.path.basename()` should be replaced by `Path.name` diff --git a/packages/examples/cvat/exchange-oracle/src/.env.template b/packages/examples/cvat/exchange-oracle/src/.env.template index ae4b9dfb05..65a0e3966f 100644 --- a/packages/examples/cvat/exchange-oracle/src/.env.template +++ b/packages/examples/cvat/exchange-oracle/src/.env.template @@ -44,9 +44,7 @@ PROCESS_RECORDING_ORACLE_WEBHOOKS_CHUNK_SIZE= PROCESS_REPUTATION_ORACLE_WEBHOOKS_CHUNK_SIZE= PROCESS_REPUTATION_ORACLE_WEBHOOKS_INT= TRACK_COMPLETED_PROJECTS_INT= -TRACK_COMPLETED_PROJECTS_CHUNK_SIZE= TRACK_COMPLETED_TASKS_INT= -TRACK_COMPLETED_TASKS_CHUNK_SIZE= TRACK_COMPLETED_ESCROWS_INT= TRACK_COMPLETED_ESCROWS_CHUNK_SIZE= TRACK_ESCROW_VALIDATIONS_INT= @@ -79,13 +77,11 @@ CVAT_IOU_THRESHOLD= CVAT_OKS_SIGMA= CVAT_EXPORT_TIMEOUT= CVAT_IMPORT_TIMEOUT= -CVAT_POLYGONS_IOU_THRESHOLD= # Storage Config (S3/GCS) STORAGE_PROVIDER= STORAGE_ENDPOINT_URL= -STORAGE_REGION= STORAGE_ACCESS_KEY= STORAGE_SECRET_KEY= STORAGE_KEY_FILE_PATH= @@ -96,6 +92,7 @@ STORAGE_USE_SSL= ENABLE_CUSTOM_CLOUD_HOST= REQUEST_LOGGING_ENABLED= +PROFILING_ENABLED= MANIFEST_CACHE_TTL= # Core @@ -112,9 +109,14 @@ HUMAN_APP_JWT_KEY= # API config DEFAULT_API_PAGE_SIZE= +MIN_API_PAGE_SIZE= +MAX_API_PAGE_SIZE= +STATS_RPS_LIMIT= # Localhost +LOCALHOST_RPC_API_URL= +LOCALHOST_AMOY_ADDR= LOCALHOST_RECORDING_ORACLE_ADDRESS= LOCALHOST_RECORDING_ORACLE_URL= LOCALHOST_JOB_LAUNCHER_URL= @@ -122,6 +124,7 @@ LOCALHOST_REPUTATION_ORACLE_ADDRESS= LOCALHOST_REPUTATION_ORACLE_URL= # Encryption + PGP_PRIVATE_KEY= PGP_PASSPHRASE= PGP_PUBLIC_KEY_URL= @@ -129,4 +132,4 @@ PGP_PUBLIC_KEY_URL= # Development DEV_CVAT_IN_DOCKER= -DEV_CVAT_LOCAL_HOST= +DEV_EXCHANGE_ORACLE_HOST= diff --git a/packages/examples/cvat/exchange-oracle/src/chain/escrow.py b/packages/examples/cvat/exchange-oracle/src/chain/escrow.py index bc4ac9b51e..c5679fc7f7 100644 --- a/packages/examples/cvat/exchange-oracle/src/chain/escrow.py +++ b/packages/examples/cvat/exchange-oracle/src/chain/escrow.py @@ -73,11 +73,7 @@ def get_available_webhook_types( ) -> dict[str, OracleWebhookTypes]: escrow = get_escrow(chain_id, escrow_address) return { - escrow.launcher.lower(): OracleWebhookTypes.job_launcher, - ( - Config.localhost.recording_oracle_address or escrow.recording_oracle - ).lower(): OracleWebhookTypes.recording_oracle, - ( - Config.localhost.reputation_oracle_address or escrow.reputation_oracle - ).lower(): OracleWebhookTypes.reputation_oracle, + (escrow.launcher or "").lower(): OracleWebhookTypes.job_launcher, + (escrow.recording_oracle or "").lower(): OracleWebhookTypes.recording_oracle, + (escrow.reputation_oracle or "").lower(): OracleWebhookTypes.reputation_oracle, } diff --git a/packages/examples/cvat/exchange-oracle/src/core/config.py b/packages/examples/cvat/exchange-oracle/src/core/config.py index cc5d560093..951d51f2f9 100644 --- a/packages/examples/cvat/exchange-oracle/src/core/config.py +++ b/packages/examples/cvat/exchange-oracle/src/core/config.py @@ -5,6 +5,7 @@ import os from collections.abc import Iterable from enum import Enum +from os import getenv from typing import ClassVar, Optional from attrs.converters import to_bool @@ -16,13 +17,16 @@ from src.utils.logging import parse_log_level from src.utils.net import is_ipv4 -dotenv_path = os.getenv("DOTENV_PATH", None) +dotenv_path = getenv("DOTENV_PATH", None) if dotenv_path and not os.path.exists(dotenv_path): # noqa: PTH110 raise FileNotFoundError(dotenv_path) load_dotenv(dotenv_path) +# TODO: add some logic to report unused/deprecated env vars on startup + + class _BaseConfig: @classmethod def validate(cls) -> None: @@ -30,12 +34,12 @@ def validate(cls) -> None: class PostgresConfig: - port = os.environ.get("PG_PORT", "5432") - host = os.environ.get("PG_HOST", "0.0.0.0") # noqa: S104 - user = os.environ.get("PG_USER", "admin") - password = os.environ.get("PG_PASSWORD", "admin") - database = os.environ.get("PG_DB", "exchange_oracle") - lock_timeout = int(os.environ.get("PG_LOCK_TIMEOUT", "3000")) # milliseconds + port = getenv("PG_PORT", "5432") + host = getenv("PG_HOST", "0.0.0.0") # noqa: S104 + user = getenv("PG_USER", "admin") + password = getenv("PG_PASSWORD", "admin") + database = getenv("PG_DB", "exchange_oracle") + lock_timeout = int(getenv("PG_LOCK_TIMEOUT", "3000")) # milliseconds @classmethod def connection_url(cls) -> str: @@ -43,12 +47,12 @@ def connection_url(cls) -> str: class RedisConfig: - port = int(os.environ.get("REDIS_PORT", "6379")) - host = os.environ.get("REDIS_HOST", "0.0.0.0") # noqa: S104 - database = int(os.environ.get("REDIS_DB", "0")) - user = os.environ.get("REDIS_USER", "") - password = os.environ.get("REDIS_PASSWORD", "") - use_ssl = to_bool(os.environ.get("REDIS_USE_SSL", "false")) + port = int(getenv("REDIS_PORT", "6379")) + host = getenv("REDIS_HOST", "0.0.0.0") # noqa: S104 + database = int(getenv("REDIS_DB", "0")) + user = getenv("REDIS_USER", "") + password = getenv("REDIS_PASSWORD", "") + use_ssl = to_bool(getenv("REDIS_USE_SSL", "false")) @classmethod def connection_url(cls) -> str: @@ -80,144 +84,126 @@ def is_configured(cls) -> bool: class PolygonMainnetConfig(_NetworkConfig): chain_id = 137 - rpc_api = os.environ.get("POLYGON_MAINNET_RPC_API_URL") - private_key = os.environ.get("POLYGON_MAINNET_PRIVATE_KEY") - addr = os.environ.get("POLYGON_MAINNET_ADDR") + rpc_api = getenv("POLYGON_MAINNET_RPC_API_URL") + private_key = getenv("POLYGON_MAINNET_PRIVATE_KEY") + addr = getenv("POLYGON_MAINNET_ADDR") class PolygonAmoyConfig(_NetworkConfig): chain_id = 80002 - rpc_api = os.environ.get("POLYGON_AMOY_RPC_API_URL") - private_key = os.environ.get("POLYGON_AMOY_PRIVATE_KEY") - addr = os.environ.get("POLYGON_AMOY_ADDR") + rpc_api = getenv("POLYGON_AMOY_RPC_API_URL") + private_key = getenv("POLYGON_AMOY_PRIVATE_KEY") + addr = getenv("POLYGON_AMOY_ADDR") class LocalhostConfig(_NetworkConfig): chain_id = 1338 - rpc_api = os.environ.get("LOCALHOST_RPC_API_URL", "http://blockchain-node:8545") - private_key = os.environ.get( + rpc_api = getenv("LOCALHOST_RPC_API_URL", "http://blockchain-node:8545") + private_key = getenv( "LOCALHOST_PRIVATE_KEY", "ac0974bec39a17e36ba4a6b4d238ff944bacb478cbed5efcae784d7bf4f2ff80", ) - addr = os.environ.get("LOCALHOST_AMOY_ADDR", "0xf39Fd6e51aad88F6F4ce6aB8827279cffFb92266") + addr = getenv("LOCALHOST_AMOY_ADDR", "0xf39Fd6e51aad88F6F4ce6aB8827279cffFb92266") - job_launcher_url = os.environ.get("LOCALHOST_JOB_LAUNCHER_URL") + job_launcher_url = getenv("LOCALHOST_JOB_LAUNCHER_URL") - recording_oracle_address = os.environ.get("LOCALHOST_RECORDING_ORACLE_ADDRESS") - recording_oracle_url = os.environ.get("LOCALHOST_RECORDING_ORACLE_URL") + recording_oracle_address = getenv("LOCALHOST_RECORDING_ORACLE_ADDRESS") + recording_oracle_url = getenv("LOCALHOST_RECORDING_ORACLE_URL") - reputation_oracle_address = os.environ.get("LOCALHOST_REPUTATION_ORACLE_ADDRESS") - reputation_oracle_url = os.environ.get("LOCALHOST_REPUTATION_ORACLE_URL") + reputation_oracle_address = getenv("LOCALHOST_REPUTATION_ORACLE_ADDRESS") + reputation_oracle_url = getenv("LOCALHOST_REPUTATION_ORACLE_URL") class CronConfig: - process_job_launcher_webhooks_int = int(os.environ.get("PROCESS_JOB_LAUNCHER_WEBHOOKS_INT", 30)) + process_job_launcher_webhooks_int = int(getenv("PROCESS_JOB_LAUNCHER_WEBHOOKS_INT", 30)) process_job_launcher_webhooks_chunk_size = int( - os.environ.get("PROCESS_JOB_LAUNCHER_WEBHOOKS_CHUNK_SIZE", 5) - ) - process_recording_oracle_webhooks_int = int( - os.environ.get("PROCESS_RECORDING_ORACLE_WEBHOOKS_INT", 30) + getenv("PROCESS_JOB_LAUNCHER_WEBHOOKS_CHUNK_SIZE", 5) ) + process_recording_oracle_webhooks_int = int(getenv("PROCESS_RECORDING_ORACLE_WEBHOOKS_INT", 30)) process_recording_oracle_webhooks_chunk_size = int( - os.environ.get("PROCESS_RECORDING_ORACLE_WEBHOOKS_CHUNK_SIZE", 5) + getenv("PROCESS_RECORDING_ORACLE_WEBHOOKS_CHUNK_SIZE", 5) ) process_reputation_oracle_webhooks_chunk_size = int( - os.environ.get("PROCESS_REPUTATION_ORACLE_WEBHOOKS_CHUNK_SIZE", 5) + getenv("PROCESS_REPUTATION_ORACLE_WEBHOOKS_CHUNK_SIZE", 5) ) process_reputation_oracle_webhooks_int = int( - os.environ.get("PROCESS_REPUTATION_ORACLE_WEBHOOKS_INT", 5) - ) - track_completed_projects_int = int(os.environ.get("TRACK_COMPLETED_PROJECTS_INT", 30)) - track_completed_projects_chunk_size = os.environ.get("TRACK_COMPLETED_PROJECTS_CHUNK_SIZE", 5) - track_completed_tasks_int = int(os.environ.get("TRACK_COMPLETED_TASKS_INT", 30)) - track_completed_tasks_chunk_size = os.environ.get("TRACK_COMPLETED_TASKS_CHUNK_SIZE", 20) - track_creating_tasks_int = int(os.environ.get("TRACK_CREATING_TASKS_INT", 300)) - track_creating_tasks_chunk_size = os.environ.get("TRACK_CREATING_TASKS_CHUNK_SIZE", 5) - track_assignments_int = int(os.environ.get("TRACK_ASSIGNMENTS_INT", 5)) - track_assignments_chunk_size = os.environ.get("TRACK_ASSIGNMENTS_CHUNK_SIZE", 10) - - track_completed_escrows_int = int( - # backward compatibility - os.environ.get( - "TRACK_COMPLETED_ESCROWS_INT", os.environ.get("RETRIEVE_ANNOTATIONS_INT", 60) - ) - ) - track_completed_escrows_chunk_size = int( - os.environ.get("TRACK_COMPLETED_ESCROWS_CHUNK_SIZE", 100) - ) - track_escrow_validations_int = int(os.environ.get("TRACK_COMPLETED_ESCROWS_INT", 60)) - track_escrow_validations_chunk_size = int( - os.environ.get("TRACK_ESCROW_VALIDATIONS_CHUNK_SIZE", 1) + getenv("PROCESS_REPUTATION_ORACLE_WEBHOOKS_INT", 5) ) + track_completed_projects_int = int(getenv("TRACK_COMPLETED_PROJECTS_INT", 30)) + track_completed_tasks_int = int(getenv("TRACK_COMPLETED_TASKS_INT", 30)) + track_creating_tasks_int = int(getenv("TRACK_CREATING_TASKS_INT", 300)) + track_creating_tasks_chunk_size = getenv("TRACK_CREATING_TASKS_CHUNK_SIZE", 5) + track_assignments_int = int(getenv("TRACK_ASSIGNMENTS_INT", 5)) + track_assignments_chunk_size = int(getenv("TRACK_ASSIGNMENTS_CHUNK_SIZE", 10)) + + track_completed_escrows_int = int(getenv("TRACK_COMPLETED_ESCROWS_INT", 60)) + track_completed_escrows_chunk_size = int(getenv("TRACK_COMPLETED_ESCROWS_CHUNK_SIZE", 100)) + track_escrow_validations_int = int(getenv("TRACK_ESCROW_VALIDATIONS_INT", 60)) + track_escrow_validations_chunk_size = int(getenv("TRACK_ESCROW_VALIDATIONS_CHUNK_SIZE", 1)) track_completed_escrows_max_downloading_retries = int( - os.environ.get("TRACK_COMPLETED_ESCROWS_MAX_DOWNLOADING_RETRIES", 10) + getenv("TRACK_COMPLETED_ESCROWS_MAX_DOWNLOADING_RETRIES", 10) ) "Maximum number of downloading attempts per job or project during results downloading" track_completed_escrows_jobs_downloading_batch_size = int( - os.environ.get("TRACK_COMPLETED_ESCROWS_JOBS_DOWNLOADING_BATCH_SIZE", 500) + getenv("TRACK_COMPLETED_ESCROWS_JOBS_DOWNLOADING_BATCH_SIZE", 500) ) "Maximum number of parallel downloading requests during results downloading" - rejected_projects_chunk_size = os.environ.get("REJECTED_PROJECTS_CHUNK_SIZE", 20) - accepted_projects_chunk_size = os.environ.get("ACCEPTED_PROJECTS_CHUNK_SIZE", 20) + process_rejected_projects_chunk_size = int(getenv("REJECTED_PROJECTS_CHUNK_SIZE", 20)) + process_accepted_projects_chunk_size = int(getenv("ACCEPTED_PROJECTS_CHUNK_SIZE", 20)) - track_escrow_creation_chunk_size = os.environ.get("TRACK_ESCROW_CREATION_CHUNK_SIZE", 20) - track_escrow_creation_int = int(os.environ.get("TRACK_ESCROW_CREATION_INT", 300)) + track_escrow_creation_chunk_size = int(getenv("TRACK_ESCROW_CREATION_CHUNK_SIZE", 20)) + track_escrow_creation_int = int(getenv("TRACK_ESCROW_CREATION_INT", 300)) class CvatConfig: - # TODO: remove cvat_ prefix - cvat_url = os.environ.get("CVAT_URL", "http://localhost:8080") - cvat_admin = os.environ.get("CVAT_ADMIN", "admin") - cvat_admin_pass = os.environ.get("CVAT_ADMIN_PASS", "admin") - cvat_org_slug = os.environ.get("CVAT_ORG_SLUG", "") - - cvat_job_overlap = int(os.environ.get("CVAT_JOB_OVERLAP", 0)) - cvat_task_segment_size = int(os.environ.get("CVAT_TASK_SEGMENT_SIZE", 150)) - cvat_default_image_quality = int(os.environ.get("CVAT_DEFAULT_IMAGE_QUALITY", 70)) - cvat_max_jobs_per_task = int(os.environ.get("CVAT_MAX_JOBS_PER_TASK", 1000)) - cvat_task_creation_check_interval = int(os.environ.get("CVAT_TASK_CREATION_CHECK_INTERVAL", 5)) - - cvat_export_timeout = int(os.environ.get("CVAT_EXPORT_TIMEOUT", 5 * 60)) + host_url = getenv("CVAT_URL", "http://localhost:8080") + admin_login = getenv("CVAT_ADMIN", "admin") + admin_pass = getenv("CVAT_ADMIN_PASS", "admin") + org_slug = getenv("CVAT_ORG_SLUG", "") + + job_overlap = int(getenv("CVAT_JOB_OVERLAP", 0)) + task_segment_size = int(getenv("CVAT_TASK_SEGMENT_SIZE", 150)) + default_image_quality = int(getenv("CVAT_DEFAULT_IMAGE_QUALITY", 70)) + max_jobs_per_task = int(getenv("CVAT_MAX_JOBS_PER_TASK", 1000)) + task_creation_check_interval = int(getenv("CVAT_TASK_CREATION_CHECK_INTERVAL", 5)) + + export_timeout = int(getenv("CVAT_EXPORT_TIMEOUT", 5 * 60)) "Timeout, in seconds, for annotations or dataset export waiting" - cvat_import_timeout = int(os.environ.get("CVAT_IMPORT_TIMEOUT", 60 * 60)) + import_timeout = int(getenv("CVAT_IMPORT_TIMEOUT", 60 * 60)) "Timeout, in seconds, for waiting on GT annotations import" # quality control settings - cvat_max_validation_checks = int(os.environ.get("CVAT_MAX_VALIDATION_CHECKS", 3)) + max_validation_checks = int(getenv("CVAT_MAX_VALIDATION_CHECKS", 3)) "Maximum number of attempts to run a validation check on a job after completing annotation" - cvat_iou_threshold = float(os.environ.get("CVAT_IOU_THRESHOLD", 0.8)) - cvat_oks_sigma = float(os.environ.get("CVAT_OKS_SIGMA", 0.1)) - - cvat_polygons_iou_threshold = float(os.environ.get("CVAT_POLYGONS_IOU_THRESHOLD", 0.5)) - "`iou_threshold` parameter for quality settings in polygons tasks" + iou_threshold = float(getenv("CVAT_IOU_THRESHOLD", 0.8)) + oks_sigma = float(getenv("CVAT_OKS_SIGMA", 0.1)) - cvat_incoming_webhooks_url = os.environ.get("CVAT_INCOMING_WEBHOOKS_URL") - cvat_webhook_secret = os.environ.get("CVAT_WEBHOOK_SECRET", "thisisasamplesecret") + incoming_webhooks_url = getenv("CVAT_INCOMING_WEBHOOKS_URL") + webhook_secret = getenv("CVAT_WEBHOOK_SECRET", "thisisasamplesecret") class StorageConfig: provider: ClassVar[str] = os.environ["STORAGE_PROVIDER"].lower() data_bucket_name: ClassVar[str] = ( - os.environ.get("STORAGE_RESULTS_BUCKET_NAME") # backward compatibility + getenv("STORAGE_RESULTS_BUCKET_NAME") # backward compatibility or os.environ["STORAGE_BUCKET_NAME"] ) endpoint_url: ClassVar[str] = os.environ[ "STORAGE_ENDPOINT_URL" ] # TODO: probably should be optional - region: ClassVar[str | None] = os.environ.get("STORAGE_REGION") - results_dir_suffix: ClassVar[str] = os.environ.get("STORAGE_RESULTS_DIR_SUFFIX", "-results") - secure: ClassVar[bool] = to_bool(os.environ.get("STORAGE_USE_SSL", "true")) + results_dir_suffix: ClassVar[str] = getenv("STORAGE_RESULTS_DIR_SUFFIX", "-results") + secure: ClassVar[bool] = to_bool(getenv("STORAGE_USE_SSL", "true")) # S3 specific attributes - access_key: ClassVar[str | None] = os.environ.get("STORAGE_ACCESS_KEY") - secret_key: ClassVar[str | None] = os.environ.get("STORAGE_SECRET_KEY") + access_key: ClassVar[str | None] = getenv("STORAGE_ACCESS_KEY") + secret_key: ClassVar[str | None] = getenv("STORAGE_SECRET_KEY") # GCS specific attributes - key_file_path: ClassVar[str | None] = os.environ.get("STORAGE_KEY_FILE_PATH") + key_file_path: ClassVar[str | None] = getenv("STORAGE_KEY_FILE_PATH") @classmethod def get_scheme(cls) -> str: @@ -236,13 +222,13 @@ def bucket_url(cls) -> str: class FeaturesConfig: - enable_custom_cloud_host = to_bool(os.environ.get("ENABLE_CUSTOM_CLOUD_HOST", "no")) + enable_custom_cloud_host = to_bool(getenv("ENABLE_CUSTOM_CLOUD_HOST", "no")) "Allows using a custom host in manifest bucket urls" - request_logging_enabled = to_bool(os.getenv("REQUEST_LOGGING_ENABLED", "0")) + request_logging_enabled = to_bool(getenv("REQUEST_LOGGING_ENABLED", "0")) "Allow to log request details for each request" - profiling_enabled = to_bool(os.getenv("PROFILING_ENABLED", "0")) + profiling_enabled = to_bool(getenv("PROFILING_ENABLED", "0")) "Allow to profile specific requests" manifest_cache_ttl = int(os.getenv("MANIFEST_CACHE_TTL", str(2 * 24 * 60 * 60))) @@ -250,15 +236,15 @@ class FeaturesConfig: class CoreConfig: - default_assignment_time = int(os.environ.get("DEFAULT_ASSIGNMENT_TIME", 1800)) + default_assignment_time = int(getenv("DEFAULT_ASSIGNMENT_TIME", 1800)) - skeleton_assignment_size_mult = int(os.environ.get("SKELETON_ASSIGNMENT_SIZE_MULT", 1)) + skeleton_assignment_size_mult = int(getenv("SKELETON_ASSIGNMENT_SIZE_MULT", 1)) "Assignment size multiplier for image_skeletons_from_boxes tasks" - min_roi_size_w = int(os.environ.get("MIN_ROI_SIZE_W", 350)) + min_roi_size_w = int(getenv("MIN_ROI_SIZE_W", 350)) "Minimum absolute ROI size for image_boxes_from_points and image_skeletons_from_boxes tasks" - min_roi_size_h = int(os.environ.get("MIN_ROI_SIZE_H", 300)) + min_roi_size_h = int(getenv("MIN_ROI_SIZE_H", 300)) "Minimum absolute ROI size for image_boxes_from_points and image_skeletons_from_boxes tasks" @@ -268,21 +254,21 @@ class HumanAppConfig: # openssl ecparam -name prime256v1 -genkey -noout -out ec_private.pem # openssl ec -in ec_private.pem -pubout -out ec_public.pem # HUMAN_APP_JWT_KEY=$(cat ec_public.pem) - jwt_public_key = os.environ.get("HUMAN_APP_JWT_KEY") + jwt_public_key = getenv("HUMAN_APP_JWT_KEY") class ApiConfig: - default_page_size = int(os.environ.get("DEFAULT_API_PAGE_SIZE", 5)) - min_page_size = int(os.environ.get("MIN_API_PAGE_SIZE", 1)) - max_page_size = int(os.environ.get("MAX_API_PAGE_SIZE", 10)) + default_page_size = int(getenv("DEFAULT_API_PAGE_SIZE", 5)) + min_page_size = int(getenv("MIN_API_PAGE_SIZE", 1)) + max_page_size = int(getenv("MAX_API_PAGE_SIZE", 10)) - stats_rps_limit = int(os.environ.get("STATS_RPS_LIMIT", 4)) + stats_rps_limit = int(getenv("STATS_RPS_LIMIT", 4)) class EncryptionConfig(_BaseConfig): - pgp_passphrase = os.environ.get("PGP_PASSPHRASE", "") - pgp_private_key = os.environ.get("PGP_PRIVATE_KEY", "") - pgp_public_key_url = os.environ.get("PGP_PUBLIC_KEY_URL", "") + pgp_passphrase = getenv("PGP_PASSPHRASE", "") + pgp_private_key = getenv("PGP_PRIVATE_KEY", "") + pgp_public_key_url = getenv("PGP_PUBLIC_KEY_URL", "") @classmethod def validate(cls) -> None: @@ -301,10 +287,16 @@ def validate(cls) -> None: raise Exception(" ".join([ex_prefix, str(ex)])) -class Development: - cvat_in_docker = bool(int(os.environ.get("DEV_CVAT_IN_DOCKER", "0"))) - # might be `host.docker.internal` or `172.22.0.1` if CVAT is running in docker - cvat_local_host = os.environ.get("DEV_CVAT_LOCAL_HOST", "localhost") +class DevelopmentConfig: + cvat_in_docker = bool(int(getenv("DEV_CVAT_IN_DOCKER", "1"))) + + exchange_oracle_host = getenv("DEV_EXCHANGE_ORACLE_HOST", "172.22.0.1") + """ + Might be `host.docker.internal` or `172.22.0.1` if CVAT is running in Docker. + + Remember to allow this host via: + SMOKESCREEN_OPTS="--allow-address=" docker compose ... + """ class Environment(str, Enum): @@ -323,13 +315,13 @@ def _missing_(cls, value: str) -> Optional["Environment"]: class Config: - debug = to_bool(os.environ.get("DEBUG", "false")) - port = int(os.environ.get("PORT", 8000)) - environment = Environment(os.environ.get("ENVIRONMENT", Environment.DEVELOPMENT.value)) - workers_amount = int(os.environ.get("WORKERS_AMOUNT", 1)) - webhook_max_retries = int(os.environ.get("WEBHOOK_MAX_RETRIES", 5)) - webhook_delay_if_failed = int(os.environ.get("WEBHOOK_DELAY_IF_FAILED", 60)) - loglevel = parse_log_level(os.environ.get("LOGLEVEL", "info")) + debug = to_bool(getenv("DEBUG", "false")) + port = int(getenv("PORT", 8000)) + environment = Environment(getenv("ENVIRONMENT", Environment.DEVELOPMENT.value)) + workers_amount = int(getenv("WORKERS_AMOUNT", 1)) + webhook_max_retries = int(getenv("WEBHOOK_MAX_RETRIES", 5)) + webhook_delay_if_failed = int(getenv("WEBHOOK_DELAY_IF_FAILED", 60)) + loglevel = parse_log_level(getenv("LOGLEVEL", "info")) polygon_mainnet = PolygonMainnetConfig polygon_amoy = PolygonAmoyConfig @@ -346,7 +338,7 @@ class Config: features = FeaturesConfig core_config = CoreConfig encryption_config = EncryptionConfig - development_config = Development + development_config = DevelopmentConfig @classmethod def is_development_mode(cls) -> bool: diff --git a/packages/examples/cvat/exchange-oracle/src/crons/cvat/state_trackers.py b/packages/examples/cvat/exchange-oracle/src/crons/cvat/state_trackers.py index cf0e74c410..e4b689aaa6 100644 --- a/packages/examples/cvat/exchange-oracle/src/crons/cvat/state_trackers.py +++ b/packages/examples/cvat/exchange-oracle/src/crons/cvat/state_trackers.py @@ -15,7 +15,7 @@ from src.db import errors as db_errors from src.db.utils import ForUpdateParams from src.handlers.completed_escrows import handle_escrows_validations -from src.log import format_sequence +from src.utils.logging import format_sequence @cron_job diff --git a/packages/examples/cvat/exchange-oracle/src/crons/webhooks/recording_oracle.py b/packages/examples/cvat/exchange-oracle/src/crons/webhooks/recording_oracle.py index eeb4a3581b..0315cac825 100644 --- a/packages/examples/cvat/exchange-oracle/src/crons/webhooks/recording_oracle.py +++ b/packages/examples/cvat/exchange-oracle/src/crons/webhooks/recording_oracle.py @@ -60,7 +60,6 @@ def handle_recording_oracle_event(webhook: Webhook, *, db_session: Session, logg ) return - chunk_size = CronConfig.accepted_projects_chunk_size project_ids = cvat_db_service.get_project_cvat_ids_by_escrow_address( db_session, webhook.escrow_address ) @@ -71,6 +70,7 @@ def handle_recording_oracle_event(webhook: Webhook, *, db_session: Session, logg ) return + chunk_size = CronConfig.process_accepted_projects_chunk_size for ids_chunk in take_by(project_ids, chunk_size): projects_chunk = cvat_db_service.get_projects_by_cvat_ids( db_session, ids_chunk, for_update=True, limit=chunk_size @@ -138,7 +138,7 @@ def handle_recording_oracle_event(webhook: Webhook, *, db_session: Session, logg ) rejected_project_cvat_ids = set(j.cvat_project_id for j in rejected_jobs) - chunk_size = CronConfig.rejected_projects_chunk_size + chunk_size = CronConfig.process_rejected_projects_chunk_size for chunk_ids in take_by(rejected_project_cvat_ids, chunk_size): projects_chunk = cvat_db_service.get_projects_by_cvat_ids( db_session, chunk_ids, for_update=True, limit=chunk_size diff --git a/packages/examples/cvat/exchange-oracle/src/cvat/api_calls.py b/packages/examples/cvat/exchange-oracle/src/cvat/api_calls.py index e40b4af209..cd1ad4520d 100644 --- a/packages/examples/cvat/exchange-oracle/src/cvat/api_calls.py +++ b/packages/examples/cvat/exchange-oracle/src/cvat/api_calls.py @@ -77,7 +77,7 @@ def _get_annotations( time_begin = utcnow() if timeout is _NOTSET: - timeout = Config.cvat_config.cvat_export_timeout + timeout = Config.cvat_config.export_timeout while True: (_, response) = endpoint.call_with_http_info( @@ -118,23 +118,23 @@ def get_api_client() -> ApiClient: return current_api_client configuration = Configuration( - host=Config.cvat_config.cvat_url, - username=Config.cvat_config.cvat_admin, - password=Config.cvat_config.cvat_admin_pass, + host=Config.cvat_config.host_url, + username=Config.cvat_config.admin_login, + password=Config.cvat_config.admin_pass, ) api_client = ApiClient(configuration=configuration) - api_client.set_default_header("X-organization", Config.cvat_config.cvat_org_slug) + api_client.set_default_header("X-organization", Config.cvat_config.org_slug) return api_client def get_sdk_client() -> Client: client = make_client( - host=Config.cvat_config.cvat_url, - credentials=(Config.cvat_config.cvat_admin, Config.cvat_config.cvat_admin_pass), + host=Config.cvat_config.host_url, + credentials=(Config.cvat_config.admin_login, Config.cvat_config.admin_pass), ) - client.organization_slug = Config.cvat_config.cvat_org_slug + client.organization_slug = Config.cvat_config.org_slug return client @@ -291,11 +291,11 @@ def create_cvat_webhook(project_id: int) -> models.WebhookRead: logger = logging.getLogger("app") with get_api_client() as api_client: webhook_write_request = models.WebhookWriteRequest( - target_url=Config.cvat_config.cvat_incoming_webhooks_url, + target_url=Config.cvat_config.incoming_webhooks_url, description="Exchange Oracle notification", type=models.WebhookType("project"), content_type=models.WebhookContentType("application/json"), - secret=Config.cvat_config.cvat_webhook_secret, + secret=Config.cvat_config.webhook_secret, is_active=True, # enable_ssl=True, project_id=project_id, @@ -403,7 +403,7 @@ def put_task_data( data_request = models.DataRequest( chunk_size=chunk_size, cloud_storage_id=cloudstorage_id, - image_quality=Config.cvat_config.cvat_default_image_quality, + image_quality=Config.cvat_config.image_quality, use_cache=True, use_zip_chunks=True, sorting_method=sorting_method, @@ -641,7 +641,7 @@ def upload_gt_annotations( *, format_name: str, sleep_interval: int = 5, - timeout: int | None = Config.cvat_config.cvat_import_timeout, + timeout: int | None = Config.cvat_config.import_timeout, ) -> None: # FUTURE-TODO: use job.import_annotations when CVAT supports a waiting timeout start_time = datetime.now(timezone.utc) @@ -728,8 +728,8 @@ def update_quality_control_settings( *, target_metric_threshold: float, target_metric: str = "accuracy", - max_validations_per_job: int = Config.cvat_config.cvat_max_validation_checks, - iou_threshold: float = Config.cvat_config.cvat_iou_threshold, + max_validations_per_job: int = Config.cvat_config.max_validation_checks, + iou_threshold: float = Config.cvat_config.iou_threshold, oks_sigma: float | None = None, point_size_base: str | None = None, match_empty_frames: bool | None = None, @@ -823,7 +823,7 @@ def get_user_id(user_email: str) -> int: try: (invitation, _) = api_client.invitations_api.create( models.InvitationWriteRequest(role="worker", email=user_email), - org=Config.cvat_config.cvat_org_slug, + org=Config.cvat_config.org_slug, ) except exceptions.ApiException as e: logger.exception(f"Exception when calling get_user_id(): {e}\n") @@ -839,7 +839,7 @@ def remove_user_from_org(user_id: int): try: (page, _) = api_client.users_api.list( filter='{"==":[{"var":"id"},"%s"]}' % user_id, # noqa: UP031 - org=Config.cvat_config.cvat_org_slug, + org=Config.cvat_config.org_slug, ) if not page.results: return @@ -849,7 +849,7 @@ def remove_user_from_org(user_id: int): (page, _) = api_client.memberships_api.list( user=user.username, - org=Config.cvat_config.cvat_org_slug, + org=Config.cvat_config.org_slug, ) if page.results: api_client.memberships_api.destroy(page.results[0].id) diff --git a/packages/examples/cvat/exchange-oracle/src/endpoints/middleware.py b/packages/examples/cvat/exchange-oracle/src/endpoints/middleware.py index e111f1e2f3..64fedbdfce 100644 --- a/packages/examples/cvat/exchange-oracle/src/endpoints/middleware.py +++ b/packages/examples/cvat/exchange-oracle/src/endpoints/middleware.py @@ -136,7 +136,9 @@ async def _log_request(self, request: Request) -> dict[str, Any]: if raw_body: body = body.decode(errors="ignore") - body = body[: self.max_displayed_body_size] + + if len(body) > self.max_displayed_body_size: + body = body[: self.max_displayed_body_size - 3] + "..." request_logging["body"] = body diff --git a/packages/examples/cvat/exchange-oracle/src/handlers/job_creation.py b/packages/examples/cvat/exchange-oracle/src/handlers/job_creation.py index 2385d00774..f2d940ad4c 100644 --- a/packages/examples/cvat/exchange-oracle/src/handlers/job_creation.py +++ b/packages/examples/cvat/exchange-oracle/src/handlers/job_creation.py @@ -33,13 +33,13 @@ from src.core.storage import compose_data_bucket_filename from src.core.types import CvatLabelTypes, TaskStatuses, TaskTypes from src.db import SessionLocal -from src.log import ROOT_LOGGER_NAME, format_sequence +from src.log import ROOT_LOGGER_NAME from src.models.cvat import Project from src.services.cloud import CloudProviders, StorageClient from src.services.cloud.utils import BucketAccessInfo, compose_bucket_url from src.utils.annotations import InstanceSegmentsToBbox, ProjectLabels, is_point_in_bbox from src.utils.assignments import parse_manifest -from src.utils.logging import NullLogger, get_function_logger +from src.utils.logging import NullLogger, format_sequence, get_function_logger from src.utils.zip_archive import write_dir_to_zip_archive if TYPE_CHECKING: @@ -203,7 +203,7 @@ def _wait_task_creation(self, task_id: int) -> cvat_api.UploadStatus: if task_status not in [cvat_api.UploadStatus.STARTED, cvat_api.UploadStatus.QUEUED]: return task_status - sleep(Config.cvat_config.cvat_task_creation_check_interval) + sleep(Config.cvat_config.task_creation_check_interval) def _setup_gt_job_for_cvat_task( self, task_id: int, gt_dataset: dm.Dataset, *, dm_export_format: str = "coco" @@ -409,7 +409,7 @@ def build(self): for data_subset in self._split_dataset_per_task( data_to_be_annotated, - subset_size=Config.cvat_config.cvat_max_jobs_per_task * segment_size, + subset_size=Config.cvat_config.max_jobs_per_task * segment_size, ): cvat_task = cvat_api.create_task( cvat_project.id, escrow_address, segment_size=segment_size @@ -552,7 +552,7 @@ def build(self): class PolygonTaskBuilder(SimpleTaskBuilder): def _setup_quality_settings(self, task_id: int, **overrides) -> None: - values = {"iou_threshold": Config.cvat_config.cvat_polygons_iou_threshold, **overrides} + values = {"iou_threshold": Config.cvat_config.iou_threshold, **overrides} super()._setup_quality_settings(task_id, **values) @@ -1554,7 +1554,7 @@ def _create_on_cvat(self): for data_subset in self._split_dataset_per_task( self._roi_filenames_to_be_annotated, - subset_size=Config.cvat_config.cvat_max_jobs_per_task * segment_size, + subset_size=Config.cvat_config.max_jobs_per_task * segment_size, ): cvat_task = cvat_api.create_task( cvat_project.id, self.escrow_address, segment_size=segment_size @@ -2354,7 +2354,7 @@ def _prepare_task_params(self): label_id=label_id, roi_ids=task_data_roi_ids, roi_gt_ids=label_gt_roi_ids ) for task_data_roi_ids in take_by( - label_data_roi_ids, Config.cvat_config.cvat_max_jobs_per_task * segment_size + label_data_roi_ids, Config.cvat_config.max_jobs_per_task * segment_size ) ] ) @@ -2603,7 +2603,7 @@ def _save_cvat_gt_dataset_to_oracle_bucket( def _setup_quality_settings(self, task_id: int, **overrides) -> None: values = { - "oks_sigma": Config.cvat_config.cvat_oks_sigma, + "oks_sigma": Config.cvat_config.oks_sigma, "point_size_base": "image_size", # we don't expect any boxes or groups, so ignore them } values.update(overrides) @@ -2767,7 +2767,7 @@ def _task_params_label_key(ts): cvat_task.id, gt_point_dataset, dm_export_format="cvat" ) self._setup_quality_settings( - cvat_task.id, oks_sigma=Config.cvat_config.cvat_oks_sigma + cvat_task.id, oks_sigma=Config.cvat_config.oks_sigma ) db_service.create_data_upload(session, cvat_task.id) diff --git a/packages/examples/cvat/exchange-oracle/src/log.py b/packages/examples/cvat/exchange-oracle/src/log.py index 406f243e43..7a26a6a287 100644 --- a/packages/examples/cvat/exchange-oracle/src/log.py +++ b/packages/examples/cvat/exchange-oracle/src/log.py @@ -1,9 +1,7 @@ """Config for the application logger""" import logging -from collections.abc import Sequence from logging.config import dictConfig -from typing import Any from src.core.config import Config @@ -50,9 +48,3 @@ def setup_logging(): def get_root_logger() -> logging.Logger: return logging.getLogger(ROOT_LOGGER_NAME) - - -def format_sequence(items: Sequence[Any], *, max_items: int = 5, separator: str = ", ") -> str: - remainder_count = len(items) - max_items - tail = f" (and {remainder_count} more)" if remainder_count > 0 else "" - return f"{separator.join(map(str, items[:max_items]))}{tail}" diff --git a/packages/examples/cvat/exchange-oracle/src/utils/assignments.py b/packages/examples/cvat/exchange-oracle/src/utils/assignments.py index ceeebfc150..51c8fe9c7b 100644 --- a/packages/examples/cvat/exchange-oracle/src/utils/assignments.py +++ b/packages/examples/cvat/exchange-oracle/src/utils/assignments.py @@ -24,7 +24,7 @@ def compose_assignment_url(task_id: int, job_id: int, *, project: Project) -> st if project.job_type == TaskTypes.image_skeletons_from_boxes: query_params += "&defaultPointsCount=1" - return urljoin(Config.cvat_config.cvat_url, f"/tasks/{task_id}/jobs/{job_id}{query_params}") + return urljoin(Config.cvat_config.host_url, f"/tasks/{task_id}/jobs/{job_id}{query_params}") def get_default_assignment_timeout(task_type: TaskTypes) -> int: diff --git a/packages/examples/cvat/exchange-oracle/src/utils/logging.py b/packages/examples/cvat/exchange-oracle/src/utils/logging.py index e7660eb0d7..a8d8c94daf 100644 --- a/packages/examples/cvat/exchange-oracle/src/utils/logging.py +++ b/packages/examples/cvat/exchange-oracle/src/utils/logging.py @@ -1,5 +1,6 @@ import logging -from typing import NewType +from collections.abc import Sequence +from typing import Any, NewType from src.utils.stack import current_function_name @@ -30,3 +31,9 @@ class NullLogger(logging.Logger): def __init__(self, name: str = "", level=0) -> None: super().__init__(name, level) self.disabled = True + + +def format_sequence(items: Sequence[Any], *, max_items: int = 5, separator: str = ", ") -> str: + remainder_count = len(items) - max_items + tail = f" (and {remainder_count} more)" if remainder_count > 0 else "" + return f"{separator.join(map(str, items[:max_items]))}{tail}" diff --git a/packages/examples/cvat/exchange-oracle/src/validators/signature.py b/packages/examples/cvat/exchange-oracle/src/validators/signature.py index 077177b80a..2f42d019e5 100644 --- a/packages/examples/cvat/exchange-oracle/src/validators/signature.py +++ b/packages/examples/cvat/exchange-oracle/src/validators/signature.py @@ -32,7 +32,7 @@ async def validate_cvat_signature(request: Request, x_signature_256: str): signature = ( "sha256=" + hmac.new( - Config.cvat_config.cvat_webhook_secret.encode("utf-8"), + Config.cvat_config.webhook_secret.encode("utf-8"), data, digestmod=sha256, ).hexdigest() diff --git a/packages/examples/cvat/exchange-oracle/tests/integration/cron/test_process_job_launcher_webhooks.py b/packages/examples/cvat/exchange-oracle/tests/integration/cron/test_process_job_launcher_webhooks.py index e4336b1da8..14607ea7fa 100644 --- a/packages/examples/cvat/exchange-oracle/tests/integration/cron/test_process_job_launcher_webhooks.py +++ b/packages/examples/cvat/exchange-oracle/tests/integration/cron/test_process_job_launcher_webhooks.py @@ -43,9 +43,9 @@ def tearDown(self): self.session.close() def test_process_incoming_job_launcher_webhooks_escrow_created_type(self): - webhok_id = str(uuid.uuid4()) + webhook_id = str(uuid.uuid4()) webhook = Webhook( - id=webhok_id, + id=webhook_id, signature="signature", escrow_address=escrow_address, chain_id=chain_id, @@ -93,7 +93,7 @@ def test_process_incoming_job_launcher_webhooks_escrow_created_type(self): process_incoming_job_launcher_webhooks() updated_webhook = ( - self.session.execute(select(Webhook).where(Webhook.id == webhok_id)).scalars().first() + self.session.execute(select(Webhook).where(Webhook.id == webhook_id)).scalars().first() ) assert updated_webhook.status == OracleWebhookStatuses.completed.value diff --git a/packages/examples/cvat/exchange-oracle/tests/utils/setup_cvat.py b/packages/examples/cvat/exchange-oracle/tests/utils/setup_cvat.py index 549a177552..8bc1a01515 100644 --- a/packages/examples/cvat/exchange-oracle/tests/utils/setup_cvat.py +++ b/packages/examples/cvat/exchange-oracle/tests/utils/setup_cvat.py @@ -18,7 +18,7 @@ def generate_cvat_signature(data: dict): return ( "sha256=" + hmac.new( - CvatConfig.cvat_webhook_secret.encode("utf-8"), + CvatConfig.webhook_secret.encode("utf-8"), b_data, digestmod=sha256, ).hexdigest() diff --git a/packages/examples/cvat/recording-oracle/alembic/versions/76f0bc042477_update_gt_stats.py b/packages/examples/cvat/recording-oracle/alembic/versions/76f0bc042477_update_gt_stats.py new file mode 100644 index 0000000000..3b27fc4ef0 --- /dev/null +++ b/packages/examples/cvat/recording-oracle/alembic/versions/76f0bc042477_update_gt_stats.py @@ -0,0 +1,60 @@ +"""Update GT stats with total_uses field + +Revision ID: 76f0bc042477 +Revises: 9d4367899f90 +Create Date: 2024-12-12 18:14:43.885249 + +""" + +import sqlalchemy as sa +from sqlalchemy import Column, ForeignKey, Integer, String, update +from sqlalchemy.orm import declarative_base + +from alembic import op + +# revision identifiers, used by Alembic. +revision = "76f0bc042477" +down_revision = "9d4367899f90" +branch_labels = None +depends_on = None + +Base = declarative_base() + + +class GtStats(Base): + __tablename__ = "gt_stats" + + # A composite primary key is used + task_id = Column( + String, ForeignKey("tasks.id", ondelete="CASCADE"), primary_key=True, nullable=False + ) + gt_frame_name = Column(String, primary_key=True, nullable=False) + + failed_attempts = Column(Integer, default=0, nullable=False) + accepted_attempts = Column(Integer, default=0, nullable=False) + total_uses = Column(Integer, default=0, nullable=False) + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column( + "gt_stats", sa.Column("total_uses", sa.Integer(), nullable=False, server_default="0") + ) + op.add_column( + "gt_stats", sa.Column("enabled", sa.Boolean(), nullable=False, server_default="True") + ) + # ### end Alembic commands ### + + op.execute( + update(GtStats).values(total_uses=GtStats.accepted_attempts + GtStats.failed_attempts) + ) + + op.alter_column("gt_stats", "total_uses", server_default=None) + op.alter_column("gt_stats", "enabled", server_default=None) + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("gt_stats", "total_uses") + op.drop_column("gt_stats", "enabled") + # ### end Alembic commands ### diff --git a/packages/examples/cvat/recording-oracle/debug.py b/packages/examples/cvat/recording-oracle/debug.py index 3915e79929..78d63c2a28 100644 --- a/packages/examples/cvat/recording-oracle/debug.py +++ b/packages/examples/cvat/recording-oracle/debug.py @@ -1,5 +1,8 @@ import datetime -import hashlib +from collections.abc import Generator +from contextlib import ExitStack, contextmanager +from logging import Logger +from unittest import mock import uvicorn @@ -7,123 +10,152 @@ from src.core.config import Config from src.services import cloud from src.services.cloud import BucketAccessInfo -from src.utils.logging import get_function_logger +from src.utils.logging import format_sequence, get_function_logger -def apply_local_development_patches(): - """ - Applies local development patches to bypass direct source code modifications: - - Overrides `EscrowUtils.get_escrow` to retrieve local escrow data for specific addresses, - using mock data if the address corresponds to a local manifest. - - Updates local manifest files from cloud storage. - - Overrides `validate_address` to disable address validation. - - Replaces `validate_oracle_webhook_signature` with a lenient version for oracle signature - validation in development. - - Replaces `src.chain.escrow.store_results` to avoid attempting to store results on chain. - - Replaces `src.validators.signature.validate_oracle_webhook_signature` to always return - `OracleWebhookTypes.exchange_oracle`. - """ - logger = get_function_logger(apply_local_development_patches.__name__) +@contextmanager +def _mock_get_manifests_from_minio(logger: Logger) -> Generator[None, None, None]: + from human_protocol_sdk.constants import ChainId + from human_protocol_sdk.escrow import EscrowData, EscrowUtils + + minio_client = cloud.make_client( + BucketAccessInfo.parse_obj(Config.exchange_oracle_storage_config) + ) + original_get_escrow = EscrowUtils.get_escrow - import src.crons._utils + def patched_get_escrow(chain_id: int, escrow_address: str) -> EscrowData: + minio_manifests = minio_client.list_files(bucket="manifests") + logger.debug(f"DEV: Local manifests: {format_sequence(minio_manifests)}") + + candidate_files = [fn for fn in minio_manifests if f"{escrow_address}.json" in fn] + if not candidate_files: + return original_get_escrow(ChainId(chain_id), escrow_address) + if len(candidate_files) != 1: + raise Exception( + f"Can't select local manifest to be used for escrow '{escrow_address}'" + f" - several manifests math: {format_sequence(candidate_files)}" + ) - def prepare_signed_message( - escrow_address, + manifest_file = candidate_files[0] + escrow = EscrowData( + chain_id=ChainId(chain_id), + id="test", + address=escrow_address, + amount_paid=10, + balance=10, + count=1, + factory_address="", + launcher="", + status="Pending", + token="HMT", # noqa: S106 + total_funded_amount=10, + created_at=datetime.datetime(2023, 1, 1, tzinfo=datetime.timezone.utc), + manifest_url=(f"http://{Config.storage_config.endpoint_url}/manifests/{manifest_file}"), + ) + + logger.info(f"DEV: Using local manifest '{manifest_file}' for escrow '{escrow_address}'") + return escrow + + with mock.patch.object(EscrowUtils, "get_escrow", patched_get_escrow): + yield + + +@contextmanager +def _mock_escrow_results_saving(logger: Logger) -> Generator[None, None, None]: + def patched_store_results( chain_id, - message: str | None = None, - body: dict | None = None, - ) -> tuple[None, str]: - digest = hashlib.sha256( - (escrow_address + ":".join(map(str, (chain_id, message, body)))).encode() - ).hexdigest() - signature = f"{OracleWebhookTypes.recording_oracle}:{digest}" - logger.info(f"DEV: Generated patched signature {signature}") - return None, signature - - src.crons._utils.prepare_signed_message = prepare_signed_message + escrow_address, + url, + hash, + ) -> None: + logger.info( + f"DEV: Would store results for escrow '{escrow_address}@{chain_id}' " + f"on chain: {url}, {hash}" + ) - from human_protocol_sdk.constants import ChainId - from human_protocol_sdk.escrow import EscrowData, EscrowUtils + with mock.patch("src.chain.escrow.store_results", patched_store_results): + yield - minio_client = cloud.make_client(BucketAccessInfo.parse_obj(Config.storage_config)) - - def get_local_escrow(chain_id: int, escrow_address: str) -> EscrowData: - possible_manifest_name = escrow_address.split(":")[0] - local_manifests = minio_client.list_files(bucket="manifests") - logger.info(f"Local manifests: {local_manifests}") - if possible_manifest_name in local_manifests: - logger.info(f"DEV: Using local manifest {escrow_address}") - return EscrowData( - chain_id=ChainId(chain_id), - id="test", - address=escrow_address, - amount_paid=10, - balance=10, - count=1, - factory_address="", - launcher="", - status="Pending", - token="HMT", # noqa: S106 - total_funded_amount=10, - created_at=datetime.datetime(2023, 1, 1, tzinfo=datetime.timezone.utc), - manifest_url=( - f"http://{Config.storage_config.endpoint_url}/manifests/{possible_manifest_name}" - ), - ) - return original_get_escrow(ChainId(chain_id), escrow_address) - original_get_escrow = EscrowUtils.get_escrow - EscrowUtils.get_escrow = get_local_escrow +@contextmanager +def _mock_webhook_signature_checking(_: Logger) -> Generator[None, None, None]: + """ + Allows to receive webhooks from other services: + - from exchange oracle - + signed with Config.localhost.exchange_oracle_address + or with signature "exchange_oracle" + """ - import src.schemas.webhook + from src.chain.escrow import ( + get_available_webhook_types as original_get_available_webhook_types, + ) from src.core.types import OracleWebhookTypes + from src.validators.signature import ( + validate_oracle_webhook_signature as original_validate_oracle_webhook_signature, + ) - src.schemas.webhook.validate_address = lambda x: x - - async def lenient_validate_oracle_webhook_signature( - request, # noqa: ARG001 (not relevant here) - signature, - webhook, # noqa: ARG001 (not relevant here) + async def patched_validate_oracle_webhook_signature(request, signature, webhook): + for webhook_type in OracleWebhookTypes: + if signature.startswith(webhook_type.value.lower()): + return webhook_type + + return await original_validate_oracle_webhook_signature(request, signature, webhook) + + def patched_get_available_webhook_types(chain_id, escrow_address): + d = dict(original_get_available_webhook_types(chain_id, escrow_address)) + d[Config.localhost.exchange_oracle_address.lower()] = OracleWebhookTypes.exchange_oracle + return d + + with ( + mock.patch("src.schemas.webhook.validate_address", lambda x: x), + mock.patch( + "src.validators.signature.get_available_webhook_types", + patched_get_available_webhook_types, + ), + mock.patch( + "src.endpoints.webhook.validate_oracle_webhook_signature", + patched_validate_oracle_webhook_signature, + ), ): - try: - parsed_type = OracleWebhookTypes(signature.split(":")[0]) - logger.info(f"DEV: Recovered {parsed_type} from the signature {signature}") - except (ValueError, TypeError): - logger.info(f"DEV: Falling back to {OracleWebhookTypes.exchange_oracle} webhook sender") - return OracleWebhookTypes.exchange_oracle + yield - import src.endpoints.webhook - src.endpoints.webhook.validate_oracle_webhook_signature = ( - lenient_validate_oracle_webhook_signature - ) +@contextmanager +def apply_local_development_patches() -> Generator[None, None, None]: + """ + Applies local development patches to avoid manual source code modification + """ - import src.chain.escrow + logger = get_function_logger(apply_local_development_patches.__name__) - def store_results( - chain_id: int, # noqa: ARG001 (not relevant here) - escrow_address: str, - url: str, - hash: str, - ) -> None: - logger.info(f"Would store results for escrow {escrow_address} on chain: {url}, {hash}") + logger.warning("DEV: Applying local development patches") + + with ExitStack() as es: + for mock_callback in ( + _mock_get_manifests_from_minio, + _mock_webhook_signature_checking, + _mock_escrow_results_saving, + ): + logger.warning(f"DEV: applying patch {mock_callback.__name__}...") + es.enter_context(mock_callback(logger)) - src.chain.escrow.store_results = store_results + logger.warning("DEV: Local development patches applied.") - logger.warning("Local development patches applied.") + yield if __name__ == "__main__": - is_dev = Config.environment == "development" - if is_dev: - apply_local_development_patches() - - Config.validate() - register_in_kvstore() - - uvicorn.run( - app="src:app", - host="0.0.0.0", # noqa: S104 - port=int(Config.port), - workers=Config.workers_amount, - ) + with ExitStack() as es: + is_dev = Config.environment == "development" + if is_dev: + es.enter_context(apply_local_development_patches()) + + Config.validate() + register_in_kvstore() + + uvicorn.run( + app="src:app", + host="0.0.0.0", # noqa: S104 + port=int(Config.port), + workers=Config.workers_amount, + ) diff --git a/packages/examples/cvat/recording-oracle/pyproject.toml b/packages/examples/cvat/recording-oracle/pyproject.toml index 6b42e915e8..01be8d38b3 100644 --- a/packages/examples/cvat/recording-oracle/pyproject.toml +++ b/packages/examples/cvat/recording-oracle/pyproject.toml @@ -106,6 +106,7 @@ ignore = [ "ANN002", # Missing type annotation for `*args` "TRY300", # Consider moving this statement to an `else` block "C901", # Function is too complex + "PLW1508", # invalid-envvar-default. Alerts only for os.getenv(), but not for os.environ.get() "PLW2901", # Variable overwritten by assignment target "PTH118", # Prefer pathlib instead of os.path "PTH119", # `os.path.basename()` should be replaced by `Path.name` diff --git a/packages/examples/cvat/recording-oracle/src/.env.template b/packages/examples/cvat/recording-oracle/src/.env.template index e4889a7ddb..729027cc6d 100644 --- a/packages/examples/cvat/recording-oracle/src/.env.template +++ b/packages/examples/cvat/recording-oracle/src/.env.template @@ -37,7 +37,6 @@ PROCESS_REPUTATION_ORACLE_WEBHOOKS_CHUNK_SIZE= STORAGE_PROVIDER= STORAGE_ENDPOINT_URL= -STORAGE_REGION= STORAGE_ACCESS_KEY= STORAGE_SECRET_KEY= STORAGE_RESULTS_BUCKET_NAME= @@ -48,7 +47,6 @@ STORAGE_KEY_FILE_PATH= EXCHANGE_ORACLE_STORAGE_PROVIDER= EXCHANGE_ORACLE_STORAGE_ENDPOINT_URL= -EXCHANGE_ORACLE_STORAGE_REGION= EXCHANGE_ORACLE_STORAGE_ACCESS_KEY= EXCHANGE_ORACLE_STORAGE_SECRET_KEY= EXCHANGE_ORACLE_STORAGE_RESULTS_BUCKET_NAME= @@ -66,6 +64,7 @@ CVAT_QUALITY_CHECK_INTERVAL= # Localhost +LOCALHOST_EXCHANGE_ORACLE_ADDRESS= LOCALHOST_EXCHANGE_ORACLE_URL= LOCALHOST_REPUTATION_ORACLE_URL= @@ -75,12 +74,13 @@ ENABLE_CUSTOM_CLOUD_HOST= # Validation -DEFAULT_POINT_VALIDITY_RELATIVE_RADIUS= -DEFAULT_OKS_SIGMA= -GT_FAILURE_THRESHOLD= +MIN_AVAILABLE_GT_THRESHOLD= +MAX_USABLE_GT_SHARE= GT_BAN_THRESHOLD= UNVERIFIABLE_ASSIGNMENTS_THRESHOLD= MAX_ESCROW_ITERATIONS= +WARMUP_ITERATIONS= +MIN_WARMUP_PROGRESS= # Encryption PGP_PRIVATE_KEY= diff --git a/packages/examples/cvat/recording-oracle/src/chain/escrow.py b/packages/examples/cvat/recording-oracle/src/chain/escrow.py index 33b8c78a85..bc6e89da60 100644 --- a/packages/examples/cvat/recording-oracle/src/chain/escrow.py +++ b/packages/examples/cvat/recording-oracle/src/chain/escrow.py @@ -7,6 +7,7 @@ from src.chain.web3 import get_web3 from src.core.config import Config +from src.core.types import OracleWebhookTypes def get_escrow(chain_id: int, escrow_address: str) -> EscrowData: @@ -64,9 +65,8 @@ def store_results(chain_id: int, escrow_address: str, url: str, hash: str) -> No escrow_client.store_results(escrow_address, url, hash) -def get_reputation_oracle_address(chain_id: int, escrow_address: str) -> str: - return get_escrow(chain_id, escrow_address).reputation_oracle - - -def get_exchange_oracle_address(chain_id: int, escrow_address: str) -> str: - return get_escrow(chain_id, escrow_address).exchange_oracle +def get_available_webhook_types( + chain_id: int, escrow_address: str +) -> dict[str, OracleWebhookTypes]: + escrow = get_escrow(chain_id, escrow_address) + return {(escrow.exchange_oracle or "").lower(): OracleWebhookTypes.exchange_oracle} diff --git a/packages/examples/cvat/recording-oracle/src/core/config.py b/packages/examples/cvat/recording-oracle/src/core/config.py index d2d539340a..687a9cf5b1 100644 --- a/packages/examples/cvat/recording-oracle/src/core/config.py +++ b/packages/examples/cvat/recording-oracle/src/core/config.py @@ -4,6 +4,7 @@ import inspect import os from collections.abc import Iterable +from os import getenv from typing import ClassVar from attrs.converters import to_bool @@ -15,7 +16,7 @@ from src.utils.logging import parse_log_level from src.utils.net import is_ipv4 -dotenv_path = os.getenv("DOTENV_PATH", None) +dotenv_path = getenv("DOTENV_PATH", None) if dotenv_path and not os.path.exists(dotenv_path): # noqa: PTH110 raise FileNotFoundError(dotenv_path) @@ -29,12 +30,12 @@ def validate(cls) -> None: class Postgres: - port = os.environ.get("PG_PORT", "5434") - host = os.environ.get("PG_HOST", "0.0.0.0") # noqa: S104 - user = os.environ.get("PG_USER", "admin") - password = os.environ.get("PG_PASSWORD", "admin") - database = os.environ.get("PG_DB", "recording_oracle") - lock_timeout = int(os.environ.get("PG_LOCK_TIMEOUT", "3000")) # milliseconds + port = getenv("PG_PORT", "5434") + host = getenv("PG_HOST", "0.0.0.0") # noqa: S104 + user = getenv("PG_USER", "admin") + password = getenv("PG_PASSWORD", "admin") + database = getenv("PG_DB", "recording_oracle") + lock_timeout = int(getenv("PG_LOCK_TIMEOUT", "3000")) # milliseconds @classmethod def connection_url(cls) -> str: @@ -58,43 +59,43 @@ def is_configured(cls) -> bool: class PolygonMainnetConfig(_NetworkConfig): chain_id = 137 - rpc_api = os.environ.get("POLYGON_MAINNET_RPC_API_URL") - private_key = os.environ.get("POLYGON_MAINNET_PRIVATE_KEY") - addr = os.environ.get("POLYGON_MAINNET_ADDR") + rpc_api = getenv("POLYGON_MAINNET_RPC_API_URL") + private_key = getenv("POLYGON_MAINNET_PRIVATE_KEY") + addr = getenv("POLYGON_MAINNET_ADDR") class PolygonAmoyConfig(_NetworkConfig): chain_id = 80002 - rpc_api = os.environ.get("POLYGON_AMOY_RPC_API_URL") - private_key = os.environ.get("POLYGON_AMOY_PRIVATE_KEY") - addr = os.environ.get("POLYGON_AMOY_ADDR") + rpc_api = getenv("POLYGON_AMOY_RPC_API_URL") + private_key = getenv("POLYGON_AMOY_PRIVATE_KEY") + addr = getenv("POLYGON_AMOY_ADDR") class LocalhostConfig(_NetworkConfig): chain_id = 1338 - rpc_api = os.environ.get("LOCALHOST_RPC_API_URL", "http://blockchain-node:8545") - private_key = os.environ.get( + rpc_api = getenv("LOCALHOST_RPC_API_URL", "http://blockchain-node:8545") + private_key = getenv( "LOCALHOST_PRIVATE_KEY", "ac0974bec39a17e36ba4a6b4d238ff944bacb478cbed5efcae784d7bf4f2ff80", ) - addr = os.environ.get("LOCALHOST_AMOY_ADDR", "0xf39Fd6e51aad88F6F4ce6aB8827279cffFb92266") + addr = getenv("LOCALHOST_AMOY_ADDR", "0xf39Fd6e51aad88F6F4ce6aB8827279cffFb92266") - exchange_oracle_url = os.environ.get("LOCALHOST_EXCHANGE_ORACLE_URL") - reputation_oracle_url = os.environ.get("LOCALHOST_REPUTATION_ORACLE_URL") + exchange_oracle_address = getenv("LOCALHOST_EXCHANGE_ORACLE_ADDRESS") + exchange_oracle_url = getenv("LOCALHOST_EXCHANGE_ORACLE_URL") + + reputation_oracle_url = getenv("LOCALHOST_REPUTATION_ORACLE_URL") class CronConfig: - process_exchange_oracle_webhooks_int = int( - os.environ.get("PROCESS_EXCHANGE_ORACLE_WEBHOOKS_INT", 3000) - ) - process_exchange_oracle_webhooks_chunk_size = os.environ.get( - "PROCESS_EXCHANGE_ORACLE_WEBHOOKS_CHUNK_SIZE", 5 + process_exchange_oracle_webhooks_int = int(getenv("PROCESS_EXCHANGE_ORACLE_WEBHOOKS_INT", 3000)) + process_exchange_oracle_webhooks_chunk_size = int( + getenv("PROCESS_EXCHANGE_ORACLE_WEBHOOKS_CHUNK_SIZE", 5) ) process_reputation_oracle_webhooks_int = int( - os.environ.get("PROCESS_REPUTATION_ORACLE_WEBHOOKS_INT", 3000) + getenv("PROCESS_REPUTATION_ORACLE_WEBHOOKS_INT", 3000) ) - process_reputation_oracle_webhooks_chunk_size = os.environ.get( - "PROCESS_REPUTATION_ORACLE_WEBHOOKS_CHUNK_SIZE", 5 + process_reputation_oracle_webhooks_chunk_size = int( + getenv("PROCESS_REPUTATION_ORACLE_WEBHOOKS_CHUNK_SIZE", 5) ) @@ -103,7 +104,6 @@ class IStorageConfig: data_bucket_name: ClassVar[str] secure: ClassVar[bool] endpoint_url: ClassVar[str] # TODO: probably should be optional - region: ClassVar[str | None] # AWS S3 specific attributes access_key: ClassVar[str | None] secret_key: ClassVar[str | None] @@ -128,16 +128,15 @@ def bucket_url(cls) -> str: class StorageConfig(IStorageConfig): provider = os.environ["STORAGE_PROVIDER"].lower() endpoint_url = os.environ["STORAGE_ENDPOINT_URL"] # TODO: probably should be optional - region = os.environ.get("STORAGE_REGION") data_bucket_name = os.environ["STORAGE_RESULTS_BUCKET_NAME"] - secure = to_bool(os.environ.get("STORAGE_USE_SSL", "true")) + secure = to_bool(getenv("STORAGE_USE_SSL", "true")) # AWS S3 specific attributes - access_key = os.environ.get("STORAGE_ACCESS_KEY") - secret_key = os.environ.get("STORAGE_SECRET_KEY") + access_key = getenv("STORAGE_ACCESS_KEY") + secret_key = getenv("STORAGE_SECRET_KEY") # GCS specific attributes - key_file_path = os.environ.get("STORAGE_KEY_FILE_PATH") + key_file_path = getenv("STORAGE_KEY_FILE_PATH") class ExchangeOracleStorageConfig(IStorageConfig): @@ -146,65 +145,72 @@ class ExchangeOracleStorageConfig(IStorageConfig): endpoint_url = os.environ[ "EXCHANGE_ORACLE_STORAGE_ENDPOINT_URL" ] # TODO: probably should be optional - region = os.environ.get("EXCHANGE_ORACLE_STORAGE_REGION") data_bucket_name = os.environ["EXCHANGE_ORACLE_STORAGE_RESULTS_BUCKET_NAME"] - results_dir_suffix = os.environ.get("STORAGE_RESULTS_DIR_SUFFIX", "-results") - secure = to_bool(os.environ.get("EXCHANGE_ORACLE_STORAGE_USE_SSL", "true")) + results_dir_suffix = getenv("STORAGE_RESULTS_DIR_SUFFIX", "-results") + secure = to_bool(getenv("EXCHANGE_ORACLE_STORAGE_USE_SSL", "true")) # AWS S3 specific attributes - access_key = os.environ.get("EXCHANGE_ORACLE_STORAGE_ACCESS_KEY") - secret_key = os.environ.get("EXCHANGE_ORACLE_STORAGE_SECRET_KEY") + access_key = getenv("EXCHANGE_ORACLE_STORAGE_ACCESS_KEY") + secret_key = getenv("EXCHANGE_ORACLE_STORAGE_SECRET_KEY") # GCS specific attributes - key_file_path = os.environ.get("EXCHANGE_ORACLE_STORAGE_KEY_FILE_PATH") + key_file_path = getenv("EXCHANGE_ORACLE_STORAGE_KEY_FILE_PATH") class FeaturesConfig: - enable_custom_cloud_host = to_bool(os.environ.get("ENABLE_CUSTOM_CLOUD_HOST", "no")) + enable_custom_cloud_host = to_bool(getenv("ENABLE_CUSTOM_CLOUD_HOST", "no")) "Allows using a custom host in manifest bucket urls" class ValidationConfig: - default_point_validity_relative_radius = float( - os.environ.get("DEFAULT_POINT_VALIDITY_RELATIVE_RADIUS", 0.9) - ) - - default_oks_sigma = float( - os.environ.get("DEFAULT_OKS_SIGMA", 0.1) # average value for COCO points - ) - "Default OKS sigma for GT skeleton points validation. Valid range is (0; 1]" + min_available_gt_threshold = float(getenv("MIN_AVAILABLE_GT_THRESHOLD", "0.3")) + """ + The minimum required share of available GT frames required to continue annotation attempts. + When there is no enough GT left, annotation stops. + """ - gt_failure_threshold = float(os.environ.get("GT_FAILURE_THRESHOLD", 0.9)) + max_gt_share = float(getenv("MAX_USABLE_GT_SHARE", "0.05")) """ - The maximum allowed fraction of failed assignments per GT sample, - before it's considered failed for the current validation iteration. - v = 0 -> any GT failure leads to image failure - v = 1 -> any GT failures do not lead to image failure + The maximum share of the dataset to be used for validation. If the available GT share is + greater than this number, the extra frames will not be used. It's recommended to keep this + value small enough for faster convergence rate of the annotation process. """ - gt_ban_threshold = int(os.environ.get("GT_BAN_THRESHOLD", 3)) + gt_ban_threshold = float(getenv("GT_BAN_THRESHOLD", "0.03")) """ - The maximum allowed number of failures per GT sample before it's excluded from validation + The minimum allowed rating (annotation probability) per GT sample, + before it's considered bad and banned for further use. """ - unverifiable_assignments_threshold = float( - os.environ.get("UNVERIFIABLE_ASSIGNMENTS_THRESHOLD", 0.1) - ) + unverifiable_assignments_threshold = float(getenv("UNVERIFIABLE_ASSIGNMENTS_THRESHOLD", "0.1")) """ + Deprecated. Not expected to happen in practice, kept only as a safety fallback rule. + The maximum allowed fraction of jobs with insufficient GT available for validation. Each such job will be accepted "blindly", as we can't validate the annotations. """ - max_escrow_iterations = int(os.getenv("MAX_ESCROW_ITERATIONS", "0")) + max_escrow_iterations = int(getenv("MAX_ESCROW_ITERATIONS", "50")) """ Maximum escrow annotation-validation iterations. After this, the escrow is finished automatically. Supposed only for testing. Use 0 to disable. """ + warmup_iterations = int(getenv("WARMUP_ITERATIONS", "1")) + """ + The first escrow iterations where the annotation speed is checked to be big enough. + """ + + min_warmup_progress = float(getenv("MIN_WARMUP_PROGRESS", "10")) + """ + Minimum percent of the accepted jobs in an escrow after the first WARMUP iterations. + If the value is lower, the escrow annotation is paused for manual investigation. + """ + class EncryptionConfig(_BaseConfig): - pgp_passphrase = os.environ.get("PGP_PASSPHRASE", "") - pgp_private_key = os.environ.get("PGP_PRIVATE_KEY", "") - pgp_public_key_url = os.environ.get("PGP_PUBLIC_KEY_URL", "") + pgp_passphrase = getenv("PGP_PASSPHRASE", "") + pgp_private_key = getenv("PGP_PRIVATE_KEY", "") + pgp_public_key_url = getenv("PGP_PUBLIC_KEY_URL", "") @classmethod def validate(cls) -> None: @@ -224,22 +230,22 @@ def validate(cls) -> None: class CvatConfig: - cvat_url = os.environ.get("CVAT_URL", "http://localhost:8080") - cvat_admin = os.environ.get("CVAT_ADMIN", "admin") - cvat_admin_pass = os.environ.get("CVAT_ADMIN_PASS", "admin") - cvat_org_slug = os.environ.get("CVAT_ORG_SLUG", "org1") + host_url = getenv("CVAT_URL", "http://localhost:8080") + admin_login = getenv("CVAT_ADMIN", "admin") + admin_pass = getenv("CVAT_ADMIN_PASS", "admin") + org_slug = getenv("CVAT_ORG_SLUG", "org1") - cvat_quality_retrieval_timeout = int(os.environ.get("CVAT_QUALITY_RETRIEVAL_TIMEOUT", 60 * 60)) - cvat_quality_check_interval = int(os.environ.get("CVAT_QUALITY_CHECK_INTERVAL", 5)) + quality_retrieval_timeout = int(getenv("CVAT_QUALITY_RETRIEVAL_TIMEOUT", 60 * 60)) + quality_check_interval = int(getenv("CVAT_QUALITY_CHECK_INTERVAL", 5)) class Config: - port = int(os.environ.get("PORT", 8000)) - environment = os.environ.get("ENVIRONMENT", "development") - workers_amount = int(os.environ.get("WORKERS_AMOUNT", 1)) - webhook_max_retries = int(os.environ.get("WEBHOOK_MAX_RETRIES", 5)) - webhook_delay_if_failed = int(os.environ.get("WEBHOOK_DELAY_IF_FAILED", 60)) - loglevel = parse_log_level(os.environ.get("LOGLEVEL", "info")) + port = int(getenv("PORT", 8000)) + environment = getenv("ENVIRONMENT", "development") + workers_amount = int(getenv("WORKERS_AMOUNT", 1)) + webhook_max_retries = int(getenv("WEBHOOK_MAX_RETRIES", 5)) + webhook_delay_if_failed = int(getenv("WEBHOOK_DELAY_IF_FAILED", 60)) + loglevel = parse_log_level(getenv("LOGLEVEL", "info")) polygon_mainnet = PolygonMainnetConfig polygon_amoy = PolygonAmoyConfig diff --git a/packages/examples/cvat/recording-oracle/src/core/gt_stats.py b/packages/examples/cvat/recording-oracle/src/core/gt_stats.py index a94c8dfb54..74c8c80647 100644 --- a/packages/examples/cvat/recording-oracle/src/core/gt_stats.py +++ b/packages/examples/cvat/recording-oracle/src/core/gt_stats.py @@ -6,12 +6,13 @@ class ValidationFrameStats: accumulated_quality: float = 0.0 failed_attempts: int = 0 accepted_attempts: int = 0 + total_uses: int = 0 + enabled: bool = True @property - def average_quality(self) -> float: - return self.accumulated_quality / ((self.failed_attempts + self.accepted_attempts) or 1) + def rating(self) -> float: + return (self.accepted_attempts + 1) / (self.total_uses + 1) -_TaskIdValFrameIdPair = tuple[int, int] - -GtStats = dict[_TaskIdValFrameIdPair, ValidationFrameStats] +GtKey = str +GtStats = dict[GtKey, ValidationFrameStats] diff --git a/packages/examples/cvat/recording-oracle/src/core/validation_errors.py b/packages/examples/cvat/recording-oracle/src/core/validation_errors.py index b2d4b68a47..40ad8c1d39 100644 --- a/packages/examples/cvat/recording-oracle/src/core/validation_errors.py +++ b/packages/examples/cvat/recording-oracle/src/core/validation_errors.py @@ -13,3 +13,16 @@ def __str__(self) -> str: class LowAccuracyError(DatasetValidationError): pass + + +class TooSlowAnnotationError(DatasetValidationError): + def __init__(self, current_progress: float, current_iteration: int): + super().__init__() + self.current_progress = current_progress + self.current_iteration = current_iteration + + def __str__(self): + return ( + f"Escrow annotation progress is too small: {self.current_progress:.2f}% " + f"at the {self.current_iteration} iterations" + ) diff --git a/packages/examples/cvat/recording-oracle/src/cvat/api_calls.py b/packages/examples/cvat/recording-oracle/src/cvat/api_calls.py index ab52816957..72d61c4a2b 100644 --- a/packages/examples/cvat/recording-oracle/src/cvat/api_calls.py +++ b/packages/examples/cvat/recording-oracle/src/cvat/api_calls.py @@ -14,13 +14,13 @@ def get_api_client() -> ApiClient: configuration = Configuration( - host=Config.cvat_config.cvat_url, - username=Config.cvat_config.cvat_admin, - password=Config.cvat_config.cvat_admin_pass, + host=Config.cvat_config.host_url, + username=Config.cvat_config.admin_login, + password=Config.cvat_config.admin_pass, ) api_client = ApiClient(configuration=configuration) - api_client.set_default_header("X-organization", Config.cvat_config.cvat_org_slug) + api_client.set_default_header("X-organization", Config.cvat_config.org_slug) return api_client @@ -40,8 +40,8 @@ def get_last_task_quality_report(task_id: int) -> models.QualityReport | None: def compute_task_quality_report( task_id: int, *, - timeout: int = Config.cvat_config.cvat_quality_retrieval_timeout, - check_interval: float = Config.cvat_config.cvat_quality_check_interval, + timeout: int = Config.cvat_config.quality_retrieval_timeout, + check_interval: float = Config.cvat_config.quality_check_interval, ) -> models.QualityReport: logger = logging.getLogger("app") start_time = utcnow() @@ -91,8 +91,8 @@ def get_task(task_id: int) -> models.TaskRead: def get_task_quality_report( task_id: int, *, - timeout: int = Config.cvat_config.cvat_quality_retrieval_timeout, - check_interval: float = Config.cvat_config.cvat_quality_check_interval, + timeout: int = Config.cvat_config.quality_retrieval_timeout, + check_interval: float = Config.cvat_config.quality_check_interval, ) -> models.QualityReport: logger = logging.getLogger("app") diff --git a/packages/examples/cvat/recording-oracle/src/handlers/process_intermediate_results.py b/packages/examples/cvat/recording-oracle/src/handlers/process_intermediate_results.py index 725c54501d..7d090dd6b5 100644 --- a/packages/examples/cvat/recording-oracle/src/handlers/process_intermediate_results.py +++ b/packages/examples/cvat/recording-oracle/src/handlers/process_intermediate_results.py @@ -3,7 +3,9 @@ import io import logging import os +from collections import Counter from dataclasses import dataclass +from functools import cached_property from pathlib import Path from tempfile import TemporaryDirectory from typing import TYPE_CHECKING, TypeVar @@ -12,21 +14,31 @@ import numpy as np import src.cvat.api_calls as cvat_api +import src.models.validation as db_models import src.services.validation as db_service from src.core.annotation_meta import AnnotationMeta from src.core.config import Config -from src.core.gt_stats import GtStats, ValidationFrameStats +from src.core.gt_stats import GtKey, GtStats, ValidationFrameStats from src.core.types import TaskTypes -from src.core.validation_errors import DatasetValidationError, LowAccuracyError, TooFewGtError +from src.core.validation_errors import ( + DatasetValidationError, + LowAccuracyError, + TooFewGtError, + TooSlowAnnotationError, +) from src.core.validation_meta import JobMeta, ResultMeta, ValidationMeta from src.core.validation_results import ValidationFailure, ValidationSuccess from src.db.utils import ForUpdateParams from src.services.cloud import make_client as make_cloud_client from src.services.cloud.utils import BucketAccessInfo +from src.utils import grouped from src.utils.annotations import ProjectLabels +from src.utils.formatting import value_and_percent from src.utils.zip_archive import extract_zip_archive, write_dir_to_zip_archive if TYPE_CHECKING: + from collections.abc import Callable, Sequence + from sqlalchemy.orm import Session from src.core.manifest import TaskManifest @@ -187,9 +199,9 @@ def _validate_jobs(self): # assess quality of the job's honeypots task_quality_report_data = task_id_to_quality_report_data[cvat_task_id] - sorted_task_frame_names = task_id_to_sequence_of_frame_names[cvat_task_id] - task_honeypots = {int(frame) for frame in task_quality_report_data.frame_results} - honeypots_mapping = task_id_to_honeypots_mapping[cvat_task_id] + task_frame_names = task_id_to_sequence_of_frame_names[cvat_task_id] + task_honeypots = set(task_id_to_val_layout[cvat_task_id].honeypot_frames) + task_honeypots_mapping = task_id_to_honeypots_mapping[cvat_task_id] job_honeypots = task_honeypots & set(job_meta.job_frame_range) if not job_honeypots: @@ -198,8 +210,8 @@ def _validate_jobs(self): continue for honeypot in job_honeypots: - val_frame = honeypots_mapping[honeypot] - val_frame_name = sorted_task_frame_names[val_frame] + val_frame = task_honeypots_mapping[honeypot] + val_frame_name = task_frame_names[val_frame] result = task_quality_report_data.frame_results[str(honeypot)] self._gt_stats.setdefault(val_frame_name, ValidationFrameStats()) @@ -220,6 +232,14 @@ def _validate_jobs(self): if accuracy < min_quality: rejected_jobs[cvat_job_id] = LowAccuracyError() + for gt_stat in self._gt_stats.values(): + gt_stat.total_uses = max( + gt_stat.total_uses, + gt_stat.failed_attempts + gt_stat.accepted_attempts, + # at the first iteration we have no information on total uses + # from previous iterations, so we derive it from the validation results + ) + self._job_results = job_results self._rejected_jobs = rejected_jobs self._task_id_to_val_layout = task_id_to_val_layout @@ -351,6 +371,305 @@ def validate(self) -> _ValidationResult: ) +@dataclass +class _HoneypotUpdateResult: + updated_gt_stats: GtStats + can_continue_annotation: bool + + +_K = TypeVar("_K") + + +class _TaskHoneypotManager: + def __init__( + self, + task: db_models.Task, + manifest: TaskManifest, + *, + annotation_meta: AnnotationMeta, + gt_stats: GtStats, + validation_result: _ValidationResult, + logger: logging.Logger, + rng: np.random.Generator | None = None, + ): + self.task = task + self.logger = logger + self.annotation_meta = annotation_meta + self.gt_stats = gt_stats + self.validation_result = validation_result + self.manifest = manifest + + self._job_annotation_meta_by_job_id = { + meta.job_id: meta for meta in self.annotation_meta.jobs + } + + if not rng: + rng = np.random.default_rng() + self.rng = rng + + def _make_gt_key(self, validation_frame_name: str) -> GtKey: + return validation_frame_name + + @cached_property + def _get_gt_frame_uses(self) -> dict[GtKey, int]: + return {gt_key: gt_stat.total_uses for gt_key, gt_stat in self.gt_stats.items()} + + def _select_random_least_used( + self, + items: Sequence[_K], + count: int, + *, + key: Callable[[_K], int] | None = None, + rng: np.random.Generator | None = None, + ) -> Sequence[_K]: + """ + Selects 'count' least used items randomly, without repetition. + 'key' can be used to provide a custom item count function. + """ + if not rng: + rng = self.rng + + if not key: + item_counts = Counter(items) + key = item_counts.__getitem__ + + pick = set() + for randval in rng.random(count): + # TODO: try to optimize item counting on each iteration + # maybe by using a bagged data structure + least_use_count = min(key(item) for item in items if item not in pick) + least_used_items = [ + item for item in items if key(item) == least_use_count if item not in pick + ] + pick.add(least_used_items[int(randval * len(least_used_items))]) + + return pick + + def _get_available_gt_frames(self): + if max_gt_share := Config.validation.max_gt_share: + # Limit maximum used GT frames + regular_frames_count = 0 + for task_id, task_val_layout in self.validation_result.task_id_to_val_layout.items(): + # Safety check for the next operations. Here we assume + # that all the tasks use the same GT frames. + task_validation_frames = task_val_layout.validation_frames + task_frame_names = self.validation_result.task_id_to_frame_names[task_id] + task_gt_keys = { + self._make_gt_key(task_frame_names[f]) for f in task_validation_frames + } + + # Populate missing entries for unused GT frames + for gt_key in task_gt_keys: + if gt_key not in self.gt_stats: + self.gt_stats[gt_key] = ValidationFrameStats() + + regular_frames_count += ( + len(task_frame_names) + - len(task_validation_frames) + - task_val_layout.honeypot_count + ) + + if len(self.manifest.annotation.labels) != 1: + # TODO: count GT frames per label set to avoid situations with empty GT sets + # for some labels or tasks. + # Note that different task types can have different label setups. + self.logger.warning( + "Tasks with multiple labels are not supported yet." + " Honeypots in tasks will not be limited" + ) + else: + total_frames_count = regular_frames_count + len(self.gt_stats) + enabled_gt_keys = {k for k, gt_stat in self.gt_stats.items() if gt_stat.enabled} + current_gt_share = len(enabled_gt_keys) / (total_frames_count or 1) + max_usable_gt_share = min( + len(self.gt_stats) / (total_frames_count or 1), max_gt_share + ) + max_gt_count = min(int(max_gt_share * total_frames_count), len(self.gt_stats)) + has_updates = False + if max_gt_count < len(enabled_gt_keys): + # disable some validation frames, take the least used ones + pick = self._select_random_least_used( + enabled_gt_keys, + count=len(enabled_gt_keys) - max_gt_count, + key=lambda k: self.gt_stats[k].total_uses, + ) + + enabled_gt_keys.difference_update(pick) + has_updates = True + elif ( + # Allow restoring GT frames on max limit config changes + current_gt_share < max_usable_gt_share + ): + # add more validation frames, take the most used ones + pick = self._select_random_least_used( + enabled_gt_keys, + count=max_gt_count - len(enabled_gt_keys), + key=lambda k: -self.gt_stats[k].total_uses, + ) + + enabled_gt_keys.update(pick) + has_updates = True + + if has_updates: + for gt_key, gt_stat in self.gt_stats.items(): + gt_stat.enabled = gt_key in enabled_gt_keys + + return { + gt_key + for gt_key, gt_stat in self.gt_stats.items() + if gt_stat.enabled + if gt_stat.rating > Config.validation.gt_ban_threshold + } + + def _check_warmup_annotation_speed(self): + validation_result = self.validation_result + rejected_jobs = validation_result.rejected_jobs + + current_iteration = self.task.iteration + 1 + total_jobs_count = len(validation_result.job_results) + completed_jobs_count = total_jobs_count - len(rejected_jobs) + current_progress = completed_jobs_count / (total_jobs_count or 1) * 100 + if ( + (Config.validation.warmup_iterations > 0) + and (Config.validation.min_warmup_progress > 0) + and (Config.validation.warmup_iterations <= current_iteration) + and (current_progress < Config.validation.min_warmup_progress) + ): + self.logger.warning( + f"Escrow validation failed for escrow_address={self.task.escrow_address}:" + f" progress is too slow. Min required {Config.validation.min_warmup_progress:.2f}%" + f" after the first {Config.validation.warmup_iterations} iterations," + f" got {current_progress:2f} after the {current_iteration} iteration." + " Annotation will be stopped for a manual review." + ) + raise TooSlowAnnotationError( + current_progress=current_progress, current_iteration=current_iteration + ) + + def update_honeypots(self) -> _HoneypotUpdateResult: + gt_stats = self.gt_stats + validation_result = self.validation_result + rejected_jobs = validation_result.rejected_jobs + + # Update honeypots in jobs + available_gt_frames = self._get_available_gt_frames() + + if self.logger.isEnabledFor(logging.DEBUG): + self.logger.debug( + "Escrow validation for escrow_address={}: iteration: {}" + ", available GT count: {} ({}%, banned {})" + ", remaining jobs count: {} ({}%)".format( + self.task.escrow_address, + self.task.iteration, + *value_and_percent(len(available_gt_frames), len(gt_stats)), + len(gt_stats) - len(available_gt_frames), + *value_and_percent(len(rejected_jobs), len(self.annotation_meta.jobs)), + ), + ) + + should_complete = False + + self._check_warmup_annotation_speed() + + if len(available_gt_frames) / len(gt_stats) < Config.validation.min_available_gt_threshold: + self.logger.debug("Not enough available GT to continue, stopping") + return _HoneypotUpdateResult(updated_gt_stats=gt_stats, can_continue_annotation=False) + + gt_frame_uses = self._get_gt_frame_uses + + tasks_with_rejected_jobs = grouped( + rejected_jobs, key=lambda jid: self._job_annotation_meta_by_job_id[jid].task_id + ) + + # Update honeypots in rejected jobs + for cvat_task_id, task_rejected_jobs in tasks_with_rejected_jobs.items(): + if not task_rejected_jobs: + continue + + task_validation_layout = validation_result.task_id_to_val_layout[cvat_task_id] + task_frame_names = validation_result.task_id_to_frame_names[cvat_task_id] + + task_validation_frame_to_gt_key = { + # TODO: maybe switch to per GT case stats for GT frame stats + # e.g. per skeleton point (all skeleton points use the same GT frame names) + validation_frame: self._make_gt_key(task_frame_names[validation_frame]) + for validation_frame in task_validation_layout.validation_frames + } + + task_available_validation_frames = { + validation_frame + for validation_frame in task_validation_layout.validation_frames + if task_frame_names[validation_frame] in available_gt_frames + } + + if len(task_available_validation_frames) < task_validation_layout.frames_per_job_count: + # TODO: value from the manifest can be different from what's in the task + # because exchange oracle can use size multipliers for tasks + # Need to sync these values later (maybe by removing it from the manifest) + should_complete = True + self.logger.info( + f"Validation for escrow_address={self.task.escrow_address}: " + "Too few validation frames left " + f"(required: {task_validation_layout.frames_per_job_count}, " + f"left: {len(task_available_validation_frames)}) for the task({cvat_task_id}), " + "stopping annotation" + ) + break + + task_updated_disabled_frames = [ + validation_frame + for validation_frame in task_validation_layout.validation_frames + if validation_frame not in task_available_validation_frames + ] + + task_honeypot_to_index: dict[int, int] = { + honeypot: i for i, honeypot in enumerate(task_validation_layout.honeypot_frames) + } # honeypot -> honeypot list index + + task_updated_honeypot_real_frames = task_validation_layout.honeypot_real_frames.copy() + + for job_id in task_rejected_jobs: + job_frame_range = self._job_annotation_meta_by_job_id[job_id].job_frame_range + job_honeypots = sorted( + set(task_validation_layout.honeypot_frames).intersection(job_frame_range) + ) + + # Choose new unique validation frames for the job + job_validation_frames = self._select_random_least_used( + task_available_validation_frames, + count=len(job_honeypots), + key=lambda k: gt_frame_uses[task_validation_frame_to_gt_key[k]], + ) + for job_honeypot, job_validation_frame in zip( + job_honeypots, job_validation_frames, strict=False + ): + gt_frame_uses[task_validation_frame_to_gt_key[job_validation_frame]] += 1 + honeypot_index = task_honeypot_to_index[job_honeypot] + task_updated_honeypot_real_frames[honeypot_index] = job_validation_frame + + # Make sure honeypots do not repeat in jobs + assert len( + { + task_updated_honeypot_real_frames[task_honeypot_to_index[honeypot]] + for honeypot in job_honeypots + } + ) == len(job_honeypots) + + cvat_api.update_task_validation_layout( + cvat_task_id, + disabled_frames=task_updated_disabled_frames, + honeypot_real_frames=task_updated_honeypot_real_frames, + ) + + # Update GT use counts + for gt_key, gt_stat in gt_stats.items(): + gt_stat.total_uses = gt_frame_uses[gt_key] + + return _HoneypotUpdateResult( + updated_gt_stats=gt_stats, can_continue_annotation=not should_complete + ) + + def process_intermediate_results( # noqa: PLR0912 session: Session, *, @@ -392,6 +711,8 @@ def process_intermediate_results( # noqa: PLR0912 failed_attempts=gt_image_stat.failed_attempts, accepted_attempts=gt_image_stat.accepted_attempts, accumulated_quality=gt_image_stat.accumulated_quality, + total_uses=gt_image_stat.total_uses, + enabled=gt_image_stat.enabled, ) for gt_image_stat in db_service.get_task_gt_stats(session, task.id) } @@ -408,7 +729,6 @@ def process_intermediate_results( # noqa: PLR0912 validation_result = validator.validate() job_results = validation_result.job_results rejected_jobs = validation_result.rejected_jobs - updated_merged_dataset_archive = validation_result.updated_merged_dataset if logger.isEnabledFor(logging.DEBUG): logger.debug("Validation results %s", validation_result) @@ -419,135 +739,32 @@ def process_intermediate_results( # noqa: PLR0912 ) gt_stats = validation_result.gt_stats - if gt_stats: - # cvat_task_id: {val_frame_id, ...} - cvat_task_id_to_failed_val_frames: dict[int, set[int]] = {} - rejected_job_ids = rejected_jobs.keys() - - if rejected_job_ids: - job_id_to_task_id = {j.job_id: j.task_id for j in unchecked_jobs_meta.jobs} - job_id_to_frame_range = {j.job_id: j.job_frame_range for j in unchecked_jobs_meta.jobs} - - # find validation frames to be disabled - for rejected_job_id in rejected_job_ids: - job_frame_range = job_id_to_frame_range[rejected_job_id] - cvat_task_id = job_id_to_task_id[rejected_job_id] - task_honeypots_mapping = validation_result.task_id_to_honeypots_mapping[ - cvat_task_id - ] - job_honeypots = sorted(set(task_honeypots_mapping.keys()) & set(job_frame_range)) - validation_frames = [ - val_frame - for honeypot, val_frame in task_honeypots_mapping.items() - if honeypot in job_honeypots - ] - sorted_task_frame_names = validation_result.task_id_to_frame_names[cvat_task_id] - - for val_frame in validation_frames: - val_frame_name = sorted_task_frame_names[val_frame] - val_frame_stats = gt_stats[val_frame_name] - if ( - val_frame_stats.failed_attempts >= Config.validation.gt_ban_threshold - and not val_frame_stats.accepted_attempts - ): - cvat_task_id_to_failed_val_frames.setdefault(cvat_task_id, set()).add( - val_frame - ) - - for cvat_task_id, task_bad_validation_frames in cvat_task_id_to_failed_val_frames.items(): - task_validation_layout = validation_result.task_id_to_val_layout[cvat_task_id] - - task_disabled_bad_frames = ( - set(task_validation_layout.disabled_frames) & task_bad_validation_frames - ) - if task_disabled_bad_frames: - logger.error( - "Logical error occurred while disabling validation frames " - f"for the task({task_id}). Frames {task_disabled_bad_frames} " - "are already disabled." - ) - task_updated_disabled_frames = list( - set(task_validation_layout.disabled_frames) | set(task_bad_validation_frames) - ) - task_good_validation_frames = list( - set(task_validation_layout.validation_frames) - set(task_updated_disabled_frames) - ) - - if len(task_good_validation_frames) < task_validation_layout.frames_per_job_count: - should_complete = True - logger.info( - f"Validation for escrow_address={escrow_address}: " - "Too few validation frames left " - f"(required: {task_validation_layout.frames_per_job_count}, " - f"left: {len(task_good_validation_frames)}) for the task({cvat_task_id}), " - "stopping annotation" - ) - break - - task_honeypot_to_index: dict[int, int] = { - honeypot: i for i, honeypot in enumerate(task_validation_layout.honeypot_frames) - } # honeypot -> list index - - task_honeypots_mapping = validation_result.task_id_to_honeypots_mapping[cvat_task_id] - - task_rejected_jobs = [ - j - for j in unchecked_jobs_meta.jobs - if j.job_id in rejected_job_ids and j.task_id == cvat_task_id - ] - - task_updated_honeypot_real_frames = task_validation_layout.honeypot_real_frames.copy() - for job in task_rejected_jobs: - job_frame_range = job.job_frame_range - job_honeypots = sorted(set(task_honeypots_mapping.keys()) & set(job_frame_range)) - - job_honeypots_to_replace = [] - job_validation_frames_to_replace = [] - job_validation_frames_to_keep = [] - for honeypot in job_honeypots: - validation_frame = task_honeypots_mapping[honeypot] - if validation_frame in task_bad_validation_frames: - job_honeypots_to_replace.append(honeypot) - job_validation_frames_to_replace.append(validation_frame) - else: - job_validation_frames_to_keep.append(validation_frame) - - # choose new unique validation frames for the job - assert not ( - set(job_validation_frames_to_replace) & set(task_good_validation_frames) - ) - job_available_validation_frames = list( - set(task_good_validation_frames) - set(job_validation_frames_to_keep) - ) - - rng = np.random.Generator(np.random.MT19937()) - new_job_validation_frames = rng.choice( - job_available_validation_frames, - replace=False, - size=len(job_validation_frames_to_replace), - ).tolist() - - for honeypot, new_validation_frame in zip( - job_honeypots_to_replace, new_job_validation_frames, strict=True - ): - honeypot_index = task_honeypot_to_index[honeypot] - task_updated_honeypot_real_frames[honeypot_index] = new_validation_frame + if (Config.validation.max_escrow_iterations > 0) and ( + Config.validation.max_escrow_iterations <= task.iteration + ): + logger.info( + f"Validation for escrow_address={escrow_address}:" + f" too many iterations, stopping annotation" + ) + should_complete = True + elif rejected_jobs and gt_stats: + honeypot_manager = _TaskHoneypotManager( + task, + manifest, + annotation_meta=meta, + gt_stats=gt_stats, + validation_result=validation_result, + logger=logger, + ) - # Make sure honeypots do not repeat in jobs - assert len( - { - task_updated_honeypot_real_frames[task_honeypot_to_index[honeypot]] - for honeypot in job_honeypots - } - ) == len(job_honeypots) + honeypot_update_result = honeypot_manager.update_honeypots() + if not honeypot_update_result.can_continue_annotation: + should_complete = True - cvat_api.update_task_validation_layout( - cvat_task_id, - disabled_frames=task_updated_disabled_frames, - honeypot_real_frames=task_updated_honeypot_real_frames, - ) + gt_stats = honeypot_update_result.updated_gt_stats + if gt_stats: if logger.isEnabledFor(logging.DEBUG): logger.debug("Updating GT stats: %s", gt_stats) @@ -589,15 +806,6 @@ def process_intermediate_results( # noqa: PLR0912 task_jobs = task.jobs - if Config.validation.max_escrow_iterations > 0: - escrow_iteration = task.iteration - if escrow_iteration and Config.validation.max_escrow_iterations <= escrow_iteration: - logger.info( - f"Validation for escrow_address={escrow_address}:" - f" too many iterations, stopping annotation" - ) - should_complete = True - db_service.update_escrow_iteration(session, escrow_address, chain_id, task.iteration + 1) if not should_complete: @@ -657,7 +865,7 @@ def process_intermediate_results( # noqa: PLR0912 return ValidationSuccess( job_results=job_results, validation_meta=validation_meta, - resulting_annotations=updated_merged_dataset_archive.getvalue(), + resulting_annotations=validation_result.updated_merged_dataset.getvalue(), average_quality=np.mean( [v for v in job_results.values() if v != _TaskValidator.UNKNOWN_QUALITY and v >= 0] or [0] diff --git a/packages/examples/cvat/recording-oracle/src/models/validation.py b/packages/examples/cvat/recording-oracle/src/models/validation.py index bda079c86e..273db23c39 100644 --- a/packages/examples/cvat/recording-oracle/src/models/validation.py +++ b/packages/examples/cvat/recording-oracle/src/models/validation.py @@ -1,7 +1,7 @@ # pylint: disable=too-few-public-methods from __future__ import annotations -from sqlalchemy import Column, DateTime, Enum, Float, ForeignKey, Integer, String +from sqlalchemy import Boolean, Column, DateTime, Enum, Float, ForeignKey, Integer, String from sqlalchemy.orm import Mapped, relationship from sqlalchemy.sql import func @@ -61,6 +61,10 @@ class GtStats(Base): failed_attempts = Column(Integer, default=0, nullable=False) accepted_attempts = Column(Integer, default=0, nullable=False) + total_uses = Column(Integer, default=0, nullable=False) + accumulated_quality = Column(Float, default=0.0, nullable=False) + enabled = Column(Boolean, default=True, nullable=False) + task: Mapped[Task] = relationship(back_populates="gt_stats") diff --git a/packages/examples/cvat/recording-oracle/src/services/validation.py b/packages/examples/cvat/recording-oracle/src/services/validation.py index 11ee52070a..ee5428e3f2 100644 --- a/packages/examples/cvat/recording-oracle/src/services/validation.py +++ b/packages/examples/cvat/recording-oracle/src/services/validation.py @@ -156,6 +156,8 @@ def update_gt_stats( "failed_attempts": val_frame_stats.failed_attempts, "accepted_attempts": val_frame_stats.accepted_attempts, "accumulated_quality": val_frame_stats.accumulated_quality, + "total_uses": val_frame_stats.total_uses, + "enabled": val_frame_stats.enabled, } for gt_frame_name, val_frame_stats in updated_gt_stats.items() ], diff --git a/packages/examples/cvat/recording-oracle/src/utils/__init__.py b/packages/examples/cvat/recording-oracle/src/utils/__init__.py index e69de29bb2..ff791bf80a 100644 --- a/packages/examples/cvat/recording-oracle/src/utils/__init__.py +++ b/packages/examples/cvat/recording-oracle/src/utils/__init__.py @@ -0,0 +1,32 @@ +from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence +from typing import TypeVar + +_K = TypeVar("_K") +_V = TypeVar("_V") + + +def grouped( + items: Iterator[_V] | Iterable[_V], *, key: Callable[[_V], _K] +) -> Mapping[_K, Sequence[_V]]: + """ + Returns a mapping with input iterable elements grouped by key, for example: + + grouped( + [("apple1", "red"), ("apple2", "green"), ("apple3", "red")], + key=lambda v: v[1] + ) + -> + { + "red": [("apple1", "red"), ("apple3", "red")], + "green": [("apple2", "green")] + } + + Similar to itertools.groupby, but allows reiteration on resulting groups. + """ + + # Can be implemented with itertools.groupby, but it requires extra sorting for input elements + grouped_items = {} + for item in items: + grouped_items.setdefault(key(item), []).append(item) + + return grouped_items diff --git a/packages/examples/cvat/recording-oracle/src/utils/formatting.py b/packages/examples/cvat/recording-oracle/src/utils/formatting.py new file mode 100644 index 0000000000..8d9cfa3e98 --- /dev/null +++ b/packages/examples/cvat/recording-oracle/src/utils/formatting.py @@ -0,0 +1,2 @@ +def value_and_percent(numerator: float, denominator: float) -> tuple[float, float]: + return (numerator, numerator / (denominator or 1) * 100) diff --git a/packages/examples/cvat/recording-oracle/src/utils/logging.py b/packages/examples/cvat/recording-oracle/src/utils/logging.py index e7660eb0d7..a8d8c94daf 100644 --- a/packages/examples/cvat/recording-oracle/src/utils/logging.py +++ b/packages/examples/cvat/recording-oracle/src/utils/logging.py @@ -1,5 +1,6 @@ import logging -from typing import NewType +from collections.abc import Sequence +from typing import Any, NewType from src.utils.stack import current_function_name @@ -30,3 +31,9 @@ class NullLogger(logging.Logger): def __init__(self, name: str = "", level=0) -> None: super().__init__(name, level) self.disabled = True + + +def format_sequence(items: Sequence[Any], *, max_items: int = 5, separator: str = ", ") -> str: + remainder_count = len(items) - max_items + tail = f" (and {remainder_count} more)" if remainder_count > 0 else "" + return f"{separator.join(map(str, items[:max_items]))}{tail}" diff --git a/packages/examples/cvat/recording-oracle/src/validation/__init__.py b/packages/examples/cvat/recording-oracle/src/validation/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/packages/examples/cvat/recording-oracle/src/validation/annotation_matching.py b/packages/examples/cvat/recording-oracle/src/validation/annotation_matching.py deleted file mode 100644 index b5d7855a85..0000000000 --- a/packages/examples/cvat/recording-oracle/src/validation/annotation_matching.py +++ /dev/null @@ -1,144 +0,0 @@ -import itertools -from collections.abc import Callable, Sequence -from typing import NamedTuple, TypeVar - -import numpy as np -from scipy.optimize import linear_sum_assignment -from scipy.stats import gmean - -from src.core.config import Config - -Annotation = TypeVar("Annotation") - - -class Bbox(NamedTuple): - x: float - y: float - w: float - h: float - label: int - - -class Point(NamedTuple): - x: float - y: float - label: int - - -def bbox_iou(a_bbox: Bbox, b_bbox: Bbox) -> float: - """ - IoU computation for simple cases with axis-aligned bounding boxes - """ - - a_x, a_y, a_w, a_h = a_bbox[:4] - b_x, b_y, b_w, b_h = b_bbox[:4] - int_right = min(a_x + a_w, b_x + b_w) - int_left = max(a_x, b_x) - int_top = max(a_y, b_y) - int_bottom = min(a_y + a_h, b_y + b_h) - - int_w = max(0, int_right - int_left) - int_h = max(0, int_bottom - int_top) - intersection = int_w * int_h - if not intersection: - return 0 - - a_area = a_w * a_h - b_area = b_w * b_h - union = a_area + b_area - intersection - return intersection / union - - -def point_to_bbox_cmp( - bbox: Bbox, - point: Point, - *, - rel_sigma: float = Config.validation.default_point_validity_relative_radius, -) -> float: - """ - Checks that the point is within the axis-aligned bbox, - then measures the distance to the bbox center. - - rel_sigma: - Expected sigma for human point placement within a bbox - the value is relative to the bbox sides size - e.g. 0.5 = the point is likely to be within the smaller bbox with sides 0.5w x 0.5h - around the GT bbox center - """ - # bbox filter + 2d Gaussian + geomean - - if not ((bbox.x <= point.x <= bbox.x + bbox.w) and (bbox.y <= point.y <= bbox.y + bbox.h)): - return 0 - - bbox_cx = bbox.x + bbox.w / 2 - bbox_cy = bbox.y + bbox.h / 2 - scale2sq = (rel_sigma**2) * 0.5 * np.array((bbox.w**2, bbox.h**2)) - dists = np.abs((bbox_cx - point.x, bbox_cy - point.y)) - return gmean(np.exp(-(dists**2) / scale2sq)) - - -class MatchResult(NamedTuple): - matches: list[tuple[Annotation, Annotation]] - mispred: list[tuple[Annotation, Annotation]] - a_extra: list[Annotation] - b_extra: list[Annotation] - - -def match_annotations( - a_anns: Sequence[Annotation], - b_anns: Sequence[Annotation], - similarity: Callable[[Annotation, Annotation], float] = bbox_iou, - min_similarity: float = 1.0, - label_matcher: Callable[[Annotation, Annotation], bool] = lambda a, b: a.label == b.label, -) -> MatchResult: - assert callable(similarity), similarity - assert callable(label_matcher), label_matcher - - max_ann_count = max(len(a_anns), len(b_anns)) - distances = np.array( - [ - [ - 1 - similarity(a, b) if a is not None and b is not None else 1 - for b, _ in itertools.zip_longest(b_anns, range(max_ann_count), fillvalue=None) - ] - for a, _ in itertools.zip_longest(a_anns, range(max_ann_count), fillvalue=None) - ] - ) - - distances[~np.isfinite(distances)] = 1 - distances[distances > 1 - min_similarity] = 1 - - if a_anns and b_anns: - a_matches, b_matches = linear_sum_assignment(distances) - else: - a_matches = [] - b_matches = [] - - # matches: annotations we succeeded to match completely - # mispred: annotations we succeeded to match, having label mismatch - matches = [] - mispred = [] - # *_umatched: annotations of (*) we failed to match - a_unmatched = [] - b_unmatched = [] - - for a_idx, b_idx in zip(a_matches, b_matches, strict=False): - dist = distances[a_idx, b_idx] - if dist > 1 - min_similarity or dist == 1: - if a_idx < len(a_anns): - a_unmatched.append(a_anns[a_idx]) - if b_idx < len(b_anns): - b_unmatched.append(b_anns[b_idx]) - else: - a_ann = a_anns[a_idx] - b_ann = b_anns[b_idx] - if label_matcher(a_ann, b_ann): - matches.append((a_ann, b_ann)) - else: - mispred.append((a_ann, b_ann)) - - if not len(a_matches) and not len(b_matches): - a_unmatched = list(a_anns) - b_unmatched = list(b_anns) - - return MatchResult(matches, mispred, a_unmatched, b_unmatched) diff --git a/packages/examples/cvat/recording-oracle/src/validators/signature.py b/packages/examples/cvat/recording-oracle/src/validators/signature.py index deb3364a70..bd93a83f3b 100644 --- a/packages/examples/cvat/recording-oracle/src/validators/signature.py +++ b/packages/examples/cvat/recording-oracle/src/validators/signature.py @@ -3,7 +3,7 @@ from fastapi import HTTPException, Request -from src.chain.escrow import get_exchange_oracle_address +from src.chain.escrow import get_available_webhook_types from src.chain.web3 import recover_signer from src.core.types import OracleWebhookTypes from src.schemas.webhook import OracleWebhook @@ -15,22 +15,10 @@ async def validate_oracle_webhook_signature( data: bytes = await request.body() message: dict = literal_eval(data.decode("utf-8")) - signer = recover_signer(webhook.chain_id, message, signature) + signer = recover_signer(webhook.chain_id, message, signature).lower() + webhook_types = get_available_webhook_types(webhook.chain_id, webhook.escrow_address) - exchange_oracle_address = get_exchange_oracle_address(webhook.chain_id, webhook.escrow_address) - possible_signers = { - OracleWebhookTypes.exchange_oracle: exchange_oracle_address, - } - - matched_signer = next( - ( - s_type - for s_type in possible_signers - if signer.lower() == possible_signers[s_type].lower() - ), - None, - ) - if not matched_signer: + if not (webhook_sender := webhook_types.get(signer)): raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED) - return matched_signer + return webhook_sender diff --git a/packages/examples/cvat/recording-oracle/tests/integration/chain/test_escrow.py b/packages/examples/cvat/recording-oracle/tests/integration/chain/test_escrow.py index 9420567794..ed41bd41b4 100644 --- a/packages/examples/cvat/recording-oracle/tests/integration/chain/test_escrow.py +++ b/packages/examples/cvat/recording-oracle/tests/integration/chain/test_escrow.py @@ -11,7 +11,6 @@ from src.chain.escrow import ( get_escrow_manifest, - get_reputation_oracle_address, store_results, validate_escrow, ) @@ -163,27 +162,3 @@ def test_store_results_invalid_hash(self): mock_function.return_value = self.w3 with pytest.raises(EscrowClientError, match="Invalid empty hash"): store_results(self.w3.eth.chain_id, escrow_address, DEFAULT_MANIFEST_URL, "") - - def test_get_reputation_oracle_address(self): - escrow_address = create_escrow(self.w3) - with ( - patch("src.chain.escrow.get_web3") as mock_get_web3, - patch("src.chain.escrow.get_escrow") as mock_get_escrow, - ): - mock_get_web3.return_value = self.w3 - mock_escrow = MagicMock() - mock_escrow.reputation_oracle = REPUTATION_ORACLE_ADDRESS - mock_get_escrow.return_value = mock_escrow - address = get_reputation_oracle_address(self.w3.eth.chain_id, escrow_address) - assert isinstance(address, str) - assert address is not None - - def test_get_reputation_oracle_address_invalid_address(self): - with patch("src.chain.escrow.get_web3") as mock_function: - mock_function.return_value = self.w3 - with pytest.raises(EscrowClientError, match="Invalid escrow address:"): - get_reputation_oracle_address(self.w3.eth.chain_id, "invalid_address") - - def test_get_reputation_oracle_address_invalid_chain_id(self): - with pytest.raises(Exception, match="Can't find escrow"): - get_reputation_oracle_address(1, "0x1234567890123456789012345678901234567890") diff --git a/packages/examples/cvat/recording-oracle/tests/integration/services/test_validation_service.py b/packages/examples/cvat/recording-oracle/tests/integration/services/test_validation_service.py index ca238bb777..269227a063 100644 --- a/packages/examples/cvat/recording-oracle/tests/integration/services/test_validation_service.py +++ b/packages/examples/cvat/recording-oracle/tests/integration/services/test_validation_service.py @@ -3,6 +3,8 @@ import random import unittest import uuid +from collections import Counter +from collections.abc import Sequence from contextlib import ExitStack from logging import Logger from types import SimpleNamespace @@ -15,6 +17,7 @@ from src.core.annotation_meta import AnnotationMeta, JobMeta from src.core.types import Networks +from src.core.validation_errors import TooSlowAnnotationError from src.core.validation_results import ValidationFailure, ValidationSuccess from src.cvat import api_calls as cvat_api from src.db import SessionLocal @@ -119,15 +122,28 @@ def test_can_handle_lowered_quality_requirements_in_manifest(self, session: Sess with ExitStack() as common_lock_es: logger = mock.Mock(Logger) - mock_make_cloud_client = common_lock_es.enter_context( - mock.patch("src.handlers.process_intermediate_results.make_cloud_client") + common_lock_es.enter_context( + mock.patch("src.handlers.process_intermediate_results.BucketAccessInfo.parse_obj") + ) + common_lock_es.enter_context( + mock.patch("src.handlers.process_intermediate_results.dm.Dataset.import_from") + ) + common_lock_es.enter_context( + mock.patch("src.handlers.process_intermediate_results.extract_zip_archive") + ) + common_lock_es.enter_context( + mock.patch("src.handlers.process_intermediate_results.write_dir_to_zip_archive") ) - mock_make_cloud_client.return_value.download_file = mock.Mock(return_value=b"") common_lock_es.enter_context( - mock.patch("src.handlers.process_intermediate_results.BucketAccessInfo.parse_obj") + mock.patch("src.core.config.ValidationConfig.min_warmup_progress", 0), ) + mock_make_cloud_client = common_lock_es.enter_context( + mock.patch("src.handlers.process_intermediate_results.make_cloud_client") + ) + mock_make_cloud_client.return_value.download_file = mock.Mock(return_value=b"") + mock_get_task_validation_layout = common_lock_es.enter_context( mock.patch( "src.handlers.process_intermediate_results.cvat_api.get_task_validation_layout" @@ -135,8 +151,11 @@ def test_can_handle_lowered_quality_requirements_in_manifest(self, session: Sess ) mock_get_task_validation_layout.return_value = mock.Mock( cvat_api.models.ITaskValidationLayoutRead, + validation_frames=[2, 3], + honeypot_count=2, honeypot_frames=[0, 1], - honeypot_real_frames=[0, 1], + honeypot_real_frames=[2, 3], + frames_per_job_count=2, ) mock_get_task_data_meta = common_lock_es.enter_context( @@ -147,16 +166,6 @@ def test_can_handle_lowered_quality_requirements_in_manifest(self, session: Sess frames=[SimpleNamespace(name=f"frame_{i}.jpg") for i in range(frame_count)], ) - common_lock_es.enter_context( - mock.patch("src.handlers.process_intermediate_results.dm.Dataset.import_from") - ) - common_lock_es.enter_context( - mock.patch("src.handlers.process_intermediate_results.extract_zip_archive") - ) - common_lock_es.enter_context( - mock.patch("src.handlers.process_intermediate_results.write_dir_to_zip_archive") - ) - def patched_prepare_merged_dataset(self): self._updated_merged_dataset_archive = io.BytesIO() @@ -191,6 +200,11 @@ def patched_prepare_merged_dataset(self): mock.patch( "src.handlers.process_intermediate_results.cvat_api.get_jobs_quality_reports" ) as mock_get_jobs_quality_reports, + mock.patch( + "src.handlers.process_intermediate_results.cvat_api.update_task_validation_layout" + ), + mock.patch("src.core.config.ValidationConfig.min_available_gt_threshold", 0), + mock.patch("src.core.config.ValidationConfig.max_gt_share", 1), ): mock_get_task_quality_report.return_value = mock.Mock( cvat_api.models.IQualityReport, id=1 @@ -280,7 +294,7 @@ def patched_prepare_merged_dataset(self): class TestValidationLogic: - @pytest.mark.parametrize("seed", range(50)) + @pytest.mark.parametrize("seed", [41]) # range(50)) def test_can_change_bad_honeypots_in_jobs(self, session: Session, seed: int): escrow_address = ESCROW_ADDRESS chain_id = Networks.localhost @@ -305,7 +319,6 @@ def test_can_change_bad_honeypots_in_jobs(self, session: Session, seed: int): ) ( - _, task_frame_names, task_validation_frames, task_honeypots, @@ -331,9 +344,16 @@ def test_can_change_bad_honeypots_in_jobs(self, session: Session, seed: int): logger = mock.Mock(Logger) common_lock_es.enter_context( - mock.patch( - "src.core.config.Config.validation.gt_ban_threshold", max_validation_frame_uses - ) + mock.patch("src.handlers.process_intermediate_results.BucketAccessInfo.parse_obj") + ) + common_lock_es.enter_context( + mock.patch("src.handlers.process_intermediate_results.dm.Dataset.import_from") + ) + common_lock_es.enter_context( + mock.patch("src.handlers.process_intermediate_results.extract_zip_archive") + ) + common_lock_es.enter_context( + mock.patch("src.handlers.process_intermediate_results.write_dir_to_zip_archive") ) mock_make_cloud_client = common_lock_es.enter_context( @@ -341,10 +361,6 @@ def test_can_change_bad_honeypots_in_jobs(self, session: Session, seed: int): ) mock_make_cloud_client.return_value.download_file = mock.Mock(return_value=b"") - common_lock_es.enter_context( - mock.patch("src.handlers.process_intermediate_results.BucketAccessInfo.parse_obj") - ) - mock_get_task_validation_layout = common_lock_es.enter_context( mock.patch( "src.handlers.process_intermediate_results.cvat_api.get_task_validation_layout" @@ -368,16 +384,6 @@ def test_can_change_bad_honeypots_in_jobs(self, session: Session, seed: int): frames=[SimpleNamespace(name=name) for name in task_frame_names], ) - common_lock_es.enter_context( - mock.patch("src.handlers.process_intermediate_results.dm.Dataset.import_from") - ) - common_lock_es.enter_context( - mock.patch("src.handlers.process_intermediate_results.extract_zip_archive") - ) - common_lock_es.enter_context( - mock.patch("src.handlers.process_intermediate_results.write_dir_to_zip_archive") - ) - def patched_prepare_merged_dataset(self): self._updated_merged_dataset_archive = io.BytesIO() @@ -419,6 +425,10 @@ def patched_prepare_merged_dataset(self): mock.patch( "src.handlers.process_intermediate_results.cvat_api.update_task_validation_layout" ) as mock_update_task_validation_layout, + mock.patch("src.core.config.ValidationConfig.gt_ban_threshold", 0.35), + mock.patch("src.core.config.ValidationConfig.min_available_gt_threshold", 0), + mock.patch("src.core.config.ValidationConfig.max_gt_share", 1), + mock.patch("src.core.config.ValidationConfig.min_warmup_progress", 0), ): mock_get_task_quality_report.return_value = mock.Mock( cvat_api.models.IQualityReport, id=1 @@ -451,31 +461,40 @@ def patched_prepare_merged_dataset(self): logger=logger, ) - assert isinstance(vr, ValidationFailure) + assert isinstance(vr, ValidationFailure) + assert {j.job_id for j in annotation_meta.jobs} == set(vr.rejected_jobs) - assert mock_update_task_validation_layout.call_count == 1 + assert mock_update_task_validation_layout.call_count == 1 - updated_disabled_frames = mock_update_task_validation_layout.call_args.kwargs[ - "disabled_frames" - ] - assert all(v in task_validation_frames for v in updated_disabled_frames) + updated_disabled_frames = mock_update_task_validation_layout.call_args.kwargs[ + "disabled_frames" + ] + assert all(v in task_validation_frames for v in updated_disabled_frames) + assert { + f + for f, c in Counter(task_honeypot_real_frames).items() + if c == max_validation_frame_uses + } == set(updated_disabled_frames), Counter(task_honeypot_real_frames) + + updated_honeypot_real_frames = mock_update_task_validation_layout.call_args.kwargs[ + "honeypot_real_frames" + ] - updated_honeypot_real_frames = mock_update_task_validation_layout.call_args.kwargs[ - "honeypot_real_frames" + for job_start, job_stop in job_frame_ranges: + job_honeypot_positions = [ + i for i, v in enumerate(task_honeypots) if v in range(job_start, job_stop + 1) ] + job_updated_honeypots = [ + updated_honeypot_real_frames[i] for i in job_honeypot_positions + ] + + # Check that the frames do not repeat + assert sorted(job_updated_honeypots) == sorted(set(job_updated_honeypots)) + + # Check that the new frames are not from the excluded set + assert set(job_updated_honeypots).isdisjoint(updated_disabled_frames) - for job_start, job_stop in job_frame_ranges: - job_honeypot_positions = [ - i - for i, v in enumerate(task_honeypots) - if v in range(job_start, job_stop + 1) - ] - job_updated_honeypots = [ - updated_honeypot_real_frames[i] for i in job_honeypot_positions - ] - assert sorted(job_updated_honeypots) == sorted(set(job_updated_honeypots)) - - def _get_job_frame_ranges(self, jobs) -> list[tuple[int, int]]: + def _get_job_frame_ranges(self, jobs: Sequence[Sequence[str]]) -> list[tuple[int, int]]: job_frame_ranges = [] job_start = 0 for job_frames in jobs: @@ -487,53 +506,202 @@ def _get_job_frame_ranges(self, jobs) -> list[tuple[int, int]]: def _generate_task_frames( self, - frame_count, - validation_frames_count, - job_size, - validation_frames_per_job, - seed, - max_validation_frame_uses=None, - ): + frame_count: int, + validation_frames_count: int, + job_size: int, + validation_frames_per_job: int, + *, + seed: int | None = None, + max_validation_frame_uses: int | None = None, + ) -> tuple[Sequence[str], Sequence[int], Sequence[int], Sequence[int], Sequence[Sequence[str]]]: rng = np.random.Generator(np.random.MT19937(seed)) - task_frames = list(range(frame_count)) - task_validation_frames = task_frames[-validation_frames_count:] - task_real_frames = [] - task_honeypots = [] - task_honeypot_real_frames = [] - validation_frame_uses = {vf: 0 for vf in task_validation_frames} + + task_frame_names = list(map(str, range(frame_count))) + task_validation_frame_names = task_frame_names[-validation_frames_count:] + task_honeypot_real_frame_names = [] + + output_task_frame_names = [] + output_task_honeypots = [] + output_task_honeypot_real_frames = [] + + validation_frame_uses = {fn: 0 for fn in task_validation_frame_names} jobs = [] - for job_real_frames in take_by( - task_frames[: frame_count - validation_frames_count], job_size + for job_frame_names in take_by( + task_frame_names[: frame_count - validation_frames_count], job_size ): - available_validation_frames = [ - vf - for vf in task_validation_frames + available_validation_frame_names = [ + fn + for fn in task_validation_frame_names if not max_validation_frame_uses - or validation_frame_uses[vf] < max_validation_frame_uses + or validation_frame_uses[fn] < max_validation_frame_uses ] - job_validation_frames = rng.choice( - available_validation_frames, validation_frames_per_job, replace=False + job_validation_frame_names = rng.choice( + available_validation_frame_names, validation_frames_per_job, replace=False ).tolist() - job_real_frames = job_real_frames + job_validation_frames - rng.shuffle(job_real_frames) + for fn in job_validation_frame_names: + validation_frame_uses[fn] += 1 + + job_frame_names = job_frame_names + job_validation_frame_names + rng.shuffle(job_frame_names) + + jobs.append(job_frame_names) - jobs.append(job_real_frames) + job_start_frame = len(output_task_frame_names) + output_task_frame_names.extend(job_frame_names) - job_start_frame = len(task_real_frames) - task_real_frames.extend(job_real_frames) + for i, v in enumerate(job_frame_names): + if v in job_validation_frame_names: + output_task_honeypots.append(i + job_start_frame) + task_honeypot_real_frame_names.append(v) - for i, v in enumerate(job_real_frames): - if v in job_validation_frames: - task_honeypots.append(i + job_start_frame) - task_honeypot_real_frames.append(v) + validation_frame_name_to_idx = {} + for fn in task_validation_frame_names: + validation_frame_name_to_idx[fn] = len(output_task_frame_names) + output_task_frame_names.append(fn) + + output_task_honeypot_real_frames = [ + validation_frame_name_to_idx[fn] for fn in task_honeypot_real_frame_names + ] + + output_task_validation_frames = [ + validation_frame_name_to_idx[fn] for fn in task_validation_frame_names + ] - task_frame_names = list(map(str, task_real_frames)) return ( - task_real_frames, - task_frame_names, - task_validation_frames, - task_honeypots, - task_honeypot_real_frames, + output_task_frame_names, + output_task_validation_frames, + output_task_honeypots, + output_task_honeypot_real_frames, jobs, ) + + def test_can_stop_on_slow_annotation_after_warmup_iterations(self, session: Session): + escrow_address = ESCROW_ADDRESS + chain_id = Networks.localhost + + frame_count = 10 + + manifest = generate_manifest() + + cvat_task_id = 1 + cvat_job_id = 1 + annotator1 = WALLET_ADDRESS1 + + assignment1_id = f"0x{0:040d}" + assignment1_quality = 0 + + # create a validation input + with ExitStack() as common_lock_es: + logger = mock.Mock(Logger) + + common_lock_es.enter_context( + mock.patch("src.handlers.process_intermediate_results.BucketAccessInfo.parse_obj") + ) + common_lock_es.enter_context( + mock.patch("src.handlers.process_intermediate_results.dm.Dataset.import_from") + ) + common_lock_es.enter_context( + mock.patch("src.handlers.process_intermediate_results.extract_zip_archive") + ) + common_lock_es.enter_context( + mock.patch("src.handlers.process_intermediate_results.write_dir_to_zip_archive") + ) + + mock_make_cloud_client = common_lock_es.enter_context( + mock.patch("src.handlers.process_intermediate_results.make_cloud_client") + ) + mock_make_cloud_client.return_value.download_file = mock.Mock(return_value=b"") + + mock_get_task_validation_layout = common_lock_es.enter_context( + mock.patch( + "src.handlers.process_intermediate_results.cvat_api.get_task_validation_layout" + ) + ) + mock_get_task_validation_layout.return_value = mock.Mock( + cvat_api.models.ITaskValidationLayoutRead, + validation_frames=[2, 3], + honeypot_count=2, + honeypot_frames=[0, 1], + honeypot_real_frames=[2, 3], + frames_per_job_count=2, + ) + + mock_get_task_data_meta = common_lock_es.enter_context( + mock.patch("src.handlers.process_intermediate_results.cvat_api.get_task_data_meta") + ) + mock_get_task_data_meta.return_value = mock.Mock( + cvat_api.models.IDataMetaRead, + frames=[SimpleNamespace(name=f"frame_{i}.jpg") for i in range(frame_count)], + ) + + def patched_prepare_merged_dataset(self): + self._updated_merged_dataset_archive = io.BytesIO() + + common_lock_es.enter_context( + mock.patch( + "src.handlers.process_intermediate_results._TaskValidator._prepare_merged_dataset", + patched_prepare_merged_dataset, + ) + ) + + annotation_meta = AnnotationMeta( + jobs=[ + JobMeta( + job_id=cvat_job_id, + task_id=cvat_task_id, + annotation_filename="", + annotator_wallet_address=annotator1, + assignment_id=assignment1_id, + start_frame=0, + stop_frame=manifest.annotation.job_size + manifest.validation.val_size, + ) + ] + ) + + with ( + mock.patch( + "src.handlers.process_intermediate_results.cvat_api.get_task_quality_report" + ) as mock_get_task_quality_report, + mock.patch( + "src.handlers.process_intermediate_results.cvat_api.get_quality_report_data" + ) as mock_get_quality_report_data, + mock.patch( + "src.handlers.process_intermediate_results.cvat_api.get_jobs_quality_reports" + ) as mock_get_jobs_quality_reports, + mock.patch( + "src.handlers.process_intermediate_results.cvat_api.update_task_validation_layout" + ), + mock.patch("src.core.config.ValidationConfig.min_available_gt_threshold", 0), + mock.patch("src.core.config.ValidationConfig.max_gt_share", 1), + mock.patch("src.core.config.ValidationConfig.warmup_iterations", 1), + mock.patch("src.core.config.ValidationConfig.min_warmup_progress", 20), + ): + mock_get_task_quality_report.return_value = mock.Mock( + cvat_api.models.IQualityReport, id=1 + ) + mock_get_quality_report_data.return_value = mock.Mock( + cvat_api.QualityReportData, + frame_results={ + "0": mock.Mock(annotations=mock.Mock(accuracy=assignment1_quality)), + "1": mock.Mock(annotations=mock.Mock(accuracy=assignment1_quality)), + }, + ) + mock_get_jobs_quality_reports.return_value = [ + mock.Mock( + cvat_api.models.IQualityReport, + job_id=1, + summary=mock.Mock(accuracy=assignment1_quality), + ), + ] + + with pytest.raises(TooSlowAnnotationError): + process_intermediate_results( + session, + escrow_address=escrow_address, + chain_id=chain_id, + meta=annotation_meta, + merged_annotations=io.BytesIO(), + manifest=manifest, + logger=logger, + )