diff --git a/sdk/voice/speechmatics/voice/_client.py b/sdk/voice/speechmatics/voice/_client.py index b3cc134..2e1ee43 100644 --- a/sdk/voice/speechmatics/voice/_client.py +++ b/sdk/voice/speechmatics/voice/_client.py @@ -14,7 +14,10 @@ from typing import Callable from typing import Optional from typing import Union +from urllib.parse import parse_qs from urllib.parse import urlencode +from urllib.parse import urlparse +from urllib.parse import urlunparse from speechmatics.rt import AsyncClient from speechmatics.rt import AudioEncoding @@ -1914,12 +1917,21 @@ def _get_endpoint_url(self, url: str, app: Optional[str] = None) -> str: app: The application name to use in the endpoint URL. Returns: - str: The formatted endpoint URL. + str: The formatted endpoint URL. """ - query_params = {} - query_params["sm-app"] = app or f"voice-sdk/{__version__}" - query_params["sm-voice-sdk"] = f"{__version__}" - query = urlencode(query_params) + # Parse the URL to extract existing query parameters + parsed = urlparse(url) - return f"{url}?{query}" + # Extract existing params into a dict of lists, keeping params without values + params = parse_qs(parsed.query, keep_blank_values=True) + + # Use the provided app name, or fallback to existing value, or use the default string + existing_app = params.get("sm-app", [None])[0] + app_name = app or existing_app or f"voice-sdk/{__version__}" + params["sm-app"] = [app_name] + params["sm-voice-sdk"] = [__version__] + + # Re-encode the query string and reconstruct + updated_query = urlencode(params, doseq=True) + return urlunparse(parsed._replace(query=updated_query)) diff --git a/tests/voice/test_16_url.py b/tests/voice/test_16_url.py new file mode 100644 index 0000000..c41dfc9 --- /dev/null +++ b/tests/voice/test_16_url.py @@ -0,0 +1,89 @@ +from dataclasses import dataclass +from typing import Optional +from urllib.parse import parse_qs +from urllib.parse import urlparse + +import pytest +from _utils import get_client + +from speechmatics.voice import __version__ + + +@dataclass +class URLExample: + input_url: str + input_app: Optional[str] = None + + +URLS: list[URLExample] = [ + URLExample( + input_url="wss://dummy/ep", + input_app="dummy-0.1.2", + ), + URLExample( + input_url="wss://dummy:1234/ep?client=amz", + input_app="dummy-0.1.2", + ), + URLExample( + input_url="wss://dummy/ep?sm-app=dummy", + ), + URLExample( + input_url="ws://localhost:8080/ep?sm-app=dummy", + input_app="dummy-0.1.2", + ), + URLExample( + input_url="http://dummy/ep/v1/", + input_app="dummy-0.1.2", + ), + URLExample( + input_url="wss://dummy/ep", + ), + URLExample( + input_url="wss://dummy/ep", + input_app="client/a#b:c^d", + ), +] + + +@pytest.mark.asyncio +@pytest.mark.parametrize("test", URLS, ids=lambda s: s.input_url) +async def test_url_endpoints(test: URLExample): + """Test URL endpoint construction.""" + + # Client + client = await get_client( + api_key="DUMMY", + connect=False, + ) + + # Parse the input parameters + input_parsed = urlparse(test.input_url) + input_params = parse_qs(input_parsed.query, keep_blank_values=True) + + # URL test + generated_url = client._get_endpoint_url(test.input_url, test.input_app) + + # Parse the URL + parsed_url = urlparse(generated_url) + parsed_params = parse_qs(parsed_url.query, keep_blank_values=True) + + # Check the url scheme, netloc and path are preserved + assert parsed_url.scheme == input_parsed.scheme + assert parsed_url.netloc == input_parsed.netloc + assert parsed_url.path == input_parsed.path + + # Validate `sm-app` + if test.input_app: + assert parsed_params["sm-app"] == [test.input_app] + elif "sm-app" in input_params: + assert parsed_params["sm-app"] == [input_params["sm-app"][0]] + else: + assert parsed_params["sm-app"] == [f"voice-sdk/{__version__}"] + + # Validate `sm-voice-sdk` + assert parsed_params["sm-voice-sdk"] == [__version__] + + # Check other original params are preserved + for key, value in input_params.items(): + if key not in ["sm-app", "sm-voice-sdk"]: + assert parsed_params[key] == value