Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 66 additions & 31 deletions DashAI/back/dataloaders/classes/dashai_dataset.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,11 @@
"""DashAI Dataset implementation."""

import json
import logging
import os
from contextlib import suppress

import numpy as np
import pyarrow as pa
import pyarrow.ipc as ipc
from beartype import beartype
from beartype.typing import Dict, List, Literal, Optional, Tuple, Union
from datasets import ClassLabel, Dataset, DatasetDict, Value, concatenate_datasets
from datasets.table import InMemoryTable
from pandas import DataFrame
from sklearn.model_selection import train_test_split
from datasets import Dataset

from DashAI.back.types.categorical import Categorical
from DashAI.back.types.dashai_data_type import DashAIDataType
Expand All @@ -29,7 +21,7 @@
log = logging.getLogger(__name__)


def get_arrow_table(ds: Dataset) -> pa.Table:
def get_arrow_table(ds: Dataset) -> object:
"""
Retrieve the underlying PyArrow table from a Hugging Face Dataset.
This function abstracts away the need to access private attributes.
Expand All @@ -38,7 +30,7 @@ def get_arrow_table(ds: Dataset) -> pa.Table:
ds (Dataset): A Hugging Face Dataset.

Returns:
pa.Table: The underlying PyArrow table.
object: The underlying PyArrow table.

Raises:
ValueError: If the arrow table cannot be retrieved.
Expand All @@ -57,7 +49,7 @@ class DashAIDataset(Dataset):
@beartype
def __init__(
self,
table: Union[pa.Table, InMemoryTable],
table: object,
splits: dict = None,
types: Optional[Dict[str, DashAIDataType]] = None,
*args,
Expand Down Expand Up @@ -100,12 +92,12 @@ def cast(self, *args, **kwargs) -> "DashAIDataset":
return DashAIDataset(arrow_tbl, splits=self.splits, types=self._types)

@property
def arrow_table(self) -> pa.Table:
def arrow_table(self) -> object:
"""
Provides a clean way to access the underlying PyArrow table.

Returns:
pa.Table: The underlying PyArrow table.
object: The underlying PyArrow table.
"""
try:
# Now we reference the pa.table from here (DashAIDataset)
Expand Down Expand Up @@ -147,6 +139,8 @@ def change_columns_type(self, column_types: Dict[str, str]) -> "DashAIDataset":
DashAIDataset
The dataset after columns type changes.
"""
from datasets import Value as HFValue

if not isinstance(column_types, dict):
raise TypeError(f"types should be a dict, got {type(column_types)}")

Expand All @@ -163,7 +157,7 @@ def change_columns_type(self, column_types: Dict[str, str]) -> "DashAIDataset":
if column_types[column] == "Categorical":
new_features[column] = encode_labels(self, column)
elif column_types[column] == "Numerical":
new_features[column] = Value("float32")
new_features[column] = HFValue("float32")
dataset = self.cast(new_features)
return dataset

