From 59e4e25f8cbc8056f91023973706690f349ae514 Mon Sep 17 00:00:00 2001 From: Vlad Iftime Date: Mon, 16 Jun 2025 01:50:42 +0200 Subject: [PATCH 1/8] Flask server example --- examples/example_flask_server.py | 289 ++++++++++++++++ examples/example_s2_server.py | 11 +- examples/mock_s2_server.py | 121 ------- examples/mock_s2_websocket.py | 87 ----- src/s2python/authorization/client.py | 36 +- src/s2python/authorization/default_client.py | 18 +- .../authorization/default_http_server.py | 45 ++- .../authorization/default_ws_server.py | 83 +++-- .../authorization/flask_http_server.py | 245 ++++++++++++++ src/s2python/authorization/flask_service.py | 0 src/s2python/authorization/flask_ws_server.py | 317 ++++++++++++++++++ src/s2python/authorization/server.py | 53 ++- 12 files changed, 1016 insertions(+), 289 deletions(-) create mode 100644 examples/example_flask_server.py delete mode 100644 examples/mock_s2_server.py delete mode 100644 examples/mock_s2_websocket.py create mode 100644 src/s2python/authorization/flask_http_server.py delete mode 100644 src/s2python/authorization/flask_service.py create mode 100644 src/s2python/authorization/flask_ws_server.py diff --git a/examples/example_flask_server.py b/examples/example_flask_server.py new file mode 100644 index 0000000..3acf250 --- /dev/null +++ b/examples/example_flask_server.py @@ -0,0 +1,289 @@ +""" +Example S2 server implementation using Flask. + +This example demonstrates how to set up both an HTTP and a WebSocket server +using the Flask-based implementations. + +Note: You need to install Flask and Flask-Sock: + `pip install flask flask-sock` + +For running the WebSocket server with true async capabilities, an ASGI server +like Hypercorn is recommended: + `pip install hypercorn` + `hypercorn examples.example_flask_server:server_ws.app` +""" + +import argparse +import logging +import signal +import sys +import uuid + +from flask_sock import Sock + +from s2python.authorization.flask_http_server import S2FlaskHTTPServer +from s2python.authorization.flask_ws_server import S2FlaskWSServer +from s2python.common import ( + ControlType, + EnergyManagementRole, + Handshake, + HandshakeResponse, + ReceptionStatusValues, + ResourceManagerDetails, + SelectControlType, +) +from s2python.frbc import ( + FRBCActuatorStatus, + FRBCFillLevelTargetProfile, + FRBCStorageStatus, + FRBCSystemDescription, +) +from s2python.generated.gen_s2_pairing import ( + Deployment, + PairingToken, + Protocols, + S2NodeDescription, + S2Role, +) +from s2python.message import S2Message + +# Set up logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger("example_flask_server") + +# Create the server instance at the module level so Hypercorn can find it. +# We assume 'ws' mode for ASGI server execution. +server_instance = S2FlaskWSServer( + host="localhost", # These values can be configured via env vars or other means + port=8080, + role=EnergyManagementRole.CEM, +) + + +def create_signal_handler(): + """Create a signal handler function.""" + + def handler(signum, frame): + logger.info("Received signal %d. Shutting down...", signum) + if server_instance: + server_instance.stop() + # For Flask's dev server, this will be interrupted by the signal. + # For production servers (gunicorn, hypercorn), they handle signals for shutdown. + sys.exit(0) + + return handler + + +async def handle_FRBC_system_description(server: S2FlaskWSServer, message: S2Message, websocket: Sock) -> None: + """Handle FRBC system description messages.""" + if not isinstance(message, FRBCSystemDescription): + logger.error( + "Handler for FRBCSystemDescription received a message of the wrong type: %s", + type(message), + ) + return + + logger.info("Received FRBCSystemDescription: %s", message.to_json()) + await server.respond_with_reception_status( + subject_message_id=message.message_id, + status=ReceptionStatusValues.OK, + diagnostic_label="FRBCSystemDescription received", + websocket=websocket, + ) + + +async def handle_FRBCActuatorStatus(server: S2FlaskWSServer, message: S2Message, websocket: Sock) -> None: + """Handle FRBCActuatorStatus messages.""" + if not isinstance(message, FRBCActuatorStatus): + logger.error( + "Handler for FRBCActuatorStatus received a message of the wrong type: %s", + type(message), + ) + return + + logger.info("Received FRBCActuatorStatus: %s", message.to_json()) + await server.respond_with_reception_status( + subject_message_id=message.message_id, + status=ReceptionStatusValues.OK, + diagnostic_label="FRBCActuatorStatus received", + websocket=websocket, + ) + + +async def handle_FillLevelTargetProfile(server: S2FlaskWSServer, message: S2Message, websocket: Sock) -> None: + """Handle FillLevelTargetProfile messages.""" + if not isinstance(message, FRBCFillLevelTargetProfile): + logger.error( + "Handler for FillLevelTargetProfile received a message of the wrong type: %s", + type(message), + ) + return + + logger.info("Received FillLevelTargetProfile: %s", message.to_json()) + await server.respond_with_reception_status( + subject_message_id=message.message_id, + status=ReceptionStatusValues.OK, + diagnostic_label="FillLevelTargetProfile received", + websocket=websocket, + ) + + +async def handle_FRBCStorageStatus(server: S2FlaskWSServer, message: S2Message, websocket: Sock) -> None: + """Handle FRBCStorageStatus messages.""" + if not isinstance(message, FRBCStorageStatus): + logger.error( + "Handler for FRBCStorageStatus received a message of the wrong type: %s", + type(message), + ) + return + + logger.info("Received FRBCStorageStatus: %s", message.to_json()) + await server.respond_with_reception_status( + subject_message_id=message.message_id, + status=ReceptionStatusValues.OK, + diagnostic_label="FRBCStorageStatus received", + websocket=websocket, + ) + + +async def handle_ResourceManagerDetails(server: S2FlaskWSServer, message: S2Message, websocket: Sock) -> None: + """Handle ResourceManagerDetails messages.""" + if not isinstance(message, ResourceManagerDetails): + logger.error( + "Handler for ResourceManagerDetails received a message of the wrong type: %s", + type(message), + ) + return + + logger.info("Received ResourceManagerDetails: %s", message.to_json()) + await server.respond_with_reception_status( + subject_message_id=message.message_id, + status=ReceptionStatusValues.OK, + diagnostic_label="ResourceManagerDetails received", + websocket=websocket, + ) + + +async def handle_handshake(server: S2FlaskWSServer, message: S2Message, websocket: Sock) -> None: + """Handle handshake messages and send control type selection if client is RM.""" + if not isinstance(message, Handshake): + logger.error( + "Handler for Handshake received a message of the wrong type: %s", + type(message), + ) + return + + logger.info("Received Handshake in example_flask_server: %s", message.to_json()) + + # Send reception status for the handshake + await server.respond_with_reception_status( + subject_message_id=message.message_id, + status=ReceptionStatusValues.OK, + diagnostic_label="Handshake received", + websocket=websocket, + ) + + handshake_response = HandshakeResponse( + message_id=message.message_id, + selected_protocol_version="1.0", + ) + logger.info("Sent HandshakeResponse: %s", handshake_response.to_json()) + await server._send_and_forget(handshake_response, websocket) + + # If client is RM, send control type selection + if message.role == EnergyManagementRole.RM: + select_control_type = SelectControlType( + message_id=uuid.uuid4(), + control_type=ControlType.FILL_RATE_BASED_CONTROL, + ) + logger.info("Sending select control type: %s", select_control_type.to_json()) + await server.send_msg_and_await_reception_status_async(select_control_type, websocket) + +# Register handlers on the globally defined instance +server_instance._handlers.register_handler(Handshake, handle_handshake) +server_instance._handlers.register_handler(FRBCSystemDescription, handle_FRBC_system_description) +server_instance._handlers.register_handler(ResourceManagerDetails, handle_ResourceManagerDetails) +server_instance._handlers.register_handler(FRBCActuatorStatus, handle_FRBCActuatorStatus) +server_instance._handlers.register_handler(FRBCFillLevelTargetProfile, handle_FillLevelTargetProfile) +server_instance._handlers.register_handler(FRBCStorageStatus, handle_FRBCStorageStatus) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Example S2 server implementation using Flask.") + parser.add_argument( + "--host", + type=str, + default="localhost", + help="Host to bind the server to (default: localhost)", + ) + parser.add_argument( + "--http-port", + type=int, + default=8000, + help="HTTP port to use (default: 8000)", + ) + parser.add_argument( + "--ws-port", + type=int, + default=8080, + help="WebSocket port to use (default: 8080)", + ) + parser.add_argument( + "--instance", + type=str, + default="http", + choices=["http", "ws"], + help="Instance to use (http or ws, default: http)", + ) + parser.add_argument( + "--pairing-token", + type=str, + default="ca14fda4", + help="Pairing token to use (default: ca14fda4)", + ) + args = parser.parse_args() + + # Create node description for the server + server_node_description = S2NodeDescription( + brand="TNO", + logoUri="https://www.tno.nl/publish/pages/5604/tno-logo-1484x835_003_.jpg", + type="demo frbc example", + modelName="S2 server example (Flask)", + userDefinedName="TNO S2 server example for frbc using Flask", + role=S2Role.RM, + deployment=Deployment.LAN, + ) + logger.info("http_port: %s", args.http_port) + logger.info("ws_port: %s", args.ws_port) + + # Setup signal handling + handler = create_signal_handler() + signal.signal(signal.SIGINT, handler) + signal.signal(signal.SIGTERM, handler) + + if args.instance == "ws": + logger.info( + "Starting Flask WebSocket server. For async, run with 'hypercorn examples.example_flask_server:server_ws.app'" + ) + try: + server_instance.start() + except KeyboardInterrupt: + handler(signal.SIGINT, None) + else: + server_http = S2FlaskHTTPServer( + host=args.host, + http_port=args.http_port, + ws_port=args.ws_port, + instance=args.instance, + server_node_description=server_node_description, + token=PairingToken(token=args.pairing_token), + supported_protocols=[Protocols.WebSocketSecure], + ) + server_instance = server_http # type: ignore[assignment] + + logger.info("Starting Flask HTTP server.") + try: + # Note: server_http.start_server() uses Flask's development server. + server_http.start_server() + except KeyboardInterrupt: + handler(signal.SIGINT, None) diff --git a/examples/example_s2_server.py b/examples/example_s2_server.py index bc28c1d..642878a 100644 --- a/examples/example_s2_server.py +++ b/examples/example_s2_server.py @@ -6,9 +6,6 @@ import logging import signal import sys -from datetime import datetime, timedelta -from typing import Any -import asyncio import uuid from websockets import WebSocketServerProtocol @@ -31,14 +28,8 @@ ResourceManagerDetails, ) from s2python.frbc import ( - FRBCInstruction, FRBCSystemDescription, - FRBCActuatorDescription, - FRBCStorageDescription, - FRBCOperationMode, - FRBCOperationModeElement, FRBCFillLevelTargetProfile, - FRBCFillLevelTargetProfileElement, FRBCStorageStatus, FRBCActuatorStatus, ) @@ -260,7 +251,7 @@ async def handle_handshake(server: S2DefaultWSServer, message: S2Message, websoc server_ws._handlers.register_handler(FRBCActuatorStatus, handle_FRBCActuatorStatus) server_ws._handlers.register_handler(FRBCFillLevelTargetProfile, handle_FillLevelTargetProfile) server_ws._handlers.register_handler(FRBCStorageStatus, handle_FRBCStorageStatus) - + # Create and register signal handlers handler = create_signal_handler(server_ws) signal.signal(signal.SIGINT, handler) diff --git a/examples/mock_s2_server.py b/examples/mock_s2_server.py deleted file mode 100644 index 1410a5a..0000000 --- a/examples/mock_s2_server.py +++ /dev/null @@ -1,121 +0,0 @@ -import socketserver -import json -from typing import Any -import uuid -import logging -import random -import string - -from s2python.authorization.default_server import S2DefaultHTTPHandler - -# Set up logging -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger("mock_s2_server") - - -def generate_token() -> str: - """ - Generate a random alphanumeric token with exactly 32 characters. - - Returns: - str: A string of 32 random alphanumeric characters matching pattern ^[0-9a-zA-Z]{32}$ - """ - # Define the character set: uppercase letters, lowercase letters, and digits - chars = string.ascii_letters + string.digits - - # Generate a 32-character token by randomly selecting from the character set - token = "".join(random.choice(chars) for _ in range(32)) - - return token - - -# Generate random token for pairing -PAIRING_TOKEN = generate_token() -SERVER_NODE_ID = str(uuid.uuid4()) -WS_PORT = 8080 -HTTP_PORT = 8000 - - -class MockS2Handler(S2DefaultHTTPHandler): - def do_POST(self) -> None: # pylint: disable=C0103 - content_length = int(self.headers.get("Content-Length", 0)) - post_data = self.rfile.read(content_length).decode("utf-8") - - try: - request_json = json.loads(post_data) - logger.info("Received request at %s", self.path) - logger.debug("Request body: %s", request_json) - - if self.path == "/requestPairing": - # Handle pairing request - # The token in the S2 protocol is a PairingToken object with a token field - token_obj = request_json.get("token", {}) - - # Handle case where token is directly the string or a dict with token field - if isinstance(token_obj, dict) and "token" in token_obj: - request_token_string = token_obj["token"] - else: - request_token_string = token_obj - - logger.info("Extracted token: %s", request_token_string) - logger.info("Expected token: %s", PAIRING_TOKEN) - - if request_token_string == PAIRING_TOKEN: - # Create pairing response - response = { - "s2ServerNodeId": SERVER_NODE_ID, - "serverNodeDescription": { - "brand": "Mock S2 Server", - "type": "Test Server", - "modelName": "Mock Model", - "logoUri": "http://example.com/logo.png", - "userDefinedName": "Mock Server", - "role": "CEM", - "deployment": "LAN", - }, - "requestConnectionUri": f"http://localhost:{HTTP_PORT}/requestConnection", - } - self._send_json_response(200, response) - logger.info("Pairing request successful") - else: - self._send_json_response(401, {"error": "Invalid token"}) - logger.error("Invalid pairing token") - - elif self.path == "/requestConnection": - # Create challenge (normally would be a JWE) - challenge = "mock_challenge_string" - - # Create connection details response - response = { - "connectionUri": f"ws://localhost:{WS_PORT}/s2/mock-websocket", - "challenge": challenge, - "selectedProtocol": "WebSocketSecure", - } - - # Handle connection request - self._send_json_response(200, response) - logger.info("Connection request successful") - - else: - self._send_json_response(404, {"error": "Endpoint not found"}) - logger.error("Unknown endpoint: %s", self.path) - - except Exception as e: - self._send_json_response(500, {"error": str(e)}) - logger.error("Error handling request: %s", e) - raise e - - def log_message(self, format: str, *args: Any) -> None: # pylint: disable=W0622 - logger.info(format % args) # pylint: disable=W1201 - - -def run_server() -> None: - with socketserver.TCPServer(("localhost", HTTP_PORT), MockS2Handler) as httpd: - logger.info("Mock S2 Server running at: http://localhost:%s", HTTP_PORT) - logger.info("Use pairing token: %s", PAIRING_TOKEN) - logger.info("Pairing endpoint: http://localhost:%s/requestPairing", HTTP_PORT) - httpd.serve_forever() - - -if __name__ == "__main__": - run_server() diff --git a/examples/mock_s2_websocket.py b/examples/mock_s2_websocket.py deleted file mode 100644 index 41011f7..0000000 --- a/examples/mock_s2_websocket.py +++ /dev/null @@ -1,87 +0,0 @@ -import asyncio -import logging -import json -import uuid -from datetime import datetime, timezone -import websockets - -# Set up logging -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger("mock_s2_websocket") - -# WebSocket server port -WS_PORT = 8080 - - -# Handle client connection -async def handle_connection( - websocket: websockets.WebSocketServerProtocol, path: str -) -> None: - client_id = str(uuid.uuid4()) - logger.info("Client %s connected on path: %s", client_id, path) - - try: - # Send handshake message to client - handshake = { - "type": "Handshake", - "messageId": str(uuid.uuid4()), - "protocolVersion": "0.0.2-beta", - "timestamp": datetime.now(timezone.utc).isoformat(), - } - await websocket.send(json.dumps(handshake)) - logger.info("Sent handshake to client %s", client_id) - - # Listen for messages - async for message in websocket: - try: - data = json.loads(message) - logger.info("Received message from client %s: %s", client_id, data) - - # Extract message type - message_type = data.get("type", "") - message_id = data.get("messageId", str(uuid.uuid4())) - - # Send reception status - reception_status = { - "type": "ReceptionStatus", - "messageId": str(uuid.uuid4()), - "refMessageId": message_id, - "timestamp": datetime.now(timezone.utc).isoformat(), - "status": "OK", - } - await websocket.send(json.dumps(reception_status)) - logger.info("Sent reception status for message %s", message_id) - - # Handle specific message types - if message_type == "HandshakeResponse": - logger.info("Received handshake response") - - # For FRBC messages, you could add specific handling here - - except json.JSONDecodeError: - logger.error("Invalid JSON received from client %s", client_id) - except Exception as e: - logger.error( - "Error processing message from client %s: %s", client_id, e - ) - raise e - - except websockets.exceptions.ConnectionClosed: - logger.info("Connection with client %s closed", client_id) - except Exception as e: - logger.error("Error with client %s: %s", client_id, e) - raise e - finally: - logger.info("Client %s disconnected", client_id) - - -async def start_server() -> None: - server = await websockets.serve(handle_connection, "localhost", WS_PORT) - logger.info("WebSocket server started on ws://localhost:%s", WS_PORT) - - # Keep the server running - await server.wait_closed() - - -if __name__ == "__main__": - asyncio.run(start_server()) diff --git a/src/s2python/authorization/client.py b/src/s2python/authorization/client.py index 994afef..fde1202 100644 --- a/src/s2python/authorization/client.py +++ b/src/s2python/authorization/client.py @@ -116,7 +116,9 @@ def store_connection_request_uri(self, uri: str) -> None: self._connection_request_uri = uri elif self.pairing_uri is not None and "requestPairing" in self.pairing_uri: # Fall back to constructing the URI from the pairing URI - self._connection_request_uri = self.pairing_uri.replace("requestPairing", "requestConnection") + self._connection_request_uri = self.pairing_uri.replace( + "requestPairing", "requestConnection" + ) else: # No valid URI could be determined self._connection_request_uri = None @@ -177,10 +179,14 @@ def request_pairing(self) -> PairingResponse: ValueError: If pairing_uri or token is not set, or if the request fails """ if not self.pairing_uri: - raise ValueError("Pairing URI not set. Set pairing_uri before calling request_pairing.") + raise ValueError( + "Pairing URI not set. Set pairing_uri before calling request_pairing." + ) if not self.token: - raise ValueError("Pairing token not set. Set token before calling request_pairing.") + raise ValueError( + "Pairing token not set. Set token before calling request_pairing." + ) # Ensure we have keys if not self._public_key: @@ -211,7 +217,9 @@ def request_pairing(self) -> PairingResponse: # Parse response if status_code != 200: - raise ValueError(f"Pairing request failed with status {status_code}: {response_text}") + raise ValueError( + f"Pairing request failed with status {status_code}: {response_text}" + ) pairing_response = PairingResponse.model_validate(json.loads(response_text)) @@ -231,7 +239,9 @@ def request_connection(self) -> ConnectionDetails: ValueError: If connection request URI is not set or if the request fails """ if not self._connection_request_uri: - raise ValueError("Connection request URI not set. Call request_pairing first.") + raise ValueError( + "Connection request URI not set. Call request_pairing first." + ) # Create connection request connection_request = ConnectionRequest( @@ -249,7 +259,9 @@ def request_connection(self) -> ConnectionDetails: # Parse response if status_code != 200: - raise ValueError(f"Connection request failed with status {status_code}: {response_text}") + raise ValueError( + f"Connection request failed with status {status_code}: {response_text}" + ) connection_details = ConnectionDetails.model_validate(json.loads(response_text)) @@ -283,13 +295,19 @@ def request_connection(self) -> ConnectionDetails: # Replace the URI with the full WebSocket URL connection_data["connectionUri"] = full_ws_url # Recreate the ConnectionDetails object - connection_details = ConnectionDetails.model_validate(connection_data) - logger.info("Updated relative WebSocket URI to absolute: %s", full_ws_url) + connection_details = ConnectionDetails.model_validate( + connection_data + ) + logger.info( + "Updated relative WebSocket URI to absolute: %s", full_ws_url + ) except (ValueError, TypeError, KeyError) as e: logger.info("Failed to update WebSocket URI: %s", e) else: # Log a warning but don't modify the URI if we can't create a proper absolute URI - logger.info("Received relative WebSocket URI but pairing_uri is not available to create absolute URL") + logger.info( + "Received relative WebSocket URI but pairing_uri is not available to create absolute URL" + ) # Store for later use self._connection_details = connection_details diff --git a/src/s2python/authorization/default_client.py b/src/s2python/authorization/default_client.py index c6c1ef7..2321346 100644 --- a/src/s2python/authorization/default_client.py +++ b/src/s2python/authorization/default_client.py @@ -32,6 +32,7 @@ logging.basicConfig(level=logging.INFO) logger = logging.getLogger("S2DefaultClient") + class S2DefaultClient(S2AbstractClient): """Default implementation of the S2AbstractClient using the requests library for HTTP and jwskate for cryptographic operations. @@ -156,7 +157,7 @@ def solve_challenge(self, challenge: Optional[Any] = None) -> str: rsa_key_pair = Jwk.from_pem(self._public_key) else: raise ValueError("No public key available") - #check that the challenge is a JweCompact + # check that the challenge is a JweCompact if not isinstance(challenge, str): raise ValueError("Challenge is not a string") # Log the challenge @@ -192,7 +193,7 @@ def solve_challenge(self, challenge: Optional[Any] = None) -> str: decrypted_challenge_str=decrypted_challenge_str, ) - logger.info('Decrypted challenge: %s', decrypted_challenge_str) + logger.info("Decrypted challenge: %s", decrypted_challenge_str) return decrypted_challenge_str except (ValueError, TypeError, KeyError, json.JSONDecodeError) as e: @@ -229,8 +230,13 @@ def establish_secure_connection(self) -> Dict[str, Any]: "Challenge solution not available. Call solve_challenge first." ) - logger.info('Establishing WebSocket connection to %s,', self._connection_details.connectionUri) - logger.info('Using solved challenge: %s', self._pairing_details.decrypted_challenge_str) + logger.info( + "Establishing WebSocket connection to %s,", + self._connection_details.connectionUri, + ) + logger.info( + "Using solved challenge: %s", self._pairing_details.decrypted_challenge_str + ) # Placeholder for the connection object self._ws_connection = { @@ -241,9 +247,7 @@ def establish_secure_connection(self) -> Dict[str, Any]: return self._ws_connection def close_connection(self) -> None: - """Close the WebSocket connection. - - """ + """Close the WebSocket connection.""" if self._ws_connection: logger.info("Would close WebSocket connection") diff --git a/src/s2python/authorization/default_http_server.py b/src/s2python/authorization/default_http_server.py index 6ca82d5..cf4f154 100644 --- a/src/s2python/authorization/default_http_server.py +++ b/src/s2python/authorization/default_http_server.py @@ -2,41 +2,25 @@ Default implementation of the S2 protocol server. """ -import base64 import http.server import json import logging import socketserver import asyncio -import uuid -from datetime import datetime, timezone -from typing import Dict, Any, Tuple, Optional, Union, Awaitable +from datetime import datetime +from typing import Dict, Any, Tuple, Optional, Union from jwskate import Jwk, Jwt from jwskate.jwe.compact import JweCompact import websockets -from websockets.server import WebSocketServerProtocol from s2python.authorization.server import S2AbstractServer from s2python.generated.gen_s2_pairing import ( - ConnectionDetails, ConnectionRequest, PairingRequest, - PairingResponse, - Protocols, ) -from s2python.message import S2Message from websockets.server import WebSocketServer -from s2python.common import ( - ReceptionStatusValues, - ReceptionStatus, - Handshake, - HandshakeResponse, - EnergyManagementRole, - SelectControlType, -) -from s2python.version import S2_VERSION from s2python.communication.s2_connection import MessageHandlers, S2Connection from s2python.s2_parser import S2Parser @@ -86,7 +70,9 @@ def do_POST(self) -> None: # pylint: disable=C0103 logger.error("Error handling request: %s", e) raise e - def _send_json_response(self, status_code: int, response_body: Union[dict, str]) -> None: + def _send_json_response( + self, status_code: int, response_body: Union[dict, str] + ) -> None: """ Helper function to send a JSON response. :param handler: The HTTP handler instance (self). @@ -133,7 +119,9 @@ def _handle_connection_request(self, request_json: Dict[str, Any]) -> None: connection_request = ConnectionRequest.model_validate(request_json) # Process request using server instance - response = self.server_instance.handle_connection_request(connection_request) + response = self.server_instance.handle_connection_request( + connection_request + ) # Send response self._send_json_response(200, response.model_dump_json()) @@ -205,7 +193,9 @@ def store_key_pair(self, public_key: str, private_key: str) -> None: # Convert to JWK for JWT operations self._private_jwk = Jwk.from_pem(private_key) - def _create_signed_token(self, claims: Dict[str, Any], expiry_date: datetime) -> str: + def _create_signed_token( + self, claims: Dict[str, Any], expiry_date: datetime + ) -> str: """Create a signed JWT token. Args: @@ -230,7 +220,11 @@ def _create_signed_token(self, claims: Dict[str, Any], expiry_date: datetime) -> return str(token) def _create_encrypted_challenge( - self, client_public_key: str, client_node_id: str, nested_signed_token: str, expiry_date: datetime + self, + client_public_key: str, + client_node_id: str, + nested_signed_token: str, + expiry_date: datetime, ) -> str: """Create an encrypted challenge for the client. @@ -286,7 +280,9 @@ def handler_factory(*args: Any, **kwargs: Any) -> S2DefaultHTTPHandler: return S2DefaultHTTPHandler(*args, server_instance=self, **kwargs) # Create and start server - self._httpd = socketserver.TCPServer((self.host, self.http_port), handler_factory) + self._httpd = socketserver.TCPServer( + (self.host, self.http_port), handler_factory + ) logger.info("S2 Server running at: http://%s:%s", self.host, self.http_port) # Start the WebSocket server self._httpd.serve_forever() @@ -307,8 +303,7 @@ def stop_server(self) -> None: self._ws_server = None def _get_ws_url(self) -> str: - """Get the WebSocket URL for the server. - """ + """Get the WebSocket URL for the server.""" return f"ws://{self.host}:{self.ws_port}" def _get_base_url(self) -> str: diff --git a/src/s2python/authorization/default_ws_server.py b/src/s2python/authorization/default_ws_server.py index c9b0e37..cc218dc 100644 --- a/src/s2python/authorization/default_ws_server.py +++ b/src/s2python/authorization/default_ws_server.py @@ -43,7 +43,10 @@ def __init__(self) -> None: self.handlers = {} async def handle_message( - self, server: "S2DefaultWSServer", msg: S2Message, websocket: WebSocketServerProtocol + self, + server: "S2DefaultWSServer", + msg: S2Message, + websocket: WebSocketServerProtocol, ) -> None: """Handle the S2 message using the registered handler. @@ -125,7 +128,9 @@ def __init__( def _register_default_handlers(self) -> None: """Register default message handlers.""" self._handlers.register_handler(Handshake, self.handle_handshake) - self._handlers.register_handler(HandshakeResponse, self.handle_handshake_response) + self._handlers.register_handler( + HandshakeResponse, self.handle_handshake_response + ) self._handlers.register_handler(ReceptionStatus, self.handle_reception_status) def start(self) -> None: @@ -166,9 +171,15 @@ async def _connect_and_run(self) -> None: host=self.host, port=self.port, ) - logger.info("S2 WebSocket server running at: ws://%s:%s", self.host, self.port) + logger.info( + "S2 WebSocket server running at: ws://%s:%s", self.host, self.port + ) else: - logger.info("S2 WebSocket server already running at: ws://%s:%s", self.host, self.port) + logger.info( + "S2 WebSocket server already running at: ws://%s:%s", + self.host, + self.port, + ) async def wait_till_stop() -> None: await self._stop_event.wait() @@ -180,7 +191,9 @@ async def wait_till_connection_restart() -> None: self._eventloop.create_task(wait_till_stop()), self._eventloop.create_task(wait_till_connection_restart()), ] - (done, pending) = await asyncio.wait(background_tasks, return_when=asyncio.FIRST_COMPLETED) + (done, pending) = await asyncio.wait( + background_tasks, return_when=asyncio.FIRST_COMPLETED + ) await self._stop_event.wait() def stop(self) -> None: @@ -190,7 +203,9 @@ def stop(self) -> None: if self._server: self._server.close() - async def _handle_websocket_connection(self, websocket: WebSocketServerProtocol, path: str) -> None: + async def _handle_websocket_connection( + self, websocket: WebSocketServerProtocol, path: str + ) -> None: """Handle incoming WebSocket connections. Args: @@ -208,11 +223,15 @@ async def _handle_websocket_connection(self, websocket: WebSocketServerProtocol, s2_msg = self.s2_parser.parse_as_any_message(message) if isinstance(s2_msg, ReceptionStatus): logger.info("Received reception status: %s", s2_msg) - await self.reception_status_awaiter.receive_reception_status(s2_msg) + await self.reception_status_awaiter.receive_reception_status( + s2_msg + ) continue except json.JSONDecodeError: await self.respond_with_reception_status( - subject_message_id=uuid.UUID("00000000-0000-0000-0000-000000000000"), + subject_message_id=uuid.UUID( + "00000000-0000-0000-0000-000000000000" + ), status=ReceptionStatusValues.INVALID_DATA, diagnostic_label="Not valid json.", websocket=websocket, @@ -223,7 +242,9 @@ async def _handle_websocket_connection(self, websocket: WebSocketServerProtocol, await self._handlers.handle_message(self, s2_msg, websocket) except json.JSONDecodeError: await self.respond_with_reception_status( - subject_message_id=uuid.UUID("00000000-0000-0000-0000-000000000000"), + subject_message_id=uuid.UUID( + "00000000-0000-0000-0000-000000000000" + ), status=ReceptionStatusValues.INVALID_DATA, diagnostic_label="Not valid json.", websocket=websocket, @@ -240,7 +261,9 @@ async def _handle_websocket_connection(self, websocket: WebSocketServerProtocol, ) else: await self.respond_with_reception_status( - subject_message_id=uuid.UUID("00000000-0000-0000-0000-000000000000"), + subject_message_id=uuid.UUID( + "00000000-0000-0000-0000-000000000000" + ), status=ReceptionStatusValues.INVALID_DATA, diagnostic_label="Message appears valid json but could not find a message_id field.", websocket=websocket, @@ -276,7 +299,9 @@ async def respond_with_reception_status( status=status, diagnostic_label=diagnostic_label, ) - logger.info("Sending reception status %s for message %s", status, subject_message_id) + logger.info( + "Sending reception status %s for message %s", status, subject_message_id + ) try: await websocket.send(response.to_json()) except websockets.exceptions.ConnectionClosed: @@ -292,7 +317,9 @@ def respond_with_reception_status_sync( """Synchronous version of respond_with_reception_status.""" if self._loop: asyncio.run_coroutine_threadsafe( - self.respond_with_reception_status(subject_message_id, status, diagnostic_label, websocket), + self.respond_with_reception_status( + subject_message_id, status, diagnostic_label, websocket + ), self._loop, ).result() @@ -351,7 +378,10 @@ async def send_msg_and_await_reception_status_async( ) async def handle_handshake( - self, _: "S2DefaultWSServer", message: S2Message, websocket: WebSocketServerProtocol + self, + _: "S2DefaultWSServer", + message: S2Message, + websocket: WebSocketServerProtocol, ) -> None: """Handle handshake messages. @@ -372,7 +402,9 @@ async def handle_handshake( message_id=message.message_id, selected_protocol_version=message.supported_protocol_versions, ) - await self.send_msg_and_await_reception_status_async(handshake_response, websocket) + await self.send_msg_and_await_reception_status_async( + handshake_response, websocket + ) await self.respond_with_reception_status( subject_message_id=message.message_id, @@ -387,7 +419,10 @@ async def handle_handshake( ) async def handle_reception_status( - self, _: "S2DefaultWSServer", message: S2Message, websocket: WebSocketServerProtocol + self, + _: "S2DefaultWSServer", + message: S2Message, + websocket: WebSocketServerProtocol, ) -> None: """Handle reception status messages.""" if not isinstance(message, ReceptionStatus): @@ -396,10 +431,15 @@ async def handle_reception_status( type(message), ) return - logger.info("Received ReceptionStatus in handle_reception_status: %s", message.to_json()) + logger.info( + "Received ReceptionStatus in handle_reception_status: %s", message.to_json() + ) async def handle_handshake_response( - self, _: "S2DefaultWSServer", message: S2Message, websocket: WebSocketServerProtocol + self, + _: "S2DefaultWSServer", + message: S2Message, + websocket: WebSocketServerProtocol, ) -> None: """Handle handshake response messages. @@ -417,7 +457,9 @@ async def handle_handshake_response( logger.debug("Received HandshakeResponse: %s", message.to_json()) - async def _send_and_forget(self, s2_msg: S2Message, websocket: WebSocketServerProtocol) -> None: + async def _send_and_forget( + self, s2_msg: S2Message, websocket: WebSocketServerProtocol + ) -> None: """Send a message and forget about it. Args: @@ -430,7 +472,10 @@ async def _send_and_forget(self, s2_msg: S2Message, websocket: WebSocketServerPr logger.warning("Connection closed while sending message") async def send_select_control_type( - self, control_type: ControlType, websocket: WebSocketServerProtocol, send_okay: Awaitable[None] + self, + control_type: ControlType, + websocket: WebSocketServerProtocol, + send_okay: Awaitable[None], ) -> None: """Select the control type. diff --git a/src/s2python/authorization/flask_http_server.py b/src/s2python/authorization/flask_http_server.py new file mode 100644 index 0000000..58f1ec6 --- /dev/null +++ b/src/s2python/authorization/flask_http_server.py @@ -0,0 +1,245 @@ +""" +Flask implementation of the S2 protocol server. +""" + +import logging +from datetime import datetime +from typing import Dict, Any, Tuple +import json + +from flask import Flask, request +from jwskate import Jwk, Jwt +from jwskate.jwe.compact import JweCompact + +from s2python.authorization.server import S2AbstractServer +from s2python.generated.gen_s2_pairing import ( + ConnectionRequest, + PairingRequest, +) + +from s2python.communication.s2_connection import MessageHandlers, S2Connection +from s2python.s2_parser import S2Parser + +# Set up logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger("S2FlaskServer") + + +class S2FlaskHTTPServer(S2AbstractServer): + """Flask implementation of the S2 protocol server.""" + + def __init__( + self, + host: str = "localhost", + http_port: int = 8000, + ws_port: int = 8080, + instance: str = "http", + *args: Any, + **kwargs: Any, + ) -> None: + """Initialize the Flask server implementation. + + Args: + host: The host to bind to + http_port: The HTTP port to use + ws_port: The WebSocket port to use + instance: The instance type (http or ws) + """ + super().__init__(*args, **kwargs) + self.host = host + self.http_port = http_port + self.ws_port = ws_port + self.instance = instance + self._app = Flask(__name__) + self._connections: Dict[str, S2Connection] = {} + self._handlers = MessageHandlers() + self.s2_parser = S2Parser() + self._setup_routes() + + def _setup_routes(self) -> None: + """Set up Flask routes for the server.""" + self._app.add_url_rule( + "/requestPairing", + "requestPairing", + self._handle_pairing_request, + methods=["POST"], + ) + self._app.add_url_rule( + "/requestConnection", + "requestConnection", + self._handle_connection_request, + methods=["POST"], + ) + + def _handle_pairing_request(self) -> Tuple[dict, int]: + """Handle a pairing request. + + Returns: + Tuple[dict, int]: (response JSON, status code) + """ + try: + request_json = request.get_json() + logger.info("Received pairing request at /requestPairing") + logger.debug("Request body: %s", request_json) + + # Convert request to PairingRequest + pairing_request = PairingRequest.model_validate(request_json) + + # Process request using server instance + response = self.handle_pairing_request(pairing_request) + + # Send response + logger.info("Pairing request successful") + return response.model_dump(), 200 + + except ValueError as e: + logger.error("Invalid pairing request: %s", e) + return {"error": str(e)}, 400 + except Exception as e: + logger.error("Error handling pairing request: %s", e) + return {"error": str(e)}, 500 + + def _handle_connection_request(self) -> Tuple[dict, int]: + """Handle a connection request. + + Returns: + Tuple[dict, int]: (response JSON, status code) + """ + try: + request_json = request.get_json() + logger.info("Received connection request at /requestConnection") + logger.debug("Request body: %s", request_json) + + # Convert request to ConnectionRequest + connection_request = ConnectionRequest.model_validate(request_json) + + # Process request using server instance + response = self.handle_connection_request(connection_request) + + # Send response + logger.info("Connection request successful") + return response.model_dump(), 200 + + except ValueError as e: + logger.error("Invalid connection request: %s", e) + return {"error": str(e)}, 400 + except Exception as e: + logger.error("Error handling connection request: %s", e) + return {"error": str(e)}, 500 + + def generate_key_pair(self) -> Tuple[str, str]: + """Generate a public/private key pair for the server. + + Returns: + Tuple[str, str]: (public_key, private_key) pair as base64 encoded strings + """ + logger.info("Generating key pair") + self._key_pair = Jwk.generate_for_alg("RSA-OAEP-256").with_kid_thumbprint() + self._public_jwk = self._key_pair + self._private_jwk = self._key_pair + return ( + self._public_jwk.to_pem(), + self._private_jwk.to_pem(), + ) + + def store_key_pair(self, public_key: str, private_key: str) -> None: + """Store the server's public/private key pair. + + Args: + public_key: Base64 encoded public key + private_key: Base64 encoded private key + """ + self._private_key = private_key + # Convert to JWK for JWT operations + self._private_jwk = Jwk.from_pem(private_key) + + def _create_signed_token( + self, claims: Dict[str, Any], expiry_date: datetime + ) -> str: + """Create a signed JWT token. + + Args: + claims: The claims to include in the token + expiry_date: The token's expiration date + + Returns: + str: The signed JWT token + """ + if not self._private_jwk: + # Generate key pair with correct algorithm + self._key_pair = Jwk.generate_for_alg("RS256").with_kid_thumbprint() + self._private_jwk = self._key_pair + self._public_jwk = self._key_pair + + # Add expiration to claims + claims["exp"] = int(expiry_date.timestamp()) + + # Create JWT with claims using RS256 for signing + token = Jwt.sign(claims=claims, key=self._private_jwk, alg="RS256") + + return str(token) + + def _create_encrypted_challenge( + self, + client_public_key: str, + client_node_id: str, + nested_signed_token: str, + expiry_date: datetime, + ) -> str: + """Create an encrypted challenge for the client. + + Args: + client_public_key: The client's public key + client_node_id: The client's node ID + nested_signed_token: The nested signed token + expiry_date: The challenge's expiration date + + Returns: + str: The encrypted challenge + """ + # Convert client's public key to JWK + client_jwk = Jwk.from_pem(client_public_key) + + # Create the payload to encrypt - this will be decrypted and used as an unprotected JWT + payload = { + "S2ClientNodeId": client_node_id, + "signedToken": nested_signed_token, + "exp": int(expiry_date.timestamp()), + } + + # Create JWE with all required components + jwe = JweCompact.encrypt( + plaintext=json.dumps(payload).encode(), + key=client_jwk, # Using client's public key for encryption + alg="RSA-OAEP-256", + enc="A256GCM", + ) + + logger.info("JWE: %s", str(jwe)) + return str(jwe) + + def start_server(self) -> None: + """Start the HTTP or WebSocket server.""" + if self.instance == "http": + logger.info("Starting Flask HTTP server------>") + self._app.run(host=self.host, port=self.http_port) + else: + raise ValueError("Invalid instance type") + + def stop_server(self) -> None: + """Stop the server.""" + # Flask doesn't have a built-in way to stop the server + # This would typically be handled by the WSGI server + pass + + def _get_ws_url(self) -> str: + """Get the WebSocket URL for the server.""" + return f"ws://{self.host}:{self.ws_port}" + + def _get_base_url(self) -> str: + """Get the base URL for the server. + + Returns: + str: The base URL (e.g., "http://localhost:8000") + """ + return f"http://{self.host}:{self.http_port}" diff --git a/src/s2python/authorization/flask_service.py b/src/s2python/authorization/flask_service.py deleted file mode 100644 index e69de29..0000000 diff --git a/src/s2python/authorization/flask_ws_server.py b/src/s2python/authorization/flask_ws_server.py new file mode 100644 index 0000000..bd9a698 --- /dev/null +++ b/src/s2python/authorization/flask_ws_server.py @@ -0,0 +1,317 @@ +""" +Flask implementation of the S2 protocol WebSocket server. +""" + +import asyncio +import json +import logging +import traceback +import uuid +from typing import Any, Callable, Dict, Optional, Type + +from flask import Flask +from flask_sock import ConnectionClosed, Sock + +from s2python.common import ( + ControlType, + EnergyManagementRole, + Handshake, + HandshakeResponse, + ReceptionStatus, + ReceptionStatusValues, + SelectControlType, +) +from s2python.communication.reception_status_awaiter import ReceptionStatusAwaiter +from s2python.message import S2Message +from s2python.s2_parser import S2Parser +from s2python.s2_validation_error import S2ValidationError + +# Set up logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger("S2FlaskWSServer") + + +class MessageHandlers: + """Class to manage message handlers for different message types.""" + + handlers: Dict[Type[S2Message], Callable] + + def __init__(self) -> None: + self.handlers = {} + + async def handle_message( + self, + server: "S2FlaskWSServer", + msg: S2Message, + websocket: Sock, + ) -> None: + """Handle the S2 message using the registered handler. + Args: + server: The server instance handling the message + msg: The S2 message to handle + websocket: The websocket connection to the client + """ + handler = self.handlers.get(type(msg)) + if handler is not None: + try: + if asyncio.iscoroutinefunction(handler): + await handler(server, msg, websocket) + else: + + def do_message() -> None: + handler(server, msg, websocket) + + eventloop = asyncio.get_event_loop() + await eventloop.run_in_executor(executor=None, func=do_message) + except Exception: + logger.error( + "While processing message %s an unrecoverable error occurred.", + msg.message_id, # type: ignore[attr-defined] + ) + logger.error("Error: %s", traceback.format_exc()) + await server.respond_with_reception_status( + subject_message_id=msg.message_id, # type: ignore[attr-defined] + status=ReceptionStatusValues.PERMANENT_ERROR, + diagnostic_label=f"While processing message {msg.message_id} " # type: ignore[attr-defined] + f"an unrecoverable error occurred.", + websocket=websocket, + ) + raise + else: + logger.warning( + "Received a message of type %s but no handler is registered. Ignoring the message.", + type(msg), + ) + + def register_handler(self, msg_type: Type[S2Message], handler: Callable[..., Any]) -> None: + """Register a handler for a specific message type. + Args: + msg_type: The message type to handle + handler: The handler function + """ + self.handlers[msg_type] = handler + + +class S2FlaskWSServer: + """Flask-based WebSocket server implementation for S2 protocol.""" + + def __init__( + self, + host: str = "localhost", + port: int = 8080, + role: EnergyManagementRole = EnergyManagementRole.CEM, + ws_path: str = "/", + ) -> None: + """Initialize the WebSocket server. + Args: + host: The host to bind to + port: The port to listen on + role: The role of this server (CEM or RM) + ws_path: The path for the WebSocket endpoint. + """ + self.host = host + self.port = port + self.role = role + self.ws_path = ws_path + + self.app = Flask(__name__) + self.sock = Sock(self.app) + + self._handlers = MessageHandlers() + self.s2_parser = S2Parser() + self._connections: Dict[str, Sock] = {} + self.reception_status_awaiter = ReceptionStatusAwaiter() + + self._register_default_handlers() + self._setup_routes() + + def _setup_routes(self) -> None: + self.sock.route(self.ws_path)(self._ws_handler) + + def _register_default_handlers(self) -> None: + """Register default message handlers.""" + self._handlers.register_handler(Handshake, self.handle_handshake) + self._handlers.register_handler(HandshakeResponse, self.handle_handshake_response) + self._handlers.register_handler(ReceptionStatus, self.handle_reception_status) + + def start(self) -> None: + """ + Start the WebSocket server using Flask's built-in development server. + + Note: This server is synchronous (WSGI) and is not suitable for production. + It runs each WebSocket connection in a separate thread. For true async + support and production use, you must run this application with an + ASGI server like Hypercorn. Example: + `hypercorn -b 127.0.0.1:8080 examples.example_flask_server:server_instance.app` + """ + logger.info( + "Starting S2 Flask WebSocket server with Flask's development WSGI server at: ws://%s:%s%s", + self.host, + self.port, + self.ws_path, + ) + self.app.run(host=self.host, port=self.port, debug=False) + + def stop(self) -> None: + """Stop the WebSocket server.""" + logger.warning("S2FlaskWSServer.stop() is a no-op. The server should be managed by a process manager.") + + def _ws_handler(self, ws: Sock) -> None: + """ + Wrapper to run the async websocket handler from a synchronous context. + This is required for Flask's development server. An ASGI server would + be able to run the async handler directly. + """ + try: + asyncio.run(self._handle_websocket_connection(ws)) + except Exception as e: + # The websocket is likely closed, or another network error occurred. + logger.error("Error in websocket handler: %s", e) + + async def _handle_websocket_connection(self, websocket: Sock) -> None: + """Handle incoming WebSocket connections.""" + client_id = str(uuid.uuid4()) + logger.info("Client %s connected.", client_id) + self._connections[client_id] = websocket + + try: + while True: + message = await websocket.receive() + try: + s2_msg = self.s2_parser.parse_as_any_message(message) + if isinstance(s2_msg, ReceptionStatus): + await self.reception_status_awaiter.receive_reception_status(s2_msg) + continue + except json.JSONDecodeError: + await self.respond_with_reception_status( + subject_message_id=uuid.UUID("00000000-0000-0000-0000-000000000000"), + status=ReceptionStatusValues.INVALID_DATA, + diagnostic_label="Not valid json.", + websocket=websocket, + ) + continue + try: + await self._handlers.handle_message(self, s2_msg, websocket) + except json.JSONDecodeError: + await self.respond_with_reception_status( + subject_message_id=uuid.UUID("00000000-0000-0000-0000-000000000000"), + status=ReceptionStatusValues.INVALID_DATA, + diagnostic_label="Not valid json.", + websocket=websocket, + ) + except S2ValidationError as e: + json_msg = json.loads(message) + message_id = json_msg.get("message_id") + if message_id: + await self.respond_with_reception_status( + subject_message_id=message_id, + status=ReceptionStatusValues.INVALID_MESSAGE, + diagnostic_label=str(e), + websocket=websocket, + ) + else: + await self.respond_with_reception_status( + subject_message_id=uuid.UUID("00000000-0000-0000-0000-000000000000"), + status=ReceptionStatusValues.INVALID_DATA, + diagnostic_label="Message appears valid json but could not find a message_id field.", + websocket=websocket, + ) + except Exception as e: + logger.error("Error processing message: %s", str(e)) + raise + except ConnectionClosed: + logger.info("Connection with client %s closed", client_id) + finally: + if client_id in self._connections: + del self._connections[client_id] + logger.info("Client %s disconnected", client_id) + + async def respond_with_reception_status( + self, + subject_message_id: uuid.UUID, + status: ReceptionStatusValues, + diagnostic_label: str, + websocket: Sock, + ) -> None: + """Send a reception status response.""" + response = ReceptionStatus( + subject_message_id=subject_message_id, status=status, diagnostic_label=diagnostic_label + ) + logger.info("Sending reception status %s for message %s", status, subject_message_id) + try: + await websocket.send(response.to_json()) + except ConnectionClosed: + logger.warning("Connection closed while sending reception status") + + async def send_msg_and_await_reception_status_async( + self, + s2_msg: S2Message, + websocket: Sock, + timeout_reception_status: float = 20.0, + raise_on_error: bool = True, + ) -> ReceptionStatus: + """Send a message and await a reception status.""" + await self._send_and_forget(s2_msg, websocket) + try: + response = await asyncio.wait_for(websocket.receive(), timeout=timeout_reception_status) + # Assuming the response is the correct reception status + return ReceptionStatus( + subject_message_id=s2_msg.message_id, # type: ignore[attr-defined] + status=ReceptionStatusValues.OK, + diagnostic_label="Reception status received.", + ) + except asyncio.TimeoutError: + if raise_on_error: + raise TimeoutError(f"Did not receive a reception status on time for {s2_msg.message_id}") # type: ignore[attr-defined] + return ReceptionStatus( + subject_message_id=s2_msg.message_id, # type: ignore[attr-defined] + status=ReceptionStatusValues.PERMANENT_ERROR, + diagnostic_label="Timeout waiting for reception status.", + ) + except ConnectionClosed: + return ReceptionStatus( + subject_message_id=s2_msg.message_id, # type: ignore[attr-defined] + status=ReceptionStatusValues.OK, + diagnostic_label="Connection closed, assuming OK status.", + ) + + async def handle_handshake(self, _: "S2FlaskWSServer", message: S2Message, websocket: Sock) -> None: + """Handle handshake messages.""" + if not isinstance(message, Handshake): + return + + handshake_response = HandshakeResponse( + message_id=message.message_id, selected_protocol_version=message.supported_protocol_versions + ) + await self.send_msg_and_await_reception_status_async(handshake_response, websocket) + + await self.respond_with_reception_status( + subject_message_id=message.message_id, + status=ReceptionStatusValues.OK, + diagnostic_label="Handshake received", + websocket=websocket, + ) + + async def handle_reception_status(self, _: "S2FlaskWSServer", message: S2Message, websocket: Sock) -> None: + """Handle reception status messages.""" + if not isinstance(message, ReceptionStatus): + return + logger.info("Received ReceptionStatus in handle_reception_status: %s", message.to_json()) + + async def handle_handshake_response(self, _: "S2FlaskWSServer", message: S2Message, websocket: Sock) -> None: + """Handle handshake response messages.""" + if not isinstance(message, HandshakeResponse): + return + logger.debug("Received HandshakeResponse: %s", message.to_json()) + + async def _send_and_forget(self, s2_msg: S2Message, websocket: Sock) -> None: + """Send a message and forget about it.""" + try: + await websocket.send(s2_msg.to_json()) + except ConnectionClosed: + logger.warning("Connection closed while sending message") + + async def send_select_control_type(self, control_type: ControlType, websocket: Sock) -> None: + """Select the control type.""" + select_control_type = SelectControlType(message_id=uuid.uuid4(), control_type=control_type) + await self._send_and_forget(select_control_type, websocket) diff --git a/src/s2python/authorization/server.py b/src/s2python/authorization/server.py index 8123dec..74e0d72 100644 --- a/src/s2python/authorization/server.py +++ b/src/s2python/authorization/server.py @@ -115,7 +115,9 @@ def get_client_public_key(self, client_node_id: str) -> Optional[str]: """ return self._client_keys.get(client_node_id) - def handle_pairing_request(self, pairing_request: PairingRequest) -> PairingResponse: + def handle_pairing_request( + self, pairing_request: PairingRequest + ) -> PairingResponse: """Handle a pairing request from a client. Args: @@ -130,8 +132,14 @@ def handle_pairing_request(self, pairing_request: PairingRequest) -> PairingResp logger.info(f"Pairing request for Client Node: {pairing_request}") # Validate required fields - if not pairing_request.publicKey or not pairing_request.s2ClientNodeId or not pairing_request.token: - raise ValueError("Missing fields, public key, s2ClientNodeId and token are required") + if ( + not pairing_request.publicKey + or not pairing_request.s2ClientNodeId + or not pairing_request.token + ): + raise ValueError( + "Missing fields, public key, s2ClientNodeId and token are required" + ) # Validate token # TODO: Get token from server FM @@ -140,7 +148,9 @@ def handle_pairing_request(self, pairing_request: PairingRequest) -> PairingResp # Store client's public key # TODO: Store client's public key. sqlLite? - self.store_client_public_key(str(pairing_request.s2ClientNodeId), pairing_request.publicKey) + self.store_client_public_key( + str(pairing_request.s2ClientNodeId), pairing_request.publicKey + ) # Create full URLs for endpoints base_url = self._get_base_url() @@ -156,7 +166,9 @@ def handle_pairing_request(self, pairing_request: PairingRequest) -> PairingResp logger.info(f"Pairing response: {pairing_response}") return pairing_response - def handle_connection_request(self, connection_request: ConnectionRequest) -> ConnectionDetails: + def handle_connection_request( + self, connection_request: ConnectionRequest + ) -> ConnectionDetails: """Handle a connection request from a client. Args: @@ -175,10 +187,14 @@ def handle_connection_request(self, connection_request: ConnectionRequest) -> Co not connection_request.supportedProtocols or Protocols.WebSocketSecure not in connection_request.supportedProtocols ): - raise ValueError("S2 Server does not support any of the protocols supported by the client") + raise ValueError( + "S2 Server does not support any of the protocols supported by the client" + ) # Get client's public key - client_public_key = self.get_client_public_key(connection_request.s2ClientNodeId) + client_public_key = self.get_client_public_key( + connection_request.s2ClientNodeId + ) if not client_public_key: raise ValueError("Cannot retrieve client's public key") @@ -187,12 +203,16 @@ def handle_connection_request(self, connection_request: ConnectionRequest) -> Co # Create nested signed token nested_signed_token = self._create_signed_token( - claims={"S2ClientNodeId": connection_request.s2ClientNodeId}, expiry_date=expiry_date + claims={"S2ClientNodeId": connection_request.s2ClientNodeId}, + expiry_date=expiry_date, ) # Create encrypted challenge challenge = self._create_encrypted_challenge( - client_public_key, connection_request.s2ClientNodeId, nested_signed_token, expiry_date + client_public_key, + connection_request.s2ClientNodeId, + nested_signed_token, + expiry_date, ) ws_url = self._get_ws_url() # Create connection details @@ -206,7 +226,9 @@ def handle_connection_request(self, connection_request: ConnectionRequest) -> Co return connection_details @abc.abstractmethod - def _create_signed_token(self, claims: Dict[str, Any], expiry_date: datetime) -> str: + def _create_signed_token( + self, claims: Dict[str, Any], expiry_date: datetime + ) -> str: """Create a signed JWT token. Args: @@ -219,7 +241,11 @@ def _create_signed_token(self, claims: Dict[str, Any], expiry_date: datetime) -> @abc.abstractmethod def _create_encrypted_challenge( - self, client_public_key: str, client_node_id: str, nested_signed_token: str, expiry_date: datetime + self, + client_public_key: str, + client_node_id: str, + nested_signed_token: str, + expiry_date: datetime, ) -> Any: """Create an encrypted challenge for the client. TODO: using Any to avoid stringification of the JWE. Pros/Cons? @@ -243,6 +269,11 @@ def _get_base_url(self) -> str: # This should be overridden by concrete implementations return "http://localhost:8000" + @abc.abstractmethod + def _get_ws_url(self) -> str: + """Get the WebSocket URL for the server.""" + return "ws://localhost:8080" + @abc.abstractmethod def start_server(self) -> None: """Start the server. From e2f525be4d8b0b26654aa8e0b96749adf859a523 Mon Sep 17 00:00:00 2001 From: Vlad Iftime Date: Sat, 21 Jun 2025 00:30:35 +0200 Subject: [PATCH 2/8] Running server from single file with database and header ccheck for the ws connection --- challenges.db | Bin 0 -> 16384 bytes examples/example_pairing_frbc_rm.py | 77 +++++++------ examples/example_s2_server.py | 105 ++++++++++-------- s2.db | Bin 0 -> 20480 bytes src/s2python/authorization/database.py | 102 +++++++++++++++++ src/s2python/authorization/default_client.py | 48 +++----- .../authorization/default_http_server.py | 41 ++++--- .../authorization/default_ws_server.py | 89 ++++++++------- src/s2python/authorization/server.py | 6 +- src/s2python/communication/s2_connection.py | 101 +++++++---------- 10 files changed, 336 insertions(+), 233 deletions(-) create mode 100644 challenges.db create mode 100644 s2.db create mode 100644 src/s2python/authorization/database.py diff --git a/challenges.db b/challenges.db new file mode 100644 index 0000000000000000000000000000000000000000..e7f4015a86fdac3e7fb135f8e39e791155d39f04 GIT binary patch literal 16384 zcmeI%&rZTH90%}r5HTbKE`}rBO%r2ccmSPP$>Km9C6UuIpz7iT+y;@Oc=WM+8sEU9 z9flDV)Wmop-(S-8-!Du1Suf3+=av^qb~y-c-H4S*flx{g86$*fs-#t!$LBtsPkd8r zWN-YG%fFIT?uF!O{#jKB2tWV=5P$##AOHafKmY;|_)mf9Rw`30mFVLnay!1fkrQX= zdO?^p(iNKTK&tGXqi(OltnsEWY7KVGPmGo&Sfyp#ye^z%=F9qa9u8=x zsOxlU&Rn|JuJ6nKl?)dx?Ssr>d<9LC@V?t0dH%YCvbIsI>Xdl>p1cq5e0869+)*@` z^_^rjp|iher_@o~jE`F`eiY&d0s;_#00bZa0SG_<0uX=z1Rwx`brs0Q8^is7T|Y1O y1px>^00Izz00bZa0SG_<0uY!9;Qv1k0t6rc0SG_<0uX=z1Rwwb2tZ){1>OJ)QIy#L literal 0 HcmV?d00001 diff --git a/examples/example_pairing_frbc_rm.py b/examples/example_pairing_frbc_rm.py index 8cea463..9183edd 100644 --- a/examples/example_pairing_frbc_rm.py +++ b/examples/example_pairing_frbc_rm.py @@ -1,8 +1,13 @@ import argparse import logging +import threading +import time +import os +import uuid -from .example_frbc_rm import start_s2_session from s2python.authorization.default_client import S2DefaultClient +from s2python.authorization.default_http_server import S2DefaultHTTPServer +from s2python.authorization.default_ws_server import S2DefaultWSServer from s2python.generated.gen_s2_pairing import ( S2NodeDescription, Deployment, @@ -14,31 +19,26 @@ logger = logging.getLogger("s2python") +def run_http_server(server): + server.start_server() + + +def run_ws_server(server): + server.start() + + if __name__ == "__main__": - parser = argparse.ArgumentParser( - description="A simple S2 resource manager example." - ) - parser.add_argument( - "--endpoint", - type=str, - help="Rest endpoint to start S2 pairing. E.g. https://localhost/requestPairing", - ) - parser.add_argument( - "--pairing_token", - type=str, - help="The pairing token for the endpoint. You should get this from the S2 server e.g. ca14fda4", - ) - parser.add_argument( - "--verify-ssl", - action="store_true", - help="Verify SSL certificates (default: False)", - default=False, - ) + # Configuration + parser = argparse.ArgumentParser(description="S2 pairing example for FRBC RM") + parser.add_argument("--pairing_endpoint", type=str, required=True) + parser.add_argument("--pairing_token", type=str, required=True) + args = parser.parse_args() - # Configure logging - logging.basicConfig(level=logging.INFO) + pairing_endpoint = args.pairing_endpoint + pairing_token = args.pairing_token + # --- Client Setup --- # Create node description node_description = S2NodeDescription( brand="TNO", @@ -52,18 +52,16 @@ # Create a client to perform the pairing client = S2DefaultClient( - pairing_uri=args.endpoint, - token=PairingToken( - token=args.pairing_token, - ), + pairing_uri=pairing_endpoint, + token=PairingToken(token=pairing_token), node_description=node_description, - verify_certificate=args.verify_ssl, + verify_certificate=False, supported_protocols=[Protocols.WebSocketSecure], ) try: # Request pairing - logger.info("Initiating pairing with endpoint: %s", args.endpoint) + logger.info("Initiating pairing with endpoint: %s", pairing_endpoint) pairing_response = client.request_pairing() logger.info("Pairing request successful, requesting connection...") @@ -73,17 +71,26 @@ # Solve challenge challenge_result = client.solve_challenge() - logger.info("Challenge decrypted successfully") + logger.info("Challenge solved successfully") - # Log connection details - logger.info("Connection URI: %s", connection_details.connectionUri) + # Establish secure connection + s2_connection = client.establish_secure_connection() + logger.info("Secure WebSocket connection established.") # Start S2 session with the connection details logger.info("Starting S2 session...") - start_s2_session( - str(connection_details.connectionUri), - ) + s2_connection.start() + logger.info("S2 session is running. Press Ctrl+C to exit.") + + # Keep the main thread alive to allow the WebSocket connection to run. + event = threading.Event() + event.wait() + except KeyboardInterrupt: + logger.info("Program interrupted by user.") except Exception as e: - logger.error("Error during pairing process: %s", e) + logger.error("Error during pairing process: %s", e, exc_info=True) raise e + finally: + client.close_connection() + logger.info("Connection closed.") diff --git a/examples/example_s2_server.py b/examples/example_s2_server.py index 642878a..a8dfefc 100644 --- a/examples/example_s2_server.py +++ b/examples/example_s2_server.py @@ -6,6 +6,8 @@ import logging import signal import sys +import os +import threading import uuid from websockets import WebSocketServerProtocol @@ -71,7 +73,6 @@ async def handle_FRBC_system_description( ) - async def handle_FRBCActuatorStatus( server: S2DefaultWSServer, message: S2Message, websocket: WebSocketServerProtocol ) -> None: @@ -211,20 +212,25 @@ async def handle_handshake(server: S2DefaultWSServer, message: S2Message, websoc default=8080, help="WebSocket port to use (default: 8080)", ) - parser.add_argument( - "--instance", - type=str, - default="http", - help="Instance to use (default: http)", - ) + parser.add_argument( "--pairing-token", type=str, default="ca14fda4", help="Pairing token to use (default: ca14fda4)", ) + parser.add_argument( + "--db-path", + type=str, + default="s2.db", + help="Path to the SQLite database (default: s2.db)", + ) + args = parser.parse_args() + # Clean up previous database file + if os.path.exists(args.db_path): + os.remove(args.db_path) # Create node description for the server server_node_description = S2NodeDescription( brand="TNO", @@ -238,45 +244,46 @@ async def handle_handshake(server: S2DefaultWSServer, message: S2Message, websoc logger.info("http_port: %s", args.http_port) logger.info("ws_port: %s", args.ws_port) - if args.instance == "ws": - server_ws = S2DefaultWSServer( - host=args.host, - port=args.ws_port, - role=EnergyManagementRole.CEM, - ) - # Register our custom handshake handler - server_ws._handlers.register_handler(Handshake, handle_handshake) - server_ws._handlers.register_handler(FRBCSystemDescription, handle_FRBC_system_description) - server_ws._handlers.register_handler(ResourceManagerDetails, handle_ResourceManagerDetails) - server_ws._handlers.register_handler(FRBCActuatorStatus, handle_FRBCActuatorStatus) - server_ws._handlers.register_handler(FRBCFillLevelTargetProfile, handle_FillLevelTargetProfile) - server_ws._handlers.register_handler(FRBCStorageStatus, handle_FRBCStorageStatus) - - # Create and register signal handlers - handler = create_signal_handler(server_ws) - signal.signal(signal.SIGINT, handler) - signal.signal(signal.SIGTERM, handler) - - try: - server_ws.start() - except KeyboardInterrupt: - server_ws.stop() - else: - server_http = S2DefaultHTTPServer( - host=args.host, - http_port=args.http_port, - ws_port=args.ws_port, - instance=args.instance, - server_node_description=server_node_description, - token=PairingToken(token=args.pairing_token), - supported_protocols=[Protocols.WebSocketSecure], - ) - # Create and register signal handlers - handler = create_signal_handler(server_http) - signal.signal(signal.SIGINT, handler) - signal.signal(signal.SIGTERM, handler) - - try: - server_http.start_server() - except KeyboardInterrupt: - server_http.stop_server() + server_ws = S2DefaultWSServer( + host=args.host, + port=args.ws_port, + db_path=args.db_path, + role=EnergyManagementRole.CEM, + ) + # Register our custom handshake handler + server_ws._handlers.register_handler(Handshake, handle_handshake) + server_ws._handlers.register_handler(FRBCSystemDescription, handle_FRBC_system_description) + server_ws._handlers.register_handler(ResourceManagerDetails, handle_ResourceManagerDetails) + server_ws._handlers.register_handler(FRBCActuatorStatus, handle_FRBCActuatorStatus) + server_ws._handlers.register_handler(FRBCFillLevelTargetProfile, handle_FillLevelTargetProfile) + server_ws._handlers.register_handler(FRBCStorageStatus, handle_FRBCStorageStatus) + + # Create and register signal handlers + handler = create_signal_handler(server_ws) + signal.signal(signal.SIGINT, handler) + signal.signal(signal.SIGTERM, handler) + server_ws_thread = threading.Thread(target=server_ws.start, daemon=True) + server_ws_thread.start() + logger.info("WebSocket Server started in background thread.") + + server_http = S2DefaultHTTPServer( + host=args.host, + http_port=args.http_port, + ws_port=args.ws_port, + db_path=args.db_path, + server_node_description=server_node_description, + token=PairingToken(token=args.pairing_token), + supported_protocols=[Protocols.WebSocketSecure], + ) + # Create and register signal handlers + handler = create_signal_handler(server_http) + signal.signal(signal.SIGINT, handler) + signal.signal(signal.SIGTERM, handler) + server_http_thread = threading.Thread(target=server_http.start_server, daemon=True) + server_http_thread.start() + logger.info("HTTP Server started in background thread.") + + # Wait for both threads to finish + server_ws_thread.join() + server_http_thread.join() + logger.info("Both servers have stopped.") diff --git a/s2.db b/s2.db new file mode 100644 index 0000000000000000000000000000000000000000..17888574d1ccdae36c748fd362d3d0d0ccb6a254 GIT binary patch literal 20480 zcmeI%%SyvQ6oBDLTQ3N;x{xlN?i552A3&pyQcSDHDd;N1Hfk{5+O+6WUHVwQi96rG zrIV^DqNRekQ2qmxx#SS$e9UIZ?v4}2fjm5Kp8By|6cfU-#Ilq^h@5FdZ_P4&=55V% z-AL93MNZ6j?(>Cbkn4h9f1G>2q1s}0tg_000Iag@S6g+Q`wQ?^t5$%9s37S za1va57k=1mrGwn6t8A^LwpSb_(`7j`CclWVA*Hk2#3-Ezft_vNPApZhZRqJ0pB zb Generator[sqlite3.Connection, None, None]: + """Provides a database connection.""" + conn = sqlite3.connect(self.db_path) + try: + yield conn + finally: + conn.close() + + def init_db(self) -> None: + """Initializes the database and creates the 'challenges' table if it doesn't exist.""" + with self._get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + """ + CREATE TABLE IF NOT EXISTS challenges ( + challenge TEXT PRIMARY KEY, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + """ + ) + cursor.execute( + """ + CREATE TABLE IF NOT EXISTS key_pairs ( + id INTEGER PRIMARY KEY, + public_key TEXT NOT NULL, + private_key TEXT NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + """ + ) + + conn.commit() + logger.info("Database initialized at %s", self.db_path) + + def store_challenge(self, challenge: str) -> None: + """ + Stores a challenge in the database. + + :param challenge: The challenge string to store. + """ + with self._get_connection() as conn: + cursor = conn.cursor() + cursor.execute("INSERT INTO challenges (challenge) VALUES (?)", (challenge,)) + conn.commit() + logger.info("Stored challenge in the database.") + + def store_key_pair(self, public_key: str, private_key: str) -> None: + """ + Stores a key pair in the database. + + :param public_key: The public key string to store. + :param private_key: The private key string to store. + """ + with self._get_connection() as conn: + cursor = conn.cursor() + cursor.execute("INSERT INTO key_pairs (public_key, private_key) VALUES (?, ?)", (public_key, private_key)) + conn.commit() + logger.info("Stored key pair in the database.") + + def verify_and_remove_challenge(self, challenge: str) -> bool: + """ + Verifies a challenge exists and removes it to prevent reuse. + + :param challenge: The challenge string to verify. + :return: True if the challenge was valid, False otherwise. + """ + with self._get_connection() as conn: + cursor = conn.cursor() + cursor.execute("SELECT challenge FROM challenges WHERE challenge=?", (challenge,)) + result = cursor.fetchone() + if result: + logger.info("Challenge found. Removing it from database.") + cursor.execute("DELETE FROM challenges WHERE challenge=?", (challenge,)) + conn.commit() + return True + logger.warning("Challenge not found in database.") + return False diff --git a/src/s2python/authorization/default_client.py b/src/s2python/authorization/default_client.py index 2321346..1ef606d 100644 --- a/src/s2python/authorization/default_client.py +++ b/src/s2python/authorization/default_client.py @@ -27,6 +27,8 @@ KEY_ALGORITHM, PairingDetails, ) +from s2python.communication.s2_connection import S2Connection +from s2python.common import EnergyManagementRole # Set up logging logging.basicConfig(level=logging.INFO) @@ -60,7 +62,7 @@ def __init__( supported_protocols, ) # Additional state for this implementation - self._ws_connection: Optional[Dict[str, Any]] = None + self._ws_connection: Optional[S2Connection] = None def generate_key_pair(self) -> Tuple[str, str]: """Generate a public/private key pair using jwskate library. @@ -138,15 +140,11 @@ def solve_challenge(self, challenge: Optional[Any] = None) -> str: """ if challenge is None: if not self._connection_details or not self._connection_details.challenge: - raise ValueError( - "Challenge not provided and not available in connection details" - ) + raise ValueError("Challenge not provided and not available in connection details") challenge = self._connection_details.challenge if not self._key_pair and not self._public_key: - raise ValueError( - "Public key is not available. Generate or load a key pair first." - ) + raise ValueError("Public key is not available. Generate or load a key pair first.") try: # If we have a jwskate Jwk object, use it directly @@ -181,9 +179,7 @@ def solve_challenge(self, challenge: Optional[Any] = None) -> str: jwt_token_str = str(jwt_token) # Encode the token as base64 - decrypted_challenge_str: str = base64.b64encode( - jwt_token_str.encode("utf-8") - ).decode("utf-8") + decrypted_challenge_str: str = base64.b64encode(jwt_token_str.encode("utf-8")).decode("utf-8") # Store the pairing details if we have all required components if self._pairing_response and self._connection_details: @@ -201,7 +197,7 @@ def solve_challenge(self, challenge: Optional[Any] = None) -> str: logger.info(error_msg) raise RuntimeError(error_msg) from e - def establish_secure_connection(self) -> Dict[str, Any]: + def establish_secure_connection(self) -> S2Connection: """Establish a secure WebSocket connection. This implementation establishes a WebSocket connection @@ -211,38 +207,30 @@ def establish_secure_connection(self) -> Dict[str, Any]: this would use a WebSocket library like websocket-client or websockets. Returns: - Dict[str, Any]: A WebSocket connection object + S2Connection: A S2Connection object Raises: ValueError: If connection details or solved challenge are not available RuntimeError: If connection establishment fails """ if not self._connection_details: - raise ValueError( - "Connection details not available. Call request_connection first." - ) - - if ( - not self._pairing_details - or not self._pairing_details.decrypted_challenge_str - ): - raise ValueError( - "Challenge solution not available. Call solve_challenge first." - ) + raise ValueError("Connection details not available. Call request_connection first.") + + if not self._pairing_details or not self._pairing_details.decrypted_challenge_str: + raise ValueError("Challenge solution not available. Call solve_challenge first.") logger.info( "Establishing WebSocket connection to %s,", self._connection_details.connectionUri, ) - logger.info( - "Using solved challenge: %s", self._pairing_details.decrypted_challenge_str - ) + logger.info("Using solved challenge: %s", self._pairing_details.decrypted_challenge_str) # Placeholder for the connection object - self._ws_connection = { - "status": "connected", - "uri": str(self._connection_details.connectionUri), - } + self._ws_connection = S2Connection( + url=str(self._connection_details.connectionUri), + role=EnergyManagementRole.CEM, + bearer_token=self._pairing_details.decrypted_challenge_str, + ) return self._ws_connection diff --git a/src/s2python/authorization/default_http_server.py b/src/s2python/authorization/default_http_server.py index cf4f154..24c4608 100644 --- a/src/s2python/authorization/default_http_server.py +++ b/src/s2python/authorization/default_http_server.py @@ -7,6 +7,7 @@ import logging import socketserver import asyncio +import base64 from datetime import datetime from typing import Dict, Any, Tuple, Optional, Union @@ -22,6 +23,7 @@ from websockets.server import WebSocketServer from s2python.communication.s2_connection import MessageHandlers, S2Connection +from s2python.authorization.database import S2Database from s2python.s2_parser import S2Parser @@ -70,9 +72,7 @@ def do_POST(self) -> None: # pylint: disable=C0103 logger.error("Error handling request: %s", e) raise e - def _send_json_response( - self, status_code: int, response_body: Union[dict, str] - ) -> None: + def _send_json_response(self, status_code: int, response_body: Union[dict, str]) -> None: """ Helper function to send a JSON response. :param handler: The HTTP handler instance (self). @@ -119,9 +119,7 @@ def _handle_connection_request(self, request_json: Dict[str, Any]) -> None: connection_request = ConnectionRequest.model_validate(request_json) # Process request using server instance - response = self.server_instance.handle_connection_request( - connection_request - ) + response = self.server_instance.handle_connection_request(connection_request) # Send response self._send_json_response(200, response.model_dump_json()) @@ -145,6 +143,8 @@ def __init__( http_port: int = 8000, ws_port: int = 8080, instance: str = "http", + db_path: Optional[str] = None, + encryption_algorithm: str = "RSA-OAEP-256", *args: Any, **kwargs: Any, ) -> None: @@ -154,6 +154,8 @@ def __init__( host: The host to bind to http_port: The HTTP port to use ws_port: The WebSocket port to use + db_path: Path to the SQLite database for challenges. + encryption_algorithm: The algorithm for JWE encryption. """ super().__init__(*args, **kwargs) self.host = host @@ -166,6 +168,8 @@ def __init__( self._loop: Optional[asyncio.AbstractEventLoop] = None self._handlers = MessageHandlers() self.s2_parser = S2Parser() + self.s2_db = S2Database(db_path) if db_path else None + self.encryption_algorithm = encryption_algorithm def generate_key_pair(self) -> Tuple[str, str]: """Generate a public/private key pair for the server. @@ -190,12 +194,15 @@ def store_key_pair(self, public_key: str, private_key: str) -> None: private_key: Base64 encoded private key """ self._private_key = private_key + self._public_key = public_key # Convert to JWK for JWT operations self._private_jwk = Jwk.from_pem(private_key) + self._public_jwk = Jwk.from_pem(public_key) + # Store the key pair in the database + if self.s2_db: + self.s2_db.store_key_pair(public_key, private_key) - def _create_signed_token( - self, claims: Dict[str, Any], expiry_date: datetime - ) -> str: + def _create_signed_token(self, claims: Dict[str, Any], expiry_date: datetime) -> str: """Create a signed JWT token. Args: @@ -247,12 +254,20 @@ def _create_encrypted_challenge( "exp": int(expiry_date.timestamp()), } + if self.s2_db: + # This is what the client will produce after decrypting the challenge. + # We store it, so the WS server can verify it. + jwt_token = Jwt.unprotected(payload) + jwt_token_str = str(jwt_token) + decrypted_challenge_str: str = base64.b64encode(jwt_token_str.encode("utf-8")).decode("utf-8") + self.s2_db.store_challenge(decrypted_challenge_str) + # Create JWE with all required components jwe = JweCompact.encrypt( plaintext=json.dumps(payload).encode(), key=client_jwk, # Using client's public key for encryption - alg="RSA-OAEP-256", - enc="A256GCM", + alg=self.encryption_algorithm, + enc="A256GCM", # TODO: Remove hardcode ) # test the decryption of the JWE # decrypted_payload = jwe.decrypt(client_jwk) @@ -280,9 +295,7 @@ def handler_factory(*args: Any, **kwargs: Any) -> S2DefaultHTTPHandler: return S2DefaultHTTPHandler(*args, server_instance=self, **kwargs) # Create and start server - self._httpd = socketserver.TCPServer( - (self.host, self.http_port), handler_factory - ) + self._httpd = socketserver.TCPServer((self.host, self.http_port), handler_factory) logger.info("S2 Server running at: http://%s:%s", self.host, self.http_port) # Start the WebSocket server self._httpd.serve_forever() diff --git a/src/s2python/authorization/default_ws_server.py b/src/s2python/authorization/default_ws_server.py index cc218dc..fd30f85 100644 --- a/src/s2python/authorization/default_ws_server.py +++ b/src/s2python/authorization/default_ws_server.py @@ -8,11 +8,12 @@ import threading import time import uuid -from typing import Any, Optional, List, Type, Dict, Callable, Awaitable, Union +from typing import Any, Optional, List, Type, Dict, Callable, Awaitable, Union, Tuple import traceback import websockets from websockets.server import WebSocketServerProtocol, serve as ws_serve +from websockets.datastructures import Headers from s2python.common import ( ReceptionStatusValues, @@ -28,6 +29,7 @@ from s2python.s2_validation_error import S2ValidationError from s2python.communication.reception_status_awaiter import ReceptionStatusAwaiter from s2python.version import S2_VERSION +from s2python.authorization.database import S2Database # Set up logging logging.basicConfig(level=logging.INFO) @@ -102,6 +104,7 @@ def __init__( host: str = "localhost", port: int = 8080, role: EnergyManagementRole = EnergyManagementRole.CEM, + db_path: Optional[str] = None, ) -> None: """Initialize the WebSocket server. @@ -109,6 +112,7 @@ def __init__( host: The host to bind to port: The port to listen on role: The role of this server (CEM or RM) + db_path: Path to the SQLite database for challenges. """ self.host = host self.port = port @@ -122,15 +126,14 @@ def __init__( self._stop_event = asyncio.Event() self.reception_status_awaiter = ReceptionStatusAwaiter() self.reconnect = False + self.s2_db = S2Database(db_path) if db_path else None # Register default handlers self._register_default_handlers() def _register_default_handlers(self) -> None: """Register default message handlers.""" self._handlers.register_handler(Handshake, self.handle_handshake) - self._handlers.register_handler( - HandshakeResponse, self.handle_handshake_response - ) + self._handlers.register_handler(HandshakeResponse, self.handle_handshake_response) self._handlers.register_handler(ReceptionStatus, self.handle_reception_status) def start(self) -> None: @@ -162,18 +165,43 @@ async def _run_as_cem(self) -> None: logger.debug("Finished S2 connection eventloop.") + async def _process_request( + self, path: str, request_headers: Headers + ) -> Optional[Tuple[int, List[Tuple[str, str]], bytes]]: + """ + Process incoming connection requests and validate the challenge. + """ + if self.s2_db: + auth_header = request_headers.get("Authorization") + if not auth_header: + logger.warning("Connection attempt without Authorization header. Rejecting.") + return (401, [], b"Unauthorized") + + if not auth_header.startswith("Bearer "): + logger.warning("Invalid Authorization header format. Rejecting.") + return (401, [], b"Unauthorized") + + token = auth_header.split(" ", 1)[1] + + if not self.s2_db.verify_and_remove_challenge(token): + logger.warning("Invalid token provided. Rejecting connection.") + return (403, [], b"Forbidden") + + logger.info("Token validated. Accepting connection.") + + return None # Accept connection + async def _connect_and_run(self) -> None: """Connect to the WebSocket server and run the event loop.""" - self._received_messages = asyncio.Queue() + self._received_messages: asyncio.Queue[S2Message] = asyncio.Queue() if self._server is None: self._server = await ws_serve( self._handle_websocket_connection, host=self.host, port=self.port, + process_request=self._process_request, ) - logger.info( - "S2 WebSocket server running at: ws://%s:%s", self.host, self.port - ) + logger.info("S2 WebSocket server running at: ws://%s:%s", self.host, self.port) else: logger.info( "S2 WebSocket server already running at: ws://%s:%s", @@ -191,9 +219,7 @@ async def wait_till_connection_restart() -> None: self._eventloop.create_task(wait_till_stop()), self._eventloop.create_task(wait_till_connection_restart()), ] - (done, pending) = await asyncio.wait( - background_tasks, return_when=asyncio.FIRST_COMPLETED - ) + (done, pending) = await asyncio.wait(background_tasks, return_when=asyncio.FIRST_COMPLETED) await self._stop_event.wait() def stop(self) -> None: @@ -203,9 +229,7 @@ def stop(self) -> None: if self._server: self._server.close() - async def _handle_websocket_connection( - self, websocket: WebSocketServerProtocol, path: str - ) -> None: + async def _handle_websocket_connection(self, websocket: WebSocketServerProtocol, path: str) -> None: """Handle incoming WebSocket connections. Args: @@ -223,15 +247,11 @@ async def _handle_websocket_connection( s2_msg = self.s2_parser.parse_as_any_message(message) if isinstance(s2_msg, ReceptionStatus): logger.info("Received reception status: %s", s2_msg) - await self.reception_status_awaiter.receive_reception_status( - s2_msg - ) + await self.reception_status_awaiter.receive_reception_status(s2_msg) continue except json.JSONDecodeError: await self.respond_with_reception_status( - subject_message_id=uuid.UUID( - "00000000-0000-0000-0000-000000000000" - ), + subject_message_id=uuid.UUID("00000000-0000-0000-0000-000000000000"), status=ReceptionStatusValues.INVALID_DATA, diagnostic_label="Not valid json.", websocket=websocket, @@ -242,9 +262,7 @@ async def _handle_websocket_connection( await self._handlers.handle_message(self, s2_msg, websocket) except json.JSONDecodeError: await self.respond_with_reception_status( - subject_message_id=uuid.UUID( - "00000000-0000-0000-0000-000000000000" - ), + subject_message_id=uuid.UUID("00000000-0000-0000-0000-000000000000"), status=ReceptionStatusValues.INVALID_DATA, diagnostic_label="Not valid json.", websocket=websocket, @@ -261,9 +279,7 @@ async def _handle_websocket_connection( ) else: await self.respond_with_reception_status( - subject_message_id=uuid.UUID( - "00000000-0000-0000-0000-000000000000" - ), + subject_message_id=uuid.UUID("00000000-0000-0000-0000-000000000000"), status=ReceptionStatusValues.INVALID_DATA, diagnostic_label="Message appears valid json but could not find a message_id field.", websocket=websocket, @@ -299,9 +315,7 @@ async def respond_with_reception_status( status=status, diagnostic_label=diagnostic_label, ) - logger.info( - "Sending reception status %s for message %s", status, subject_message_id - ) + logger.info("Sending reception status %s for message %s", status, subject_message_id) try: await websocket.send(response.to_json()) except websockets.exceptions.ConnectionClosed: @@ -317,9 +331,7 @@ def respond_with_reception_status_sync( """Synchronous version of respond_with_reception_status.""" if self._loop: asyncio.run_coroutine_threadsafe( - self.respond_with_reception_status( - subject_message_id, status, diagnostic_label, websocket - ), + self.respond_with_reception_status(subject_message_id, status, diagnostic_label, websocket), self._loop, ).result() @@ -346,7 +358,6 @@ async def send_msg_and_await_reception_status_async( timeout_reception_status, ) try: - logger.info("Waiting for reception status for -------> %s", s2_msg.message_id) # type: ignore[attr-defined, union-attr] try: response = await websocket.recv() logger.info("Received reception status: %s", response) @@ -402,9 +413,7 @@ async def handle_handshake( message_id=message.message_id, selected_protocol_version=message.supported_protocol_versions, ) - await self.send_msg_and_await_reception_status_async( - handshake_response, websocket - ) + await self.send_msg_and_await_reception_status_async(handshake_response, websocket) await self.respond_with_reception_status( subject_message_id=message.message_id, @@ -431,9 +440,7 @@ async def handle_reception_status( type(message), ) return - logger.info( - "Received ReceptionStatus in handle_reception_status: %s", message.to_json() - ) + logger.info("Received ReceptionStatus in handle_reception_status: %s", message.to_json()) async def handle_handshake_response( self, @@ -457,9 +464,7 @@ async def handle_handshake_response( logger.debug("Received HandshakeResponse: %s", message.to_json()) - async def _send_and_forget( - self, s2_msg: S2Message, websocket: WebSocketServerProtocol - ) -> None: + async def _send_and_forget(self, s2_msg: S2Message, websocket: WebSocketServerProtocol) -> None: """Send a message and forget about it. Args: diff --git a/src/s2python/authorization/server.py b/src/s2python/authorization/server.py index 74e0d72..f35a94e 100644 --- a/src/s2python/authorization/server.py +++ b/src/s2python/authorization/server.py @@ -72,6 +72,7 @@ def __init__( self._private_key: Optional[str] = None self._private_jwk: Optional[Jwk] = None + self.encryption_algorithm = None @abc.abstractmethod def generate_key_pair(self) -> Tuple[str, str]: """Generate a public/private key pair for the server. @@ -136,9 +137,10 @@ def handle_pairing_request( not pairing_request.publicKey or not pairing_request.s2ClientNodeId or not pairing_request.token + or not pairing_request.encryptionAlgorithm ): raise ValueError( - "Missing fields, public key, s2ClientNodeId and token are required" + "Missing fields, public key, s2ClientNodeId, token and encryptionAlgorithm are required" ) # Validate token @@ -151,7 +153,7 @@ def handle_pairing_request( self.store_client_public_key( str(pairing_request.s2ClientNodeId), pairing_request.publicKey ) - + self.encryption_algorithm = pairing_request.encryptionAlgorithm #type: ignore # Create full URLs for endpoints base_url = self._get_base_url() request_connection_uri = f"{base_url}/requestConnection" diff --git a/src/s2python/communication/s2_connection.py b/src/s2python/communication/s2_connection.py index 2a6d099..a3a2760 100644 --- a/src/s2python/communication/s2_connection.py +++ b/src/s2python/communication/s2_connection.py @@ -331,47 +331,46 @@ async def _connect_and_run(self) -> None: self._received_messages = asyncio.Queue() if self.ws is None: await self._connect_ws() - else: - self.ws = self.ws - - async def wait_till_stop() -> None: - await self._stop_event.wait() - - async def wait_till_connection_restart() -> None: - await self._restart_connection_event.wait() - - background_tasks = [ - self._eventloop.create_task(self._receive_messages()), - self._eventloop.create_task(wait_till_stop()), - self._eventloop.create_task( - self._connect_as_rm() if self.role == EnergyManagementRole.RM else self._connect_as_cem() - ), - self._eventloop.create_task(wait_till_connection_restart()), - ] - - (done, pending) = await asyncio.wait(background_tasks, return_when=asyncio.FIRST_COMPLETED) - - if self._current_control_type: - self._current_control_type.deactivate(self) - self._current_control_type = None - - for task in done: - try: - await task - except asyncio.CancelledError: - pass - except (websockets.ConnectionClosedError, websockets.ConnectionClosedOK): - logger.info("The other party closed the websocket connection.") - - for task in pending: - try: - task.cancel() - await task - except asyncio.CancelledError: - pass - - await self.ws.close() - await self.ws.wait_closed() + self.ws = self.ws + + async def wait_till_stop() -> None: + await self._stop_event.wait() + + async def wait_till_connection_restart() -> None: + await self._restart_connection_event.wait() + + background_tasks = [ + self._eventloop.create_task(self._receive_messages()), + self._eventloop.create_task(wait_till_stop()), + self._eventloop.create_task( + self._connect_as_rm() if self.role == EnergyManagementRole.RM else self._connect_as_cem() + ), + self._eventloop.create_task(wait_till_connection_restart()), + ] + + (done, pending) = await asyncio.wait(background_tasks, return_when=asyncio.FIRST_COMPLETED) + + if self._current_control_type: + self._current_control_type.deactivate(self) + self._current_control_type = None + + for task in done: + try: + await task + except asyncio.CancelledError: + pass + except (websockets.ConnectionClosedError, websockets.ConnectionClosedOK): + logger.info("The other party closed the websocket connection.") + + for task in pending: + try: + task.cancel() + await task + except asyncio.CancelledError: + pass + + await self.ws.close() + await self.ws.wait_closed() async def _connect_ws(self) -> None: max_retries = 3 @@ -402,26 +401,6 @@ async def _connect_ws(self) -> None: continue raise RuntimeError(f"Failed to connect after {max_retries} attempts: {str(e)}") - async def _start_server(self) -> None: - max_retries = 3 - retry_delay = 1 # seconds - - for attempt in range(max_retries): - try: - logger.info("Starting WebSocket server (attempt %d/%d)", attempt + 1, max_retries) - - self.ws = await ws_serve(self._handle_websocket_connection, self.url) - logger.info("Successfully started WebSocket server") - return - - except (EOFError, OSError) as e: - logger.warning( - "Could not start WebSocket server due to: %s (attempt %d/%d)", str(e), attempt + 1, max_retries - ) - if attempt < max_retries - 1: - await asyncio.sleep(retry_delay) - continue - raise RuntimeError(f"Failed to start WebSocket server after {max_retries} attempts: {str(e)}") async def _connect_as_rm(self) -> None: From 66383637881e37851941e39181f2f2d159a70742 Mon Sep 17 00:00:00 2001 From: Vlad Iftime Date: Mon, 23 Jun 2025 13:18:51 +0200 Subject: [PATCH 3/8] Added working examples for CEM server and RM client with auth/pair and without --- README.rst | 18 +- examples/example_frbc_rm.py | 6 +- examples/example_pairing_frbc_rm.py | 159 ++++++++++++++++-- examples/example_s2_server.py | 10 ++ src/s2python/authorization/default_client.py | 15 +- .../authorization/default_ws_server.py | 7 +- .../communication/reception_status_awaiter.py | 4 +- src/s2python/communication/s2_connection.py | 7 +- 8 files changed, 195 insertions(+), 31 deletions(-) diff --git a/README.rst b/README.rst index 8c9a0c3..ed6ea7d 100644 --- a/README.rst +++ b/README.rst @@ -56,12 +56,26 @@ Use S2 classes in your code: json_str = '{"start_of_range": 4.0, "end_of_range": 5.0, "commodity_quantity": "ELECTRIC.POWER.L1"}' PowerRange.from_json(json_str) -Run an example CEM server and RM client: +Run an example CEM server with websocket and http server: .. code-block:: bash python -m examples.example_s2_server --host localhost --http-port 8000 --ws-port 8080 --pairing-token ca14fda4 - python -m examples.example_pairing_frbc_rm --endpoint http://localhost:8000/requestPairing --pairing_token ca14fda4 + + +This will start both a http and a websocket server instances. It also allows to set a hardcoded pairing token. + +Run an example RM client that pairs with the CEM server, authenticates and starts sending S2 messages that describe an FRBC device: + +.. code-block:: bash + + python -m examples.example_pairing_frbc_rm --pairing_endpoint http://localhost:8000/requestPairing --pairing_token ca14fda4 + +In case you want to run the example of a client that does not need to pair with the CEM server, you can add the --dev-mode flag. This will disable the pairing/authentication check and allows you to send messages to the CEM server without pairing. The CEM server still needs to be running. + +.. code-block:: bash + + python -m examples.example_frbc_rm --endpoint ws://localhost:8080 Development diff --git a/examples/example_frbc_rm.py b/examples/example_frbc_rm.py index 1d3dc80..94d3638 100644 --- a/examples/example_frbc_rm.py +++ b/examples/example_frbc_rm.py @@ -30,7 +30,7 @@ FRBCStorageStatus, FRBCActuatorStatus, ) -from s2python.s2_connection import S2Connection, AssetDetails +from s2python.communication.s2_connection import S2Connection, AssetDetails from s2python.s2_control_type import FRBCControlType, NoControlControlType from s2python.message import S2Message @@ -197,10 +197,10 @@ def sigterm_handler(signum, frame): description="A simple S2 reseource manager example." ) parser.add_argument( - "endpoint", + "--endpoint", type=str, help="WebSocket endpoint uri for the server (CEM) e.g. " - "ws://localhost:8080/backend/rm/s2python-frbc/cem/dummy_model/ws", + "ws://localhost:8080/", ) args = parser.parse_args() diff --git a/examples/example_pairing_frbc_rm.py b/examples/example_pairing_frbc_rm.py index 9183edd..5902473 100644 --- a/examples/example_pairing_frbc_rm.py +++ b/examples/example_pairing_frbc_rm.py @@ -1,13 +1,11 @@ import argparse import logging import threading -import time -import os +import datetime import uuid +from typing import Callable from s2python.authorization.default_client import S2DefaultClient -from s2python.authorization.default_http_server import S2DefaultHTTPServer -from s2python.authorization.default_ws_server import S2DefaultWSServer from s2python.generated.gen_s2_pairing import ( S2NodeDescription, Deployment, @@ -16,15 +14,134 @@ Protocols, ) -logger = logging.getLogger("s2python") - +from s2python.common import ( + EnergyManagementRole, + Duration, + Role, + RoleType, + Commodity, + Currency, + NumberRange, + PowerRange, + CommodityQuantity, +) +from s2python.frbc import ( + FRBCInstruction, + FRBCSystemDescription, + FRBCActuatorDescription, + FRBCStorageDescription, + FRBCOperationMode, + FRBCOperationModeElement, + FRBCFillLevelTargetProfile, + FRBCFillLevelTargetProfileElement, + FRBCStorageStatus, + FRBCActuatorStatus, +) +from s2python.communication.s2_connection import S2Connection, AssetDetails +from s2python.s2_control_type import FRBCControlType, NoControlControlType +from s2python.message import S2Message -def run_http_server(server): - server.start_server() +logger = logging.getLogger("s2python") -def run_ws_server(server): - server.start() +class MyFRBCControlType(FRBCControlType): + def handle_instruction(self, conn: S2Connection, msg: S2Message, send_okay: Callable[[], None]) -> None: + if not isinstance(msg, FRBCInstruction): + raise RuntimeError(f"Expected an FRBCInstruction but received a message of type {type(msg)}.") + print(f"I have received the message {msg} from {conn}") + + def activate(self, conn: S2Connection) -> None: + print("The control type FRBC is now activated.") + + print("Time to send a FRBC SystemDescription") + actuator_id = uuid.uuid4() + operation_mode_id = uuid.uuid4() + conn.send_msg_and_await_reception_status_sync( + FRBCSystemDescription( + message_id=uuid.uuid4(), + valid_from=datetime.datetime.now(tz=datetime.timezone.utc), + actuators=[ + FRBCActuatorDescription( + id=actuator_id, + operation_modes=[ + FRBCOperationMode( + id=operation_mode_id, + elements=[ + FRBCOperationModeElement( + fill_level_range=NumberRange(start_of_range=0.0, end_of_range=100.0), + fill_rate=NumberRange(start_of_range=-5.0, end_of_range=5.0), + power_ranges=[ + PowerRange( + start_of_range=-200.0, + end_of_range=200.0, + commodity_quantity=CommodityQuantity.ELECTRIC_POWER_L1, + ) + ], + ) + ], + diagnostic_label="Load & unload battery", + abnormal_condition_only=False, + ) + ], + transitions=[], + timers=[], + supported_commodities=[Commodity.ELECTRICITY], + ) + ], + storage=FRBCStorageDescription( + fill_level_range=NumberRange(start_of_range=0.0, end_of_range=100.0), + fill_level_label="%", + diagnostic_label="Imaginary battery", + provides_fill_level_target_profile=True, + provides_leakage_behaviour=False, + provides_usage_forecast=False, + ), + ) + ) + print("Also send the target profile") + + conn.send_msg_and_await_reception_status_sync( + FRBCFillLevelTargetProfile( + message_id=uuid.uuid4(), + start_time=datetime.datetime.now(tz=datetime.timezone.utc), + elements=[ + FRBCFillLevelTargetProfileElement( + duration=Duration.from_milliseconds(30_000), + fill_level_range=NumberRange(start_of_range=20.0, end_of_range=30.0), + ), + FRBCFillLevelTargetProfileElement( + duration=Duration.from_milliseconds(300_000), + fill_level_range=NumberRange(start_of_range=40.0, end_of_range=50.0), + ), + ], + ) + ) + + print("Also send the storage status.") + conn.send_msg_and_await_reception_status_sync( + FRBCStorageStatus(message_id=uuid.uuid4(), present_fill_level=10.0) + ) + + print("Also send the actuator status.") + conn.send_msg_and_await_reception_status_sync( + FRBCActuatorStatus( + message_id=uuid.uuid4(), + actuator_id=actuator_id, + active_operation_mode_id=operation_mode_id, + operation_mode_factor=0.5, + ) + ) + + def deactivate(self, conn: S2Connection) -> None: + print("The control type FRBC is now deactivated.") + + +class MyNoControlControlType(NoControlControlType): + def activate(self, conn: S2Connection) -> None: + print("The control type NoControl is now activated.") + + def deactivate(self, conn: S2Connection) -> None: + print("The control type NoControl is now deactivated.") if __name__ == "__main__": @@ -73,13 +190,27 @@ def run_ws_server(server): challenge_result = client.solve_challenge() logger.info("Challenge solved successfully") - # Establish secure connection - s2_connection = client.establish_secure_connection() - logger.info("Secure WebSocket connection established.") + s2_connection = S2Connection( + url=connection_details.connectionUri, # type: ignore + role=EnergyManagementRole.RM, + control_types=[MyFRBCControlType(), MyNoControlControlType()], + asset_details=AssetDetails( + resource_id=client.client_node_id, + name="Some asset", + instruction_processing_delay=Duration.from_milliseconds(20), + roles=[Role(role=RoleType.ENERGY_CONSUMER, commodity=Commodity.ELECTRICITY)], + currency=Currency.EUR, + provides_forecast=False, + provides_power_measurements=[CommodityQuantity.ELECTRIC_POWER_L1], + ), + reconnect=True, + verify_certificate=False, + bearer_token=challenge_result, + ) # Start S2 session with the connection details logger.info("Starting S2 session...") - s2_connection.start() + s2_connection.start_as_rm() logger.info("S2 session is running. Press Ctrl+C to exit.") # Keep the main thread alive to allow the WebSocket connection to run. diff --git a/examples/example_s2_server.py b/examples/example_s2_server.py index a8dfefc..46bf55f 100644 --- a/examples/example_s2_server.py +++ b/examples/example_s2_server.py @@ -180,7 +180,9 @@ async def handle_handshake(server: S2DefaultWSServer, message: S2Message, websoc await server._send_and_forget(handshake_response, websocket) # If client is RM, send control type selection + logger.info("Role: %s", message.role) if message.role == EnergyManagementRole.RM: + logger.info("Sending control type selection") # First await the send_okay for the handshake # await send_okay # Then send the control type selection and wait for its reception status @@ -190,6 +192,7 @@ async def handle_handshake(server: S2DefaultWSServer, message: S2Message, websoc ) logger.info("Sending select control type: %s", select_control_type.to_json()) await server.send_msg_and_await_reception_status_async(select_control_type, websocket) + logger.info("Activated control type. Routine finished") if __name__ == "__main__": @@ -225,6 +228,12 @@ async def handle_handshake(server: S2DefaultWSServer, message: S2Message, websoc default="s2.db", help="Path to the SQLite database (default: s2.db)", ) + + parser.add_argument( + "--dev-mode", + action="store_true", + help="Enable dev mode (default: False)", + ) args = parser.parse_args() @@ -249,6 +258,7 @@ async def handle_handshake(server: S2DefaultWSServer, message: S2Message, websoc port=args.ws_port, db_path=args.db_path, role=EnergyManagementRole.CEM, + dev_mode=args.dev_mode, ) # Register our custom handshake handler server_ws._handlers.register_handler(Handshake, handle_handshake) diff --git a/src/s2python/authorization/default_client.py b/src/s2python/authorization/default_client.py index 1ef606d..9ec2559 100644 --- a/src/s2python/authorization/default_client.py +++ b/src/s2python/authorization/default_client.py @@ -9,7 +9,8 @@ import json import uuid import logging -from typing import Dict, Optional, Tuple, Union, List, Any, Mapping +import datetime +from typing import Dict, Optional, Tuple, Union, List, Any, Mapping, Callable import requests from requests import Response @@ -27,8 +28,8 @@ KEY_ALGORITHM, PairingDetails, ) -from s2python.communication.s2_connection import S2Connection from s2python.common import EnergyManagementRole +from s2python.communication.s2_connection import S2Connection # Set up logging logging.basicConfig(level=logging.INFO) @@ -51,6 +52,7 @@ def __init__( verify_certificate: Union[bool, str] = False, client_node_id: Optional[uuid.UUID] = None, supported_protocols: Optional[List[Protocols]] = None, + role: EnergyManagementRole = EnergyManagementRole.RM, ) -> None: """Initialize the default client with configuration parameters.""" super().__init__( @@ -63,6 +65,7 @@ def __init__( ) # Additional state for this implementation self._ws_connection: Optional[S2Connection] = None + self._role = role def generate_key_pair(self) -> Tuple[str, str]: """Generate a public/private key pair using jwskate library. @@ -228,15 +231,15 @@ def establish_secure_connection(self) -> S2Connection: # Placeholder for the connection object self._ws_connection = S2Connection( url=str(self._connection_details.connectionUri), - role=EnergyManagementRole.CEM, + role=self._role, bearer_token=self._pairing_details.decrypted_challenge_str, ) - + return self._ws_connection def close_connection(self) -> None: """Close the WebSocket connection.""" if self._ws_connection: - - logger.info("Would close WebSocket connection") + logger.info("Closing WebSocket connection") + self._ws_connection.stop() self._ws_connection = None diff --git a/src/s2python/authorization/default_ws_server.py b/src/s2python/authorization/default_ws_server.py index fd30f85..ddb46ec 100644 --- a/src/s2python/authorization/default_ws_server.py +++ b/src/s2python/authorization/default_ws_server.py @@ -105,6 +105,7 @@ def __init__( port: int = 8080, role: EnergyManagementRole = EnergyManagementRole.CEM, db_path: Optional[str] = None, + dev_mode: bool = False, ) -> None: """Initialize the WebSocket server. @@ -129,6 +130,7 @@ def __init__( self.s2_db = S2Database(db_path) if db_path else None # Register default handlers self._register_default_handlers() + self.dev_mode = dev_mode def _register_default_handlers(self) -> None: """Register default message handlers.""" @@ -171,6 +173,9 @@ async def _process_request( """ Process incoming connection requests and validate the challenge. """ + if self.dev_mode: + return None + if self.s2_db: auth_header = request_headers.get("Authorization") if not auth_header: @@ -408,7 +413,7 @@ async def handle_handshake( ) return - logger.info("Received Handshak(In WS Server): %s", message.to_json()) + logger.info("Received Handshake(In WS Server): %s", message.to_json()) handshake_response = HandshakeResponse( message_id=message.message_id, selected_protocol_version=message.supported_protocol_versions, diff --git a/src/s2python/communication/reception_status_awaiter.py b/src/s2python/communication/reception_status_awaiter.py index aa4120b..cae5a74 100644 --- a/src/s2python/communication/reception_status_awaiter.py +++ b/src/s2python/communication/reception_status_awaiter.py @@ -27,8 +27,8 @@ async def wait_for_reception_status( self, message_id: uuid.UUID, timeout_reception_status: float ) -> ReceptionStatus: # log all the received messages - logger.info(f"Received messages: {self.received}") - logger.info(f"Awaiting messages: {self.awaiting}") + # logger.info(f"Received messages: {self.received}") + # logger.info(f"Awaiting messages: {self.awaiting}") if message_id in self.received: reception_status = self.received[message_id] else: diff --git a/src/s2python/communication/s2_connection.py b/src/s2python/communication/s2_connection.py index a3a2760..914e602 100644 --- a/src/s2python/communication/s2_connection.py +++ b/src/s2python/communication/s2_connection.py @@ -99,7 +99,7 @@ def __init__(self, connection: "S2Connection", subject_message: S2Message): async def run_async(self) -> None: self.status_is_send.set() - logger.info("Sending reception status for message (SendOkay) %s", self.subject_message_id) + logger.info("SendOkay") await self.connection.respond_with_reception_status( subject_message_id=self.subject_message_id, status=ReceptionStatusValues.OK, @@ -234,6 +234,7 @@ def __init__( # pylint: disable=too-many-arguments # Register default handlers based on role if role == EnergyManagementRole.RM: + logger.info("Registering RM handlers") self._register_rm_handlers() else: logger.info("Registering CEM handlers") @@ -496,9 +497,9 @@ async def handle_select_control_type_as_rm( ) return - await send_okay logger.debug("CEM selected control type %s. Activating control type.", message.control_type) - + await send_okay + control_types_by_protocol_name = {c.get_protocol_control_type(): c for c in self.control_types} selected_control_type = control_types_by_protocol_name.get(message.control_type) From 477f7e793d53987fae9784ee0b1adb72561baa4f Mon Sep 17 00:00:00 2001 From: Vlad Iftime Date: Fri, 11 Jul 2025 10:49:15 +0200 Subject: [PATCH 4/8] Experimenting with custom thread for the S2Connection --- examples/example_pairing_frbc_rm.py | 2 +- .../custom_thread_s2_connection.py | 48 +++++++++++++++++++ .../unit/custom_thread_s2_connection_test.py | 48 +++++++++++++++++++ 3 files changed, 97 insertions(+), 1 deletion(-) create mode 100644 src/s2python/communication/custom_thread_s2_connection.py create mode 100644 tests/unit/custom_thread_s2_connection_test.py diff --git a/examples/example_pairing_frbc_rm.py b/examples/example_pairing_frbc_rm.py index 5902473..5a3bb8c 100644 --- a/examples/example_pairing_frbc_rm.py +++ b/examples/example_pairing_frbc_rm.py @@ -154,7 +154,7 @@ def deactivate(self, conn: S2Connection) -> None: pairing_endpoint = args.pairing_endpoint pairing_token = args.pairing_token - + res_man = RM(S2Connection) # --- Client Setup --- # Create node description node_description = S2NodeDescription( diff --git a/src/s2python/communication/custom_thread_s2_connection.py b/src/s2python/communication/custom_thread_s2_connection.py new file mode 100644 index 0000000..8d8cc79 --- /dev/null +++ b/src/s2python/communication/custom_thread_s2_connection.py @@ -0,0 +1,48 @@ +import threading +import asyncio +from typing import Optional +from s2python.communication.s2_connection import S2Connection +from s2python.common import EnergyManagementRole + +class CustomThreadS2Connection(S2Connection): + """ + Extends S2Connection to allow running the event loop in a developer-supplied thread. + + If a thread is provided, the event loop will be run in that thread. The developer is responsible + for managing the thread's lifecycle and ensuring it is not used for other conflicting tasks. + """ + def __init__(self, *args, thread: Optional[threading.Thread] = None, **kwargs): + super().__init__(*args, **kwargs) + self._external_thread = thread + self._thread_started_by_user = False + + def start(self) -> None: + if self._external_thread: + # Only start the thread if it is not already running + if not self._external_thread.is_alive(): + def run_loop(): + asyncio.set_event_loop(self._eventloop) + if self.role == EnergyManagementRole.RM: + self._run_eventloop(self._run_as_rm()) + else: + self._run_eventloop(self._run_as_cem()) + self._external_thread.run = run_loop + self._external_thread.start() + self._thread_started_by_user = True + else: + raise RuntimeError("Provided thread is already running. Please provide a fresh thread.") + else: + # Default behavior: run in the current thread + super().start() + + def stop(self) -> None: + """ + Stops the S2 connection. If an external thread was provided and started by this class, + it will join that thread. Otherwise, uses the default stop behavior. + """ + if self._external_thread and self._thread_started_by_user: + if self._eventloop.is_running(): + asyncio.run_coroutine_threadsafe(self._do_stop(), self._eventloop).result() + self._external_thread.join() + else: + super().stop() diff --git a/tests/unit/custom_thread_s2_connection_test.py b/tests/unit/custom_thread_s2_connection_test.py new file mode 100644 index 0000000..77d071c --- /dev/null +++ b/tests/unit/custom_thread_s2_connection_test.py @@ -0,0 +1,48 @@ +import unittest +import threading +import time +from s2python.communication.custom_thread_s2_connection import CustomThreadS2Connection +from s2python.common import EnergyManagementRole + +class DummyS2Connection(CustomThreadS2Connection): + def _run_eventloop(self, main_task): + # Simulate event loop running for a short time + time.sleep(0.1) + def _run_as_rm(self): + # Dummy awaitable + class DummyAwaitable: + def __await__(self): + yield + return DummyAwaitable() + def _run_as_cem(self): + class DummyAwaitable: + def __await__(self): + yield + return DummyAwaitable() + async def _do_stop(self): + pass + +class TestCustomThreadS2Connection(unittest.TestCase): + def test_start_with_external_thread(self): + thread = threading.Thread() + conn = DummyS2Connection( + url="ws", + role=EnergyManagementRole.RM, + thread=thread + ) + conn.start() + # Wait for thread to finish + conn.stop() + self.assertFalse(thread.is_alive()) + + def test_start_without_external_thread(self): + conn = DummyS2Connection( + url="ws://localhost:1234", + role=EnergyManagementRole.CEM + ) + # Should not raise + conn.start() + conn.stop() + +if __name__ == "__main__": + unittest.main() From 2f9afaf8a9891048a19e0140287950901641d4e5 Mon Sep 17 00:00:00 2001 From: "F.N. Claessen" Date: Wed, 13 Aug 2025 15:25:18 +0200 Subject: [PATCH 5/8] dev: add resource-id and bearer-token arguments to example script Signed-off-by: F.N. Claessen --- examples/example_frbc_rm.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/examples/example_frbc_rm.py b/examples/example_frbc_rm.py index 94d3638..df019e4 100644 --- a/examples/example_frbc_rm.py +++ b/examples/example_frbc_rm.py @@ -158,7 +158,7 @@ def stop(s2_connection, signal_num, _current_stack_frame): s2_connection.stop() -def start_s2_session(url, client_node_id=str(uuid.uuid4())): +def start_s2_session(url, client_node_id=str(uuid.uuid4()), bearer_token=None): s2_conn = S2Connection( url=url, role=EnergyManagementRole.RM, @@ -176,6 +176,7 @@ def start_s2_session(url, client_node_id=str(uuid.uuid4())): ), reconnect=True, verify_certificate=False, + bearer_token=bearer_token, ) # Create signal handlers @@ -194,7 +195,7 @@ def sigterm_handler(signum, frame): if __name__ == "__main__": parser = argparse.ArgumentParser( - description="A simple S2 reseource manager example." + description="A simple S2 resource manager example." ) parser.add_argument( "--endpoint", @@ -202,6 +203,19 @@ def sigterm_handler(signum, frame): help="WebSocket endpoint uri for the server (CEM) e.g. " "ws://localhost:8080/", ) + parser.add_argument( + "--resource-id", + type=str, + required=False, + help="Resource that we want to manage. " + "Some UUID", + ) + parser.add_argument( + "--bearer-token", + type=str, + required=False, + help="Bearer token for testing." + ) args = parser.parse_args() - start_s2_session(args.endpoint) + start_s2_session(args.endpoint, args.resource_id, args.bearer_token) From 30b23079766ea9d3f500ee94e227d3f593f157e4 Mon Sep 17 00:00:00 2001 From: Vlad Iftime Date: Thu, 23 Oct 2025 10:53:33 +0200 Subject: [PATCH 6/8] Added itho specific example Signed-off-by: Vlad Iftime --- examples/example_frbc_rm.py | 166 ++++++++++++++++++++++++++++++------ examples/test_ws.py | 58 ------------- 2 files changed, 140 insertions(+), 84 deletions(-) delete mode 100644 examples/test_ws.py diff --git a/examples/example_frbc_rm.py b/examples/example_frbc_rm.py index df019e4..094fadd 100644 --- a/examples/example_frbc_rm.py +++ b/examples/example_frbc_rm.py @@ -5,6 +5,8 @@ import uuid import signal import datetime +import time +import threading from typing import Callable from s2python.common import ( @@ -17,6 +19,8 @@ NumberRange, PowerRange, CommodityQuantity, + Timer, + Transition, ) from s2python.frbc import ( FRBCInstruction, @@ -52,9 +56,22 @@ def handle_instruction( def activate(self, conn: S2Connection) -> None: print("The control type FRBC is now activated.") - print("Time to send a FRBC SystemDescription") + print("Creating a FRBC device with proper transitions and timers") + + # Create charge and off operation modes like in example_schedule_frbc.py actuator_id = uuid.uuid4() - operation_mode_id = uuid.uuid4() + charge_operation_mode_id = uuid.uuid4() + off_operation_mode_id = uuid.uuid4() + + # Create timers for transitions (needed for proper FRBC operation) + on_to_off_timer_id = uuid.uuid4() + off_to_on_timer_id = uuid.uuid4() + + # Create transitions between modes + transition_on_to_off_id = uuid.uuid4() + transition_off_to_on_id = uuid.uuid4() + + print("Time to send a FRBC SystemDescription") conn.send_msg_and_await_reception_status_sync( FRBCSystemDescription( message_id=uuid.uuid4(), @@ -63,31 +80,89 @@ def activate(self, conn: S2Connection) -> None: FRBCActuatorDescription( id=actuator_id, operation_modes=[ + # Charging mode - similar to recharge system description from example FRBCOperationMode( - id=operation_mode_id, + id=charge_operation_mode_id, elements=[ FRBCOperationModeElement( fill_level_range=NumberRange( start_of_range=0.0, end_of_range=100.0 ), fill_rate=NumberRange( - start_of_range=-5.0, end_of_range=5.0 + start_of_range=0.0, end_of_range=0.01099537114 # Charge power from example ), power_ranges=[ PowerRange( - start_of_range=-200.0, - end_of_range=200.0, + start_of_range=0.0, + end_of_range=57000.0, # 57kW from example commodity_quantity=CommodityQuantity.ELECTRIC_POWER_L1, ) ], ) ], - diagnostic_label="Load & unload battery", + diagnostic_label="charge.on", abnormal_condition_only=False, + ), + # Off mode - similar to driving/off system description from example + FRBCOperationMode( + id=off_operation_mode_id, + elements=[ + FRBCOperationModeElement( + fill_level_range=NumberRange( + start_of_range=0.0, end_of_range=100.0 + ), + fill_rate=NumberRange( + start_of_range=0.0, end_of_range=0.0 + ), + power_ranges=[ + PowerRange( + start_of_range=0.0, + end_of_range=0.0, + commodity_quantity=CommodityQuantity.ELECTRIC_POWER_L1, + ) + ], + ) + ], + diagnostic_label="charge.off", + abnormal_condition_only=False, + ) + ], + transitions=[ + # Transition from charging to off + Transition( + id=transition_on_to_off_id, + **{"from": charge_operation_mode_id}, + to=off_operation_mode_id, + start_timers=[off_to_on_timer_id], + blocking_timers=[on_to_off_timer_id], + transition_duration=None, + abnormal_condition_only=False + ), + # Transition from off to charging + Transition( + id=transition_off_to_on_id, + **{"from": off_operation_mode_id}, + to=charge_operation_mode_id, + start_timers=[on_to_off_timer_id], + blocking_timers=[off_to_on_timer_id], + transition_duration=None, + abnormal_condition_only=False + ) + ], + timers=[ + # Timer for on to off transition + Timer( + id=on_to_off_timer_id, + diagnostic_label="charge_on.to.off.timer", + duration=Duration.from_milliseconds(30000) # 30 seconds + ), + # Timer for off to on transition + Timer( + id=off_to_on_timer_id, + diagnostic_label="charge_off.to.on.timer", + duration=Duration.from_milliseconds(30000) # 30 seconds ) ], - transitions=[], - timers=[], supported_commodities=[Commodity.ELECTRICITY], ) ], @@ -95,52 +170,89 @@ def activate(self, conn: S2Connection) -> None: fill_level_range=NumberRange( start_of_range=0.0, end_of_range=100.0 ), - fill_level_label="%", - diagnostic_label="Imaginary battery", + fill_level_label="SoC %", + diagnostic_label="battery", provides_fill_level_target_profile=True, provides_leakage_behaviour=False, provides_usage_forecast=False, ), ) ) - print("Also send the target profile") - + + print("Send fill level target profile - similar to example pattern") + # Create a target profile similar to the example with charging goals conn.send_msg_and_await_reception_status_sync( FRBCFillLevelTargetProfile( message_id=uuid.uuid4(), start_time=datetime.datetime.now(tz=datetime.timezone.utc), elements=[ + # First period: charge to higher level (similar to recharge period from example) FRBCFillLevelTargetProfileElement( - duration=Duration.from_milliseconds(30_000), + duration=Duration.from_milliseconds(1800000), # 30 minutes fill_level_range=NumberRange( - start_of_range=20.0, end_of_range=30.0 + start_of_range=80.0, end_of_range=100.0 # Target high charge ), ), + # Second period: maintain level FRBCFillLevelTargetProfileElement( - duration=Duration.from_milliseconds(300_000), + duration=Duration.from_milliseconds(1800000), # 30 minutes fill_level_range=NumberRange( - start_of_range=40.0, end_of_range=50.0 + start_of_range=90.0, end_of_range=100.0 # Maintain high charge ), ), ], ) ) - - print("Also send the storage status.") + time.sleep(5) + print("Send storage status - current charge level") conn.send_msg_and_await_reception_status_sync( - FRBCStorageStatus(message_id=uuid.uuid4(), present_fill_level=10.0) + FRBCStorageStatus(message_id=uuid.uuid4(), present_fill_level=20.0) # Start at 20% like example ) - - print("Also send the actuator status.") + time.sleep(5) + print("Send actuator status - currently in charge mode") conn.send_msg_and_await_reception_status_sync( FRBCActuatorStatus( message_id=uuid.uuid4(), actuator_id=actuator_id, - active_operation_mode_id=operation_mode_id, - operation_mode_factor=0.5, + active_operation_mode_id=charge_operation_mode_id, # Start in charge mode + operation_mode_factor=0.0, # Will be set by CEM instructions ) ) + # Start the countdown loop for sending periodic actuator status + self._start_actuator_status_loop(conn, actuator_id, charge_operation_mode_id) + + def _start_actuator_status_loop(self, conn: S2Connection, actuator_id: uuid.UUID, operation_mode_id: uuid.UUID) -> None: + """Start a background thread that sends actuator status every 45 seconds with countdown display.""" + def countdown_and_send(): + while True: + try: + # 45 second countdown with display + for remaining in range(20, 0, -1): + print(f"\rNext actuator status in {remaining:2d} seconds...", end="", flush=True) + time.sleep(1) + + print("\rSending actuator status... ") + + # Send actuator status + conn.send_msg_and_await_reception_status_sync( + FRBCActuatorStatus( + message_id=uuid.uuid4(), + actuator_id=actuator_id, + active_operation_mode_id=operation_mode_id, + operation_mode_factor=0.0, + ) + ) + print("Actuator status sent successfully!") + + except Exception as e: + print(f"\nError sending actuator status: {e}") + break + + # Start the countdown thread as daemon so it stops when main program exits + countdown_thread = threading.Thread(target=countdown_and_send, daemon=True) + countdown_thread.start() + def deactivate(self, conn: S2Connection) -> None: print("The control type FRBC is now deactivated.") @@ -217,5 +329,7 @@ def sigterm_handler(signum, frame): help="Bearer token for testing." ) args = parser.parse_args() - - start_s2_session(args.endpoint, args.resource_id, args.bearer_token) + args.bearer_token = 'cvp6XXTsgonYda9IB52ltqS+StG7xFrt+ApqVIwUVhg=' + # Use provided resource_id or generate a new UUID if None + resource_id = args.resource_id if args.resource_id is not None else str(uuid.uuid4()) + start_s2_session(args.endpoint, resource_id, args.bearer_token) diff --git a/examples/test_ws.py b/examples/test_ws.py deleted file mode 100644 index 7b07b04..0000000 --- a/examples/test_ws.py +++ /dev/null @@ -1,58 +0,0 @@ -#!/usr/bin/env python - -import asyncio -import websockets -import uuid -from s2python.common import ( - EnergyManagementRole, - Handshake, - ReceptionStatus, - ReceptionStatusValues, -) -import json - - -async def hello(): - """ - Connects to a WebSocket server, sends a message, - and prints the response. - """ - uri = "ws://localhost:8080" # <-- Replace with your server's URI - try: - async with websockets.connect(uri) as websocket: - message = Handshake( - message_id=uuid.uuid4(), - role=EnergyManagementRole.RM, - supported_protocol_versions=["1.0"], - ) - - await websocket.send(message.to_json()) - print(f">>> {message.to_json()}") - - reception_status = await websocket.recv() - reception_status_json = json.loads(reception_status) - print(f"<<< {reception_status_json}") - - handshake_response = await websocket.recv() - handshake_response_json = json.loads(handshake_response) - print(f"<<< {handshake_response_json}") - - reception_status = ReceptionStatus( - subject_message_id=handshake_response_json["message_id"], - status=ReceptionStatusValues.OK, - diagnostic_label="Handshake received", - ) - await websocket.send(reception_status.to_json()) - print(f">>> {reception_status.to_json()}") - response = await websocket.recv() - - print(f"<<< {response}") - - except ConnectionRefusedError: - print(f"Connection to {uri} refused. Is the server running?") - except Exception as e: - print(f"An error occurred: {e}") - - -if __name__ == "__main__": - asyncio.run(hello()) From 81e53ba77e3684c40bd44c00cce64306602f446e Mon Sep 17 00:00:00 2001 From: Vlad Iftime Date: Thu, 23 Oct 2025 10:56:31 +0200 Subject: [PATCH 7/8] Removed issue left in Signed-off-by: Vlad Iftime --- examples/example_frbc_rm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/example_frbc_rm.py b/examples/example_frbc_rm.py index 094fadd..1f3d424 100644 --- a/examples/example_frbc_rm.py +++ b/examples/example_frbc_rm.py @@ -329,7 +329,7 @@ def sigterm_handler(signum, frame): help="Bearer token for testing." ) args = parser.parse_args() - args.bearer_token = 'cvp6XXTsgonYda9IB52ltqS+StG7xFrt+ApqVIwUVhg=' + args.bearer_token = '' # Use provided resource_id or generate a new UUID if None resource_id = args.resource_id if args.resource_id is not None else str(uuid.uuid4()) start_s2_session(args.endpoint, resource_id, args.bearer_token) From 570c0ea3722c43ad27a42e1b74d34c14ba4790cd Mon Sep 17 00:00:00 2001 From: Vlad Iftime Date: Thu, 23 Oct 2025 10:59:52 +0200 Subject: [PATCH 8/8] Removed issue left in --- examples/example_frbc_rm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/example_frbc_rm.py b/examples/example_frbc_rm.py index 094fadd..1f3d424 100644 --- a/examples/example_frbc_rm.py +++ b/examples/example_frbc_rm.py @@ -329,7 +329,7 @@ def sigterm_handler(signum, frame): help="Bearer token for testing." ) args = parser.parse_args() - args.bearer_token = 'cvp6XXTsgonYda9IB52ltqS+StG7xFrt+ApqVIwUVhg=' + args.bearer_token = '' # Use provided resource_id or generate a new UUID if None resource_id = args.resource_id if args.resource_id is not None else str(uuid.uuid4()) start_s2_session(args.endpoint, resource_id, args.bearer_token)