diff --git a/kloppy/_providers/sportscode.py b/kloppy/_providers/sportscode.py index 3ce84e3cf..d403cba61 100644 --- a/kloppy/_providers/sportscode.py +++ b/kloppy/_providers/sportscode.py @@ -2,6 +2,7 @@ from kloppy.infra.serializers.code.sportscode import ( SportsCodeDeserializer, SportsCodeInputs, + SportsCodeOutputs, SportsCodeSerializer, ) from kloppy.io import FileLike, open_as_file @@ -31,6 +32,6 @@ def save(dataset: CodeDataset, output_filename: str) -> None: dataset: The SportsCode dataset to save. output_filename: The output filename. """ - with open(output_filename, "wb") as fp: + with open_as_file(output_filename, "wb") as data_fp: serializer = SportsCodeSerializer() - fp.write(serializer.serialize(dataset)) + serializer.serialize(dataset, outputs=SportsCodeOutputs(data=data_fp)) diff --git a/kloppy/infra/io/adapters/adapter.py b/kloppy/infra/io/adapters/adapter.py index 83839956e..c90305823 100644 --- a/kloppy/infra/io/adapters/adapter.py +++ b/kloppy/infra/io/adapters/adapter.py @@ -1,27 +1,61 @@ from abc import ABC, abstractmethod -from typing import BinaryIO + +from kloppy.infra.io.buffered_stream import BufferedStream class Adapter(ABC): @abstractmethod def supports(self, url: str) -> bool: - pass + """Returns True if this adapter supports the given URL, False otherwise.""" @abstractmethod def is_directory(self, url: str) -> bool: - pass + """Returns True if the given URL points to a directory, False otherwise.""" @abstractmethod def is_file(self, url: str) -> bool: - pass + """Returns True if the given URL points to a file, False otherwise.""" @abstractmethod - def read_to_stream(self, url: str, output: BinaryIO): - pass + def read_to_stream(self, url: str, output: BufferedStream): + """Read content from the given URL into the BufferedStream. + + Args: + url: The source URL + output: BufferedStream to write to + """ + + def write_from_stream(self, url: str, input: BufferedStream, mode: str): # noqa: A002 + """Write content from BufferedStream to the given URL. + + Args: + url: The destination URL + input: BufferedStream to read from + mode: Write mode ('wb' for write/overwrite or 'ab' for append) + + Raises: + NotImplementedError: If write operations are not supported by this adapter + """ + raise NotImplementedError( + f"Write operations not supported for {url}. " + f"Adapter {self.__class__.__name__} does not implement write_from_stream." + ) @abstractmethod def list_directory(self, url: str, recursive: bool = True) -> list[str]: - pass + """Lists the contents of a directory. + + Args: + url: The directory URL + recursive: Whether to list contents recursively + + Returns: + A list of files in the directory + + Example: + >>> adapter.list_directory("s3://my-bucket/data/", recursive=False) + ['s3://my-bucket/data/file1.csv', 's3://my-bucket/data/file2.csv'] + """ __all__ = ["Adapter"] diff --git a/kloppy/infra/io/adapters/file.py b/kloppy/infra/io/adapters/file.py index 72ddd99f2..9fca448ac 100644 --- a/kloppy/infra/io/adapters/file.py +++ b/kloppy/infra/io/adapters/file.py @@ -1,3 +1,5 @@ +import os + import fsspec from .fsspec import FSSpecAdapter @@ -11,3 +13,14 @@ def _get_filesystem( self, url: str, no_cache: bool = False ) -> fsspec.AbstractFileSystem: return fsspec.filesystem("file") + + def list_directory(self, url: str, recursive: bool = True) -> list[str]: + """ + Lists the contents of a directory. + """ + fs = self._get_filesystem(url) + if recursive: + files = fs.find(url, detail=False) + else: + files = fs.listdir(url, detail=False) + return [os.path.normpath(fp) for fp in files] diff --git a/kloppy/infra/io/adapters/fsspec.py b/kloppy/infra/io/adapters/fsspec.py index b26217084..44a436c87 100644 --- a/kloppy/infra/io/adapters/fsspec.py +++ b/kloppy/infra/io/adapters/fsspec.py @@ -1,11 +1,12 @@ from abc import ABC, abstractmethod import re -from typing import BinaryIO, Optional +from typing import Optional import fsspec from kloppy.config import get_config from kloppy.exceptions import InputNotFoundError +from kloppy.infra.io.buffered_stream import BufferedStream from .adapter import Adapter @@ -28,7 +29,6 @@ def _get_filesystem( Get the appropriate fsspec filesystem for the given URL, with caching enabled. """ protocol = self._infer_protocol(url) - if no_cache: return fsspec.filesystem(protocol) @@ -38,6 +38,16 @@ def _get_filesystem( cache_storage=get_config("cache"), ) + def _get_filesystem_for_reading( + self, url: str + ) -> fsspec.AbstractFileSystem: + return self._get_filesystem(url, no_cache=False) + + def _get_filesystem_for_writing( + self, url: str + ) -> fsspec.AbstractFileSystem: + return self._get_filesystem(url, no_cache=True) + def _detect_compression(self, url: str) -> Optional[str]: """ Detect the compression type based on the file extension. @@ -60,20 +70,36 @@ def supports(self, url: str) -> bool: Check if the adapter can handle the URL. """ - def read_to_stream(self, url: str, output: BinaryIO): + def read_to_stream(self, url: str, output: BufferedStream): """ - Reads content from the given URL and writes it to the provided binary stream. - Uses caching for remote files. + Reads content from the given URL and writes it to the provided BufferedStream. + Uses caching for remote files. Copies data in chunks. """ - fs = self._get_filesystem(url) + fs = self._get_filesystem_for_reading(url) compression = self._detect_compression(url) try: with fs.open(url, "rb", compression=compression) as source_file: - output.write(source_file.read()) + output.read_from(source_file) except FileNotFoundError as e: raise InputNotFoundError(f"Input file not found: {url}") from e + def write_from_stream(self, url: str, input: BufferedStream, mode: str): # noqa: A002 + """ + Writes content from BufferedStream to the given URL. + Does not use caching for writes. Copies data in chunks. + + Args: + url: The destination URL + input: BufferedStream to read from + mode: Write mode ('wb' for write/overwrite or 'ab' for append) + """ + fs = self._get_filesystem_for_writing(url) + compression = self._detect_compression(url) + + with fs.open(url, mode, compression=compression) as dest_file: + input.write_to(dest_file) + def list_directory(self, url: str, recursive: bool = True) -> list[str]: """ Lists the contents of a directory. @@ -84,12 +110,7 @@ def list_directory(self, url: str, recursive: bool = True) -> list[str]: files = fs.find(url, detail=False) else: files = fs.listdir(url, detail=False) - return [ - f"{protocol}://{fp}" - if protocol != "file" and not fp.startswith(protocol) - else fp - for fp in files - ] + return [f"{protocol}://{fp}" for fp in files] def is_directory(self, url: str) -> bool: """ diff --git a/kloppy/infra/io/adapters/http.py b/kloppy/infra/io/adapters/http.py index c7279c428..e4bd156d4 100644 --- a/kloppy/infra/io/adapters/http.py +++ b/kloppy/infra/io/adapters/http.py @@ -65,3 +65,14 @@ def is_directory(self, url: str) -> bool: """ fs = self._get_filesystem(url, no_cache=True) return fs.isdir(url) + + def list_directory(self, url: str, recursive: bool = True) -> list[str]: + """ + Lists the contents of a directory. + """ + fs = self._get_filesystem(url) + if recursive: + files = fs.find(url, detail=False) + else: + files = fs.listdir(url, detail=False) + return files # already includes the http(s):// prefix diff --git a/kloppy/infra/io/adapters/zip.py b/kloppy/infra/io/adapters/zip.py index f74f1612a..71a174eab 100644 --- a/kloppy/infra/io/adapters/zip.py +++ b/kloppy/infra/io/adapters/zip.py @@ -16,8 +16,19 @@ class ZipAdapter(FSSpecAdapter): def supports(self, url: str) -> bool: return url.startswith("zip://") - def _get_filesystem( - self, url: str, no_cache: bool = False + def _get_filesystem_for_reading( + self, url: str + ) -> fsspec.AbstractFileSystem: + fo = get_config("adapters.zip.fo") + if fo is None: + raise AdapterError( + "No zip archive provided for the zip adapter." + " Please provide one using the 'adapters.zip.fo' config." + ) + return fsspec.filesystem(protocol="zip", fo=fo, mode="r") + + def _get_filesystem_for_writing( + self, url: str ) -> fsspec.AbstractFileSystem: fo = get_config("adapters.zip.fo") if fo is None: @@ -25,10 +36,12 @@ def _get_filesystem( "No zip archive provided for the zip adapter." " Please provide one using the 'adapters.zip.fo' config." ) - return fsspec.filesystem( - protocol="zip", - fo=fo, - ) + return fsspec.filesystem(protocol="zip", fo=fo, mode="a") + + def _get_filesystem( + self, url: str, no_cache: bool = False + ) -> fsspec.AbstractFileSystem: + return self._get_filesystem_for_reading(url) def list_directory(self, url: str, recursive: bool = True) -> list[str]: """ diff --git a/kloppy/infra/io/buffered_stream.py b/kloppy/infra/io/buffered_stream.py new file mode 100644 index 000000000..5bc1ba4fa --- /dev/null +++ b/kloppy/infra/io/buffered_stream.py @@ -0,0 +1,76 @@ +"""Buffered stream utilities for efficient I/O operations.""" + +import shutil +import tempfile +from typing import BinaryIO, Protocol + +DEFAULT_BUFFER_SIZE = 5 * 1024 * 1024 # 5MB before spilling to disk + + +class SupportsWrite(Protocol): + """Protocol for objects that support write operations.""" + + def write(self, data: bytes) -> int: ... + + +class SupportsRead(Protocol): + """Protocol for objects that support read operations.""" + + def read(self, n: int) -> bytes: ... + + +class BufferedStream(tempfile.SpooledTemporaryFile): + """A spooled temporary file that can efficiently copy from streams in chunks.""" + + def __init__(self, max_size: int = DEFAULT_BUFFER_SIZE, mode: str = "w+b"): + super().__init__(max_size=max_size, mode=mode) + + def write(self, data: bytes) -> int: # make it clearly bytes-only + return super().write(data) + + def read(self, n: int = -1) -> bytes: # make it clearly bytes-only + return super().read(n) + + @classmethod + def from_stream( + cls, + source: BinaryIO, + max_size: int = DEFAULT_BUFFER_SIZE, + chunk_size: int = 0, + ) -> "BufferedStream": + """ + Create a BufferedStream by copying data from source stream in chunks. + + Args: + source: The source binary stream to read from + max_size: Maximum size to keep in memory before spilling to disk + chunk_size: Size of chunks to keep in memory before spilling to disk + + Returns: + A BufferedStream containing the copied data + """ + buffer = cls(max_size=max_size) + buffer.read_from(source, chunk_size) + return buffer + + def read_from(self, source: SupportsRead, chunk_size: int = 0): + """ + Read data from source into this BufferedStream in chunks. + + Args: + source: The source that supports read() method + chunk_size: Size of chunks to copy at a time (0 uses default) + """ + shutil.copyfileobj(source, self, chunk_size) + self.seek(0) + + def write_to(self, output: SupportsWrite, chunk_size: int = 0) -> None: + """ + Write all contents of this BufferedStream to the output in chunks. + + Args: + output: The destination that supports write() method + chunk_size: Size of chunks to keep in memory before spilling to disk + """ + self.seek(0) + shutil.copyfileobj(self, output, chunk_size) diff --git a/kloppy/infra/serializers/code/base.py b/kloppy/infra/serializers/code/base.py index d726dca18..039c9a249 100644 --- a/kloppy/infra/serializers/code/base.py +++ b/kloppy/infra/serializers/code/base.py @@ -3,16 +3,17 @@ from kloppy.domain import CodeDataset -T = TypeVar("T") +T_I = TypeVar("T_I") +T_O = TypeVar("T_O") -class CodeDataDeserializer(ABC, Generic[T]): +class CodeDataDeserializer(ABC, Generic[T_I]): @abstractmethod - def deserialize(self, inputs: T) -> CodeDataset: + def deserialize(self, inputs: T_I) -> CodeDataset: raise NotImplementedError -class CodeDataSerializer(ABC): +class CodeDataSerializer(ABC, Generic[T_O]): @abstractmethod - def serialize(self, dataset: CodeDataset) -> bytes: + def serialize(self, dataset: CodeDataset, outputs: T_O) -> bool: raise NotImplementedError diff --git a/kloppy/infra/serializers/code/sportscode.py b/kloppy/infra/serializers/code/sportscode.py index 698de6675..bd063ddc1 100644 --- a/kloppy/infra/serializers/code/sportscode.py +++ b/kloppy/infra/serializers/code/sportscode.py @@ -45,6 +45,10 @@ class SportsCodeInputs(NamedTuple): data: IO[bytes] +class SportsCodeOutputs(NamedTuple): + data: IO[bytes] + + class SportsCodeDeserializer(CodeDataDeserializer[SportsCodeInputs]): def deserialize(self, inputs: SportsCodeInputs) -> CodeDataset: all_instances = objectify.fromstring(inputs.data.read()) @@ -88,8 +92,10 @@ def deserialize(self, inputs: SportsCodeInputs) -> CodeDataset: ) -class SportsCodeSerializer(CodeDataSerializer): - def serialize(self, dataset: CodeDataset) -> bytes: +class SportsCodeSerializer(CodeDataSerializer[SportsCodeOutputs]): + def serialize( + self, dataset: CodeDataset, outputs: SportsCodeOutputs + ) -> bool: root = etree.Element("file") all_instances = etree.SubElement(root, "ALL_INSTANCES") for i, code in enumerate(dataset.codes): @@ -138,10 +144,12 @@ def serialize(self, dataset: CodeDataset) -> bytes: text_ = etree.SubElement(label, "text") text_.text = str(text) - return etree.tostring( - root, - pretty_print=True, - xml_declaration=True, - encoding="utf-8", # This might not work with some tools because they expected 'ascii'. - method="xml", + outputs.data.write( + etree.tostring( + root, + pretty_print=True, + xml_declaration=True, + encoding="utf-8", # This might not work with some tools because they expected 'ascii'. + method="xml", + ) ) diff --git a/kloppy/io.py b/kloppy/io.py index ba4f7eefc..270dbc88b 100644 --- a/kloppy/io.py +++ b/kloppy/io.py @@ -11,18 +11,11 @@ import lzma import os import re -from typing import ( - IO, - Any, - BinaryIO, - Callable, - Optional, - TextIO, - Union, -) +from typing import IO, Any, BinaryIO, Callable, Optional, Union, cast from kloppy.exceptions import AdapterError, InputNotFoundError from kloppy.infra.io.adapters import get_adapter +from kloppy.infra.io.buffered_stream import BufferedStream logger = logging.getLogger(__name__) @@ -316,15 +309,35 @@ def get_file_extension(file_or_path: FileLike) -> str: @contextlib.contextmanager -def dummy_context_mgr() -> Generator[None, None, None]: - yield +def _write_context_manager( + uri: str, mode: str +) -> Generator[BinaryIO, None, None]: + """ + Context manager for write operations that buffers writes and flushes to adapter on exit. + + Args: + uri: The destination URI + mode: Write mode ('wb' or 'ab') + + Yields: + A BufferedStream for writing + """ + buffer = BufferedStream() + try: + yield buffer + finally: + adapter = get_adapter(uri) + if adapter: + adapter.write_from_stream(uri, buffer, mode) + else: + raise AdapterError(f"No adapter found for {uri}") def open_as_file( input_: FileLike, - encoding: Optional[str] = None, -) -> AbstractContextManager[Union[BinaryIO, TextIO, None]]: - """Open a byte stream to the given input object. + mode: str = "rb", +) -> AbstractContextManager[Optional[BinaryIO]]: + """Open a byte stream to/from the given input object. The following input types are supported: - A string or `pathlib.Path` object representing a local file path. @@ -340,83 +353,138 @@ def open_as_file( input types. Args: - input_ (FileLike): The input object to be opened. - encoding (str, optional): The name of the encoding used to decode or encode the - file. This should only be used in text mode. + input_ (FileLike): The input/output object to be opened. + mode (str): File mode - 'rb' (read), 'wb' (write), or 'ab' (append). + Defaults to 'rb'. Returns: - Union[BinaryIO, TextIO]: A stream to the input object. + BinaryIO: A binary stream to/from the input object. Raises: - ValueError: If the input is required but not provided. + ValueError: If the input is required but not provided, or invalid mode. InputNotFoundError: If the input file is not found and should not be skipped. TypeError: If the input type is not supported. + NotImplementedError: If write mode is used with unsupported input types. Example: + >>> # Reading >>> with open_as_file("example.txt") as f: ... contents = f.read() + >>> + >>> # Writing + >>> with open_as_file("output.txt", mode="wb") as f: + ... f.write(b"Hello, world!") Note: To support reading data from other sources, see the [Adapter](`kloppy.io.adapters.Adapter`) class. If the given file path or URL ends with '.gz', '.xz', or '.bz2', the - file will be decompressed before being read. + file will be automatically compressed/decompressed. + + Write mode limitations: + - HTTP/HTTPS URLs: Not supported + - Inline strings/bytes: Not supported (invalid output destination) """ + # 1. Handle Source wrapper logic first if isinstance(input_, Source): - if input_.data is None and input_.optional: - # This saves us some additional code in every vendor specific code - return dummy_context_mgr() - elif input_.data is None: + if input_.data is None: + if input_.optional: + return contextlib.nullcontext(None) raise ValueError("Input required but not provided.") - else: - try: - return open_as_file(input_.data, encoding=encoding) - except InputNotFoundError as exc: - if input_.skip_if_missing: - logging.info(f"Input {input_.data} not found. Skipping") - return dummy_context_mgr() - else: - raise exc - - stream: Union[BinaryIO, TextIO] - - if isinstance(input_, str) and ("{" in input_ or "<" in input_): - # If input_ is a JSON or XML string, return it as a binary stream - stream = BytesIO(input_.encode("utf8")) - - elif isinstance(input_, bytes): - # If input_ is a bytes object, return it as a binary stream - stream = BytesIO(input_) - - elif isinstance(input_, str) or hasattr(input_, "__fspath__"): - # If input_ is a path-like object, open it and return the binary stream - uri = _filepath_from_path_or_filelike(input_) + try: + return open_as_file(input_.data, mode=mode) + except InputNotFoundError: + if input_.skip_if_missing: + logger.info(f"Input {input_.data} not found. Skipping") + return contextlib.nullcontext(None) + raise + + # 2. Validate input for Write Modes + if mode in ("wb", "ab"): + if isinstance(input_, str) and ("{" in input_ or "<" in input_): + raise TypeError("Cannot write to inline JSON/XML string.") + if isinstance(input_, bytes): + raise TypeError( + "Cannot write to bytes object. Use BytesIO instead." + ) + + # 3. Handle Inline Data (Read Mode) + if mode == "rb": + if isinstance(input_, str) and ("{" in input_ or "<" in input_): + return contextlib.nullcontext(BytesIO(input_.encode("utf8"))) + if isinstance(input_, bytes): + return contextlib.nullcontext(BytesIO(input_)) + + # 4. Handle Adapter-based URIs/Paths + # Check if input looks like a path or string URI + if isinstance(input_, (str, os.PathLike)): + uri = _filepath_from_path_or_filelike(input_) adapter = get_adapter(uri) + if adapter: - stream = BytesIO() - adapter.read_to_stream(uri, stream) - stream.seek(0) - else: - raise AdapterError(f"No adapter found for {uri}") + if mode == "rb": + stream = BufferedStream() + adapter.read_to_stream(uri, stream) + stream.seek(0) + return contextlib.nullcontext(stream) + else: + return _write_context_manager(uri, mode) + + # check if the uri is a string with adapter prefix + elif isinstance(input_, str): + prefix_match = re.match(r"^([a-zA-Z0-9+.-]+)://", input_) + if prefix_match: + raise AdapterError( + f"No adapter found for {prefix_match.group(1)}://" + ) - elif isinstance(input_, TextIOWrapper): - # If file_or_path is a TextIOWrapper, return its underlying binary buffer - stream = input_.buffer + # If no adapter found, fall through to standard _open (local file handling) - elif hasattr(input_, "readinto"): - # If file_or_path is a file-like object, return it as is - stream = _open(input_) # type: ignore + # 5. Handle File Objects or Standard Local Files + if ( + hasattr(input_, "readinto") + or hasattr(input_, "write") + or isinstance(input_, (str, os.PathLike)) + ): + # --- Validation: Check mode compatibility for existing file objects --- + if not isinstance(input_, (str, os.PathLike)): + input_mode = getattr(input_, "mode", None) + if input_mode and input_mode != mode: + raise ValueError( + f"File opened in mode '{input_mode}' but '{mode}' requested" + ) - else: - raise TypeError(f"Unsupported input type: {type(input_)}") + # --- Processing: Open or wrap the input --- + # _open handles: + # 1. Opening paths + # 2. Extracting binary buffers from TextIOWrapper + # 3. Detecting compression (gzip, etc) and returning a Decompressor wrapper + opened = _open(input_, mode) + + # --- Ownership: Decide if we should close the file on exit --- + + # Case A: We created a new wrapper (e.g. opened a path, or wrapped BytesIO in GzipFile) + # We return the object directly so its __exit__ cleans up the wrapper. + # Note: We check if `opened` is different from `input_` AND different from `input_.buffer` + # (the latter handles the TextIOWrapper case where we don't want to close the wrapper). + is_transformed = opened is not input_ + if hasattr(input_, "buffer"): + is_transformed = is_transformed and opened is not input_.buffer - if encoding is not None: - stream = TextIOWrapper(stream, encoding=encoding) + if is_transformed: + # Exception: If the original input was a file object, and _open returned a + # compression wrapper (like GzipFile), closing GzipFile usually closes the + # underlying file. + return cast(AbstractContextManager, opened) - return stream + # Case B: It is the exact same raw stream (e.g. plain BytesIO) + # We wrap in nullcontext so we don't close the user's object. + return contextlib.nullcontext(opened) + + raise TypeError(f"Unsupported input type: {type(input_)}") def _natural_sort_key(path: str) -> list[Union[int, str]]: @@ -456,71 +524,41 @@ def expand_inputs( An iterator over the resolved file paths or stream content. """ - def is_file(uri): + def _get_adapter_safe(uri): adapter = get_adapter(uri) - if adapter: - return adapter.is_file(uri) - raise AdapterError(f"No adapter found for {uri}") - - def is_directory(uri): - adapter = get_adapter(uri) - if adapter: - return adapter.is_directory(uri) - raise AdapterError(f"No adapter found for {uri}") - - def process_expansion(files): - """ - Process a list of files by filtering and sorting them. - - Args: - files: List of file URIs to process. + if not adapter: + raise AdapterError(f"No adapter found for {uri}") + return adapter - Returns: - A sorted and filtered list of file URIs. - """ - files = [f for f in files if not is_directory(f)] + # 1. Handle Single String/Path Input + if isinstance(inputs, (str, os.PathLike)): + uri = _filepath_from_path_or_filelike(inputs) + adapter = _get_adapter_safe(uri) - if regex_filter: - pattern = re.compile(regex_filter) - files = [f for f in files if pattern.search(f)] + if adapter.is_directory(uri): + # Recursively expand directory contents + all_files = adapter.list_directory(uri, recursive=True) - files.sort(key=sort_key or _natural_sort_key) - return files + # Apply Filter + if regex_filter: + pattern = re.compile(regex_filter) + all_files = [f for f in all_files if pattern.search(f)] - if isinstance(inputs, (str, os.PathLike)): - uri = _filepath_from_path_or_filelike(inputs) + # Apply Sort + all_files.sort(key=sort_key or _natural_sort_key) - if is_directory(uri): - adapter = get_adapter(uri) - if adapter: - yield from process_expansion( - adapter.list_directory(uri, recursive=True) - ) - else: - raise AdapterError(f"No adapter found for {uri}") - elif is_file(uri): + yield from all_files + elif adapter.is_file(uri): yield uri else: raise InputNotFoundError(f"Invalid path or file: {inputs}") - elif isinstance(inputs, Iterable): + # 2. Handle Iterable Input + elif isinstance(inputs, Iterable) and not isinstance(inputs, (str, bytes)): for item in inputs: - if isinstance(item, (str, os.PathLike)): - uri = _filepath_from_path_or_filelike(item) - if is_file(uri): - yield uri - elif is_directory(uri): - adapter = get_adapter(uri) - if adapter: - yield from process_expansion( - adapter.list_directory(uri, recursive=True) - ) - else: - raise AdapterError(f"No adapter found for {uri}") - else: - raise InputNotFoundError(f"Invalid path or file: {item}") - else: - yield item + # Recursive call allows mixed lists of directories and files + yield from expand_inputs(item, regex_filter, sort_key) + # 3. Handle Single Object Input (BytesIO, etc) else: yield inputs diff --git a/kloppy/tests/test_io.py b/kloppy/tests/test_io.py index 00834ce26..87574a4a6 100644 --- a/kloppy/tests/test_io.py +++ b/kloppy/tests/test_io.py @@ -6,6 +6,7 @@ import os from pathlib import Path import sys +from typing import BinaryIO, Optional import zipfile from botocore.session import Session @@ -14,174 +15,242 @@ from kloppy.config import set_config from kloppy.exceptions import InputNotFoundError +from kloppy.infra.io import adapters +from kloppy.infra.io.adapters import Adapter +from kloppy.infra.io.buffered_stream import BufferedStream from kloppy.io import expand_inputs, get_file_extension, open_as_file +# --- Shared Helpers --- -@pytest.fixture() -def filesystem_content(tmp_path: Path) -> Path: - """Set up the content to be read from a local filesystem.""" - content = "Hello, world!" - content_bytes = content.encode("utf-8") - # Create a regular text file - text_file = tmp_path / "testfile.txt" - text_file.write_text(content) +def create_test_files(base_path: Path, content: str = "Hello, world!"): + """Helper to generate standard test files (plain and compressed).""" + # Plain text + (base_path / "testfile.txt").write_text(content) - # Create a gzip-compressed file - gz_file = tmp_path / "testfile.txt.gz" - with gzip.open(gz_file, "wb") as f_out: - f_out.write(content_bytes) + # Compressed formats + compressors = { + ".gz": gzip.open, + ".xz": lzma.open, + ".bz2": bz2.open, + } - # Create a xz-compressed file - xz_file = tmp_path / "testfile.txt.xz" - with lzma.open(xz_file, "wb") as f_out: - f_out.write(content_bytes) + for ext, opener in compressors.items(): + with opener(base_path / f"testfile.txt{ext}", "wb") as f: + f.write(content.encode("utf-8")) - # Create a bzip2-compressed file - bz2_file = tmp_path / "testfile.txt.bz2" - with bz2.open(bz2_file, "wb") as f_out: - f_out.write(content_bytes) +@pytest.fixture +def populated_dir(tmp_path: Path) -> Path: + """Fixture that returns a directory populated with standard test files.""" + create_test_files(tmp_path) return tmp_path +# --- Core IO Unit Tests --- + + +class TestBufferedStream: + """Tests for BufferedStream chunked copying.""" + + def test_from_stream_small_data(self): + """It should copy small data in chunks and keep in memory.""" + source = BytesIO(b"Small data content") + buffer = BufferedStream.from_stream(source, chunk_size=8) + + assert buffer.read() == b"Small data content" + assert buffer._rolled is False # Still in memory + + def test_from_stream_large_data(self): + """It should spill large data to disk.""" + buffer_size = 5 * 1024 * 1024 # 5MB + large_data = b"x" * (buffer_size + 1000) + source = BytesIO(large_data) + buffer = BufferedStream.from_stream(source, max_size=buffer_size) + + assert buffer._rolled is True # Spilled to disk + assert buffer.read() == large_data + + class TestOpenAsFile: - """Tests for the open_as_file function.""" + """Tests for core open_as_file read/write functionality.""" + + @pytest.fixture(params=[True, False], ids=["with_adapters", "no_adapters"]) + def setup_adapters(self, request, monkeypatch): + """ + Fixture that runs tests in two states: + 1. Default state (adapters enabled). + 2. Patched state (adapters list empty). + """ + if not request.param: + monkeypatch.setattr(adapters, "adapters", []) - def test_bytes(self): + # --- Read Tests --- + + def test_read_bytes(self): """It should be able to open a bytes object as a file.""" with open_as_file(b"Hello, world!") as fp: - assert fp is not None assert fp.read() == b"Hello, world!" - def test_data_string(self): + def test_read_data_string(self): """It should be able to open a json/xml string as a file.""" with open_as_file('{"msg": "Hello, world!"}') as fp: - assert fp is not None assert json.load(fp) == {"msg": "Hello, world!"} - def test_stream(self): + def test_read_stream(self): """It should be able to open a byte stream as a file.""" data = b"Hello, world!" with open_as_file(BytesIO(data)) as fp: - assert fp is not None assert fp.read() == data @pytest.mark.parametrize( "compress_func", - [ - gzip.compress, - bz2.compress, - lzma.compress, - ], + [gzip.compress, bz2.compress, lzma.compress], ids=["gzip", "bz2", "xz"], ) - def test_compressed_stream(self, compress_func): + def test_read_compressed_stream(self, compress_func): """It should be able to open a compressed byte stream as a file.""" data = compress_func(b"Hello, world!") with open_as_file(BytesIO(data)) as fp: - assert fp is not None - assert fp.read() == b"Hello, world!" - - def test_path_str(self, filesystem_content: Path): - """It should be able to open a file from a string path.""" - path = str(filesystem_content / "testfile.txt") - with open_as_file(path) as fp: - assert fp is not None assert fp.read() == b"Hello, world!" - def test_path_obj(self, filesystem_content: Path): - """It should be able to open a file from a Path object.""" - path = filesystem_content / "testfile.txt" + @pytest.mark.parametrize( + "path_type", [str, Path], ids=["str_path", "Path_obj"] + ) + def test_read_local_file_paths( + self, populated_dir, path_type, setup_adapters + ): + """It should be able to open a local file (with and without adapters).""" + path = path_type(populated_dir / "testfile.txt") with open_as_file(path) as fp: - assert fp is not None assert fp.read() == b"Hello, world!" @pytest.mark.parametrize("ext", ["gz", "xz", "bz2"]) - def test_path_compressed(self, filesystem_content: Path, ext: str): - """It should be able to open a compressed local file.""" - path = filesystem_content / f"testfile.txt.{ext}" + def test_read_compressed_local_file( + self, populated_dir, ext, setup_adapters + ): + """It should be able to open a compressed local file (with and without adapters).""" + path = populated_dir / f"testfile.txt.{ext}" with open_as_file(path) as fp: - assert fp is not None assert fp.read() == b"Hello, world!" - def test_path_missing(self, filesystem_content: Path): + def test_read_missing_file(self, tmp_path): """It should raise an error if the file is not found.""" - path = filesystem_content / "missing.txt" with pytest.raises(InputNotFoundError): - with open_as_file(path) as _: - pass + open_as_file(tmp_path / "missing.txt") + + def test_read_opened_file(self, populated_dir): + """It should return the same file object if already opened.""" + path = populated_dir / "testfile.txt" + with open_as_file(path.open("rb")) as fp: + assert fp.read() == b"Hello, world!" + + # --- Write Tests --- + + def test_write_stream(self): + """It should be able to write to a byte stream.""" + buffer = BytesIO() + with open_as_file(buffer, mode="wb") as fp: + fp.write(b"In-memory write") + + buffer.seek(0) + assert buffer.read() == b"In-memory write" + + @pytest.mark.parametrize( + "path_type", [str, Path], ids=["str_path", "Path_obj"] + ) + def test_write_local_file(self, tmp_path, path_type, setup_adapters): + """It should be able to write to a local file (with and without adapters).""" + output_path = path_type(tmp_path / "output.txt") + with open_as_file(output_path, mode="wb") as fp: + fp.write(b"Hello, write!") + + assert (tmp_path / "output.txt").read_bytes() == b"Hello, write!" + + @pytest.mark.parametrize( + "ext, opener", + [("gz", gzip.open), ("bz2", bz2.open), ("xz", lzma.open)], + ids=["gzip", "bz2", "xz"], + ) + def test_write_compressed_file(self, tmp_path, ext, opener, setup_adapters): + """It should be able to write compressed files (with and without adapters).""" + output_path = tmp_path / f"output.txt.{ext}" + content = b"Compressed content" + + with open_as_file(output_path, mode="wb") as fp: + fp.write(content) + + # Verify by reading back + with opener(output_path, "rb") as f: + assert f.read() == content + + def test_write_opened_file(self, tmp_path): + """It should write to the same file object if already opened.""" + output_path = tmp_path / "output.txt" + output_file = output_path.open("wb") + with open_as_file(output_file, mode="wb") as fp: + fp.write(b"Hello, opened write!") + output_file.close() + + assert output_path.read_bytes() == b"Hello, opened write!" + + def test_mode_conflict(self, populated_dir): + """It should raise an error if mode conflicts with opened file.""" + path = populated_dir / "testfile.txt" + with pytest.raises(ValueError): + open_as_file(path.open("r"), mode="wb") + with pytest.raises(ValueError): + open_as_file(path.open("wb"), mode="rb") + with pytest.raises(ValueError): + open_as_file(path.open("rb"), mode="wb") class TestExpandInputs: @pytest.fixture - def mock_filesystem(self, tmp_path): + def mock_fs(self, tmp_path): # Create a temporary directory structure - file1 = tmp_path / "file1.txt" - file2 = tmp_path / "file2.log" - subdir = tmp_path / "subdir" - subdir.mkdir() - file3 = subdir / "file3.txt" - - file1.write_text("Content of file1") - file2.write_text("Content of file2") - file3.write_text("Content of file3") + (tmp_path / "file1.txt").touch() + (tmp_path / "file2.log").touch() + (tmp_path / "subdir").mkdir() + (tmp_path / "subdir" / "file3.txt").touch() + # Return dict mapping keys to absolute string paths return { - "root": str(tmp_path.as_posix()), - "file1": str(file1.as_posix()), - "file2": str(file2.as_posix()), - "subdir": str(subdir.as_posix()), - "file3": str(file3.as_posix()), + "root": os.fspath(tmp_path), + "file1": os.fspath(tmp_path / "file1.txt"), + "file2": os.fspath(tmp_path / "file2.log"), + "file3": os.fspath(tmp_path / "subdir" / "file3.txt"), } - def test_single_file(self, mock_filesystem): - files = list(expand_inputs(mock_filesystem["file1"])) - assert files == [mock_filesystem["file1"]] - - def test_directory_expansion(self, mock_filesystem): - files = sorted(expand_inputs(mock_filesystem["root"])) - expected_files = sorted( - [ - mock_filesystem["file1"], - mock_filesystem["file2"], - mock_filesystem["file3"], - ] - ) - assert files == expected_files + def test_single_file(self, mock_fs): + assert list(expand_inputs(mock_fs["file1"])) == [mock_fs["file1"]] - def test_regex_filter(self, mock_filesystem): - files = list( - expand_inputs(mock_filesystem["root"], regex_filter=r".*.txt$") - ) - expected_files = [ - mock_filesystem["file1"], - mock_filesystem["file3"], - ] - assert sorted(files) == sorted(expected_files) - - def test_sort_key(self, mock_filesystem): - files = list( - expand_inputs(mock_filesystem["root"], sort_key=lambda x: x[::-1]) + def test_directory_expansion(self, mock_fs): + expected = sorted( + [mock_fs["file1"], mock_fs["file2"], mock_fs["file3"]] ) - expected_files = sorted( - [ - mock_filesystem["file1"], - mock_filesystem["file2"], - mock_filesystem["file3"], - ], + assert sorted(expand_inputs(mock_fs["root"])) == expected + + def test_regex_filter(self, mock_fs): + expected = sorted([mock_fs["file1"], mock_fs["file3"]]) + files = list(expand_inputs(mock_fs["root"], regex_filter=r".*.txt$")) + assert sorted(files) == expected + + def test_sort_key(self, mock_fs): + expected = sorted( + [mock_fs["file1"], mock_fs["file2"], mock_fs["file3"]], key=lambda x: x[::-1], ) - assert files == expected_files + files = list(expand_inputs(mock_fs["root"], sort_key=lambda x: x[::-1])) + assert files == expected - def test_list_of_files(self, mock_filesystem): - input_list = [mock_filesystem["file1"], mock_filesystem["file2"]] - files = list(expand_inputs(input_list)) - assert files == input_list + def test_list_of_files(self, mock_fs): + inputs = [mock_fs["file1"], mock_fs["file2"]] + assert list(expand_inputs(inputs)) == inputs def test_invalid_path(self): with pytest.raises(InputNotFoundError): - list(expand_inputs("nonexistent_file.txt")) + list(expand_inputs("nonexistent.txt")) def test_get_file_extension(): @@ -191,185 +260,325 @@ def test_get_file_extension(): assert get_file_extension("data") == "" +# --- Adapter Integration Tests --- + + +class MockAdapter(Adapter): + """ + Generic Mock adapter storing data in memory. + Supports both read and write testing. + """ + + def __init__(self, initial_data: Optional[dict[str, bytes]] = None): + self.storage = initial_data if initial_data else {} + + def supports(self, url: str) -> bool: + return url.startswith("mock://") + + def is_directory(self, url: str) -> bool: + return url not in self.storage and url.endswith("/") + + def is_file(self, url: str) -> bool: + return url in self.storage + + def read_to_stream(self, url: str, output: BinaryIO): + if url in self.storage: + output.write(self.storage[url]) + else: + raise FileNotFoundError(f"Mock file not found: {url}") + + def write_from_stream(self, url: str, input: BinaryIO, mode: str): # noqa: A002 + input.seek(0) + self.storage[url] = input.read() + + def list_directory(self, url: str, recursive: bool = True) -> list[str]: + return [k for k in self.storage.keys() if k.startswith(url)] + + +class TestMockAdapter: + """Tests for generic Adapter logic using the in-memory MockAdapter.""" + + @pytest.fixture + def adapter_setup(self, monkeypatch): + # Pre-seed some data + mock_adapter = MockAdapter( + { + "mock://read/data.txt": b"Pre-existing content", + "mock://read/config.json": b'{"foo": "bar"}', + } + ) + + # Inject adapter + from kloppy.infra.io import adapters + + monkeypatch.setattr( + adapters, "adapters", [mock_adapter] + adapters.adapters + ) + return mock_adapter + + def test_expand_inputs(self, adapter_setup): + expected = {"mock://read/data.txt", "mock://read/config.json"} + assert set(expand_inputs("mock://read/")) == expected + + def test_read_via_adapter(self, adapter_setup): + with open_as_file("mock://read/data.txt") as fp: + assert fp.read() == b"Pre-existing content" + + def test_write_via_adapter(self, adapter_setup): + with open_as_file("mock://write/new.txt", mode="wb") as fp: + fp.write(b"New data") + + # Verify directly in storage + assert adapter_setup.storage["mock://write/new.txt"] == b"New data" + + # Verify via read + with open_as_file("mock://write/new.txt") as fp: + assert fp.read() == b"New data" + + +class TestFileAdapter: + """Tests for FileAdapter.""" + + @pytest.fixture(autouse=True) + def setup_files(self, populated_dir): + self.root_dir = populated_dir + + def test_expand_inputs(self): + """It should be able to list the contents of a local directory.""" + found = set(expand_inputs(str(self.root_dir))) + assert found == { + str(self.root_dir / f) + for f in [ + "testfile.txt", + "testfile.txt.gz", + "testfile.txt.bz2", + "testfile.txt.xz", + ] + } + + def test_read_via_adapter(self): + """It should be able to open a file from the local filesystem.""" + path = self.root_dir / "testfile.txt" + with open_as_file(str(path)) as fp: + assert fp.read() == b"Hello, world!" + + def test_read_compressed_via_adapter(self): + """It should be able to open and decompress a file from the local filesystem.""" + path = self.root_dir / "testfile.txt.gz" + with open_as_file(str(path)) as fp: + assert fp.read() == b"Hello, world!" + + def test_write_via_adapter(self): + """It should be able to write a file to the local filesystem.""" + path = self.root_dir / "new_file.txt" + with open_as_file(str(path), mode="wb") as fp: + fp.write(b"New written data") + + assert path.exists() + with open(path, "rb") as f: + assert f.read() == b"New written data" + + def test_write_compressed_via_adapter(self): + """It should be able to write a compressed file to the local filesystem.""" + path = self.root_dir / "new_file.txt.gz" + with open_as_file(str(path), mode="wb") as fp: + fp.write(b"New compressed data") + + assert path.exists() + with gzip.open(path, "rb") as f: + assert f.read() == b"New compressed data" + + class TestHTTPAdapter: + """Tests for HTTPAdapter.""" + @pytest.fixture(autouse=True) - def httpserver_content(self, httpserver): + def httpserver_content(self, httpserver, tmp_path): """Set up the content to be read from an HTTP server.""" - # Define the content - content = "Hello, world!" - compressed_content = gzip.compress(b"Hello, world!") + # 1. Generate standard files + create_test_files(tmp_path) - # Serve the plain text file - httpserver.expect_request("/testfile.txt").respond_with_data(content) + # 2. Read binaries to serve + txt_content = (tmp_path / "testfile.txt").read_bytes() + gz_content = (tmp_path / "testfile.txt.gz").read_bytes() + + # 3. Configure Server + httpserver.expect_request("/testfile.txt").respond_with_data( + txt_content + ) - # Serve the compressed text file with Content-Encoding header - httpserver.expect_request("/compressed_testfile.txt").respond_with_data( - compressed_content, + # Serve compressed content with explicit headers + httpserver.expect_request("/compressed_endpoint").respond_with_data( + gz_content, headers={"Content-Encoding": "gzip", "Content-Type": "text/plain"}, ) - # Serve the gzip file with application/x-gzip content type + # Serve generic .gz file httpserver.expect_request("/testfile.txt.gz").respond_with_data( - compressed_content, - headers={"Content-Type": "application/x-gzip"}, + gz_content, headers={"Content-Type": "application/x-gzip"} ) - # Generate the index.html content with links to all resources - index_html = f""" - -