diff --git a/dev-requirements.txt b/dev-requirements.txt index 183afa3..9426c39 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -239,6 +239,18 @@ wheel==0.44.0 # via pip-tools zipp==3.20.1 # via importlib-metadata +jwskate==0.11.1 +binapy==0.8.0 + # via jwskate +cffi==1.17.1 + # via jwskate +cryptography==44.0.2 + # via jwskate +pycparser==2.22 + # via jwskate +requests==2.32.3 +types-requests==2.32.0.20241016 + # The following packages are considered to be unsafe in a requirements file: # pip diff --git a/examples/example_frbc_rm.py b/examples/example_frbc_rm.py index 36d8fb7..ff10fd4 100644 --- a/examples/example_frbc_rm.py +++ b/examples/example_frbc_rm.py @@ -1,5 +1,4 @@ import argparse -import re from functools import partial import logging import sys @@ -157,7 +156,7 @@ def stop(s2_connection, signal_num, _current_stack_frame): print(f"Received signal {signal_num}. Will stop S2 connection.") s2_connection.stop() -def start_s2_session(url, client_node_id=str(uuid.uuid4())): +def start_s2_session(url, client_node_id=str(uuid.uuid4()), bearer_token=None): s2_conn = S2Connection( url=url, role=EnergyManagementRole.RM, @@ -172,7 +171,8 @@ def start_s2_session(url, client_node_id=str(uuid.uuid4())): provides_power_measurements=[CommodityQuantity.ELECTRIC_POWER_L1] ), reconnect=True, - verify_certificate=False + verify_certificate=False, + bearer_token=bearer_token ) signal.signal(signal.SIGINT, partial(stop, s2_conn)) signal.signal(signal.SIGTERM, partial(stop, s2_conn)) @@ -181,7 +181,11 @@ def start_s2_session(url, client_node_id=str(uuid.uuid4())): if __name__ == "__main__": parser = argparse.ArgumentParser(description="A simple S2 reseource manager example.") - parser.add_argument('endpoint', type=str, help="WebSocket endpoint uri for the server (CEM) e.g. ws://localhost:8080/backend/rm/s2python-frbc/cem/dummy_model/ws") + parser.add_argument( + 'endpoint', + type=str, + help="WebSocket endpoint uri for the server (CEM) e.h. ws://localhost:8080/websocket/s2/my-first-websocket-rm" + ) args = parser.parse_args() start_s2_session(args.endpoint) diff --git a/examples/example_with_pairing_frbc_rm.py b/examples/example_with_pairing_frbc_rm.py new file mode 100644 index 0000000..aef4de4 --- /dev/null +++ b/examples/example_with_pairing_frbc_rm.py @@ -0,0 +1,41 @@ +import argparse +import uuid +import logging + +from example_frbc_rm import start_s2_session +from s2python.s2_pairing import S2Pairing +from s2python.generated.gen_s2_pairing import S2NodeDescription, Deployment +from s2python.generated.gen_s2 import EnergyManagementRole + +logger = logging.getLogger("s2python") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="A simple S2 reseource manager example.") + parser.add_argument('endpoint', + type=str, + help="Rest endpoint to start S2 pairing. E.g. https://localhost/requestPairing") + parser.add_argument('pairing_token', + type=str, + help="The pairing toekn for teh endpoint. You should get this from the S2 server e.g. ca14fda4") + args = parser.parse_args() + + nodeDescription: S2NodeDescription = \ + S2NodeDescription(brand="TNO", + logoUri = "https://www.tno.nl/publish/pages/5604/tno-logo-1484x835_003_.jpg", + type = "demo frbc example", + modelName = "S2 pairing example stub", + userDefinedName = "TNO S2 pairing example for frbc", + role = EnergyManagementRole.RM, + deployment = Deployment.LAN) + client_node_id: str = str(uuid.uuid4()) + + pairing: S2Pairing = S2Pairing(request_pairing_endpoint = args.endpoint, + token = args.pairing_token, + s2_client_node_description = nodeDescription, + client_node_id = client_node_id) + + logger.info("Pairing details: \n%s", pairing.pairing_details) + + start_s2_session(pairing.pairing_details.connection_details.connectionUri, + bearer_token=pairing.pairing_details.decrypted_challenge_base64) diff --git a/setup.cfg b/setup.cfg index b23fdfb..2395f6a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -43,6 +43,8 @@ install_requires = pytz click websockets~=13.1 + jwskate~=0.11 + requests~=2.32.3 [options.packages.find] where = src diff --git a/src/s2python/generated/gen_s2_pairing.py b/src/s2python/generated/gen_s2_pairing.py new file mode 100644 index 0000000..df45a08 --- /dev/null +++ b/src/s2python/generated/gen_s2_pairing.py @@ -0,0 +1,82 @@ +# generated by datamodel-codegen: +# filename: s2-over-ip-pairing +# timestamp: 2025-02-28T14:52:45+00:00 + +from __future__ import annotations + +from enum import Enum +from typing import List + +from pydantic import BaseModel, ConfigDict, Field + + +class S2Role(str, Enum): + CEM = 'CEM' + RM = 'RM' + +class Deployment(str, Enum): + WAN = 'WAN' + LAN = 'LAN' + + +class Protocols(str, Enum): + WebSocketSecure = 'WebSocketSecure' + + +class PairingInfo(BaseModel): + model_config = ConfigDict( + extra='forbid', + ) + pairingUri: str + token: str + validUntil: str + + +class S2NodeDescription(BaseModel): + model_config = ConfigDict( + extra='forbid', + ) + brand: str + logoUri: str + type: str + modelName: str + userDefinedName: str + role: S2Role + deployment: Deployment + + +class PairingRequest(BaseModel): + model_config = ConfigDict( + extra='forbid', + ) + token: str + publicKey: str + s2ClientNodeId: str + s2ClientNodeDescription: S2NodeDescription + supportedProtocols: List[Protocols] + + +class PairingResponse(BaseModel): + model_config = ConfigDict( + extra='forbid', + ) + s2ServerNodeId: str + serverNodeDescription: S2NodeDescription + requestConnectionUri: str + + +class ConnectionRequest(BaseModel): + model_config = ConfigDict( + extra='forbid', + ) + s2ClientNodeId: str + supportedProtocols: List[Protocols] + + +class ConnectionDetails(BaseModel): + model_config = ConfigDict( + extra='forbid', + ) + selectedProtocol: Protocols + challenge: str + connectionUri: str diff --git a/src/s2python/s2_pairing.py b/src/s2python/s2_pairing.py new file mode 100644 index 0000000..1b9c413 --- /dev/null +++ b/src/s2python/s2_pairing.py @@ -0,0 +1,135 @@ +import base64 +import logging +import uuid +import datetime +from dataclasses import dataclass +from typing import Tuple, Union, Mapping, Any +import json +import requests + +from jwskate import JweCompact, Jwk, Jwt, SignedJwt + +from s2python.generated.gen_s2_pairing import (Protocols, + PairingRequest, + S2NodeDescription, + PairingResponse, + ConnectionRequest, + ConnectionDetails) + + +logger = logging.getLogger("s2python") + + +REQTEST_TIMEOUT = 10 +PAIRING_TIMEOUT = datetime.timedelta(minutes=5) +KEY_ALGORITHM = "RSA-OAEP-256" + +@dataclass(frozen=True) +class PairingDetails: + """The result of an S2 pairing + :param pairing_response: Details about the server. + :param connection_details: Details about how to connect. + :param decrypted_challenge_base64: The decrypted challenge needed as bearer token.""" + pairing_response: PairingResponse + connection_details: ConnectionDetails + decrypted_challenge_base64: str + +class S2Pairing: # pylint: disable=too-many-instance-attributes + _pairing_details: PairingDetails + _paring_timestamp: datetime.datetime + _request_pairing_endpoint: str + _token: str + _s2_client_node_description: S2NodeDescription + _verify_certificate: Union[bool, str] + _client_node_id: str + _supported_protocols: Tuple[Protocols] + def __init__( # pylint: disable=too-many-arguments + self, + request_pairing_endpoint: str, + token: str, + s2_client_node_description: S2NodeDescription, + verify_certificate: Union[bool, str] = False, + client_node_id: str = str(uuid.uuid4()), + supported_protocols: Tuple[Protocols] = (Protocols.WebSocketSecure, ) + ) -> None: + """Creates an S2 pairing for the device and holds the challenge needed to be provided as bearer token + when setting up an S2 (websockets) communication session + :param request_pairing_endpoint: The full uri endpoint to request pairing from. + :param token: The token that needs to be provided to the server in teh pairing process. + :param s2_client_node_description: The descriptin ofr the client as a S2NodeDescription. + :param verify_certificate: Either a boolean whether or not to verify the server's SSL certificate + (defaults to False), or a path to a certificate file to use for verification purposes. + :param client_node_id: UUID for the client. If none is given, one will be generated. + :param supported_protocols: The protocols supported by the client (defaults: Protocols.WebSocketSecure).""" + self._paring_timestamp = datetime.datetime(year = datetime.MINYEAR, month = 1, day = 1) + self._request_pairing_endpoint = request_pairing_endpoint + self._token = token + self._s2_client_node_description = s2_client_node_description + self._verify_certificate = verify_certificate + self._client_node_id = client_node_id + self._supported_protocols = supported_protocols + + def _pair(self) -> None: + """Private method establishing pairing""" + # If pairing has been established recently we don't need to do it again + if datetime.datetime.now() < (self._paring_timestamp + PAIRING_TIMEOUT): + return + + self._paring_timestamp = datetime.datetime.now() + + rsa_key_pair = Jwk.generate_for_alg(KEY_ALGORITHM).with_kid_thumbprint() + pairing_request: PairingRequest = PairingRequest(token=self._token, + publicKey=rsa_key_pair.public_jwk().to_pem(), + s2ClientNodeId=self._client_node_id, + s2ClientNodeDescription=self._s2_client_node_description, + supportedProtocols=self._supported_protocols) + + response = requests.post(self._request_pairing_endpoint, + json = pairing_request.dict(), + timeout = REQTEST_TIMEOUT, + verify = self._verify_certificate) + response.raise_for_status() + pairing_response: PairingResponse = PairingResponse.parse_raw(response.text) + + connection_request: ConnectionRequest = ConnectionRequest(s2ClientNodeId=self._client_node_id, + supportedProtocols=self._supported_protocols) + + + restest_pairing_uri: str = \ + pairing_response.requestConnectionUri if hasattr(pairing_response, 'requestConnectionUri') \ + and pairing_response.requestConnectionUri is not None \ + and pairing_response.requestConnectionUri != "" \ + else self._request_pairing_endpoint.replace('requestPairing', + 'requestConnection') + + logger.info('restest_pairing_uri %s ', restest_pairing_uri) + + response = requests.post(restest_pairing_uri, + json = connection_request.dict(), + timeout = REQTEST_TIMEOUT, + verify = self._verify_certificate) + response.raise_for_status() + connection_details: ConnectionDetails = ConnectionDetails.parse_raw(response.text) + # if websocket address doesn't start with ws:// or wss:// assume it's relative to the requestPairing + if not connection_details.connectionUri.startswith('ws://') \ + and not connection_details.connectionUri.startswith('wss://'): + connection_details.connectionUri = \ + self._request_pairing_endpoint.replace('http://', 'ws://') \ + .replace('https://', 'wss://') \ + .replace('requestPairing', '') \ + .rstrip('/') \ + + '/' + connection_details.connectionUri.lstrip('/') + logger.info('connectionUri %s ', connection_details.connectionUri) + + challenge: Mapping[str, Any] = json.loads(JweCompact(connection_details.challenge).decrypt(rsa_key_pair)) + decrypted_challenge_token: SignedJwt = Jwt.unprotected(challenge) + decrypted_challenge_str: str = base64.b64encode(bytes(decrypted_challenge_token)).decode('utf-8') + self._pairing_details = PairingDetails(pairing_response, connection_details, decrypted_challenge_str) + + + @property + def pairing_details(self) -> PairingDetails: + """:raises: requests.exceptions.HTTPError, requests.exceptions.JSONDecodeError + :return: PairingDetails object that's the result of the latest pairing.""" + self._pair() + return self._pairing_details diff --git a/src/s2python/validate_values_mixin.py b/src/s2python/validate_values_mixin.py index cc9c6fd..fa4a8d7 100644 --- a/src/s2python/validate_values_mixin.py +++ b/src/s2python/validate_values_mixin.py @@ -59,7 +59,10 @@ def inner(*args: List[Any], **kwargs: Dict[str, Any]) -> Any: return inner -def catch_and_convert_exceptions(input_class: Type[S2MessageComponent[B_co]]) -> Type[S2MessageComponent[B_co]]: +S = TypeVar("S", bound=S2MessageComponent) + + +def catch_and_convert_exceptions(input_class: Type[S]) -> Type[S]: input_class.__init__ = convert_to_s2exception(input_class.__init__) # type: ignore[method-assign] input_class.__setattr__ = convert_to_s2exception(input_class.__setattr__) # type: ignore[method-assign] input_class.model_validate_json = convert_to_s2exception( # type: ignore[method-assign]