diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 7ca3906..e55622e 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -57,11 +57,11 @@ jobs: strategy: matrix: python: - - "3.8" - "3.9" - "3.10" - "3.11" - - "3.12" # newest Python that is stable + - "3.12" + - "3.13" # newest Python that is stable platform: - ubuntu-latest # - macos-latest @@ -79,8 +79,7 @@ jobs: - name: Run tests run: >- pipx run --python '${{ steps.setup-python.outputs.python-path }}' - tox --installpkg '${{ needs.prepare.outputs.wheel-distribution }}' - -- -rFEx --durations 10 --color yes # pytest args + tox --develop # - name: Generate coverage report # run: pipx run coverage lcov -o coverage.lcov # - name: Upload partial coverage report @@ -96,11 +95,11 @@ jobs: strategy: matrix: python: - - "3.8" - "3.9" - "3.10" - "3.11" - - "3.12" # newest Python that is stable + - "3.12" + - "3.13" # newest Python that is stable platform: - ubuntu-latest # - macos-latest @@ -118,18 +117,18 @@ jobs: - name: Run tests run: >- pipx run --python '${{ steps.setup-python.outputs.python-path }}' - tox -e lint --installpkg '${{ needs.prepare.outputs.wheel-distribution }}' + tox -e lint --develop typecheck: needs: prepare strategy: matrix: python: - - "3.8" - "3.9" - "3.10" - "3.11" - - "3.12" # newest Python that is stable + - "3.12" + - "3.13" # newest Python that is stable platform: - ubuntu-latest # - macos-latest @@ -147,7 +146,7 @@ jobs: - name: Run tests run: >- pipx run --python '${{ steps.setup-python.outputs.python-path }}' - tox -e typecheck --installpkg '${{ needs.prepare.outputs.wheel-distribution }}' + tox -e typecheck --develop finalize: needs: [test, lint, typecheck] diff --git a/.gitignore b/.gitignore index 3f96451..5baf340 100644 --- a/.gitignore +++ b/.gitignore @@ -13,4 +13,3 @@ venv dist/ build/ %LOCALAPPDATA% -s2-python/src/s2python/specification/s2-pairing/s2-over-ip-pairing.yaml diff --git a/dev-requirements.txt b/dev-requirements.txt index 5b3daa2..2abfdc1 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -238,6 +238,16 @@ wheel==0.45.1 # via pip-tools zipp==3.20.2 # via importlib-metadata +jwskate==0.11.1 +binapy==0.8.0 + # via jwskate +cffi==1.17.1 + # via jwskate +cryptography==44.0.2 + # via jwskate +pycparser==2.22 + # via jwskate +types-requests==2.32.0.20250328 # The following packages are considered to be unsafe in a requirements file: # pip diff --git a/pyproject.toml b/pyproject.toml index 239a3de..9e2c664 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,7 +12,7 @@ version = "0.5.0" readme = "README.rst" license = "Apache-2.0" license-files = ["LICENSE"] -requires-python = ">=3.8, < 3.13" +requires-python = ">=3.9, < 3.14" dependencies = [ "pydantic>=2.8.2", "pytz", @@ -20,11 +20,11 @@ dependencies = [ ] classifiers = [ "Development Status :: 4 - Beta", - "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", ] [project.urls] "Source code" = "https://github.com/flexiblepower/s2-ws-json-python" @@ -32,6 +32,8 @@ classifiers = [ [project.optional-dependencies] ws = [ "websockets~=13.1", + "jwskate~=0.11", + "requests~=2.32.3", ] fastapi = [ "fastapi", diff --git a/specification/s2-over-ip-pairing.yaml b/specification/s2-over-ip-pairing.yaml new file mode 100644 index 0000000..0657095 --- /dev/null +++ b/specification/s2-over-ip-pairing.yaml @@ -0,0 +1,136 @@ +openapi: 3.0.3 +info: + version: "0.1" + title: s2-over-ip pairing and connection initiation + description: "Description of the pairing process over IP for S2" +paths: + /requestPairing: + post: + description: Initiate pairing + requestBody: + description: TODO + content: + application/json: + schema: + $ref: '#/components/schemas/PairingRequest' + responses: + '200': + description: TODO + content: + application/json: + schema: + $ref: '#/components/schemas/PairingResponse' + /requestConnection: + post: + description: TODO + requestBody: + description: TODO + content: + application/json: + schema: + $ref: '#/components/schemas/ConnectionRequest' + responses: + '200': + description: TODO + content: + application/json: + schema: + $ref: '#/components/schemas/ConnectionDetails' +components: + schemas: + PairingInfo: + type: object + properties: + pairingUri: + type: string + format: uri + token: + $ref: "#/components/schemas/PairingToken" + validUntil: + type: string + format: date-time + PairingRequest: + type: object + properties: + token: + $ref: "#/components/schemas/PairingToken" + publicKey: + type: string + format: byte + s2ClientNodeId: + type: string + format: uuid + s2ClientNodeDescription: + $ref: "#/components/schemas/S2NodeDescription" + supportedProtocols: + type: array + items: + $ref: "#/components/schemas/Protocols" + PairingResponse: + type: object + properties: + s2ServerNodeId: + type: string + format: uuid + serverNodeDescription: + $ref: "#/components/schemas/S2NodeDescription" + requestConnectionUri: + type: string + format: uri + ConnectionRequest: + type: object + properties: + s2ClientNodeId: + type: string + format: uuid + supportedProtocols: + type: array + items: + $ref: "#/components/schemas/Protocols" + ConnectionDetails: + type: object + properties: + selectedProtocol: + $ref: "#/components/schemas/Protocols" + challenge: + type: string + format: byte + connectionUri: + type: string + format: uri + S2NodeDescription: + type: object + description: TODO nog even over nadenken + properties: + brand: + type: string + logoUri: + type: string + format: uri + type: + type: string + modelName: + type: string + userDefinedName: + type: string + role: + $ref: "#/components/schemas/S2Role" + deployment: + $ref: "#/components/schemas/Deployment" + Protocols: + type: string + enum: + - WebSocketSecure + S2Role: + type: string + enum: + - CEM + - RM + Deployment: + type: string + enum: + - WAN + - LAN + PairingToken: + type: string + pattern: "^[0-9a-zA-Z]{32}$" diff --git a/src/s2python/__init__.py b/src/s2python/__init__.py index 0ab0a42..64bfb48 100644 --- a/src/s2python/__init__.py +++ b/src/s2python/__init__.py @@ -1,3 +1,4 @@ +import sys # pragma: no cover from importlib.metadata import PackageNotFoundError, version # pragma: no cover try: @@ -8,3 +9,6 @@ __version__ = "unknown" finally: del version, PackageNotFoundError + + from s2python.communication.s2_connection import S2Connection, AssetDetails # pragma: no cover + sys.modules['s2python.s2_connection'] = sys.modules['s2python.communication.s2_connection'] # pragma: no cover diff --git a/src/s2python/authorization/__init__.py b/src/s2python/authorization/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/s2python/authorization/client.py b/src/s2python/authorization/client.py index 90e3a17..a5d5422 100644 --- a/src/s2python/authorization/client.py +++ b/src/s2python/authorization/client.py @@ -1,67 +1,345 @@ -from abc import ABC, abstractmethod -from typing import Any, Dict +""" +S2 protocol client for handling pairing and secure connections. +""" +import abc +import json +import uuid +import datetime +import logging +from typing import Dict, Optional, Tuple, Union, List, Any -class AbstractConnectionClient(ABC): - """Abstract class for handling the /requestConnection endpoint.""" - def request_connection(self) -> Any: - """Orchestrate the connection request flow: build → execute → handle.""" - request_data = self.build_connection_request() - response_data = self.execute_connection_request(request_data) - return self.handle_connection_response(response_data) +from jwskate import Jwk +from pydantic import BaseModel + +from s2python.generated.gen_s2_pairing import ( + ConnectionDetails, + ConnectionRequest, + PairingRequest, + PairingResponse, + PairingToken, + S2NodeDescription, + Protocols, +) + + +REQTEST_TIMEOUT = 10 +PAIRING_TIMEOUT = datetime.timedelta(minutes=5) +KEY_ALGORITHM = "RSA-OAEP-256" + +# Set up logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger("S2AbstractClient") + +class PairingDetails(BaseModel): + """Contains all details from the pairing process.""" + + pairing_response: PairingResponse + connection_details: ConnectionDetails + decrypted_challenge_str: Optional[str] = None - @abstractmethod - def build_connection_request(self) -> Dict: - """ - Build the payload for the ConnectionRequest schema. - Returns a dictionary with keys: s2ClientNodeId, supportedProtocols. - """ - @abstractmethod - def execute_connection_request(self, request_data: Dict) -> Dict: +class S2AbstractClient(abc.ABC): + """Abstract client for handling S2 protocol pairing and connections. + + Client handles: + - HTTP client with TLS + - Storage of connection request URI + - Storage of public/private key pairs + - Challenge solving + + This class serves as an interface that developers can extend to implement + S2 protocol functionality with their preferred technology stack. + Concrete implementations should override the abstract methods marked + with @abc.abstractmethod. + """ + + # pylint: disable=too-many-instance-attributes + # pylint: disable=too-many-arguments + def __init__( + self, + pairing_uri: Optional[str] = None, + token: Optional[PairingToken] = None, + node_description: Optional[S2NodeDescription] = None, + verify_certificate: Union[bool, str] = False, + client_node_id: Optional[uuid.UUID] = None, + supported_protocols: Optional[List[Protocols]] = None, + ) -> None: + """Initialize the client with configuration parameters. + + Args: + pairing_uri: URI for the pairing request + token: Pairing token for authentication + node_description: S2 node description + verify_certificate: Whether to verify SSL certificates (or path to CA cert) + client_node_id: Client node UUID (generated if not provided) + supported_protocols: List of supported protocols """ - Execute the POST request to /requestConnection. - Implementations should send the request_data to the endpoint - and return the JSON response as a dictionary. + # Connection and authentication info + self.pairing_uri = pairing_uri + self.token = token + self.node_description = node_description + self.verify_certificate = verify_certificate + self.client_node_id = client_node_id if client_node_id else uuid.uuid4() + self.supported_protocols = supported_protocols or [Protocols.WebSocketSecure] + + # Internal state + self._connection_request_uri: Optional[str] = None + self._public_key: Optional[str] = None + self._private_key: Optional[str] = None + self._public_jwk: Optional[Jwk] = None + self._private_jwk: Optional[Jwk] = None + self._key_pair: Optional[Jwk] = None + self._pairing_response: Optional[PairingResponse] = None + self._connection_details: Optional[ConnectionDetails] = None + self._pairing_details: Optional[PairingDetails] = None + + @property + def connection_request_uri(self) -> Optional[str]: + """Get the stored connection request URI.""" + return self._connection_request_uri + + def store_connection_request_uri(self, uri: str) -> None: + """Store the connection request URI. + + If the provided URI is empty, None, or doesn't contain 'requestConnection', + it will attempt to derive it from the pairing URI by replacing 'requestPairing' + with 'requestConnection'. + + Args: + uri: The connection request URI from the pairing response """ + if uri is not None and uri.strip() != "" and "requestConnection" in uri: + self._connection_request_uri = uri + elif self.pairing_uri is not None and "requestPairing" in self.pairing_uri: + # Fall back to constructing the URI from the pairing URI + self._connection_request_uri = self.pairing_uri.replace("requestPairing", "requestConnection") + else: + # No valid URI could be determined + self._connection_request_uri = None + + @abc.abstractmethod + def generate_key_pair(self) -> Tuple[str, str]: + """Generate a public/private key pair. + + This method should be implemented by concrete subclasses to use their + preferred cryptographic libraries or key management systems. - @abstractmethod - def handle_connection_response(self, response_data: Dict) -> Any: + Returns: + Tuple[str, str]: (public_key, private_key) pair as base64 encoded strings """ - Process the ConnectionDetails response (e.g., extract challenge and connection URI). - The response_data contains keys: selectedProtocol, challenge, connectionUri. + + @abc.abstractmethod + def store_key_pair(self, public_key: str, private_key: str) -> None: + """Store the public/private key pair. + + This method should be implemented by concrete subclasses to store keys + according to their security requirements (e.g., secure storage, HSM, etc.). + + Args: + public_key: Base64 encoded public key + private_key: Base64 encoded private key """ + @abc.abstractmethod + def _make_https_request( + self, + url: str, + method: str = "GET", + data: Optional[Dict[str, Any]] = None, + headers: Optional[Dict[str, str]] = None, + ) -> Tuple[int, str]: + """Make an HTTPS request. -class AbstractPairingClient(ABC): - """Abstract class for handling the /requestPairing endpoint.""" + This method should be implemented by concrete subclasses to use their + preferred HTTP client library or framework. - def request_pairing(self) -> Any: - """Orchestrate the pairing request flow: build → execute → handle.""" - request_data = self.build_pairing_request() - response_data = self.execute_pairing_request(request_data) - return self.handle_pairing_response(response_data) + Args: + url: Target URL + method: HTTP method (GET, POST, etc.) + data: Request body data + headers: HTTP headers - @abstractmethod - def build_pairing_request(self) -> Dict: + Returns: + Tuple[int, str]: (status_code, response_text) """ - Build the payload for the PairingRequest schema. - Returns a dictionary with keys: token, publicKey, s2ClientNodeId, - s2ClientNodeDescription, supportedProtocols. + + def request_pairing(self) -> PairingResponse: + """Send a pairing request to the server using client configuration. + + Returns: + PairingResponse: The server's response to the pairing request + + Raises: + ValueError: If pairing_uri or token is not set, or if the request fails """ + if not self.pairing_uri: + raise ValueError("Pairing URI not set. Set pairing_uri before calling request_pairing.") + + if not self.token: + raise ValueError("Pairing token not set. Set token before calling request_pairing.") + + # Ensure we have keys + if not self._public_key: + public_key, private_key = self.generate_key_pair() + self.store_key_pair(public_key, private_key) + + # Create pairing request + logger.info("Creating pairing request") + pairing_request = PairingRequest( + token=self.token, + publicKey=self._public_key, + s2ClientNodeId=str(self.client_node_id), + s2ClientNodeDescription=self.node_description, + supportedProtocols=self.supported_protocols, + ) + + # Make pairing request + logger.info("Making pairing request") + status_code, response_text = self._make_https_request( + url=self.pairing_uri, + method="POST", + data=pairing_request.model_dump(exclude_none=True), + headers={"Content-Type": "application/json"}, + ) + logger.info('Pairing request response: %s %s', status_code, response_text) - @abstractmethod - def execute_pairing_request(self, request_data: Dict) -> Dict: + # Parse response + if status_code != 200: + raise ValueError(f"Pairing request failed with status {status_code}: {response_text}") + + pairing_response = PairingResponse.model_validate(json.loads(response_text)) + + # Store for later use + self._pairing_response = pairing_response + self.store_connection_request_uri(str(pairing_response.requestConnectionUri)) + + return pairing_response + + def request_connection(self) -> ConnectionDetails: + """Request connection details from the server. + + Returns: + ConnectionDetails: The connection details returned by the server + + Raises: + ValueError: If connection request URI is not set or if the request fails """ - Execute the POST request to /requestPairing. - Implementations should send the request_data to the endpoint - and return the JSON response as a dictionary. + if not self._connection_request_uri: + raise ValueError("Connection request URI not set. Call request_pairing first.") + + # Create connection request + connection_request = ConnectionRequest( + s2ClientNodeId=str(self.client_node_id), + supportedProtocols=self.supported_protocols, + ) + + # Make a POST request to the connection request URI + status_code, response_text = self._make_https_request( + url=self._connection_request_uri, + method="POST", + data=connection_request.model_dump(exclude_none=True), + headers={"Content-Type": "application/json"}, + ) + + # Parse response + if status_code != 200: + raise ValueError(f"Connection request failed with status {status_code}: {response_text}") + + connection_details = ConnectionDetails.model_validate(json.loads(response_text)) + + # Handle relative WebSocket URI paths + if ( + connection_details.connectionUri is not None + and not str(connection_details.connectionUri).startswith("ws://") + and not str(connection_details.connectionUri).startswith("wss://") + ): + + # If websocket address doesn't start with ws:// or wss:// assume it's relative to the pairing URI + if self.pairing_uri: + base_uri = self.pairing_uri + # Convert to WebSocket protocol and remove the requestPairing path + ws_base = ( + base_uri.replace("http://", "ws://") + .replace("https://", "wss://") + .replace("requestPairing", "") + .rstrip("/") + ) + + # Combine with the relative path from connectionUri + relative_path = str(connection_details.connectionUri).lstrip("/") + + # Create complete URL + full_ws_url = f"{ws_base}/{relative_path}" + + try: + # Update the connection details with the new URL + connection_data = connection_details.model_dump() + # Replace the URI with the full WebSocket URL + connection_data["connectionUri"] = full_ws_url + # Recreate the ConnectionDetails object + connection_details = ConnectionDetails.model_validate( + connection_data + ) + logger.info('Updated relative WebSocket URI to absolute: %s', full_ws_url) + except (ValueError, TypeError, KeyError) as e: + logger.info('Failed to update WebSocket URI: %s', e) + else: + # Log a warning but don't modify the URI if we can't create a proper absolute URI + logger.info('Received relative WebSocket URI but pairing_uri is not available to create absolute URL') + + # Store for later use + self._connection_details = connection_details + + return connection_details + + @abc.abstractmethod + def solve_challenge(self, challenge: Optional[str] = None) -> str: + """Solve the connection challenge using the public key. + + If no challenge is provided, uses the challenge from connection_details. + + The challenge is a JWE (JSON Web Encryption) that must be decrypted using + the client's public key, then encoded as a base64 string. + + Args: + challenge: The challenge string from the server (optional) + + Returns: + str: The solution to the challenge (base64 encoded decrypted challenge) + + Raises: + ValueError: If no challenge is provided and connection_details is not set + ValueError: If the public key is not available + RuntimeError: If challenge decryption fails """ - @abstractmethod - def handle_pairing_response(self, response_data: Dict) -> Any: + @abc.abstractmethod + def establish_secure_connection(self) -> Any: + """Establish a secure connection to the server. + + This method should be implemented by concrete subclasses to establish + a secure connection using the connection details and solved challenge. + Implementations needs to use WebSocket Secure. + + Returns: + Any: A connection object or handler specific to the implementation + + Raises: + ValueError: If connection details or solved challenge are not available + RuntimeError: If connection establishment fails """ - Process the PairingResponse (e.g., extract server details). - The response_data contains keys: s2ServerNodeId, serverNodeDescription, requestConnectionUri. + + @abc.abstractmethod + def close_connection(self) -> None: + """Close the connection to the server. + + This method should be implemented by concrete subclasses to properly + close the connection established by establish_secure_connection. """ + + @property + def pairing_details(self) -> Optional[PairingDetails]: + """Get the stored pairing details.""" + return self._pairing_details diff --git a/src/s2python/authorization/default_client.py b/src/s2python/authorization/default_client.py new file mode 100644 index 0000000..68a280c --- /dev/null +++ b/src/s2python/authorization/default_client.py @@ -0,0 +1,244 @@ +""" +Default implementation of the S2 protocol client. + +This module provides a concrete implementation of the S2AbstractClient +for developers to use directly or as a reference for their own implementations. +""" + +import base64 +import json +import uuid +import logging +from typing import Dict, Optional, Tuple, Union, List, Any, Mapping + +import requests +from requests import Response + +from jwskate import JweCompact, Jwk, Jwt + +from s2python.generated.gen_s2_pairing import ( + PairingToken, + S2NodeDescription, + Protocols, +) +from s2python.authorization.client import ( + S2AbstractClient, + REQTEST_TIMEOUT, + KEY_ALGORITHM, + PairingDetails, +) + +# Set up logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger("S2DefaultClient") + +class S2DefaultClient(S2AbstractClient): + """Default implementation of the S2AbstractClient using the requests library for HTTP + and jwskate for cryptographic operations. + + This implementation can be used directly or as a reference for custom implementations. + """ + + # pylint: disable=too-many-arguments + def __init__( + self, + pairing_uri: Optional[str] = None, + token: Optional[PairingToken] = None, + node_description: Optional[S2NodeDescription] = None, + verify_certificate: Union[bool, str] = False, + client_node_id: Optional[uuid.UUID] = None, + supported_protocols: Optional[List[Protocols]] = None, + ) -> None: + """Initialize the default client with configuration parameters.""" + super().__init__( + pairing_uri, + token, + node_description, + verify_certificate, + client_node_id, + supported_protocols, + ) + # Additional state for this implementation + self._ws_connection: Optional[Dict[str, Any]] = None + + def generate_key_pair(self) -> Tuple[str, str]: + """Generate a public/private key pair using jwskate library. + + Returns: + Tuple[str, str]: (public_key, private_key) pair as PEM encoded strings + """ + logger.info("Generating key pair") + self._key_pair = Jwk.generate_for_alg(KEY_ALGORITHM).with_kid_thumbprint() + self._public_jwk = self._key_pair + self._private_jwk = self._key_pair + return ( + self._public_jwk.to_pem(), + self._private_jwk.to_pem(), + ) + + def store_key_pair(self, public_key: str, private_key: str) -> None: + """Store the public/private key pair in memory. + + In a production implementation, this might use a secure storage mechanism + like a keystore, HSM, or encrypted database. + + Args: + public_key: PEM encoded public key + private_key: PEM encoded private key + """ + logger.info("Storing key pair") + self._public_key = public_key + self._private_key = private_key + + def _make_https_request( + self, + url: str, + method: str = "GET", + data: Optional[Dict[str, Any]] = None, + headers: Optional[Dict[str, str]] = None, + ) -> Tuple[int, str]: + """Make an HTTPS request using the requests library. + + Args: + url: Target URL + method: HTTP method (GET, POST, etc.) + data: Request body data + headers: HTTP headers + + Returns: + Tuple[int, str]: (status_code, response_text) + """ + # Using requests library with verification settings from instance + response: Response = requests.request( + method=method, + url=url, + json=data, + headers=headers or {"Content-Type": "application/json"}, + verify=self.verify_certificate, + timeout=REQTEST_TIMEOUT, + ) + return response.status_code, response.text + + def solve_challenge(self, challenge: Optional[str] = None) -> str: + """Solve the connection challenge using the public key. + + If no challenge is provided, uses the challenge from connection_details. + + Args: + challenge: The challenge string from the server (optional) + + Returns: + str: The solution to the challenge (base64 encoded decrypted challenge) + + Raises: + ValueError: If no challenge is provided and connection_details is not set + ValueError: If the public key is not available + RuntimeError: If challenge decryption fails + """ + if challenge is None: + if not self._connection_details or not self._connection_details.challenge: + raise ValueError( + "Challenge not provided and not available in connection details" + ) + challenge = self._connection_details.challenge + + if not self._key_pair and not self._public_key: + raise ValueError( + "Public key is not available. Generate or load a key pair first." + ) + + try: + # If we have a jwskate Jwk object, use it directly + if self._key_pair: + rsa_key_pair = self._key_pair + # Otherwise try to parse the public key + elif self._public_key: + rsa_key_pair = Jwk.from_pem(self._public_key) + else: + raise ValueError("No public key available") + + # Decrypt the JWE challenge - get result as bytes and convert to string + jwe_compact = JweCompact(challenge) + decrypted_bytes = jwe_compact.decrypt(rsa_key_pair) + # Make sure we have a proper string + if hasattr(decrypted_bytes, "decode"): + decrypted_string = decrypted_bytes.decode("utf-8") + else: + decrypted_string = str(decrypted_bytes) + + # Parse the JSON payload + challenge_mapping: Mapping[str, Any] = json.loads(decrypted_string) + + # Create an unprotected JWT from the challenge + jwt_token = Jwt.unprotected(challenge_mapping) + jwt_token_str = str(jwt_token) + + # Encode the token as base64 + decrypted_challenge_str: str = base64.b64encode( + jwt_token_str.encode("utf-8") + ).decode("utf-8") + + # Store the pairing details if we have all required components + if self._pairing_response and self._connection_details: + self._pairing_details = PairingDetails( + pairing_response=self._pairing_response, + connection_details=self._connection_details, + decrypted_challenge_str=decrypted_challenge_str, + ) + + logger.info('Decrypted challenge: %s', decrypted_challenge_str) + return decrypted_challenge_str + + except (ValueError, TypeError, KeyError, json.JSONDecodeError) as e: + error_msg = f"Failed to solve challenge: {e}" + logger.info(error_msg) + raise RuntimeError(error_msg) from e + + def establish_secure_connection(self) -> Dict[str, Any]: + """Establish a secure WebSocket connection. + + This implementation establishes a WebSocket connection + using the connection details and solved challenge. + + Note: This is a placeholder implementation. In a real implementation, + this would use a WebSocket library like websocket-client or websockets. + + Returns: + Dict[str, Any]: A WebSocket connection object + + Raises: + ValueError: If connection details or solved challenge are not available + RuntimeError: If connection establishment fails + """ + if not self._connection_details: + raise ValueError( + "Connection details not available. Call request_connection first." + ) + + if ( + not self._pairing_details + or not self._pairing_details.decrypted_challenge_str + ): + raise ValueError( + "Challenge solution not available. Call solve_challenge first." + ) + + logger.info('Establishing WebSocket connection to %s,', self._connection_details.connectionUri) + logger.info('Using solved challenge: %s', self._pairing_details.decrypted_challenge_str) + + # Placeholder for the connection object + self._ws_connection = { + "status": "connected", + "uri": str(self._connection_details.connectionUri), + } + + return self._ws_connection + + def close_connection(self) -> None: + """Close the WebSocket connection. + + """ + if self._ws_connection: + + logger.info("Would close WebSocket connection") + self._ws_connection = None diff --git a/src/s2python/communication/__init__.py b/src/s2python/communication/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/s2python/communication/examples/example_frbc_rm.py b/src/s2python/communication/examples/example_frbc_rm.py new file mode 100644 index 0000000..7b2b6ff --- /dev/null +++ b/src/s2python/communication/examples/example_frbc_rm.py @@ -0,0 +1,204 @@ +import argparse +from functools import partial +import logging +import sys +import uuid +import signal +import datetime +from typing import Any, Callable, Optional + +from s2python.common import ( + EnergyManagementRole, + Duration, + Role, + RoleType, + Commodity, + Currency, + NumberRange, + PowerRange, + CommodityQuantity, +) +from s2python.frbc import ( + FRBCInstruction, + FRBCSystemDescription, + FRBCActuatorDescription, + FRBCStorageDescription, + FRBCOperationMode, + FRBCOperationModeElement, + FRBCFillLevelTargetProfile, + FRBCFillLevelTargetProfileElement, + FRBCStorageStatus, + FRBCActuatorStatus, +) +from s2python.communication.s2_connection import S2Connection, AssetDetails +from s2python.s2_control_type import FRBCControlType, NoControlControlType +from s2python.message import S2Message + +logger = logging.getLogger("s2python") +logger.addHandler(logging.StreamHandler(sys.stdout)) +logger.setLevel(logging.DEBUG) + + +class MyFRBCControlType(FRBCControlType): + def handle_instruction( + self, conn: S2Connection, msg: S2Message, send_okay: Callable[[], None] + ) -> None: + if not isinstance(msg, FRBCInstruction): + raise RuntimeError( + f"Expected an FRBCInstruction but received a message of type {type(msg)}." + ) + print(f"I have received the message {msg} from {conn}") + + def activate(self, conn: S2Connection) -> None: + print("The control type FRBC is now activated.") + + print("Time to send a FRBC SystemDescription") + actuator_id = uuid.uuid4() + operation_mode_id = uuid.uuid4() + conn.send_msg_and_await_reception_status_sync( + FRBCSystemDescription( + message_id=uuid.uuid4(), + valid_from=datetime.datetime.now(tz=datetime.timezone.utc), + actuators=[ + FRBCActuatorDescription( + id=actuator_id, + operation_modes=[ + FRBCOperationMode( + id=operation_mode_id, + elements=[ + FRBCOperationModeElement( + fill_level_range=NumberRange( + start_of_range=0.0, end_of_range=100.0 + ), + fill_rate=NumberRange( + start_of_range=-5.0, end_of_range=5.0 + ), + power_ranges=[ + PowerRange( + start_of_range=-200.0, + end_of_range=200.0, + commodity_quantity=CommodityQuantity.ELECTRIC_POWER_L1, + ) + ], + ) + ], + diagnostic_label="Load & unload battery", + abnormal_condition_only=False, + ) + ], + transitions=[], + timers=[], + supported_commodities=[Commodity.ELECTRICITY], + ) + ], + storage=FRBCStorageDescription( + fill_level_range=NumberRange( + start_of_range=0.0, end_of_range=100.0 + ), + fill_level_label="%", + diagnostic_label="Imaginary battery", + provides_fill_level_target_profile=True, + provides_leakage_behaviour=False, + provides_usage_forecast=False, + ), + ) + ) + print("Also send the target profile") + + conn.send_msg_and_await_reception_status_sync( + FRBCFillLevelTargetProfile( + message_id=uuid.uuid4(), + start_time=datetime.datetime.now(tz=datetime.timezone.utc), + elements=[ + FRBCFillLevelTargetProfileElement( + duration=Duration.from_milliseconds(30_000), + fill_level_range=NumberRange( + start_of_range=20.0, end_of_range=30.0 + ), + ), + FRBCFillLevelTargetProfileElement( + duration=Duration.from_milliseconds(300_000), + fill_level_range=NumberRange( + start_of_range=40.0, end_of_range=50.0 + ), + ), + ], + ) + ) + + print("Also send the storage status.") + conn.send_msg_and_await_reception_status_sync( + FRBCStorageStatus(message_id=uuid.uuid4(), present_fill_level=10.0) + ) + + print("Also send the actuator status.") + conn.send_msg_and_await_reception_status_sync( + FRBCActuatorStatus( + message_id=uuid.uuid4(), + actuator_id=actuator_id, + active_operation_mode_id=operation_mode_id, + operation_mode_factor=0.5, + ) + ) + + def deactivate(self, conn: S2Connection) -> None: + print("The control type FRBC is now deactivated.") + + +class MyNoControlControlType(NoControlControlType): + def activate(self, conn: S2Connection) -> None: + print("The control type NoControl is now activated.") + + def deactivate(self, conn: S2Connection) -> None: + print("The control type NoControl is now deactivated.") + + +def stop( + s2_connection: S2Connection, signal_num: int, _current_stack_frame: Any +) -> None: + print(f"Received signal {signal_num}. Will stop S2 connection.") + s2_connection.stop() + + +def start_s2_session( + url: str, + client_node_id: uuid.UUID = uuid.uuid4(), + bearer_token: Optional[str] = None, +) -> None: + s2_conn = S2Connection( + url=url, + role=EnergyManagementRole.RM, + control_types=[MyFRBCControlType(), MyNoControlControlType()], + asset_details=AssetDetails( + resource_id=client_node_id, + name="Some asset", + instruction_processing_delay=Duration.from_milliseconds(20), + roles=[ + Role(role=RoleType.ENERGY_CONSUMER, commodity=Commodity.ELECTRICITY) + ], + currency=Currency.EUR, + provides_forecast=False, + provides_power_measurements=[CommodityQuantity.ELECTRIC_POWER_L1], + ), + reconnect=True, + verify_certificate=False, + bearer_token=bearer_token, + ) + signal.signal(signal.SIGINT, partial(stop, s2_conn)) + signal.signal(signal.SIGTERM, partial(stop, s2_conn)) + + s2_conn.start_as_rm() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="A simple S2 reseource manager example." + ) + parser.add_argument( + "endpoint", + type=str, + help="WebSocket endpoint uri for the server (CEM) e.g. ws://localhost:8080/websocket/s2/my-first-websocket-rm", + ) + args = parser.parse_args() + + start_s2_session(args.endpoint) diff --git a/src/s2python/communication/examples/example_pairing_frbc_rm.py b/src/s2python/communication/examples/example_pairing_frbc_rm.py new file mode 100644 index 0000000..f192281 --- /dev/null +++ b/src/s2python/communication/examples/example_pairing_frbc_rm.py @@ -0,0 +1,90 @@ +import argparse +import logging + +from s2python.communication.examples.example_frbc_rm import start_s2_session +from s2python.authorization.default_client import S2DefaultClient +from s2python.generated.gen_s2_pairing import ( + S2NodeDescription, + Deployment, + PairingToken, + S2Role, + Protocols, +) + +logger = logging.getLogger("s2python") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="A simple S2 resource manager example." + ) + parser.add_argument( + "--endpoint", + type=str, + help="Rest endpoint to start S2 pairing. E.g. https://localhost/requestPairing", + ) + parser.add_argument( + "--pairing_token", + type=str, + help="The pairing token for the endpoint. You should get this from the S2 server e.g. ca14fda4", + ) + parser.add_argument( + "--verify-ssl", + action="store_true", + help="Verify SSL certificates (default: False)", + default=False, + ) + args = parser.parse_args() + + # Configure logging + logging.basicConfig(level=logging.INFO) + + # Create node description + node_description = S2NodeDescription( + brand="TNO", + logoUri="https://www.tno.nl/publish/pages/5604/tno-logo-1484x835_003_.jpg", + type="demo frbc example", + modelName="S2 pairing example stub", + userDefinedName="TNO S2 pairing example for frbc", + role=S2Role.RM, + deployment=Deployment.LAN, + ) + + # Create a client to perform the pairing + client = S2DefaultClient( + pairing_uri=args.endpoint, + token=PairingToken( + token=args.pairing_token, + ), + node_description=node_description, + verify_certificate=args.verify_ssl, + supported_protocols=[Protocols.WebSocketSecure], + ) + + try: + # Request pairing + logger.info("Initiating pairing with endpoint: %s", args.endpoint) + pairing_response = client.request_pairing() + logger.info("Pairing request successful, requesting connection...") + + # Request connection details + connection_details = client.request_connection() + logger.info("Connection request successful") + + # Solve challenge + challenge_result = client.solve_challenge() + logger.info("Challenge solved successfully") + + # Log connection details + logger.info("Connection URI: %s", connection_details.connectionUri) + + # Start S2 session with the connection details + logger.info("Starting S2 session...") + start_s2_session( + str(connection_details.connectionUri), + bearer_token=challenge_result, + ) + + except Exception as e: + logger.error("Error during pairing process: %s", e) + raise e diff --git a/src/s2python/communication/examples/mock_s2_server.py b/src/s2python/communication/examples/mock_s2_server.py new file mode 100644 index 0000000..c085b63 --- /dev/null +++ b/src/s2python/communication/examples/mock_s2_server.py @@ -0,0 +1,138 @@ +import http.server +import socketserver +import json +from typing import Any +import uuid +import logging +import random +import string + +# Set up logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger("mock_s2_server") + + +def generate_token() -> str: + """ + Generate a random alphanumeric token with exactly 32 characters. + + Returns: + str: A string of 32 random alphanumeric characters matching pattern ^[0-9a-zA-Z]{32}$ + """ + # Define the character set: uppercase letters, lowercase letters, and digits + chars = string.ascii_letters + string.digits + + # Generate a 32-character token by randomly selecting from the character set + token = "".join(random.choice(chars) for _ in range(32)) + + return token + + +# Generate random token for pairing +PAIRING_TOKEN = generate_token() +SERVER_NODE_ID = str(uuid.uuid4()) +WS_PORT = 8080 +HTTP_PORT = 8000 + + +class MockS2Handler(http.server.BaseHTTPRequestHandler): + def do_POST(self) -> None: # pylint: disable=C0103 + content_length = int(self.headers.get("Content-Length", 0)) + post_data = self.rfile.read(content_length).decode("utf-8") + + try: + request_json = json.loads(post_data) + logger.info('Received request at %s', self.path) + logger.debug('Request body: %s', request_json) + + if self.path == "/requestPairing": + # Handle pairing request + # The token in the S2 protocol is a PairingToken object with a token field + token_obj = request_json.get("token", {}) + + # Handle case where token is directly the string or a dict with token field + if isinstance(token_obj, dict) and "token" in token_obj: + request_token_string = token_obj["token"] + else: + request_token_string = token_obj + + logger.info('Extracted token: %s', request_token_string) + logger.info('Expected token: %s', PAIRING_TOKEN) + + if request_token_string == PAIRING_TOKEN: + self.send_response(200) + self.send_header("Content-Type", "application/json") + self.end_headers() + + # Create pairing response + response = { + "s2ServerNodeId": SERVER_NODE_ID, + "serverNodeDescription": { + "brand": "Mock S2 Server", + "type": "Test Server", + "modelName": "Mock Model", + "logoUri": "http://example.com/logo.png", + "userDefinedName": "Mock Server", + "role": "CEM", + "deployment": "LAN", + }, + "requestConnectionUri": f"http://localhost:{HTTP_PORT}/requestConnection", + } + + self.wfile.write(json.dumps(response).encode()) + logger.info("Pairing request successful") + else: + self.send_response(401) + self.send_header("Content-Type", "application/json") + self.end_headers() + self.wfile.write(json.dumps({"error": "Invalid token"}).encode()) + logger.error("Invalid pairing token") + + elif self.path == "/requestConnection": + # Handle connection request + self.send_response(200) + self.send_header("Content-Type", "application/json") + self.end_headers() + + # Create challenge (normally would be a JWE) + challenge = "mock_challenge_string" + + # Create connection details response + response = { + "connectionUri": f"ws://localhost:{WS_PORT}/s2/mock-websocket", + "challenge": challenge, + "selectedProtocol": "WebSocketSecure", + } + + self.wfile.write(json.dumps(response).encode()) + logger.info("Connection request successful") + + else: + self.send_response(404) + self.send_header("Content-Type", "application/json") + self.end_headers() + self.wfile.write(json.dumps({"error": "Endpoint not found"}).encode()) + logger.error('Unknown endpoint: %s', self.path) + + except Exception as e: + self.send_response(500) + self.send_header("Content-Type", "application/json") + self.end_headers() + self.wfile.write(json.dumps({"error": str(e)}).encode()) + logger.error('Error handling request: %s', e) + raise e + + def log_message(self, format: str, *args: Any) -> None: # pylint: disable=W0622 + logger.info(format % args) # pylint: disable=W1201 + + +def run_server() -> None: + with socketserver.TCPServer(("localhost", HTTP_PORT), MockS2Handler) as httpd: + logger.info('Mock S2 Server running at: http://localhost:%s', HTTP_PORT) + logger.info('Use pairing token: %s', PAIRING_TOKEN) + logger.info('Pairing endpoint: http://localhost:%s/requestPairing', HTTP_PORT) + httpd.serve_forever() + + +if __name__ == "__main__": + run_server() diff --git a/src/s2python/communication/examples/mock_s2_websocket.py b/src/s2python/communication/examples/mock_s2_websocket.py new file mode 100644 index 0000000..e8fd991 --- /dev/null +++ b/src/s2python/communication/examples/mock_s2_websocket.py @@ -0,0 +1,85 @@ +import asyncio +import logging +import json +import uuid +from datetime import datetime, timezone +import websockets + +# Set up logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger("mock_s2_websocket") + +# WebSocket server port +WS_PORT = 8080 + + +# Handle client connection +async def handle_connection( + websocket: websockets.WebSocketServerProtocol, path: str +) -> None: + client_id = str(uuid.uuid4()) + logger.info('Client %s connected on path: %s', client_id, path) + + try: + # Send handshake message to client + handshake = { + "type": "Handshake", + "messageId": str(uuid.uuid4()), + "protocolVersion": "0.0.2-beta", + "timestamp": datetime.now(timezone.utc).isoformat(), + } + await websocket.send(json.dumps(handshake)) + logger.info('Sent handshake to client %s', client_id) + + # Listen for messages + async for message in websocket: + try: + data = json.loads(message) + logger.info('Received message from client %s: %s', client_id, data) + + # Extract message type + message_type = data.get("type", "") + message_id = data.get("messageId", str(uuid.uuid4())) + + # Send reception status + reception_status = { + "type": "ReceptionStatus", + "messageId": str(uuid.uuid4()), + "refMessageId": message_id, + "timestamp": datetime.now(timezone.utc).isoformat(), + "status": "OK", + } + await websocket.send(json.dumps(reception_status)) + logger.info('Sent reception status for message %s', message_id) + + # Handle specific message types + if message_type == "HandshakeResponse": + logger.info("Received handshake response") + + # For FRBC messages, you could add specific handling here + + except json.JSONDecodeError: + logger.error('Invalid JSON received from client %s', client_id) + except Exception as e: + logger.error('Error processing message from client %s: %s', client_id, e) + raise e + + except websockets.exceptions.ConnectionClosed: + logger.info('Connection with client %s closed', client_id) + except Exception as e: + logger.error('Error with client %s: %s', client_id, e) + raise e + finally: + logger.info('Client %s disconnected', client_id) + + +async def start_server() -> None: + server = await websockets.serve(handle_connection, "localhost", WS_PORT) + logger.info('WebSocket server started on ws://localhost:%s', WS_PORT) + + # Keep the server running + await server.wait_closed() + + +if __name__ == "__main__": + asyncio.run(start_server()) diff --git a/src/s2python/reception_status_awaiter.py b/src/s2python/communication/reception_status_awaiter.py similarity index 100% rename from src/s2python/reception_status_awaiter.py rename to src/s2python/communication/reception_status_awaiter.py diff --git a/src/s2python/s2_connection.py b/src/s2python/communication/s2_connection.py similarity index 96% rename from src/s2python/s2_connection.py rename to src/s2python/communication/s2_connection.py index efbd366..71e3326 100644 --- a/src/s2python/s2_connection.py +++ b/src/s2python/communication/s2_connection.py @@ -33,7 +33,7 @@ SelectControlType, ) from s2python.generated.gen_s2 import CommodityQuantity -from s2python.reception_status_awaiter import ReceptionStatusAwaiter +from s2python.communication.reception_status_awaiter import ReceptionStatusAwaiter from s2python.s2_control_type import S2ControlType from s2python.s2_parser import S2Parser from s2python.s2_validation_error import S2ValidationError @@ -244,7 +244,9 @@ def __init__( # pylint: disable=too-many-arguments SelectControlType, self.handle_select_control_type_as_rm ) self._handlers.register_handler(Handshake, self.handle_handshake) - self._handlers.register_handler(HandshakeResponse, self.handle_handshake_response_as_rm) + self._handlers.register_handler( + HandshakeResponse, self.handle_handshake_response_as_rm + ) self._bearer_token = bearer_token def start_as_rm(self) -> None: @@ -431,9 +433,9 @@ async def handle_select_control_type_as_rm( control_types_by_protocol_name = { c.get_protocol_control_type(): c for c in self.control_types } - selected_control_type: Optional[S2ControlType] = ( - control_types_by_protocol_name.get(message.control_type) - ) + selected_control_type: Optional[ + S2ControlType + ] = control_types_by_protocol_name.get(message.control_type) if self._current_control_type is not None: await self._eventloop.run_in_executor( @@ -467,7 +469,9 @@ async def _receive_messages(self) -> None: except json.JSONDecodeError: await self._send_and_forget( ReceptionStatus( - subject_message_id=uuid.UUID("00000000-0000-0000-0000-000000000000"), + subject_message_id=uuid.UUID( + "00000000-0000-0000-0000-000000000000" + ), status=ReceptionStatusValues.INVALID_DATA, diagnostic_label="Not valid json.", ) @@ -483,7 +487,9 @@ async def _receive_messages(self) -> None: ) else: await self.respond_with_reception_status( - subject_message_id=uuid.UUID("00000000-0000-0000-0000-000000000000"), + subject_message_id=uuid.UUID( + "00000000-0000-0000-0000-000000000000" + ), status=ReceptionStatusValues.INVALID_DATA, diagnostic_label="Message appears valid json but could not find a message_id field.", ) @@ -514,7 +520,10 @@ async def _send_and_forget(self, s2_msg: S2Message) -> None: self._restart_connection_event.set() async def respond_with_reception_status( - self, subject_message_id: uuid.UUID, status: ReceptionStatusValues, diagnostic_label: str + self, + subject_message_id: uuid.UUID, + status: ReceptionStatusValues, + diagnostic_label: str, ) -> None: logger.debug( "Responding to message %s with status %s", subject_message_id, status @@ -528,7 +537,10 @@ async def respond_with_reception_status( ) def respond_with_reception_status_sync( - self, subject_message_id: uuid.UUID, status: ReceptionStatusValues, diagnostic_label: str + self, + subject_message_id: uuid.UUID, + status: ReceptionStatusValues, + diagnostic_label: str, ) -> None: asyncio.run_coroutine_threadsafe( self.respond_with_reception_status( diff --git a/src/s2python/generated/gen_s2_pairing.py b/src/s2python/generated/gen_s2_pairing.py index 15a6b05..afb5493 100644 --- a/src/s2python/generated/gen_s2_pairing.py +++ b/src/s2python/generated/gen_s2_pairing.py @@ -1,70 +1,88 @@ # generated by datamodel-codegen: -# filename: s2-over-ip-pairing.yaml -# timestamp: 2025-04-15T14:41:29+00:00 +# filename: s2-over-ip-pairing +# timestamp: 2025-02-28T14:52:45+00:00 from __future__ import annotations from enum import Enum -from typing import List, Optional -from uuid import UUID +from typing import List -from pydantic import AnyUrl, AwareDatetime, BaseModel, RootModel, constr +from pydantic import BaseModel, ConfigDict -class Protocols(Enum): - WebSocketSecure = 'WebSocketSecure' +class S2Role(str, Enum): + CEM = "CEM" + RM = "RM" -class S2Role(Enum): - CEM = 'CEM' - RM = 'RM' +class Deployment(str, Enum): + WAN = "WAN" + LAN = "LAN" -class Deployment(Enum): - WAN = 'WAN' - LAN = 'LAN' - - -class PairingToken(RootModel[constr(pattern=r'^[0-9a-zA-Z]{32}$')]): - root: constr(pattern=r'^[0-9a-zA-Z]{32}$') +class Protocols(str, Enum): + WebSocketSecure = "WebSocketSecure" +class PairingToken(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + token: str class PairingInfo(BaseModel): - pairingUri: Optional[AnyUrl] = None - token: Optional[PairingToken] = None - validUntil: Optional[AwareDatetime] = None + model_config = ConfigDict( + extra="forbid", + ) + pairingUri: str + token: str + validUntil: str -class ConnectionRequest(BaseModel): - s2ClientNodeId: Optional[UUID] = None - supportedProtocols: Optional[List[Protocols]] = None +class S2NodeDescription(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + brand: str + logoUri: str + type: str + modelName: str + userDefinedName: str + role: S2Role + deployment: Deployment -class ConnectionDetails(BaseModel): - selectedProtocol: Optional[Protocols] = None - challenge: Optional[str] = None - connectionUri: Optional[AnyUrl] = None +class PairingRequest(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + token: PairingToken + publicKey: str + s2ClientNodeId: str + s2ClientNodeDescription: S2NodeDescription + supportedProtocols: List[Protocols] -class S2NodeDescription(BaseModel): - brand: Optional[str] = None - logoUri: Optional[AnyUrl] = None - type: Optional[str] = None - modelName: Optional[str] = None - userDefinedName: Optional[str] = None - role: Optional[S2Role] = None - deployment: Optional[Deployment] = None +class PairingResponse(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + s2ServerNodeId: str + serverNodeDescription: S2NodeDescription + requestConnectionUri: str -class PairingRequest(BaseModel): - token: Optional[PairingToken] = None - publicKey: Optional[str] = None - s2ClientNodeId: Optional[UUID] = None - s2ClientNodeDescription: Optional[S2NodeDescription] = None - supportedProtocols: Optional[List[Protocols]] = None +class ConnectionRequest(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + s2ClientNodeId: str + supportedProtocols: List[Protocols] -class PairingResponse(BaseModel): - s2ServerNodeId: Optional[UUID] = None - serverNodeDescription: Optional[S2NodeDescription] = None - requestConnectionUri: Optional[AnyUrl] = None +class ConnectionDetails(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + selectedProtocol: Protocols + challenge: str + connectionUri: str diff --git a/src/s2python/s2_control_type.py b/src/s2python/s2_control_type.py index 135f775..32c520c 100644 --- a/src/s2python/s2_control_type.py +++ b/src/s2python/s2_control_type.py @@ -8,7 +8,7 @@ from s2python.message import S2Message if typing.TYPE_CHECKING: - from s2python.s2_connection import S2Connection, MessageHandlers + from s2python.communication.s2_connection import S2Connection, MessageHandlers class S2ControlType(abc.ABC): diff --git a/tests/unit/reception_status_awaiter_test.py b/tests/unit/reception_status_awaiter_test.py index fb06630..22e9ef4 100644 --- a/tests/unit/reception_status_awaiter_test.py +++ b/tests/unit/reception_status_awaiter_test.py @@ -16,7 +16,7 @@ InstructionStatus, InstructionStatusUpdate, ) -from s2python.reception_status_awaiter import ReceptionStatusAwaiter +from s2python.communication.reception_status_awaiter import ReceptionStatusAwaiter class ReceptionStatusAwaiterTest(IsolatedAsyncioTestCase):