-
Notifications
You must be signed in to change notification settings - Fork 85
[CDF] to_cdf #519
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
[CDF] to_cdf #519
Changes from all commits
e252c91
1e2d6e6
8eaafbb
fbf20c6
b5d7708
3e1073a
e249156
f419105
7993ec8
9e00f8e
6b904b2
87883a2
c731a76
3100eb5
b72f03e
c4055e7
5ffa0d0
766e2f0
769ab73
9ba3f97
75da5d6
f9b7e84
31f0c26
7d2ee8d
a954f49
29ab3e1
e1bb431
f256970
b33bd48
f846aa3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,89 @@ | ||
| from typing import Optional | ||
|
|
||
| from kloppy.domain import TrackingDataset | ||
| from kloppy.infra.serializers.tracking.cdf import ( | ||
| CDFTrackingDataInputs, | ||
| CDFTrackingDeserializer, | ||
| ) | ||
| from kloppy.io import FileLike, open_as_file | ||
|
|
||
|
|
||
| def load_tracking( | ||
| meta_data: FileLike, | ||
| raw_data: FileLike, | ||
| sample_rate: Optional[float] = None, | ||
| limit: Optional[int] = None, | ||
| coordinates: Optional[str] = None, | ||
| include_empty_frames: Optional[bool] = False, | ||
| only_alive: Optional[bool] = True, | ||
| ) -> TrackingDataset: | ||
| """ | ||
| Load Common Data Format broadcast tracking data. | ||
|
|
||
| Args: | ||
| meta_data: A JSON feed containing the meta data. | ||
| raw_data: A JSONL feed containing the raw tracking data. | ||
| sample_rate: Sample the data at a specific rate. | ||
| limit: Limit the number of frames to load to the first `limit` frames. | ||
| coordinates: The coordinate system to use. | ||
| include_empty_frames: Include frames in which no objects were tracked. | ||
| only_alive: Only include frames in which the game is not paused. | ||
|
|
||
| Returns: | ||
| The parsed tracking data. | ||
| """ | ||
| deserializer = CDFTrackingDeserializer( | ||
| sample_rate=sample_rate, | ||
| limit=limit, | ||
| coordinate_system=coordinates, | ||
| include_empty_frames=include_empty_frames, | ||
| only_alive=only_alive, | ||
| ) | ||
| with ( | ||
| open_as_file(meta_data) as meta_data_fp, | ||
| open_as_file(raw_data) as raw_data_fp, | ||
| ): | ||
| return deserializer.deserialize( | ||
| inputs=CDFTrackingDataInputs( | ||
| meta_data=meta_data_fp, raw_data=raw_data_fp | ||
| ) | ||
| ) | ||
|
|
||
|
|
||
| # def load_event( | ||
| # event_data: FileLike, | ||
| # meta_data: FileLike, | ||
| # event_types: Optional[list[str]] = None, | ||
| # coordinates: Optional[str] = None, | ||
| # event_factory: Optional[EventFactory] = None, | ||
| # ) -> EventDataset: | ||
| # """ | ||
| # Load Common Data Format event data. | ||
|
|
||
| # Args: | ||
| # event_data: JSON feed with the raw event data of a game. | ||
| # meta_data: JSON feed with the corresponding lineup information of the game. | ||
| # event_types: A list of event types to load. | ||
| # coordinates: The coordinate system to use. | ||
| # event_factory: A custom event factory. | ||
|
|
||
| # Returns: | ||
| # The parsed event data. | ||
| # """ | ||
| # deserializer = StatsBombDeserializer( | ||
| # event_types=event_types, | ||
| # coordinate_system=coordinates, | ||
| # event_factory=event_factory | ||
| # or get_config("event_factory") | ||
| # or StatsBombEventFactory(), | ||
| # ) | ||
| # with ( | ||
| # open_as_file(event_data) as event_data_fp, | ||
| # open_as_file(meta_data) as meta_data_fp, | ||
| # ): | ||
| # return deserializer.deserialize( | ||
| # inputs=StatsBombInputs( | ||
| # event_data=event_data_fp, | ||
| # lineup_data=meta_data_fp, | ||
| # ) | ||
| # ) | ||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,5 @@ | ||||||
| """Functions for loading SkillCorner broadcast tracking data.""" | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
|
|
||||||
| from ._providers.cdf import load_tracking | ||||||
|
|
||||||
| __all__ = ["load_tracking"] | ||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -109,6 +109,7 @@ class Provider(Enum): | |
| HAWEKEYE (Provider): | ||
| SPORTVU (Provider): | ||
| IMPECT (Provider): | ||
| CDF (Provider): | ||
| OTHER (Provider): | ||
| """ | ||
|
|
||
|
|
@@ -126,8 +127,9 @@ class Provider(Enum): | |
| STATSPERFORM = "statsperform" | ||
| HAWKEYE = "hawkeye" | ||
| SPORTVU = "sportvu" | ||
| SIGNALITY = "signality" | ||
| IMPECT = "impect" | ||
| CDF = "common_data_format" | ||
| SIGNALITY = "signality" | ||
| OTHER = "other" | ||
|
|
||
| def __str__(self): | ||
|
|
@@ -679,12 +681,16 @@ def to_mplsoccer(self): | |
| dim = BaseDims( | ||
| left=self.pitch_dimensions.x_dim.min, | ||
| right=self.pitch_dimensions.x_dim.max, | ||
| bottom=self.pitch_dimensions.y_dim.min | ||
| if not invert_y | ||
| else self.pitch_dimensions.y_dim.max, | ||
| top=self.pitch_dimensions.y_dim.max | ||
| if not invert_y | ||
| else self.pitch_dimensions.y_dim.min, | ||
| bottom=( | ||
| self.pitch_dimensions.y_dim.min | ||
| if not invert_y | ||
| else self.pitch_dimensions.y_dim.max | ||
| ), | ||
| top=( | ||
| self.pitch_dimensions.y_dim.max | ||
| if not invert_y | ||
| else self.pitch_dimensions.y_dim.min | ||
| ), | ||
| width=self.pitch_dimensions.x_dim.max | ||
| - self.pitch_dimensions.x_dim.min, | ||
| length=self.pitch_dimensions.y_dim.max | ||
|
|
@@ -733,14 +739,16 @@ def to_mplsoccer(self): | |
| - self.pitch_dimensions.x_dim.min | ||
| ), | ||
| pad_multiplier=1, | ||
| aspect_equal=False | ||
| if self.pitch_dimensions.unit == Unit.NORMED | ||
| else True, | ||
| aspect_equal=( | ||
| False if self.pitch_dimensions.unit == Unit.NORMED else True | ||
| ), | ||
| pitch_width=pitch_width, | ||
| pitch_length=pitch_length, | ||
| aspect=pitch_width / pitch_length | ||
| if self.pitch_dimensions.unit == Unit.NORMED | ||
| else 1.0, | ||
| aspect=( | ||
| pitch_width / pitch_length | ||
| if self.pitch_dimensions.unit == Unit.NORMED | ||
| else 1.0 | ||
| ), | ||
| ) | ||
| return dim | ||
|
|
||
|
|
@@ -1184,6 +1192,58 @@ def pitch_dimensions(self) -> PitchDimensions: | |
| ) | ||
|
|
||
|
|
||
| class CDFCoordinateSystem(ProviderCoordinateSystem): | ||
| """ | ||
| CDFCoordinateSystem coordinate system. | ||
|
|
||
| Uses a pitch with the origin at the center and the y-axis oriented | ||
| from bottom to top. The coordinates are in meters. | ||
| """ | ||
|
|
||
| @property | ||
| def provider(self) -> Provider: | ||
| return Provider.CDF | ||
|
|
||
| @property | ||
| def origin(self) -> Origin: | ||
| return Origin.CENTER | ||
|
|
||
| @property | ||
| def vertical_orientation(self) -> VerticalOrientation: | ||
| return VerticalOrientation.BOTTOM_TO_TOP | ||
|
|
||
| @property | ||
| def pitch_dimensions(self) -> PitchDimensions: | ||
| return NormalizedPitchDimensions( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think these should be |
||
| x_dim=Dimension( | ||
| -1 * self._pitch_length / 2, self._pitch_length / 2 | ||
| ), | ||
| y_dim=Dimension(-1 * self._pitch_width / 2, self._pitch_width / 2), | ||
| pitch_length=self._pitch_length, | ||
| pitch_width=self._pitch_width, | ||
| standardized=False, | ||
| ) | ||
|
|
||
| def __init__( | ||
| self, | ||
| base_coordinate_system: ProviderCoordinateSystem | None = None, | ||
| pitch_length: float | None = None, | ||
| pitch_width: float | None = None, | ||
|
Comment on lines
+1229
to
+1231
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should use |
||
| ): | ||
| if base_coordinate_system is not None: | ||
| # Used by serializer - derive dimensions from source coordinate system | ||
| self._pitch_length = ( | ||
| base_coordinate_system.pitch_dimensions.pitch_length | ||
| ) | ||
| self._pitch_width = ( | ||
| base_coordinate_system.pitch_dimensions.pitch_width | ||
| ) | ||
| else: | ||
| # Used by deserializer - direct pitch dimensions | ||
| self._pitch_length = pitch_length | ||
| self._pitch_width = pitch_width | ||
|
|
||
|
|
||
| class SignalityCoordinateSystem(ProviderCoordinateSystem): | ||
| @property | ||
| def provider(self) -> Provider: | ||
|
|
@@ -1414,6 +1474,7 @@ def build_coordinate_system( | |
| Provider.SPORTVU: SportVUCoordinateSystem, | ||
| Provider.SIGNALITY: SignalityCoordinateSystem, | ||
| Provider.IMPECT: ImpectCoordinateSystem, | ||
| Provider.CDF: CDFCoordinateSystem, | ||
| } | ||
|
|
||
| if provider in coordinate_systems: | ||
|
|
@@ -1944,6 +2005,10 @@ def to_df( | |
| else: | ||
| raise KloppyParameterError(f"Engine {engine} is not valid") | ||
|
|
||
| def to_cdf(self): | ||
| if self.dataset_type != DatasetType.TRACKING: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove the |
||
| raise ValueError("to_cdf() is only supported for TrackingDataset") | ||
|
|
||
| def __repr__(self): | ||
| return f"<{self.__class__.__name__} record_count={len(self.records)}>" | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,5 +1,5 @@ | ||
| from dataclasses import dataclass, field | ||
| from typing import Any, Optional | ||
| from typing import TYPE_CHECKING, Any, Optional, Union | ||
|
|
||
| from kloppy.domain.models.common import DatasetType | ||
| from kloppy.utils import ( | ||
|
|
@@ -9,6 +9,11 @@ | |
| from .common import DataRecord, Dataset, Player | ||
| from .pitch import Point, Point3D | ||
|
|
||
| if TYPE_CHECKING: | ||
| from cdf.domain import CdfMetaDataSchema | ||
|
|
||
| from kloppy.io import FileLike | ||
|
|
||
|
|
||
| @dataclass | ||
| class PlayerData: | ||
|
|
@@ -79,5 +84,74 @@ def frames(self): | |
| def frame_rate(self): | ||
| return self.metadata.frame_rate | ||
|
|
||
| # Update the to_cdf method in Dataset class | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove / address this comment |
||
| def to_cdf( | ||
| self, | ||
| metadata_output_file: "FileLike", | ||
| tracking_output_file: "FileLike", | ||
| additional_metadata: Optional[Union[dict, "CdfMetaDataSchema"]] = None, | ||
| ) -> None: | ||
| """ | ||
| Export dataset to Common Data Format (CDF). | ||
|
|
||
| Args: | ||
| metadata_output_file: File path or file-like object for metadata JSON output. | ||
| Must have .json extension if a string path. | ||
| tracking_output_file: File path or file-like object for tracking JSONL output. | ||
| Must have .jsonl extension if a string path. | ||
| additional_metadata: Additional metadata to include in the CDF output. | ||
| Can be a complete CdfMetaDataSchema TypedDict or a partial dict. | ||
| Supported top-level keys: 'competition', 'season', 'stadium', 'meta', 'match'. | ||
| Supports nested updates like {'stadium': {'id': '123'}}. | ||
|
|
||
| Raises: | ||
| KloppyError: If the dataset is not a TrackingDataset. | ||
| ValueError: If file extensions are invalid. | ||
|
|
||
| Examples: | ||
| >>> # Export to local files | ||
| >>> dataset.to_cdf( | ||
| ... metadata_output_file='metadata.json', | ||
| ... tracking_output_file='tracking.jsonl' | ||
| ... ) | ||
|
|
||
| >>> # Export to S3 | ||
| >>> dataset.to_cdf( | ||
| ... metadata_output_file='s3://bucket/metadata.json', | ||
| ... tracking_output_file='s3://bucket/tracking.jsonl' | ||
| ... ) | ||
|
|
||
| >>> # Export with partial metadata updates | ||
| >>> dataset.to_cdf( | ||
| ... metadata_output_file='metadata.json', | ||
| ... tracking_output_file='tracking.jsonl', | ||
| ... additional_metadata={ | ||
| ... 'competition': {'id': '123'}, | ||
| ... 'season': {'id': '2024'}, | ||
| ... 'stadium': {'id': '456', 'name': 'Stadium Name'} | ||
| ... } | ||
| ... ) | ||
| """ | ||
| from kloppy.infra.serializers.tracking.cdf import ( | ||
| CDFOutputs, | ||
| CDFTrackingSerializer, | ||
| ) | ||
| from kloppy.io import open_as_file | ||
|
|
||
| serializer = CDFTrackingSerializer() | ||
|
|
||
| # TODO: write files but also support non-local files, similar to how open_as_file supports non-local files | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this is done now. |
||
|
|
||
| # Use open_as_file with mode="wb" for writing | ||
| with ( | ||
| open_as_file(metadata_output_file, mode="wb") as metadata_fp, | ||
| open_as_file(tracking_output_file, mode="wb") as tracking_fp, | ||
| ): | ||
| serializer.serialize( | ||
| dataset=self, | ||
| outputs=CDFOutputs(meta_data=metadata_fp, raw_data=tracking_fp), | ||
| additional_metadata=additional_metadata, | ||
| ) | ||
|
|
||
|
|
||
| __all__ = ["Frame", "TrackingDataset", "PlayerData"] | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,6 @@ | ||
| from kloppy.domain.models.common import CDFCoordinateSystem | ||
|
|
||
| from .deserializer import CDFTrackingDataInputs, CDFTrackingDeserializer | ||
| from .serializer import CDFOutputs, CDFTrackingSerializer | ||
|
|
||
| __all__ = ["CDFCoordinateSystem", "CDFTrackingSerializer", "CDFOutputs"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this only work for broadcast tracking data? I guess this is a copy-paste error.