From f4cd737d401d27ad7335dba9767d46ff7b31d2b6 Mon Sep 17 00:00:00 2001 From: Cristi Bleotiu <164478159+cristibleotiu@users.noreply.github.com> Date: Mon, 26 Jan 2026 14:14:12 +0200 Subject: [PATCH 1/5] =?UTF-8?q?feat:=20base=5Finference=5Fapi=20+=20llm=5F?= =?UTF-8?q?inference=5Fapi=20refactor=20+=20draft=20for=20cv=5F=E2=80=A6?= =?UTF-8?q?=20(#346)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat: base_inference_api + llm_inference_api refactor + draft for cv_inference_api * chore: added docstrings and removed artifacts from previous base_inference_api implementation * feat: split cv_inference_api.py in cv_inference_api.py and cerviguard_api.py * feat: ora sync improvement + unit tests --- .devcontainer/devcontainer.json | 61 + .devcontainer/sparse-exclude.txt | 4 + .devcontainer/sparse-include.txt | 2 + .dockerignore | 26 +- .gitignore | 17 +- constants.py | 2 +- extensions/business/deeploy/deeploy_mixin.py | 23 +- .../edge_inference_api/base_inference_api.py | 999 ++++++++++++++++ .../edge_inference_api/cv_inference_api.py | 486 ++++++++ .../edge_inference_api/llm_inference_api.py | 684 +++++++++++ .../edge_inference_api/sd_inference_api.py | 532 +++++++++ .../inference_api/base_inference_api.py | 357 ------ extensions/business/jeeves/jeeves_api.py | 2 + .../business/mixins/base_agent_mixin.py | 93 ++ extensions/business/mixins/nlp_agent_mixin.py | 64 +- .../business/nlp/doc_embedding_agent.py | 4 +- extensions/business/nlp/vllm_agent.py | 6 +- .../business/oracle_management/oracle_api.py | 4 +- .../business/oracle_sync/oracle_sync_01.py | 47 +- .../sync_mixins/ora_sync_constants.py | 2 +- .../sync_mixins/ora_sync_states_mixin.py | 57 +- .../sync_mixins/ora_sync_utils_mixin.py | 116 +- .../data/default/jeeves/jeeves_listener.py | 46 +- extensions/serving/base/base_llm_serving.py | 21 +- .../default_inference/nlp/llama_cpp_base.py | 28 +- extensions/serving/mixins_llm/llm_utils.py | 1 + plugins/business/cerviguard/cerviguard_api.py | 178 +++ .../business/cerviguard/local_serving_api.py | 0 plugins/business/llm/code_assist_01.py | 9 +- ver.py | 2 +- .../oracle_sync/oracle_sync_test_plan.md | 343 ++++++ xperimental/oracle_sync/test_ora_sync.py | 1000 +++++++++++++++++ 32 files changed, 4685 insertions(+), 531 deletions(-) create mode 100644 .devcontainer/sparse-exclude.txt create mode 100644 .devcontainer/sparse-include.txt create mode 100644 extensions/business/edge_inference_api/base_inference_api.py create mode 100644 extensions/business/edge_inference_api/cv_inference_api.py create mode 100644 extensions/business/edge_inference_api/llm_inference_api.py create mode 100644 extensions/business/edge_inference_api/sd_inference_api.py delete mode 100644 extensions/business/inference_api/base_inference_api.py create mode 100644 extensions/business/mixins/base_agent_mixin.py create mode 100644 plugins/business/cerviguard/cerviguard_api.py rename {extensions => plugins}/business/cerviguard/local_serving_api.py (100%) create mode 100644 xperimental/oracle_sync/oracle_sync_test_plan.md create mode 100644 xperimental/oracle_sync/test_ora_sync.py diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index 67f82a05..e6ebcdd7 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -17,9 +17,70 @@ "--privileged" ], + // mount your Windows ~/.ssh read-only into the container + "mounts": [ + "source=${localEnv:USERPROFILE}/.ssh,target=/host-ssh,type=bind,consistency=cached,readonly" + ], + + // Wait for the clone+sparse step to finish before the IDE connects + "waitFor": "onCreateCommand", + + // Set your repo URL once here (or pass it via env) + "containerEnv": { + "REPO_URL1": "git@github.com-cristibleotiu:Ratio1/edge_node", + "REPO_URL": "https://github.com/Ratio1/edge_node", + "R1_SSH_HOST": "github.com", + + "R1_SSH_ALIAS": "github.com-cristibleotiu", + + // where the mounted keys live in the container + "R1_SSH_KEY_SOURCE_DIR": "${localEnv:USERPROFILE}/.ssh", + // filename of your private key on Windows (e.g. id_ed25519 or my_github_key) + "R1_SSH_KEY_FILENAME": "github-cristibleotiu", + + // where to place the key inside the container + "R1_SSH_DIR": "/root/.ssh", + "R1_SSH_KEY_DEST": "github-cristibleotiu", + + // sparse rule files inside the repo (adjust if you keep them elsewhere) + "R1_SPARSE_EXCLUDE_FILE": ".devcontainer/sparse-exclude.txt", + "R1_SPARSE_INCLUDE_FILE": ".devcontainer/sparse-include.txt" + }, + +// "onCreateCommand": [ +// "bash ls -all", +// // strict shell + ensure ~/.ssh with safe perms +// "bash -lc \"set -euo pipefail; mkdir -p \\\"$R1_SSH_DIR\\\"; chmod 700 \\\"$R1_SSH_DIR\\\"\"", +// +// // copy the Windows key into place with correct perms (public key optional) +// "bash -lc \"cp \\\"$R1_SSH_KEY_SOURCE_DIR/$R1_SSH_KEY_FILENAME\\\" \\\"$R1_SSH_DIR/$R1_SSH_KEY_DEST\\\"; chmod 600 \\\"$R1_SSH_DIR/$R1_SSH_KEY_DEST\\\"; if [ -f \\\"$R1_SSH_KEY_SOURCE_DIR/$R1_SSH_KEY_FILENAME.pub\\\" ]; then cp \\\"$R1_SSH_KEY_SOURCE_DIR/$R1_SSH_KEY_FILENAME.pub\\\" \\\"$R1_SSH_DIR/$R1_SSH_KEY_DEST.pub\\\"; fi\"", +// +// // write SSH config for the alias host (e.g., github.com-) → real host + key +// "bash -lc \"cat > \\\"$R1_SSH_DIR/config\\\" <> \\\"$R1_SSH_DIR/known_hosts\\\" || true\"", +// +// // blobless (partial) clone via the alias URL (e.g., git@github.com-:org/repo) +// "bash -lc \"cd ${containerWorkspaceFolder} && git clone --filter=blob:none --sparse \\\"$R1_REPO_SSH_URL\\\" .\"", +// +// // non-cone sparse mode so we can use excludes and fine-grained files +// "bash -lc \"git -C ${containerWorkspaceFolder} sparse-checkout init --no-cone\"", +// +// // include everything, then apply excludes from file, then optional re-includes +// "bash -lc \"git -C ${containerWorkspaceFolder} sparse-checkout set --no-cone '/*' $(sed -E '/^\\s*(#|$)/d; s/^/!/' \\\"${containerWorkspaceFolder}/$R1_SPARSE_EXCLUDE_FILE\\\" 2>/dev/null || true) $(sed -E '/^\\s*(#|$)/d' \\\"${containerWorkspaceFolder}/$R1_SPARSE_INCLUDE_FILE\\\" 2>/dev/null || true)\"", +// +// // drop now-excluded files from the working tree +// "bash -lc \"git -C ${containerWorkspaceFolder} clean -ffdqx || true\"" +// ], + "build": { +// "dockerfile": "Dockerfile", "context": "../", +// "options": [ +// "--progress=plain" +// ] }, diff --git a/.devcontainer/sparse-exclude.txt b/.devcontainer/sparse-exclude.txt new file mode 100644 index 00000000..fa815542 --- /dev/null +++ b/.devcontainer/sparse-exclude.txt @@ -0,0 +1,4 @@ +/data/ +/logs/ +/third_party/huge_lib/ +/weights/ diff --git a/.devcontainer/sparse-include.txt b/.devcontainer/sparse-include.txt new file mode 100644 index 00000000..d5579c3d --- /dev/null +++ b/.devcontainer/sparse-include.txt @@ -0,0 +1,2 @@ +/data/private/secret.json # keep this file even though /data/ is excluded +/weights/README.md # another example file inside an excluded dir diff --git a/.dockerignore b/.dockerignore index e650c6cd..7a082181 100644 --- a/.dockerignore +++ b/.dockerignore @@ -8,11 +8,23 @@ Dockerfile_tegra_dev config_startup.json .env -*/_local_cache/* -*/_data/* -*/_logs/* -*/_models/* -*/_output/* -*/_vector_db* -*/ratio1_* +**/_local_cache/** +**/_data/** +**/_logs/** +**/_models/** +**/_output/** +**/vectordb/** +**/_vector_db**/ +**/ratio1_* +**/db_cache/** + +.idea/** +.git/** +**/__pycache__/** + +xperimental/llama_cpp/** + +!**/_data/box_configuration/** +!**/authorized_addrs + .venv diff --git a/.gitignore b/.gitignore index 3c30d57e..aa2139b2 100644 --- a/.gitignore +++ b/.gitignore @@ -136,17 +136,32 @@ dmypy.json # local python files *__local__.py +*/_local_cache/* +*/ratio1_0* */_data/* */_logs/* */_output/* */_models/* +**/_local_cache/** +**/_data/** +**/_logs/** +**/_models/** +**/_output/** +**/vectordb/** +**/_vector_db**/ +**/ratio1_* +**/db_cache/** + config_startup*.txt config_startup*.json config_startup*.yaml config_startup*.yml inference/model_testing/_local_cache/_logs/MPTF.txt inference/model_testing/_local_cache/_logs/20211224_102325_MPTF_001_log.txt - +vectordb +_vector_db_cache +_vector_db_cache_HNSWVectorDB +db_cache plugins/libs/_cache/ core/utils/_cache/ diff --git a/constants.py b/constants.py index ff85f5d8..8e6d0c78 100644 --- a/constants.py +++ b/constants.py @@ -160,7 +160,7 @@ class JeevesCt: JEEVES_API_SIGNATURES = [ "JEEVES_API", "KEYSOFT_JEEVES", - "BASE_INFERENCE_API", + "LLM_INFERENCE_API", ] JEEVES_AGENT_SIGNATURES = [ diff --git a/extensions/business/deeploy/deeploy_mixin.py b/extensions/business/deeploy/deeploy_mixin.py index 0d3e8b78..6b8568df 100644 --- a/extensions/business/deeploy/deeploy_mixin.py +++ b/extensions/business/deeploy/deeploy_mixin.py @@ -1824,7 +1824,28 @@ def _discover_plugin_instances( ): """ Discover the plugin instances for the given app_id and target nodes. - Returns a list of dictionaries containing infomration about plugin instances. + + Returns a list of dictionaries containing information about plugin instances. + + Parameters + ---------- + app_id : str + Generated application identifier in the form of {app_name}-{uuid4}. + job_id : str + Incremental job identifier. Interchangeable with app_id for discovery, but they are never the same. + target_nodes : list[str] + List of target node addresses to filter the search. + owner : str + Owner address to filter the search. + plugin_signature : str + Plugin signature to filter the search. + instance_id : str + Plugin instance ID to filter the search. + + Returns + ------- + list[dict] + List of discovered plugin instances with details. """ apps = self._get_online_apps(owner=owner, target_nodes=target_nodes) self.P(f"online apps for owner {owner} and target_nodes {target_nodes}: {self.json_dumps(apps)}") diff --git a/extensions/business/edge_inference_api/base_inference_api.py b/extensions/business/edge_inference_api/base_inference_api.py new file mode 100644 index 00000000..0af32813 --- /dev/null +++ b/extensions/business/edge_inference_api/base_inference_api.py @@ -0,0 +1,999 @@ +""" +BASE_INFERENCE_API Plugin + +Production-Grade Inference API + +This plugin exposes a hardened, FastAPI-powered interface for generic inference. +It keeps the lightweight loopback data flow used by the +Ratio1 node while adding security, observability, and request lifecycle +management. + +It can work with both async and sync requests. +In case of sync requests, they will be processed using PostponedRequest objects. +Otherwise, the request_id will be returned immediately, and the client can poll for results. + +Highlights +- Can be exposed through tunneling for remote access or kept local-only for third-party apps hosted through Ratio1. +- We recommend using it locally paired with a third-party app that manages the rate limiting, authentication, and + request tracking (e.g., a web app built with Streamlit, Gradio, or Flask). +- In case of need for remote access, it can be exposed through tunneling with bearer-token authentication and a +built-in rate limiting mechanism. +- Supports any AI engine supported by Ratio1 through the Loopback plugin type. +- Durable, restart-safe request tracking with health/metrics/list endpoints +- Async + sync inference payload layout +- Automatic timeout handling, TTL-based eviction, and persistence to cacheapi + +In case of no tunneling and local-only access, authentication will be disabled by default. +For tunneling export `INFERENCE_API_TOKEN` (comma-separated values for multiple clients) to enforce token +checks or provide the tokens through the `PREDEFINED_AUTH_TOKENS` config parameter. + +Available Endpoints: +- POST /predict - Compute prediction (sync) +- POST /predict_async - Compute prediction (async) +- GET /health - Health check +- GET /status - Status of API +- GET /metrics - Retrieve API metrics +- GET /request_status - Check for current status of async request results + +# TODO: find a legit example for generic inference API configuration +# or keep class as abstract only? +Example pipeline configuration: +{ + "NAME": "local_inference_api", + "TYPE": "Loopback", + "PLUGINS": [ + { + "SIGNATURE": "BASE_INFERENCE_API", + "INSTANCES": [ + { + "INSTANCE_ID": "llm_interface", + "AI_ENGINE": "llama_cpp", + "STARTUP_AI_ENGINE_PARAMS": { + "HF_TOKEN": "", + "MODEL_FILENAME": "llama-3.2-1b-instruct-q4_k_m.gguf", + "MODEL_NAME": "hugging-quants/Llama-3.2-1B-Instruct-Q4_K_M-GGUF", + "SERVER_COLLECTOR_TIMEDELTA": 360000 + } + } + ] + } + ] +} +""" +from naeural_core.business.default.web_app.fast_api_web_app import FastApiWebAppPlugin as BasePlugin +from extensions.business.mixins.base_agent_mixin import _BaseAgentMixin, BASE_AGENT_MIXIN_CONFIG + +from typing import Any, Dict, List, Optional + + +__VER__ = '0.1.0' + +_CONFIG = { + **BasePlugin.CONFIG, + **BASE_AGENT_MIXIN_CONFIG, + + # MANDATORY SETTING IN ORDER TO RECEIVE REQUESTS + "ALLOW_EMPTY_INPUTS": True, # allow processing even when no input data is present + + # MANDATORY LOOPBACK SETTINGS + "IS_LOOPBACK_PLUGIN": True, + "TUNNEL_ENGINE_ENABLED": False, + "API_TITLE": "Local Inference API", + "API_SUMMARY": "FastAPI server for local-only inference.", + + "PROCESS_DELAY": 0, + "REQUEST_TIMEOUT": 600, # 10 minutes + "SAVE_PERIOD": 300, # 5 minutes + + "REQUEST_TTL_SECONDS": 60 * 60 * 2, # keep historical results for 2 hours + "RATE_LIMIT_PER_MINUTE": 5, + "AUTH_TOKEN_ENV": "INFERENCE_API_TOKEN", + "PREDEFINED_AUTH_TOKENS": [], # e.g. ["token1", "token2"] + "ALLOW_ANONYMOUS_ACCESS": True, + + "METRICS_REFRESH_SECONDS": 5 * 60, # 5 minutes + + # Semaphore key for paired plugin synchronization (e.g., with WAR containers) + # When set, this plugin will signal readiness and expose env vars to paired plugins + "SEMAPHORE": None, + + "VALIDATION_RULES": { + **BasePlugin.CONFIG['VALIDATION_RULES'], + } +} + + +class BaseInferenceApiPlugin( + BasePlugin, + _BaseAgentMixin +): + CONFIG = _CONFIG + + STATUS_PENDING = "pending" + STATUS_COMPLETED = "completed" + STATUS_FAILED = "failed" + STATUS_TIMEOUT = "timeout" + + def on_init(self): + """ + Initialize plugin state and restore persisted request metadata. + + Returns + ------- + None + Method has no return value; it prepares in-memory stores, metrics, and persistence. + """ + super(BaseInferenceApiPlugin, self).on_init() + if not self.cfg_ai_engine: + err_msg = f"AI_ENGINE must be specified for {self.get_signature()} plugin." + self.P(err_msg) + raise ValueError(err_msg) + # endif AI_ENGINE not specified + self._requests: Dict[str, Dict[str, Any]] = {} + self._api_errors: Dict[str, Dict[str, Any]] = {} + # TODO: add inference metrics tracking (latency, tokens, etc) + self._metrics = { + 'requests_total': 0, + 'requests_completed': 0, + 'requests_failed': 0, + 'requests_timeout': 0, + 'requests_active': 0, + } + self._rate_limit_state: Dict[str, Dict[str, Any]] = {} + # This is different from self.last_error_time in BasePlugin + # self.last_error_time tracks unhandled errors that occur in the plugin loop + # This one tracks all errors that occur during API request handling + self.last_handled_error_time = None + self.last_metrics_refresh = 0 + self.last_persistence_save = 0 + self.load_persistence_data() + tunneling_str = f"(with tunneling enabled)" if self.cfg_tunnel_engine_enabled else "" + start_msg = f"{self.get_signature()} initialized{tunneling_str}.\n" + lst_endpoint_names = list(self._endpoints.keys()) + endpoints_str = ", ".join([f"/{endpoint_name}" for endpoint_name in lst_endpoint_names]) + start_msg += f"\t\tEndpoints: {endpoints_str}\n" + start_msg += f"\t\tAI Engine: {self.cfg_ai_engine}\n" + start_msg += f"\t\tLoopback key: loopback_dct_{self._stream_id}" + self.P(start_msg) + return + + def _get_payload_field(self, data: dict, key: str, default=None): + """ + Retrieve a value from payload data using case-insensitive lookup. + + Parameters + ---------- + data : dict + Payload dictionary to search. + key : str + Target key to retrieve (case-insensitive). + default : Any, optional + Fallback value when the key is not present. + + Returns + ------- + Any + Matched value from the payload or the provided default. + """ + if not isinstance(data, dict): + return default + if key in data: + return data[key] + key_upper = key.upper() + if key_upper in data: + return data[key_upper] + return default + + def _setup_semaphore_env(self): + """Set semaphore environment variables for bundled plugins.""" + localhost_ip = self.log.get_localhost_ip() + port = self.cfg_port + self.semaphore_set_env('API_HOST', localhost_ip) + if port: + self.semaphore_set_env('API_PORT', str(port)) + self.semaphore_set_env('API_URL', f'http://{localhost_ip}:{port}') + return + + """PERSISTENCE + STATUS""" + if True: + def load_persistence_data(self): + """ + Restore cached request data, errors, and metrics from persistence. + + Returns + ------- + None + Updates in-memory state if cached data is available. + """ + cached_data = self.cacheapi_load_pickle() + if cached_data is not None: + # Useful only for debugging purposes + self._requests = cached_data.get('_requests', {}) + self._api_errors = cached_data.get('_api_errors', {}) + self._metrics = cached_data.get('_metrics', {}) + self.last_handled_error_time = cached_data.get('last_handled_error_time', None) + # endif cached_data is not None + return + + def maybe_save_persistence_data(self, force=False): + """ + Persist current request tracking state when needed. + + Parameters + ---------- + force : bool, optional + If True, persistence is forced regardless of elapsed time. + + Returns + ------- + None + Saves request, error, and metric data when the save interval has elapsed or force is True. + """ + if force or (self.time() - self.last_persistence_save) > self.cfg_save_period: + data_to_save = { + '_requests': self._requests, + '_api_errors': self._api_errors, + '_metrics': self._metrics, + 'last_handled_error_time': self.last_handled_error_time, + } + self.cacheapi_save_pickle(data_to_save) + self.last_persistence_save = self.time() + # endif needs saving + return + + def cleanup_expired_requests(self): + """ + Remove completed requests that exceeded the TTL window. + + Returns + ------- + None + Evicts expired request entries and logs eviction counts when applicable. + """ + ttl_seconds = self.cfg_request_ttl_seconds + if ttl_seconds <= 0: + return + now_ts = self.time() + expired_ids = [] + for request_id, request_data in self._requests.items(): + finished_at = request_data.get('finished_at') + if finished_at is None: + continue + if (now_ts - finished_at) > ttl_seconds: + expired_ids.append(request_id) + for request_id in expired_ids: + self._requests.pop(request_id, None) + if expired_ids: + self.Pd(f"Evicted {len(expired_ids)} completed requests due to TTL policy.") + return + + def record_api_error(self, request_id: Optional[str], error_message: str): + """ + Record an API handling error and update metrics. + + Parameters + ---------- + request_id : str or None + Identifier of the request that failed, if available. + error_message : str + Description of the error encountered. + + Returns + ------- + None + Stores the error entry and increments failure metrics. + """ + self.last_handled_error_time = self.time() + key = request_id or f"error_{self.last_handled_error_time}" + self._api_errors[key] = { + 'request_id': request_id, + 'message': error_message, + 'ts': self.last_handled_error_time, + } + self._metrics['requests_failed'] += 1 + return + + def get_status(self): + """ + Compute the current status of the API based on recent errors. + + Returns + ------- + str + 'ok' when healthy, otherwise a degraded status annotated with time since last error. + """ + last_error_time = self.last_handled_error_time + status = "ok" + if last_error_time is not None: + delta_seconds = (self.time() - last_error_time) + if delta_seconds < 300: + status = f"degraded (last error {int(delta_seconds)}s ago)" + # endif enough time has passed since last error + # endif last_error_time is not None + return status + """END PERSISTENCE + STATUS""" + + """SECURITY + RATE LIMITING""" + if True: + def check_allow_all_requests(self): + """ + In case the API is not using tunneling and is only accessible locally, + we can allow all requests without token checks. + + Returns + ------- + bool + True if all requests are allowed without authentication. + """ + if self.cfg_is_loopback_plugin and not self.cfg_tunnel_engine_enabled: + return True + return False + + def env_allowed_tokens(self): + """ + Retrieve allowed tokens from the configured environment variable. + + Returns + ------- + list of str + Token strings parsed from the configured auth environment variable. + """ + env_name = self.cfg_auth_token_env + if not env_name: + return [] + raw_value = self.os_environ.get(env_name, '').strip() + if not raw_value: + return [] + return [token.strip() for token in raw_value.split(',') if token.strip()] + + def _configured_tokens(self) -> List[str]: + """ + Aggregate authentication tokens from environment and configuration. + + Returns + ------- + list of str + Unique list of tokens allowed for request authorization. + """ + env_tokens = self.env_allowed_tokens() + predefined_tokens = self.cfg_predefined_auth_tokens or [] + all_tokens = set(env_tokens + predefined_tokens) + return list(all_tokens) + + def authorize_request(self, authorization: Optional[str]) -> str: + """ + Validate the authorization header and return the associated subject. + + Parameters + ---------- + authorization : str or None + Value of the Authorization header, expected to contain a bearer token. + + Returns + ------- + str + Identified subject token or 'anonymous' when anonymous access is permitted. + + Raises + ------ + PermissionError + If authorization is required and the provided token is missing or invalid. + """ + if self.check_allow_all_requests(): + # TODO: should the apps using this API also have identification tokens for usage analytics? + return "anonymous" + tokens = self._configured_tokens() + if not tokens: + if not self.cfg_allow_anonymous_access: + raise PermissionError( + "Authorization required but no tokens were configured. Provide tokens via INFERENCE_API_TOKEN." + ) + return "anonymous" + if authorization is None: + raise PermissionError("Missing Authorization header.") + token = authorization + if token.startswith('Bearer '): + token = token[7:] + token = token.strip() + if token not in tokens: + raise PermissionError("Invalid Authorization token.") + return token + + def enforce_rate_limit(self, subject: str): + """ + Enforce per-subject rate limiting when configured. + + Parameters + ---------- + subject : str + Identifier for the client or token being rate limited. + + Returns + ------- + None + Increments rate limit counters or raises if the limit is exceeded. + + Raises + ------ + RuntimeError + When the subject exceeds the configured requests-per-minute threshold. + """ + # TODO: maybe make the rate limit window configurable + if self.check_allow_all_requests(): + return + limit = self.cfg_rate_limit_per_minute + if limit <= 0: + return + bucket_key = subject or 'anonymous' + now_minute = int(self.time() // 60) + bucket = self._rate_limit_state.get(bucket_key) + if bucket is None or bucket['minute'] != now_minute: + bucket = {'minute': now_minute, 'count': 0} + self._rate_limit_state[bucket_key] = bucket + if bucket['count'] >= limit: + raise RuntimeError( + f"Rate limit exceeded for subject '{bucket_key}'. Max {limit} requests per minute." + ) + bucket['count'] += 1 + return + """END SECURITY + RATE LIMITING""" + + + """REQUEST TRACKING""" + if True: + def refresh_metrics(self): + """ + Update the active request count based on current request statuses. + + Returns + ------- + None + Recomputes in-memory metrics for active requests. + """ + self._metrics['requests_active'] = sum( + 1 for req in self._requests.values() if req['status'] == self.STATUS_PENDING + ) + # Maybe update other metrics here if needed + # Need to consider if expired requests should be counted in total metrics. + return + + def maybe_refresh_metrics(self): + """ + Refresh metrics when the refresh interval has elapsed. + + Returns + ------- + None + Triggers metric recomputation based on cfg_metrics_refresh_seconds. + """ + # For performance, we only refresh metrics every cfg_metrics_refresh_seconds + now_ts = self.time() + if (now_ts - self.last_metrics_refresh) > self.cfg_metrics_refresh_seconds: + self.refresh_metrics() + self.last_metrics_refresh = now_ts + return + + def maybe_mark_request_failed(self, request_id: str, request_data: Dict[str, Any]): + """ + Mark a pending request as failed when an error is attached. + + Parameters + ---------- + request_id : str + Identifier of the tracked request. + request_data : dict + Stored request record containing status and error information. + + Returns + ------- + bool + True when the request transitions to failed, False otherwise. + """ + if request_data['status'] == self.STATUS_FAILED: + return True + if request_data['status'] != self.STATUS_PENDING: + return False + error = request_data.get('error', None) + if error is None: + return False + self.P(f"Request {request_id} failed: {error}") + request_data['status'] = self.STATUS_FAILED + request_data['updated_at'] = self.time() + request_data['result'] = { + 'error': error, + 'status': self.STATUS_FAILED, + 'request_id': request_id, + } + self._metrics['requests_failed'] += 1 + self._metrics['requests_active'] -= 1 + return True + + def maybe_mark_request_timeout(self, request_id: str, request_data: Dict[str, Any]): + """ + Mark a pending request as timed out when exceeding the configured timeout. + + Parameters + ---------- + request_id : str + Identifier of the tracked request. + request_data : dict + Stored request record containing timestamps and timeout settings. + + Returns + ------- + bool + True when the request transitions to timeout, False otherwise. + """ + if request_data['status'] == self.STATUS_TIMEOUT: + return True + if request_data['status'] != self.STATUS_PENDING: + return False + timeout = request_data.get('timeout', self.cfg_request_timeout) + # No timeout configured + if timeout is None or timeout <= 0: + return False + if (self.time() - request_data['created_at']) <= timeout: + return False + self.P(f"Request {request_id} timed out after {timeout} seconds.") + request_data['status'] = self.STATUS_TIMEOUT + request_data['updated_at'] = self.time() + request_data['error'] = f"Request timed out after {timeout} seconds." + request_data['result'] = { + 'error': request_data['error'], + 'status': self.STATUS_TIMEOUT, + 'request_id': request_id, + 'timeout': timeout, + } + self._metrics['requests_timeout'] += 1 + self._metrics['requests_active'] -= 1 + return True + + def solve_postponed_request(self, request_id: str): + """ + Resolve or requeue a postponed request by checking its current status. + + Parameters + ---------- + request_id : str + Identifier of the request to resolve. + + Returns + ------- + dict + Request result when completed or failed, or a PostponedRequest for pending work. + """ + if request_id in self._requests: + self.Pd(f"Checking status of request ID {request_id}...") + request_data = self._requests[request_id] + + self.maybe_mark_request_timeout(request_id=request_id, request_data=request_data) + self.maybe_mark_request_failed(request_id=request_id, request_data=request_data) + if request_data['status'] != self.STATUS_PENDING: + return request_data['result'] + # endif request not pending + else: + self.Pd(f"Request ID {request_id} not found in requests.") + return { + 'status': 'error', + "error": f"Request ID {request_id} not found.", + 'request_id': request_id, + } + # endif request exists + return self.create_postponed_request( + solver_method=self.solve_postponed_request, + method_kwargs={ + "request_id": request_id + } + ) + + def register_request( + self, + subject: str, + parameters: Dict[str, Any], + metadata: Optional[Dict[str, Any]] = None, + timeout: Optional[int] = None + ): + """ + Register a new inference request and initialize tracking metadata. + + Parameters + ---------- + subject : str + Identifier representing the caller (token or user). + parameters : dict + Request parameters to forward to the inference engine. + metadata : dict, optional + Additional metadata to store with the request. + timeout : int or None, optional + Override for request timeout in seconds. + + Returns + ------- + tuple + Generated request_id and the stored request data dictionary. + """ + request_id = self.uuid() + start_time = self.time() + request_data = { + "request_id": request_id, + 'subject': subject, + 'parameters': parameters, + 'metadata': metadata or {}, + 'status': self.STATUS_PENDING, + 'created_at': start_time, + 'updated_at': start_time, + 'timeout': timeout or self.cfg_request_timeout, + 'result': None, + 'error': None, + } + self._requests[request_id] = request_data + self._metrics['requests_total'] += 1 + self._metrics['requests_active'] += 1 + return request_id, request_data + + def serialize_request(self, request_id: str): + """ + Produce a client-friendly view of a tracked request. + + Parameters + ---------- + request_id : str + Identifier of the request to serialize. + + Returns + ------- + dict or None + Serialized request data including status and metadata, or None if not found. + """ + request_data = self._requests.get(request_id) + if request_data is None: + return None + serialized = { + 'request_id': request_id, + 'status': request_data['status'], + 'created_at': request_data['created_at'], + 'updated_at': request_data['updated_at'], + 'metadata': request_data.get('metadata') or {}, + 'subject': request_data.get('subject'), + } + if request_data['status'] != self.STATUS_PENDING: + serialized['result'] = request_data['result'] + if request_data.get('error') is not None: + serialized['error'] = request_data['error'] + return serialized + """END REQUEST TRACKING""" + + """API ENDPOINTS""" + if True: + @BasePlugin.endpoint(method="GET") + def health(self): + """ + Health check endpoint exposing plugin status and metrics. + + Returns + ------- + dict + Status information including uptime, last error time, and request metrics. + """ + return { + "status": self.get_status(), + "pipeline": self.get_stream_id(), + "plugin": self.get_signature(), + "instance_id": self.get_instance_id(), + "loopback_enabled": self.cfg_is_loopback_plugin, + "uptime": self.get_alive_time(), + "last_error_time": self.last_handled_error_time, + "total_errors": len(self._api_errors), + "metrics": self._metrics, + } + + @BasePlugin.endpoint(method="GET") + def status(self): + """ + Status endpoint summarizing API state. + + Returns + ------- + dict + Basic service info plus counts of pending and completed requests. + """ + pending = len([ + rid for rid, data in self._requests.items() + if data.get('status') == self.STATUS_PENDING + ]) + completed = len([ + rid for rid, data in self._requests.items() + if data.get('status') == self.STATUS_COMPLETED + ]) + return { + "status": self.get_status(), + "service": self.cfg_api_summary, + "version": __VER__, + "stream_id": self.get_stream_id(), + "plugin": self.get_signature(), + "instance_id": self.get_instance_id(), + "total_requests": len(self._requests), + "pending_requests": pending, + "completed_requests": completed, + "uptime_seconds": self.get_alive_time(), + } + + @BasePlugin.endpoint(method="GET") + def metrics(self): + """ + Metrics endpoint summarizing request counts and active requests. + + Returns + ------- + dict + Metric counters and identifiers of currently pending requests. + """ + self.maybe_refresh_metrics() + return { + "metrics": self._metrics, + "active_requests": [ + rid for rid, data in self._requests.items() + if data['status'] == self.STATUS_PENDING + ], + 'errors_tracked': len(self._api_errors), + } + + @BasePlugin.endpoint(method="GET") + def request_status(self, request_id: str, return_full: bool = False): + """ + Retrieve the status and result of a previously submitted request. + + Parameters + ---------- + request_id : str + The unique identifier of the request to retrieve. + return_full : bool, optional + If True, return the full serialized request data including status and result. + + Returns + ------- + dict + If return_full is True, returns the full serialized request data. + If the request is still pending, returns only the request_id and status. + If the request is completed, returns the result of the request. + If the request_id is not found, returns an error message. + """ + serialized = self.serialize_request(request_id=request_id) + if serialized is None: + return { + "error": f"Request ID {request_id} not found.", + 'request_id': request_id, + } + if return_full: + return serialized + if serialized['status'] == self.STATUS_PENDING: + return { + 'request_id': request_id, + 'status': serialized['status'], + } + return serialized['result'] + + @BasePlugin.endpoint(method="POST") + def predict( + self, + authorization: Optional[str] = None, + **kwargs + ): + """ + Synchronous prediction entrypoint. + + Parameters + ---------- + authorization : str or None, optional + Authorization token supplied by the caller. + **kwargs + Additional parameters forwarded to request handling. + + Returns + ------- + dict + Request result or error payload for synchronous processing. + """ + return self._predict_entrypoint( + authorization=authorization, + async_request=False, + **kwargs + ) + + @BasePlugin.endpoint(method="POST") + def predict_async( + self, + authorization: Optional[str] = None, + **kwargs + ): + """ + Asynchronous prediction entrypoint. + + Parameters + ---------- + authorization : str or None, optional + Authorization token supplied by the caller. + **kwargs + Additional parameters forwarded to request handling. + + Returns + ------- + dict + Tracking information for the pending request or error payload. + """ + return self._predict_entrypoint( + authorization=authorization, + async_request=True, + **kwargs + ) + """END API ENDPOINTS""" + + """CHAT COMPLETION SECTION""" + if True: + def check_predict_params(self, **kwargs): + """ + Hook for checking generic predict parameters. + Will have all the parameters passed to the /predict endpoint. + Parameters + ---------- + kwargs : dict + The parameters to check. + + Returns + ------- + str or None + An error message if parameters are invalid, otherwise None. + """ + return None + + def process_predict_params(self, **kwargs): + """ + Hook for processing generic predict parameters. + Will have all the parameters passed to the /predict endpoint. + Parameters + ---------- + kwargs : dict + The parameters to process. + + Returns + ------- + dict + The processed parameters. + """ + return kwargs + + def compute_payload_kwargs_from_predict_params( + self, + request_id: str, + request_data: Dict[str, Any], + ): + """ + Build payload fields from request parameters for downstream processing. + + Parameters + ---------- + request_id : str + Identifier of the request being processed. + request_data : dict + Stored request record containing parameters and metadata. + + Returns + ------- + dict + Payload keyword arguments to dispatch to the inference engine. + """ + return { + 'REQUEST_ID': request_id, + **request_data, + } + + def _predict_entrypoint( + self, + authorization: Optional[str], + async_request: bool, + **kwargs + ): + """ + Shared prediction handler performing auth, validation, and dispatch. + + Parameters + ---------- + authorization : str or None + Authorization token provided by the client. + async_request : bool + Whether the request should be processed asynchronously. + **kwargs + Arbitrary parameters passed to validation and request processing. + + Returns + ------- + dict + Response payload containing request status, errors, or results. + """ + try: + subject = self.authorize_request(authorization) + self.enforce_rate_limit(subject) + except PermissionError as exc: + return {'error': str(exc), 'status': 'unauthorized'} + except RuntimeError as exc: + return {'error': str(exc), 'status': 'rate_limited'} + except Exception as exc: + return {'error': f"Unexpected error: {str(exc)}", 'status': 'error'} + # endtry + + err = self.check_predict_params(**kwargs) + if err is not None: + return {'error': err} + parameters = self.process_predict_params(**kwargs) + metadata = {} + if 'metadata' in parameters: + metadata = parameters.pop('metadata') or {} + # endif 'metadata' in parameters + request_id, request_data = self.register_request( + subject=subject, + parameters=parameters, + metadata=metadata, + timeout=parameters.get('timeout') + ) + payload_kwargs = self.compute_payload_kwargs_from_predict_params( + request_id=request_id, + request_data=request_data, + ) + self.Pd( + f"Dispatching request {request_id} :: {self.json_dumps(payload_kwargs, indent=2)[:500]}" + ) + self.add_payload_by_fields( + **payload_kwargs, + signature=self.get_signature() + ) + + if async_request: + return { + 'request_id': request_id, + 'poll_url': f"/request_status?request_id={request_id}", + 'status': self.STATUS_PENDING, + } + return self.solve_postponed_request(request_id=request_id) + """END CHAT COMPLETION SECTION""" + + """INFERENCE HANDLING""" + if True: + def filter_valid_inference(self, inference): + """ + Validate that an inference payload corresponds to a tracked request. + + Parameters + ---------- + inference : dict + Inference payload produced by the downstream engine. + + Returns + ------- + bool + True when the inference is accepted for processing, False otherwise. + """ + is_valid = super(BaseInferenceApiPlugin, self).filter_valid_inference(inference=inference) + if is_valid: + request_id = inference.get('REQUEST_ID', None) + if request_id is None or request_id not in self._requests: + is_valid = False + # endif not is_valid + return is_valid + """END INFERENCE HANDLING""" + + def process(self): + """ + Main plugin loop handler to refresh metrics, prune requests, and process inferences. + + Returns + ------- + None + Drives inference handling for the current iteration. + """ + self.maybe_refresh_metrics() + self.cleanup_expired_requests() + self.maybe_save_persistence_data() + data = self.dataapi_struct_datas() + inferences = self.dataapi_struct_data_inferences() + self.handle_inferences(inferences=inferences, data=data) + return diff --git a/extensions/business/edge_inference_api/cv_inference_api.py b/extensions/business/edge_inference_api/cv_inference_api.py new file mode 100644 index 00000000..852c862e --- /dev/null +++ b/extensions/business/edge_inference_api/cv_inference_api.py @@ -0,0 +1,486 @@ +""" +CV_INFERENCE_API Plugin + +Production-Grade Computer Vision Inference API + +This plugin exposes a hardened, FastAPI-powered interface for +computer-vision workloads. It reuses the BaseInferenceApi request lifecycle +while tailoring validation and response shaping for image analysis. + +Highlights +- Loopback-only surface paired with local third-party applications that use it +- Request tracking, persistence, auth, and rate limiting from BaseInferenceApi +- Base64 payload validation and metadata normalization for serving plugins +- Structured mapping of struct_data payloads and inferences back to requests +""" + +from typing import Any, Dict, Optional + +from extensions.business.edge_inference_api.base_inference_api import BaseInferenceApiPlugin as BasePlugin + + +__VER__ = '0.1.0' + +_CONFIG = { + **BasePlugin.CONFIG, + "API_TITLE": "CV Inference API", + "API_SUMMARY": "Local image analysis API", + "REQUEST_TIMEOUT": 240, + "MIN_IMAGE_DATA_LENGTH": 100, + + "VALIDATION_RULES": { + **BasePlugin.CONFIG['VALIDATION_RULES'], + "MIN_IMAGE_DATA_LENGTH": { + "DESCRIPTION": "Minimum base64 payload length used for coarse validation.", + "TYPE": "int", + "MIN_VAL": 10, + "MAX_VAL": 1_000_000, + }, + 'REQUEST_TIMEOUT': { + 'DESCRIPTION': 'Timeout for PostponedRequest polling (seconds)', + 'TYPE': 'int', + 'MIN_VAL': 30, + 'MAX_VAL': 600, + }, + }, +} + + +class CvInferenceApiPlugin(BasePlugin): + CONFIG = _CONFIG + + """VALIDATION""" + if True: + def check_predict_params( + self, + image_data: str, + metadata: Optional[Dict[str, Any]] = None, + **kwargs + ): + """ + Validate input parameters for image prediction requests. + + Parameters + ---------- + image_data : str + Base64-encoded image string. + metadata : dict or None, optional + Optional metadata accompanying the request. + **kwargs + Additional parameters ignored by validation. + + Returns + ------- + str or None + Error message when validation fails, otherwise None. + """ + if not isinstance(image_data, str) or len(image_data) < self.cfg_min_image_data_length: + return ( + "Invalid or missing image data. " + f"Expecting base64 content with at least {self.cfg_min_image_data_length} characters." + ) + if metadata is not None and not isinstance(metadata, dict): + return "`metadata` must be a dictionary when provided." + return None + + def process_predict_params( + self, + image_data: str, + metadata: Optional[Dict[str, Any]] = None, + **kwargs + ): + """ + Normalize and forward parameters for request registration. + + Parameters + ---------- + image_data : str + Base64-encoded image string. + metadata : dict or None, optional + Optional metadata accompanying the request. + **kwargs + Additional parameters to propagate downstream. + + Returns + ------- + dict + Processed parameters ready for dispatch to the inference engine. + """ + cleaned_metadata = metadata or {} + return { + 'image_data': image_data, + 'metadata': cleaned_metadata, + 'request_type': 'prediction', + **{k: v for k, v in kwargs.items() if k not in {'metadata'}}, + } + + def compute_payload_kwargs_from_predict_params( + self, + request_id: str, + request_data: Dict[str, Any], + ): + """ + Build payload keyword arguments for Computer Vision inference. + + Parameters + ---------- + request_id : str + Identifier of the tracked request. + request_data : dict + Stored request record containing processed parameters. + + Returns + ------- + dict + Payload fields including image data, metadata, and submission info. + """ + params = request_data['parameters'] + submitted_at = request_data['created_at'] + metadata = params.get('metadata') or request_data.get('metadata') or {} + return { + 'request_id': request_id, + 'image_data': params['image_data'], + 'metadata': metadata, + 'type': params.get('request_type', 'prediction'), + 'submitted_at': submitted_at, + } + """END VALIDATION""" + + """API ENDPOINTS""" + if True: + @BasePlugin.endpoint(method="GET") + def list_results(self, limit: int = 50, include_pending: bool = False): + """ + List recent request results with optional pending entries. + + Parameters + ---------- + limit : int, optional + Maximum number of results to return (bounded to 1..100). + include_pending : bool, optional + Whether to include still-pending requests in the output. + + Returns + ------- + dict + Summary of results and metadata for each tracked request. + """ + limit = min(max(1, limit), 100) + results = [] + for request_id, request_data in self._requests.items(): + status = request_data.get('status') + if (not include_pending) and status == self.STATUS_PENDING: + continue + entry = { + 'request_id': request_id, + 'type': request_data.get('parameters', {}).get('request_type', 'prediction'), + 'status': status, + 'submitted_at': request_data.get('created_at'), + 'metadata': request_data.get('metadata') or {}, + } + if status != self.STATUS_PENDING and request_data.get('result') is not None: + entry['result'] = request_data['result'] + if request_data.get('error') is not None: + entry['error'] = request_data['error'] + results.append(entry) + results.sort(key=lambda item: item.get('submitted_at', 0), reverse=True) + results = results[:limit] + return { + "total_results": len(results), + "limit": limit, + "include_pending": include_pending, + "results": results + } + + @BasePlugin.endpoint(method="POST") + def predict( + self, + image_data: str = '', + metadata: Optional[Dict[str, Any]] = None, + authorization: Optional[str] = None, + **kwargs + ): + """ + Synchronous Computer Vision prediction endpoint. + + Parameters + ---------- + image_data : str, optional + Base64-encoded image string to analyze. + metadata : dict or None, optional + Optional metadata accompanying the request. + authorization : str or None, optional + Bearer token used for authentication. + **kwargs + Extra parameters forwarded to the base handler. + + Returns + ------- + dict + Result payload for synchronous processing or an error message. + """ + return super(CvInferenceApiPlugin, self).predict( + image_data=image_data, + metadata=metadata, + authorization=authorization, + **kwargs + ) + + @BasePlugin.endpoint(method="POST") + def predict_async( + self, + image_data: str = '', + metadata: Optional[Dict[str, Any]] = None, + authorization: Optional[str] = None, + **kwargs + ): + """ + Asynchronous Computer Vision prediction endpoint. + + Parameters + ---------- + image_data : str, optional + Base64-encoded image string to analyze. + metadata : dict or None, optional + Optional metadata accompanying the request. + authorization : str or None, optional + Bearer token used for authentication. + **kwargs + Extra parameters forwarded to the base handler. + + Returns + ------- + dict + Tracking payload for asynchronous processing or an error message. + """ + return super(CvInferenceApiPlugin, self).predict_async( + image_data=image_data, + metadata=metadata, + authorization=authorization, + **kwargs + ) + """END API ENDPOINTS""" + + """INFERENCE HANDLING""" + if True: + def _mark_request_failure(self, request_id: str, error_message: str): + """ + Mark a tracked request as failed and record error details. + + Parameters + ---------- + request_id : str + Identifier of the tracked request. + error_message : str + Description of the failure encountered. + + Returns + ------- + None + Updates request status and metrics in place. + """ + request_data = self._requests.get(request_id) + if request_data is None: + return + if request_data.get('status') != self.STATUS_PENDING: + return + self.P(f"Request {request_id} failed: {error_message}") + now_ts = self.time() + request_data['status'] = self.STATUS_FAILED + request_data['error'] = error_message + request_data['result'] = { + 'status': 'error', + 'error': error_message, + 'request_id': request_id, + } + request_data['finished_at'] = now_ts + request_data['updated_at'] = now_ts + self._metrics['requests_failed'] += 1 + self._metrics['requests_active'] -= 1 + return + + def _mark_request_completed( + self, + request_id: str, + request_data: Dict[str, Any], + inference_payload: Dict[str, Any], + metadata: Dict[str, Any], + ): + """ + Mark a tracked request as completed with the provided inference payload. + + Parameters + ---------- + request_id : str + Identifier of the tracked request. + request_data : dict + Stored request record to update. + inference_payload : dict + Result payload constructed from the inference output. + metadata : dict + Metadata associated with the request. + + Returns + ------- + None + Updates request status and metrics in place. + """ + now_ts = self.time() + request_data['status'] = self.STATUS_COMPLETED + request_data['finished_at'] = now_ts + request_data['updated_at'] = now_ts + request_data['result'] = inference_payload + self._metrics['requests_completed'] += 1 + self._metrics['requests_active'] -= 1 + return + + def _extract_request_id(self, payload: Optional[Dict[str, Any]], inference: Any): + """ + Extract a request identifier from payload or inference data. + + Parameters + ---------- + payload : dict or None + Structured data payload, if available. + inference : Any + Inference result that may contain identifiers. + + Returns + ------- + str or None + Extracted request ID when present, otherwise None. + """ + request_id = self._get_payload_field(payload, 'request_id') if payload else None + if request_id is None and isinstance(inference, dict): + request_id = self._get_payload_field(inference, 'request_id') + if request_id is None and isinstance(inference, dict): + request_id = self._get_payload_field(inference, 'REQUEST_ID') + return request_id + + def _build_result_from_inference( + self, + request_id: str, + inference: Dict[str, Any], + metadata: Dict[str, Any], + request_data: Dict[str, Any] + ): + """ + Construct a result payload from inference output and metadata. + + Parameters + ---------- + request_id : str + Identifier of the tracked request. + inference : dict + Inference result data. + metadata : dict + Metadata to include in the response. + request_data : dict + Stored request record for reference. + + Returns + ------- + dict + Structured result payload including analysis and image details. + + Raises + ------ + ValueError + If the inference result format is invalid. + RuntimeError + When the inference indicates an error status. + """ + return { + 'request_id': request_id, + 'inference': inference, + 'metadata': metadata or request_data.get('metadata') or {}, + } + + def handle_inference_for_request( + self, + request_id: str, + inference: Any, + metadata: Dict[str, Any] + ): + """ + Handle inference output for a specific tracked request. + + Parameters + ---------- + request_id : str + Identifier of the tracked request. + inference : Any + Inference payload to process. + metadata : dict + Metadata associated with the request. + + Returns + ------- + None + Updates request tracking based on inference success or failure. + """ + if request_id not in self._requests: + self.Pd(f"Received inference for unknown request_id {request_id}.") + return + request_data = self._requests[request_id] + if request_data.get('status') != self.STATUS_PENDING: + return + if inference is None: + self._mark_request_failure(request_id, "No inference result available.") + return + try: + result_payload = self._build_result_from_inference( + request_id=request_id, + inference=inference, + metadata=metadata, + request_data=request_data + ) + except Exception as exc: + self._mark_request_failure(request_id, str(exc)) + return + self._mark_request_completed( + request_id=request_id, + request_data=request_data, + inference_payload=result_payload, + metadata=metadata + ) + return + + def handle_inferences(self, inferences, data=None): + """ + Process incoming inferences and map them back to pending requests. + + Parameters + ---------- + inferences : list or Any + Inference outputs from the serving pipeline. + data : list or Any, optional + Optional data payloads paired with inferences. + + Returns + ------- + None + Iterates over incoming results and updates tracked requests. + """ + payloads = data if isinstance(data, list) else self.dataapi_struct_datas(full=False, as_list=True) or [] + inferences = inferences or [] + + if not payloads and not inferences: + return + + max_len = max(len(payloads), len(inferences)) + for idx in range(max_len): + payload = payloads[idx] if idx < len(payloads) else None + inference = inferences[idx] if idx < len(inferences) else None + request_id = self._extract_request_id(payload, inference) + if request_id is None: + self.Pd(f"No request_id found for index {idx}, skipping.") + continue + metadata = self._get_payload_field(payload, 'metadata', {}) if payload else {} + self.handle_inference_for_request( + request_id=request_id, + inference=inference, + metadata=metadata or {} + ) + return + """END INFERENCE HANDLING""" diff --git a/extensions/business/edge_inference_api/llm_inference_api.py b/extensions/business/edge_inference_api/llm_inference_api.py new file mode 100644 index 00000000..b4fc0db7 --- /dev/null +++ b/extensions/business/edge_inference_api/llm_inference_api.py @@ -0,0 +1,684 @@ +""" +LLM_INFERENCE_API Plugin + +Production-Grade LLM Inference API + +This plugin exposes a hardened, FastAPI-powered interface for chat-completion +style LLM workloads. It keeps the lightweight loopback data flow used by the +Ratio1 node while adding security, observability, and request lifecycle +management that mirrors hosted LLM APIs. + +It can work with both async and sync requests. +In case of sync requests, they will be processed using PostponedRequest objects. +Otherwise, the request_id will be returned immediately, and the client can poll for results. + +Highlights +- Bearer-token authentication with optional anonymous fallback (env driven) +- Per-subject rate limiting and structured audit logging with request metrics +- Durable, restart-safe request tracking with health/metrics/list endpoints +- Async + sync chat completions with OpenAI-compatible payload layout +- Automatic timeout handling, TTL-based eviction, and persistence to cacheapi + +Export `LLM_API_TOKEN` (comma-separated values for multiple clients) to enforce token +checks or provide the tokens through the `PREDEFINED_AUTH_TOKENS` config parameter. + +Available Endpoints: +- POST /predict - Predict endpoint (sync) +- POST /predict_async - Predict endpoint (async) +- POST /create_chat_completion - Alias for predict and replicating the OpenAI standard (sync) +- POST /create_chat_completion_async - Alias for predict and replicating the OpenAI standard (async) +- GET /health - Health check +- GET /metrics - Retrieve API metrics endpoint +- GET /status_request - Check for current status of async request results + +Example pipeline configuration: +{ + "NAME": "llm_inference_api", + "TYPE": "Loopback", + "PLUGINS": [ + { + "SIGNATURE": "LLM_INFERENCE_API", + "INSTANCES": [ + { + "INSTANCE_ID": "llm_interface", + "AI_ENGINE": "llama_cpp", + "PORT": , + "STARTUP_AI_ENGINE_PARAMS": { + "HF_TOKEN": "", + "MODEL_FILENAME": "llama-3.2-1b-instruct-q4_k_m.gguf", + "MODEL_NAME": "hugging-quants/Llama-3.2-1B-Instruct-Q4_K_M-GGUF", + "SERVER_COLLECTOR_TIMEDELTA": 360000 + } + } + ] + }, + { + "SIGNATURE": "WORKER_APP_RUNNER", + "INSTANCES": [ + { + "INSTANCE_ID": "third_party_app", + "PORT": , + "BUILD_AND_RUN_COMMANDS": [ + "npm install", + "npm run dev" + ], + "VCS_DATA": { + "PROVIDER": "github", + "USERNAME": "", + "TOKEN": "", + "REPO_URL": "", + "BRANCH": "main", + "POLL_INTERVAL": 60 + }, + "AUTOUPDATE": true, + "TUNNEL_ENGINE_ENABLED": true, + "CLOUDFLARE_TOKEN": "", + "ENV": { + "INFERENCE_API_HOST": "$R1EN_HOST_IP", + "INFERENCE_API_PORT": "" + }, + "HEALTH_CHECK": { + "PATH": "/health", + } + } + ] + } + ] +} +""" + +from extensions.business.edge_inference_api.base_inference_api import BaseInferenceApiPlugin as BasePlugin +from extensions.serving.mixins_llm.llm_utils import LlmCT + +from typing import Any, Dict, List, Optional + + +_CONFIG = { + **BasePlugin.CONFIG, + "AI_ENGINE": "llama_cpp_small", + + "API_TITLE": "LLM Inference API", + + "TEMPERATURE_MIN": 0.0, + "TEMPERATURE_MAX": 1.5, + "MIN_COMPLETION_TOKENS": 16, + "MAX_COMPLETION_TOKENS": 4096, + + 'VALIDATION_RULES': { + **BasePlugin.CONFIG['VALIDATION_RULES'], + }, +} + + +class LLMInferenceApiPlugin(BasePlugin): + CONFIG = _CONFIG + + """VALIDATION SECTION""" + if True: + def check_messages(self, messages: list[dict]): + """ + Validate chat messages payload structure. + + Parameters + ---------- + messages : list of dict + Sequence of chat messages including role and content fields. + + Returns + ------- + str or None + Error message when validation fails, otherwise None. + """ + if not isinstance(messages, list) or len(messages) == 0: + return "`messages` must be a non-empty list of message dicts." + for idx, message in enumerate(messages): + if not isinstance(message, dict): + return f"Message at index {idx} from `messages` must be a dict." + role = message.get('role', None) + content = message.get('content', None) + if role not in {'system', 'user', 'assistant', 'tool'}: + return f"Message {idx} has invalid role '{role}'." + if not isinstance(content, str) or not content.strip(): + return f"Message {idx} content must be a non-empty string." + return None + + def check_generation_params( + self, + temperature: float, + max_tokens: int, + top_p: float = 1.0, + **kwargs + ): + """ + Validate generation hyperparameters. + + Parameters + ---------- + temperature : float + Sampling temperature requested by the client. + max_tokens : int + Maximum number of tokens to generate. + top_p : float, optional + Nucleus sampling cutoff between 0 and 1. + **kwargs + Additional unused parameters. + + Returns + ------- + str or None + Error message when validation fails, otherwise None. + """ + if not self.cfg_temperature_min <= temperature <= self.cfg_temperature_max: + return ( + f"temperature must be between {self.cfg_temperature_min} and " + f"{self.cfg_temperature_max}." + ) + if not self.cfg_min_completion_tokens <= max_tokens <= self.cfg_max_completion_tokens: + return ( + f"max_tokens must be between {self.cfg_min_completion_tokens} and " + f"{self.cfg_max_completion_tokens}." + ) + if not 0 < top_p <= 1: + return "top_p must be between 0 and 1." + return None + + def normalize_messages(self, messages: List[Dict[str, Any]]): + """ + Normalize chat messages by trimming content. + + Parameters + ---------- + messages : list of dict + Original messages payload provided by the client. + + Returns + ------- + list of dict + Messages with whitespace-trimmed content fields. + """ + normalized = [] + for message in messages: + normalized.append({ + 'role': message['role'], + 'content': message['content'].strip(), + }) + return normalized + """END VALIDATION SECTION""" + + """API ENDPOINTS""" + if True: + @BasePlugin.endpoint(method="POST") + def predict( + self, + messages: List[Dict[str, Any]], + temperature: float = 0.7, + max_tokens: int = 512, + top_p: float = 1.0, + repeat_penalty: Optional[float] = 1.0, + metadata: Optional[Dict[str, Any]] = None, + authorization: Optional[str] = None, + **kwargs + ): + """ + Synchronous chat completion prediction endpoint. + + Parameters + ---------- + messages : list of dict + Chat history for the model to complete. + temperature : float, optional + Sampling temperature. + max_tokens : int, optional + Maximum number of tokens to generate. + top_p : float, optional + Nucleus sampling probability threshold. + repeat_penalty : float or None, optional + Penalty for repeated tokens if supported by the backend. + metadata : dict or None, optional + Additional metadata to store with the request. + authorization : str or None, optional + Bearer token used for authentication. + **kwargs + Extra parameters forwarded to the base handler. + + Returns + ------- + dict + Result payload for synchronous processing or an error message. + """ + return super(LLMInferenceApiPlugin, self).predict( + messages=messages, + temperature=temperature, + max_tokens=max_tokens, + top_p=top_p, + repeat_penalty=repeat_penalty, + metadata=metadata, + authorization=authorization, + **kwargs + ) + + @BasePlugin.endpoint(method="POST") + def predict_async( + self, + messages: List[Dict[str, Any]], + temperature: float = 0.7, + max_tokens: int = 512, + top_p: float = 1.0, + repeat_penalty: float = 1.0, + metadata: Optional[Dict[str, Any]] = None, + authorization: Optional[str] = None, + **kwargs + ): + """ + Asynchronous chat completion prediction endpoint. + + Parameters + ---------- + messages : list of dict + Chat history for the model to complete. + temperature : float, optional + Sampling temperature. + max_tokens : int, optional + Maximum number of tokens to generate. + top_p : float, optional + Nucleus sampling probability threshold. + repeat_penalty : float, optional + Penalty for repeated tokens if supported by the backend. + metadata : dict or None, optional + Additional metadata to store with the request. + authorization : str or None, optional + Bearer token used for authentication. + **kwargs + Extra parameters forwarded to the base handler. + + Returns + ------- + dict + Tracking payload for asynchronous processing or an error message. + """ + return super(LLMInferenceApiPlugin, self).predict_async( + messages=messages, + temperature=temperature, + max_tokens=max_tokens, + top_p=top_p, + repeat_penalty=repeat_penalty, + metadata=metadata, + authorization=authorization, + **kwargs + ) + + @BasePlugin.endpoint(method="POST") + def create_chat_completion( + self, + messages: List[Dict[str, Any]], + temperature: float = 0.7, + max_tokens: int = 512, + top_p: float = 1.0, + repeat_penalty: Optional[float] = 1.0, + metadata: Optional[Dict[str, Any]] = None, + authorization: Optional[str] = None, + **kwargs + ): + """ + Alias for predict endpoint, replicating the OpenAI chat completion interface. + + Parameters + ---------- + messages : list of dict + Chat history for the model to complete. + temperature : float, optional + Sampling temperature. + max_tokens : int, optional + Maximum number of tokens to generate. + top_p : float, optional + Nucleus sampling probability threshold. + repeat_penalty : float or None, optional + Penalty for repeated tokens if supported by the backend. + metadata : dict or None, optional + Additional metadata to store with the request. + authorization : str or None, optional + Bearer token used for authentication. + **kwargs + Extra parameters forwarded to the base handler. + + Returns + ------- + dict + Result payload for synchronous processing or an error message. + """ + return self.predict( + messages=messages, + temperature=temperature, + max_tokens=max_tokens, + top_p=top_p, + repeat_penalty=repeat_penalty, + metadata=metadata, + authorization=authorization, + **kwargs + ) + + @BasePlugin.endpoint(method="POST") + def create_chat_completion_async( + self, + messages: List[Dict[str, Any]], + temperature: float = 0.7, + max_tokens: int = 512, + top_p: float = 1.0, + repeat_penalty: float = 1.0, + metadata: Optional[Dict[str, Any]] = None, + authorization: Optional[str] = None, + **kwargs + ): + """ + Asynchronous alias mirroring OpenAI's chat completion API. + + Parameters + ---------- + messages : list of dict + Chat history for the model to complete. + temperature : float, optional + Sampling temperature. + max_tokens : int, optional + Maximum number of tokens to generate. + top_p : float, optional + Nucleus sampling probability threshold. + repeat_penalty : float, optional + Penalty for repeated tokens if supported by the backend. + metadata : dict or None, optional + Additional metadata to store with the request. + authorization : str or None, optional + Bearer token used for authentication. + **kwargs + Extra parameters forwarded to the base handler. + + Returns + ------- + dict + Tracking payload for asynchronous processing or an error message. + """ + return self.predict_async( + messages=messages, + temperature=temperature, + max_tokens=max_tokens, + top_p=top_p, + repeat_penalty=repeat_penalty, + metadata=metadata, + authorization=authorization, + **kwargs + ) + """END API ENDPOINTS""" + + """PREDICT ENDPOINT HANDLING""" + if True: + def check_predict_params( + self, + messages: List[Dict[str, Any]], + temperature: float, + max_tokens: int, + top_p: float = 1.0, + repeat_penalty: float = 1.0, + **kwargs + ): + """ + Validate request parameters for LLM predictions. + + Parameters + ---------- + messages : list of dict + Chat history for the model to complete. + temperature : float + Sampling temperature. + max_tokens : int + Maximum number of tokens to generate. + top_p : float, optional + Nucleus sampling probability threshold. + repeat_penalty : float, optional + Penalty for repeated tokens if supported by the backend. + **kwargs + Additional parameters not validated here. + + Returns + ------- + str or None + Error message when validation fails, otherwise None. + """ + err = self.check_messages(messages) + if err is not None: + return err + err = self.check_generation_params( + temperature=temperature, + max_tokens=max_tokens, + top_p=top_p, + **kwargs + ) + if err is not None: + return err + return None + + def process_predict_params( + self, + messages: List[Dict[str, Any]], + temperature: float, + max_tokens: int, + top_p: float = 1.0, + repeat_penalty: float = 1.0, + **kwargs + ): + """ + Normalize and forward parameters for request registration. + + Parameters + ---------- + messages : list of dict + Chat history for the model to complete. + temperature : float + Sampling temperature. + max_tokens : int + Maximum number of tokens to generate. + top_p : float, optional + Nucleus sampling probability threshold. + repeat_penalty : float, optional + Penalty for repeated tokens if supported by the backend. + **kwargs + Additional parameters to include as-is. + + Returns + ------- + dict + Processed parameters ready for dispatch. + """ + normalized_messages = self.normalize_messages(messages) + return { + 'messages': normalized_messages, + 'temperature': temperature, + 'max_tokens': max_tokens, + 'top_p': top_p, + 'repeat_penalty': repeat_penalty, + **kwargs + } + + def compute_payload_kwargs_from_predict_params( + self, + request_id: Optional[str], + request_data: Dict[str, Any] + ): + """ + Prepare payload fields for the loopback inference engine. + + Parameters + ---------- + request_id : str or None + Identifier of the registered request. + request_data : dict + Stored request record containing processed parameters. + + Returns + ------- + dict + Payload keyed for downstream LLM handling. + """ + request_parameters = request_data['parameters'] + return { + 'jeeves_content': { + 'REQUEST_ID': request_id, + 'request_type': 'LLM', + **request_parameters, + } + } + """END PREDICT ENDPOINT HANDLING""" + + """INFERENCE HANDLING""" + if True: + def inference_to_response(self, inference, model_name, input_data=None): + """ + Convert inference output into a lightweight response structure. + + Parameters + ---------- + inference : dict + Inference payload produced by the model. + model_name : str + Name of the model that generated the inference. + input_data : Any, optional + Optional original input for context. + + Returns + ------- + dict + Simplified response containing identifiers and text output. + """ + return { + 'REQUEST_ID': inference.get('REQUEST_ID'), + 'MODEL_NAME': model_name, + 'TEXT_RESPONSE': inference.get('text'), + } + + def handle_single_inference(self, inference, model_name=None, input_data=None): + """ + Handle a single inference result and update tracked request state. + + Parameters + ---------- + inference : dict + Inference payload produced by the model. + model_name : str or None, optional + Model name reported with the inference. + input_data : Any, optional + Optional original input for context. + + Returns + ------- + None + Updates request tracking and stores the completion payload. + """ + request_id = inference.get('REQUEST_ID', None) + self.Pd(f"Processing inference for request ID: {request_id}, model: {model_name}") + if request_id is None: + self.Pd("No REQUEST_ID found in inference. Skipping.") + return + request_data = self._requests.get(request_id) + if request_data is None: + self.Pd(f"Received inference for unknown request_id {request_id}.") + return + if request_data['status'] != self.STATUS_PENDING: + return + + response_payload = self.build_completion_response( + request_id=request_id, + model_name=model_name or request_data['model'], + inference=inference, + request_data=request_data + ) + request_data['result'] = response_payload + request_data['status'] = self.STATUS_COMPLETED + request_data['finished_at'] = self.time() + request_data['updated_at'] = request_data['finished_at'] + self._metrics['requests_completed'] += 1 + self._metrics['requests_active'] -= 1 + + text_response = inference.get(LlmCT.TEXT, None) + full_output = inference.get(LlmCT.FULL_OUTPUT, None) + # TODO: adapt this to match OpenAI-style response structure if flag active + self._requests[request_id]['result'] = { + 'REQUEST_ID': request_id, + 'MODEL_NAME': model_name, + 'TEXT_RESPONSE': text_response, + LlmCT.FULL_OUTPUT: full_output, + } + self._requests[request_id]['finished'] = True + return + + def build_completion_response( + self, + request_id: str, + model_name: str, + inference: dict, + request_data: dict + ): + """ + Build a completion-style response payload from an inference result. + TODO: adapt default response structure to match OpenAI-style APIs: + { + 'id': request_id, + 'object': 'chat.completion', + 'created': int(self.time()), + 'model': model_name, + 'choices': [ + { + 'index': 0, + 'message': { + 'role': 'assistant', + 'content': text_response, + }, + 'finish_reason': inference.get('finish_reason', 'stop'), + } + ], + 'usage': { + 'prompt_tokens': usage.get('prompt_tokens'), + 'completion_tokens': usage.get('completion_tokens'), + 'total_tokens': usage.get('total_tokens'), + }, + 'metadata': request_data.get('metadata') or {}, + } + Parameters + ---------- + request_id : str + Identifier of the tracked request. + model_name : str + Name of the model producing the inference. + inference : dict + Inference payload containing text and optional full output. + request_data : dict + Stored request record with metadata and parameters. + + Returns + ------- + dict + Chat-completion shaped response enriched with metadata and timestamps. + """ + text_response = inference.get(LlmCT.TEXT, None) + full_output = inference.get(LlmCT.FULL_OUTPUT, None) + + response_payload = { + 'REQUEST_ID': request_id, + 'MODEL_NAME': model_name, + 'TEXT_RESPONSE': text_response, + } + # Check if full_output is already an API-friendly dict. + # TODO: enhance this check based on expected structure. + if isinstance(full_output, dict): + response_payload = { + **response_payload, + **full_output, + } + else: + response_payload[LlmCT.FULL_OUTPUT] = full_output + # endif full_output is dict + response_payload['metadata'] = request_data.get('metadata') or {} + response_payload['object'] = 'chat.completion' + response_payload['created'] = int(self.time()) + response_payload['id'] = request_id + response_payload['model'] = model_name + return response_payload + """END INFERENCE HANDLING""" + diff --git a/extensions/business/edge_inference_api/sd_inference_api.py b/extensions/business/edge_inference_api/sd_inference_api.py new file mode 100644 index 00000000..4e94317c --- /dev/null +++ b/extensions/business/edge_inference_api/sd_inference_api.py @@ -0,0 +1,532 @@ +""" +SD_INFERENCE_API Plugin + +Production-Grade Structured Data Inference API + +This plugin exposes a hardened, FastAPI-powered interface for structured-data +workloads. It reuses the BaseInferenceApi request lifecycle while tailoring +validation and response shaping for general-purpose tabular/JSON inference. + +Highlights +- Loopback-only surface paired with local clients +- Request tracking, persistence, auth, and rate limiting from BaseInferenceApi +- Structured payload validation and metadata normalization +- Mapping of struct_data payloads and inferences back to requests +""" + +from typing import Any, Dict, Optional + +from extensions.business.edge_inference_api.base_inference_api import BaseInferenceApiPlugin as BasePlugin + + +__VER__ = '0.1.0' + +_CONFIG = { + **BasePlugin.CONFIG, + "API_TITLE": "Structured Data Inference API", + "API_SUMMARY": "Local structured-data analysis API for paired clients.", + "REQUEST_TIMEOUT": 240, + "MIN_STRUCT_DATA_FIELDS": 1, + + "VALIDATION_RULES": { + **BasePlugin.CONFIG['VALIDATION_RULES'], + "MIN_STRUCT_DATA_FIELDS": { + "DESCRIPTION": "Minimum number of top-level fields required in struct_data.", + "TYPE": "int", + "MIN_VAL": 1, + "MAX_VAL": 10000, + }, + 'REQUEST_TIMEOUT': { + 'DESCRIPTION': 'Timeout for PostponedRequest polling (seconds)', + 'TYPE': 'int', + 'MIN_VAL': 30, + 'MAX_VAL': 600, + }, + }, +} + + +class SdInferenceApiPlugin(BasePlugin): + CONFIG = _CONFIG + + def _normalize_struct_data(self, struct_data: Any): + """ + Normalize and validate struct_data input shape. + + Parameters + ---------- + struct_data : Any + Incoming structured data payload. + + Returns + ------- + tuple + (normalized_struct_data, error_message). error_message is None when valid. + """ + if isinstance(struct_data, dict): + if len(struct_data) < self.cfg_min_struct_data_fields: + return None, ( + f"`struct_data` must contain at least {self.cfg_min_struct_data_fields} fields." + ) + return struct_data, None + if isinstance(struct_data, list): + if not struct_data: + return None, "`struct_data` list must not be empty." + if not all(isinstance(item, dict) for item in struct_data): + return None, "`struct_data` list items must all be dictionaries." + return struct_data, None + return None, "`struct_data` must be a dictionary or list of dictionaries." + + """VALIDATION""" + if True: + def check_predict_params( + self, + struct_data: Any, + metadata: Optional[Dict[str, Any]] = None, + **kwargs + ): + """ + Validate input parameters for structured-data prediction requests. + + Parameters + ---------- + struct_data : Any + Structured payload (dict or list of dicts). + metadata : dict or None, optional + Optional metadata accompanying the request. + **kwargs + Additional parameters ignored by validation. + + Returns + ------- + str or None + Error message when validation fails, otherwise None. + """ + _, err = self._normalize_struct_data(struct_data) + if err: + return err + if metadata is not None and not isinstance(metadata, dict): + return "`metadata` must be a dictionary when provided." + return None + + def process_predict_params( + self, + struct_data: Any, + metadata: Optional[Dict[str, Any]] = None, + **kwargs + ): + """ + Normalize and forward parameters for request registration. + + Parameters + ---------- + struct_data : Any + Structured payload (dict or list of dicts). + metadata : dict or None, optional + Optional metadata accompanying the request. + **kwargs + Additional parameters to propagate downstream. + + Returns + ------- + dict + Processed parameters ready for dispatch to the inference engine. + """ + normalized_struct, _ = self._normalize_struct_data(struct_data) + cleaned_metadata = metadata or {} + return { + 'struct_data': normalized_struct, + 'metadata': cleaned_metadata, + 'request_type': 'prediction', + **{k: v for k, v in kwargs.items() if k not in {'metadata'}}, + } + + def compute_payload_kwargs_from_predict_params( + self, + request_id: str, + request_data: Dict[str, Any], + ): + """ + Build payload keyword arguments for structured-data inference. + + Parameters + ---------- + request_id : str + Identifier of the tracked request. + request_data : dict + Stored request record containing processed parameters. + + Returns + ------- + dict + Payload fields including struct_data, metadata, and submission info. + """ + params = request_data['parameters'] + submitted_at = request_data['created_at'] + metadata = params.get('metadata') or request_data.get('metadata') or {} + return { + 'request_id': request_id, + 'struct_data': params['struct_data'], + 'metadata': metadata, + 'type': params.get('request_type', 'prediction'), + 'submitted_at': submitted_at, + } + """END VALIDATION""" + + """API ENDPOINTS""" + if True: + @BasePlugin.endpoint(method="GET") + def list_results(self, limit: int = 50, include_pending: bool = False): + """ + List recent request results with optional pending entries. + + Parameters + ---------- + limit : int, optional + Maximum number of results to return (bounded to 1..100). + include_pending : bool, optional + Whether to include still-pending requests in the output. + + Returns + ------- + dict + Summary of results and metadata for each tracked request. + """ + limit = min(max(1, limit), 100) + results = [] + for request_id, request_data in self._requests.items(): + status = request_data.get('status') + if (not include_pending) and status == self.STATUS_PENDING: + continue + entry = { + 'request_id': request_id, + 'type': request_data.get('parameters', {}).get('request_type', 'prediction'), + 'status': status, + 'submitted_at': request_data.get('created_at'), + 'metadata': request_data.get('metadata') or {}, + } + if status != self.STATUS_PENDING and request_data.get('result') is not None: + entry['result'] = request_data['result'] + if request_data.get('error') is not None: + entry['error'] = request_data['error'] + results.append(entry) + results.sort(key=lambda item: item.get('submitted_at', 0), reverse=True) + results = results[:limit] + return { + "total_results": len(results), + "limit": limit, + "include_pending": include_pending, + "results": results + } + + @BasePlugin.endpoint(method="POST") + def predict( + self, + struct_data: Any = None, + metadata: Optional[Dict[str, Any]] = None, + authorization: Optional[str] = None, + **kwargs + ): + """ + Synchronous structured-data prediction endpoint. + + Parameters + ---------- + struct_data : Any, optional + Structured payload (dict or list of dicts) to analyze. + metadata : dict or None, optional + Optional metadata accompanying the request. + authorization : str or None, optional + Bearer token used for authentication. + **kwargs + Extra parameters forwarded to the base handler. + + Returns + ------- + dict + Result payload for synchronous processing or an error message. + """ + return super(SdInferenceApiPlugin, self).predict( + struct_data=struct_data, + metadata=metadata, + authorization=authorization, + **kwargs + ) + + @BasePlugin.endpoint(method="POST") + def predict_async( + self, + struct_data: Any = None, + metadata: Optional[Dict[str, Any]] = None, + authorization: Optional[str] = None, + **kwargs + ): + """ + Asynchronous structured-data prediction endpoint. + + Parameters + ---------- + struct_data : Any, optional + Structured payload (dict or list of dicts) to analyze. + metadata : dict or None, optional + Optional metadata accompanying the request. + authorization : str or None, optional + Bearer token used for authentication. + **kwargs + Extra parameters forwarded to the base handler. + + Returns + ------- + dict + Tracking payload for asynchronous processing or an error message. + """ + return super(SdInferenceApiPlugin, self).predict_async( + struct_data=struct_data, + metadata=metadata, + authorization=authorization, + **kwargs + ) + """END API ENDPOINTS""" + + """INFERENCE HANDLING""" + if True: + def _mark_request_failure(self, request_id: str, error_message: str): + """ + Mark a tracked request as failed and record error details. + + Parameters + ---------- + request_id : str + Identifier of the tracked request. + error_message : str + Description of the failure encountered. + + Returns + ------- + None + Updates request status and metrics in place. + """ + request_data = self._requests.get(request_id) + if request_data is None: + return + if request_data.get('status') != self.STATUS_PENDING: + return + self.P(f"Request {request_id} failed: {error_message}") + now_ts = self.time() + request_data['status'] = self.STATUS_FAILED + request_data['error'] = error_message + request_data['result'] = { + 'status': 'error', + 'error': error_message, + 'request_id': request_id, + } + request_data['finished_at'] = now_ts + request_data['updated_at'] = now_ts + self._metrics['requests_failed'] += 1 + self._metrics['requests_active'] -= 1 + return + + def _mark_request_completed( + self, + request_id: str, + request_data: Dict[str, Any], + inference_payload: Dict[str, Any], + metadata: Dict[str, Any], + ): + """ + Mark a tracked request as completed with the provided inference payload. + + Parameters + ---------- + request_id : str + Identifier of the tracked request. + request_data : dict + Stored request record to update. + inference_payload : dict + Result payload constructed from the inference output. + metadata : dict + Metadata associated with the request. + + Returns + ------- + None + Updates request status and metrics in place. + """ + now_ts = self.time() + request_data['status'] = self.STATUS_COMPLETED + request_data['finished_at'] = now_ts + request_data['updated_at'] = now_ts + request_data['result'] = inference_payload + self._metrics['requests_completed'] += 1 + self._metrics['requests_active'] -= 1 + return + + def _extract_request_id(self, payload: Optional[Dict[str, Any]], inference: Any): + """ + Extract a request identifier from payload or inference data. + + Parameters + ---------- + payload : dict or None + Structured data payload, if available. + inference : Any + Inference result that may contain identifiers. + + Returns + ------- + str or None + Extracted request ID when present, otherwise None. + """ + request_id = self._get_payload_field(payload, 'request_id') if payload else None + if request_id is None and isinstance(inference, dict): + request_id = self._get_payload_field(inference, 'request_id') + if request_id is None and isinstance(inference, dict): + request_id = self._get_payload_field(inference, 'REQUEST_ID') + return request_id + + def _build_result_from_inference( + self, + request_id: str, + inference: Dict[str, Any], + metadata: Dict[str, Any], + request_data: Dict[str, Any] + ): + """ + Construct a result payload from inference output and metadata. + + Parameters + ---------- + request_id : str + Identifier of the tracked request. + inference : dict + Inference result data. + metadata : dict + Metadata to include in the response. + request_data : dict + Stored request record for reference. + + Returns + ------- + dict + Structured result payload including prediction and auxiliary details. + + Raises + ------ + ValueError + If the inference result format is invalid. + RuntimeError + When the inference indicates an error status. + """ + if not isinstance(inference, dict): + raise ValueError("Invalid inference result format.") + inference_data = inference.get('data', inference) + status = inference_data.get('status', inference.get('status', 'completed')) + if status == 'error': + err_msg = inference_data.get('error', 'Unknown error') + raise RuntimeError(err_msg) + + prediction = inference_data.get('prediction', inference_data.get('result')) + result_payload = { + 'status': 'completed', + 'request_id': request_id, + 'prediction': prediction, + 'metadata': metadata or request_data.get('metadata') or {}, + 'processed_at': inference_data.get('processed_at', self.time()), + 'processor_version': inference_data.get('processor_version', 'unknown'), + } + if 'model_name' in inference_data: + result_payload['model_name'] = inference_data['model_name'] + if 'scores' in inference_data: + result_payload['scores'] = inference_data['scores'] + if 'probabilities' in inference_data: + result_payload['probabilities'] = inference_data['probabilities'] + return result_payload + + def handle_inference_for_request( + self, + request_id: str, + inference: Any, + metadata: Dict[str, Any] + ): + """ + Handle inference output for a specific tracked request. + + Parameters + ---------- + request_id : str + Identifier of the tracked request. + inference : Any + Inference payload to process. + metadata : dict + Metadata associated with the request. + + Returns + ------- + None + Updates request tracking based on inference success or failure. + """ + if request_id not in self._requests: + self.Pd(f"Received inference for unknown request_id {request_id}.") + return + request_data = self._requests[request_id] + if request_data.get('status') != self.STATUS_PENDING: + return + if inference is None: + self._mark_request_failure(request_id, "No inference result available.") + return + try: + result_payload = self._build_result_from_inference( + request_id=request_id, + inference=inference, + metadata=metadata, + request_data=request_data + ) + except Exception as exc: + self._mark_request_failure(request_id, str(exc)) + return + self._mark_request_completed( + request_id=request_id, + request_data=request_data, + inference_payload=result_payload, + metadata=metadata + ) + return + + def handle_inferences(self, inferences, data=None): + """ + Process incoming inferences and map them back to pending requests. + + Parameters + ---------- + inferences : list or Any + Inference outputs from the serving pipeline. + data : list or Any, optional + Optional data payloads paired with inferences. + + Returns + ------- + None + Iterates over incoming results and updates tracked requests. + """ + payloads = data if isinstance(data, list) else self.dataapi_struct_datas(full=False, as_list=True) or [] + inferences = inferences or [] + + if not payloads and not inferences: + return + + max_len = max(len(payloads), len(inferences)) + for idx in range(max_len): + payload = payloads[idx] if idx < len(payloads) else None + inference = inferences[idx] if idx < len(inferences) else None + request_id = self._extract_request_id(payload, inference) + if request_id is None: + self.Pd(f"No request_id found for index {idx}, skipping.") + continue + metadata = self._get_payload_field(payload, 'metadata', {}) if payload else {} + self.handle_inference_for_request( + request_id=request_id, + inference=inference, + metadata=metadata or {} + ) + return + """END INFERENCE HANDLING""" diff --git a/extensions/business/inference_api/base_inference_api.py b/extensions/business/inference_api/base_inference_api.py deleted file mode 100644 index 3ad897f2..00000000 --- a/extensions/business/inference_api/base_inference_api.py +++ /dev/null @@ -1,357 +0,0 @@ -""" -LOCAL_SERVING_API Plugin - -This plugin creates a FastAPI server for both local-only access (localhost) and through tunneling -that works with a loopback data capture pipeline. -It can work with both async and sync requests. -In case of sync requests, they will be processed using PostponedRequest objects. -Otherwise, the request_id will be returned immediately, and the client can poll for results. - -Key Features: -- Loopback mode: Outputs return to DCT queue for processing -- Designed for LLM chat completions - -Available Endpoints: -- POST /create_chat_completion - Create chat completion (sync) -- POST /create_chat_completion_async - Create chat completion (async) -- GET /health - Health check -- GET /status_request - Check for current status of async request results - -Example pipeline configuration: -{ - "NAME": "local_inference_api", - "TYPE": "Loopback", - "PLUGINS": [ - { - "SIGNATURE": "BASE_INFERENCE_API", - "INSTANCES": [ - { - "INSTANCE_ID": "llm_interface", - "AI_ENGINE": "llama_cpp", - "STARTUP_AI_ENGINE_PARAMS": { - "HF_TOKEN": "", - "MODEL_FILENAME": "llama-3.2-1b-instruct-q4_k_m.gguf", - "MODEL_NAME": "hugging-quants/Llama-3.2-1B-Instruct-Q4_K_M-GGUF", - "SERVER_COLLECTOR_TIMEDELTA": 360000 - } - } - ] - } - ] -} -""" -from naeural_core.business.default.web_app.fast_api_web_app import FastApiWebAppPlugin as BasePlugin -from extensions.business.mixins.nlp_agent_mixin import _NlpAgentMixin, NLP_AGENT_MIXIN_CONFIG - - -__VER__ = '0.1.0' - -_CONFIG = { - **BasePlugin.CONFIG, - **NLP_AGENT_MIXIN_CONFIG, - - # MANDATORY SETTING IN ORDER TO RECEIVE REQUESTS - "ALLOW_EMPTY_INPUTS": True, # allow processing even when no input data is present - - # MANDATORY LOOPBACK SETTINGS - "IS_LOOPBACK_PLUGIN": True, - "TUNNEL_ENGINE_ENABLED": False, - "API_TITLE": "Local Inference API", - "API_SUMMARY": "FastAPI server for local-only inference.", - - "PROCESS_DELAY": 0, - "REQUEST_TIMEOUT": 600, # 10 minutes - "SAVE_PERIOD": 300, # 5 minutes - - "VALIDATION_RULES": { - **BasePlugin.CONFIG['VALIDATION_RULES'], - } -} - - -class BaseInferenceApiPlugin( - BasePlugin, - _NlpAgentMixin -): - CONFIG = _CONFIG - - def on_init(self): - super(BaseInferenceApiPlugin, self).on_init() - self._requests = {} - self._api_errors = {} - # This is different from self.last_error_time in BasePlugin - # self.last_error_time tracks unhandled errors that occur in the plugin loop - # This one tracks all errors that occur during API request handling - self.last_handled_error_time = None - self.last_persistence_save = 0 - self.load_persistence_data() - return - - """UTIL METHODS""" - if True: - def load_persistence_data(self): - cached_data = self.cacheapi_load_pickle() - if cached_data is not None: - # Useful only for debugging purposes - self._requests = cached_data.get('_requests', {}) - self._api_errors = cached_data.get('_api_errors', {}) - self.last_handled_error_time = cached_data.get('last_handled_error_time', None) - # endif cached_data is not None - return - - def maybe_save_persistence_data(self, force=False): - if force or (self.time() - self.last_persistence_save) > self.cfg_save_period: - data_to_save = { - '_requests': self._requests, - '_api_errors': self._api_errors, - 'last_handled_error_time': self.last_handled_error_time, - } - self.cacheapi_save_pickle(data_to_save) - self.last_persistence_save = self.time() - # endif needs saving - return - - def get_status(self): - last_error_time = self.last_handled_error_time - status = "ok" - if last_error_time is not None: - delta_seconds = (self.time() - last_error_time) - if delta_seconds < 300: - status = f"degraded (last error {int(delta_seconds)}s ago)" - # endif enough time has passed since last error - # endif last_error_time is not None - return status - - def solve_postponed_request(self, request_id: str): - if request_id in self._requests: - self.Pd(f"Checking status of request ID {request_id}...") - request_data = self._requests[request_id] - start_time = request_data.get("start_time", None) - timeout = request_data.get("timeout", self.cfg_request_timeout) - is_finished = request_data.get("finished", False) - if is_finished: - return request_data["result"] - elif start_time is not None and (self.time() - start_time) > timeout: - self.Pd(f"Request ID {request_id} has timed out after {timeout} seconds.") - error_response = f"Request ID {request_id} has timed out after {timeout} seconds." - request_data['result'] = { - "error": error_response, - "request_id": request_id, - } - request_data["finished"] = True - return request_data['result'] - # endif check finished or timeout - else: - self.Pd(f"Request ID {request_id} not found in requests.") - return { - "error": f"Request ID {request_id} not found." - } - # endif request exists - return self.create_postponed_request( - solver_method=self.solve_postponed_request, - method_kwargs={ - "request_id": request_id - } - ) - - def register_request( - self, - **kwargs - ): - request_id = self.uuid() - start_time = self.time() - request_data = { - **kwargs, - "request_id": request_id, - "start_time": start_time, - "finished": None, - "error": None, - } - self._requests[request_id] = request_data - return request_id, request_data - """END UTIL METHODS""" - - """GENERIC API ENDPOINTS""" - if True: - @BasePlugin.endpoint(method="GET") - def health(self): - return { - "status": self.get_status(), - "pipeline": self.get_stream_id(), - "plugin": self.get_signature(), - "instance_id": self.get_instance_id(), - "loopback_enabled": self.cfg_is_loopback_plugin, - "uptime": self.get_alive_time(), - "last_error_time": self.last_handled_error_time, - "total_errors": len(self._api_errors), - } - - @BasePlugin.endpoint(method="GET") - def check_request(self, request_id: str): - res = { - "error": f"Request ID {request_id} not found." - } - if request_id in self._requests: - request_data = self._requests[request_id] - is_finished = request_data.get("finished", False) - if is_finished: - res = request_data["result"] - else: - res = { - "status": "pending", - "request_id": request_id, - } - # endif request exists - return res - """END GENERIC API ENDPOINTS""" - - """CHAT COMPLETION SECTION""" - if True: - """VALIDATION SECTION""" - if True: - def check_messages(self, messages: list[dict]): - err_msg = None - if not isinstance(messages, list) or len(messages) == 0: - err_msg = "`messages` must be a non-empty list of message dicts." - if err_msg is None and not all(isinstance(m, dict) for m in messages): - err_msg = "Each message in `messages` must be a dict." - if err_msg is not None: - all_messages_valid = all( - isinstance(m, dict) and - 'role' in m and isinstance(m['role'], str) and - 'content' in m and isinstance(m['content'], str) - for m in messages - ) - if err_msg is None and not all_messages_valid: - err_msg = "Each message dict must contain 'role' (str) and 'content' (str) keys." - # endif err_msg is not None - return err_msg - - def check_chat_completion_params( - self, - messages: list[dict], - temperature: float = 0.7, - max_tokens: int = 512, - repeat_penalty: float = 1.0, - **kwargs - ): - err_msg = None - err_msg = self.check_messages(messages) - - return err_msg - """END VALIDATION SECTION""" - - def create_chat_completion_helper( - self, - messages: list[dict], - temperature: float = 0.7, - max_tokens: int = 512, - repeat_penalty: float = 1.0, - async_request=False, - **kwargs - ): - err_msg = self.check_chat_completion_params( - messages=messages, - temperature=temperature, - max_tokens=max_tokens, - repeat_penalty=repeat_penalty, - **kwargs - ) - if err_msg is not None: - return { - "error": err_msg - } - # endif invalid params - request_id, request_data = self.register_request( - async_request=async_request, - **kwargs - ) - jeeves_content = { - 'REQUEST_ID': request_id, - **kwargs, - 'messages': messages, - 'temperature': temperature, - 'max_tokens': max_tokens, - 'repeat_penalty': repeat_penalty, - 'request_type': 'LLM', - } - self.Pd(f"Creating chat completion request {request_id} with data:\n{self.json_dumps(jeeves_content, indent=2)}") - self.add_payload_by_fields( - jeeves_content=jeeves_content, - signature=self.get_signature(), - ) - if async_request: - return { - "request_id": request_id, - "poll_url": f"/status_request?request_id={request_id}" - } - return self.solve_postponed_request(request_id=request_id) - - @BasePlugin.endpoint(method="POST") - def create_chat_completion( - self, - messages: list[dict], - temperature: float = 0.7, - max_tokens: int = 512, - repeat_penalty: float = 1.0, - **kwargs - ): - return self.create_chat_completion_helper( - async_request=False, - messages=messages, - temperature=temperature, - max_tokens=max_tokens, - repeat_penalty=repeat_penalty, - **kwargs - ) - - @BasePlugin.endpoint(method="POST") - def create_chat_completion_async( - self, - messages: list[dict], - temperature: float = 0.7, - max_tokens: int = 512, - repeat_penalty: float = 1.0, - **kwargs - ): - return self.create_chat_completion_helper( - async_request=True, - messages=messages, - temperature=temperature, - max_tokens=max_tokens, - repeat_penalty=repeat_penalty, - **kwargs - ) - """END CHAT COMPLETION SECTION""" - - def filter_valid_inference(self, inference): - is_valid = super(BaseInferenceApiPlugin, self).filter_valid_inference(inference=inference) - if is_valid: - request_id = inference.get('REQUEST_ID', None) - if request_id is None or request_id not in self._requests: - is_valid = False - # endif not is_valid - return is_valid - - def handle_single_inference(self, inference, model_name=None): - request_id = inference.get('REQUEST_ID', None) - self.Pd(f"Processing inference for request ID: {request_id}, model: {model_name}") - if request_id is None: - self.Pd("No REQUEST_ID found in inference; skipping.") - return - text_response = inference.get('text', None) - self._requests[request_id]['result'] = { - 'REQUEST_ID': request_id, - 'MODEL_NAME': model_name, - 'TEXT_RESPONSE': text_response, - } - self._requests[request_id]['finished'] = True - return - - def process(self): - self.maybe_save_persistence_data() - inferences = self.dataapi_struct_data_inferences() - self.handle_inferences(inferences=inferences) - return - - diff --git a/extensions/business/jeeves/jeeves_api.py b/extensions/business/jeeves/jeeves_api.py index fd217069..fa5f777b 100644 --- a/extensions/business/jeeves/jeeves_api.py +++ b/extensions/business/jeeves/jeeves_api.py @@ -13,6 +13,8 @@ 'CHAINSTORE_RESPONSE_KEY': None, "MAX_INPUTS_QUEUE_SIZE": 100, + # The verifying is not required since it is done in the DCT. + "SKIP_MESSAGE_VERIFY": True, 'PORT': 15033, 'ASSETS': 'extensions/business/fastapi/jeeves_api', diff --git a/extensions/business/mixins/base_agent_mixin.py b/extensions/business/mixins/base_agent_mixin.py new file mode 100644 index 00000000..18c2577c --- /dev/null +++ b/extensions/business/mixins/base_agent_mixin.py @@ -0,0 +1,93 @@ +BASE_AGENT_MIXIN_CONFIG = { + 'OBJECT_TYPE': [], + "ALLOW_EMPTY_INPUTS": False, # if this is set to true the on-idle will be triggered continuously the process + "DEBUG_LOGGING_ENABLED": True, + + "VALIDATION_RULES": { + }, +} + + +class _BaseAgentMixin(object): + def filter_valid_inference(self, inference): + return isinstance(inference, dict) and inference.get("IS_VALID", True) + + def filter_valid_inferences(self, inferences, return_idxs=False): + res = [] + idxs = [] + for idx, inf in enumerate(inferences): + if self.filter_valid_inference(inference=inf): + res.append(inf) + idxs.append(idx) + # endfor inferences + return res if not return_idxs else (res, idxs) + + def inference_to_response(self, inference, model_name, input_data): + return inference + + def handle_single_inference(self, inference: dict, model_name: str = None, input_data: dict = None): + """ + Method for handling a single inference, along with the input data that generated it. + + Parameters + ---------- + inference : dict + The inference dictionary + model_name: str, optional + The name of the model + input_data: dict, optional + The input data + """ + request_id = inference.get('REQUEST_ID', None) + self.Pd(f"Processing inference for request ID: {request_id}, model: {model_name}") + request_result = self.inference_to_response( + inference=inference, + model_name=model_name, + input_data=input_data + ) + current_payload_kwargs = { + 'result': request_result, + 'request_id': request_id, + } + self.add_payload_by_fields(**current_payload_kwargs) + return + + def handle_inferences(self, inferences, data=None): + """ + Method for handling list of inferences, along with the input data that generated them. + This will filter the valid inference and handle them using handle_single_inference() + + Parameters + ---------- + inferences : list + Array of inference dictionaries + data : dict or list, optional + List of inputs or dictionary of {int_idx: input_data} + """ + if not isinstance(inferences, list): + return + if len(inferences) > 0 and not isinstance(inferences[0], dict): + return + model_name = inferences[0].get('MODEL_NAME', None) if len(inferences) > 0 else None + cnt_initial_inferences = len(inferences) + inferences, valid_idxs = self.filter_valid_inferences(inferences, return_idxs=True) + self.Pd(f"Filtered {cnt_initial_inferences} inferences to {len(inferences)} valid inferences.") + filtered_data = None + if data is not None: + filtered_data = [ + data[idx] for idx in valid_idxs + ] + if len(filtered_data) > 0: + self.Pd(f"Received requests: {self.json_dumps(self.shorten_str(filtered_data), indent=2)}") + # endif data is not None + + for idx, inf in enumerate(inferences): + current_input = filtered_data[idx] if filtered_data else {} + self.handle_single_inference( + inference=inf, + model_name=model_name, + input_data=current_input, + ) + # endfor inferences + return + diff --git a/extensions/business/mixins/nlp_agent_mixin.py b/extensions/business/mixins/nlp_agent_mixin.py index 4c2af5f2..f02fd3e0 100644 --- a/extensions/business/mixins/nlp_agent_mixin.py +++ b/extensions/business/mixins/nlp_agent_mixin.py @@ -1,69 +1,15 @@ -NLP_AGENT_MIXIN_CONFIG = { - 'OBJECT_TYPE': [], - "ALLOW_EMPTY_INPUTS": False, # if this is set to true the on-idle will be triggered continuously the process - "DEBUG_MODE": True, +from extensions.business.mixins.base_agent_mixin import BASE_AGENT_MIXIN_CONFIG, _BaseAgentMixin - "VALIDATION_RULES": { - }, +NLP_AGENT_MIXIN_CONFIG = { + **BASE_AGENT_MIXIN_CONFIG, } -class _NlpAgentMixin(object): - def Pd(self, msg, **kwargs): - if self.cfg_debug_mode: - self.P(msg, **kwargs) - return - - def filter_valid_inference(self, inference): - return isinstance(inference, dict) and inference.get("IS_VALID", True) - - def filter_valid_inferences(self, inferences, return_idxs=False): - res = [] - idxs = [] - for idx, inf in enumerate(inferences): - if self.filter_valid_inference(inference=inf): - res.append(inf) - idxs.append(idx) - # endfor inferences - return res if not return_idxs else (res, idxs) - - def inference_to_response(self, inference, model_name): +class _NlpAgentMixin(_BaseAgentMixin): + def inference_to_response(self, inference, model_name, input_data): return { 'REQUEST_ID': inference.get('REQUEST_ID'), 'MODEL_NAME': model_name, 'TEXT_RESPONSE': inference.get('text'), } - def handle_single_inference(self, inference, model_name=None): - request_id = inference.get('REQUEST_ID', None) - self.Pd(f"Processing inference for request ID: {request_id}, model: {model_name}") - request_result = self.inference_to_response(inference, model_name) - current_payload_kwargs = { - 'result': request_result, - 'request_id': request_id, - } - self.add_payload_by_fields(**current_payload_kwargs) - return - - def handle_inferences(self, inferences, data=None): - if not isinstance(inferences, list): - return - if len(inferences) > 0 and not isinstance(inferences[0], dict): - return - model_name = inferences[0].get('MODEL_NAME', None) if len(inferences) > 0 else None - cnt_initial_inferences = len(inferences) - inferences, valid_idxs = self.filter_valid_inferences(inferences, return_idxs=True) - self.Pd(f"Filtered {cnt_initial_inferences} inferences to {len(inferences)} valid inferences.") - if data is not None: - filtered_data = [ - data[idx] for idx in valid_idxs - ] - if len(filtered_data) > 0: - self.Pd(f"Received requests: {self.json_dumps(self.shorten_str(filtered_data), indent=2)}") - # endif data is not None - - for inf in inferences: - self.handle_single_inference(inference=inf, model_name=model_name) - # endfor inferences - return - diff --git a/extensions/business/nlp/doc_embedding_agent.py b/extensions/business/nlp/doc_embedding_agent.py index 2515dd83..8762a859 100644 --- a/extensions/business/nlp/doc_embedding_agent.py +++ b/extensions/business/nlp/doc_embedding_agent.py @@ -15,7 +15,7 @@ "DOC_EMBED_STATUS_PERIOD": 20, 'ALLOW_EMPTY_INPUTS': True, # if this is set to true the on-idle will continuously trigger the process - "DEBUG_MODE": True, + "DEBUG_LOGGING_ENABLED": True, 'CHAINSTORE_RESPONSE_KEY': None, @@ -97,7 +97,7 @@ def maybe_send_status(self, inf_meta): # endif time to send status return - def inference_to_response(self, inference, model_name): + def inference_to_response(self, inference, model_name, input_data=None): return inference def _process(self): diff --git a/extensions/business/nlp/vllm_agent.py b/extensions/business/nlp/vllm_agent.py index 77603d7a..c26ba8ab 100644 --- a/extensions/business/nlp/vllm_agent.py +++ b/extensions/business/nlp/vllm_agent.py @@ -1,6 +1,6 @@ -# from naeural_core.business.base.network_processor import NetworkProcessorPlugin as BasePlugin from naeural_core.business.base import BasePluginExecutor as BasePlugin from extensions.business.mixins.nlp_agent_mixin import _NlpAgentMixin, NLP_AGENT_MIXIN_CONFIG +from extensions.serving.mixins_llm.llm_utils import LlmCT from concurrent.futures.thread import ThreadPoolExecutor from dataclasses import dataclass @@ -569,8 +569,8 @@ def extract_request_result(self, request_id: str, req_entry: _ReqEntry) -> Dict: "MODEL_NAME": self.cfg_model_name, "REQUEST_ID": request_id, "IS_VALID": True, - "text": res.get("content", None), - "RAW": res.get("raw", None), + LlmCT.TEXT: res.get("content", None), + LlmCT.FULL_OUTPUT: res.get("raw", None), "ELAPSED_TIME": req_entry.elapsed_time, } except Exception as e: diff --git a/extensions/business/oracle_management/oracle_api.py b/extensions/business/oracle_management/oracle_api.py index 8ad07b8e..a566c998 100644 --- a/extensions/business/oracle_management/oracle_api.py +++ b/extensions/business/oracle_management/oracle_api.py @@ -1,10 +1,10 @@ """ The OracleApiPlugin is a FastAPI web app that provides endpoints to interact with the -oracle network of the Naeural Edge Protocol +oracle network of the Ratio1 Edge Protocol Each request will generate data as follows: -- availablity data is requested from the oracle API +- availability data is requested from the oracle API - the data is signed with EVM signature and signature/address is added - other oracle peers signatures are added - all must be on same agreed availability - package is node-signed and returned to the client diff --git a/extensions/business/oracle_sync/oracle_sync_01.py b/extensions/business/oracle_sync/oracle_sync_01.py index 0f5b72df..a2c38616 100644 --- a/extensions/business/oracle_sync/oracle_sync_01.py +++ b/extensions/business/oracle_sync/oracle_sync_01.py @@ -66,7 +66,7 @@ POTENTIALLY_FULL_AVAILABILITY_THRESHOLD, SUPERVISOR_MIN_AVAIL_PRC, - MAX_RECEIVED_MESSAGES_SIZE, + MAX_RECEIVED_MESSAGES_PER_ORACLE, ORACLE_SYNC_USE_R1FS, ) @@ -320,9 +320,15 @@ def _prepare_job_state_transition_map(self): 'TRANSITIONS': [ { 'NEXT_STATE': self.STATES.S0_WAIT_FOR_EPOCH_CHANGE, - 'TRANSITION_CONDITION': self.state_machine_api_callback_always_true, + 'TRANSITION_CONDITION': self._check_enough_signatures_collected, 'ON_TRANSITION_CALLBACK': self._reset_to_initial_state, - 'DESCRIPTION': "Wait for the epoch to change to start a new sync process", + 'DESCRIPTION': "Consensus reached. Wait for the epoch to change to start a new sync process", + }, + { + 'NEXT_STATE': self.STATES.S8_SEND_REQUEST_AGREED_MEDIAN_TABLE, + 'TRANSITION_CONDITION': self._check_not_enough_signatures_collected, + 'ON_TRANSITION_CALLBACK': self.state_machine_api_callback_do_nothing, + 'DESCRIPTION': 'If not enough signatures were collected, request the agreement from the other oracles.' } ], }, @@ -417,7 +423,7 @@ def on_init(self): # because they have to request the agreed median table and wait to receive # the agreed median table from the previous epochs. self.state_machine_name = 'OracleSyncPlugin' - self._received_messages_from_oracles = self.deque(maxlen=MAX_RECEIVED_MESSAGES_SIZE) + self._oracle_received_messages = self.defaultdict(lambda: self.deque(maxlen=MAX_RECEIVED_MESSAGES_PER_ORACLE)) self.state_machine_api_init( name=self.state_machine_name, state_machine_transitions=self._prepare_job_state_transition_map(), @@ -469,6 +475,8 @@ def _reset_to_initial_state(self): self.dct_agreed_availability_cid = {} self.dct_agreement_signatures_cid = {} + self.enough_signatures_collected = None + # This will store the number of iterations the oracle has performed after the early stopping condition # is met. self.early_stopping_iterations = {} @@ -479,16 +487,6 @@ def _reset_to_initial_state(self): self.P(f'Current epoch: {self._current_epoch}, Last epoch synced: {self._last_epoch_synced}.') return - # """STATE MACHINE CALLBACKS SECTION""" - # if True: - # - # """END STATE MACHINE SECTION""" - - # """UTILS SECTION""" - # if True: - # - # """END UTILS SECTION""" - """MESSAGE HANDLING UTILS SUBSECTION""" if True: @NetworkProcessorPlugin.payload_handler() @@ -504,25 +502,24 @@ def handle_received_payloads(self, payload: dict): sender = payload.get(self.ct.PAYLOAD_DATA.EE_SENDER) if not self._is_oracle(sender): return - self._received_messages_from_oracles.append(payload) + self._oracle_received_messages[sender].append(payload) return def get_received_messages_from_oracles(self): """ - Get the messages received from the oracles. - This method returns a generator for memory efficiency. + Get the messages received from all the oracles. + This method will gather the last received message from each oracle and return them Returns ------- - generator : The messages received from the oracles + received_messages : list of dict + The received messages """ - # retrieve messages from self._received_messages_from_oracles - dct_messages = list(self._received_messages_from_oracles) - self._received_messages_from_oracles.clear() - # This will return a generator that will be used in the next steps. - received_messages = (dct_messages[i] for i in range(len(dct_messages))) - - return received_messages + return [ + message_deque.popleft() + for message_deque in self._oracle_received_messages.values() + if message_deque + ] """END MESSAGE HANDLING UTILS SUBSECTION""" def maybe_self_assessment(self): diff --git a/extensions/business/oracle_sync/sync_mixins/ora_sync_constants.py b/extensions/business/oracle_sync/sync_mixins/ora_sync_constants.py index 4d254ab7..b2fcfaff 100644 --- a/extensions/business/oracle_sync/sync_mixins/ora_sync_constants.py +++ b/extensions/business/oracle_sync/sync_mixins/ora_sync_constants.py @@ -4,7 +4,7 @@ ORACLE_SYNC_ONLINE_PRESENCE_MIN_THRESHOLD ) -MAX_RECEIVED_MESSAGES_SIZE = 1000 +MAX_RECEIVED_MESSAGES_PER_ORACLE = 50 DEBUG_MODE = False SIGNATURES_EXCHANGE_MULTIPLIER = 2 REQUEST_AGREEMENT_TABLE_MULTIPLIER = 5 if DEBUG_MODE else 2 diff --git a/extensions/business/oracle_sync/sync_mixins/ora_sync_states_mixin.py b/extensions/business/oracle_sync/sync_mixins/ora_sync_states_mixin.py index 62f3f1d6..ce90cc88 100644 --- a/extensions/business/oracle_sync/sync_mixins/ora_sync_states_mixin.py +++ b/extensions/business/oracle_sync/sync_mixins/ora_sync_states_mixin.py @@ -1010,7 +1010,21 @@ def _receive_agreement_signature_and_maybe_send_agreement_signature(self): if not self._check_received_agreement_signature_ok(sender, oracle_data): continue - signature_dict = oracle_data[OracleSyncCt.AGREEMENT_SIGNATURE] + # Attempting to extract both in case of multiple signature message from S10 + signature_dict = oracle_data.get(OracleSyncCt.AGREEMENT_SIGNATURE) + signatures_dict = oracle_data.get(OracleSyncCt.AGREEMENT_SIGNATURES) + is_single_signature = False + if signature_dict is not None: + is_single_signature = True + signatures_dict = {sender: signature_dict} + # endif single signature message + + temp_agreement_signatures = { + **signatures_dict, + **self.compiled_agreed_median_table_signatures, + } + is_duplicated = len(temp_agreement_signatures) == len(self.compiled_agreed_median_table_signatures) + current_count = len(temp_agreement_signatures) if self.cfg_debug_sync: stage = oracle_data[OracleSyncCt.STAGE] @@ -1018,14 +1032,20 @@ def _receive_agreement_signature_and_maybe_send_agreement_signature(self): sender=sender, stage=stage, data=self.compiled_agreed_median_table_signatures, - return_str=True + return_str=True, + is_duplicated=is_duplicated, + current_count=current_count, ) if self.cfg_debug_sync_full: - log_str += f", {signature_dict = }" + added_log_str = f", {signature_dict = }" + if not is_single_signature: + added_log_str = f", signatures_dict = {self.json_dumps(signatures_dict, indent=2)}" + # endif multi signature + log_str += added_log_str # endif debug_sync_full self.P(log_str) # endif debug_sync - self.compiled_agreed_median_table_signatures[sender] = signature_dict + self.compiled_agreed_median_table_signatures = temp_agreement_signatures # endfor received messages return @@ -1123,6 +1143,14 @@ def _exchange_signatures_timeout(self): """S7_UPDATE_EPOCH_MANAGER CALLBACKS""" if True: + def _check_enough_signatures_collected(self): + return self.enough_signatures_collected + + def _check_not_enough_signatures_collected(self): + # This check is written in this way in order to be able to differentiate between + # the consensus being in progress and it being failed at the final state. + return self.enough_signatures_collected is False + def _update_epoch_manager_with_agreed_median_table( self, epoch=None, compiled_agreed_median_table=None, agreement_signatures=None, epoch_is_valid=None, agreement_cid=None, signatures_cid=None, debug=True @@ -1177,6 +1205,12 @@ def _update_epoch_manager_with_agreed_median_table( oracle_list = self.get_oracle_list() oracle_signers = [oracle for oracle in oracle_list if oracle in signers] epoch_is_valid = len(oracle_signers) > 0 + if is_single_call and epoch_is_valid: + self.enough_signatures_collected = self._check_enough_oracles( + participating_oracles=signers + ) + epoch_is_valid = self.enough_signatures_collected + # endif is_single_call # endif epoch_is_valid if epoch <= self._last_epoch_synced: @@ -1223,8 +1257,11 @@ def _update_epoch_manager_with_agreed_median_table( if debug: valid_str = "VALID" if epoch_is_valid else "INVALID" announced_cnt = len(self._announced_participating) - log_str = f'Successfully synced epoch {epoch} with {valid_str} agreed median table ' - log_str += f'and {len(agreement_signatures)} agreement signatures from ' + log_str = f'Successfully synced epoch {epoch} with {valid_str} agreed median table and' + if is_single_call and not epoch_is_valid: + log_str = f"Failed to reach consensus for epoch {epoch} with" + # endif check if failed consensus + log_str += f' {len(agreement_signatures)} agreement signatures from ' log_str += f'{announced_cnt} announced participants at the start.\n' log_str += f"Initially announced participants:" log_str += "".join([ @@ -1250,9 +1287,13 @@ def _update_epoch_manager_with_agreed_median_table( if self.cfg_debug_sync_full: self.P(f'DEBUG EM data after update:\n{self.netmon.epoch_manager.data}') - self._last_epoch_synced = epoch - # In case of multiple updates, the save is only needed after the last update. if is_single_call: + if epoch_is_valid: + self._last_epoch_synced = epoch + else: + self._last_epoch_synced = epoch + # In case of multiple updates, the save is only needed after the last update. + if is_single_call and epoch_is_valid: self.netmon.epoch_manager.maybe_update_cached_data(force=True) self.netmon.epoch_manager.save_status() # endif part of consensus process diff --git a/extensions/business/oracle_sync/sync_mixins/ora_sync_utils_mixin.py b/extensions/business/oracle_sync/sync_mixins/ora_sync_utils_mixin.py index ec2c3cd2..ec4d2360 100644 --- a/extensions/business/oracle_sync/sync_mixins/ora_sync_utils_mixin.py +++ b/extensions/business/oracle_sync/sync_mixins/ora_sync_utils_mixin.py @@ -642,9 +642,13 @@ def log_received_message( stage: str, data: dict, return_str: bool = False, + is_duplicated: bool = None, + current_count: int = None, ): - is_duplicated = sender in data.keys() - current_count = len(data) + (1 - is_duplicated) + if is_duplicated is None: + is_duplicated = sender in data.keys() + if current_count is None: + current_count = len(data) + (1 - is_duplicated) duplicated_str = "(duplicated)" if is_duplicated else "" progress_str = f"[{current_count}/{self.total_participating_oracles()}]" sender_str = self.get_sender_str(sender) @@ -956,8 +960,9 @@ def _check_received_oracle_data_for_values( expected_variable_names : list[str] The list of expected variable names in `oracle_data` Will be used to retrieve the standards from `VALUE_STANDARDS` - expected_stage : str, optional + expected_stage : str or list, optional The expected stage of the message. If None, this will be skipped. + If list[str] is provided all the states will be valid. verify : bool, optional If True, oracle_data has to contain `EE_SIGN` key with the signature of the data in it. @@ -1452,6 +1457,43 @@ def _check_agreement_signature( return False return True + def _check_multiple_signatures(self, agreement_signatures_dict, sender_str: str): + """ + Check if signatures from multiple senders are valid. + + Parameters + ---------- + agreement_signatures_dict : dict + A dictionary of {sender: agreement_signature} with the signatures of each sender. + + sender_str : str + String from sender of the entire message + + Returns + ------- + bool : True if the signatures are valid, False otherwise + """ + for sig_sender, signature_dict in agreement_signatures_dict.items(): + sig_sender_str = self.get_sender_str(sender=sig_sender) + if not self.is_participating.get(sig_sender, False): + self.P(f"Oracle {sig_sender_str} should not have sent signature for agreement[Received from {sender_str}]. ignoring...", color='r') + # TODO: review if this should make the entire message invalid + # One node can be seen as potentially full online by some oracles and not by others. + # But for the node to actually participate in the agreement, it has to see itself as full online. + # That can not happen if the node is not seen as potentially full online by at least another + # participating oracle. + return False + # endif not expected to participate + if not self._check_agreement_signature( + sender=sig_sender, + signature_dict=signature_dict, + ): + self.P(f"Invalid agreement signature from oracle {sig_sender_str}!", color='r') + return False + # endif valid agreement signature + # endfor agreement signatures + return True + def _check_received_agreement_signature_ok(self, sender, oracle_data): """ Check if the received signature for agreement is ok. Print the error message if not. @@ -1467,7 +1509,7 @@ def _check_received_agreement_signature_ok(self, sender, oracle_data): ------- bool : True if the received agreed value is ok, False otherwise """ - if not self._check_received_oracle_data_for_values( + valid_single_signature_message = self._check_received_oracle_data_for_values( sender=sender, oracle_data=oracle_data, expected_variable_names=[ @@ -1477,11 +1519,29 @@ def _check_received_agreement_signature_ok(self, sender, oracle_data): # No need to verify here, since the agreement signature itself is signed and verified # in the self._check_agreement_signature method verify=False, - ): + ) + valid_multi_signature_message = self._check_received_oracle_data_for_values( + sender=sender, + oracle_data=oracle_data, + expected_variable_names=[ + OracleSyncCt.STAGE, OracleSyncCt.AGREEMENT_SIGNATURES + ], + expected_stage=self.STATES.S10_EXCHANGE_AGREEMENT_SIGNATURES, + # No need to verify here, since the agreement signature itself is signed and verified + # in the self._check_agreement_signature method + verify=False, + ) + if not valid_single_signature_message and not valid_multi_signature_message: return False - signature_dict = oracle_data[OracleSyncCt.AGREEMENT_SIGNATURE] sender_str = self.get_sender_str(sender=sender) + signatures_dict = {} + if valid_single_signature_message: + signature_dict = oracle_data[OracleSyncCt.AGREEMENT_SIGNATURE] + signatures_dict[sender] = signature_dict + else: + signatures_dict = oracle_data[OracleSyncCt.AGREEMENT_SIGNATURES] + # endif single signature or multi signature # In the is_participating dictionary, only oracles that were seen # as full online are marked as True @@ -1490,20 +1550,14 @@ def _check_received_agreement_signature_ok(self, sender, oracle_data): self.P(f"Oracle {sender_str} should not have sent signature for agreement. ignoring...", color='r') return False - if not self._check_agreement_signature( - sender=sender, - signature_dict=signature_dict - ): - return False - - if sender != signature_dict.get('EE_SENDER'): - self.P( - f"Agreement signature from oracle {sender_str} does not match the sender! Possible impersonation attack!", - color='r' - ) + if sender not in signatures_dict: + self.P(f"Oracle {sender_str} sent agreement signatures, but without their own! Possible impersonation!", color='r') return False - # endif identity check - return True + # endif sender has own signature in message + return self._check_multiple_signatures( + agreement_signatures_dict=signatures_dict, + sender_str=sender_str, + ) def _check_received_agreement_signatures_ok(self, sender: str, oracle_data: dict): """ @@ -1542,26 +1596,10 @@ def _check_received_agreement_signatures_ok(self, sender: str, oracle_data: dict agreement_signatures = oracle_data[OracleSyncCt.AGREEMENT_SIGNATURES] - for sig_sender, signature_dict in agreement_signatures.items(): - sig_sender_str = self.get_sender_str(sender=sig_sender) - if not self.is_participating.get(sig_sender, False): - self.P(f"Oracle {sig_sender_str} should not have sent signature for agreement. ignoring...", color='r') - # TODO: review if this should make the entire message invalid - # One node can be seen as potentially full online by some oracles and not by others. - # But for the node to actually participate in the agreement, it has to see itself as full online. - # That can not happen if the node is not seen as potentially full online by at least another - # participating oracle. - return False - # endif not expected to participate - if not self._check_agreement_signature( - sender=sig_sender, - signature_dict=signature_dict - ): - self.P(f"Invalid agreement signature from oracle {sig_sender_str}!", color='r') - return False - # endif agreement signature - # endfor agreement signatures - return True + return self._check_multiple_signatures( + agreement_signatures_dict=agreement_signatures, + sender_str=sender_str + ) def _check_received_epoch__agreed_median_table_ok(self, sender, oracle_data): """ diff --git a/extensions/data/default/jeeves/jeeves_listener.py b/extensions/data/default/jeeves/jeeves_listener.py index b523769e..7eb888c2 100644 --- a/extensions/data/default/jeeves/jeeves_listener.py +++ b/extensions/data/default/jeeves/jeeves_listener.py @@ -9,6 +9,7 @@ "PATH_FILTER": JeevesCt.UNIFIED_PATH_FILTER, "SUPPORTED_REQUEST_TYPES": None, # supported request types, None means all are supported + "ALLOW_UNVERIFIED_MESSAGES": False, # if set to True, messages that fail verification will still be processed 'VALIDATION_RULES': { **BaseClass.CONFIG['VALIDATION_RULES'], @@ -19,14 +20,6 @@ class JeevesListenerDataCapture(BaseClass, _JeevesUtilsMixin): CONFIG = _CONFIG - def Pd(self, s, color=None, **kwargs): - """ - Print debug message with Jeeves agent prefix. - """ - if self.cfg_debug_iot_payloads: - self.P(s, color=color, **kwargs) - return - def filter_message_for_agent(self, normalized_message: dict): """ Method for filtering messages intended for Jeeves agent processing. @@ -100,6 +93,28 @@ def check_message_for_agent(self, message: dict) -> bool: return payload_signature in JeevesCt.JEEVES_API_SIGNATURES + def maybe_verify_message(self, message: dict): + if self.cfg_allow_unverified_messages: + return message + verified = False + verify_msg = None + try: + verify_results = self.bc.verify( + dct_data=message, + str_signature=None, sender_address=None, + return_full_info=True + ) + verified = verify_results.valid + verify_msg = verify_results.message + except Exception as e: + verify_msg = f"Error during message signature verification: {e}" + if not verified: + self.P( + f"Message signature verification failed: {verify_msg}. Message: {message}", + color='r' + ) + return message if verified else None + def _filter_message(self, unfiltered_message: dict): """ Method for checking if the message should be kept or not during the filtering process. @@ -129,13 +144,28 @@ def _filter_message(self, unfiltered_message: dict): self.Pd(f"Invalid message format: {self.shorten_str(prefiltered_message)}", color='r') return None + self.Pd(f"Initial prefiltered message: {prefiltered_message}") + + prefiltered_message = self.maybe_verify_message(prefiltered_message) + if prefiltered_message is None: + self.P(f"Message verification failed, dropping message.", color='r') + return None + self.Pd(f"Verified prefiltered message: {prefiltered_message}") + prefiltered_message = self.receive_and_decrypt_payload(prefiltered_message) + self.Pd(f"Decrypted prefiltered message: {self.shorten_str(prefiltered_message)}") + if not prefiltered_message: + self.P(f"Message decryption failed, dropping message.", color='r') + return None + normalized_message = { (k.upper() if isinstance(k, str) else k): v for k, v in prefiltered_message.items() } if self.check_message_for_agent(normalized_message): + self.P(f"Message intended for Jeeves agent processing.") prefiltered_message = self.filter_message_for_agent(normalized_message) # endif message for agent + self.Pd(f"Final prefiltered message: {self.shorten_str(prefiltered_message)}") return prefiltered_message diff --git a/extensions/serving/base/base_llm_serving.py b/extensions/serving/base/base_llm_serving.py index 00480646..b8135094 100644 --- a/extensions/serving/base/base_llm_serving.py +++ b/extensions/serving/base/base_llm_serving.py @@ -238,7 +238,7 @@ def get_relevant_signatures(self): def get_local_path(self): models_cache = self.log.get_models_folder() - model_name = 'models/{}'.format(self.get_model_name()) + model_name = 'models/{}'.format(self.cfg_model_name) model_subfolder = model_name.replace('/', '--') path = self.os_path.join(models_cache, model_subfolder) if self.os_path.isdir(path): @@ -442,11 +442,14 @@ def check_relevant_input(self, input_dict: dict): normalized_signature = str(inp_signature).upper() if inp_signature is not None else None if normalized_signature not in self.get_relevant_signatures(): - # self.P(f"[DEBUG]Skipping irrelevant signature: {normalized_signature}. Relevant signatures: {self.get_relevant_signatures()}", color='y') + self.P(f"[DEBUG]Skipping irrelevant signature: {normalized_signature}. Relevant signatures: {self.get_relevant_signatures()}", color='y') return False jeeves_content = input_dict.get(self.ct.JeevesCt.JEEVES_CONTENT, {}) - # self.P(f"[DEBUG]Extracted jeeves content for relevance check: {self.shorten_str(jeeves_content)}", color='g') + if not jeeves_content: + self.P(f"[DEBUG]No jeeves content found in input: {self.shorten_str(input_dict)}", color='y') + return False + self.P(f"[DEBUG]Extracted jeeves content for relevance check: {self.shorten_str(jeeves_content)}", color='g') return self.check_supported_request_type(message_data=jeeves_content) def process_predict_kwargs(self, predict_kwargs: dict): @@ -928,11 +931,16 @@ def _predict(self, preprocessed_batch): # LlmCT.PRED: yhat, LlmCT.PRMP: prompt_lst, # LlmCT.TKNS: batch_tokens, + # TODO: add back the tps and maybe additional metrics to mimic the openai API structure # LlmCT.TPS: num_tps, LlmCT.ADDITIONAL: additional_lst, LlmCT.TEXT: text_lst, "RELEVANT_IDS": relevant_input_ids, - "TOTAL_INPUTS": cnt_total_inputs + "TOTAL_INPUTS": cnt_total_inputs, + # Placeholder for full output if needed in post-processing. + # In the future, this can be populated with more detailed + # generation information. + LlmCT.FULL_OUTPUT: text_lst, } return dct_result @@ -949,6 +957,7 @@ def _post_process(self, preds_batch): text_lst = preds_batch[LlmCT.TEXT] relevant_input_ids = preds_batch["RELEVANT_IDS"] cnt_total_inputs = preds_batch["TOTAL_INPUTS"] + full_output_lst = preds_batch[LlmCT.FULL_OUTPUT] for i, additional in enumerate(additionals): self.processed_requests.add(additional[LlmCT.REQUEST_ID]) @@ -961,12 +970,14 @@ def _post_process(self, preds_batch): # LlmCT.PRED : yhat[i].tolist(), LlmCT.PRMP : prompts[i], LlmCT.TEXT : decoded, + LlmCT.FULL_OUTPUT: full_output_lst[i], # LlmCT.TKNS : tokens[i].tolist(), # LlmCT.TPS : tps, **preds_batch[LlmCT.ADDITIONAL][i], # TODO: find a way to send the model metadata to the plugin, other than through the inferences. - 'MODEL_NAME': self.get_model_name() + 'MODEL_NAME': self.get_model_name(), } + # endif full_output_lst is not None result.append(dct_result) # endfor each text current_text_idx = 0 diff --git a/extensions/serving/default_inference/nlp/llama_cpp_base.py b/extensions/serving/default_inference/nlp/llama_cpp_base.py index d86c9e70..51b8845e 100644 --- a/extensions/serving/default_inference/nlp/llama_cpp_base.py +++ b/extensions/serving/default_inference/nlp/llama_cpp_base.py @@ -1,3 +1,6 @@ +""" +TODO: example pipeline with additional explanations +""" from extensions.serving.base.base_llm_serving import BaseLlmServing as BaseServingProcess from llama_cpp import Llama from extensions.serving.mixins_llm.llm_utils import LlmCT @@ -41,8 +44,16 @@ def _load_tokenizer(self): # llama.cpp uses built-in tokenizer return + def get_model_name(self): + model_id = self.cfg_model_name + model_filename = self.cfg_model_filename + if model_id is None or model_filename is None: + raise ValueError("Both MODEL_NAME and MODEL_FILENAME must be specified for Llama_cpp models.") + # endif model id/filename check + return f"{model_id}/{model_filename}" + def _load_model(self): - model_id = self.get_model_name() + model_id = self.cfg_model_name model_filename = self.cfg_model_filename n_ctx = self.cfg_model_n_ctx @@ -214,8 +225,8 @@ def _predict(self, preprocessed_batch): ] = preprocessed_batch results = [ - # (idx, valid, process_method, reply) - (idx, valid_condition, process_methods[idx], None) + # (idx, valid, process_method, reply, full_output) + (idx, valid_condition, process_methods[idx], None, None) for idx, valid_condition in enumerate(valid_conditions) ] obj_for_inference = [ @@ -227,6 +238,7 @@ def _predict(self, preprocessed_batch): tries = 0 while not conditions_satisfied: reply_lst = [] + full_output_lst = [] t0 = self.time() timings = [] total_generated_tokens = 0 @@ -244,6 +256,7 @@ def _predict(self, preprocessed_batch): num_tokens_generated = out["usage"]["completion_tokens"] total_generated_tokens += num_tokens_generated reply_lst.append(reply) + full_output_lst.append(out) # endfor obj_for_inference t_total = self.time() - t0 curr_tps = total_generated_tokens / t_total if t_total > 0 else 0 @@ -256,6 +269,7 @@ def _predict(self, preprocessed_batch): valid_condition = results[idx_orig][1] process_method = results[idx_orig][2] current_text = reply_lst[idx_curr] + full_output = full_output_lst[idx_curr] self.P(f"Checking condition for object {idx_orig}:\nvalid:`{valid_condition}`|process:`{process_method}`|text:\n{current_text}") current_text = self.maybe_process_text(current_text, process_method) self.P(f"Processed text:\n{current_text}") @@ -269,7 +283,7 @@ def _predict(self, preprocessed_batch): current_condition_satisfied = valid_text or (tries >= max_tries) if current_condition_satisfied: # If the condition is satisfied, we can save the result - results[idx_orig] = (idx_orig, valid_condition, process_method, current_text) + results[idx_orig] = (idx_orig, valid_condition, process_method, current_text, full_output) else: invalid_objects.append((idx_orig, len(invalid_objects))) # endif current condition satisfied @@ -281,13 +295,15 @@ def _predict(self, preprocessed_batch): conditions_satisfied = True # endwhile conditions_satisfied - text_lst = [text for _, _, _, text in results] + text_lst = [text for _, _, _, text, _ in results] + full_output_lst = [full_output for _, _, _, _, full_output in results] dct_result = { LlmCT.PRMP: messages_lst, LlmCT.TEXT: text_lst, LlmCT.ADDITIONAL: additional_lst, "RELEVANT_IDS": relevant_input_ids, - "TOTAL_INPUTS": cnt_total_inputs + "TOTAL_INPUTS": cnt_total_inputs, + LlmCT.FULL_OUTPUT: full_output_lst, } return dct_result diff --git a/extensions/serving/mixins_llm/llm_utils.py b/extensions/serving/mixins_llm/llm_utils.py index 2a833740..6f6e36a6 100644 --- a/extensions/serving/mixins_llm/llm_utils.py +++ b/extensions/serving/mixins_llm/llm_utils.py @@ -38,6 +38,7 @@ class LlmCT: REQUEST_ID = 'REQUEST_ID' REQUEST_TYPE = 'REQUEST_TYPE' VALID_MASK = 'VALID_MASK' + FULL_OUTPUT = 'FULL_OUTPUT' # Constants for encoding a prompt using chat templates REQUEST_ROLE = 'user' diff --git a/plugins/business/cerviguard/cerviguard_api.py b/plugins/business/cerviguard/cerviguard_api.py new file mode 100644 index 00000000..85babec7 --- /dev/null +++ b/plugins/business/cerviguard/cerviguard_api.py @@ -0,0 +1,178 @@ +""" +CV_INFERENCE_API Plugin + +Production-Grade CerviGuard Inference API + +This plugin exposes a hardened, FastAPI-powered interface for CerviGuard +computer-vision workloads. It reuses the BaseInferenceApi request lifecycle +while tailoring validation and response shaping for image analysis. + +Highlights +- Loopback-only surface paired with local CerviGuard clients +- Request tracking, persistence, auth, and rate limiting from BaseInferenceApi +- Base64 payload validation and metadata normalization for serving plugins +- Structured mapping of struct_data payloads and inferences back to requests +""" + +from extensions.business.edge_inference_api.cv_inference_api import CvInferenceApiPlugin as BasePlugin +from naeural_core.utils.fastapi_utils import PostponedRequest + +__VER__ = '0.1.0' + +_CONFIG = { + **BasePlugin.CONFIG, + + # Server configuration + 'PORT': 5082, + + # API metadata + 'API_TITLE': 'CerviGuard Local Serving API', + 'API_SUMMARY': 'Local image analysis API for CerviGuard', + 'API_DESCRIPTION': 'FastAPI server for cervical image analysis', + + # AI Engine for image processing + 'AI_ENGINE': 'CERVIGUARD_IMAGE_ANALYZER', + + 'VALIDATION_RULES': { + **BasePlugin.CONFIG['VALIDATION_RULES'], + }, +} + + +class CerviguardApiPlugin(BasePlugin): + """ + LOCAL_SERVING_API Plugin + + A FastAPI plugin designed for localhost-only access with loopback data capture. + This plugin: + - Does NOT require token authentication (localhost only) + - Routes outputs back to the loopback DCT queue (IS_LOOPBACK_PLUGIN = True) + - Provides simple REST endpoints for data processing + - Works with Loopback data capture type pipelines + """ + + CONFIG = _CONFIG + + """INFERENCE HANDLING""" + if True: + def _validate_analysis(self, analysis: dict) -> dict: + """ + Validate and sanitize analysis fields returned by inference. + + Parameters + ---------- + analysis : dict + Raw analysis data from the inference engine. + + Returns + ------- + dict + Analysis payload with validated and defaulted fields. + """ + safe_defaults = { + 'tz_type': 'Type 1', + 'lesion_assessment': 'none', + 'lesion_summary': 'Analysis unavailable', + 'risk_score': 0, + 'image_quality': 'unknown', + 'image_quality_sufficient': True + } + if not isinstance(analysis, dict): + return safe_defaults + + validated = {} + + tz_type = analysis.get('tz_type', safe_defaults['tz_type']) + if tz_type not in ['Type 0', 'Type 1', 'Type 2', 'Type 3']: + tz_type = safe_defaults['tz_type'] + validated['tz_type'] = tz_type + + lesion_assessment = analysis.get('lesion_assessment', safe_defaults['lesion_assessment']) + if lesion_assessment not in ['none', 'low', 'moderate', 'high']: + lesion_assessment = safe_defaults['lesion_assessment'] + validated['lesion_assessment'] = lesion_assessment + + lesion_summary = analysis.get('lesion_summary', safe_defaults['lesion_summary']) + if not isinstance(lesion_summary, str): + lesion_summary = safe_defaults['lesion_summary'] + validated['lesion_summary'] = lesion_summary + + risk_score = analysis.get('risk_score', safe_defaults['risk_score']) + try: + risk_score = int(risk_score) + if risk_score < 0 or risk_score > 100: + risk_score = max(0, min(100, risk_score)) + except (TypeError, ValueError): + risk_score = safe_defaults['risk_score'] + validated['risk_score'] = risk_score + + image_quality = analysis.get('image_quality', safe_defaults['image_quality']) + if not isinstance(image_quality, str): + image_quality = safe_defaults['image_quality'] + validated['image_quality'] = image_quality + + image_quality_sufficient = analysis.get( + 'image_quality_sufficient', + safe_defaults['image_quality_sufficient'] + ) + if not isinstance(image_quality_sufficient, bool): + image_quality_sufficient = safe_defaults['image_quality_sufficient'] + validated['image_quality_sufficient'] = image_quality_sufficient + return validated + + def _build_result_from_inference( + self, + request_id: str, + inference: Dict[str, Any], + metadata: Dict[str, Any], + request_data: Dict[str, Any] + ): + """ + Construct a result payload from inference output and metadata. + + Parameters + ---------- + request_id : str + Identifier of the tracked request. + inference : dict + Inference result data. + metadata : dict + Metadata to include in the response. + request_data : dict + Stored request record for reference. + + Returns + ------- + dict + Structured result payload including analysis and image details. + + Raises + ------ + ValueError + If the inference result format is invalid. + RuntimeError + When the inference indicates an error status. + """ + if not isinstance(inference, dict): + raise ValueError("Invalid inference result format.") + inference_data = inference.get('data', inference) + status = inference_data.get('status', inference.get('status', 'completed')) + if status == 'error': + err_msg = inference_data.get('error', 'Unknown error') + raise RuntimeError(err_msg) + + analysis = self._validate_analysis(inference_data.get('analysis', {})) + image_info = inference_data.get('image_info', {}) + result_payload = { + 'status': 'completed', + 'request_id': request_id, + 'analysis': analysis, + 'image_info': image_info, + 'processed_at': inference_data.get('processed_at', self.time()), + 'processor_version': inference_data.get('processor_version', 'unknown'), + 'metadata': metadata or request_data.get('metadata') or {}, + } + if 'model_name' in inference_data: + result_payload['model_name'] = inference_data['model_name'] + return result_payload + """END INFERENCE HANDLING""" diff --git a/extensions/business/cerviguard/local_serving_api.py b/plugins/business/cerviguard/local_serving_api.py similarity index 100% rename from extensions/business/cerviguard/local_serving_api.py rename to plugins/business/cerviguard/local_serving_api.py diff --git a/plugins/business/llm/code_assist_01.py b/plugins/business/llm/code_assist_01.py index 85258f15..e63a6e57 100644 --- a/plugins/business/llm/code_assist_01.py +++ b/plugins/business/llm/code_assist_01.py @@ -26,11 +26,10 @@ class CodeAssist01Plugin(BasePlugin, _NlpAgentMixin): def _process(self): # we always receive input from the upstream due to the fact that _process # is called only when we have input based on ALLOW_EMPTY_INPUTS=False (from NLP_AGENT_MIXIN_CONFIG) - if self.is_debug_mode: - full_input = self.dataapi_full_input() - self.P("Processing received input: {}".format(full_input)) - str_dump = self.json_dumps(full_input, indent=2) - self.P("Received input from pipeline:\n{}".format(str_dump)) + full_input = self.dataapi_full_input() + self.Pd("Processing received input: {}".format(full_input)) + str_dump = self.json_dumps(full_input, indent=2) + self.Pd("Received input from pipeline:\n{}".format(str_dump)) # endif debug string data = self.dataapi_struct_data() inferences = self.dataapi_struct_data_inferences() diff --git a/ver.py b/ver.py index 3e50febc..29ec0381 100644 --- a/ver.py +++ b/ver.py @@ -1 +1 @@ -__VER__ = '2.10.10' +__VER__ = '2.10.13' diff --git a/xperimental/oracle_sync/oracle_sync_test_plan.md b/xperimental/oracle_sync/oracle_sync_test_plan.md new file mode 100644 index 00000000..2ba2e2b5 --- /dev/null +++ b/xperimental/oracle_sync/oracle_sync_test_plan.md @@ -0,0 +1,343 @@ +# Oracle Sync (ORACLE_SYNC_01) — Testing Development Plan + +**Goal:** Provide Codex (coding agent) a concrete, code-grounded plan to build an automated test harness and scenario suite for the Oracle Sync plugin, focused on correctness, robustness, and regression prevention. + +Target code (4 files): + +- `edge_node/extensions/business/oracle_sync/oracle_sync_01.py` +- `edge_node/extensions/business/oracle_sync/sync_mixins/ora_sync_states_mixin.py` +- `edge_node/extensions/business/oracle_sync/sync_mixins/ora_sync_utils_mixin.py` +- `edge_node/extensions/business/oracle_sync/sync_mixins/ora_sync_constants.py` + +Deliverable to implement: `edge_node/xperimental/oracle_sync/test_ora_sync.py` (pytest-friendly, runnable as a script as well). + +--- + +## 1) Code-grounded system understanding (what must be tested) + +### 1.1 State machine flow (high-level) +The plugin is a state machine that: +1) **On startup** begins in **S8** (request historical agreements), then transitions to **S0** once caught up. +2) **Per epoch**, after epoch change: announce participants (**S11**), compute local tables (**S1**), exchange local tables (**S2**), compute median tables (**S3**), exchange median tables (**S4**), compute agreed median (**S5**), gather signatures (**S6**), exchange signatures (**S10**), update epoch manager (**S7**), then wait (**S0**). + +### 1.2 Message handling guarantees and constraints +- Incoming payloads are buffered **per-oracle** as deques with `MAX_RECEIVED_MESSAGES_PER_ORACLE`, and the state machine consumes at most **one message per oracle per step** (`popleft`). +- Non-oracle senders are ignored (`_is_oracle`). +- There are multiple message schemas enforced via `_check_received_oracle_data_for_values()` using `VALUE_STANDARDS` (type validation + optional CID decoding). + +### 1.3 Consensus & fault behavior +- Participant set is negotiated in **S11** and influences thresholds such as “half of valid oracles”. +- When **agreement is not reached**, or when **not enough signatures** are collected, nodes move into the “request agreed tables” path (S8/S9) and may **mark epochs as faulty**. +- “Too close to epoch change” requests should be ignored (3 minutes before end). + +### 1.4 R1FS (IPFS-like) conditional behavior +- Messages may include **CIDs instead of full data** for certain keys (`maybe_cid=True` in `VALUE_STANDARDS`). Retrieval failures must cause the message to be rejected. +- Add-to-R1FS has multiple fallback paths (not warmed, add failure, exception). + +--- + +## 2) Test strategy overview + +### 2.1 What “good” looks like (acceptance criteria) +Codex should aim for: + +1) **Deterministic simulation tests** that exercise full state-machine cycles with multiple oracles. +2) **Robust validation tests** (type/stage/signature/CID) that ensure bad payloads are ignored, not crashing the state machine. +3) **Corner-case regression tests** for timeouts, early stopping, oracle dropouts, and partial/historical sync. +4) Tests that run locally in seconds, with **no dependency on real blockchain/IPFS/network** (fully mocked). + +Recommended tooling: +- **pytest fixtures + monkeypatch** for dependency injection and time control. +- **unittest.mock autospec** where appropriate to keep mocks honest. +- **Hypothesis stateful testing** for randomized message ordering & adversarial sequences. + +--- + +## 3) Test harness architecture (what Codex should build) + +### 3.1 Core idea +Build a **closed-loop simulation** that runs N OracleSync instances, routes their outbound messages to each other, and advances “time” deterministically. + +**Key constraints to respect:** +- The plugin relies on `NetworkProcessorPlugin` methods (payload handlers, `add_payload_by_fields`, state machine API). +- For unit tests, we don’t need the full framework; we need **minimal shims** + **monkeypatched methods**. + +### 3.2 Components to implement in `test_ora_sync.py` + +#### A) `FakeClock` +- `time()` → float seconds (manual advance) +- `sleep(dt)` → advance time rather than real sleeping +- `datetime.now(tz)` support (only what `_check_too_close_to_epoch_change` needs) + +#### B) `FakeEpochManager` +Provide the subset used by the plugin/mixins, including: +- `epoch_length`, `get_current_epoch()`, `get_time_epoch()`, `maybe_close_epoch()` +- `get_current_epoch_end(current_epoch)` (for “ignore requests” window) +- `get_last_sync_epoch()` and persistence stubs: `maybe_update_cached_data(force=True)`, `save_status()` +- For availability: + - `get_current_epoch_availability(return_absolute=True, return_max=True)` (self-assessment) + - `get_node_previous_epoch(node)` (local view of previous epoch availability) + - `get_epoch_availability(epoch, return_additional=True)` returns: `(availability_table, signatures_table, agreement_cid, signatures_cid)` +- `is_epoch_valid(epoch)` and (if used) methods to mark invalid/faulty epochs. + +Keep internal data as simple dicts: +- `epoch_availability[epoch][node] -> int` +- `epoch_signatures[epoch] -> dict(oracle->sig)` (or per-node signatures if needed) +- `epoch_valid[epoch] -> bool` +- `last_sync_epoch` + +#### C) `FakeBlockchain` +Must support: +- `get_oracles()` to return the oracle list. +- `sign(dct, add_data=True, use_digest=True)` to attach a deterministic signature field (`EE_SIGN`). +- `verify(dct_data, str_signature=None, sender_address=None)` returning an object with `.valid` and `.message`. +- `maybe_add_prefix(addr)` (can be no-op) +- (Optional) helpers to “tamper” signatures for negative tests. + +**Deterministic signature approach for tests:** +- signature := sha256(json_sorted(data_without_EE_SIGN) + sender) +- store sender address inside the signed object if production does that. + +#### D) `FakeR1FS` +Supports: +- `is_ipfs_warmed` bool +- `add_pickle(obj)` returns CID or None +- `get_file(cid)` returns a temp file path written by the fake (pickle dump) +- Optional knobs to inject failures and timeouts. + +#### E) `MessageBus` (simulation network) +- Captures outbound `oracle_data` payloads produced by each oracle and delivers them to other oracles by calling their `handle_received_payloads()`. +- Must model: + - broadcast vs targeted (production seems broadcast) + - delivery delays and reordering (for adversarial tests) + - duplicate deliveries + +#### F) `OracleHarness` +Wraps a plugin instance and injects: +- `netmon` (with `epoch_manager` + node helpers used in logs/formatting) +- `bc`, `r1fs`, `time`, `sleep`, and any other module references used via `self.*` +- override `add_payload_by_fields(oracle_data=...)` to push into `MessageBus` instead of real networking + +**Implementation tactic:** +- Instantiate `OracleSync01Plugin` without framework bootstrapping, or subclass it in test to bypass base init. +- Then set required attributes and call `on_init()` (but patch `on_init` loops to avoid waiting). + +--- + +## 4) Test inventory (unit, integration, property-based) + +### 4.1 Constants tests (`ora_sync_constants.py`) +1) **Threshold sanity:** + - `FULL_AVAILABILITY_THRESHOLD == round(SUPERVISOR_MIN_AVAIL_PRC * EPOCH_MAX_VALUE)` + - `POTENTIALLY_FULL_AVAILABILITY_THRESHOLD` math stays within `[0, EPOCH_MAX_VALUE]` +2) **Timeout multipliers nonzero and stable** + - `*_SEND_MULTIPLIER`, `REQUEST_AGREEMENT_TABLE_MULTIPLIER`, `SIGNATURES_EXCHANGE_MULTIPLIER` +3) **VALUE_STANDARDS coherence** + - keys exist for used message fields and `maybe_cid` matches intended fields (LOCAL_TABLE, MEDIAN_TABLE). + +### 4.2 Utils mixin tests (`ora_sync_utils_mixin.py`) +**A) R1FS helpers** +- `r1fs_add_data_to_message()`: + - warmup false → embeds full dict + - warmup true + add succeeds → embeds CID + - warmup true + add returns None → embeds full dict + - add raises exception → embeds full dict +- `r1fs_get_data_from_message()`: + - value is dict → returns as-is + - value is CID → loads pickle + - CID retrieval fails → returns None and triggers rejection upstream + +**B) Message validation** +- `_check_received_oracle_data_for_values()` matrix: + - non-dict oracle_data → reject + - missing fields / None fields → reject + - wrong types vs `VALUE_STANDARDS` → reject + - stage mismatch (single stage and list-of-stages) → reject + - `maybe_cid` field with CID that fails retrieval → reject + - `verify=True` invalid signature → reject + +**C) Epoch-range validation** +- `_check_received_epoch__agreed_median_table_ok()`: + - non-contiguous `epoch_keys` → reject + - mismatch between epoch_keys and table keys/signature keys/is_valid keys → reject + +**D) “Too close to epoch change” rule** +- Freeze time near `get_current_epoch_end()` and verify `_check_too_close_to_epoch_change()` flips at `ORACLE_SYNC_IGNORE_REQUESTS_SECONDS` + +### 4.3 States mixin tests (`ora_sync_states_mixin.py`) +Because the mixin functions are state callbacks, cover them in two layers: + +#### Layer 1: “Pure-ish” unit tests (minimal environment) +Focus on compute/check functions that can run with mocked dependencies: +- `_compute_simple_median_table()` and `_compute_simple_agreed_value_table()` (if present) +- `_compute_agreed_median_table()` (driven by a prepared `dct_median_tables` and `is_participating`) +- `_compute_requested_agreed_median_table()` (hash-frequency consensus and faulty epoch marking) + +**Key corner cases:** +- **Strict majority rule**: verify behavior when max_frequency == floor(n/2) (must fail) and when > floor(n/2) (must succeed). +- “Faulty nodes” path: median frequency below `min_frequency` excludes nodes; ensure exclusion is deterministic and doesn’t crash. +- Potential float threshold from `_count_half_of_valid_oracles()` (it returns `/2`): ensure comparisons behave as intended and don’t cause off-by-one. + +#### Layer 2: Integration tests (full multi-oracle simulation) +Run N oracles through: +- Participant announcement (**S11**) and threshold update based on local availability. +- Local table exchange (**S2**) and median computation (**S3**) leading to **S5** agreement and signature exchange. +- Epoch manager update (**S7**) and transition back to **S0** only after signature criteria holds. + +### 4.4 Plugin-level tests (`oracle_sync_01.py`) +1) **Message queue bounds** + - Push > `MAX_RECEIVED_MESSAGES_PER_ORACLE` and verify older messages are dropped, and `get_received_messages_from_oracles()` drains one per oracle. +2) **Startup state** + - `on_init()` initializes state machine in **S8** and sets message buffers. +3) **Process exception containment** + - Force an exception in a state callback and verify `process()` sets `exception_occurred` and does not crash the test runner (it sleeps briefly). +4) **Oracle list refresh logic** + - Ensure `maybe_refresh_oracle_list()` is rate limited and handles empty blockchain response. + +--- + +## 5) Scenario suite (must-have end-to-end tests) + +### Scenario A — Happy-path consensus (3 oracles, 1 epoch) +**Objective:** validate S11→S7 pipeline produces identical agreement across oracles and updates epoch manager. + +**Setup:** +- 3 oracle nodes (A,B,C) with high previous availability. +- Several non-oracle nodes with availability values; each oracle’s local view should match (or have small differences if median logic expects it). + +**Assertions:** +- `compiled_agreed_median_table` matches expected strict-majority values. +- Signatures are collected, exchanged, and stored. +- `epoch_manager.get_last_sync_epoch()` increments to previous epoch. + +### Scenario B — One oracle cannot participate +**Objective:** ensure the “non-participating oracle” follows S8 request path and still catches up. + +**Setup:** +- Oracle C has previous availability below `FULL_AVAILABILITY_THRESHOLD`. + +**Assertions:** +- C does not announce participation; A,B still reach consensus. +- C requests agreed tables and updates epoch manager after receiving responses. + +### Scenario C — Disordered + duplicated messages +**Objective:** message routing noise shouldn’t break correctness. + +**Setup:** +- Inject random reorder/duplication in MessageBus for S2/S4/S6/S10 phases. + +**Assertions:** +- Final agreement still converges (or deterministically fails if strict majority is impossible). +- No infinite loops; timeouts or early-stopping triggers transition. + +### Scenario D — Invalid signatures / tampering +**Objective:** invalid data is ignored, not poisoning consensus. + +**Setup:** +- Tamper with one oracle’s signed median table entries or agreement signature. + +**Assertions:** +- `_check_received_oracle_data_for_values(...verify=True...)` rejects it. +- Consensus still succeeds if enough honest oracles remain. + +### Scenario E — R1FS CID path and retrieval failures +**Objective:** verify both (CID success) and (CID failure -> message rejected) paths. + +**Setup:** +- Enable `cfg_use_r1fs=True`. +- For LOCAL_TABLE/MEDIAN_TABLE, send CID and ensure receivers fetch from FakeR1FS. +- Inject CID retrieval failure for one sender. + +**Assertions:** +- Good CID messages are accepted and decoded. +- Failed retrieval causes message rejection (and may trigger timeout/fallback). + +### Scenario F — Historical sync on startup (multi-epoch range) +**Objective:** node starts late and needs epochs `[last_synced+1, current_epoch-1]`. + +**Setup:** +- Set `_last_epoch_synced` behind by K epochs. +- Ensure other oracles respond with `EPOCH__AGREED_MEDIAN_TABLE`, signatures, keys, and `EPOCH__IS_VALID`. + +**Assertions:** +- Receiver only accepts complete continuous ranges. +- Epoch manager gets updated for each epoch in range. +- If consensus hashes don’t reach strict majority, epochs are marked faulty. + +--- + +## 6) Property-based / stateful testing (high value, optional but recommended) + +### 6.1 Fuzz message validators +Use Hypothesis strategies to generate: +- missing keys, wrong types, None values +- epoch_keys with gaps, mismatched dict keys, mixed str/int keys +- random stages + +Property: validator functions must not throw, only return True/False with logs. + +### 6.2 Rule-based state machine simulation +Model: +- actions: deliver message, drop message, advance time, flip oracle availability participation, toggle R1FS warmed state +- invariants: + - state transitions remain within known states + - buffers never exceed configured maxlen + - if strict majority is possible and enough time passes, sync eventually completes + +Hypothesis provides `RuleBasedStateMachine` for this. + +--- + +## 7) Implementation steps for Codex (ordered, concrete tasks) + +### Step 0 — Create the skeleton file +- Create `edge_node/xperimental/oracle_sync/test_ora_sync.py` +- Add `pytest` entrypoint compatibility (`if __name__ == "__main__": ... pytest.main([...])`) + +### Step 1 — Build fakes (clock, epoch manager, blockchain, r1fs) +- Implement each fake with explicit failure injection switches. +- Add unit tests per fake (sanity + failure modes). + +### Step 2 — Build MessageBus + OracleHarness +- Override plugin outbound send path to push into MessageBus. +- Route inbound messages via `handle_received_payloads()`. + +### Step 3 — Wire 3-oracle simulation and run a single sync cycle +- Deterministic time advance: + - call `process()` in a loop + - advance time so that send intervals/timeouts trigger +- Assert correct state progression and epoch manager update. + +### Step 4 — Add the scenario suite (A–F) +- Each scenario must run in < 2–5 seconds and be deterministic. + +### Step 5 — Add validator unit tests +- Directly call `_check_received_oracle_data_for_values`, `_check_received_epoch__agreed_median_table_ok`, etc., with crafted payloads. + +### Step 6 — Add Hypothesis tests (optional) +- If runtime is acceptable in CI, keep them enabled with tuned settings; otherwise mark as nightly. + +### Step 7 — Quality gates +- Add coverage reporting for these 4 files (target: 80%+ on mixins). +- Ensure logs do not flood: use pytest’s `caplog` or silence debug flags. + +--- + +## 8) Suggested micro-refactors (only if tests reveal pain points) +These are not required for the test file, but will dramatically improve testability and long-term reliability: + +1) Extract “compute median” and “compute agreed median” into pure functions that accept dicts and return dicts (no `self` access). +2) Normalize strict-majority thresholds to integers (`ceil(n/2)` or `n//2 + 1`) to avoid float comparisons. +3) Add a thin abstraction for outbound messaging (e.g., `send_oracle_data(oracle_data)`) to simplify mocking. + +--- + +## 9) Done definition (what to merge) +A PR is “done” when: +- `test_ora_sync.py` contains: + - fakes + harness + message bus + - ≥ 6 scenario tests (A–F) + - validator unit tests +- Tests pass reliably with `pytest -q` and do not depend on network, real blockchain, or real IPFS. +- Clear docstring at top explains how to run locally and how to extend scenarios. + diff --git a/xperimental/oracle_sync/test_ora_sync.py b/xperimental/oracle_sync/test_ora_sync.py new file mode 100644 index 00000000..f7dfa3f4 --- /dev/null +++ b/xperimental/oracle_sync/test_ora_sync.py @@ -0,0 +1,1000 @@ +""" +Oracle Sync test harness and scenario suite. + +Run: + pytest -q xperimental/oracle_sync/test_ora_sync.py + python3 xperimental/oracle_sync/test_ora_sync.py +""" + +from __future__ import annotations + +import hashlib +import json +import os +import pickle +import sys +import tempfile +import types +from collections import defaultdict, deque +from dataclasses import dataclass +from datetime import datetime, timezone +from typing import Any + +import pytest + +_REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) +if _REPO_ROOT not in sys.path: + sys.path.insert(0, _REPO_ROOT) + +__VER__ = "0.1.0" + + +def _install_naeural_core_stubs(): + if "naeural_core" in sys.modules: + return + + naeural_core = types.ModuleType("naeural_core") + naeural_core_constants = types.ModuleType("naeural_core.constants") + naeural_core_business = types.ModuleType("naeural_core.business") + naeural_core_business_base = types.ModuleType("naeural_core.business.base") + naeural_core_business_base_np = types.ModuleType("naeural_core.business.base.network_processor") + + naeural_core_constants.SUPERVISOR_MIN_AVAIL_PRC = 0.8 + naeural_core_constants.EPOCH_MAX_VALUE = 100 + naeural_core_constants.ORACLE_SYNC_USE_R1FS = False + naeural_core_constants.ORACLE_SYNC_BLOCKCHAIN_PRESENCE_MIN_THRESHOLD = 0.5 + naeural_core_constants.ORACLE_SYNC_ONLINE_PRESENCE_MIN_THRESHOLD = 0.5 + + class _StubPayloadData: + EE_SENDER = "EE_SENDER" + + class _StubCt: + PAYLOAD_DATA = _StubPayloadData + + class NetworkProcessorPlugin: + CONFIG = { + "VALIDATION_RULES": {}, + } + + def __init__(self): + self.ct = _StubCt() + self._state_machines = {} + + @staticmethod + def payload_handler(): + def _decorator(fn): + return fn + return _decorator + + def P(self, msg, **kwargs): + return msg + + def state_machine_api_init(self, name, state_machine_transitions, initial_state, on_successful_step_callback): + self._state_machines[name] = { + "transitions": state_machine_transitions, + "state": initial_state, + "on_success": on_successful_step_callback, + } + + def state_machine_api_get_current_state(self, name): + return self._state_machines[name]["state"] + + def state_machine_api_set_current_state(self, name, state): + self._state_machines[name]["state"] = state + + def state_machine_api_callback_do_nothing(self): + return + + def state_machine_api_step(self, name): + state = self._state_machines[name]["state"] + transitions = self._state_machines[name]["transitions"] + state_info = transitions[state] + state_callback = state_info["STATE_CALLBACK"] + state_callback() + for transition in state_info.get("TRANSITIONS", []): + if transition["TRANSITION_CONDITION"](): + transition["ON_TRANSITION_CALLBACK"]() + self._state_machines[name]["state"] = transition["NEXT_STATE"] + break + on_success = self._state_machines[name]["on_success"] + if on_success: + on_success() + + naeural_core_business_base_np.NetworkProcessorPlugin = NetworkProcessorPlugin + + sys.modules["naeural_core"] = naeural_core + sys.modules["naeural_core.constants"] = naeural_core_constants + sys.modules["naeural_core.business"] = naeural_core_business + sys.modules["naeural_core.business.base"] = naeural_core_business_base + sys.modules["naeural_core.business.base.network_processor"] = naeural_core_business_base_np + + +_install_naeural_core_stubs() + +from extensions.business.oracle_sync.oracle_sync_01 import OracleSync01Plugin +from extensions.business.oracle_sync.sync_mixins.ora_sync_constants import ( + FULL_AVAILABILITY_THRESHOLD, + POTENTIALLY_FULL_AVAILABILITY_THRESHOLD, + VALUE_STANDARDS, + EPOCH_MAX_VALUE, + LOCAL_TABLE_SEND_MULTIPLIER, + MEDIAN_TABLE_SEND_MULTIPLIER, + REQUEST_AGREEMENT_TABLE_MULTIPLIER, + SIGNATURES_EXCHANGE_MULTIPLIER, + OracleSyncCt, + SUPERVISOR_MIN_AVAIL_PRC, +) + + +class FakeClock: + def __init__(self, start: float = 0.0): + self._now = start + + def time(self): + return self._now + + def sleep(self, dt: float): + self._now += dt + + def now(self, tz=None): + return datetime.fromtimestamp(self._now, tz=tz) + + +class FakeDateTime: + def __init__(self, clock: FakeClock): + self._clock = clock + + def now(self, tz=None): + return self._clock.now(tz=tz) + + +class FakeR1FS: + def __init__(self, warmed=True): + self.is_ipfs_warmed = warmed + self._store = {} + self._counter = 0 + self.fail_add = False + self.fail_get = False + self.raise_add = False + + def add_pickle(self, obj, show_logs=True): + if self.raise_add: + raise RuntimeError("fake add error") + if self.fail_add: + return None + self._counter += 1 + cid = f"cid_{self._counter}" + self._store[cid] = obj + return cid + + def get_file(self, cid, show_logs=True): + if self.fail_get or cid not in self._store: + raise FileNotFoundError(cid) + fd, path = tempfile.mkstemp(prefix="fake_r1fs_", suffix=".pkl") + os.close(fd) + with open(path, "wb") as f: + pickle.dump(self._store[cid], f) + return path + + +class FakeEpochManager: + def __init__(self, clock: FakeClock, epoch_length: int = 100, current_epoch: int = 2): + self.clock = clock + self.epoch_length = epoch_length + self._current_epoch = current_epoch + self._last_sync_epoch = current_epoch - 1 + self.epoch_availability = defaultdict(dict) + self.epoch_signatures = defaultdict(dict) + self.epoch_valid = defaultdict(lambda: True) + self.epoch_cids = defaultdict(dict) + self.faulty_epochs = set() + + def get_current_epoch(self): + return self._current_epoch + + def set_current_epoch(self, epoch: int): + self._current_epoch = epoch + + def get_time_epoch(self): + return self._current_epoch + + def maybe_close_epoch(self): + return + + def get_current_epoch_end(self, current_epoch): + epoch_end_ts = (current_epoch + 1) * self.epoch_length + return datetime.fromtimestamp(epoch_end_ts, tz=timezone.utc) + + def get_current_epoch_availability(self, return_absolute=True, return_max=True): + total_from_start = min(self.clock.time(), self.epoch_length) + return total_from_start, total_from_start + + def get_node_previous_epoch(self, node): + prev_epoch = self._current_epoch - 1 + return self.epoch_availability[prev_epoch].get(node, 0) + + def get_epoch_availability(self, epoch, return_additional=True): + availability = self.epoch_availability.get(epoch, {}) + signatures = self.epoch_signatures.get(epoch, {}) + cids = self.epoch_cids.get(epoch, {}) + return availability, signatures, cids.get("agreement"), cids.get("signatures") + + def update_epoch_availability(self, epoch, availability_table, agreement_signatures, debug=False, + agreement_cid=None, signatures_cid=None): + self.epoch_availability[epoch] = dict(availability_table) + self.epoch_signatures[epoch] = dict(agreement_signatures) + if agreement_cid or signatures_cid: + self.epoch_cids[epoch] = { + "agreement": agreement_cid, + "signatures": signatures_cid, + } + self.epoch_valid[epoch] = True + return True + + def mark_epoch_as_faulty(self, epoch, debug=False): + self.epoch_valid[epoch] = False + self.faulty_epochs.add(epoch) + return True + + def is_epoch_valid(self, epoch): + return self.epoch_valid[epoch] + + def get_last_sync_epoch(self): + return self._last_sync_epoch + + def set_last_sync_epoch(self, epoch): + self._last_sync_epoch = epoch + + def add_cid_for_epoch(self, epoch, agreement_cid, signatures_cid, debug=False): + self.epoch_cids[epoch] = { + "agreement": agreement_cid, + "signatures": signatures_cid, + } + + def maybe_update_cached_data(self, force=True): + return + + def save_status(self): + return + + +class FakeBlockchain: + def __init__(self, oracles, current_address): + self._oracles = list(oracles) + self.current_address = current_address + self._verify_override = None + self.calls = defaultdict(int) + + def get_oracles(self): + self.calls["get_oracles"] += 1 + return list(self._oracles), None + + def sign(self, dct, add_data=True, use_digest=True): + sender = self.current_address + if add_data: + dct["EE_SENDER"] = sender + payload = json.dumps({k: dct[k] for k in sorted(dct) if k != "EE_SIGN"}, sort_keys=True) + sig = hashlib.sha256((payload + sender).encode("utf-8")).hexdigest() + dct["EE_SIGN"] = sig + return sig + + def verify(self, dct_data, str_signature=None, sender_address=None): + if self._verify_override is not None: + return self._verify_override + sender = sender_address or dct_data.get("EE_SENDER", "") + payload = json.dumps({k: dct_data[k] for k in sorted(dct_data) if k != "EE_SIGN"}, sort_keys=True) + expected = hashlib.sha256((payload + sender).encode("utf-8")).hexdigest() + signature = dct_data.get("EE_SIGN") if str_signature is None else str_signature + valid = signature == expected + message = "valid" if valid else "invalid signature" + return types.SimpleNamespace(valid=valid, message=message) + + def maybe_add_prefix(self, addr): + return addr + + def address_is_valid(self, addr): + return isinstance(addr, str) and len(addr) > 0 + + +class FakeNetmon: + def __init__(self, epoch_manager: FakeEpochManager, oracles: list[str]): + self.epoch_manager = epoch_manager + self._oracles = list(oracles) + self.all_nodes = list(oracles) + + def network_node_eeid(self, addr): + return addr[-4:] + + def network_node_is_supervisor(self, addr): + return addr in self._oracles + + def network_node_is_online(self, addr): + return addr in self._oracles + + +class MessageBus: + def __init__(self, duplicate_rate=0.0, reorder=False, seed=123): + self._oracles = {} + self._duplicate_rate = duplicate_rate + self._reorder = reorder + self._rng = __import__("random").Random(seed) + + def add_oracle(self, node_addr, oracle): + self._oracles[node_addr] = oracle + + def broadcast(self, sender, oracle_data): + deliveries = [] + for addr, oracle in self._oracles.items(): + if addr == sender: + continue + payload = { + oracle.ct.PAYLOAD_DATA.EE_SENDER: sender, + "ORACLE_DATA": oracle_data, + } + deliveries.append((oracle, payload)) + if self._duplicate_rate > 0 and self._rng.random() < self._duplicate_rate: + deliveries.append((oracle, payload)) + if self._reorder: + self._rng.shuffle(deliveries) + for oracle, payload in deliveries: + oracle.handle_received_payloads(payload) + + +def _get_numpy_like(): + try: + import numpy as np # type: ignore + return np + except Exception: + import statistics + + class _NP: + @staticmethod + def median(values): + return statistics.median(values) + + @staticmethod + def mean(values): + return statistics.mean(values) + + class random: + @staticmethod + def choice(values): + return values[0] + + return _NP() + + +def _json_dumps(data, **kwargs): + return json.dumps(data, sort_keys=True, **kwargs) + + +def _get_hash(data, algorithm="sha256"): + h = hashlib.new(algorithm) + h.update(data.encode("utf-8")) + return h.hexdigest() + + +@dataclass +class OracleHarness: + node_addr: str + oracles: list[str] + epoch_manager: FakeEpochManager + clock: FakeClock + bus: MessageBus | None = None + use_r1fs: bool = False + use_r1fs_during_consensus: bool = False + + def build(self): + oracle = OracleSync01Plugin() + oracle._name__ = "oracle_sync_test" + oracle.node_addr = self.node_addr + oracle.time = self.clock.time + oracle.sleep = self.clock.sleep + oracle.datetime = FakeDateTime(self.clock) + oracle.timezone = timezone + oracle.deque = deque + oracle.defaultdict = defaultdict + oracle.json_dumps = _json_dumps + oracle.deepcopy = lambda d: json.loads(json.dumps(d)) + oracle.os_path = os.path + oracle.diskapi_load_pickle_from_output = lambda filename: pickle.load(open(filename, "rb")) + oracle.get_hash = _get_hash + oracle.np = _get_numpy_like() + oracle.trace_info = lambda: "trace" + oracle.get_sender_str = lambda sender: sender[-4:] if isinstance(sender, str) else str(sender) + oracle.get_elapsed_and_total_time_of_stage = lambda stage=None: (0.0, 0.0) + oracle.r1fs = FakeR1FS(warmed=True) + oracle.cfg_debug_sync = False + oracle.cfg_debug_sync_full = False + oracle.cfg_send_interval = 1 + oracle.cfg_send_period = 1 + oracle.cfg_oracle_list_refresh_interval = 10 + oracle.cfg_self_assessment_interval = 60 + oracle.cfg_use_r1fs = self.use_r1fs + oracle.cfg_use_r1fs_during_consensus = self.use_r1fs_during_consensus + + oracle.bc = FakeBlockchain(oracles=self.oracles, current_address=self.node_addr) + oracle.netmon = FakeNetmon(epoch_manager=self.epoch_manager, oracles=self.oracles) + + def _add_payload_by_fields(oracle_data): + if self.bus is not None: + self.bus.broadcast(self.node_addr, oracle_data) + + oracle.add_payload_by_fields = _add_payload_by_fields + oracle.on_init() + return oracle + + +def _set_state(oracle, state): + if hasattr(oracle, "state_machine_api_set_current_state"): + oracle.state_machine_api_set_current_state(oracle.state_machine_name, state) + elif hasattr(oracle, "_state_machines"): + oracle._state_machines[oracle.state_machine_name]["state"] = state + else: + oracle._current_state = state + + +def _sign_agreement(oracle, compiled_table, epoch): + signature_dict = { + OracleSyncCt.COMPILED_AGREED_MEDIAN_TABLE: compiled_table, + OracleSyncCt.EPOCH: epoch, + } + oracle.bc.sign(signature_dict, add_data=True, use_digest=True) + signature_dict.pop(OracleSyncCt.EPOCH) + signature_dict.pop(OracleSyncCt.COMPILED_AGREED_MEDIAN_TABLE) + return signature_dict + + +@pytest.fixture() +def fake_clock(): + return FakeClock() + + +@pytest.fixture() +def fake_epoch_manager(fake_clock): + return FakeEpochManager(clock=fake_clock, epoch_length=EPOCH_MAX_VALUE, current_epoch=2) + + +def test_constants_sanity(): + assert FULL_AVAILABILITY_THRESHOLD == round(SUPERVISOR_MIN_AVAIL_PRC * EPOCH_MAX_VALUE) + assert 0 <= POTENTIALLY_FULL_AVAILABILITY_THRESHOLD <= EPOCH_MAX_VALUE + assert LOCAL_TABLE_SEND_MULTIPLIER > 0 + assert MEDIAN_TABLE_SEND_MULTIPLIER > 0 + assert REQUEST_AGREEMENT_TABLE_MULTIPLIER > 0 + assert SIGNATURES_EXCHANGE_MULTIPLIER > 0 + assert VALUE_STANDARDS[OracleSyncCt.LOCAL_TABLE]["maybe_cid"] is True + assert VALUE_STANDARDS[OracleSyncCt.MEDIAN_TABLE]["maybe_cid"] is True + + +def test_r1fs_add_data_to_message(fake_epoch_manager, fake_clock): + harness = OracleHarness("oracle_a", ["oracle_a"], fake_epoch_manager, fake_clock, use_r1fs=True) + oracle = harness.build() + message = {} + oracle.r1fs.is_ipfs_warmed = False + oracle.r1fs_add_data_to_message(message_dict=message, data_dict={"a": 1}, data_key="K") + assert message["K"] == {"a": 1} + + oracle.r1fs.is_ipfs_warmed = True + message = {} + oracle.r1fs.fail_add = False + oracle.r1fs_add_data_to_message(message_dict=message, data_dict={"b": 2}, data_key="K") + assert isinstance(message["K"], str) + + message = {} + oracle.r1fs.fail_add = True + oracle.r1fs_add_data_to_message(message_dict=message, data_dict={"c": 3}, data_key="K") + assert message["K"] == {"c": 3} + + message = {} + oracle.r1fs.fail_add = False + oracle.r1fs.raise_add = True + oracle.r1fs_add_data_to_message(message_dict=message, data_dict={"d": 4}, data_key="K") + assert message["K"] == {"d": 4} + + +def test_r1fs_get_data_from_message(fake_epoch_manager, fake_clock): + harness = OracleHarness("oracle_a", ["oracle_a"], fake_epoch_manager, fake_clock, use_r1fs=True) + oracle = harness.build() + message = {"K": {"a": 1}} + assert oracle.r1fs_get_data_from_message(message_dict=message, data_key="K") == {"a": 1} + + cid = oracle.r1fs.add_pickle({"b": 2}) + message = {"K": cid} + assert oracle.r1fs_get_data_from_message(message_dict=message, data_key="K") == {"b": 2} + + oracle.r1fs.fail_get = True + message = {"K": "missing"} + assert oracle.r1fs_get_data_from_message(message_dict=message, data_key="K") is None + + +def test_check_received_oracle_data_for_values(fake_epoch_manager, fake_clock): + harness = OracleHarness("oracle_a", ["oracle_a"], fake_epoch_manager, fake_clock) + oracle = harness.build() + sender = "oracle_b" + bad = {"STAGE": oracle.STATES.S2_SEND_LOCAL_TABLE} + assert not oracle._check_received_oracle_data_for_values( + sender=sender, + oracle_data=bad, + expected_variable_names=[OracleSyncCt.STAGE, OracleSyncCt.LOCAL_TABLE], + ) + + good = { + OracleSyncCt.STAGE: oracle.STATES.S2_SEND_LOCAL_TABLE, + OracleSyncCt.LOCAL_TABLE: {"n1": 1}, + } + oracle.bc.sign(good, add_data=True, use_digest=True) + assert oracle._check_received_oracle_data_for_values( + sender=sender, + oracle_data=good, + expected_variable_names=[OracleSyncCt.STAGE, OracleSyncCt.LOCAL_TABLE], + expected_stage=oracle.STATES.S2_SEND_LOCAL_TABLE, + ) + + bad_type = { + OracleSyncCt.STAGE: oracle.STATES.S2_SEND_LOCAL_TABLE, + OracleSyncCt.LOCAL_TABLE: ["not-a-dict"], + } + oracle.bc.sign(bad_type, add_data=True, use_digest=True) + assert not oracle._check_received_oracle_data_for_values( + sender=sender, + oracle_data=bad_type, + expected_variable_names=[OracleSyncCt.STAGE, OracleSyncCt.LOCAL_TABLE], + expected_stage=oracle.STATES.S2_SEND_LOCAL_TABLE, + ) + + bad_stage = { + OracleSyncCt.STAGE: oracle.STATES.S4_SEND_MEDIAN_TABLE, + OracleSyncCt.LOCAL_TABLE: {"n1": 1}, + } + oracle.bc.sign(bad_stage, add_data=True, use_digest=True) + assert not oracle._check_received_oracle_data_for_values( + sender=sender, + oracle_data=bad_stage, + expected_variable_names=[OracleSyncCt.STAGE, OracleSyncCt.LOCAL_TABLE], + expected_stage=oracle.STATES.S2_SEND_LOCAL_TABLE, + ) + + bad_sig = { + OracleSyncCt.STAGE: oracle.STATES.S2_SEND_LOCAL_TABLE, + OracleSyncCt.LOCAL_TABLE: {"n1": 1}, + "EE_SIGN": "bad", + } + assert not oracle._check_received_oracle_data_for_values( + sender=sender, + oracle_data=bad_sig, + expected_variable_names=[OracleSyncCt.STAGE, OracleSyncCt.LOCAL_TABLE], + expected_stage=oracle.STATES.S2_SEND_LOCAL_TABLE, + ) + + +def test_check_received_epoch_agreed_median_table_ok(fake_epoch_manager, fake_clock): + harness = OracleHarness("oracle_a", ["oracle_a"], fake_epoch_manager, fake_clock) + oracle = harness.build() + _set_state(oracle, oracle.STATES.S0_WAIT_FOR_EPOCH_CHANGE) + oracle_data = { + OracleSyncCt.EPOCH__AGREED_MEDIAN_TABLE: {"1": {"n1": 1}, "3": {"n1": 2}}, + OracleSyncCt.EPOCH__AGREEMENT_SIGNATURES: {"1": {}, "3": {}}, + OracleSyncCt.EPOCH__IS_VALID: {"1": True, "3": True}, + OracleSyncCt.EPOCH_KEYS: [1, 3], + OracleSyncCt.STAGE: oracle.STATES.S0_WAIT_FOR_EPOCH_CHANGE, + } + assert not oracle._check_received_epoch__agreed_median_table_ok("oracle_b", oracle_data) + + oracle_data = { + OracleSyncCt.EPOCH__AGREED_MEDIAN_TABLE: {"1": {"n1": 1}, "2": {"n1": 2}}, + OracleSyncCt.EPOCH__AGREEMENT_SIGNATURES: {"1": {}}, + OracleSyncCt.EPOCH__IS_VALID: {"1": True, "2": True}, + OracleSyncCt.EPOCH_KEYS: [1, 2], + OracleSyncCt.STAGE: oracle.STATES.S0_WAIT_FOR_EPOCH_CHANGE, + } + assert not oracle._check_received_epoch__agreed_median_table_ok("oracle_b", oracle_data) + + +def test_check_too_close_to_epoch_change(fake_epoch_manager, fake_clock): + harness = OracleHarness("oracle_a", ["oracle_a"], fake_epoch_manager, fake_clock) + oracle = harness.build() + oracle._current_epoch = 1 + fake_clock._now = (oracle.netmon.epoch_manager.epoch_length * 2) - 10 + assert oracle._check_too_close_to_epoch_change(show_logs=False) + + +def test_compute_agreed_median_table_majority(fake_epoch_manager, fake_clock): + oracles = ["oracle_a", "oracle_b", "oracle_c"] + harness = OracleHarness("oracle_a", oracles, fake_epoch_manager, fake_clock) + oracle = harness.build() + oracle.is_participating = {k: True for k in oracles} + oracle.dct_median_tables = { + "oracle_a": {"n1": {"VALUE": 10, "EE_SENDER": "oracle_a"}}, + "oracle_b": {"n1": {"VALUE": 10, "EE_SENDER": "oracle_b"}}, + "oracle_c": {"n1": {"VALUE": 12, "EE_SENDER": "oracle_c"}}, + } + oracle._compute_agreed_median_table() + assert oracle.compiled_agreed_median_table["n1"] == 10 + + +def test_compute_agreed_median_table_failure(fake_epoch_manager, fake_clock): + oracles = ["oracle_a", "oracle_b", "oracle_c"] + harness = OracleHarness("oracle_a", oracles, fake_epoch_manager, fake_clock) + oracle = harness.build() + oracle.is_participating = {k: True for k in oracles} + oracle.dct_median_tables = { + "oracle_a": {"n1": {"VALUE": 10, "EE_SENDER": "oracle_a"}}, + "oracle_b": {"n1": {"VALUE": 11, "EE_SENDER": "oracle_b"}}, + "oracle_c": {"n1": {"VALUE": 12, "EE_SENDER": "oracle_c"}}, + } + oracle._compute_agreed_median_table() + assert oracle.compiled_agreed_median_table is None + + +def test_compute_requested_agreed_median_table_majority(fake_epoch_manager, fake_clock, monkeypatch): + oracles = ["oracle_a", "oracle_b", "oracle_c"] + harness = OracleHarness("oracle_a", oracles, fake_epoch_manager, fake_clock) + oracle = harness.build() + oracle._current_epoch = 4 + oracle._last_epoch_synced = 1 + oracle.dct_agreed_availability_table = { + "oracle_a": {2: {"n1": 10}, 3: {"n1": 12}}, + "oracle_b": {2: {"n1": 10}, 3: {"n1": 12}}, + "oracle_c": {2: {"n1": 11}, 3: {"n1": 13}}, + } + oracle.dct_agreed_availability_signatures = { + "oracle_a": {2: {"oracle_a": {"EE_SIGN": "sig"}}, 3: {"oracle_a": {"EE_SIGN": "sig"}}}, + "oracle_b": {2: {"oracle_b": {"EE_SIGN": "sig"}}, 3: {"oracle_b": {"EE_SIGN": "sig"}}}, + "oracle_c": {2: {"oracle_c": {"EE_SIGN": "sig"}}, 3: {"oracle_c": {"EE_SIGN": "sig"}}}, + } + oracle.dct_agreed_availability_is_valid = { + "oracle_a": {2: True, 3: True}, + "oracle_b": {2: True, 3: True}, + "oracle_c": {2: True, 3: True}, + } + oracle.dct_agreed_availability_cid = { + "oracle_a": {2: None, 3: None}, + "oracle_b": {2: None, 3: None}, + "oracle_c": {2: None, 3: None}, + } + oracle.dct_agreement_signatures_cid = { + "oracle_a": {2: None, 3: None}, + "oracle_b": {2: None, 3: None}, + "oracle_c": {2: None, 3: None}, + } + monkeypatch.setattr(oracle.np.random, "choice", lambda x: x[0]) + oracle._compute_requested_agreed_median_table() + assert fake_epoch_manager.epoch_availability[2]["n1"] == 10 + assert fake_epoch_manager.epoch_availability[3]["n1"] == 12 + + +def test_compute_requested_agreed_median_table_failure(fake_epoch_manager, fake_clock): + oracles = ["oracle_a", "oracle_b"] + harness = OracleHarness("oracle_a", oracles, fake_epoch_manager, fake_clock) + oracle = harness.build() + oracle._current_epoch = 3 + oracle._last_epoch_synced = 0 + oracle.dct_agreed_availability_table = { + "oracle_a": {1: {"n1": 10}, 2: {"n1": 11}}, + "oracle_b": {1: {"n1": 12}, 2: {"n1": 13}}, + } + oracle.dct_agreed_availability_signatures = { + "oracle_a": {1: {}, 2: {}}, + "oracle_b": {1: {}, 2: {}}, + } + oracle.dct_agreed_availability_is_valid = { + "oracle_a": {1: True, 2: True}, + "oracle_b": {1: True, 2: True}, + } + oracle.dct_agreed_availability_cid = { + "oracle_a": {1: None, 2: None}, + "oracle_b": {1: None, 2: None}, + } + oracle.dct_agreement_signatures_cid = { + "oracle_a": {1: None, 2: None}, + "oracle_b": {1: None, 2: None}, + } + oracle._compute_requested_agreed_median_table() + assert 1 in fake_epoch_manager.faulty_epochs + assert 2 in fake_epoch_manager.faulty_epochs + + +def test_message_queue_bounds(fake_epoch_manager, fake_clock): + oracles = ["oracle_a", "oracle_b"] + harness = OracleHarness("oracle_a", oracles, fake_epoch_manager, fake_clock) + oracle = harness.build() + sender = "oracle_b" + for i in range(100): + oracle.handle_received_payloads({ + oracle.ct.PAYLOAD_DATA.EE_SENDER: sender, + "ORACLE_DATA": {"idx": i}, + }) + assert len(oracle._oracle_received_messages[sender]) <= 50 + messages = oracle.get_received_messages_from_oracles() + assert len(messages) == 1 + + +def test_on_init_sets_state(fake_epoch_manager, fake_clock): + oracles = ["oracle_a"] + harness = OracleHarness("oracle_a", oracles, fake_epoch_manager, fake_clock) + oracle = harness.build() + assert oracle.state_machine_api_get_current_state(oracle.state_machine_name) == oracle.STATES.S8_SEND_REQUEST_AGREED_MEDIAN_TABLE + assert oracle._oracle_received_messages is not None + + +def test_process_exception_sets_flag(fake_epoch_manager, fake_clock, monkeypatch): + harness = OracleHarness("oracle_a", ["oracle_a"], fake_epoch_manager, fake_clock) + oracle = harness.build() + monkeypatch.setattr(oracle, "state_machine_api_step", lambda name: (_ for _ in ()).throw(RuntimeError("boom"))) + oracle.process() + assert oracle.exception_occurred + + +def test_maybe_refresh_oracle_list_rate_limit(fake_epoch_manager, fake_clock): + harness = OracleHarness("oracle_a", ["oracle_a"], fake_epoch_manager, fake_clock) + oracle = harness.build() + oracle._last_oracle_list_refresh_attempt = oracle.time() + oracle.bc.calls["get_oracles"] = 0 + oracle.maybe_refresh_oracle_list() + assert oracle.bc.calls["get_oracles"] == 0 + + +def test_scenario_a_happy_path_consensus(fake_epoch_manager, fake_clock): + oracles = ["oracle_a", "oracle_b", "oracle_c"] + bus = MessageBus() + harnesses = [ + OracleHarness(addr, oracles, fake_epoch_manager, fake_clock, bus=bus) + for addr in oracles + ] + instances = [h.build() for h in harnesses] + for oracle in instances: + bus.add_oracle(oracle.node_addr, oracle) + + prev_epoch = fake_epoch_manager.get_current_epoch() - 1 + fake_epoch_manager.set_last_sync_epoch(prev_epoch - 1) + for oracle in instances: + oracle._last_epoch_synced = prev_epoch - 1 + for node in oracles: + fake_epoch_manager.epoch_availability[prev_epoch][node] = FULL_AVAILABILITY_THRESHOLD + 1 + fake_epoch_manager.epoch_availability[prev_epoch]["node_x"] = 40 + + for oracle in instances: + _set_state(oracle, oracle.STATES.S11_ANNOUNCE_PARTICIPANTS) + oracle._announce_and_observe_participants() + for oracle in instances: + oracle._announce_and_observe_participants() + + for oracle in instances: + _set_state(oracle, oracle.STATES.S1_COMPUTE_LOCAL_TABLE) + oracle._compute_local_table() + + for oracle in instances: + _set_state(oracle, oracle.STATES.S2_SEND_LOCAL_TABLE) + oracle._receive_local_table_and_maybe_send_local_table() + for oracle in instances: + oracle._receive_local_table_and_maybe_send_local_table() + + for oracle in instances: + _set_state(oracle, oracle.STATES.S3_COMPUTE_MEDIAN_TABLE) + oracle._compute_median_table() + + for oracle in instances: + _set_state(oracle, oracle.STATES.S4_SEND_MEDIAN_TABLE) + oracle._receive_median_table_and_maybe_send_median_table() + for oracle in instances: + oracle._receive_median_table_and_maybe_send_median_table() + + for oracle in instances: + _set_state(oracle, oracle.STATES.S5_COMPUTE_AGREED_MEDIAN_TABLE) + oracle._compute_agreed_median_table() + assert oracle.compiled_agreed_median_table is not None + + compiled = instances[0].compiled_agreed_median_table + for oracle in instances: + assert oracle.compiled_agreed_median_table == compiled + + epoch = fake_epoch_manager.get_current_epoch() - 1 + for oracle in instances: + oracle.compiled_agreed_median_table_signatures[oracle.node_addr] = _sign_agreement( + oracle, compiled, epoch + ) + instances[0]._update_epoch_manager_with_agreed_median_table( + epoch=epoch, + compiled_agreed_median_table=compiled, + agreement_signatures=instances[0].compiled_agreed_median_table_signatures, + epoch_is_valid=True, + ) + assert fake_epoch_manager.epoch_availability[epoch] == compiled + + +def test_scenario_b_one_oracle_cannot_participate(fake_epoch_manager, fake_clock, monkeypatch): + oracles = ["oracle_a", "oracle_b", "oracle_c"] + bus = MessageBus() + harnesses = [ + OracleHarness(addr, oracles, fake_epoch_manager, fake_clock, bus=bus) + for addr in oracles + ] + instances = [h.build() for h in harnesses] + for oracle in instances: + bus.add_oracle(oracle.node_addr, oracle) + + prev_epoch = fake_epoch_manager.get_current_epoch() - 1 + for node in oracles: + fake_epoch_manager.epoch_availability[prev_epoch][node] = FULL_AVAILABILITY_THRESHOLD + 1 + fake_epoch_manager.epoch_availability[prev_epoch]["oracle_c"] = FULL_AVAILABILITY_THRESHOLD - 1 + + for oracle in instances: + _set_state(oracle, oracle.STATES.S11_ANNOUNCE_PARTICIPANTS) + oracle._announce_and_observe_participants() + for oracle in instances: + oracle._announce_and_observe_participants() + + for oracle in instances: + _set_state(oracle, oracle.STATES.S1_COMPUTE_LOCAL_TABLE) + oracle._compute_local_table() + + for oracle in instances: + _set_state(oracle, oracle.STATES.S2_SEND_LOCAL_TABLE) + oracle._receive_local_table_and_maybe_send_local_table() + for oracle in instances: + oracle._receive_local_table_and_maybe_send_local_table() + + for oracle in instances: + _set_state(oracle, oracle.STATES.S3_COMPUTE_MEDIAN_TABLE) + oracle._compute_median_table() + + for oracle in instances: + _set_state(oracle, oracle.STATES.S4_SEND_MEDIAN_TABLE) + oracle._receive_median_table_and_maybe_send_median_table() + for oracle in instances: + oracle._receive_median_table_and_maybe_send_median_table() + + for oracle in instances: + _set_state(oracle, oracle.STATES.S5_COMPUTE_AGREED_MEDIAN_TABLE) + oracle._compute_agreed_median_table() + + compiled = instances[0].compiled_agreed_median_table + assert compiled is not None + + non_participant = instances[2] + non_participant._current_epoch = 3 + non_participant._last_epoch_synced = 1 + non_participant.dct_agreed_availability_table = { + "oracle_a": {2: compiled}, + "oracle_b": {2: compiled}, + } + non_participant.dct_agreed_availability_signatures = { + "oracle_a": {2: {"oracle_a": {"EE_SIGN": "sig"}}}, + "oracle_b": {2: {"oracle_b": {"EE_SIGN": "sig"}}}, + } + non_participant.dct_agreed_availability_is_valid = { + "oracle_a": {2: True}, + "oracle_b": {2: True}, + } + non_participant.dct_agreed_availability_cid = { + "oracle_a": {2: None}, + "oracle_b": {2: None}, + } + non_participant.dct_agreement_signatures_cid = { + "oracle_a": {2: None}, + "oracle_b": {2: None}, + } + monkeypatch.setattr(non_participant.np.random, "choice", lambda x: x[0]) + non_participant._compute_requested_agreed_median_table() + assert fake_epoch_manager.epoch_availability[2] == compiled + + +def test_scenario_c_disordered_and_duplicated_messages(fake_epoch_manager, fake_clock): + oracles = ["oracle_a", "oracle_b", "oracle_c"] + bus = MessageBus(duplicate_rate=0.5, reorder=True) + harnesses = [ + OracleHarness(addr, oracles, fake_epoch_manager, fake_clock, bus=bus) + for addr in oracles + ] + instances = [h.build() for h in harnesses] + for oracle in instances: + bus.add_oracle(oracle.node_addr, oracle) + + prev_epoch = fake_epoch_manager.get_current_epoch() - 1 + for node in oracles: + fake_epoch_manager.epoch_availability[prev_epoch][node] = FULL_AVAILABILITY_THRESHOLD + 2 + + for oracle in instances: + _set_state(oracle, oracle.STATES.S11_ANNOUNCE_PARTICIPANTS) + oracle._announce_and_observe_participants() + for oracle in instances: + oracle._announce_and_observe_participants() + + for oracle in instances: + _set_state(oracle, oracle.STATES.S1_COMPUTE_LOCAL_TABLE) + oracle._compute_local_table() + + for oracle in instances: + _set_state(oracle, oracle.STATES.S2_SEND_LOCAL_TABLE) + oracle._receive_local_table_and_maybe_send_local_table() + for oracle in instances: + oracle._receive_local_table_and_maybe_send_local_table() + + for oracle in instances: + _set_state(oracle, oracle.STATES.S3_COMPUTE_MEDIAN_TABLE) + oracle._compute_median_table() + + for oracle in instances: + _set_state(oracle, oracle.STATES.S4_SEND_MEDIAN_TABLE) + oracle._receive_median_table_and_maybe_send_median_table() + for oracle in instances: + oracle._receive_median_table_and_maybe_send_median_table() + + for oracle in instances: + _set_state(oracle, oracle.STATES.S5_COMPUTE_AGREED_MEDIAN_TABLE) + oracle._compute_agreed_median_table() + assert oracle.compiled_agreed_median_table is not None + + +def test_scenario_d_invalid_signature_rejected(fake_epoch_manager, fake_clock): + harness = OracleHarness("oracle_a", ["oracle_a", "oracle_b"], fake_epoch_manager, fake_clock) + oracle = harness.build() + oracle_data = { + OracleSyncCt.STAGE: oracle.STATES.S2_SEND_LOCAL_TABLE, + OracleSyncCt.LOCAL_TABLE: {"n1": 1}, + } + oracle.bc.sign(oracle_data, add_data=True, use_digest=True) + oracle_data[OracleSyncCt.LOCAL_TABLE]["n1"] = 999 + assert not oracle._check_received_oracle_data_for_values( + sender="oracle_b", + oracle_data=oracle_data, + expected_variable_names=[OracleSyncCt.STAGE, OracleSyncCt.LOCAL_TABLE], + expected_stage=oracle.STATES.S2_SEND_LOCAL_TABLE, + ) + + +def test_scenario_e_r1fs_cid_success_and_failure(fake_epoch_manager, fake_clock): + harness = OracleHarness("oracle_a", ["oracle_a", "oracle_b"], fake_epoch_manager, fake_clock, use_r1fs=True) + oracle = harness.build() + oracle.r1fs.is_ipfs_warmed = True + msg = {} + oracle.r1fs_add_data_to_message(msg, {"n1": 1}, "DATA") + cid = msg["DATA"] + assert isinstance(cid, str) + assert oracle.r1fs_get_data_from_message(msg, "DATA") == {"n1": 1} + + oracle.r1fs.fail_get = True + msg = {"DATA": cid} + assert oracle.r1fs_get_data_from_message(msg, "DATA") is None + + +def test_scenario_f_historical_sync_multi_epoch(fake_epoch_manager, fake_clock, monkeypatch): + oracles = ["oracle_a", "oracle_b", "oracle_c"] + harness = OracleHarness("oracle_a", oracles, fake_epoch_manager, fake_clock) + oracle = harness.build() + oracle._current_epoch = 5 + oracle._last_epoch_synced = 1 + oracle.dct_agreed_availability_table = { + "oracle_a": {2: {"n1": 10}, 3: {"n1": 11}, 4: {"n1": 12}}, + "oracle_b": {2: {"n1": 10}, 3: {"n1": 11}, 4: {"n1": 12}}, + "oracle_c": {2: {"n1": 10}, 3: {"n1": 11}, 4: {"n1": 12}}, + } + oracle.dct_agreed_availability_signatures = { + "oracle_a": {2: {}, 3: {}, 4: {}}, + "oracle_b": {2: {}, 3: {}, 4: {}}, + "oracle_c": {2: {}, 3: {}, 4: {}}, + } + oracle.dct_agreed_availability_is_valid = { + "oracle_a": {2: True, 3: True, 4: True}, + "oracle_b": {2: True, 3: True, 4: True}, + "oracle_c": {2: True, 3: True, 4: True}, + } + oracle.dct_agreed_availability_cid = { + "oracle_a": {2: None, 3: None, 4: None}, + "oracle_b": {2: None, 3: None, 4: None}, + "oracle_c": {2: None, 3: None, 4: None}, + } + oracle.dct_agreement_signatures_cid = { + "oracle_a": {2: None, 3: None, 4: None}, + "oracle_b": {2: None, 3: None, 4: None}, + "oracle_c": {2: None, 3: None, 4: None}, + } + monkeypatch.setattr(oracle.np.random, "choice", lambda x: x[0]) + oracle._compute_requested_agreed_median_table() + assert fake_epoch_manager.epoch_availability[2]["n1"] == 10 + assert fake_epoch_manager.epoch_availability[3]["n1"] == 11 + assert fake_epoch_manager.epoch_availability[4]["n1"] == 12 + + +if __name__ == "__main__": + raise SystemExit(pytest.main([__file__])) From 9dd5e73cbd65322870ab608ff5341a6a149a9790 Mon Sep 17 00:00:00 2001 From: Cristi Bleotiu Date: Wed, 28 Jan 2026 16:17:12 +0200 Subject: [PATCH 2/5] fix: rollback on .devcontainer --- .devcontainer/devcontainer.json | 61 -------------------------------- .devcontainer/sparse-exclude.txt | 4 --- .devcontainer/sparse-include.txt | 2 -- ver.py | 2 +- 4 files changed, 1 insertion(+), 68 deletions(-) delete mode 100644 .devcontainer/sparse-exclude.txt delete mode 100644 .devcontainer/sparse-include.txt diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index e6ebcdd7..67f82a05 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -17,70 +17,9 @@ "--privileged" ], - // mount your Windows ~/.ssh read-only into the container - "mounts": [ - "source=${localEnv:USERPROFILE}/.ssh,target=/host-ssh,type=bind,consistency=cached,readonly" - ], - - // Wait for the clone+sparse step to finish before the IDE connects - "waitFor": "onCreateCommand", - - // Set your repo URL once here (or pass it via env) - "containerEnv": { - "REPO_URL1": "git@github.com-cristibleotiu:Ratio1/edge_node", - "REPO_URL": "https://github.com/Ratio1/edge_node", - "R1_SSH_HOST": "github.com", - - "R1_SSH_ALIAS": "github.com-cristibleotiu", - - // where the mounted keys live in the container - "R1_SSH_KEY_SOURCE_DIR": "${localEnv:USERPROFILE}/.ssh", - // filename of your private key on Windows (e.g. id_ed25519 or my_github_key) - "R1_SSH_KEY_FILENAME": "github-cristibleotiu", - - // where to place the key inside the container - "R1_SSH_DIR": "/root/.ssh", - "R1_SSH_KEY_DEST": "github-cristibleotiu", - - // sparse rule files inside the repo (adjust if you keep them elsewhere) - "R1_SPARSE_EXCLUDE_FILE": ".devcontainer/sparse-exclude.txt", - "R1_SPARSE_INCLUDE_FILE": ".devcontainer/sparse-include.txt" - }, - -// "onCreateCommand": [ -// "bash ls -all", -// // strict shell + ensure ~/.ssh with safe perms -// "bash -lc \"set -euo pipefail; mkdir -p \\\"$R1_SSH_DIR\\\"; chmod 700 \\\"$R1_SSH_DIR\\\"\"", -// -// // copy the Windows key into place with correct perms (public key optional) -// "bash -lc \"cp \\\"$R1_SSH_KEY_SOURCE_DIR/$R1_SSH_KEY_FILENAME\\\" \\\"$R1_SSH_DIR/$R1_SSH_KEY_DEST\\\"; chmod 600 \\\"$R1_SSH_DIR/$R1_SSH_KEY_DEST\\\"; if [ -f \\\"$R1_SSH_KEY_SOURCE_DIR/$R1_SSH_KEY_FILENAME.pub\\\" ]; then cp \\\"$R1_SSH_KEY_SOURCE_DIR/$R1_SSH_KEY_FILENAME.pub\\\" \\\"$R1_SSH_DIR/$R1_SSH_KEY_DEST.pub\\\"; fi\"", -// -// // write SSH config for the alias host (e.g., github.com-) → real host + key -// "bash -lc \"cat > \\\"$R1_SSH_DIR/config\\\" <> \\\"$R1_SSH_DIR/known_hosts\\\" || true\"", -// -// // blobless (partial) clone via the alias URL (e.g., git@github.com-:org/repo) -// "bash -lc \"cd ${containerWorkspaceFolder} && git clone --filter=blob:none --sparse \\\"$R1_REPO_SSH_URL\\\" .\"", -// -// // non-cone sparse mode so we can use excludes and fine-grained files -// "bash -lc \"git -C ${containerWorkspaceFolder} sparse-checkout init --no-cone\"", -// -// // include everything, then apply excludes from file, then optional re-includes -// "bash -lc \"git -C ${containerWorkspaceFolder} sparse-checkout set --no-cone '/*' $(sed -E '/^\\s*(#|$)/d; s/^/!/' \\\"${containerWorkspaceFolder}/$R1_SPARSE_EXCLUDE_FILE\\\" 2>/dev/null || true) $(sed -E '/^\\s*(#|$)/d' \\\"${containerWorkspaceFolder}/$R1_SPARSE_INCLUDE_FILE\\\" 2>/dev/null || true)\"", -// -// // drop now-excluded files from the working tree -// "bash -lc \"git -C ${containerWorkspaceFolder} clean -ffdqx || true\"" -// ], - "build": { -// "dockerfile": "Dockerfile", "context": "../", -// "options": [ -// "--progress=plain" -// ] }, diff --git a/.devcontainer/sparse-exclude.txt b/.devcontainer/sparse-exclude.txt deleted file mode 100644 index fa815542..00000000 --- a/.devcontainer/sparse-exclude.txt +++ /dev/null @@ -1,4 +0,0 @@ -/data/ -/logs/ -/third_party/huge_lib/ -/weights/ diff --git a/.devcontainer/sparse-include.txt b/.devcontainer/sparse-include.txt deleted file mode 100644 index d5579c3d..00000000 --- a/.devcontainer/sparse-include.txt +++ /dev/null @@ -1,2 +0,0 @@ -/data/private/secret.json # keep this file even though /data/ is excluded -/weights/README.md # another example file inside an excluded dir diff --git a/ver.py b/ver.py index 29ec0381..8fb05202 100644 --- a/ver.py +++ b/ver.py @@ -1 +1 @@ -__VER__ = '2.10.13' +__VER__ = '2.10.20' From 72fadca7ed9c14a506c62ec6371e4e8c591245a2 Mon Sep 17 00:00:00 2001 From: Vitalii <87299468+toderian@users.noreply.github.com> Date: Wed, 28 Jan 2026 16:40:44 +0200 Subject: [PATCH 3/5] Cerviguard inference (#348) * fix: use models loading for cerviguard * fix: use actual models for inference * fix: cerviguard response body * fix: exclude some classes from validation * chore: inc version * chore: upd version --- .../cerviguard/cerviguard_image_analyzer.py | 789 +++++++++++++----- plugins/business/cerviguard/cerviguard_api.py | 185 ++-- ver.py | 2 +- 3 files changed, 665 insertions(+), 311 deletions(-) diff --git a/extensions/serving/cerviguard/cerviguard_image_analyzer.py b/extensions/serving/cerviguard/cerviguard_image_analyzer.py index 413056b0..c5d9b3ec 100644 --- a/extensions/serving/cerviguard/cerviguard_image_analyzer.py +++ b/extensions/serving/cerviguard/cerviguard_image_analyzer.py @@ -30,10 +30,44 @@ import base64 from io import BytesIO from PIL import Image +import numpy as np +import torch from naeural_core.serving.base import ModelServingProcess as BaseServingProcess -__VER__ = '0.1.2' +__VER__ = '0.2.0' + + +class SimpleImageProcessor: + """Simple image processor for custom CNN models.""" + + def __init__(self, size=(256, 256)): + self.size = size if isinstance(size, tuple) else (size, size) + + def __call__(self, images, return_tensors="pt"): + if not isinstance(images, list): + images = [images] + + processed = [] + for img in images: + if isinstance(img, np.ndarray): + img = Image.fromarray(img) + + # Resize to expected size (height, width) -> PIL uses (width, height) + img = img.resize((self.size[1], self.size[0]), Image.Resampling.BILINEAR) + + # Convert to tensor and normalize to [0, 1] + img_array = np.array(img).astype(np.float32) / 255.0 + + # HWC to CHW + if img_array.ndim == 3: + img_array = img_array.transpose(2, 0, 1) + + processed.append(torch.tensor(img_array)) + + if return_tensors == "pt": + return {"pixel_values": torch.stack(processed)} + return processed _CONFIG = { **BaseServingProcess.CONFIG, @@ -44,6 +78,50 @@ # Allow running without input for initialization "RUNS_ON_EMPTY_INPUT": False, + # Image validation settings (uses model_1/ImageNet to detect non-medical images) + # If ImageNet confidently recognizes an object, it's not a cervical image + # Note: Medical images can trigger moderate confidence on various classes (e.g., tissue + # colors triggering "lipstick"), so threshold should be high to avoid false positives + "IMAGE_VALIDATION_ENABLED": True, + "IMAGE_VALIDATION_CONFIDENCE_THRESHOLD": 0.85, # 85% - reject only when ImageNet is very confident + + # ImageNet classes to skip during validation (these commonly trigger false positives on medical images) + # Add class names exactly as they appear in ImageNet labels + "IMAGE_VALIDATION_SKIP_CLASSES": [ + "lipstick, lip rouge", + "Band Aid", + "shower curtain", + "velvet", + "wool, woolen, woollen", + "theater curtain, theatre curtain", + "window shade", + "pill bottle", + "rubber eraser, rubber, eraser", + "handkerchief, hankie, hanky, hankey", + ], + + # HuggingFace model configurations + # Model 1: Lightweight ImageNet classifier (MobileNetV2) - used for image validation + "MODEL_1_NAME": "google/mobilenet_v2_1.0_224", + "MODEL_1_ENABLED": True, + "MODEL_1_TYPE": "huggingface", # "huggingface" or "custom" + "MODEL_1_CLASS_NAME": None, # Only needed for custom models + + # Model 2: Custom CerviGuard lesion model + "MODEL_2_NAME": None, # e.g., "toderian/cerviguard_lesion" + "MODEL_2_ENABLED": False, + "MODEL_2_TYPE": "custom", # Custom model with model.py + "MODEL_2_CLASS_NAME": "CervicalCancerCNN", # Class name in model.py + + # Model 3: Custom CerviGuard transfer zones model + "MODEL_3_NAME": None, # e.g., "toderian/cerviguard_transfer_zones" + "MODEL_3_ENABLED": False, + "MODEL_3_TYPE": "custom", # Custom model with model.py + "MODEL_3_CLASS_NAME": "BaseCNN", # Class name in model.py + + # HuggingFace token for private models (optional) + "HF_TOKEN": None, + 'VALIDATION_RULES': { **BaseServingProcess.CONFIG['VALIDATION_RULES'], }, @@ -68,43 +146,435 @@ def on_init(self): """ super(CerviguardImageAnalyzer, self).on_init() self._processed_count = 0 - self.rng = self.np.random.default_rng() - self.base_risks = {'none': 10, 'low': 30, 'moderate': 55, 'high': 75} - self.tz_descriptions = { - 'Type 0': 'Type 0 transformation zone (normal-appearing cervix, no visible lesions).', - 'Type 1': 'Type 1 transformation zone (fully ectocervical and fully visible).', - 'Type 2': 'Type 2 transformation zone (partly endocervical but fully visible).', - 'Type 3': 'Type 3 transformation zone (endocervical and not fully visible).' - } - self.lesion_text = { - 'none': 'No significant acetowhite or vascular changes seen.', - 'low': 'Minor acetowhite changes with regular vascular patterns; low-grade lesion possible.', - 'moderate': 'Acetowhite epithelium with irregular vessels; moderate-grade lesion suspected.', - 'high': 'Dense acetowhite areas with atypical vessels; high-grade lesion suspected.' - } - self.lesion_templates = { - 'Type 3': { - 'none': 'No obvious ectocervical lesions, but assessment is limited because the transformation zone is not fully visible; colposcopy with endocervical evaluation is recommended.', - 'low': 'Subtle acetowhite change seen on the ectocervix; Type 3 zone limits visualization—colposcopy/endocervical sampling advised.', - 'moderate': 'Suspicious acetowhite and vascular changes with a Type 3 zone; colposcopy and endocervical assessment recommended.', - 'high': 'Marked high-grade features with a Type 3 zone; urgent colposcopy with endocervical evaluation recommended.' - }, - 'Type 0': { - 'none': 'No lesions detected; cervix appears normal.', - 'low': 'Minor findings noted, but overall appearance is normal; routine screening advised.', - 'moderate': 'Patchy findings with otherwise normal cervix; consider follow-up colposcopy.', - 'high': 'Focal concerning area despite overall normal appearance; colposcopy recommended.' - }, - 'default': { - 'none': f"{self.lesion_text['none']} Routine screening appropriate.", - 'low': f"{self.lesion_text['low']} Follow-up in 6-12 months recommended.", - 'moderate': f"{self.lesion_text['moderate']} Colposcopy and biopsy recommended.", - 'high': f"{self.lesion_text['high']} Immediate colposcopy and biopsy strongly recommended." - } - } + + # Initialize model containers + self.model_1 = None + self.model_2 = None + self.model_3 = None + self.processor_1 = None + self.processor_2 = None + self.processor_3 = None + self.model_1_type = None + self.model_2_type = None + self.model_3_type = None + + # Load HuggingFace models + self._load_hf_models() + self.P("CerviGuard Image Analyzer initialized", color='g') self.P(f" Version: {__VER__}", color='g') self.P(f" Accepts STRUCT_DATA input (base64 images)", color='g') + if self.cfg_image_validation_enabled: + self.P(f" Image validation: ENABLED (threshold: {self.cfg_image_validation_confidence_threshold:.0%})", color='g') + else: + self.P(f" Image validation: DISABLED", color='y') + + return + + def _validate_image_content(self, model_1_result): + """ + Validate image content using ImageNet classification confidence. + + Medical/cervical images confuse ImageNet, resulting in low confidence + spread across random classes. Real everyday objects (cats, cars, etc.) + get high confidence classifications. + + Parameters + ---------- + model_1_result : dict + Result from model_1 (ImageNet) inference. + + Returns + ------- + dict + Validation result with 'valid' bool and 'reason' string. + """ + if not self.cfg_image_validation_enabled: + return {'valid': True, 'reason': 'Validation disabled'} + + if model_1_result is None or 'error' in model_1_result: + return {'valid': True, 'reason': 'ImageNet model not available'} + + top_confidence = model_1_result.get('top_confidence', 0) + top_label = model_1_result.get('top_label', 'unknown') + threshold = self.cfg_image_validation_confidence_threshold + skip_classes = self.cfg_image_validation_skip_classes or [] + + # Check if detected class is in skip list (case-insensitive comparison) + top_label_lower = top_label.lower() + for skip_class in skip_classes: + if skip_class.lower() == top_label_lower: + return { + 'valid': True, + 'reason': f"Image passed validation ('{top_label}' is in skip list)", + 'top_label': top_label, + 'confidence': top_confidence, + 'skipped': True + } + + if top_confidence >= threshold: + return { + 'valid': False, + 'reason': f"Image rejected: ImageNet detected '{top_label}' with {top_confidence:.1%} confidence. " + f"This does not appear to be a valid cervical image.", + 'detected_label': top_label, + 'confidence': top_confidence + } + + return { + 'valid': True, + 'reason': f"Image passed validation (ImageNet confidence {top_confidence:.1%} < {threshold:.0%} threshold)", + 'top_label': top_label, + 'confidence': top_confidence + } + + def _get_cache_dir(self): + """Get the cache directory for HuggingFace models.""" + return self.log.get_models_folder() + + def _get_device(self): + """Get the device for model inference (GPU if available, else CPU).""" + try: + import torch as th + if th.cuda.is_available(): + return th.device('cuda') + return th.device('cpu') + except Exception: + return 'cpu' + + def _load_module_from_path(self, module_name, file_path): + """Load a Python module from a file path (avoids import caching issues).""" + import importlib.util + spec = importlib.util.spec_from_file_location(module_name, file_path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + + def _load_standard_hf_model(self, model_name, model_num): + """Load a standard HuggingFace transformers model.""" + try: + from transformers import AutoImageProcessor, AutoModelForImageClassification + + cache_dir = self._get_cache_dir() + hf_token = self.cfg_hf_token + + self.P(f"Loading Model {model_num}: {model_name} (Standard HF)...", color='b') + self.P(f" Cache directory: {cache_dir}", color='b') + + # Load the image processor + processor = AutoImageProcessor.from_pretrained( + model_name, + cache_dir=cache_dir, + token=hf_token, + ) + + # Load the model + model = AutoModelForImageClassification.from_pretrained( + model_name, + cache_dir=cache_dir, + token=hf_token, + ) + + # Move model to device + device = self._get_device() + model = model.to(device) + model.eval() + + self.P(f" Model {model_num} loaded successfully on {device}", color='g') + self.P(f" Model labels: {len(model.config.id2label)} classes", color='g') + + return model, processor, "huggingface" + + except Exception as e: + self.P(f"Error loading Model {model_num} ({model_name}): {e}", color='r') + return None, None, None + + def _load_custom_model(self, model_name, model_num, class_name): + """ + Load a custom model from HuggingFace with model.py definition. + + Parameters + ---------- + model_name : str + HuggingFace model ID (e.g., 'toderian/cerviguard_lesion') + model_num : int + Model number (1, 2, or 3) for logging + class_name : str + Name of the model class in model.py (e.g., 'CervicalCancerCNN') + + Returns + ------- + tuple + (model, processor, model_type) or (None, None, None) if loading fails + """ + try: + import torch as th + import json + from pathlib import Path + from huggingface_hub import snapshot_download + + cache_dir = self._get_cache_dir() + hf_token = self.cfg_hf_token + + self.P(f"Loading Model {model_num}: {model_name} (Custom: {class_name})...", color='b') + self.P(f" Cache directory: {cache_dir}", color='b') + + # Download model files + model_dir = snapshot_download( + repo_id=model_name, + cache_dir=cache_dir, + token=hf_token, + ) + model_path = Path(model_dir) + self.P(f" Downloaded to: {model_dir}", color='b') + + # Load config + config_path = model_path / "config.json" + with open(config_path, 'r') as f: + config = json.load(f) + self.P(f" Config model_type: {config.get('model_type', 'unknown')}", color='b') + + # Load the model module dynamically + model_py = model_path / "model.py" + if not model_py.exists(): + raise FileNotFoundError(f"model.py not found in {model_path}") + + unique_module_name = f"model_{model_name.replace('/', '_')}" + model_module = self._load_module_from_path(unique_module_name, model_py) + + # Get the model class + if not hasattr(model_module, class_name): + available = [x for x in dir(model_module) if not x.startswith('_')] + raise AttributeError(f"Class '{class_name}' not found. Available: {available}") + + ModelClass = getattr(model_module, class_name) + + # Check if there's a from_pretrained method + if hasattr(ModelClass, 'from_pretrained'): + device = str(self._get_device()) + model = ModelClass.from_pretrained(str(model_path), device=device) + self.P(f" Loaded via {class_name}.from_pretrained()", color='g') + else: + # Create model and load weights manually + model_config = config.get('model_config', {}) + if class_name == "CervicalCancerCNN": + model = ModelClass(config=model_config) + else: + model = ModelClass(**model_config) + + # Load weights + safetensors_path = model_path / 'model.safetensors' + pytorch_path = model_path / 'pytorch_model.bin' + + if safetensors_path.exists(): + try: + from safetensors.torch import load_file + state_dict = load_file(str(safetensors_path)) + except ImportError: + state_dict = th.load(pytorch_path, map_location='cpu', weights_only=True) + elif pytorch_path.exists(): + state_dict = th.load(pytorch_path, map_location='cpu', weights_only=True) + else: + raise FileNotFoundError("No model weights found") + + model.load_state_dict(state_dict) + self.P(f" Loaded weights manually", color='g') + + device = self._get_device() + model = model.to(device) + model.eval() + + # Get labels + id2label = config.get('id2label', {}) + if not id2label and hasattr(ModelClass, 'CLASSES'): + id2label = ModelClass.CLASSES + model.id2label = {int(k) if isinstance(k, str) else k: v for k, v in id2label.items()} + + # Get input size from config + input_size = (256, 256) + if 'input_size' in config: + size_config = config['input_size'] + if isinstance(size_config, dict): + input_size = (int(size_config['height']), int(size_config['width'])) + else: + input_size = tuple(int(x) for x in size_config) + elif class_name == "CervicalCancerCNN": + input_size = (224, 298) + + processor = SimpleImageProcessor(size=input_size) + + self.P(f" Model {model_num} loaded on {device}, classes: {model.id2label}", color='g') + return model, processor, "custom" + + except Exception as e: + self.P(f"Error loading Model {model_num} ({model_name}): {e}", color='r') + import traceback + self.P(traceback.format_exc(), color='r') + return None, None, None + + def _load_single_model(self, model_name, model_num, model_type, class_name=None): + """ + Load a single model based on its type. + + Parameters + ---------- + model_name : str + HuggingFace model ID + model_num : int + Model number (1, 2, or 3) + model_type : str + "huggingface" for standard HF models, "custom" for custom models + class_name : str, optional + Class name for custom models + + Returns + ------- + tuple + (model, processor, model_type) or (None, None, None) if loading fails + """ + if model_type == "huggingface": + return self._load_standard_hf_model(model_name, model_num) + else: + return self._load_custom_model(model_name, model_num, class_name) + + def _load_hf_models(self): + """ + Load all configured HuggingFace models. + Models are loaded based on their enabled status in config. + """ + self.P("=" * 50, color='b') + self.P("Loading HuggingFace Models...", color='b') + self.P("=" * 50, color='b') + + # Model 1 + if self.cfg_model_1_enabled and self.cfg_model_1_name: + self.model_1, self.processor_1, self.model_1_type = self._load_single_model( + self.cfg_model_1_name, + model_num=1, + model_type=self.cfg_model_1_type or "huggingface", + class_name=self.cfg_model_1_class_name, + ) + else: + self.P("Model 1: Disabled or not configured", color='y') + self.model_1_type = None + + # Model 2 + if self.cfg_model_2_enabled and self.cfg_model_2_name: + self.model_2, self.processor_2, self.model_2_type = self._load_single_model( + self.cfg_model_2_name, + model_num=2, + model_type=self.cfg_model_2_type or "custom", + class_name=self.cfg_model_2_class_name, + ) + else: + self.P("Model 2: Disabled or not configured", color='y') + self.model_2_type = None + + # Model 3 + if self.cfg_model_3_enabled and self.cfg_model_3_name: + self.model_3, self.processor_3, self.model_3_type = self._load_single_model( + self.cfg_model_3_name, + model_num=3, + model_type=self.cfg_model_3_type or "custom", + class_name=self.cfg_model_3_class_name, + ) + else: + self.P("Model 3: Disabled or not configured", color='y') + self.model_3_type = None + + # Summary + loaded_count = sum([ + self.model_1 is not None, + self.model_2 is not None, + self.model_3 is not None + ]) + self.P("=" * 50, color='b') + self.P(f"Model loading complete: {loaded_count}/3 models loaded", color='g') + self.P("=" * 50, color='b') + + return + + + def _run_model_inference(self, img_array, model, processor, model_type, model_name="model"): + """ + Run inference on a single image using a loaded model. + + Parameters + ---------- + img_array : np.ndarray + Image as numpy array (HWC format) + model : nn.Module + Loaded model (HuggingFace or custom) + processor : ImageProcessor + Image processor (HuggingFace AutoImageProcessor or SimpleImageProcessor) + model_type : str + "huggingface" or "custom" + model_name : str + Name for logging + + Returns + ------- + dict + Inference results with predictions and probabilities + """ + if model is None or processor is None: + return {'error': f'{model_name} not loaded'} + + try: + import torch as th + + # Convert numpy array to PIL Image for processor + pil_image = Image.fromarray(img_array) + + # Preprocess the image + inputs = processor(images=pil_image, return_tensors="pt") + + # Move inputs to same device as model + device = next(model.parameters()).device + inputs = {k: v.to(device) for k, v in inputs.items()} + + # Run inference - different handling for HuggingFace vs custom models + with th.no_grad(): + if model_type == "huggingface": + outputs = model(**inputs) + logits = outputs.logits + else: + # Custom models return logits directly + logits = model(inputs["pixel_values"]) + + # Get predictions + probabilities = th.nn.functional.softmax(logits, dim=-1) + + # Get top predictions + top_k = min(5, probabilities.shape[-1]) + top_probs, top_indices = th.topk(probabilities[0], top_k) + + # Get labels based on model type + if model_type == "huggingface": + id2label = model.config.id2label + else: + id2label = getattr(model, 'id2label', {}) + + predictions = [] + for prob, idx in zip(top_probs.cpu().numpy(), top_indices.cpu().numpy()): + label = id2label.get(int(idx), f"class_{idx}") + predictions.append({ + 'label': label, + 'confidence': float(prob), + 'class_id': int(idx) + }) + + return { + 'predictions': predictions, + 'top_label': predictions[0]['label'] if predictions else None, + 'top_confidence': predictions[0]['confidence'] if predictions else 0.0, + } + + except Exception as e: + self.P(f"Error running inference with {model_name}: {e}", color='r') + import traceback + self.P(traceback.format_exc(), color='r') + return {'error': str(e)} def _decode_base64_image(self, image_data): """ @@ -151,7 +621,7 @@ def _decode_base64_image(self, image_data): def _extract_image_info(self, img_array): """ - Extract comprehensive information from image array. + Extract basic information from image array. Parameters ---------- @@ -169,156 +639,14 @@ def _extract_image_info(self, img_array): 'valid': False } - # Extract basic dimensions height, width = img_array.shape[:2] channels = img_array.shape[2] if len(img_array.shape) > 2 else 1 - # Calculate size info - total_pixels = height * width - size_mb = img_array.nbytes / (1024 * 1024) - - result = { + return { 'valid': True, 'width': int(width), 'height': int(height), 'channels': int(channels), - 'total_pixels': int(total_pixels), - 'size_mb': round(size_mb, 3), - 'dtype': str(img_array.dtype), - 'shape': list(img_array.shape), - } - - # Add color information for RGB images - if channels >= 3: - result['color_info'] = { - 'mean_r': float(img_array[:, :, 0].mean()), - 'mean_g': float(img_array[:, :, 1].mean()), - 'mean_b': float(img_array[:, :, 2].mean()), - 'std_r': float(img_array[:, :, 0].std()), - 'std_g': float(img_array[:, :, 1].std()), - 'std_b': float(img_array[:, :, 2].std()), - } - - # Add quality assessment - result['quality_info'] = { - 'resolution_category': self._categorize_resolution(width, height), - 'aspect_ratio': round(width / height, 3) if height > 0 else 0, - 'is_square': abs(width - height) < 10, - } - - return result - - def _categorize_resolution(self, width, height): - """ - Categorize image resolution for quality assessment. - - Parameters - ---------- - width : int - Image width in pixels - height : int - Image height in pixels - - Returns - ------- - str - Resolution category - """ - total_pixels = width * height - - if total_pixels < 100000: # < 0.1 MP - return 'very_low' - elif total_pixels < 500000: # < 0.5 MP - return 'low' - elif total_pixels < 2000000: # < 2 MP - return 'medium' - elif total_pixels < 5000000: # < 5 MP - return 'high' - else: - return 'very_high' - - def _generate_cervical_analysis(self, img_array, image_info): - """ - Generate cervical screening analysis results. - - This is a mock implementation that generates plausible analysis based on - image characteristics. In production, this would be replaced with actual - ML model inference for cervical cancer detection. - - Parameters - ---------- - img_array : np.ndarray - Image as numpy array - image_info : dict - Extracted image information - - Returns - ------- - dict - Analysis results with tz_type, lesion_assessment, lesion_summary, and risk_score - """ - if img_array is None or not image_info.get('valid', False): - return { - 'tz_type': 'Type 1', - 'lesion_assessment': 'none', - 'lesion_summary': 'Image quality insufficient for analysis', - 'risk_score': 0 - } - - quality_info = image_info.get('quality_info', {}) - resolution_category = quality_info.get('resolution_category', 'unknown') - image_quality_sufficient = resolution_category not in ['very_low', 'low'] - - # Purely random (but internally consistent) lesion and TZ selection - rng = self.rng - - tz_type = rng.choice( - ['Type 0', 'Type 1', 'Type 2', 'Type 3'], - p=[0.2, 0.3, 0.25, 0.25] - ) - - lesion_assessment = rng.choice( - ['none', 'low', 'moderate', 'high'], - p=[0.35, 0.3, 0.2, 0.15] - ) - - risk_score = self.base_risks[lesion_assessment] - - img_width = image_info.get('width', 0) - img_height = image_info.get('height', 0) - - visualization_limited = tz_type == 'Type 3' - if tz_type == 'Type 3': - risk_score = max(risk_score, 40) - - if resolution_category in ['very_low', 'low']: - quality_note = f"Image resolution ({img_width}x{img_height}) limits detailed assessment." - elif resolution_category == 'medium': - quality_note = f"Image resolution ({img_width}x{img_height}) is adequate for analysis." - else: - quality_note = f"Image resolution ({img_width}x{img_height}) is optimal for analysis." - - if tz_type == 'Type 3': - lesion_templates = self.lesion_templates['Type 3'] - elif tz_type == 'Type 0': - lesion_templates = self.lesion_templates['Type 0'] - else: - lesion_templates = self.lesion_templates['default'] - - lesion_summary = " ".join([ - self.tz_descriptions.get(tz_type, tz_type), - lesion_templates.get(lesion_assessment, self.lesion_text['none']), - quality_note - ]) - - return { - 'tz_type': tz_type, - 'lesion_assessment': lesion_assessment, - 'lesion_summary': lesion_summary, - 'risk_score': risk_score, - 'image_quality': resolution_category, - 'image_quality_sufficient': image_quality_sufficient, - 'assessment_confidence': 'reduced' if visualization_limited else 'normal' } def _pre_process(self, inputs): @@ -360,10 +688,10 @@ def _pre_process(self, inputs): def _predict(self, inputs): """ - Main prediction: extract image information. + Main prediction: extract image information and run model inference. - In the future, this will call the actual AI model for cervical - cancer detection. + Runs inference using loaded HuggingFace models for cervical + cancer detection and classification. Parameters ---------- @@ -373,7 +701,7 @@ def _predict(self, inputs): Returns ------- list - List of analysis results + List of analysis results including model predictions """ self._processed_count += 1 @@ -393,8 +721,63 @@ def _predict(self, inputs): # Extract image information image_info = self._extract_image_info(img_array) - # Generate cervical screening analysis - analysis = self._generate_cervical_analysis(img_array, image_info) + # Step 1: Run ImageNet classifier for image validation (not included in results) + imagenet_result = None + if self.model_1 is not None: + imagenet_result = self._run_model_inference( + img_array, self.model_1, self.processor_1, + model_type=self.model_1_type, + model_name=self.cfg_model_1_name + ) + + # Step 2: Validate image content using ImageNet results + # If ImageNet confidently recognizes an object, this is not a medical image + validation = self._validate_image_content(imagenet_result) + self.P(f"validation result: {self.json_dumps(validation)}") + + if not validation['valid']: + self.P(f"Image validation failed: {validation['reason']}", color='r') + self.Pd(f"Validation result details: {self.json_dumps(validation)}", color='r') + results.append({ + 'index': idx, + 'error': validation['reason'], + 'valid': False, + 'image_info': image_info, + 'processed_at': self.time(), + 'processor_version': __VER__, + }) + continue + + # Step 3: Image passed validation - run medical analysis models + # Model 2: Lesion classification (Normal, LSIL, HSIL, Cancer) + lesion_result = None + if self.model_2 is not None: + lesion_result = self._run_model_inference( + img_array, self.model_2, self.processor_2, + model_type=self.model_2_type, + model_name=self.cfg_model_2_name + ) + + # Model 3: Transformation zone classification (Type 1, Type 2, Type 3) + tz_result = None + if self.model_3 is not None: + tz_result = self._run_model_inference( + img_array, self.model_3, self.processor_3, + model_type=self.model_3_type, + model_name=self.cfg_model_3_name + ) + + # Build analysis from model results + analysis = { + 'lesion': lesion_result, + 'transformation_zone': tz_result, + } + + self.P("=============================================") + self.P(f"Processed input index: {idx}", color='g') + self.P(f"Image info: {image_info}", color='g') + self.P(f"Analysis: {self.json_dumps(analysis)}", color='g') + self.P("=============================================") # Add processing metadata result = { @@ -403,8 +786,6 @@ def _predict(self, inputs): 'analysis': analysis, 'processed_at': self.time(), 'processor_version': __VER__, - 'model_name': 'cerviguard_image_analyzer', - 'iteration': self._processed_count, } results.append(result) @@ -429,15 +810,41 @@ def _post_process(self, preds): formatted_results = [] for pred in preds: + # Determine status: check both image decoding and content validation + has_error = 'error' in pred + image_decoded = pred.get('image_info', {}).get('valid', False) + + if has_error: + status = 'error' + elif image_decoded: + status = 'completed' + else: + status = 'error' + # Format the result for output formatted = { - 'status': 'completed' if pred.get('image_info', {}).get('valid', False) else 'error', + 'status': status, 'data': pred, } - # Add error message if present - if 'error' in pred: - formatted['error'] = pred['error'] + # Add explicit error fields for UI consumption + if has_error: + error_msg = pred['error'] + formatted['error'] = error_msg + + # Determine error type for UI handling + if 'ImageNet detected' in error_msg: + formatted['error_code'] = 'INVALID_IMAGE_CONTENT' + formatted['error_type'] = 'validation' + formatted['error_message'] = 'The uploaded image does not appear to be a valid cervical image.' + elif 'Failed to decode' in error_msg: + formatted['error_code'] = 'DECODE_ERROR' + formatted['error_type'] = 'decoding' + formatted['error_message'] = 'Failed to decode the image. Please ensure the image is valid.' + else: + formatted['error_code'] = 'PROCESSING_ERROR' + formatted['error_type'] = 'processing' + formatted['error_message'] = error_msg formatted_results.append(formatted) diff --git a/plugins/business/cerviguard/cerviguard_api.py b/plugins/business/cerviguard/cerviguard_api.py index 85babec7..975b914c 100644 --- a/plugins/business/cerviguard/cerviguard_api.py +++ b/plugins/business/cerviguard/cerviguard_api.py @@ -15,7 +15,6 @@ """ from extensions.business.edge_inference_api.cv_inference_api import CvInferenceApiPlugin as BasePlugin -from naeural_core.utils.fastapi_utils import PostponedRequest __VER__ = '0.1.0' @@ -53,126 +52,74 @@ class CerviguardApiPlugin(BasePlugin): CONFIG = _CONFIG - """INFERENCE HANDLING""" - if True: - def _validate_analysis(self, analysis: dict) -> dict: - """ - Validate and sanitize analysis fields returned by inference. - - Parameters - ---------- - analysis : dict - Raw analysis data from the inference engine. - - Returns - ------- - dict - Analysis payload with validated and defaulted fields. - """ - safe_defaults = { - 'tz_type': 'Type 1', - 'lesion_assessment': 'none', - 'lesion_summary': 'Analysis unavailable', - 'risk_score': 0, - 'image_quality': 'unknown', - 'image_quality_sufficient': True - } - if not isinstance(analysis, dict): - return safe_defaults - - validated = {} - - tz_type = analysis.get('tz_type', safe_defaults['tz_type']) - if tz_type not in ['Type 0', 'Type 1', 'Type 2', 'Type 3']: - tz_type = safe_defaults['tz_type'] - validated['tz_type'] = tz_type - - lesion_assessment = analysis.get('lesion_assessment', safe_defaults['lesion_assessment']) - if lesion_assessment not in ['none', 'low', 'moderate', 'high']: - lesion_assessment = safe_defaults['lesion_assessment'] - validated['lesion_assessment'] = lesion_assessment - - lesion_summary = analysis.get('lesion_summary', safe_defaults['lesion_summary']) - if not isinstance(lesion_summary, str): - lesion_summary = safe_defaults['lesion_summary'] - validated['lesion_summary'] = lesion_summary - - risk_score = analysis.get('risk_score', safe_defaults['risk_score']) - try: - risk_score = int(risk_score) - if risk_score < 0 or risk_score > 100: - risk_score = max(0, min(100, risk_score)) - except (TypeError, ValueError): - risk_score = safe_defaults['risk_score'] - validated['risk_score'] = risk_score - - image_quality = analysis.get('image_quality', safe_defaults['image_quality']) - if not isinstance(image_quality, str): - image_quality = safe_defaults['image_quality'] - validated['image_quality'] = image_quality - - image_quality_sufficient = analysis.get( - 'image_quality_sufficient', - safe_defaults['image_quality_sufficient'] - ) - if not isinstance(image_quality_sufficient, bool): - image_quality_sufficient = safe_defaults['image_quality_sufficient'] - validated['image_quality_sufficient'] = image_quality_sufficient - return validated - - def _build_result_from_inference( - self, - request_id: str, - inference: Dict[str, Any], - metadata: Dict[str, Any], - request_data: Dict[str, Any] - ): - """ - Construct a result payload from inference output and metadata. - - Parameters - ---------- - request_id : str - Identifier of the tracked request. - inference : dict - Inference result data. - metadata : dict - Metadata to include in the response. - request_data : dict - Stored request record for reference. - - Returns - ------- - dict - Structured result payload including analysis and image details. - - Raises - ------ - ValueError - If the inference result format is invalid. - RuntimeError - When the inference indicates an error status. - """ - if not isinstance(inference, dict): - raise ValueError("Invalid inference result format.") - inference_data = inference.get('data', inference) - status = inference_data.get('status', inference.get('status', 'completed')) - if status == 'error': - err_msg = inference_data.get('error', 'Unknown error') - raise RuntimeError(err_msg) - - analysis = self._validate_analysis(inference_data.get('analysis', {})) - image_info = inference_data.get('image_info', {}) - result_payload = { - 'status': 'completed', + def _build_result_from_inference( + self, + request_id: str, + inference: dict, + metadata: dict, + request_data: dict + ): + """ + Construct a result payload from inference output and metadata. + + Parameters + ---------- + request_id : str + Identifier of the tracked request. + inference : dict + Inference result data from cerviguard_image_analyzer. + metadata : dict + Metadata to include in the response. + request_data : dict + Stored request record for reference. + + Returns + ------- + dict + Structured result payload with lesion and transformation zone predictions. + + Raises + ------ + ValueError + If the inference result format is invalid. + RuntimeError + When the inference indicates an error status. + """ + if not isinstance(inference, dict): + raise ValueError("Invalid inference result format.") + + inference_data = inference.get('data', inference) + status = inference_data.get('status', inference.get('status', 'completed')) + + if status == 'error': + # Return structured error response instead of raising exception + # This allows the UI to display meaningful error information + error_payload = { + 'status': 'error', 'request_id': request_id, - 'analysis': analysis, - 'image_info': image_info, + 'error': inference_data.get('error', inference.get('error', 'Unknown error')), + 'error_code': inference_data.get('error_code', inference.get('error_code', 'PROCESSING_ERROR')), + 'error_type': inference_data.get('error_type', inference.get('error_type', 'processing')), + 'error_message': inference_data.get('error_message', inference.get('error_message', 'An error occurred during processing.')), + 'image_info': inference_data.get('image_info', {}), 'processed_at': inference_data.get('processed_at', self.time()), 'processor_version': inference_data.get('processor_version', 'unknown'), 'metadata': metadata or request_data.get('metadata') or {}, } - if 'model_name' in inference_data: - result_payload['model_name'] = inference_data['model_name'] - return result_payload - """END INFERENCE HANDLING""" + return error_payload + + # Extract analysis from serving plugin (contains lesion and transformation_zone) + analysis = inference_data.get('analysis', {}) + image_info = inference_data.get('image_info', {}) + + result_payload = { + 'status': 'completed', + 'request_id': request_id, + 'analysis': analysis, + 'image_info': image_info, + 'processed_at': inference_data.get('processed_at', self.time()), + 'processor_version': inference_data.get('processor_version', 'unknown'), + 'metadata': metadata or request_data.get('metadata') or {}, + } + + return result_payload diff --git a/ver.py b/ver.py index 8fb05202..099b5f0c 100644 --- a/ver.py +++ b/ver.py @@ -1 +1 @@ -__VER__ = '2.10.20' +__VER__ = '2.10.21' From 31dee1507d5ad548aa7cf5b6729a183ddae4a070 Mon Sep 17 00:00:00 2001 From: Vitalii <87299468+toderian@users.noreply.github.com> Date: Wed, 28 Jan 2026 16:45:38 +0200 Subject: [PATCH 4/5] fix: read of worker entries (#349) * fix: read of worker entries * chore: inc version --- .../cybersec/red_mesh/pentester_api_01.py | 49 +++++++++++++++---- ver.py | 2 +- 2 files changed, 40 insertions(+), 11 deletions(-) diff --git a/extensions/business/cybersec/red_mesh/pentester_api_01.py b/extensions/business/cybersec/red_mesh/pentester_api_01.py index e4dd1931..11bbeb22 100644 --- a/extensions/business/cybersec/red_mesh/pentester_api_01.py +++ b/extensions/business/cybersec/red_mesh/pentester_api_01.py @@ -47,6 +47,8 @@ "CHAINSTORE_PEERS": [], "CHECK_JOBS_EACH" : 5, + + "REDMESH_VERBOSE" : 10, # Verbosity level for debug messages (0 = off, 1+ = debug) "NR_LOCAL_WORKERS" : 8, @@ -155,8 +157,33 @@ def P(self, s, *args, **kwargs): """ s = "[REDMESH] " + s return super(PentesterApi01Plugin, self).P(s, *args, **kwargs) - - + + + def Pd(self, s, *args, score=-1, **kwargs): + """ + Print debug message if verbosity level allows. + + Parameters + ---------- + s : str + Message to print. + score : int, optional + Verbosity threshold (default: -1). Message prints if cfg_redmesh_verbose > score. + *args + Additional positional arguments passed to P(). + **kwargs + Additional keyword arguments passed to P(). + + Returns + ------- + None + """ + if self.cfg_redmesh_verbose > score: + s = "[DEBUG] " + s + self.P(s, *args, **kwargs) + return + + def __post_init(self): """ Perform warmup: reconcile existing jobs in CStore, migrate legacy keys, @@ -281,29 +308,28 @@ def _normalize_job_record(self, job_key, job_spec, migrate=False): return job_key, normalized - def _ensure_worker_entry(self, job_id, job_spec): + def _get_worker_entry(self, job_id, job_spec): """ - Ensure current worker has an entry in the distributed job spec. + Get the worker entry for this node from the job spec. Parameters ---------- job_id : str Identifier of the job. job_spec : dict - Mutable job specification stored in CStore. + Job specification stored in CStore. Returns ------- - dict - Worker entry for this edge node. + dict | None + Worker entry for this edge node, or None if not assigned. """ workers = job_spec.setdefault("workers", {}) worker_entry = workers.get(self.ee_addr) if worker_entry is None: - self.P("No worker entry found for this node in job spec job_id={}, workers={}".format( + self.Pd("No worker entry found for this node in job spec job_id={}, workers={}".format( job_id, self.json_dumps(workers)), - color='r' ) return worker_entry @@ -460,7 +486,10 @@ def _maybe_launch_jobs(self, nr_local_workers=None): enabled_features = job_specs.get("enabled_features", []) if job_id is None: continue - worker_entry = self._ensure_worker_entry(job_id, job_specs) + worker_entry = self._get_worker_entry(job_id, job_specs) + if worker_entry is None: + # This node is not assigned to this job, skip it + continue current_worker_finished = worker_entry.get("finished", False) if current_worker_finished: continue diff --git a/ver.py b/ver.py index 099b5f0c..13de1756 100644 --- a/ver.py +++ b/ver.py @@ -1 +1 @@ -__VER__ = '2.10.21' +__VER__ = '2.10.22' From f12d5c54d5b688a805b3a08bd928c84f6adc15db Mon Sep 17 00:00:00 2001 From: Cristi Bleotiu Date: Wed, 28 Jan 2026 17:22:00 +0200 Subject: [PATCH 5/5] chore: inc ver --- ver.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ver.py b/ver.py index 13de1756..e77c000f 100644 --- a/ver.py +++ b/ver.py @@ -1 +1 @@ -__VER__ = '2.10.22' +__VER__ = '2.10.30'