diff --git a/src/dataworkbench/datacatalogue.py b/src/dataworkbench/datacatalogue.py index befcc6a..c7c2a69 100644 --- a/src/dataworkbench/datacatalogue.py +++ b/src/dataworkbench/datacatalogue.py @@ -39,9 +39,21 @@ def __init__(self) -> None: self.gateway: Gateway = Gateway() self.storage_base_url: str = get_secret("StorageBaseUrl") - def __build_storage_url(self, folder_id: uuid.UUID) -> str: + def __build_storage_table_root_url(self, folder_id: uuid.UUID) -> str: """ - Build the ABFSS URL for the target storage location. + Build the ABFSS URL for the root location of the table + """ + if not isinstance(folder_id, uuid.UUID): + raise TypeError("folder_id must be uuid") + + if not folder_id: + raise ValueError("folder_id cannot be empty") + + return f"{self.storage_base_url}/{folder_id}" + + def __build_storage_table_processed_url(self, folder_id: uuid.UUID) -> str: + """ + Build the ABFSS URL for the processed table storage location. Args: folder_id: Unique identifier for the storage folder @@ -51,15 +63,10 @@ def __build_storage_url(self, folder_id: uuid.UUID) -> str: Example: >>> catalogue = DataCatalogue() - >>> catalogue._build_storage_url("abc123") + >>> catalogue.__build_storage_table_processed_url("abc123") """ - if not isinstance(folder_id, uuid.UUID): - raise TypeError("folder_id must be uuid") - - if not folder_id: - raise ValueError("folder_id cannot be empty") - - return f"{self.storage_base_url}/{folder_id}/Processed" + table_root_url = self.__build_storage_table_root_url(folder_id) + return f"{table_root_url}/Processed" def save( self, @@ -118,15 +125,45 @@ def save( # Generate folder_id folder_id = uuid.uuid4() - target_path = self.__build_storage_url(folder_id) + target_path = self.__build_storage_table_processed_url(folder_id) try: # Write data using the specified or defaulted mode self.storage.write(df, target_path, mode=WriteMode.OVERWRITE.value) - return self.gateway.import_dataset( - dataset_name, dataset_description, schema_id, tags or {}, folder_id - ) + try: + # Register the dataset with the Gateway API + return self.gateway.import_dataset( + dataset_name, dataset_description, schema_id, tags or {}, folder_id + ) + except Exception as e: + self._rollback_write(folder_id) + + # Raise the original API error with additional context + error_msg = ( + f"Gateway API call failed and storage was rolled back: {str(e)}" + ) + raise type(e)(error_msg) from e except Exception as e: return {"error": str(e), "error_type": type(e).__name__} + + def _rollback_write(self, folder_id: uuid.UUID) -> None: + """ + Delete table from storage to rollback changes when an operation fails. + + Args: + target_path: Path to the data in storage that should be deleted + """ + target_path = self.__build_storage_table_root_url(folder_id) + logger.info("Rolling back data write operation to storage") + try: + self.storage.delete(target_path, recursive=True) + except Exception as rollback_error: + logger.error( + f"Failed to rollback storage operation at {target_path}: {str(rollback_error)}" + ) + + logger.info( + f"Successfully rolled back data write operation by deleting: {target_path}" + ) diff --git a/src/dataworkbench/gateway.py b/src/dataworkbench/gateway.py index 0e3132b..88a06b6 100644 --- a/src/dataworkbench/gateway.py +++ b/src/dataworkbench/gateway.py @@ -138,4 +138,4 @@ def import_dataset( return self.__send_request(url, payload) except requests.exceptions.RequestException as e: logger.error(f"Error creating data catalog entry: {e}") - return {"error": f"Failed to create data catalog entry: {str(e)}"} + raise diff --git a/src/dataworkbench/storage.py b/src/dataworkbench/storage.py index 928020f..2a20174 100644 --- a/src/dataworkbench/storage.py +++ b/src/dataworkbench/storage.py @@ -1,10 +1,12 @@ -from typing import Any, Literal +from typing import Literal from pyspark.sql import DataFrame, SparkSession from pyspark.sql.utils import AnalysisException from abc import ABC, abstractmethod +from dataworkbench.utils import get_dbutils, PrimitiveType, is_databricks from dataworkbench.log import setup_logger + # Configure logging logger = setup_logger(__name__) @@ -24,7 +26,7 @@ def write( df: DataFrame, target_path: str, mode: Literal["overwrite", "append", "error", "ignore"] = "overwrite", - **options: dict[str, Any], + **options: PrimitiveType | None, ) -> None: """ Write a DataFrame to storage. @@ -60,7 +62,7 @@ def check_path_exists(self, path: str) -> bool: pass @abstractmethod - def read(self, source_path: str, **options: dict[str, Any]) -> DataFrame: + def read(self, source_path: str, **options: PrimitiveType | None) -> DataFrame: """ Read data from storage into a DataFrame. @@ -98,6 +100,7 @@ def __init__(self, spark_session: SparkSession | None = None): raise TypeError("spark_session must be a SparkSession or None") self._spark = spark_session + self._dbutils = get_dbutils(self._spark) @property def spark(self) -> SparkSession: @@ -127,7 +130,8 @@ def write( df: DataFrame, target_path: str, mode: Literal["overwrite", "append", "error", "ignore"] = "overwrite", - **options: dict[str, Any], + partition_by: str | list[str] | None = None, + **options: PrimitiveType | None, ) -> None: """ Write a DataFrame to storage in Delta format. @@ -172,8 +176,11 @@ def write( writer = df.write.format("delta").mode(mode) # Apply options if provided - for key, value in options.items(): - writer = writer.option(key, value) + if options: + writer = writer.options(**options) + + if partition_by: + writer = writer.partitionBy(partition_by) # Save the data writer.save(target_path) @@ -189,7 +196,7 @@ def append( df: DataFrame, target_path: str, partition_by: str | list[str] | None = None, - **options: dict[str, Any], + **options: PrimitiveType | None, ) -> None: """ Append a DataFrame to existing data in Delta format. @@ -213,7 +220,13 @@ def append( >>> new_records = spark.createDataFrame([("Charlie", 35)], ["name", "age"]) >>> storage.append(new_records, "abfss://container@account.dfs.core.windows.net/path/to/data") """ - self.write(df, target_path, mode="append", partition_by=partition_by, **options) + self.write( + df=df, + target_path=target_path, + mode="append", + partition_by=partition_by, + **options, + ) def check_path_exists(self, path: str) -> bool: """ @@ -247,7 +260,7 @@ def check_path_exists(self, path: str) -> bool: logger.warning(f"Error checking path existence: {e}") return False - def read(self, source_path: str, **options: dict[str, Any]) -> DataFrame: + def read(self, source_path: str, **options: PrimitiveType | None) -> DataFrame: """ Read a Delta table from storage into a DataFrame. @@ -274,8 +287,8 @@ def read(self, source_path: str, **options: dict[str, Any]) -> DataFrame: reader = self.spark.read.format("delta") # Apply options if provided - for key, value in options.items(): - reader = reader.option(key, value) + if options: + reader = reader.options(**options) # Load the data return reader.load(source_path) @@ -284,3 +297,52 @@ def read(self, source_path: str, **options: dict[str, Any]) -> DataFrame: error_msg = f"Failed to read data from {source_path}: {e}" logger.error(error_msg) raise RuntimeError(error_msg) from e + + def file_exists(self, path: str): + if is_databricks(): + try: + self._dbutils.fs.ls(path) + return True + except Exception as e: + if "java.io.FileNotFoundException" in str(e): + return False + else: + raise + else: + logger.info("This method is not implemented outside databricks") + + def delete(self, path: str, recursive: bool = True) -> None: + """ + Delete a directory from Azure Storage using Spark. + + Args: + path: The path to the file / directory in Azure Storage to delete + recursive: If True, recursively delete all subdirectories and files + + Raises: + TypeError: If path is not a string + ValueError: If path is empty + Exception: If any error occurs during deletion + """ + if not is_databricks(): + raise RuntimeError("Delete does not work outside databricks") + + if not isinstance(path, str): + raise TypeError("path must be a non-empty string") + + if not path: + raise ValueError("path cannot be empty") + + try: + logger.info(f"Deleting path: {path}, recursive={recursive}") + + if not self.file_exists(path): + logger.warning(f"Path does not exist, nothing to delete: {path}") + return + + # Delete the path + self._dbutils.fs.rm(path, recurse=True) + + except Exception as e: + logger.error(f"Failed to delete {path}: {str(e)}") + raise Exception(f"Failed to delete: {str(e)}") from e diff --git a/src/dataworkbench/utils.py b/src/dataworkbench/utils.py index 5670a7a..e69746e 100644 --- a/src/dataworkbench/utils.py +++ b/src/dataworkbench/utils.py @@ -1,6 +1,14 @@ import os from pyspark.sql import SparkSession +from dataworkbench.log import setup_logger + +# Configure logging +logger = setup_logger(__name__) + + +PrimitiveType = str | int | float | bool + def get_spark() -> SparkSession: """ @@ -23,24 +31,35 @@ def is_databricks(): return os.getenv("DATABRICKS_RUNTIME_VERSION") is not None -def get_secret(key: str, scope: str = "dwsecrets") -> str: +def get_dbutils(spark: SparkSession | None = None): """ - Retrieve a secret from dbutils if running on Databricks, otherwise fallback to env variables. + Get dbutils module """ - - secret = None # Default value - if is_databricks(): try: from pyspark.dbutils import DBUtils # type: ignore - - spark = get_spark() - dbutils = DBUtils(spark) - secret = dbutils.secrets.get(scope, key) except ImportError: raise RuntimeError( "dbutils module not found. Ensure this is running on Databricks." ) + try: + return DBUtils(spark) + except Exception as e: + logger.error(f"Failed to create dbutils: {e}") + raise RuntimeError("No dbutils available") from e + else: + return None + + +def get_secret(key: str, scope: str = "dwsecrets") -> str: + """ + Retrieve a secret from dbutils if running on Databricks, otherwise fallback to env variables. + """ + + dbutils = get_dbutils() + + if dbutils: + secret = dbutils.secrets.get(scope, key) else: secret = os.getenv(key) diff --git a/tests/test_gateway.py b/tests/test_gateway.py index 87e3f4c..9428f13 100644 --- a/tests/test_gateway.py +++ b/tests/test_gateway.py @@ -2,6 +2,7 @@ import requests from unittest.mock import patch, MagicMock from dataworkbench.gateway import Gateway +from requests.exceptions import RequestException @pytest.fixture def mock_gateway(): @@ -35,7 +36,7 @@ def test_import_dataset_failure(mock_gateway, mock_post): """Test dataset import failure.""" mock_post.side_effect = requests.exceptions.RequestException("Request failed") - result = mock_gateway.import_dataset("dataset_name", "dataset_description", "schema_id", {"tag": "value"}, "folder_id") + with pytest.raises(RequestException): + mock_gateway.import_dataset("dataset_name", "dataset_description", "schema_id", {"tag": "value"}, "folder_id") - assert result == {"error": "Failed to create data catalog entry: Request failed"} mock_post.assert_called_once()