Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 89 additions & 0 deletions kloppy/_providers/cdf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
from typing import Optional

from kloppy.domain import TrackingDataset
from kloppy.infra.serializers.tracking.cdf import (
CDFTrackingDataInputs,
CDFTrackingDeserializer,
)
from kloppy.io import FileLike, open_as_file


def load_tracking(
meta_data: FileLike,
raw_data: FileLike,
sample_rate: Optional[float] = None,
limit: Optional[int] = None,
coordinates: Optional[str] = None,
include_empty_frames: Optional[bool] = False,
only_alive: Optional[bool] = True,
) -> TrackingDataset:
"""
Load Common Data Format broadcast tracking data.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this only work for broadcast tracking data? I guess this is a copy-paste error.


Args:
meta_data: A JSON feed containing the meta data.
raw_data: A JSONL feed containing the raw tracking data.
sample_rate: Sample the data at a specific rate.
limit: Limit the number of frames to load to the first `limit` frames.
coordinates: The coordinate system to use.
include_empty_frames: Include frames in which no objects were tracked.
only_alive: Only include frames in which the game is not paused.

Returns:
The parsed tracking data.
"""
deserializer = CDFTrackingDeserializer(
sample_rate=sample_rate,
limit=limit,
coordinate_system=coordinates,
include_empty_frames=include_empty_frames,
only_alive=only_alive,
)
with (
open_as_file(meta_data) as meta_data_fp,
open_as_file(raw_data) as raw_data_fp,
):
return deserializer.deserialize(
inputs=CDFTrackingDataInputs(
meta_data=meta_data_fp, raw_data=raw_data_fp
)
)


# def load_event(
# event_data: FileLike,
# meta_data: FileLike,
# event_types: Optional[list[str]] = None,
# coordinates: Optional[str] = None,
# event_factory: Optional[EventFactory] = None,
# ) -> EventDataset:
# """
# Load Common Data Format event data.

# Args:
# event_data: JSON feed with the raw event data of a game.
# meta_data: JSON feed with the corresponding lineup information of the game.
# event_types: A list of event types to load.
# coordinates: The coordinate system to use.
# event_factory: A custom event factory.

# Returns:
# The parsed event data.
# """
# deserializer = StatsBombDeserializer(
# event_types=event_types,
# coordinate_system=coordinates,
# event_factory=event_factory
# or get_config("event_factory")
# or StatsBombEventFactory(),
# )
# with (
# open_as_file(event_data) as event_data_fp,
# open_as_file(meta_data) as meta_data_fp,
# ):
# return deserializer.deserialize(
# inputs=StatsBombInputs(
# event_data=event_data_fp,
# lineup_data=meta_data_fp,
# )
# )
5 changes: 5 additions & 0 deletions kloppy/cdf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""Functions for loading SkillCorner broadcast tracking data."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"""Functions for loading SkillCorner broadcast tracking data."""
"""Functions for loading data in the Common Data Format (CDF) standard."""
  • I think it would be good to add a sentence explaining the CDF with a reference to the arxiv paper.


from ._providers.cdf import load_tracking

__all__ = ["load_tracking"]
91 changes: 78 additions & 13 deletions kloppy/domain/models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ class Provider(Enum):
HAWEKEYE (Provider):
SPORTVU (Provider):
IMPECT (Provider):
CDF (Provider):
OTHER (Provider):
"""

Expand All @@ -126,8 +127,9 @@ class Provider(Enum):
STATSPERFORM = "statsperform"
HAWKEYE = "hawkeye"
SPORTVU = "sportvu"
SIGNALITY = "signality"
IMPECT = "impect"
CDF = "common_data_format"
SIGNALITY = "signality"
OTHER = "other"

