diff --git a/config/pull_request.yaml b/config/pull_request.yaml index 7687f1f..ffea157 100644 --- a/config/pull_request.yaml +++ b/config/pull_request.yaml @@ -32,6 +32,14 @@ blue_sky_air_traffic_simulator_settings: single_or_multiple_sensors: "multiple" # this setting specifiies if the traffic data is submitted from a single sensor or multiple sensors sensor_ids: ["562e6297036a4adebb4848afcd1ede90"] # List of sensor IDs to use when 'multiple' is selected +# Bayesian Air traffic data configuration +bayesian_air_traffic_simulator_settings: + number_of_aircraft: 3 + simulation_duration_seconds: 30 + single_or_multiple_sensors: "multiple" # this setting specifies if the traffic data is submitted from a single sensor or multiple sensors + sensor_ids: ["562e6297036a4adebb4848afcd1ede90"] # List of sensor IDs to use when 'multiple' is selected + session_ids: ["ee9405e564ea4373823e37d950858e6a"] # List of session IDs to use when 'multiple' is selected, a session id is needed in Flight Blender to depict a period of time these observations were made (this assumes the observations may not be continuous); if empty, random UUIDs will be generated + data_files: trajectory: "config/bern/trajectory_f1.json" # Path to flight declarations JSON file flight_declaration: "config/bern/flight_declaration.json" # Path to flight declarations JSON file diff --git a/pyproject.toml b/pyproject.toml index 4db1e9e..c2819a9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,6 +49,7 @@ dependencies = [ "uvicorn[standard]>=0.38.0", # includes watchfiles for efficient reload "bluesky-simulator==1.1.0", "rtree==1.4.1", + "cam-track-gen @ git+https://github.com/openutm-labs/Canadian-Airspace-Models.git" ] [project.scripts] @@ -70,9 +71,6 @@ dev = [ "types-requests", "pytest-cov>=7.0.0", ] -bayesian-track-generation = [ - "cam-track-gen @ git+https://github.com/openutm-labs/Canadian-Airspace-Models.git", -] [build-system] requires = ["hatchling"] @@ -85,6 +83,9 @@ packages = ["src/openutm_verification"] [tool.hatch.build.targets.wheel.force-include] "docs/scenarios" = "openutm_verification/docs/scenarios" +[tool.hatch.metadata] +allow-direct-references = true + [tool.pytest.ini_options] pythonpath = [".", "src/openutm_verification"] testpaths = ["tests"] diff --git a/scenarios/stream_air_traffic_example.yaml b/scenarios/stream_air_traffic_example.yaml new file mode 100644 index 0000000..b7fde40 --- /dev/null +++ b/scenarios/stream_air_traffic_example.yaml @@ -0,0 +1,53 @@ +name: stream_air_traffic_example +version: "1.0" +description: | + Example scenario demonstrating the unified "Stream Air Traffic" step. + + This step replaces the provider-specific steps: + - "Generate Simulated Air Traffic Data" (geojson) + - "Generate BlueSky Simulation Air Traffic Data" (bluesky) + - "Generate Bayesian Simulation Air Traffic Data" (bayesian) + - "Fetch OpenSky Data" (opensky) + + All providers can now be used with a single consistent interface. + +steps: + # Example 1: GeoJSON provider with data generation only (no delivery) + - step: Stream Air Traffic + id: geojson_only + arguments: + provider: geojson + duration: 10 + target: none # Don't send anywhere, just return data + config_path: config/bern/trajectory_f1.json + number_of_aircraft: 2 + + # Example 2: GeoJSON with delivery to Flight Blender + - step: Stream Air Traffic + id: geojson_stream + arguments: + provider: geojson + duration: 30 + target: flight_blender + config_path: config/bern/trajectory_f1.json + number_of_aircraft: 2 + session_ids: + - "550e8400-e29b-41d4-a716-446655440001" + + # Example 3: BlueSky simulator (requires bluesky-simulator package) + # - step: Stream Air Traffic + # id: bluesky_stream + # arguments: + # provider: bluesky + # duration: 30 + # target: flight_blender + # config_path: config/bern/blue_sky_sim_bern.scn + + # Example 4: Live OpenSky data for Switzerland region + # - step: Stream Air Traffic + # id: opensky_live + # arguments: + # provider: opensky + # duration: 30 + # target: flight_blender + # viewport: [45.8389, 47.8229, 5.9962, 10.5226] diff --git a/src/openutm_verification/core/execution/dependencies.py b/src/openutm_verification/core/execution/dependencies.py index f27e8ef..76f9b77 100644 --- a/src/openutm_verification/core/execution/dependencies.py +++ b/src/openutm_verification/core/execution/dependencies.py @@ -12,7 +12,11 @@ from openutm_verification.core.clients.air_traffic.air_traffic_client import ( AirTrafficClient, ) -from openutm_verification.core.clients.air_traffic.base_client import AirTrafficSettings, BayesianAirTrafficSettings, BlueSkyAirTrafficSettings +from openutm_verification.core.clients.air_traffic.base_client import ( + AirTrafficSettings, + BayesianAirTrafficSettings, + BlueSkyAirTrafficSettings, +) from openutm_verification.core.clients.air_traffic.bayesian_air_traffic_client import ( BayesianTrafficClient, ) @@ -41,6 +45,7 @@ CONTEXT, dependency, ) +from openutm_verification.core.steps.air_traffic_step import AirTrafficStepClient from openutm_verification.server.runner import SessionManager from openutm_verification.utils.paths import get_docs_directory @@ -226,3 +231,10 @@ async def amqp_client(config: AppConfig) -> AsyncGenerator[AMQPClient, None]: settings = AMQPSettings.from_config(config.amqp) if config.amqp else AMQPSettings() async with AMQPClient(settings) as client: yield client + + +@dependency(AirTrafficStepClient) +async def air_traffic_step_client() -> AsyncGenerator[AirTrafficStepClient, None]: + """Provides an AirTrafficStepClient instance for the unified Stream Air Traffic step.""" + async with AirTrafficStepClient() as client: + yield client diff --git a/src/openutm_verification/core/providers/__init__.py b/src/openutm_verification/core/providers/__init__.py new file mode 100644 index 0000000..6ce8c3e --- /dev/null +++ b/src/openutm_verification/core/providers/__init__.py @@ -0,0 +1,15 @@ +"""Air traffic providers module. + +Providers generate or fetch air traffic observation data from various sources. +""" + +from .factory import ProviderType, create_provider +from .opensky_provider import DEFAULT_SWITZERLAND_VIEWPORT +from .protocol import AirTrafficProvider + +__all__ = [ + "AirTrafficProvider", + "DEFAULT_SWITZERLAND_VIEWPORT", + "ProviderType", + "create_provider", +] diff --git a/src/openutm_verification/core/providers/bayesian_provider.py b/src/openutm_verification/core/providers/bayesian_provider.py new file mode 100644 index 0000000..4facc68 --- /dev/null +++ b/src/openutm_verification/core/providers/bayesian_provider.py @@ -0,0 +1,98 @@ +"""Bayesian air traffic provider - wraps BayesianTrafficClient.""" + +from __future__ import annotations + +from openutm_verification.core.clients.air_traffic.base_client import ( + BayesianAirTrafficSettings, +) +from openutm_verification.core.clients.air_traffic.bayesian_air_traffic_client import ( + BayesianTrafficClient, +) +from openutm_verification.simulator.models.flight_data_types import ( + FlightObservationSchema, +) + + +class BayesianProvider: + """Provider that generates air traffic using Bayesian track generation. + + Wraps the existing BayesianTrafficClient to provide a consistent interface. + Note: Requires the cam-track-gen package to be installed. + """ + + def __init__( + self, + config_path: str | None = None, + number_of_aircraft: int | None = None, + duration: int | None = None, + sensor_ids: list[str] | None = None, + session_ids: list[str] | None = None, + ): + """Initialize the Bayesian provider. + + Args: + config_path: Path to config (currently unused by Bayesian client). + number_of_aircraft: Number of aircraft to simulate. + duration: Simulation duration in seconds. + sensor_ids: List of sensor UUID strings. + session_ids: List of session UUID strings. + """ + self._config_path = config_path or "" + self._number_of_aircraft = number_of_aircraft or 2 + self._duration = duration or 30 + self._sensor_ids = sensor_ids or [] + self._session_ids = session_ids or [] + + @property + def name(self) -> str: + """Provider identifier.""" + return "bayesian" + + @classmethod + def from_kwargs( + cls, + config_path: str | None = None, + number_of_aircraft: int | None = None, + duration: int | None = None, + sensor_ids: list[str] | None = None, + session_ids: list[str] | None = None, + **_kwargs, # Ignore unknown kwargs for flexibility + ) -> "BayesianProvider": + """Factory method to create provider from keyword arguments.""" + return cls( + config_path=config_path, + number_of_aircraft=number_of_aircraft, + duration=duration, + sensor_ids=sensor_ids, + session_ids=session_ids, + ) + + async def get_observations( + self, + duration: int | None = None, + ) -> list[list[FlightObservationSchema]]: + """Generate observations using the underlying BayesianTrafficClient. + + Args: + duration: Override duration in seconds. + + Returns: + List of observation lists per aircraft. + """ + effective_duration = duration or self._duration + + settings = BayesianAirTrafficSettings( + simulation_config_path=self._config_path, + simulation_duration_seconds=effective_duration, + number_of_aircraft=self._number_of_aircraft, + sensor_ids=self._sensor_ids, + session_ids=self._session_ids, + ) + + async with BayesianTrafficClient(settings) as client: + result = await client.generate_bayesian_sim_air_traffic_data( + config_path=self._config_path, + duration=effective_duration, + ) + # Handle case where Bayesian client returns None or empty + return result if result else [] diff --git a/src/openutm_verification/core/providers/bluesky_provider.py b/src/openutm_verification/core/providers/bluesky_provider.py new file mode 100644 index 0000000..a69ecd2 --- /dev/null +++ b/src/openutm_verification/core/providers/bluesky_provider.py @@ -0,0 +1,96 @@ +"""BlueSky simulation air traffic provider - wraps BlueSkyClient.""" + +from __future__ import annotations + +from openutm_verification.core.clients.air_traffic.base_client import ( + BlueSkyAirTrafficSettings, +) +from openutm_verification.core.clients.air_traffic.blue_sky_client import ( + BlueSkyClient, +) +from openutm_verification.simulator.models.flight_data_types import ( + FlightObservationSchema, +) + + +class BlueSkyProvider: + """Provider that generates air traffic from BlueSky simulator scenarios. + + Wraps the existing BlueSkyClient to provide a consistent interface. + Note: Requires the bluesky-simulator package to be installed. + """ + + def __init__( + self, + config_path: str | None = None, + number_of_aircraft: int | None = None, + duration: int | None = None, + sensor_ids: list[str] | None = None, + session_ids: list[str] | None = None, + ): + """Initialize the BlueSky provider. + + Args: + config_path: Path to the BlueSky .scn scenario file. + number_of_aircraft: Number of aircraft to simulate. + duration: Simulation duration in seconds. + sensor_ids: List of sensor UUID strings. + session_ids: List of session UUID strings. + """ + self._config_path = config_path or "" + self._number_of_aircraft = number_of_aircraft or 2 + self._duration = duration or 30 + self._sensor_ids = sensor_ids or [] + self._session_ids = session_ids or [] + + @property + def name(self) -> str: + """Provider identifier.""" + return "bluesky" + + @classmethod + def from_kwargs( + cls, + config_path: str | None = None, + number_of_aircraft: int | None = None, + duration: int | None = None, + sensor_ids: list[str] | None = None, + session_ids: list[str] | None = None, + **_kwargs, # Ignore unknown kwargs for flexibility + ) -> "BlueSkyProvider": + """Factory method to create provider from keyword arguments.""" + return cls( + config_path=config_path, + number_of_aircraft=number_of_aircraft, + duration=duration, + sensor_ids=sensor_ids, + session_ids=session_ids, + ) + + async def get_observations( + self, + duration: int | None = None, + ) -> list[list[FlightObservationSchema]]: + """Generate observations using the underlying BlueSkyClient. + + Args: + duration: Override duration in seconds. + + Returns: + List of observation lists per aircraft. + """ + effective_duration = duration or self._duration + + settings = BlueSkyAirTrafficSettings( + simulation_config_path=self._config_path, + simulation_duration_seconds=effective_duration, + number_of_aircraft=self._number_of_aircraft, + sensor_ids=self._sensor_ids, + session_ids=self._session_ids, + ) + + async with BlueSkyClient(settings) as client: + return await client.generate_bluesky_sim_air_traffic_data( + config_path=self._config_path, + duration=effective_duration, + ) diff --git a/src/openutm_verification/core/providers/factory.py b/src/openutm_verification/core/providers/factory.py new file mode 100644 index 0000000..b494c72 --- /dev/null +++ b/src/openutm_verification/core/providers/factory.py @@ -0,0 +1,61 @@ +"""Factory for creating air traffic providers.""" + +from typing import Literal + +from .bayesian_provider import BayesianProvider +from .bluesky_provider import BlueSkyProvider +from .geojson_provider import GeoJSONProvider +from .opensky_provider import OpenSkyProvider +from .protocol import AirTrafficProvider + +ProviderType = Literal["geojson", "bluesky", "bayesian", "opensky"] + + +def create_provider( + name: ProviderType, + *, + config_path: str | None = None, + number_of_aircraft: int | None = None, + duration: int | None = None, + sensor_ids: list[str] | None = None, + session_ids: list[str] | None = None, + viewport: tuple[float, float, float, float] | None = None, + **kwargs, +) -> AirTrafficProvider: + """Factory function to create providers by name. + + Args: + name: Provider type - geojson, bluesky, bayesian, or opensky. + config_path: Path to configuration file (provider-specific). + number_of_aircraft: Number of aircraft to simulate. + duration: Simulation/fetch duration in seconds. + sensor_ids: List of sensor UUID strings. + session_ids: List of session UUID strings. + viewport: Geographic bounds for OpenSky (lat_min, lat_max, lon_min, lon_max). + **kwargs: Additional provider-specific arguments. + + Returns: + An AirTrafficProvider instance. + + Raises: + ValueError: If the provider name is not recognized. + """ + providers: dict[str, type] = { + "geojson": GeoJSONProvider, + "bluesky": BlueSkyProvider, + "bayesian": BayesianProvider, + "opensky": OpenSkyProvider, + } + + if name not in providers: + raise ValueError(f"Unknown provider: {name}. Available: {list(providers.keys())}") + + return providers[name].from_kwargs( + config_path=config_path, + number_of_aircraft=number_of_aircraft, + duration=duration, + sensor_ids=sensor_ids, + session_ids=session_ids, + viewport=viewport, + **kwargs, + ) diff --git a/src/openutm_verification/core/providers/geojson_provider.py b/src/openutm_verification/core/providers/geojson_provider.py new file mode 100644 index 0000000..8b59e99 --- /dev/null +++ b/src/openutm_verification/core/providers/geojson_provider.py @@ -0,0 +1,99 @@ +"""GeoJSON air traffic provider - wraps AirTrafficClient.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from openutm_verification.core.clients.air_traffic.air_traffic_client import ( + AirTrafficClient, +) +from openutm_verification.core.clients.air_traffic.base_client import ( + AirTrafficSettings, +) + +if TYPE_CHECKING: + from openutm_verification.simulator.models.flight_data_types import ( + FlightObservationSchema, + ) + + +class GeoJSONProvider: + """Provider that generates air traffic from GeoJSON trajectory files. + + Wraps the existing AirTrafficClient to provide a consistent interface. + """ + + def __init__( + self, + config_path: str | None = None, + number_of_aircraft: int | None = None, + duration: int | None = None, + sensor_ids: list[str] | None = None, + session_ids: list[str] | None = None, + ): + """Initialize the GeoJSON provider. + + Args: + config_path: Path to the GeoJSON trajectory file. + number_of_aircraft: Number of aircraft to simulate. + duration: Simulation duration in seconds. + sensor_ids: List of sensor UUID strings. + session_ids: List of session UUID strings. + """ + self._config_path = config_path or "" + self._number_of_aircraft = number_of_aircraft or 2 + self._duration = duration or 30 + self._sensor_ids = sensor_ids or [] + self._session_ids = session_ids or [] + + @property + def name(self) -> str: + """Provider identifier.""" + return "geojson" + + @classmethod + def from_kwargs( + cls, + config_path: str | None = None, + number_of_aircraft: int | None = None, + duration: int | None = None, + sensor_ids: list[str] | None = None, + session_ids: list[str] | None = None, + **_kwargs, # Ignore unknown kwargs for flexibility + ) -> "GeoJSONProvider": + """Factory method to create provider from keyword arguments.""" + return cls( + config_path=config_path, + number_of_aircraft=number_of_aircraft, + duration=duration, + sensor_ids=sensor_ids, + session_ids=session_ids, + ) + + async def get_observations( + self, + duration: int | None = None, + ) -> list[list["FlightObservationSchema"]]: + """Generate observations using the underlying AirTrafficClient. + + Args: + duration: Override duration in seconds. + + Returns: + List of observation lists per aircraft. + """ + effective_duration = duration or self._duration + + settings = AirTrafficSettings( + simulation_config_path=self._config_path, + simulation_duration=effective_duration, + number_of_aircraft=self._number_of_aircraft, + sensor_ids=self._sensor_ids, + session_ids=self._session_ids, + ) + + async with AirTrafficClient(settings) as client: + return await client.generate_simulated_air_traffic_data( + config_path=self._config_path, + duration=effective_duration, + ) diff --git a/src/openutm_verification/core/providers/opensky_provider.py b/src/openutm_verification/core/providers/opensky_provider.py new file mode 100644 index 0000000..64ba48d --- /dev/null +++ b/src/openutm_verification/core/providers/opensky_provider.py @@ -0,0 +1,88 @@ +"""OpenSky Network live air traffic provider - wraps OpenSkyClient.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from openutm_verification.core.clients.opensky.base_client import OpenSkySettings +from openutm_verification.core.clients.opensky.opensky_client import OpenSkyClient +from openutm_verification.core.execution.config_models import get_settings + +if TYPE_CHECKING: + from openutm_verification.simulator.models.flight_data_types import ( + FlightObservationSchema, + ) + +# Default viewport covering Switzerland (lat_min, lat_max, lon_min, lon_max) +DEFAULT_SWITZERLAND_VIEWPORT: tuple[float, float, float, float] = (45.8389, 47.8229, 5.9962, 10.5226) + + +class OpenSkyProvider: + """Provider that fetches live air traffic from OpenSky Network. + + Wraps the existing OpenSkyClient to provide a consistent interface. + Note: OpenSky returns flat observation lists (not per-aircraft), so we + wrap them in an outer list for consistency with other providers. + """ + + def __init__( + self, + viewport: tuple[float, float, float, float] | None = None, + duration: int | None = None, + ): + """Initialize the OpenSky provider. + + Args: + viewport: Geographic bounds (lat_min, lat_max, lon_min, lon_max). + duration: Poll duration in seconds (how long to fetch data). + """ + self._viewport = viewport or DEFAULT_SWITZERLAND_VIEWPORT + self._duration = duration or 30 + + @property + def name(self) -> str: + """Provider identifier.""" + return "opensky" + + @classmethod + def from_kwargs( + cls, + viewport: tuple[float, float, float, float] | None = None, + duration: int | None = None, + **_kwargs, # Ignore unknown kwargs for flexibility + ) -> "OpenSkyProvider": + """Factory method to create provider from keyword arguments.""" + return cls( + viewport=viewport, + duration=duration, + ) + + async def get_observations( + self, + duration: int | None = None, + ) -> list[list["FlightObservationSchema"]]: + """Fetch observations from OpenSky Network. + + Args: + duration: Override duration in seconds (currently single fetch). + + Returns: + List containing a single observation list (all aircraft in one batch). + Returns empty list if no data available. + """ + # Get OpenSky config from application settings + config = get_settings() + + settings = OpenSkySettings( + client_id=config.opensky.auth.client_id, + client_secret=config.opensky.auth.client_secret, + viewport=self._viewport, + ) + + async with OpenSkyClient(settings) as client: + observations = await client.fetch_data() + if observations is None: + return [] + # Wrap flat list in outer list for interface consistency + # OpenSky returns all aircraft in a single list, not grouped by aircraft + return [observations] diff --git a/src/openutm_verification/core/providers/protocol.py b/src/openutm_verification/core/providers/protocol.py new file mode 100644 index 0000000..0a0f4a8 --- /dev/null +++ b/src/openutm_verification/core/providers/protocol.py @@ -0,0 +1,45 @@ +"""Protocol definitions for air traffic providers. + +Providers are responsible for generating or fetching air traffic observation data. +They abstract the data source (GeoJSON files, simulators, live APIs) behind a common interface. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Protocol, runtime_checkable + +if TYPE_CHECKING: + from openutm_verification.simulator.models.flight_data_types import ( + FlightObservationSchema, + ) + + +@runtime_checkable +class AirTrafficProvider(Protocol): + """Protocol for air traffic data sources. + + Providers generate or fetch batches of flight observations. Each provider + wraps a specific data source (GeoJSON, BlueSky, Bayesian, OpenSky) and + exposes a uniform interface for getting observations. + """ + + @property + def name(self) -> str: + """Provider identifier (e.g., 'geojson', 'bluesky', 'opensky').""" + ... + + async def get_observations( + self, + duration: int | None = None, + ) -> list[list["FlightObservationSchema"]]: + """Get observation batches for the configured duration. + + Args: + duration: Override for simulation/fetch duration in seconds. + If None, uses provider's default configuration. + + Returns: + List of observation lists - outer list is per aircraft/track, + inner list is the time series of observations. + """ + ... diff --git a/src/openutm_verification/core/steps/__init__.py b/src/openutm_verification/core/steps/__init__.py new file mode 100644 index 0000000..f1fbeee --- /dev/null +++ b/src/openutm_verification/core/steps/__init__.py @@ -0,0 +1,8 @@ +"""Scenario steps module. + +Contains unified scenario step implementations. +""" + +from .air_traffic_step import AirTrafficStepClient + +__all__ = ["AirTrafficStepClient"] diff --git a/src/openutm_verification/core/steps/air_traffic_step.py b/src/openutm_verification/core/steps/air_traffic_step.py new file mode 100644 index 0000000..5b44882 --- /dev/null +++ b/src/openutm_verification/core/steps/air_traffic_step.py @@ -0,0 +1,88 @@ +"""Unified air traffic streaming step. + +This module provides a single scenario step for all air traffic operations, +replacing the multiple provider-specific steps with a unified interface. +""" + +from openutm_verification.core.execution.scenario_runner import scenario_step +from openutm_verification.core.providers import ProviderType, create_provider +from openutm_verification.core.streamers import StreamResult, TargetType, create_streamer + + +class AirTrafficStepClient: + """Client providing the unified Stream Air Traffic step. + + This client wraps the provider/streamer architecture to expose a single + scenario step that can handle all air traffic generation and streaming + operations. + """ + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + pass + + @scenario_step("Stream Air Traffic") + async def stream_air_traffic( + self, + provider: ProviderType, + duration: int, + target: TargetType = "flight_blender", + *, + # Provider settings (optional overrides) + config_path: str | None = None, + number_of_aircraft: int | None = None, + sensor_ids: list[str] | None = None, + session_ids: list[str] | None = None, + viewport: tuple[float, float, float, float] | None = None, + ) -> StreamResult: + """Stream air traffic data from a provider to a target system. + + Unified step for all air traffic generation and streaming operations. + Supports synthetic data generation (GeoJSON, BlueSky, Bayesian) and + live data fetching (OpenSky Network). + + Args: + provider: Data source - geojson, bluesky, bayesian, or opensky. + duration: Streaming duration in seconds. + target: Delivery target - flight_blender, amqp, or none (default: flight_blender). + config_path: Path to configuration file (provider-specific). + number_of_aircraft: Number of aircraft to simulate. + sensor_ids: Sensor UUIDs for observations. + session_ids: Session UUIDs for grouping. + viewport: Geographic bounds for OpenSky (lat_min, lat_max, lon_min, lon_max). + + Returns: + StreamResult with success status, counts, and optionally the observations. + + Example YAML: + - step: Stream Air Traffic + arguments: + provider: geojson + duration: 30 + target: flight_blender + config_path: config/bern/trajectory.geojson + """ + # Build provider from arguments + provider_instance = create_provider( + name=provider, + config_path=config_path, + number_of_aircraft=number_of_aircraft, + duration=duration, + sensor_ids=sensor_ids, + session_ids=session_ids, + viewport=viewport, + ) + + # Build streamer (or null streamer for target=none) + streamer_instance = create_streamer( + name=target, + session_ids=session_ids, + ) + + # Execute streaming + return await streamer_instance.stream_from_provider( + provider=provider_instance, + duration_seconds=duration, + ) diff --git a/src/openutm_verification/core/streamers/__init__.py b/src/openutm_verification/core/streamers/__init__.py new file mode 100644 index 0000000..beabeb3 --- /dev/null +++ b/src/openutm_verification/core/streamers/__init__.py @@ -0,0 +1,14 @@ +"""Air traffic streamers module. + +Streamers deliver air traffic observations to target systems. +""" + +from .factory import TargetType, create_streamer +from .protocol import AirTrafficStreamer, StreamResult + +__all__ = [ + "AirTrafficStreamer", + "StreamResult", + "TargetType", + "create_streamer", +] diff --git a/src/openutm_verification/core/streamers/amqp_streamer.py b/src/openutm_verification/core/streamers/amqp_streamer.py new file mode 100644 index 0000000..7bc684c --- /dev/null +++ b/src/openutm_verification/core/streamers/amqp_streamer.py @@ -0,0 +1,70 @@ +"""AMQP streamer - sends observations to AMQP/RabbitMQ. + +Wraps the existing AMQPClient to provide a consistent streaming interface. +This is a placeholder implementation for future expansion. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from loguru import logger + +from .protocol import StreamResult + +if TYPE_CHECKING: + from openutm_verification.core.providers.protocol import AirTrafficProvider + + +class AMQPStreamer: + """Streamer that sends observations to AMQP/RabbitMQ. + + Note: This is currently a placeholder implementation. + Full AMQP streaming would require additional configuration + and message formatting logic. + """ + + @property + def name(self) -> str: + """Target identifier.""" + return "amqp" + + @classmethod + def from_kwargs(cls, **_kwargs) -> "AMQPStreamer": + """Factory method to create streamer from keyword arguments.""" + return cls() + + async def stream_from_provider( + self, + provider: "AirTrafficProvider", + duration_seconds: int, + ) -> StreamResult: + """Stream observations from provider to AMQP. + + Currently a placeholder - collects observations but logs a warning + that AMQP streaming is not fully implemented. + + Args: + provider: The air traffic provider to get observations from. + duration_seconds: Duration for observation generation. + + Returns: + StreamResult with observations (not actually sent to AMQP yet). + """ + logger.warning("AMQP streaming is not fully implemented. Observations will be collected but not sent to AMQP.") + + # Get observations from provider + observations = await provider.get_observations(duration=duration_seconds) + + total_observations = sum(len(batch) for batch in observations) + + return StreamResult( + success=True, + provider=provider.name, + target=self.name, + duration_seconds=duration_seconds, + total_observations=total_observations, + total_batches=len(observations), + errors=["AMQP streaming not fully implemented - data collected but not sent"], + observations=observations, + ) diff --git a/src/openutm_verification/core/streamers/factory.py b/src/openutm_verification/core/streamers/factory.py new file mode 100644 index 0000000..a1b14a2 --- /dev/null +++ b/src/openutm_verification/core/streamers/factory.py @@ -0,0 +1,41 @@ +"""Factory for creating air traffic streamers.""" + +from typing import Literal + +from .amqp_streamer import AMQPStreamer +from .flight_blender_streamer import FlightBlenderStreamer +from .null_streamer import NullStreamer +from .protocol import AirTrafficStreamer + +TargetType = Literal["flight_blender", "amqp", "none"] + + +def create_streamer( + name: TargetType, + *, + session_ids: list[str] | None = None, + **kwargs, +) -> AirTrafficStreamer: + """Factory function to create streamers by name. + + Args: + name: Target type - flight_blender, amqp, or none. + session_ids: Optional list of session UUID strings (for flight_blender). + **kwargs: Additional streamer-specific arguments. + + Returns: + An AirTrafficStreamer instance. + + Raises: + ValueError: If the streamer name is not recognized. + """ + streamers: dict[str, type] = { + "flight_blender": FlightBlenderStreamer, + "amqp": AMQPStreamer, + "none": NullStreamer, + } + + if name not in streamers: + raise ValueError(f"Unknown streamer: {name}. Available: {list(streamers.keys())}") + + return streamers[name].from_kwargs(session_ids=session_ids, **kwargs) diff --git a/src/openutm_verification/core/streamers/flight_blender_streamer.py b/src/openutm_verification/core/streamers/flight_blender_streamer.py new file mode 100644 index 0000000..e435cf9 --- /dev/null +++ b/src/openutm_verification/core/streamers/flight_blender_streamer.py @@ -0,0 +1,173 @@ +"""Flight Blender streamer - sends observations to Flight Blender API. + +Wraps the existing FlightBlenderClient's submit methods to provide +a consistent streaming interface. +""" + +from __future__ import annotations + +import uuid +from typing import TYPE_CHECKING + +from loguru import logger + +from openutm_verification.core.clients.flight_blender.flight_blender_client import ( + FlightBlenderClient, +) +from openutm_verification.core.execution.config_models import get_settings + +from .protocol import StreamResult + +if TYPE_CHECKING: + from openutm_verification.core.providers.protocol import AirTrafficProvider + + +class FlightBlenderStreamer: + """Streamer that sends observations to Flight Blender via HTTP API. + + Wraps the existing FlightBlenderClient's submit_simulated_air_traffic + method to provide the unified streaming interface. + """ + + def __init__(self, session_ids: list[uuid.UUID] | None = None): + """Initialize the Flight Blender streamer. + + Args: + session_ids: Optional list of session UUIDs for grouping observations. + """ + self._session_ids = session_ids + + @property + def name(self) -> str: + """Target identifier.""" + return "flight_blender" + + @classmethod + def from_kwargs( + cls, + session_ids: list[str] | None = None, + **_kwargs, + ) -> "FlightBlenderStreamer": + """Factory method to create streamer from configuration. + + Args: + session_ids: Optional list of session UUID strings. + """ + parsed_ids = None + if session_ids: + try: + parsed_ids = [uuid.UUID(sid) for sid in session_ids] + except ValueError: + logger.warning("Invalid session ID format detected, will auto-generate. Ensure session IDs are valid UUIDs.") + return cls(session_ids=parsed_ids) + + def _make_result( + self, + *, + success: bool, + provider_name: str, + duration_seconds: int, + total_observations: int = 0, + total_batches: int = 0, + errors: list[str] | None = None, + observations: list | None = None, + ) -> StreamResult: + """Helper to construct StreamResult with common fields.""" + return StreamResult( + success=success, + provider=provider_name, + target=self.name, + duration_seconds=duration_seconds, + total_observations=total_observations, + total_batches=total_batches, + errors=errors or [], + observations=observations or [], + ) + + async def stream_from_provider( + self, + provider: "AirTrafficProvider", + duration_seconds: int, + ) -> StreamResult: + """Stream observations from provider to Flight Blender. + + Gets observations from the provider, then submits them to Flight Blender + in real-time playback mode (one observation per second per aircraft). + + Args: + provider: The air traffic provider to get observations from. + duration_seconds: Duration for observation generation. + + Returns: + StreamResult with submission statistics. + """ + # Get observations from provider + observations = await provider.get_observations(duration=duration_seconds) + + if not observations: + return self._make_result( + success=True, + provider_name=provider.name, + duration_seconds=duration_seconds, + ) + + # Get and validate Flight Blender configuration + config = get_settings() + + if not config.flight_blender.url: + error_msg = "Flight Blender URL is not configured. Please set 'flight_blender.url' in your configuration." + logger.error(error_msg) + return self._make_result( + success=False, + provider_name=provider.name, + duration_seconds=duration_seconds, + errors=[error_msg], + observations=observations, + ) + + username = config.flight_blender.auth.username + password = config.flight_blender.auth.password + + if not username or not password: + error_msg = ( + "Flight Blender credentials are not configured. " + "Please set 'flight_blender.auth.username' and " + "'flight_blender.auth.password' in your configuration." + ) + logger.error(error_msg) + return self._make_result( + success=False, + provider_name=provider.name, + duration_seconds=duration_seconds, + errors=[error_msg], + observations=observations, + ) + + try: + async with FlightBlenderClient( + base_url=config.flight_blender.url, + credentials={"username": username, "password": password}, + ) as client: + result = await client.submit_simulated_air_traffic( + observations=observations, + session_ids=self._session_ids, + ) + + return self._make_result( + success=result.get("success", False), + provider_name=provider.name, + duration_seconds=duration_seconds, + total_observations=sum(len(batch) for batch in observations), + total_batches=len(observations), + observations=observations, + ) + + except Exception as e: + logger.error(f"Flight Blender streaming failed: {e}") + return self._make_result( + success=False, + provider_name=provider.name, + duration_seconds=duration_seconds, + errors=[str(e)], + observations=observations, + ) diff --git a/src/openutm_verification/core/streamers/null_streamer.py b/src/openutm_verification/core/streamers/null_streamer.py new file mode 100644 index 0000000..ea7cc24 --- /dev/null +++ b/src/openutm_verification/core/streamers/null_streamer.py @@ -0,0 +1,63 @@ +"""Null streamer - collects data without sending anywhere. + +Useful for testing, data generation without delivery, or when the target +system is not available. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from .protocol import StreamResult + +if TYPE_CHECKING: + from openutm_verification.core.providers.protocol import AirTrafficProvider + + +class NullStreamer: + """Streamer that collects observations without sending them anywhere. + + Useful for: + - Testing provider implementations + - Generating data for later processing + - Scenarios where you want to capture data without delivery + """ + + @property + def name(self) -> str: + """Target identifier.""" + return "none" + + @classmethod + def from_kwargs(cls, **_kwargs) -> "NullStreamer": + """Factory method to create streamer from keyword arguments.""" + return cls() + + async def stream_from_provider( + self, + provider: "AirTrafficProvider", + duration_seconds: int, + ) -> StreamResult: + """Collect observations from provider without sending. + + Args: + provider: The air traffic provider to get observations from. + duration_seconds: Duration passed to the provider. + + Returns: + StreamResult with collected observations. + """ + observations = await provider.get_observations(duration=duration_seconds) + + total_observations = sum(len(batch) for batch in observations) + + return StreamResult( + success=True, + provider=provider.name, + target=self.name, + duration_seconds=duration_seconds, + total_observations=total_observations, + total_batches=len(observations), + errors=[], + observations=observations, + ) diff --git a/src/openutm_verification/core/streamers/protocol.py b/src/openutm_verification/core/streamers/protocol.py new file mode 100644 index 0000000..21b1bd5 --- /dev/null +++ b/src/openutm_verification/core/streamers/protocol.py @@ -0,0 +1,66 @@ +"""Protocol definitions and data models for air traffic streamers. + +Streamers are responsible for delivering air traffic observations to target systems. +They abstract the delivery mechanism (HTTP, AMQP, etc.) behind a common interface. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Protocol, runtime_checkable + +if TYPE_CHECKING: + from openutm_verification.core.providers.protocol import AirTrafficProvider + from openutm_verification.simulator.models.flight_data_types import ( + FlightObservationSchema, + ) + + +@dataclass +class StreamResult: + """Result of a streaming operation. + + Contains success status, statistics, and optionally the observation data + for use by downstream steps. + """ + + success: bool + provider: str + target: str + duration_seconds: int + total_observations: int + total_batches: int + errors: list[str] = field(default_factory=list) + + # For downstream steps - stores the observations that were streamed + observations: list[list["FlightObservationSchema"]] | None = None + + +@runtime_checkable +class AirTrafficStreamer(Protocol): + """Protocol for delivering observations to a target system. + + Streamers take observations from a provider and deliver them to a specific + target (Flight Blender, AMQP, or nowhere for testing). + """ + + @property + def name(self) -> str: + """Target identifier (e.g., 'flight_blender', 'amqp', 'none').""" + ... + + async def stream_from_provider( + self, + provider: "AirTrafficProvider", + duration_seconds: int, + ) -> StreamResult: + """Stream all data from a provider to this target. + + Args: + provider: The air traffic provider to get observations from. + duration_seconds: Duration for the streaming operation. + + Returns: + StreamResult with statistics and optionally the observations. + """ + ... diff --git a/tests/test_stream_air_traffic.py b/tests/test_stream_air_traffic.py new file mode 100644 index 0000000..84cfb9c --- /dev/null +++ b/tests/test_stream_air_traffic.py @@ -0,0 +1,684 @@ +"""Tests for the unified Stream Air Traffic step.""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from openutm_verification.core.providers import ProviderType, create_provider +from openutm_verification.core.providers.geojson_provider import GeoJSONProvider +from openutm_verification.core.providers.opensky_provider import OpenSkyProvider +from openutm_verification.core.steps import AirTrafficStepClient +from openutm_verification.core.streamers import StreamResult, TargetType, create_streamer +from openutm_verification.core.streamers.null_streamer import NullStreamer +from openutm_verification.simulator.models.flight_data_types import FlightObservationSchema + + +class TestProviderFactory: + """Tests for the provider factory.""" + + def test_create_geojson_provider(self): + """Test creating a GeoJSON provider.""" + provider = create_provider( + name="geojson", + config_path="/some/path.geojson", + duration=60, + number_of_aircraft=3, + ) + assert isinstance(provider, GeoJSONProvider) + assert provider.name == "geojson" + + def test_create_opensky_provider(self): + """Test creating an OpenSky provider.""" + viewport = (45.0, 48.0, 6.0, 11.0) + provider = create_provider( + name="opensky", + viewport=viewport, + duration=30, + ) + assert isinstance(provider, OpenSkyProvider) + assert provider.name == "opensky" + + def test_create_unknown_provider_raises(self): + """Test that unknown provider name raises ValueError.""" + with pytest.raises(ValueError, match="Unknown provider"): + create_provider(name="unknown") # type: ignore + + +class TestStreamerFactory: + """Tests for the streamer factory.""" + + def test_create_null_streamer(self): + """Test creating a null streamer.""" + streamer = create_streamer(name="none") + assert isinstance(streamer, NullStreamer) + assert streamer.name == "none" + + def test_create_unknown_streamer_raises(self): + """Test that unknown streamer name raises ValueError.""" + with pytest.raises(ValueError, match="Unknown streamer"): + create_streamer(name="unknown") # type: ignore + + +class TestStreamResult: + """Tests for StreamResult dataclass.""" + + def test_stream_result_creation(self): + """Test creating a StreamResult.""" + result = StreamResult( + success=True, + provider="geojson", + target="none", + duration_seconds=30, + total_observations=100, + total_batches=2, + ) + assert result.success is True + assert result.provider == "geojson" + assert result.target == "none" + assert result.total_observations == 100 + assert result.errors == [] + assert result.observations is None + + def test_stream_result_with_errors(self): + """Test creating a StreamResult with errors.""" + result = StreamResult( + success=False, + provider="opensky", + target="flight_blender", + duration_seconds=30, + total_observations=0, + total_batches=0, + errors=["Connection failed", "Timeout"], + ) + assert result.success is False + assert len(result.errors) == 2 + + +class TestAirTrafficStepClient: + """Tests for the AirTrafficStepClient.""" + + @pytest.mark.asyncio + async def test_client_context_manager_returns_self(self): + """Test that __aenter__ returns the client instance.""" + client = AirTrafficStepClient() + async with client as entered_client: + assert entered_client is client + + @pytest.mark.asyncio + async def test_client_context_manager_calls_aexit(self): + """Test that __aexit__ is called when exiting the context.""" + aexit_called = False + original_aexit = AirTrafficStepClient.__aexit__ + + async def tracking_aexit(self, *args): + nonlocal aexit_called + aexit_called = True + return await original_aexit(self, *args) + + with patch.object(AirTrafficStepClient, "__aexit__", tracking_aexit): + async with AirTrafficStepClient(): + pass + + assert aexit_called, "__aexit__ should be called when exiting context" + + @pytest.mark.asyncio + async def test_client_context_manager_calls_aexit_on_exception(self): + """Test that __aexit__ is called even when an exception occurs.""" + aexit_called = False + original_aexit = AirTrafficStepClient.__aexit__ + + async def tracking_aexit(self, *args): + nonlocal aexit_called + aexit_called = True + return await original_aexit(self, *args) + + with patch.object(AirTrafficStepClient, "__aexit__", tracking_aexit): + with pytest.raises(ValueError): + async with AirTrafficStepClient(): + raise ValueError("Test exception") + + assert aexit_called, "__aexit__ should be called even when exception occurs" + + def test_step_registration(self): + """Test that the step is registered in STEP_REGISTRY.""" + from openutm_verification.core.execution.scenario_runner import STEP_REGISTRY + + # Force registration by importing the client + _ = AirTrafficStepClient + + assert "Stream Air Traffic" in STEP_REGISTRY + entry = STEP_REGISTRY["Stream Air Traffic"] + assert entry.client_class == AirTrafficStepClient + assert entry.method_name == "stream_air_traffic" + + def test_step_param_model_has_required_fields(self): + """Test that the parameter model has the expected fields.""" + from openutm_verification.core.execution.scenario_runner import STEP_REGISTRY + + entry = STEP_REGISTRY["Stream Air Traffic"] + param_model = entry.param_model + + # Check required fields are present + fields = param_model.model_fields + assert "provider" in fields + assert "duration" in fields + assert "target" in fields + + def test_provider_type_literal(self): + """Test that ProviderType includes expected values.""" + from typing import get_args + + expected = {"geojson", "bluesky", "bayesian", "opensky"} + actual = set(get_args(ProviderType)) + assert actual == expected + + def test_target_type_literal(self): + """Test that TargetType includes expected values.""" + from typing import get_args + + expected = {"flight_blender", "amqp", "none"} + actual = set(get_args(TargetType)) + assert actual == expected + + +# ============================================================================ +# Integration Tests with Mocked Clients +# ============================================================================ + + +def _create_mock_observations(): + """Helper to create mock flight observations.""" + return [ + [ + FlightObservationSchema( + lat_dd=46.9, + lon_dd=7.4, + altitude_mm=1000000, + traffic_source=0, + source_type=0, + icao_address="ABC123", + timestamp=1234567890, + ), + ], + [ + FlightObservationSchema( + lat_dd=46.95, + lon_dd=7.45, + altitude_mm=1100000, + traffic_source=0, + source_type=0, + icao_address="DEF456", + timestamp=1234567891, + ), + ], + ] + + +class TestGeoJSONProviderIntegration: + """Integration tests for GeoJSONProvider with mocked AirTrafficClient.""" + + @pytest.mark.asyncio + @patch("openutm_verification.core.providers.geojson_provider.AirTrafficClient") + async def test_geojson_provider_instantiates_client_with_correct_settings(self, mock_client_class): + """Test that GeoJSONProvider passes correct settings to AirTrafficClient.""" + mock_observations = _create_mock_observations() + + # Setup mock client instance + mock_client_instance = AsyncMock() + mock_client_instance.generate_simulated_air_traffic_data = AsyncMock(return_value=mock_observations) + mock_client_class.return_value.__aenter__ = AsyncMock(return_value=mock_client_instance) + mock_client_class.return_value.__aexit__ = AsyncMock(return_value=None) + + # Create provider with specific settings + provider = GeoJSONProvider( + config_path="/path/to/trajectory.geojson", + number_of_aircraft=3, + duration=60, + sensor_ids=["sensor-uuid-1", "sensor-uuid-2"], + session_ids=["session-uuid-1"], + ) + + # Call get_observations + result = await provider.get_observations(duration=45) + + # Verify AirTrafficClient was instantiated with correct settings + mock_client_class.assert_called_once() + call_args = mock_client_class.call_args + settings = call_args[0][0] # First positional argument + + assert settings.simulation_config_path == "/path/to/trajectory.geojson" + assert settings.simulation_duration == 45 # Should use the override duration + assert settings.number_of_aircraft == 3 + assert settings.sensor_ids == ["sensor-uuid-1", "sensor-uuid-2"] + assert settings.session_ids == ["session-uuid-1"] + + # Verify the method was called with correct arguments + mock_client_instance.generate_simulated_air_traffic_data.assert_called_once_with( + config_path="/path/to/trajectory.geojson", + duration=45, + ) + + # Verify result + assert result == mock_observations + + @pytest.mark.asyncio + @patch("openutm_verification.core.providers.geojson_provider.AirTrafficClient") + async def test_geojson_provider_uses_default_duration_when_not_overridden(self, mock_client_class): + """Test that GeoJSONProvider uses constructor duration when not overridden.""" + mock_observations = _create_mock_observations() + + mock_client_instance = AsyncMock() + mock_client_instance.generate_simulated_air_traffic_data = AsyncMock(return_value=mock_observations) + mock_client_class.return_value.__aenter__ = AsyncMock(return_value=mock_client_instance) + mock_client_class.return_value.__aexit__ = AsyncMock(return_value=None) + + provider = GeoJSONProvider(config_path="/test.geojson", duration=120) + + # Call without duration override + await provider.get_observations() + + # Should use constructor duration (120) + settings = mock_client_class.call_args[0][0] + assert settings.simulation_duration == 120 + + +class TestBlueSkyProviderIntegration: + """Integration tests for BlueSkyProvider with mocked BlueSkyClient.""" + + @pytest.mark.asyncio + @patch("openutm_verification.core.providers.bluesky_provider.BlueSkyClient") + async def test_bluesky_provider_instantiates_client_with_correct_settings(self, mock_client_class): + """Test that BlueSkyProvider passes correct settings to BlueSkyClient.""" + from openutm_verification.core.providers.bluesky_provider import BlueSkyProvider + + mock_observations = [ + [ + FlightObservationSchema( + lat_dd=46.9, + lon_dd=7.4, + altitude_mm=5000000, + traffic_source=0, + source_type=0, + icao_address="BLUESKY1", + timestamp=1234567890, + ), + ], + ] + + mock_client_instance = AsyncMock() + mock_client_instance.generate_bluesky_sim_air_traffic_data = AsyncMock(return_value=mock_observations) + mock_client_class.return_value.__aenter__ = AsyncMock(return_value=mock_client_instance) + mock_client_class.return_value.__aexit__ = AsyncMock(return_value=None) + + provider = BlueSkyProvider( + config_path="/path/to/scenario.scn", + number_of_aircraft=2, + duration=30, + sensor_ids=["sensor-1"], + session_ids=["session-1"], + ) + + result = await provider.get_observations(duration=25) + + # Verify BlueSkyClient was instantiated with correct settings + mock_client_class.assert_called_once() + settings = mock_client_class.call_args[0][0] + + assert settings.simulation_config_path == "/path/to/scenario.scn" + assert settings.simulation_duration_seconds == 25 + assert settings.number_of_aircraft == 2 + assert settings.sensor_ids == ["sensor-1"] + assert settings.session_ids == ["session-1"] + + # Verify method call + mock_client_instance.generate_bluesky_sim_air_traffic_data.assert_called_once_with( + config_path="/path/to/scenario.scn", + duration=25, + ) + + assert result == mock_observations + + +class TestBayesianProviderIntegration: + """Integration tests for BayesianProvider with mocked BayesianTrafficClient.""" + + @pytest.mark.asyncio + @patch("openutm_verification.core.providers.bayesian_provider.BayesianTrafficClient") + async def test_bayesian_provider_instantiates_client_with_correct_settings(self, mock_client_class): + """Test that BayesianProvider passes correct settings to BayesianTrafficClient.""" + from openutm_verification.core.providers.bayesian_provider import BayesianProvider + + mock_observations = [ + [ + FlightObservationSchema( + lat_dd=47.0, + lon_dd=8.0, + altitude_mm=3000000, + traffic_source=0, + source_type=0, + icao_address="BAYES1", + timestamp=1234567890, + ), + ], + ] + + mock_client_instance = AsyncMock() + mock_client_instance.generate_bayesian_sim_air_traffic_data = AsyncMock(return_value=mock_observations) + mock_client_class.return_value.__aenter__ = AsyncMock(return_value=mock_client_instance) + mock_client_class.return_value.__aexit__ = AsyncMock(return_value=None) + + provider = BayesianProvider( + config_path="/path/to/model.mat", + number_of_aircraft=5, + duration=100, + sensor_ids=["sensor-bayesian"], + session_ids=["session-bayesian"], + ) + + result = await provider.get_observations(duration=80) + + # Verify BayesianTrafficClient was instantiated with correct settings + mock_client_class.assert_called_once() + settings = mock_client_class.call_args[0][0] + + assert settings.simulation_config_path == "/path/to/model.mat" + assert settings.simulation_duration_seconds == 80 + assert settings.number_of_aircraft == 5 + assert settings.sensor_ids == ["sensor-bayesian"] + assert settings.session_ids == ["session-bayesian"] + + assert result == mock_observations + + @pytest.mark.asyncio + @patch("openutm_verification.core.providers.bayesian_provider.BayesianTrafficClient") + async def test_bayesian_provider_handles_none_result(self, mock_client_class): + """Test that BayesianProvider returns empty list when client returns None.""" + from openutm_verification.core.providers.bayesian_provider import BayesianProvider + + mock_client_instance = AsyncMock() + mock_client_instance.generate_bayesian_sim_air_traffic_data = AsyncMock(return_value=None) + mock_client_class.return_value.__aenter__ = AsyncMock(return_value=mock_client_instance) + mock_client_class.return_value.__aexit__ = AsyncMock(return_value=None) + + provider = BayesianProvider() + result = await provider.get_observations() + + assert result == [] + + +class TestOpenSkyProviderIntegration: + """Integration tests for OpenSkyProvider with mocked OpenSkyClient.""" + + @pytest.mark.asyncio + @patch("openutm_verification.core.providers.opensky_provider.OpenSkyClient") + @patch("openutm_verification.core.providers.opensky_provider.get_settings") + async def test_opensky_provider_instantiates_client_with_correct_viewport(self, mock_get_settings, mock_client_class): + """Test that OpenSkyProvider passes correct viewport settings to OpenSkyClient.""" + mock_observations = [ + FlightObservationSchema( + lat_dd=46.5, + lon_dd=7.0, + altitude_mm=10000000, + traffic_source=2, + source_type=1, + icao_address="LIVE123", + timestamp=1234567890, + ), + FlightObservationSchema( + lat_dd=47.0, + lon_dd=8.0, + altitude_mm=11000000, + traffic_source=2, + source_type=1, + icao_address="LIVE456", + timestamp=1234567890, + ), + ] + + # Mock get_settings + mock_config = MagicMock() + mock_config.opensky.auth.client_id = "test-client-id" + mock_config.opensky.auth.client_secret = "test-client-secret" + mock_get_settings.return_value = mock_config + + mock_client_instance = AsyncMock() + mock_client_instance.fetch_data = AsyncMock(return_value=mock_observations) + mock_client_class.return_value.__aenter__ = AsyncMock(return_value=mock_client_instance) + mock_client_class.return_value.__aexit__ = AsyncMock(return_value=None) + + custom_viewport = (44.0, 49.0, 5.0, 12.0) + provider = OpenSkyProvider(viewport=custom_viewport, duration=60) + + result = await provider.get_observations() + + # Verify OpenSkyClient was instantiated with correct settings + mock_client_class.assert_called_once() + settings = mock_client_class.call_args[0][0] + + assert settings.client_id == "test-client-id" + assert settings.client_secret == "test-client-secret" + assert settings.viewport == custom_viewport + + # Verify fetch_data was called + mock_client_instance.fetch_data.assert_called_once() + + # Result should be wrapped in outer list for consistency + assert result == [mock_observations] + + @pytest.mark.asyncio + @patch("openutm_verification.core.providers.opensky_provider.OpenSkyClient") + @patch("openutm_verification.core.providers.opensky_provider.get_settings") + async def test_opensky_provider_returns_empty_list_when_no_data(self, mock_get_settings, mock_client_class): + """Test that OpenSkyProvider returns empty list when no data available.""" + mock_config = MagicMock() + mock_config.opensky.auth.client_id = "test-id" + mock_config.opensky.auth.client_secret = "test-secret" + mock_get_settings.return_value = mock_config + + mock_client_instance = AsyncMock() + mock_client_instance.fetch_data = AsyncMock(return_value=None) + mock_client_class.return_value.__aenter__ = AsyncMock(return_value=mock_client_instance) + mock_client_class.return_value.__aexit__ = AsyncMock(return_value=None) + + provider = OpenSkyProvider() + result = await provider.get_observations() + + assert result == [] + + +class TestFlightBlenderStreamerIntegration: + """Integration tests for FlightBlenderStreamer with mocked FlightBlenderClient.""" + + @pytest.mark.asyncio + @patch("openutm_verification.core.streamers.flight_blender_streamer.FlightBlenderClient") + @patch("openutm_verification.core.streamers.flight_blender_streamer.get_settings") + async def test_flight_blender_streamer_submits_to_client(self, mock_get_settings, mock_fb_class): + """Test that FlightBlenderStreamer properly submits observations to FlightBlenderClient.""" + from openutm_verification.core.streamers.flight_blender_streamer import FlightBlenderStreamer + + mock_observations = _create_mock_observations() + + # Mock get_settings + mock_config = MagicMock() + mock_config.flight_blender.url = "http://test-flight-blender:8080" + mock_config.flight_blender.auth.username = "test-user" + mock_config.flight_blender.auth.password = "test-pass" + mock_get_settings.return_value = mock_config + + mock_fb_client = AsyncMock() + mock_fb_client.submit_simulated_air_traffic = AsyncMock(return_value={"success": True, "observations_submitted": 1}) + mock_fb_class.return_value.__aenter__ = AsyncMock(return_value=mock_fb_client) + mock_fb_class.return_value.__aexit__ = AsyncMock(return_value=None) + + # Create a mock provider + mock_provider = AsyncMock() + mock_provider.name = "geojson" + mock_provider.get_observations = AsyncMock(return_value=mock_observations) + + streamer = FlightBlenderStreamer() + result = await streamer.stream_from_provider(mock_provider, duration_seconds=30) + + # Verify FlightBlenderClient was instantiated with correct credentials + mock_fb_class.assert_called_once() + call_kwargs = mock_fb_class.call_args[1] + assert call_kwargs["base_url"] == "http://test-flight-blender:8080" + assert call_kwargs["credentials"]["username"] == "test-user" + assert call_kwargs["credentials"]["password"] == "test-pass" + + # Verify submit was called with observations + mock_fb_client.submit_simulated_air_traffic.assert_called_once() + call_args = mock_fb_client.submit_simulated_air_traffic.call_args + assert call_args[1]["observations"] == mock_observations + + # Verify result + assert result.success is True + assert result.provider == "geojson" + assert result.target == "flight_blender" + assert result.total_batches == 2 + + @pytest.mark.asyncio + async def test_flight_blender_streamer_handles_empty_observations(self): + """Test that FlightBlenderStreamer handles empty observations gracefully.""" + from openutm_verification.core.streamers.flight_blender_streamer import FlightBlenderStreamer + + mock_provider = AsyncMock() + mock_provider.name = "geojson" + mock_provider.get_observations = AsyncMock(return_value=[]) + + streamer = FlightBlenderStreamer() + result = await streamer.stream_from_provider(mock_provider, duration_seconds=10) + + # Should succeed with zero observations, without calling FlightBlenderClient + assert result.success is True + assert result.total_observations == 0 + assert result.total_batches == 0 + + @pytest.mark.asyncio + @patch("openutm_verification.core.streamers.flight_blender_streamer.FlightBlenderClient") + @patch("openutm_verification.core.streamers.flight_blender_streamer.get_settings") + async def test_flight_blender_streamer_handles_client_error(self, mock_get_settings, mock_fb_class): + """Test that FlightBlenderStreamer handles client errors gracefully.""" + from openutm_verification.core.streamers.flight_blender_streamer import FlightBlenderStreamer + + mock_observations = _create_mock_observations() + + mock_config = MagicMock() + mock_config.flight_blender.url = "http://test:8080" + mock_config.flight_blender.auth.username = "user" + mock_config.flight_blender.auth.password = "pass" + mock_get_settings.return_value = mock_config + + mock_fb_client = AsyncMock() + mock_fb_client.submit_simulated_air_traffic = AsyncMock(side_effect=Exception("Connection refused")) + mock_fb_class.return_value.__aenter__ = AsyncMock(return_value=mock_fb_client) + mock_fb_class.return_value.__aexit__ = AsyncMock(return_value=None) + + mock_provider = AsyncMock() + mock_provider.name = "geojson" + mock_provider.get_observations = AsyncMock(return_value=mock_observations) + + streamer = FlightBlenderStreamer() + result = await streamer.stream_from_provider(mock_provider, duration_seconds=30) + + assert result.success is False + assert "Connection refused" in result.errors[0] + + +class TestNullStreamerIntegration: + """Integration tests for NullStreamer.""" + + @pytest.mark.asyncio + async def test_null_streamer_collects_observations_without_sending(self): + """Test that NullStreamer collects all observations and returns them.""" + mock_observations = [ + [ + FlightObservationSchema( + lat_dd=46.9, + lon_dd=7.4, + altitude_mm=1000000, + traffic_source=0, + source_type=0, + icao_address="NULL1", + timestamp=1234567890, + ), + ], + [ + FlightObservationSchema( + lat_dd=47.0, + lon_dd=7.5, + altitude_mm=1100000, + traffic_source=0, + source_type=0, + icao_address="NULL2", + timestamp=1234567891, + ), + ], + ] + + mock_provider = AsyncMock() + mock_provider.name = "geojson" + mock_provider.get_observations = AsyncMock(return_value=mock_observations) + + streamer = NullStreamer() + result = await streamer.stream_from_provider(mock_provider, duration_seconds=30) + + # Verify provider was called with correct duration + mock_provider.get_observations.assert_called_once_with(duration=30) + + # Verify result contains all observations + assert result.success is True + assert result.target == "none" + assert result.total_batches == 2 + assert result.total_observations == 2 + assert result.observations == mock_observations + + +class TestEndToEndStreamAirTraffic: + """End-to-end tests for the Stream Air Traffic step.""" + + @pytest.mark.asyncio + @patch("openutm_verification.core.providers.geojson_provider.AirTrafficClient") + async def test_stream_air_traffic_with_null_target(self, mock_client_class): + """Test complete flow: provider -> null streamer.""" + from openutm_verification.core.reporting.reporting_models import Status + + mock_observations = [ + [ + FlightObservationSchema( + lat_dd=46.9, + lon_dd=7.4, + altitude_mm=1000000, + traffic_source=0, + source_type=0, + icao_address="E2E1", + timestamp=1234567890, + ), + ], + ] + + # Mock the AirTrafficClient used by GeoJSONProvider + mock_client_instance = AsyncMock() + mock_client_instance.generate_simulated_air_traffic_data = AsyncMock(return_value=mock_observations) + mock_client_class.return_value.__aenter__ = AsyncMock(return_value=mock_client_instance) + mock_client_class.return_value.__aexit__ = AsyncMock(return_value=None) + + async with AirTrafficStepClient() as client: + step_result = await client.stream_air_traffic( + provider="geojson", + duration=10, + target="none", + config_path="/test/path.geojson", + number_of_aircraft=1, + ) + + # The @scenario_step decorator wraps the result in a StepResult + assert step_result.status == Status.PASS + assert step_result.name == "Stream Air Traffic" + + # The inner StreamResult is in step_result.result + stream_result = step_result.result + assert stream_result.success is True + assert stream_result.provider == "geojson" + assert stream_result.target == "none" + assert stream_result.total_observations == 1 + assert stream_result.observations == mock_observations diff --git a/uv.lock b/uv.lock index 4b2300f..ab71ad3 100644 --- a/uv.lock +++ b/uv.lock @@ -1406,6 +1406,7 @@ source = { editable = "." } dependencies = [ { name = "arrow" }, { name = "bluesky-simulator" }, + { name = "cam-track-gen" }, { name = "cryptography" }, { name = "dacite" }, { name = "earcut" }, @@ -1444,9 +1445,6 @@ dependencies = [ ] [package.dev-dependencies] -bayesian-track-generation = [ - { name = "cam-track-gen" }, -] dev = [ { name = "mypy" }, { name = "pre-commit" }, @@ -1467,6 +1465,7 @@ dev = [ requires-dist = [ { name = "arrow", specifier = "==1.3.0" }, { name = "bluesky-simulator", specifier = "==1.1.0" }, + { name = "cam-track-gen", git = "https://github.com/openutm-labs/Canadian-Airspace-Models.git" }, { name = "cryptography", specifier = "==44.0.3" }, { name = "dacite", specifier = ">=1.9.2" }, { name = "earcut", specifier = ">=1.1.5" }, @@ -1505,7 +1504,6 @@ requires-dist = [ ] [package.metadata.requires-dev] -bayesian-track-generation = [{ name = "cam-track-gen", git = "https://github.com/openutm-labs/Canadian-Airspace-Models.git" }] dev = [ { name = "mypy", specifier = ">=1.11.2" }, { name = "pre-commit", specifier = ">=4.3.0" },