From 62de161c3647cb3fa0698db9b3aed2e504339766 Mon Sep 17 00:00:00 2001 From: Irozuku Date: Thu, 26 Feb 2026 12:15:44 -0300 Subject: [PATCH] refactor: defer heavy/optional imports to functions and improve modularity Move several heavy or optional dependencies from module scope into local function scope to enable deferred loading. This reduces initial import time. --- .../dataloaders/classes/dashai_dataset.py | 97 +++++++++++++------ 1 file changed, 66 insertions(+), 31 deletions(-) diff --git a/DashAI/back/dataloaders/classes/dashai_dataset.py b/DashAI/back/dataloaders/classes/dashai_dataset.py index e5100a305..320e952e6 100644 --- a/DashAI/back/dataloaders/classes/dashai_dataset.py +++ b/DashAI/back/dataloaders/classes/dashai_dataset.py @@ -1,19 +1,11 @@ """DashAI Dataset implementation.""" -import json import logging import os -from contextlib import suppress -import numpy as np -import pyarrow as pa -import pyarrow.ipc as ipc from beartype import beartype from beartype.typing import Dict, List, Literal, Optional, Tuple, Union -from datasets import ClassLabel, Dataset, DatasetDict, Value, concatenate_datasets -from datasets.table import InMemoryTable -from pandas import DataFrame -from sklearn.model_selection import train_test_split +from datasets import Dataset from DashAI.back.types.categorical import Categorical from DashAI.back.types.dashai_data_type import DashAIDataType @@ -29,7 +21,7 @@ log = logging.getLogger(__name__) -def get_arrow_table(ds: Dataset) -> pa.Table: +def get_arrow_table(ds: Dataset) -> object: """ Retrieve the underlying PyArrow table from a Hugging Face Dataset. This function abstracts away the need to access private attributes. @@ -38,7 +30,7 @@ def get_arrow_table(ds: Dataset) -> pa.Table: ds (Dataset): A Hugging Face Dataset. Returns: - pa.Table: The underlying PyArrow table. + object: The underlying PyArrow table. Raises: ValueError: If the arrow table cannot be retrieved. @@ -57,7 +49,7 @@ class DashAIDataset(Dataset): @beartype def __init__( self, - table: Union[pa.Table, InMemoryTable], + table: object, splits: dict = None, types: Optional[Dict[str, DashAIDataType]] = None, *args, @@ -100,12 +92,12 @@ def cast(self, *args, **kwargs) -> "DashAIDataset": return DashAIDataset(arrow_tbl, splits=self.splits, types=self._types) @property - def arrow_table(self) -> pa.Table: + def arrow_table(self) -> object: """ Provides a clean way to access the underlying PyArrow table. Returns: - pa.Table: The underlying PyArrow table. + object: The underlying PyArrow table. """ try: # Now we reference the pa.table from here (DashAIDataset) @@ -147,6 +139,8 @@ def change_columns_type(self, column_types: Dict[str, str]) -> "DashAIDataset": DashAIDataset The dataset after columns type changes. """ + from datasets import Value as HFValue + if not isinstance(column_types, dict): raise TypeError(f"types should be a dict, got {type(column_types)}") @@ -163,7 +157,7 @@ def change_columns_type(self, column_types: Dict[str, str]) -> "DashAIDataset": if column_types[column] == "Categorical": new_features[column] = encode_labels(self, column) elif column_types[column] == "Numerical": - new_features[column] = Value("float32") + new_features[column] = HFValue("float32") dataset = self.cast(new_features) return dataset @@ -529,6 +523,8 @@ def sample( Dict A dictionary with selected samples. """ + import numpy as np + if n > len(self): raise ValueError( "Number of samples must be less than or equal to the length " @@ -630,7 +626,7 @@ def select(self, *args, **kwargs) -> "DashAIDataset": @beartype -def merge_splits_with_metadata(dataset_dict: DatasetDict) -> DashAIDataset: +def merge_splits_with_metadata(dataset_dict: object) -> DashAIDataset: """ Merges the splits from a DatasetDict into a single DashAIDataset and records the original indices for each split in the metadata. @@ -644,6 +640,8 @@ def merge_splits_with_metadata(dataset_dict: DatasetDict) -> DashAIDataset: original split indices. """ + from datasets import concatenate_datasets # local import + concatenated_datasets = [] split_index = {} current_index = 0 @@ -699,6 +697,8 @@ def transform_dataset_with_schema( DashAIDataset - The updated dataset with new type information """ + import pyarrow as pa # local import + table = get_arrow_table(dataset) dai_table = {} my_schema = pa.schema([]) @@ -787,10 +787,14 @@ def save_dataset( if schema is not None: dataset = transform_dataset_with_schema(dataset, schema) + import json + + import pyarrow as pa # local import + table = get_arrow_table(dataset) data_filepath = os.path.join(path, "data.arrow") with pa.OSFile(data_filepath, "wb") as sink: - writer = ipc.new_file(sink, table.schema) + writer = pa.ipc.new_file(sink, table.schema) writer.write_table(table) writer.close() @@ -823,9 +827,13 @@ def load_dataset(dataset_path: Union[str, os.PathLike]) -> DashAIDataset: DashAIDataset: The loaded dataset with data and metadata. """ + import json + + import pyarrow as pa # local import + data_filepath = os.path.join(dataset_path, "data.arrow") with pa.OSFile(data_filepath, "rb") as source: - reader = ipc.open_file(source) + reader = pa.ipc.open_file(source) data = reader.read_all() metadata_filepath = os.path.join(dataset_path, "splits.json") if os.path.exists(metadata_filepath): @@ -845,7 +853,7 @@ def load_dataset(dataset_path: Union[str, os.PathLike]) -> DashAIDataset: def encode_labels( dataset: DashAIDataset, column_name: str, -) -> ClassLabel: +) -> object: """Encode a categorical column into numerical labels and return the ClassLabel feature. @@ -864,6 +872,8 @@ def encode_labels( if column_name not in dataset.column_names: raise ValueError(f"Column '{column_name}' does not exist in the dataset.") + from datasets import ClassLabel # local import + names = list(set(dataset[column_name])) class_label_feature = ClassLabel(names=names) return class_label_feature @@ -950,6 +960,9 @@ def split_indexes( # Generate shuffled indexes if seed is None: seed = 42 + import numpy as np + from sklearn.model_selection import train_test_split + indexes = np.arange(total_rows) stratify_labels = np.array(labels) if stratify else None @@ -1005,7 +1018,7 @@ def split_dataset( train_indexes: List = None, test_indexes: List = None, val_indexes: List = None, -) -> DatasetDict: +) -> object: """ Split the dataset in train, test and validation subsets. If indexes are not provided, it will use the split indices @@ -1032,7 +1045,11 @@ def split_dataset( ValueError Must provide all indexes or none. """ + import numpy as np + if all(idx is None for idx in [train_indexes, test_indexes, val_indexes]): + from datasets import DatasetDict + train_dataset = dataset.get_split("train") test_dataset = dataset.get_split("test") val_dataset = dataset.get_split("validation") @@ -1055,6 +1072,8 @@ def split_dataset( val_mask = np.isin(np.arange(n), val_indexes) # Get the underlying table + import pyarrow as pa # local import + table = dataset.arrow_table dataset.splits["split_indices"] = { @@ -1069,6 +1088,8 @@ def split_dataset( val_table = table.filter(pa.array(val_mask)) # Preserve types from the original dataset to maintain categorical mappings + from datasets import DatasetDict # local import + separate_dataset_dict = DatasetDict( { "train": DashAIDataset(train_table, types=dataset.types), @@ -1081,7 +1102,7 @@ def split_dataset( def to_dashai_dataset( - dataset: Union[DatasetDict, Dataset, DashAIDataset, DataFrame], + dataset: object, types: Optional[Dict[str, DashAIDataType]] = None, ) -> DashAIDataset: """ @@ -1102,19 +1123,27 @@ def to_dashai_dataset( if isinstance(dataset, DashAIDataset): return dataset - if isinstance(dataset, Dataset): + from datasets import Dataset as HFDataset # local import + + if isinstance(dataset, HFDataset): arrow_tbl = get_arrow_table(dataset) if types: types_serialized = {col: types[col].to_string() for col in types} arrow_tbl = save_types_in_arrow_metadata(arrow_tbl, types_serialized) return DashAIDataset(arrow_tbl, types=types) - if isinstance(dataset, DataFrame): - hf_dataset = Dataset.from_pandas(dataset, preserve_index=False) + try: + from pandas import DataFrame as PDDataFrame # local import + except Exception: + PDDataFrame = None + if PDDataFrame is not None and isinstance(dataset, PDDataFrame): + hf_dataset = HFDataset.from_pandas(dataset, preserve_index=False) arrow_tbl = get_arrow_table(hf_dataset) if types: types_serialized = {col: types[col].to_string() for col in types} arrow_tbl = save_types_in_arrow_metadata(arrow_tbl, types_serialized) return DashAIDataset(arrow_tbl, types=types) + from datasets import DatasetDict # local import + if isinstance(dataset, DatasetDict) and len(dataset) == 1: key = list(dataset.keys())[0] ds = dataset[key] @@ -1131,7 +1160,7 @@ def to_dashai_dataset( @beartype def validate_inputs_outputs( - datasetdict: Union[DatasetDict, DashAIDataset], + datasetdict: object, inputs: List[str], outputs: List[str], ) -> None: @@ -1170,9 +1199,7 @@ def validate_inputs_outputs( @beartype -def get_column_names_from_indexes( - dataset: Union[DashAIDataset, DatasetDict], indexes: List[int] -) -> List[str]: +def get_column_names_from_indexes(dataset: object, indexes: List[int]) -> List[str]: """Obtain the column labels that correspond to the provided indexes. Note: indexing starts from 1. @@ -1206,7 +1233,7 @@ def get_column_names_from_indexes( @beartype def select_columns( - dataset: Union[DatasetDict, DashAIDataset], + dataset: object, input_columns: List[str], output_columns: List[str], ) -> Tuple[DashAIDataset, DashAIDataset]: @@ -1252,6 +1279,8 @@ def get_columns_spec(dataset_path: str) -> Dict[str, Dict]: """ data_filepath = os.path.join(dataset_path, "data.arrow") + import pyarrow as pa # local import + with pa.OSFile(data_filepath, "rb") as source: reader = pa.ipc.open_file(source) schema = reader.schema @@ -1334,6 +1363,8 @@ def get_dataset_info(dataset_path: str) -> object: object Dictionary with the information of the dataset """ + import json + metadata_filepath = os.path.join(dataset_path, "splits.json") if os.path.exists(metadata_filepath): with open(metadata_filepath, "r", encoding="utf-8") as f: @@ -1404,8 +1435,10 @@ def update_dataset_splits( # I think it could be simplified since DashAITypes, but I don't want to break anything def prepare_for_model_session( dataset: DashAIDataset, splits: dict, output_columns: List[str] -) -> DatasetDict: +) -> object: """Prepare the dataset for a model session by updating the splits configuration""" + from contextlib import suppress + splitType = splits.get("splitType") if splitType == "manual" or splitType == "predefined": splits_index = splits @@ -1468,7 +1501,7 @@ def prepare_for_model_session( def modify_table( dataset: DashAIDataset, - columns: Dict[str, pa.Array], + columns: Dict[str, object], types: Optional[Dict[str, DashAIDataType]] = None, ) -> DashAIDataset: """ @@ -1487,6 +1520,8 @@ def modify_table( DashAIDataset The modified dataset with the updated column type. """ + import pyarrow as pa + original_table = dataset.arrow_table updated_columns = {}