def __str__(self):
Expand Down Expand Up @@ -679,12 +681,16 @@ def to_mplsoccer(self):
dim = BaseDims(
left=self.pitch_dimensions.x_dim.min,
right=self.pitch_dimensions.x_dim.max,
bottom=self.pitch_dimensions.y_dim.min
if not invert_y
else self.pitch_dimensions.y_dim.max,
top=self.pitch_dimensions.y_dim.max
if not invert_y
else self.pitch_dimensions.y_dim.min,
bottom=(
self.pitch_dimensions.y_dim.min
if not invert_y
else self.pitch_dimensions.y_dim.max
),
top=(
self.pitch_dimensions.y_dim.max
if not invert_y
else self.pitch_dimensions.y_dim.min
),
width=self.pitch_dimensions.x_dim.max
- self.pitch_dimensions.x_dim.min,
length=self.pitch_dimensions.y_dim.max
Expand Down Expand Up @@ -733,14 +739,16 @@ def to_mplsoccer(self):
- self.pitch_dimensions.x_dim.min
),
pad_multiplier=1,
aspect_equal=False
if self.pitch_dimensions.unit == Unit.NORMED
else True,
aspect_equal=(
False if self.pitch_dimensions.unit == Unit.NORMED else True
),
pitch_width=pitch_width,
pitch_length=pitch_length,
aspect=pitch_width / pitch_length
if self.pitch_dimensions.unit == Unit.NORMED
else 1.0,
aspect=(
pitch_width / pitch_length
if self.pitch_dimensions.unit == Unit.NORMED
else 1.0
),
)
return dim

Expand Down Expand Up @@ -1184,6 +1192,58 @@ def pitch_dimensions(self) -> PitchDimensions:
)


class CDFCoordinateSystem(ProviderCoordinateSystem):
"""
CDFCoordinateSystem coordinate system.

Uses a pitch with the origin at the center and the y-axis oriented
from bottom to top. The coordinates are in meters.
"""

@property
def provider(self) -> Provider:
return Provider.CDF

@property
def origin(self) -> Origin:
return Origin.CENTER

@property
def vertical_orientation(self) -> VerticalOrientation:
return VerticalOrientation.BOTTOM_TO_TOP

@property
def pitch_dimensions(self) -> PitchDimensions:
return NormalizedPitchDimensions(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think these should be MetricPitchDimensions

x_dim=Dimension(
-1 * self._pitch_length / 2, self._pitch_length / 2
),
y_dim=Dimension(-1 * self._pitch_width / 2, self._pitch_width / 2),
pitch_length=self._pitch_length,
pitch_width=self._pitch_width,
standardized=False,
)

def __init__(
self,
base_coordinate_system: ProviderCoordinateSystem | None = None,
pitch_length: float | None = None,
pitch_width: float | None = None,
Comment on lines +1229 to +1231
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should use Optional[T] here instead of T | None to be compatible with Python 3.9

):
if base_coordinate_system is not None:
# Used by serializer - derive dimensions from source coordinate system
self._pitch_length = (
base_coordinate_system.pitch_dimensions.pitch_length
)
self._pitch_width = (
base_coordinate_system.pitch_dimensions.pitch_width
)
else:
# Used by deserializer - direct pitch dimensions
self._pitch_length = pitch_length
self._pitch_width = pitch_width


class SignalityCoordinateSystem(ProviderCoordinateSystem):
@property
def provider(self) -> Provider:
Expand Down Expand Up @@ -1414,6 +1474,7 @@ def build_coordinate_system(
Provider.SPORTVU: SportVUCoordinateSystem,
Provider.SIGNALITY: SignalityCoordinateSystem,
Provider.IMPECT: ImpectCoordinateSystem,
Provider.CDF: CDFCoordinateSystem,
}

