diff --git a/kloppy/_providers/hawkeye.py b/kloppy/_providers/hawkeye.py index b09f56d5a..22d145e00 100644 --- a/kloppy/_providers/hawkeye.py +++ b/kloppy/_providers/hawkeye.py @@ -18,6 +18,7 @@ def load( limit: Optional[int] = None, coordinates: Optional[str] = None, show_progress: Optional[bool] = False, + object_id: Optional[str] = None, ) -> TrackingDataset: """ Load HawkEye tracking data. @@ -62,6 +63,7 @@ def load( sample_rate=sample_rate, limit=limit, coordinate_system=coordinates, + object_id=object_id, ) return deserializer.deserialize( inputs=HawkEyeInputs( diff --git a/kloppy/infra/serializers/tracking/hawkeye/deserializer.py b/kloppy/infra/serializers/tracking/hawkeye/deserializer.py index 2bb09a9f5..823e45a30 100644 --- a/kloppy/infra/serializers/tracking/hawkeye/deserializer.py +++ b/kloppy/infra/serializers/tracking/hawkeye/deserializer.py @@ -73,6 +73,13 @@ def get_identifier_variable(cls, player_tracking_data): return identifier return object_id + @classmethod + def is_priority(cls, object_id): + if object_id in cls.PRIORITY_IDS: + return True + else: + return False + class HawkEyeDeserializer(TrackingDataDeserializer[HawkEyeInputs]): def __init__( @@ -82,9 +89,12 @@ def __init__( limit: Optional[int] = None, sample_rate: Optional[float] = None, coordinate_system: Optional[Union[str, Provider]] = None, + object_id: Optional[str] = None, ): super().__init__(limit, sample_rate, coordinate_system) - self.object_id: HawkEyeObjectIdentifier = None + self.object_id: HawkEyeObjectIdentifier = ( + None if object_id is None else object_id + ) self.pitch_width = pitch_width self.pitch_length = pitch_length @@ -270,17 +280,19 @@ def deserialize(self, inputs: HawkEyeInputs) -> TrackingDataset: with open_as_file(player_centroid_feed) as player_centroid_data_fp: player_tracking_data = json.load(player_centroid_data_fp) - self.object_id = HawkEyeObjectIdentifier.get_identifier_variable( + _object_id = HawkEyeObjectIdentifier.get_identifier_variable( player_tracking_data ) + if HawkEyeObjectIdentifier.is_priority(_object_id): + self.object_id = _object_id if frame_rate is None: frame_rate = self.__infer_frame_rate(ball_tracking_data) if not self._game_id: - self._game_id = ball_tracking_data["details"]["match"]["id"][ - self.object_id - ] + self._game_id = ball_tracking_data["details"]["match"][ + "id" + ].get(self.object_id, None) # Parse the teams, players and periods. A value can be added by # later feeds, but we will not overwrite existing values.