Expand Down Expand Up @@ -529,6 +523,8 @@ def sample(
Dict
A dictionary with selected samples.
"""
import numpy as np

if n > len(self):
raise ValueError(
"Number of samples must be less than or equal to the length "
Expand Down Expand Up @@ -630,7 +626,7 @@ def select(self, *args, **kwargs) -> "DashAIDataset":


@beartype
def merge_splits_with_metadata(dataset_dict: DatasetDict) -> DashAIDataset:
def merge_splits_with_metadata(dataset_dict: object) -> DashAIDataset:
"""
Merges the splits from a DatasetDict into a single DashAIDataset and records
the original indices for each split in the metadata.
Expand All @@ -644,6 +640,8 @@ def merge_splits_with_metadata(dataset_dict: DatasetDict) -> DashAIDataset:
original split indices.
"""

from datasets import concatenate_datasets # local import

concatenated_datasets = []
split_index = {}
current_index = 0
Expand Down Expand Up @@ -699,6 +697,8 @@ def transform_dataset_with_schema(
DashAIDataset
- The updated dataset with new type information
"""
import pyarrow as pa # local import

table = get_arrow_table(dataset)
dai_table = {}
my_schema = pa.schema([])
Expand Down Expand Up @@ -787,10 +787,14 @@ def save_dataset(
if schema is not None:
dataset = transform_dataset_with_schema(dataset, schema)

import json

import pyarrow as pa # local import

table = get_arrow_table(dataset)
data_filepath = os.path.join(path, "data.arrow")
with pa.OSFile(data_filepath, "wb") as sink:
writer = ipc.new_file(sink, table.schema)
writer = pa.ipc.new_file(sink, table.schema)
writer.write_table(table)
writer.close()

Expand Down Expand Up @@ -823,9 +827,13 @@ def load_dataset(dataset_path: Union[str, os.PathLike]) -> DashAIDataset:
DashAIDataset: The loaded dataset with data and metadata.
"""

import json

import pyarrow as pa # local import

data_filepath = os.path.join(dataset_path, "data.arrow")
with pa.OSFile(data_filepath, "rb") as source:
reader = ipc.open_file(source)
reader = pa.ipc.open_file(source)
data = reader.read_all()
metadata_filepath = os.path.join(dataset_path, "splits.json")
if os.path.exists(metadata_filepath):
Expand All @@ -845,7 +853,7 @@ def load_dataset(dataset_path: Union[str, os.PathLike]) -> DashAIDataset:
def encode_labels(
dataset: DashAIDataset,
column_name: str,
) -> ClassLabel:
) -> object:
"""Encode a categorical column into numerical labels and
return the ClassLabel feature.

Expand All @@ -864,6 +872,8 @@ def encode_labels(
if column_name not in dataset.column_names:
raise ValueError(f"Column '{column_name}' does not exist in the dataset.")

from datasets import ClassLabel # local import

names = list(set(dataset[column_name]))
class_label_feature = ClassLabel(names=names)
return class_label_feature
Expand Down Expand Up @@ -950,6 +960,9 @@ def split_indexes(
# Generate shuffled indexes
if seed is None:
seed = 42
import numpy as np
from sklearn.model_selection import train_test_split

indexes = np.arange(total_rows)
stratify_labels = np.array(labels) if stratify else None

Expand Down Expand Up @@ -1005,7 +1018,7 @@ def split_dataset(
train_indexes: List = None,
test_indexes: List = None,
val_indexes: List = None,
) -> DatasetDict:
) -> object:
"""
Split the dataset in train, test and validation subsets.
If indexes are not provided, it will use the split indices
Expand All @@ -1032,7 +1045,11 @@ def split_dataset(
ValueError
Must provide all indexes or none.
"""
import numpy as np

if all(idx is None for idx in [train_indexes, test_indexes, val_indexes]):
from datasets import DatasetDict

train_dataset = dataset.get_split("train")
test_dataset = dataset.get_split("test")
val_dataset = dataset.get_split("validation")
Expand All @@ -1055,6 +1072,8 @@ def split_dataset(
val_mask = np.isin(np.arange(n), val_indexes)

# Get the underlying table
import pyarrow as pa # local import

table = dataset.arrow_table

dataset.splits["split_indices"] = {
Expand All @@ -1069,6 +1088,8 @@ def split_dataset(
val_table = table.filter(pa.array(val_mask))

# Preserve types from the original dataset to maintain categorical mappings
from datasets import DatasetDict # local import

separate_dataset_dict = DatasetDict(
{
"train": DashAIDataset(train_table, types=dataset.types),
Expand All @@ -1081,7 +1102,7 @@ def split_dataset(


def to_dashai_dataset(
dataset: Union[DatasetDict, Dataset, DashAIDataset, DataFrame],
dataset: object,
types: Optional[Dict[str, DashAIDataType]] = None,
) -> DashAIDataset:
"""
Expand All @@ -1102,19 +1123,27 @@ def to_dashai_dataset(
if isinstance(dataset, DashAIDataset):
return dataset

if isinstance(dataset, Dataset):
from datasets import Dataset as HFDataset # local import

if isinstance(dataset, HFDataset):
arrow_tbl = get_arrow_table(dataset)
if types:
types_serialized = {col: types[col].to_string() for col in types}
arrow_tbl = save_types_in_arrow_metadata(arrow_tbl, types_serialized)
return DashAIDataset(arrow_tbl, types=types)
if isinstance(dataset, DataFrame):
hf_dataset = Dataset.from_pandas(dataset, preserve_index=False)
try:
from pandas import DataFrame as PDDataFrame # local import
except Exception:
PDDataFrame = None
if PDDataFrame is not None and isinstance(dataset, PDDataFrame):
hf_dataset = HFDataset.from_pandas(dataset, preserve_index=False)
arrow_tbl = get_arrow_table(hf_dataset)
if types:
types_serialized = {col: types[col].to_string() for col in types}
arrow_tbl = save_types_in_arrow_metadata(arrow_tbl, types_serialized)
return DashAIDataset(arrow_tbl, types=types)
from datasets import DatasetDict # local import

if isinstance(dataset, DatasetDict) and len(dataset) == 1:
key = list(dataset.keys())[0]
ds = dataset[key]
Expand All @@ -1131,7 +1160,7 @@ def to_dashai_dataset(

@beartype
def validate_inputs_outputs(
datasetdict: Union[DatasetDict, DashAIDataset],
datasetdict: object,
inputs: List[str],
outputs: List[str],
) -> None:
Expand Down Expand Up @@ -1170,9 +1199,7 @@ def validate_inputs_outputs(


@beartype
def get_column_names_from_indexes(
dataset: Union[DashAIDataset, DatasetDict], indexes: List[int]
) -> List[str]:
def get_column_names_from_indexes(dataset: object, indexes: List[int]) -> List[str]:
"""Obtain the column labels that correspond to the provided indexes.

Note: indexing starts from 1.
Expand Down Expand Up @@ -1206,7 +1233,7 @@ def get_column_names_from_indexes(

@beartype
def select_columns(
dataset: Union[DatasetDict, DashAIDataset],
dataset: object,
input_columns: List[str],
output_columns: List[str],
) -> Tuple[DashAIDataset, DashAIDataset]:
Expand Down Expand Up @@ -1252,6 +1279,8 @@ def get_columns_spec(dataset_path: str) -> Dict[str, Dict]:
"""
data_filepath = os.path.join(dataset_path, "data.arrow")

import pyarrow as pa # local import

with pa.OSFile(data_filepath, "rb") as source:
reader = pa.ipc.open_file(source)
schema = reader.schema
Expand Down Expand Up @@ -1334,6 +1363,8 @@ def get_dataset_info(dataset_path: str) -> object:
object
Dictionary with the information of the dataset
"""
import json

metadata_filepath = os.path.join(dataset_path, "splits.json")
if os.path.exists(metadata_filepath):
with open(metadata_filepath, "r", encoding="utf-8") as f:
Expand Down Expand Up @@ -1404,8 +1435,10 @@ def update_dataset_splits(
# I think it could be simplified since DashAITypes, but I don't want to break anything
def prepare_for_model_session(
dataset: DashAIDataset, splits: dict, output_columns: List[str]
) -> DatasetDict:
) -> object:
"""Prepare the dataset for a model session by updating the splits configuration"""
from contextlib import suppress

splitType = splits.get("splitType")
if splitType == "manual" or splitType == "predefined":
splits_index = splits
Expand Down Expand Up @@ -1468,7 +1501,7 @@ def prepare_for_model_session(

def modify_table(
dataset: DashAIDataset,
columns: Dict[str, pa.Array],
columns: Dict[str, object],
types: Optional[Dict[str, DashAIDataType]] = None,
) -> DashAIDataset:
"""
Expand All @@ -1487,6 +1520,8 @@ def modify_table(
DashAIDataset
The modified dataset with the updated column type.
"""
import pyarrow as pa

original_table = dataset.arrow_table
updated_columns = {}

Expand Down