if provider in coordinate_systems:
Expand Down Expand Up @@ -1944,6 +2005,10 @@ def to_df(
else:
raise KloppyParameterError(f"Engine {engine} is not valid")

def to_cdf(self):
if self.dataset_type != DatasetType.TRACKING:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove the if here and override in the TrackingDataset. A parent class should not be aware of its children.

raise ValueError("to_cdf() is only supported for TrackingDataset")

def __repr__(self):
return f"<{self.__class__.__name__} record_count={len(self.records)}>"

Expand Down
76 changes: 75 additions & 1 deletion kloppy/domain/models/tracking.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass, field
from typing import Any, Optional
from typing import TYPE_CHECKING, Any, Optional, Union

from kloppy.domain.models.common import DatasetType
from kloppy.utils import (
Expand All @@ -9,6 +9,11 @@
from .common import DataRecord, Dataset, Player
from .pitch import Point, Point3D

if TYPE_CHECKING:
from cdf.domain import CdfMetaDataSchema

from kloppy.io import FileLike


@dataclass
class PlayerData:
Expand Down Expand Up @@ -79,5 +84,74 @@ def frames(self):
def frame_rate(self):
return self.metadata.frame_rate

# Update the to_cdf method in Dataset class
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove / address this comment

def to_cdf(
self,
metadata_output_file: "FileLike",
tracking_output_file: "FileLike",
additional_metadata: Optional[Union[dict, "CdfMetaDataSchema"]] = None,
) -> None:
"""
Export dataset to Common Data Format (CDF).

Args:
metadata_output_file: File path or file-like object for metadata JSON output.
Must have .json extension if a string path.
tracking_output_file: File path or file-like object for tracking JSONL output.
Must have .jsonl extension if a string path.
additional_metadata: Additional metadata to include in the CDF output.
Can be a complete CdfMetaDataSchema TypedDict or a partial dict.
Supported top-level keys: 'competition', 'season', 'stadium', 'meta', 'match'.
Supports nested updates like {'stadium': {'id': '123'}}.

Raises:
KloppyError: If the dataset is not a TrackingDataset.
ValueError: If file extensions are invalid.

Examples:
>>> # Export to local files
>>> dataset.to_cdf(
... metadata_output_file='metadata.json',
... tracking_output_file='tracking.jsonl'
... )

>>> # Export to S3
>>> dataset.to_cdf(
... metadata_output_file='s3://bucket/metadata.json',
... tracking_output_file='s3://bucket/tracking.jsonl'
... )

>>> # Export with partial metadata updates
>>> dataset.to_cdf(
... metadata_output_file='metadata.json',
... tracking_output_file='tracking.jsonl',
... additional_metadata={
... 'competition': {'id': '123'},
... 'season': {'id': '2024'},
... 'stadium': {'id': '456', 'name': 'Stadium Name'}
... }
... )
"""
from kloppy.infra.serializers.tracking.cdf import (
CDFOutputs,
CDFTrackingSerializer,
)
from kloppy.io import open_as_file

serializer = CDFTrackingSerializer()

# TODO: write files but also support non-local files, similar to how open_as_file supports non-local files
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is done now.


# Use open_as_file with mode="wb" for writing
with (
open_as_file(metadata_output_file, mode="wb") as metadata_fp,
open_as_file(tracking_output_file, mode="wb") as tracking_fp,
):
serializer.serialize(
dataset=self,
outputs=CDFOutputs(meta_data=metadata_fp, raw_data=tracking_fp),
additional_metadata=additional_metadata,
)


__all__ = ["Frame", "TrackingDataset", "PlayerData"]
8 changes: 5 additions & 3 deletions kloppy/infra/io/adapters/zip.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,10 @@ def list_directory(self, url: str, recursive: bool = True) -> list[str]:
else:
files = fs.listdir(url, detail=False)
return [
f"{protocol}://{fp}"
if protocol != "file" and not fp.startswith(protocol)
else fp
(
f"{protocol}://{fp}"
if protocol != "file" and not fp.startswith(protocol)
else fp
)
for fp in files
]
6 changes: 6 additions & 0 deletions kloppy/infra/serializers/tracking/cdf/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from kloppy.domain.models.common import CDFCoordinateSystem

from .deserializer import CDFTrackingDataInputs, CDFTrackingDeserializer
from .serializer import CDFOutputs, CDFTrackingSerializer

__all__ = ["CDFCoordinateSystem", "CDFTrackingSerializer", "CDFOutputs"]
Loading
Loading