diff --git a/kloppy/_providers/datafactory.py b/kloppy/_providers/datafactory.py index fc4afb815..5eaf9e5ff 100644 --- a/kloppy/_providers/datafactory.py +++ b/kloppy/_providers/datafactory.py @@ -12,6 +12,7 @@ def load( event_types: Optional[List[str]] = None, coordinates: Optional[str] = None, event_factory: Optional[EventFactory] = None, + exclude_penalty_shootouts: bool = False, ) -> EventDataset: """ Load DataFactory event data. @@ -21,6 +22,7 @@ def load( event_types: A list of event types to load. coordinates: The coordinate system to use. event_factory: A custom event factory. + exclude_penalty_shootouts: If True, excludes events from penalty shootouts (period 5). Returns: The parsed event data. @@ -29,6 +31,7 @@ def load( event_types=event_types, coordinate_system=coordinates, event_factory=event_factory or get_config("event_factory"), + exclude_penalty_shootouts=exclude_penalty_shootouts, ) with open_as_file(event_data) as event_data_fp: return deserializer.deserialize( diff --git a/kloppy/_providers/impect.py b/kloppy/_providers/impect.py index 418f5eedb..008396f00 100644 --- a/kloppy/_providers/impect.py +++ b/kloppy/_providers/impect.py @@ -18,6 +18,7 @@ def load( event_types: Optional[List[str]] = None, coordinates: Optional[str] = None, event_factory: Optional[EventFactory] = None, + exclude_penalty_shootouts: bool = False, ) -> EventDataset: """ Load Impect event data into a [`EventDataset`][kloppy.domain.models.event.EventDataset] @@ -30,6 +31,7 @@ def load( event_types: A list of event types to load. When set, only the specified event types will be loaded. coordinates: The coordinate system to use. Defaults to "impect". See [`kloppy.domain.models.common.Provider`][kloppy.domain.models.common.Provider] for available options. event_factory: A custom event factory. When set, the factory is used to create event instances. + exclude_penalty_shootouts: If True, excludes events from penalty shootouts (period 5). Returns: The parsed event data. @@ -38,6 +40,7 @@ def load( event_types=event_types, coordinate_system=coordinates, event_factory=event_factory or get_config("event_factory"), + exclude_penalty_shootouts=exclude_penalty_shootouts, ) with open_as_file(event_data) as event_data_fp, open_as_file( lineup_data @@ -62,6 +65,7 @@ def load_open_data( event_types: Optional[List[str]] = None, coordinates: Optional[str] = None, event_factory: Optional[EventFactory] = None, + exclude_penalty_shootouts: bool = False, ) -> EventDataset: """ Load Impect open data. @@ -75,6 +79,7 @@ def load_open_data( event_types: A list of event types to load. coordinates: The coordinate system to use. event_factory: A custom event factory. + exclude_penalty_shootouts: If True, excludes events from penalty shootouts (period 5). Returns: The parsed event data. @@ -104,4 +109,5 @@ def load_open_data( event_types=event_types, coordinates=coordinates, event_factory=event_factory, + exclude_penalty_shootouts=exclude_penalty_shootouts, ) diff --git a/kloppy/_providers/metrica.py b/kloppy/_providers/metrica.py index 660d71fa1..032869c47 100644 --- a/kloppy/_providers/metrica.py +++ b/kloppy/_providers/metrica.py @@ -91,6 +91,7 @@ def load_event( event_types: Optional[List[str]] = None, coordinates: Optional[str] = None, event_factory: Optional[EventFactory] = None, + exclude_penalty_shootouts: bool = False, ) -> EventDataset: """Load Metrica Sports JSON event data. @@ -100,6 +101,7 @@ def load_event( event_types: A list of event types to load. coordinates: The coordinate system to use. event_factory: A custom event factory. + exclude_penalty_shootouts: If True, excludes events from penalty shootouts (period 5). Returns: The parsed event data. @@ -108,6 +110,7 @@ def load_event( event_types=event_types, coordinate_system=coordinates, event_factory=event_factory or get_config("event_factory"), + exclude_penalty_shootouts=exclude_penalty_shootouts, ) with open_as_file(event_data) as event_data_fp, open_as_file( diff --git a/kloppy/_providers/opta.py b/kloppy/_providers/opta.py index 80fff92b0..e12db66fb 100644 --- a/kloppy/_providers/opta.py +++ b/kloppy/_providers/opta.py @@ -13,6 +13,7 @@ def load( event_types: Optional[List[str]] = None, coordinates: Optional[str] = None, event_factory: Optional[EventFactory] = None, + exclude_penalty_shootouts: bool = False, ) -> EventDataset: """ Load Opta event data. @@ -23,6 +24,7 @@ def load( event_types: A list of event types to load. coordinates: The coordinate system to use. event_factory: A custom event factory. + exclude_penalty_shootouts: If True, excludes events from penalty shootouts (period 5). Returns: The parsed event data. @@ -31,6 +33,7 @@ def load( event_types=event_types, coordinate_system=coordinates, event_factory=event_factory or get_config("event_factory"), + exclude_penalty_shootouts=exclude_penalty_shootouts, ) with open_as_file(f7_data) as f7_data_fp, open_as_file( f24_data diff --git a/kloppy/_providers/sportec.py b/kloppy/_providers/sportec.py index 91438e097..d99813648 100644 --- a/kloppy/_providers/sportec.py +++ b/kloppy/_providers/sportec.py @@ -20,6 +20,7 @@ def load_event( event_types: Optional[List[str]] = None, coordinates: Optional[str] = None, event_factory: Optional[EventFactory] = None, + exclude_penalty_shootouts: bool = False, ) -> EventDataset: """ Load Sportec Solutions event data. @@ -30,6 +31,7 @@ def load_event( event_types: A list of event types to load. coordinates: The coordinate system to use. event_factory: A custom event factory. + exclude_penalty_shootouts: If True, excludes events from penalty shootouts (period 5). Returns: The parsed event data. @@ -38,6 +40,7 @@ def load_event( event_types=event_types, coordinate_system=coordinates, event_factory=event_factory or get_config("event_factory"), + exclude_penalty_shootouts=exclude_penalty_shootouts, ) with open_as_file(event_data) as event_data_fp, open_as_file( meta_data @@ -94,9 +97,15 @@ def load( event_types: Optional[List[str]] = None, coordinates: Optional[str] = None, event_factory: Optional[EventFactory] = None, + exclude_penalty_shootouts: bool = False, ) -> EventDataset: return load_event( - event_data, meta_data, event_types, coordinates, event_factory + event_data, + meta_data, + event_types, + coordinates, + event_factory, + exclude_penalty_shootouts, ) @@ -133,6 +142,7 @@ def load_open_event_data( event_types: Optional[List[str]] = None, coordinates: Optional[str] = None, event_factory: Optional[EventFactory] = None, + exclude_penalty_shootouts: bool = False, ) -> EventDataset: """ Load event data for a game from the IDSSE dataset. @@ -149,6 +159,7 @@ def load_open_event_data( event_types: coordinates: event_factory: + exclude_penalty_shootouts: If True, excludes events from penalty shootouts (period 5). Notes: The dataset contains seven full matches of raw event and position data @@ -180,6 +191,7 @@ def load_open_event_data( event_types=event_types, coordinates=coordinates, event_factory=event_factory, + exclude_penalty_shootouts=exclude_penalty_shootouts, ) diff --git a/kloppy/_providers/statsbomb.py b/kloppy/_providers/statsbomb.py index 9f9f6e66f..e398825b6 100644 --- a/kloppy/_providers/statsbomb.py +++ b/kloppy/_providers/statsbomb.py @@ -20,6 +20,7 @@ def load( coordinates: Optional[str] = None, event_factory: Optional[EventFactory] = None, additional_metadata: dict = {}, + exclude_penalty_shootouts: bool = False, ) -> EventDataset: """ Load StatsBomb event data. @@ -34,6 +35,7 @@ def load( additional_metadata: A dict with additional data that will be added to the metadata. See the [`Metadata`][kloppy.domain.Metadata] entity for a list of possible keys. + exclude_penalty_shootouts: If True, excludes events from penalty shootouts (period 5). Returns: The parsed event data. @@ -44,6 +46,7 @@ def load( event_factory=event_factory or get_config("event_factory") or StatsBombEventFactory(), + exclude_penalty_shootouts=exclude_penalty_shootouts, ) with open_as_file(event_data) as event_data_fp, open_as_file( lineup_data @@ -65,6 +68,7 @@ def load_open_data( event_types: Optional[List[str]] = None, coordinates: Optional[str] = None, event_factory: Optional[EventFactory] = None, + exclude_penalty_shootouts: bool = False, ) -> EventDataset: """ Load StatsBomb open data. @@ -77,6 +81,7 @@ def load_open_data( event_types: A list of event types to load. coordinates: The coordinate system to use. event_factory: A custom event factory. + exclude_penalty_shootouts: If True, excludes events from penalty shootouts (period 5). Returns: The parsed event data. @@ -110,4 +115,5 @@ def load_open_data( event_types=event_types, coordinates=coordinates, event_factory=event_factory, + exclude_penalty_shootouts=exclude_penalty_shootouts, ) diff --git a/kloppy/_providers/statsperform.py b/kloppy/_providers/statsperform.py index c9c2abeff..fde899427 100644 --- a/kloppy/_providers/statsperform.py +++ b/kloppy/_providers/statsperform.py @@ -58,6 +58,7 @@ def load_event( event_types: Optional[List[str]] = None, coordinates: Optional[str] = None, event_factory: Optional[EventFactory] = None, + exclude_penalty_shootouts: bool = False, ) -> EventDataset: """Load Stats Perform event data. @@ -69,6 +70,7 @@ def load_event( event_types: A list of event types to load. coordinates: The coordinate system to use. event_factory: A custom event factory. + exclude_penalty_shootouts: If True, excludes events from penalty shootouts (period 5). Returns: The parsed event data. @@ -77,6 +79,7 @@ def load_event( event_types=event_types, coordinate_system=coordinates, event_factory=event_factory or get_config("event_factory"), # type: ignore + exclude_penalty_shootouts=exclude_penalty_shootouts, ) with open_as_file(ma1_data) as ma1_data_fp, open_as_file( ma3_data diff --git a/kloppy/_providers/wyscout.py b/kloppy/_providers/wyscout.py index 49e5b2b1b..b492b52c7 100644 --- a/kloppy/_providers/wyscout.py +++ b/kloppy/_providers/wyscout.py @@ -18,6 +18,7 @@ def load( coordinates: Optional[str] = None, event_factory: Optional[EventFactory] = None, data_version: Optional[str] = None, + exclude_penalty_shootouts: bool = False, ) -> EventDataset: """ Load Wyscout event data. @@ -28,6 +29,7 @@ def load( coordinates: The coordinate system to use. event_factory: A custom event factory. data_version: The version of the Wyscout data. Supported versions are "V2" and "V3". + exclude_penalty_shootouts: If True, excludes events from penalty shootouts (period 5). Returns: The parsed event data. @@ -43,6 +45,7 @@ def load( event_types=event_types, coordinate_system=coordinates, event_factory=event_factory or get_config("event_factory"), + exclude_penalty_shootouts=exclude_penalty_shootouts, ) with open_as_file(event_data) as event_data_fp: @@ -56,6 +59,7 @@ def load_open_data( event_types: Optional[List[str]] = None, coordinates: Optional[str] = None, event_factory: Optional[EventFactory] = None, + exclude_penalty_shootouts: bool = False, ) -> EventDataset: """ Load Wyscout open data. @@ -71,6 +75,7 @@ def load_open_data( event_types: A list of event types to load. coordinates: The coordinate system to use. event_factory: A custom event factory. + exclude_penalty_shootouts: If True, excludes events from penalty shootouts (period 5). Returns: The parsed event data. @@ -87,6 +92,7 @@ def load_open_data( event_types=event_types, coordinates=coordinates, event_factory=event_factory, + exclude_penalty_shootouts=exclude_penalty_shootouts, ) diff --git a/kloppy/infra/serializers/event/datafactory/deserializer.py b/kloppy/infra/serializers/event/datafactory/deserializer.py index 1a871106a..7913e63ee 100644 --- a/kloppy/infra/serializers/event/datafactory/deserializer.py +++ b/kloppy/infra/serializers/event/datafactory/deserializer.py @@ -624,7 +624,12 @@ def deserialize(self, inputs: DatafactoryInputs) -> EventDataset: game_id=game_id, ) - return EventDataset( + dataset = EventDataset( metadata=metadata, records=events, ) + + # Remove penalty shootout data if requested + dataset = self.remove_penalty_shootout_data(dataset) + + return dataset diff --git a/kloppy/infra/serializers/event/deserializer.py b/kloppy/infra/serializers/event/deserializer.py index b24c52cf5..5a04d5e5e 100644 --- a/kloppy/infra/serializers/event/deserializer.py +++ b/kloppy/infra/serializers/event/deserializer.py @@ -21,6 +21,7 @@ def __init__( event_types: Optional[List[Union[EventType, str]]] = None, coordinate_system: Optional[Union[str, Provider]] = None, event_factory: Optional[EventFactory] = None, + exclude_penalty_shootouts: bool = False, ): if not event_types: event_types = [] @@ -39,12 +40,72 @@ def __init__( if not event_factory: event_factory = EventFactory() self.event_factory = event_factory + self.exclude_penalty_shootouts = exclude_penalty_shootouts def should_include_event(self, event: Event) -> bool: + if ( + self.exclude_penalty_shootouts + and event.period + and event.period.id == 5 + ): + return False if not self.event_types: return True return event.event_type in self.event_types + def remove_penalty_shootout_data( + self, dataset: EventDataset + ) -> EventDataset: + """ + Remove all penalty shootout data from the dataset including: + - Period 5 from metadata.periods + - Player positions associated with period 5 + - Team formations associated with period 5 + """ + if not self.exclude_penalty_shootouts: + return dataset + + # Remove period 5 from metadata.periods + dataset.metadata.periods = [ + period for period in dataset.metadata.periods if period.id != 5 + ] + + # Update period references (prev/next) after removal + for i, period in enumerate(dataset.metadata.periods): + period.set_refs( + prev=dataset.metadata.periods[i - 1] if i > 0 else None, + next_=( + dataset.metadata.periods[i + 1] + if i + 1 < len(dataset.metadata.periods) + else None + ), + ) + + # Remove player positions and team formations associated with period 5 + for team in dataset.metadata.teams: + # Filter out formations for period 5 + if team.formations.items: + times_to_remove = [ + time + for time in team.formations.items.keys() + if time.period and time.period.id == 5 + ] + for time in times_to_remove: + del team.formations.items[time] + + # Filter out player positions for period 5 + for player in team.players: + if player.positions.items: + times_to_remove = [ + time + for time in player.positions.items.keys() + if time.period and time.period.id == 5 + ] + for time in times_to_remove: + del player.positions.items[time] + + return dataset + def get_transformer( self, pitch_length: Optional[float] = None, diff --git a/kloppy/infra/serializers/event/impect/deserializer.py b/kloppy/infra/serializers/event/impect/deserializer.py index a55dd709a..0822c68fd 100644 --- a/kloppy/infra/serializers/event/impect/deserializer.py +++ b/kloppy/infra/serializers/event/impect/deserializer.py @@ -157,6 +157,9 @@ def deserialize(self, inputs: ImpectInputs) -> EventDataset: ) dataset = EventDataset(metadata=metadata, records=events) + # Remove penalty shootout data if requested + dataset = self.remove_penalty_shootout_data(dataset) + return dataset @staticmethod diff --git a/kloppy/infra/serializers/event/metrica/json_deserializer.py b/kloppy/infra/serializers/event/metrica/json_deserializer.py index 70475814f..1d6a286f7 100644 --- a/kloppy/infra/serializers/event/metrica/json_deserializer.py +++ b/kloppy/infra/serializers/event/metrica/json_deserializer.py @@ -401,7 +401,7 @@ def deserialize(self, inputs: MetricaJsonEventDataInputs) -> EventDataset: if self.should_include_event(event): events.append(transformer.transform_event(event)) - return EventDataset( + dataset = EventDataset( metadata=replace( metadata, pitch_dimensions=transformer.get_to_coordinate_system().pitch_dimensions, @@ -409,3 +409,8 @@ def deserialize(self, inputs: MetricaJsonEventDataInputs) -> EventDataset: ), records=events, ) + + # Remove penalty shootout data if requested + dataset = self.remove_penalty_shootout_data(dataset) + + return dataset diff --git a/kloppy/infra/serializers/event/sportec/deserializer.py b/kloppy/infra/serializers/event/sportec/deserializer.py index 8bf8df59e..88c47bb54 100644 --- a/kloppy/infra/serializers/event/sportec/deserializer.py +++ b/kloppy/infra/serializers/event/sportec/deserializer.py @@ -717,7 +717,12 @@ def deserialize(self, inputs: SportecEventDataInputs) -> EventDataset: officials=sportec_metadata.officials, ) - return EventDataset( + dataset = EventDataset( metadata=metadata, records=events, ) + + # Remove penalty shootout data if requested + dataset = self.remove_penalty_shootout_data(dataset) + + return dataset diff --git a/kloppy/infra/serializers/event/statsbomb/deserializer.py b/kloppy/infra/serializers/event/statsbomb/deserializer.py index ff007af07..e60de3c5f 100644 --- a/kloppy/infra/serializers/event/statsbomb/deserializer.py +++ b/kloppy/infra/serializers/event/statsbomb/deserializer.py @@ -117,6 +117,10 @@ def deserialize( visible_area=freeze_frame["visible_area"], ) ) + + # Remove penalty shootout data if requested + dataset = self.remove_penalty_shootout_data(dataset) + return dataset def load_data(self, inputs: StatsBombInputs): diff --git a/kloppy/infra/serializers/event/statsperform/deserializer.py b/kloppy/infra/serializers/event/statsperform/deserializer.py index e033ec5c7..61bf9a809 100644 --- a/kloppy/infra/serializers/event/statsperform/deserializer.py +++ b/kloppy/infra/serializers/event/statsperform/deserializer.py @@ -1012,7 +1012,12 @@ def deserialize(self, inputs: StatsPerformInputs) -> EventDataset: game_id=game_id, ) - return EventDataset( + dataset = EventDataset( metadata=metadata, records=events, ) + + # Remove penalty shootout data if requested + dataset = self.remove_penalty_shootout_data(dataset) + + return dataset diff --git a/kloppy/infra/serializers/event/wyscout/deserializer_v2.py b/kloppy/infra/serializers/event/wyscout/deserializer_v2.py index 98a359941..189c412bb 100644 --- a/kloppy/infra/serializers/event/wyscout/deserializer_v2.py +++ b/kloppy/infra/serializers/event/wyscout/deserializer_v2.py @@ -736,4 +736,9 @@ def deserialize(self, inputs: WyscoutInputs) -> EventDataset: game_id=game_id, ) - return EventDataset(metadata=metadata, records=events) + dataset = EventDataset(metadata=metadata, records=events) + + # Remove penalty shootout data if requested + dataset = self.remove_penalty_shootout_data(dataset) + + return dataset diff --git a/kloppy/infra/serializers/event/wyscout/deserializer_v3.py b/kloppy/infra/serializers/event/wyscout/deserializer_v3.py index 60153d36a..8a41aec9f 100644 --- a/kloppy/infra/serializers/event/wyscout/deserializer_v3.py +++ b/kloppy/infra/serializers/event/wyscout/deserializer_v3.py @@ -1055,4 +1055,9 @@ def deserialize(self, inputs: WyscoutInputs) -> EventDataset: away_coach=away_coach, ) - return EventDataset(metadata=metadata, records=events) + dataset = EventDataset(metadata=metadata, records=events) + + # Remove penalty shootout data if requested + dataset = self.remove_penalty_shootout_data(dataset) + + return dataset diff --git a/kloppy/tests/test_event.py b/kloppy/tests/test_event.py index c89c214bf..c613757d4 100644 --- a/kloppy/tests/test_event.py +++ b/kloppy/tests/test_event.py @@ -1,6 +1,6 @@ import pytest -from kloppy import statsbomb +from kloppy import opta, statsbomb from kloppy.domain import EventDataset @@ -87,3 +87,118 @@ def test_find_all(self, dataset: EventDataset): assert goals[0].next("shot.goal") == goals[1] assert goals[0].next("shot.goal") == goals[2].prev("shot.goal") assert goals[2].next("shot.goal") is None + + +class TestExcludePenaltyShootouts: + """Tests for excluding penalty shootout data across all providers""" + + @pytest.fixture(scope="class") + def dataset_with_shootout(self, base_dir) -> EventDataset: + """Load Opta data including penalty shootout (period 5)""" + return opta.load( + f7_data=base_dir / "files" / "opta_f7.xml", + f24_data=base_dir / "files" / "opta_f24.xml", + coordinates="opta", + exclude_penalty_shootouts=False, + ) + + @pytest.fixture(scope="class") + def dataset_without_shootout(self, base_dir) -> EventDataset: + """Load Opta data excluding penalty shootout (period 5)""" + return opta.load( + f7_data=base_dir / "files" / "opta_f7.xml", + f24_data=base_dir / "files" / "opta_f24.xml", + coordinates="opta", + exclude_penalty_shootouts=True, + ) + + def test_periods_with_shootout(self, dataset_with_shootout: EventDataset): + """It should include all 5 periods when penalty shootouts are not excluded""" + assert len(dataset_with_shootout.metadata.periods) == 5 + period_ids = [p.id for p in dataset_with_shootout.metadata.periods] + assert period_ids == [1, 2, 3, 4, 5] + + def test_periods_without_shootout( + self, dataset_without_shootout: EventDataset + ): + """It should only include 4 periods when penalty shootouts are excluded""" + assert len(dataset_without_shootout.metadata.periods) == 4 + period_ids = [p.id for p in dataset_without_shootout.metadata.periods] + assert period_ids == [1, 2, 3, 4] + # Ensure period 5 is not present + assert 5 not in period_ids + + def test_events_with_shootout(self, dataset_with_shootout: EventDataset): + """It should include penalty shootout events when not excluded""" + period_5_events = [ + e + for e in dataset_with_shootout.events + if e.period and e.period.id == 5 + ] + # The test file has 1 parsed penalty shootout event (a shot) + # Note: There are 5 events with period_id="5" in the XML, but only + # events that are currently supported by the parser are deserialized + assert len(period_5_events) == 1 + + def test_events_without_shootout( + self, dataset_without_shootout: EventDataset + ): + """It should exclude all penalty shootout events when excluded""" + period_5_events = [ + e + for e in dataset_without_shootout.events + if e.period and e.period.id == 5 + ] + assert len(period_5_events) == 0 + + def test_event_count_difference( + self, + dataset_with_shootout: EventDataset, + dataset_without_shootout: EventDataset, + ): + """It should have exactly 1 fewer event when penalty shootouts are excluded""" + count_with = len(dataset_with_shootout.events) + count_without = len(dataset_without_shootout.events) + # The difference is 1 because only 1 penalty shootout event is parsed + assert count_with - count_without == 1 + + def test_player_positions_without_shootout( + self, dataset_without_shootout: EventDataset + ): + """It should not have any player positions referencing period 5""" + for team in dataset_without_shootout.metadata.teams: + for player in team.players: + if player.positions.items: + for time in player.positions.items.keys(): + assert ( + time.period is None or time.period.id != 5 + ), f"Player {player.full_name} has position at period 5" + + def test_team_formations_without_shootout( + self, dataset_without_shootout: EventDataset + ): + """It should not have any team formations referencing period 5""" + for team in dataset_without_shootout.metadata.teams: + if team.formations.items: + for time in team.formations.items.keys(): + assert ( + time.period is None or time.period.id != 5 + ), f"Team {team.name} has formation at period 5" + + def test_period_references_without_shootout( + self, dataset_without_shootout: EventDataset + ): + """It should correctly update period references after removing period 5""" + periods = dataset_without_shootout.metadata.periods + for i, period in enumerate(periods): + # Check prev_period reference + if i > 0: + assert period.prev_period == periods[i - 1] + else: + assert period.prev_period is None + + # Check next_period reference + if i < len(periods) - 1: + assert period.next_period == periods[i + 1] + else: + assert period.next_period is None