diff --git a/examples/example_flask_server.py b/examples/example_flask_server.py new file mode 100644 index 0000000..3acf250 --- /dev/null +++ b/examples/example_flask_server.py @@ -0,0 +1,289 @@ +""" +Example S2 server implementation using Flask. + +This example demonstrates how to set up both an HTTP and a WebSocket server +using the Flask-based implementations. + +Note: You need to install Flask and Flask-Sock: + `pip install flask flask-sock` + +For running the WebSocket server with true async capabilities, an ASGI server +like Hypercorn is recommended: + `pip install hypercorn` + `hypercorn examples.example_flask_server:server_ws.app` +""" + +import argparse +import logging +import signal +import sys +import uuid + +from flask_sock import Sock + +from s2python.authorization.flask_http_server import S2FlaskHTTPServer +from s2python.authorization.flask_ws_server import S2FlaskWSServer +from s2python.common import ( + ControlType, + EnergyManagementRole, + Handshake, + HandshakeResponse, + ReceptionStatusValues, + ResourceManagerDetails, + SelectControlType, +) +from s2python.frbc import ( + FRBCActuatorStatus, + FRBCFillLevelTargetProfile, + FRBCStorageStatus, + FRBCSystemDescription, +) +from s2python.generated.gen_s2_pairing import ( + Deployment, + PairingToken, + Protocols, + S2NodeDescription, + S2Role, +) +from s2python.message import S2Message + +# Set up logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger("example_flask_server") + +# Create the server instance at the module level so Hypercorn can find it. +# We assume 'ws' mode for ASGI server execution. +server_instance = S2FlaskWSServer( + host="localhost", # These values can be configured via env vars or other means + port=8080, + role=EnergyManagementRole.CEM, +) + + +def create_signal_handler(): + """Create a signal handler function.""" + + def handler(signum, frame): + logger.info("Received signal %d. Shutting down...", signum) + if server_instance: + server_instance.stop() + # For Flask's dev server, this will be interrupted by the signal. + # For production servers (gunicorn, hypercorn), they handle signals for shutdown. + sys.exit(0) + + return handler + + +async def handle_FRBC_system_description(server: S2FlaskWSServer, message: S2Message, websocket: Sock) -> None: + """Handle FRBC system description messages.""" + if not isinstance(message, FRBCSystemDescription): + logger.error( + "Handler for FRBCSystemDescription received a message of the wrong type: %s", + type(message), + ) + return + + logger.info("Received FRBCSystemDescription: %s", message.to_json()) + await server.respond_with_reception_status( + subject_message_id=message.message_id, + status=ReceptionStatusValues.OK, + diagnostic_label="FRBCSystemDescription received", + websocket=websocket, + ) + + +async def handle_FRBCActuatorStatus(server: S2FlaskWSServer, message: S2Message, websocket: Sock) -> None: + """Handle FRBCActuatorStatus messages.""" + if not isinstance(message, FRBCActuatorStatus): + logger.error( + "Handler for FRBCActuatorStatus received a message of the wrong type: %s", + type(message), + ) + return + + logger.info("Received FRBCActuatorStatus: %s", message.to_json()) + await server.respond_with_reception_status( + subject_message_id=message.message_id, + status=ReceptionStatusValues.OK, + diagnostic_label="FRBCActuatorStatus received", + websocket=websocket, + ) + + +async def handle_FillLevelTargetProfile(server: S2FlaskWSServer, message: S2Message, websocket: Sock) -> None: + """Handle FillLevelTargetProfile messages.""" + if not isinstance(message, FRBCFillLevelTargetProfile): + logger.error( + "Handler for FillLevelTargetProfile received a message of the wrong type: %s", + type(message), + ) + return + + logger.info("Received FillLevelTargetProfile: %s", message.to_json()) + await server.respond_with_reception_status( + subject_message_id=message.message_id, + status=ReceptionStatusValues.OK, + diagnostic_label="FillLevelTargetProfile received", + websocket=websocket, + ) + + +async def handle_FRBCStorageStatus(server: S2FlaskWSServer, message: S2Message, websocket: Sock) -> None: + """Handle FRBCStorageStatus messages.""" + if not isinstance(message, FRBCStorageStatus): + logger.error( + "Handler for FRBCStorageStatus received a message of the wrong type: %s", + type(message), + ) + return + + logger.info("Received FRBCStorageStatus: %s", message.to_json()) + await server.respond_with_reception_status( + subject_message_id=message.message_id, + status=ReceptionStatusValues.OK, + diagnostic_label="FRBCStorageStatus received", + websocket=websocket, + ) + + +async def handle_ResourceManagerDetails(server: S2FlaskWSServer, message: S2Message, websocket: Sock) -> None: + """Handle ResourceManagerDetails messages.""" + if not isinstance(message, ResourceManagerDetails): + logger.error( + "Handler for ResourceManagerDetails received a message of the wrong type: %s", + type(message), + ) + return + + logger.info("Received ResourceManagerDetails: %s", message.to_json()) + await server.respond_with_reception_status( + subject_message_id=message.message_id, + status=ReceptionStatusValues.OK, + diagnostic_label="ResourceManagerDetails received", + websocket=websocket, + ) + + +async def handle_handshake(server: S2FlaskWSServer, message: S2Message, websocket: Sock) -> None: + """Handle handshake messages and send control type selection if client is RM.""" + if not isinstance(message, Handshake): + logger.error( + "Handler for Handshake received a message of the wrong type: %s", + type(message), + ) + return + + logger.info("Received Handshake in example_flask_server: %s", message.to_json()) + + # Send reception status for the handshake + await server.respond_with_reception_status( + subject_message_id=message.message_id, + status=ReceptionStatusValues.OK, + diagnostic_label="Handshake received", + websocket=websocket, + ) + + handshake_response = HandshakeResponse( + message_id=message.message_id, + selected_protocol_version="1.0", + ) + logger.info("Sent HandshakeResponse: %s", handshake_response.to_json()) + await server._send_and_forget(handshake_response, websocket) + + # If client is RM, send control type selection + if message.role == EnergyManagementRole.RM: + select_control_type = SelectControlType( + message_id=uuid.uuid4(), + control_type=ControlType.FILL_RATE_BASED_CONTROL, + ) + logger.info("Sending select control type: %s", select_control_type.to_json()) + await server.send_msg_and_await_reception_status_async(select_control_type, websocket) + +# Register handlers on the globally defined instance +server_instance._handlers.register_handler(Handshake, handle_handshake) +server_instance._handlers.register_handler(FRBCSystemDescription, handle_FRBC_system_description) +server_instance._handlers.register_handler(ResourceManagerDetails, handle_ResourceManagerDetails) +server_instance._handlers.register_handler(FRBCActuatorStatus, handle_FRBCActuatorStatus) +server_instance._handlers.register_handler(FRBCFillLevelTargetProfile, handle_FillLevelTargetProfile) +server_instance._handlers.register_handler(FRBCStorageStatus, handle_FRBCStorageStatus) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Example S2 server implementation using Flask.") + parser.add_argument( + "--host", + type=str, + default="localhost", + help="Host to bind the server to (default: localhost)", + ) + parser.add_argument( + "--http-port", + type=int, + default=8000, + help="HTTP port to use (default: 8000)", + ) + parser.add_argument( + "--ws-port", + type=int, + default=8080, + help="WebSocket port to use (default: 8080)", + ) + parser.add_argument( + "--instance", + type=str, + default="http", + choices=["http", "ws"], + help="Instance to use (http or ws, default: http)", + ) + parser.add_argument( + "--pairing-token", + type=str, + default="ca14fda4", + help="Pairing token to use (default: ca14fda4)", + ) + args = parser.parse_args() + + # Create node description for the server + server_node_description = S2NodeDescription( + brand="TNO", + logoUri="https://www.tno.nl/publish/pages/5604/tno-logo-1484x835_003_.jpg", + type="demo frbc example", + modelName="S2 server example (Flask)", + userDefinedName="TNO S2 server example for frbc using Flask", + role=S2Role.RM, + deployment=Deployment.LAN, + ) + logger.info("http_port: %s", args.http_port) + logger.info("ws_port: %s", args.ws_port) + + # Setup signal handling + handler = create_signal_handler() + signal.signal(signal.SIGINT, handler) + signal.signal(signal.SIGTERM, handler) + + if args.instance == "ws": + logger.info( + "Starting Flask WebSocket server. For async, run with 'hypercorn examples.example_flask_server:server_ws.app'" + ) + try: + server_instance.start() + except KeyboardInterrupt: + handler(signal.SIGINT, None) + else: + server_http = S2FlaskHTTPServer( + host=args.host, + http_port=args.http_port, + ws_port=args.ws_port, + instance=args.instance, + server_node_description=server_node_description, + token=PairingToken(token=args.pairing_token), + supported_protocols=[Protocols.WebSocketSecure], + ) + server_instance = server_http # type: ignore[assignment] + + logger.info("Starting Flask HTTP server.") + try: + # Note: server_http.start_server() uses Flask's development server. + server_http.start_server() + except KeyboardInterrupt: + handler(signal.SIGINT, None) diff --git a/examples/example_frbc_rm.py b/examples/example_frbc_rm.py index 94d3638..1f3d424 100644 --- a/examples/example_frbc_rm.py +++ b/examples/example_frbc_rm.py @@ -5,6 +5,8 @@ import uuid import signal import datetime +import time +import threading from typing import Callable from s2python.common import ( @@ -17,6 +19,8 @@ NumberRange, PowerRange, CommodityQuantity, + Timer, + Transition, ) from s2python.frbc import ( FRBCInstruction, @@ -52,9 +56,22 @@ def handle_instruction( def activate(self, conn: S2Connection) -> None: print("The control type FRBC is now activated.") - print("Time to send a FRBC SystemDescription") + print("Creating a FRBC device with proper transitions and timers") + + # Create charge and off operation modes like in example_schedule_frbc.py actuator_id = uuid.uuid4() - operation_mode_id = uuid.uuid4() + charge_operation_mode_id = uuid.uuid4() + off_operation_mode_id = uuid.uuid4() + + # Create timers for transitions (needed for proper FRBC operation) + on_to_off_timer_id = uuid.uuid4() + off_to_on_timer_id = uuid.uuid4() + + # Create transitions between modes + transition_on_to_off_id = uuid.uuid4() + transition_off_to_on_id = uuid.uuid4() + + print("Time to send a FRBC SystemDescription") conn.send_msg_and_await_reception_status_sync( FRBCSystemDescription( message_id=uuid.uuid4(), @@ -63,31 +80,89 @@ def activate(self, conn: S2Connection) -> None: FRBCActuatorDescription( id=actuator_id, operation_modes=[ + # Charging mode - similar to recharge system description from example + FRBCOperationMode( + id=charge_operation_mode_id, + elements=[ + FRBCOperationModeElement( + fill_level_range=NumberRange( + start_of_range=0.0, end_of_range=100.0 + ), + fill_rate=NumberRange( + start_of_range=0.0, end_of_range=0.01099537114 # Charge power from example + ), + power_ranges=[ + PowerRange( + start_of_range=0.0, + end_of_range=57000.0, # 57kW from example + commodity_quantity=CommodityQuantity.ELECTRIC_POWER_L1, + ) + ], + ) + ], + diagnostic_label="charge.on", + abnormal_condition_only=False, + ), + # Off mode - similar to driving/off system description from example FRBCOperationMode( - id=operation_mode_id, + id=off_operation_mode_id, elements=[ FRBCOperationModeElement( fill_level_range=NumberRange( start_of_range=0.0, end_of_range=100.0 ), fill_rate=NumberRange( - start_of_range=-5.0, end_of_range=5.0 + start_of_range=0.0, end_of_range=0.0 ), power_ranges=[ PowerRange( - start_of_range=-200.0, - end_of_range=200.0, + start_of_range=0.0, + end_of_range=0.0, commodity_quantity=CommodityQuantity.ELECTRIC_POWER_L1, ) ], ) ], - diagnostic_label="Load & unload battery", + diagnostic_label="charge.off", abnormal_condition_only=False, ) ], - transitions=[], - timers=[], + transitions=[ + # Transition from charging to off + Transition( + id=transition_on_to_off_id, + **{"from": charge_operation_mode_id}, + to=off_operation_mode_id, + start_timers=[off_to_on_timer_id], + blocking_timers=[on_to_off_timer_id], + transition_duration=None, + abnormal_condition_only=False + ), + # Transition from off to charging + Transition( + id=transition_off_to_on_id, + **{"from": off_operation_mode_id}, + to=charge_operation_mode_id, + start_timers=[on_to_off_timer_id], + blocking_timers=[off_to_on_timer_id], + transition_duration=None, + abnormal_condition_only=False + ) + ], + timers=[ + # Timer for on to off transition + Timer( + id=on_to_off_timer_id, + diagnostic_label="charge_on.to.off.timer", + duration=Duration.from_milliseconds(30000) # 30 seconds + ), + # Timer for off to on transition + Timer( + id=off_to_on_timer_id, + diagnostic_label="charge_off.to.on.timer", + duration=Duration.from_milliseconds(30000) # 30 seconds + ) + ], supported_commodities=[Commodity.ELECTRICITY], ) ], @@ -95,52 +170,89 @@ def activate(self, conn: S2Connection) -> None: fill_level_range=NumberRange( start_of_range=0.0, end_of_range=100.0 ), - fill_level_label="%", - diagnostic_label="Imaginary battery", + fill_level_label="SoC %", + diagnostic_label="battery", provides_fill_level_target_profile=True, provides_leakage_behaviour=False, provides_usage_forecast=False, ), ) ) - print("Also send the target profile") - + + print("Send fill level target profile - similar to example pattern") + # Create a target profile similar to the example with charging goals conn.send_msg_and_await_reception_status_sync( FRBCFillLevelTargetProfile( message_id=uuid.uuid4(), start_time=datetime.datetime.now(tz=datetime.timezone.utc), elements=[ + # First period: charge to higher level (similar to recharge period from example) FRBCFillLevelTargetProfileElement( - duration=Duration.from_milliseconds(30_000), + duration=Duration.from_milliseconds(1800000), # 30 minutes fill_level_range=NumberRange( - start_of_range=20.0, end_of_range=30.0 + start_of_range=80.0, end_of_range=100.0 # Target high charge ), ), + # Second period: maintain level FRBCFillLevelTargetProfileElement( - duration=Duration.from_milliseconds(300_000), + duration=Duration.from_milliseconds(1800000), # 30 minutes fill_level_range=NumberRange( - start_of_range=40.0, end_of_range=50.0 + start_of_range=90.0, end_of_range=100.0 # Maintain high charge ), ), ], ) ) - - print("Also send the storage status.") + time.sleep(5) + print("Send storage status - current charge level") conn.send_msg_and_await_reception_status_sync( - FRBCStorageStatus(message_id=uuid.uuid4(), present_fill_level=10.0) + FRBCStorageStatus(message_id=uuid.uuid4(), present_fill_level=20.0) # Start at 20% like example ) - - print("Also send the actuator status.") + time.sleep(5) + print("Send actuator status - currently in charge mode") conn.send_msg_and_await_reception_status_sync( FRBCActuatorStatus( message_id=uuid.uuid4(), actuator_id=actuator_id, - active_operation_mode_id=operation_mode_id, - operation_mode_factor=0.5, + active_operation_mode_id=charge_operation_mode_id, # Start in charge mode + operation_mode_factor=0.0, # Will be set by CEM instructions ) ) + # Start the countdown loop for sending periodic actuator status + self._start_actuator_status_loop(conn, actuator_id, charge_operation_mode_id) + + def _start_actuator_status_loop(self, conn: S2Connection, actuator_id: uuid.UUID, operation_mode_id: uuid.UUID) -> None: + """Start a background thread that sends actuator status every 45 seconds with countdown display.""" + def countdown_and_send(): + while True: + try: + # 45 second countdown with display + for remaining in range(20, 0, -1): + print(f"\rNext actuator status in {remaining:2d} seconds...", end="", flush=True) + time.sleep(1) + + print("\rSending actuator status... ") + + # Send actuator status + conn.send_msg_and_await_reception_status_sync( + FRBCActuatorStatus( + message_id=uuid.uuid4(), + actuator_id=actuator_id, + active_operation_mode_id=operation_mode_id, + operation_mode_factor=0.0, + ) + ) + print("Actuator status sent successfully!") + + except Exception as e: + print(f"\nError sending actuator status: {e}") + break + + # Start the countdown thread as daemon so it stops when main program exits + countdown_thread = threading.Thread(target=countdown_and_send, daemon=True) + countdown_thread.start() + def deactivate(self, conn: S2Connection) -> None: print("The control type FRBC is now deactivated.") @@ -158,7 +270,7 @@ def stop(s2_connection, signal_num, _current_stack_frame): s2_connection.stop() -def start_s2_session(url, client_node_id=str(uuid.uuid4())): +def start_s2_session(url, client_node_id=str(uuid.uuid4()), bearer_token=None): s2_conn = S2Connection( url=url, role=EnergyManagementRole.RM, @@ -176,6 +288,7 @@ def start_s2_session(url, client_node_id=str(uuid.uuid4())): ), reconnect=True, verify_certificate=False, + bearer_token=bearer_token, ) # Create signal handlers @@ -194,7 +307,7 @@ def sigterm_handler(signum, frame): if __name__ == "__main__": parser = argparse.ArgumentParser( - description="A simple S2 reseource manager example." + description="A simple S2 resource manager example." ) parser.add_argument( "--endpoint", @@ -202,6 +315,21 @@ def sigterm_handler(signum, frame): help="WebSocket endpoint uri for the server (CEM) e.g. " "ws://localhost:8080/", ) + parser.add_argument( + "--resource-id", + type=str, + required=False, + help="Resource that we want to manage. " + "Some UUID", + ) + parser.add_argument( + "--bearer-token", + type=str, + required=False, + help="Bearer token for testing." + ) args = parser.parse_args() - - start_s2_session(args.endpoint) + args.bearer_token = '' + # Use provided resource_id or generate a new UUID if None + resource_id = args.resource_id if args.resource_id is not None else str(uuid.uuid4()) + start_s2_session(args.endpoint, resource_id, args.bearer_token) diff --git a/examples/example_pairing_frbc_rm.py b/examples/example_pairing_frbc_rm.py index 5902473..5a3bb8c 100644 --- a/examples/example_pairing_frbc_rm.py +++ b/examples/example_pairing_frbc_rm.py @@ -154,7 +154,7 @@ def deactivate(self, conn: S2Connection) -> None: pairing_endpoint = args.pairing_endpoint pairing_token = args.pairing_token - + res_man = RM(S2Connection) # --- Client Setup --- # Create node description node_description = S2NodeDescription( diff --git a/examples/mock_s2_server.py b/examples/mock_s2_server.py deleted file mode 100644 index 1410a5a..0000000 --- a/examples/mock_s2_server.py +++ /dev/null @@ -1,121 +0,0 @@ -import socketserver -import json -from typing import Any -import uuid -import logging -import random -import string - -from s2python.authorization.default_server import S2DefaultHTTPHandler - -# Set up logging -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger("mock_s2_server") - - -def generate_token() -> str: - """ - Generate a random alphanumeric token with exactly 32 characters. - - Returns: - str: A string of 32 random alphanumeric characters matching pattern ^[0-9a-zA-Z]{32}$ - """ - # Define the character set: uppercase letters, lowercase letters, and digits - chars = string.ascii_letters + string.digits - - # Generate a 32-character token by randomly selecting from the character set - token = "".join(random.choice(chars) for _ in range(32)) - - return token - - -# Generate random token for pairing -PAIRING_TOKEN = generate_token() -SERVER_NODE_ID = str(uuid.uuid4()) -WS_PORT = 8080 -HTTP_PORT = 8000 - - -class MockS2Handler(S2DefaultHTTPHandler): - def do_POST(self) -> None: # pylint: disable=C0103 - content_length = int(self.headers.get("Content-Length", 0)) - post_data = self.rfile.read(content_length).decode("utf-8") - - try: - request_json = json.loads(post_data) - logger.info("Received request at %s", self.path) - logger.debug("Request body: %s", request_json) - - if self.path == "/requestPairing": - # Handle pairing request - # The token in the S2 protocol is a PairingToken object with a token field - token_obj = request_json.get("token", {}) - - # Handle case where token is directly the string or a dict with token field - if isinstance(token_obj, dict) and "token" in token_obj: - request_token_string = token_obj["token"] - else: - request_token_string = token_obj - - logger.info("Extracted token: %s", request_token_string) - logger.info("Expected token: %s", PAIRING_TOKEN) - - if request_token_string == PAIRING_TOKEN: - # Create pairing response - response = { - "s2ServerNodeId": SERVER_NODE_ID, - "serverNodeDescription": { - "brand": "Mock S2 Server", - "type": "Test Server", - "modelName": "Mock Model", - "logoUri": "http://example.com/logo.png", - "userDefinedName": "Mock Server", - "role": "CEM", - "deployment": "LAN", - }, - "requestConnectionUri": f"http://localhost:{HTTP_PORT}/requestConnection", - } - self._send_json_response(200, response) - logger.info("Pairing request successful") - else: - self._send_json_response(401, {"error": "Invalid token"}) - logger.error("Invalid pairing token") - - elif self.path == "/requestConnection": - # Create challenge (normally would be a JWE) - challenge = "mock_challenge_string" - - # Create connection details response - response = { - "connectionUri": f"ws://localhost:{WS_PORT}/s2/mock-websocket", - "challenge": challenge, - "selectedProtocol": "WebSocketSecure", - } - - # Handle connection request - self._send_json_response(200, response) - logger.info("Connection request successful") - - else: - self._send_json_response(404, {"error": "Endpoint not found"}) - logger.error("Unknown endpoint: %s", self.path) - - except Exception as e: - self._send_json_response(500, {"error": str(e)}) - logger.error("Error handling request: %s", e) - raise e - - def log_message(self, format: str, *args: Any) -> None: # pylint: disable=W0622 - logger.info(format % args) # pylint: disable=W1201 - - -def run_server() -> None: - with socketserver.TCPServer(("localhost", HTTP_PORT), MockS2Handler) as httpd: - logger.info("Mock S2 Server running at: http://localhost:%s", HTTP_PORT) - logger.info("Use pairing token: %s", PAIRING_TOKEN) - logger.info("Pairing endpoint: http://localhost:%s/requestPairing", HTTP_PORT) - httpd.serve_forever() - - -if __name__ == "__main__": - run_server() diff --git a/examples/mock_s2_websocket.py b/examples/mock_s2_websocket.py deleted file mode 100644 index 41011f7..0000000 --- a/examples/mock_s2_websocket.py +++ /dev/null @@ -1,87 +0,0 @@ -import asyncio -import logging -import json -import uuid -from datetime import datetime, timezone -import websockets - -# Set up logging -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger("mock_s2_websocket") - -# WebSocket server port -WS_PORT = 8080 - - -# Handle client connection -async def handle_connection( - websocket: websockets.WebSocketServerProtocol, path: str -) -> None: - client_id = str(uuid.uuid4()) - logger.info("Client %s connected on path: %s", client_id, path) - - try: - # Send handshake message to client - handshake = { - "type": "Handshake", - "messageId": str(uuid.uuid4()), - "protocolVersion": "0.0.2-beta", - "timestamp": datetime.now(timezone.utc).isoformat(), - } - await websocket.send(json.dumps(handshake)) - logger.info("Sent handshake to client %s", client_id) - - # Listen for messages - async for message in websocket: - try: - data = json.loads(message) - logger.info("Received message from client %s: %s", client_id, data) - - # Extract message type - message_type = data.get("type", "") - message_id = data.get("messageId", str(uuid.uuid4())) - - # Send reception status - reception_status = { - "type": "ReceptionStatus", - "messageId": str(uuid.uuid4()), - "refMessageId": message_id, - "timestamp": datetime.now(timezone.utc).isoformat(), - "status": "OK", - } - await websocket.send(json.dumps(reception_status)) - logger.info("Sent reception status for message %s", message_id) - - # Handle specific message types - if message_type == "HandshakeResponse": - logger.info("Received handshake response") - - # For FRBC messages, you could add specific handling here - - except json.JSONDecodeError: - logger.error("Invalid JSON received from client %s", client_id) - except Exception as e: - logger.error( - "Error processing message from client %s: %s", client_id, e - ) - raise e - - except websockets.exceptions.ConnectionClosed: - logger.info("Connection with client %s closed", client_id) - except Exception as e: - logger.error("Error with client %s: %s", client_id, e) - raise e - finally: - logger.info("Client %s disconnected", client_id) - - -async def start_server() -> None: - server = await websockets.serve(handle_connection, "localhost", WS_PORT) - logger.info("WebSocket server started on ws://localhost:%s", WS_PORT) - - # Keep the server running - await server.wait_closed() - - -if __name__ == "__main__": - asyncio.run(start_server()) diff --git a/examples/test_ws.py b/examples/test_ws.py deleted file mode 100644 index 7b07b04..0000000 --- a/examples/test_ws.py +++ /dev/null @@ -1,58 +0,0 @@ -#!/usr/bin/env python - -import asyncio -import websockets -import uuid -from s2python.common import ( - EnergyManagementRole, - Handshake, - ReceptionStatus, - ReceptionStatusValues, -) -import json - - -async def hello(): - """ - Connects to a WebSocket server, sends a message, - and prints the response. - """ - uri = "ws://localhost:8080" # <-- Replace with your server's URI - try: - async with websockets.connect(uri) as websocket: - message = Handshake( - message_id=uuid.uuid4(), - role=EnergyManagementRole.RM, - supported_protocol_versions=["1.0"], - ) - - await websocket.send(message.to_json()) - print(f">>> {message.to_json()}") - - reception_status = await websocket.recv() - reception_status_json = json.loads(reception_status) - print(f"<<< {reception_status_json}") - - handshake_response = await websocket.recv() - handshake_response_json = json.loads(handshake_response) - print(f"<<< {handshake_response_json}") - - reception_status = ReceptionStatus( - subject_message_id=handshake_response_json["message_id"], - status=ReceptionStatusValues.OK, - diagnostic_label="Handshake received", - ) - await websocket.send(reception_status.to_json()) - print(f">>> {reception_status.to_json()}") - response = await websocket.recv() - - print(f"<<< {response}") - - except ConnectionRefusedError: - print(f"Connection to {uri} refused. Is the server running?") - except Exception as e: - print(f"An error occurred: {e}") - - -if __name__ == "__main__": - asyncio.run(hello()) diff --git a/src/s2python/authorization/client.py b/src/s2python/authorization/client.py index 994afef..fde1202 100644 --- a/src/s2python/authorization/client.py +++ b/src/s2python/authorization/client.py @@ -116,7 +116,9 @@ def store_connection_request_uri(self, uri: str) -> None: self._connection_request_uri = uri elif self.pairing_uri is not None and "requestPairing" in self.pairing_uri: # Fall back to constructing the URI from the pairing URI - self._connection_request_uri = self.pairing_uri.replace("requestPairing", "requestConnection") + self._connection_request_uri = self.pairing_uri.replace( + "requestPairing", "requestConnection" + ) else: # No valid URI could be determined self._connection_request_uri = None @@ -177,10 +179,14 @@ def request_pairing(self) -> PairingResponse: ValueError: If pairing_uri or token is not set, or if the request fails """ if not self.pairing_uri: - raise ValueError("Pairing URI not set. Set pairing_uri before calling request_pairing.") + raise ValueError( + "Pairing URI not set. Set pairing_uri before calling request_pairing." + ) if not self.token: - raise ValueError("Pairing token not set. Set token before calling request_pairing.") + raise ValueError( + "Pairing token not set. Set token before calling request_pairing." + ) # Ensure we have keys if not self._public_key: @@ -211,7 +217,9 @@ def request_pairing(self) -> PairingResponse: # Parse response if status_code != 200: - raise ValueError(f"Pairing request failed with status {status_code}: {response_text}") + raise ValueError( + f"Pairing request failed with status {status_code}: {response_text}" + ) pairing_response = PairingResponse.model_validate(json.loads(response_text)) @@ -231,7 +239,9 @@ def request_connection(self) -> ConnectionDetails: ValueError: If connection request URI is not set or if the request fails """ if not self._connection_request_uri: - raise ValueError("Connection request URI not set. Call request_pairing first.") + raise ValueError( + "Connection request URI not set. Call request_pairing first." + ) # Create connection request connection_request = ConnectionRequest( @@ -249,7 +259,9 @@ def request_connection(self) -> ConnectionDetails: # Parse response if status_code != 200: - raise ValueError(f"Connection request failed with status {status_code}: {response_text}") + raise ValueError( + f"Connection request failed with status {status_code}: {response_text}" + ) connection_details = ConnectionDetails.model_validate(json.loads(response_text)) @@ -283,13 +295,19 @@ def request_connection(self) -> ConnectionDetails: # Replace the URI with the full WebSocket URL connection_data["connectionUri"] = full_ws_url # Recreate the ConnectionDetails object - connection_details = ConnectionDetails.model_validate(connection_data) - logger.info("Updated relative WebSocket URI to absolute: %s", full_ws_url) + connection_details = ConnectionDetails.model_validate( + connection_data + ) + logger.info( + "Updated relative WebSocket URI to absolute: %s", full_ws_url + ) except (ValueError, TypeError, KeyError) as e: logger.info("Failed to update WebSocket URI: %s", e) else: # Log a warning but don't modify the URI if we can't create a proper absolute URI - logger.info("Received relative WebSocket URI but pairing_uri is not available to create absolute URL") + logger.info( + "Received relative WebSocket URI but pairing_uri is not available to create absolute URL" + ) # Store for later use self._connection_details = connection_details diff --git a/src/s2python/authorization/default_ws_server.py b/src/s2python/authorization/default_ws_server.py index 35ea0a1..ddb46ec 100644 --- a/src/s2python/authorization/default_ws_server.py +++ b/src/s2python/authorization/default_ws_server.py @@ -45,7 +45,10 @@ def __init__(self) -> None: self.handlers = {} async def handle_message( - self, server: "S2DefaultWSServer", msg: S2Message, websocket: WebSocketServerProtocol + self, + server: "S2DefaultWSServer", + msg: S2Message, + websocket: WebSocketServerProtocol, ) -> None: """Handle the S2 message using the registered handler. @@ -205,7 +208,11 @@ async def _connect_and_run(self) -> None: ) logger.info("S2 WebSocket server running at: ws://%s:%s", self.host, self.port) else: - logger.info("S2 WebSocket server already running at: ws://%s:%s", self.host, self.port) + logger.info( + "S2 WebSocket server already running at: ws://%s:%s", + self.host, + self.port, + ) async def wait_till_stop() -> None: await self._stop_event.wait() @@ -387,7 +394,10 @@ async def send_msg_and_await_reception_status_async( ) async def handle_handshake( - self, _: "S2DefaultWSServer", message: S2Message, websocket: WebSocketServerProtocol + self, + _: "S2DefaultWSServer", + message: S2Message, + websocket: WebSocketServerProtocol, ) -> None: """Handle handshake messages. @@ -423,7 +433,10 @@ async def handle_handshake( ) async def handle_reception_status( - self, _: "S2DefaultWSServer", message: S2Message, websocket: WebSocketServerProtocol + self, + _: "S2DefaultWSServer", + message: S2Message, + websocket: WebSocketServerProtocol, ) -> None: """Handle reception status messages.""" if not isinstance(message, ReceptionStatus): @@ -435,7 +448,10 @@ async def handle_reception_status( logger.info("Received ReceptionStatus in handle_reception_status: %s", message.to_json()) async def handle_handshake_response( - self, _: "S2DefaultWSServer", message: S2Message, websocket: WebSocketServerProtocol + self, + _: "S2DefaultWSServer", + message: S2Message, + websocket: WebSocketServerProtocol, ) -> None: """Handle handshake response messages. @@ -466,7 +482,10 @@ async def _send_and_forget(self, s2_msg: S2Message, websocket: WebSocketServerPr logger.warning("Connection closed while sending message") async def send_select_control_type( - self, control_type: ControlType, websocket: WebSocketServerProtocol, send_okay: Awaitable[None] + self, + control_type: ControlType, + websocket: WebSocketServerProtocol, + send_okay: Awaitable[None], ) -> None: """Select the control type. diff --git a/src/s2python/authorization/flask_http_server.py b/src/s2python/authorization/flask_http_server.py new file mode 100644 index 0000000..58f1ec6 --- /dev/null +++ b/src/s2python/authorization/flask_http_server.py @@ -0,0 +1,245 @@ +""" +Flask implementation of the S2 protocol server. +""" + +import logging +from datetime import datetime +from typing import Dict, Any, Tuple +import json + +from flask import Flask, request +from jwskate import Jwk, Jwt +from jwskate.jwe.compact import JweCompact + +from s2python.authorization.server import S2AbstractServer +from s2python.generated.gen_s2_pairing import ( + ConnectionRequest, + PairingRequest, +) + +from s2python.communication.s2_connection import MessageHandlers, S2Connection +from s2python.s2_parser import S2Parser + +# Set up logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger("S2FlaskServer") + + +class S2FlaskHTTPServer(S2AbstractServer): + """Flask implementation of the S2 protocol server.""" + + def __init__( + self, + host: str = "localhost", + http_port: int = 8000, + ws_port: int = 8080, + instance: str = "http", + *args: Any, + **kwargs: Any, + ) -> None: + """Initialize the Flask server implementation. + + Args: + host: The host to bind to + http_port: The HTTP port to use + ws_port: The WebSocket port to use + instance: The instance type (http or ws) + """ + super().__init__(*args, **kwargs) + self.host = host + self.http_port = http_port + self.ws_port = ws_port + self.instance = instance + self._app = Flask(__name__) + self._connections: Dict[str, S2Connection] = {} + self._handlers = MessageHandlers() + self.s2_parser = S2Parser() + self._setup_routes() + + def _setup_routes(self) -> None: + """Set up Flask routes for the server.""" + self._app.add_url_rule( + "/requestPairing", + "requestPairing", + self._handle_pairing_request, + methods=["POST"], + ) + self._app.add_url_rule( + "/requestConnection", + "requestConnection", + self._handle_connection_request, + methods=["POST"], + ) + + def _handle_pairing_request(self) -> Tuple[dict, int]: + """Handle a pairing request. + + Returns: + Tuple[dict, int]: (response JSON, status code) + """ + try: + request_json = request.get_json() + logger.info("Received pairing request at /requestPairing") + logger.debug("Request body: %s", request_json) + + # Convert request to PairingRequest + pairing_request = PairingRequest.model_validate(request_json) + + # Process request using server instance + response = self.handle_pairing_request(pairing_request) + + # Send response + logger.info("Pairing request successful") + return response.model_dump(), 200 + + except ValueError as e: + logger.error("Invalid pairing request: %s", e) + return {"error": str(e)}, 400 + except Exception as e: + logger.error("Error handling pairing request: %s", e) + return {"error": str(e)}, 500 + + def _handle_connection_request(self) -> Tuple[dict, int]: + """Handle a connection request. + + Returns: + Tuple[dict, int]: (response JSON, status code) + """ + try: + request_json = request.get_json() + logger.info("Received connection request at /requestConnection") + logger.debug("Request body: %s", request_json) + + # Convert request to ConnectionRequest + connection_request = ConnectionRequest.model_validate(request_json) + + # Process request using server instance + response = self.handle_connection_request(connection_request) + + # Send response + logger.info("Connection request successful") + return response.model_dump(), 200 + + except ValueError as e: + logger.error("Invalid connection request: %s", e) + return {"error": str(e)}, 400 + except Exception as e: + logger.error("Error handling connection request: %s", e) + return {"error": str(e)}, 500 + + def generate_key_pair(self) -> Tuple[str, str]: + """Generate a public/private key pair for the server. + + Returns: + Tuple[str, str]: (public_key, private_key) pair as base64 encoded strings + """ + logger.info("Generating key pair") + self._key_pair = Jwk.generate_for_alg("RSA-OAEP-256").with_kid_thumbprint() + self._public_jwk = self._key_pair + self._private_jwk = self._key_pair + return ( + self._public_jwk.to_pem(), + self._private_jwk.to_pem(), + ) + + def store_key_pair(self, public_key: str, private_key: str) -> None: + """Store the server's public/private key pair. + + Args: + public_key: Base64 encoded public key + private_key: Base64 encoded private key + """ + self._private_key = private_key + # Convert to JWK for JWT operations + self._private_jwk = Jwk.from_pem(private_key) + + def _create_signed_token( + self, claims: Dict[str, Any], expiry_date: datetime + ) -> str: + """Create a signed JWT token. + + Args: + claims: The claims to include in the token + expiry_date: The token's expiration date + + Returns: + str: The signed JWT token + """ + if not self._private_jwk: + # Generate key pair with correct algorithm + self._key_pair = Jwk.generate_for_alg("RS256").with_kid_thumbprint() + self._private_jwk = self._key_pair + self._public_jwk = self._key_pair + + # Add expiration to claims + claims["exp"] = int(expiry_date.timestamp()) + + # Create JWT with claims using RS256 for signing + token = Jwt.sign(claims=claims, key=self._private_jwk, alg="RS256") + + return str(token) + + def _create_encrypted_challenge( + self, + client_public_key: str, + client_node_id: str, + nested_signed_token: str, + expiry_date: datetime, + ) -> str: + """Create an encrypted challenge for the client. + + Args: + client_public_key: The client's public key + client_node_id: The client's node ID + nested_signed_token: The nested signed token + expiry_date: The challenge's expiration date + + Returns: + str: The encrypted challenge + """ + # Convert client's public key to JWK + client_jwk = Jwk.from_pem(client_public_key) + + # Create the payload to encrypt - this will be decrypted and used as an unprotected JWT + payload = { + "S2ClientNodeId": client_node_id, + "signedToken": nested_signed_token, + "exp": int(expiry_date.timestamp()), + } + + # Create JWE with all required components + jwe = JweCompact.encrypt( + plaintext=json.dumps(payload).encode(), + key=client_jwk, # Using client's public key for encryption + alg="RSA-OAEP-256", + enc="A256GCM", + ) + + logger.info("JWE: %s", str(jwe)) + return str(jwe) + + def start_server(self) -> None: + """Start the HTTP or WebSocket server.""" + if self.instance == "http": + logger.info("Starting Flask HTTP server------>") + self._app.run(host=self.host, port=self.http_port) + else: + raise ValueError("Invalid instance type") + + def stop_server(self) -> None: + """Stop the server.""" + # Flask doesn't have a built-in way to stop the server + # This would typically be handled by the WSGI server + pass + + def _get_ws_url(self) -> str: + """Get the WebSocket URL for the server.""" + return f"ws://{self.host}:{self.ws_port}" + + def _get_base_url(self) -> str: + """Get the base URL for the server. + + Returns: + str: The base URL (e.g., "http://localhost:8000") + """ + return f"http://{self.host}:{self.http_port}" diff --git a/src/s2python/authorization/flask_service.py b/src/s2python/authorization/flask_service.py deleted file mode 100644 index e69de29..0000000 diff --git a/src/s2python/authorization/flask_ws_server.py b/src/s2python/authorization/flask_ws_server.py new file mode 100644 index 0000000..bd9a698 --- /dev/null +++ b/src/s2python/authorization/flask_ws_server.py @@ -0,0 +1,317 @@ +""" +Flask implementation of the S2 protocol WebSocket server. +""" + +import asyncio +import json +import logging +import traceback +import uuid +from typing import Any, Callable, Dict, Optional, Type + +from flask import Flask +from flask_sock import ConnectionClosed, Sock + +from s2python.common import ( + ControlType, + EnergyManagementRole, + Handshake, + HandshakeResponse, + ReceptionStatus, + ReceptionStatusValues, + SelectControlType, +) +from s2python.communication.reception_status_awaiter import ReceptionStatusAwaiter +from s2python.message import S2Message +from s2python.s2_parser import S2Parser +from s2python.s2_validation_error import S2ValidationError + +# Set up logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger("S2FlaskWSServer") + + +class MessageHandlers: + """Class to manage message handlers for different message types.""" + + handlers: Dict[Type[S2Message], Callable] + + def __init__(self) -> None: + self.handlers = {} + + async def handle_message( + self, + server: "S2FlaskWSServer", + msg: S2Message, + websocket: Sock, + ) -> None: + """Handle the S2 message using the registered handler. + Args: + server: The server instance handling the message + msg: The S2 message to handle + websocket: The websocket connection to the client + """ + handler = self.handlers.get(type(msg)) + if handler is not None: + try: + if asyncio.iscoroutinefunction(handler): + await handler(server, msg, websocket) + else: + + def do_message() -> None: + handler(server, msg, websocket) + + eventloop = asyncio.get_event_loop() + await eventloop.run_in_executor(executor=None, func=do_message) + except Exception: + logger.error( + "While processing message %s an unrecoverable error occurred.", + msg.message_id, # type: ignore[attr-defined] + ) + logger.error("Error: %s", traceback.format_exc()) + await server.respond_with_reception_status( + subject_message_id=msg.message_id, # type: ignore[attr-defined] + status=ReceptionStatusValues.PERMANENT_ERROR, + diagnostic_label=f"While processing message {msg.message_id} " # type: ignore[attr-defined] + f"an unrecoverable error occurred.", + websocket=websocket, + ) + raise + else: + logger.warning( + "Received a message of type %s but no handler is registered. Ignoring the message.", + type(msg), + ) + + def register_handler(self, msg_type: Type[S2Message], handler: Callable[..., Any]) -> None: + """Register a handler for a specific message type. + Args: + msg_type: The message type to handle + handler: The handler function + """ + self.handlers[msg_type] = handler + + +class S2FlaskWSServer: + """Flask-based WebSocket server implementation for S2 protocol.""" + + def __init__( + self, + host: str = "localhost", + port: int = 8080, + role: EnergyManagementRole = EnergyManagementRole.CEM, + ws_path: str = "/", + ) -> None: + """Initialize the WebSocket server. + Args: + host: The host to bind to + port: The port to listen on + role: The role of this server (CEM or RM) + ws_path: The path for the WebSocket endpoint. + """ + self.host = host + self.port = port + self.role = role + self.ws_path = ws_path + + self.app = Flask(__name__) + self.sock = Sock(self.app) + + self._handlers = MessageHandlers() + self.s2_parser = S2Parser() + self._connections: Dict[str, Sock] = {} + self.reception_status_awaiter = ReceptionStatusAwaiter() + + self._register_default_handlers() + self._setup_routes() + + def _setup_routes(self) -> None: + self.sock.route(self.ws_path)(self._ws_handler) + + def _register_default_handlers(self) -> None: + """Register default message handlers.""" + self._handlers.register_handler(Handshake, self.handle_handshake) + self._handlers.register_handler(HandshakeResponse, self.handle_handshake_response) + self._handlers.register_handler(ReceptionStatus, self.handle_reception_status) + + def start(self) -> None: + """ + Start the WebSocket server using Flask's built-in development server. + + Note: This server is synchronous (WSGI) and is not suitable for production. + It runs each WebSocket connection in a separate thread. For true async + support and production use, you must run this application with an + ASGI server like Hypercorn. Example: + `hypercorn -b 127.0.0.1:8080 examples.example_flask_server:server_instance.app` + """ + logger.info( + "Starting S2 Flask WebSocket server with Flask's development WSGI server at: ws://%s:%s%s", + self.host, + self.port, + self.ws_path, + ) + self.app.run(host=self.host, port=self.port, debug=False) + + def stop(self) -> None: + """Stop the WebSocket server.""" + logger.warning("S2FlaskWSServer.stop() is a no-op. The server should be managed by a process manager.") + + def _ws_handler(self, ws: Sock) -> None: + """ + Wrapper to run the async websocket handler from a synchronous context. + This is required for Flask's development server. An ASGI server would + be able to run the async handler directly. + """ + try: + asyncio.run(self._handle_websocket_connection(ws)) + except Exception as e: + # The websocket is likely closed, or another network error occurred. + logger.error("Error in websocket handler: %s", e) + + async def _handle_websocket_connection(self, websocket: Sock) -> None: + """Handle incoming WebSocket connections.""" + client_id = str(uuid.uuid4()) + logger.info("Client %s connected.", client_id) + self._connections[client_id] = websocket + + try: + while True: + message = await websocket.receive() + try: + s2_msg = self.s2_parser.parse_as_any_message(message) + if isinstance(s2_msg, ReceptionStatus): + await self.reception_status_awaiter.receive_reception_status(s2_msg) + continue + except json.JSONDecodeError: + await self.respond_with_reception_status( + subject_message_id=uuid.UUID("00000000-0000-0000-0000-000000000000"), + status=ReceptionStatusValues.INVALID_DATA, + diagnostic_label="Not valid json.", + websocket=websocket, + ) + continue + try: + await self._handlers.handle_message(self, s2_msg, websocket) + except json.JSONDecodeError: + await self.respond_with_reception_status( + subject_message_id=uuid.UUID("00000000-0000-0000-0000-000000000000"), + status=ReceptionStatusValues.INVALID_DATA, + diagnostic_label="Not valid json.", + websocket=websocket, + ) + except S2ValidationError as e: + json_msg = json.loads(message) + message_id = json_msg.get("message_id") + if message_id: + await self.respond_with_reception_status( + subject_message_id=message_id, + status=ReceptionStatusValues.INVALID_MESSAGE, + diagnostic_label=str(e), + websocket=websocket, + ) + else: + await self.respond_with_reception_status( + subject_message_id=uuid.UUID("00000000-0000-0000-0000-000000000000"), + status=ReceptionStatusValues.INVALID_DATA, + diagnostic_label="Message appears valid json but could not find a message_id field.", + websocket=websocket, + ) + except Exception as e: + logger.error("Error processing message: %s", str(e)) + raise + except ConnectionClosed: + logger.info("Connection with client %s closed", client_id) + finally: + if client_id in self._connections: + del self._connections[client_id] + logger.info("Client %s disconnected", client_id) + + async def respond_with_reception_status( + self, + subject_message_id: uuid.UUID, + status: ReceptionStatusValues, + diagnostic_label: str, + websocket: Sock, + ) -> None: + """Send a reception status response.""" + response = ReceptionStatus( + subject_message_id=subject_message_id, status=status, diagnostic_label=diagnostic_label + ) + logger.info("Sending reception status %s for message %s", status, subject_message_id) + try: + await websocket.send(response.to_json()) + except ConnectionClosed: + logger.warning("Connection closed while sending reception status") + + async def send_msg_and_await_reception_status_async( + self, + s2_msg: S2Message, + websocket: Sock, + timeout_reception_status: float = 20.0, + raise_on_error: bool = True, + ) -> ReceptionStatus: + """Send a message and await a reception status.""" + await self._send_and_forget(s2_msg, websocket) + try: + response = await asyncio.wait_for(websocket.receive(), timeout=timeout_reception_status) + # Assuming the response is the correct reception status + return ReceptionStatus( + subject_message_id=s2_msg.message_id, # type: ignore[attr-defined] + status=ReceptionStatusValues.OK, + diagnostic_label="Reception status received.", + ) + except asyncio.TimeoutError: + if raise_on_error: + raise TimeoutError(f"Did not receive a reception status on time for {s2_msg.message_id}") # type: ignore[attr-defined] + return ReceptionStatus( + subject_message_id=s2_msg.message_id, # type: ignore[attr-defined] + status=ReceptionStatusValues.PERMANENT_ERROR, + diagnostic_label="Timeout waiting for reception status.", + ) + except ConnectionClosed: + return ReceptionStatus( + subject_message_id=s2_msg.message_id, # type: ignore[attr-defined] + status=ReceptionStatusValues.OK, + diagnostic_label="Connection closed, assuming OK status.", + ) + + async def handle_handshake(self, _: "S2FlaskWSServer", message: S2Message, websocket: Sock) -> None: + """Handle handshake messages.""" + if not isinstance(message, Handshake): + return + + handshake_response = HandshakeResponse( + message_id=message.message_id, selected_protocol_version=message.supported_protocol_versions + ) + await self.send_msg_and_await_reception_status_async(handshake_response, websocket) + + await self.respond_with_reception_status( + subject_message_id=message.message_id, + status=ReceptionStatusValues.OK, + diagnostic_label="Handshake received", + websocket=websocket, + ) + + async def handle_reception_status(self, _: "S2FlaskWSServer", message: S2Message, websocket: Sock) -> None: + """Handle reception status messages.""" + if not isinstance(message, ReceptionStatus): + return + logger.info("Received ReceptionStatus in handle_reception_status: %s", message.to_json()) + + async def handle_handshake_response(self, _: "S2FlaskWSServer", message: S2Message, websocket: Sock) -> None: + """Handle handshake response messages.""" + if not isinstance(message, HandshakeResponse): + return + logger.debug("Received HandshakeResponse: %s", message.to_json()) + + async def _send_and_forget(self, s2_msg: S2Message, websocket: Sock) -> None: + """Send a message and forget about it.""" + try: + await websocket.send(s2_msg.to_json()) + except ConnectionClosed: + logger.warning("Connection closed while sending message") + + async def send_select_control_type(self, control_type: ControlType, websocket: Sock) -> None: + """Select the control type.""" + select_control_type = SelectControlType(message_id=uuid.uuid4(), control_type=control_type) + await self._send_and_forget(select_control_type, websocket) diff --git a/src/s2python/communication/custom_thread_s2_connection.py b/src/s2python/communication/custom_thread_s2_connection.py new file mode 100644 index 0000000..8d8cc79 --- /dev/null +++ b/src/s2python/communication/custom_thread_s2_connection.py @@ -0,0 +1,48 @@ +import threading +import asyncio +from typing import Optional +from s2python.communication.s2_connection import S2Connection +from s2python.common import EnergyManagementRole + +class CustomThreadS2Connection(S2Connection): + """ + Extends S2Connection to allow running the event loop in a developer-supplied thread. + + If a thread is provided, the event loop will be run in that thread. The developer is responsible + for managing the thread's lifecycle and ensuring it is not used for other conflicting tasks. + """ + def __init__(self, *args, thread: Optional[threading.Thread] = None, **kwargs): + super().__init__(*args, **kwargs) + self._external_thread = thread + self._thread_started_by_user = False + + def start(self) -> None: + if self._external_thread: + # Only start the thread if it is not already running + if not self._external_thread.is_alive(): + def run_loop(): + asyncio.set_event_loop(self._eventloop) + if self.role == EnergyManagementRole.RM: + self._run_eventloop(self._run_as_rm()) + else: + self._run_eventloop(self._run_as_cem()) + self._external_thread.run = run_loop + self._external_thread.start() + self._thread_started_by_user = True + else: + raise RuntimeError("Provided thread is already running. Please provide a fresh thread.") + else: + # Default behavior: run in the current thread + super().start() + + def stop(self) -> None: + """ + Stops the S2 connection. If an external thread was provided and started by this class, + it will join that thread. Otherwise, uses the default stop behavior. + """ + if self._external_thread and self._thread_started_by_user: + if self._eventloop.is_running(): + asyncio.run_coroutine_threadsafe(self._do_stop(), self._eventloop).result() + self._external_thread.join() + else: + super().stop() diff --git a/tests/unit/custom_thread_s2_connection_test.py b/tests/unit/custom_thread_s2_connection_test.py new file mode 100644 index 0000000..77d071c --- /dev/null +++ b/tests/unit/custom_thread_s2_connection_test.py @@ -0,0 +1,48 @@ +import unittest +import threading +import time +from s2python.communication.custom_thread_s2_connection import CustomThreadS2Connection +from s2python.common import EnergyManagementRole + +class DummyS2Connection(CustomThreadS2Connection): + def _run_eventloop(self, main_task): + # Simulate event loop running for a short time + time.sleep(0.1) + def _run_as_rm(self): + # Dummy awaitable + class DummyAwaitable: + def __await__(self): + yield + return DummyAwaitable() + def _run_as_cem(self): + class DummyAwaitable: + def __await__(self): + yield + return DummyAwaitable() + async def _do_stop(self): + pass + +class TestCustomThreadS2Connection(unittest.TestCase): + def test_start_with_external_thread(self): + thread = threading.Thread() + conn = DummyS2Connection( + url="ws", + role=EnergyManagementRole.RM, + thread=thread + ) + conn.start() + # Wait for thread to finish + conn.stop() + self.assertFalse(thread.is_alive()) + + def test_start_without_external_thread(self): + conn = DummyS2Connection( + url="ws://localhost:1234", + role=EnergyManagementRole.CEM + ) + # Should not raise + conn.start() + conn.stop() + +if __name__ == "__main__": + unittest.main()