diff --git a/CHANGELOG.md b/CHANGELOG.md index 78a53d7ae..5803437b2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,7 +9,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added -### Changed +- Azure Storage Account json Check +- Azure token `{year}, {month}, {day}, {date}, {day-1}, {quarter}` ### Fixed diff --git a/datacontract/catalog/catalog.py b/datacontract/catalog/catalog.py index 99de87dd3..306df38a2 100644 --- a/datacontract/catalog/catalog.py +++ b/datacontract/catalog/catalog.py @@ -19,7 +19,7 @@ def create_data_contract_html(contracts, file: Path, path: Path, schema: str): file_without_suffix = file.with_suffix(".html") html_filepath = path / file_without_suffix html_filepath.parent.mkdir(parents=True, exist_ok=True) - with open(html_filepath, "w") as f: + with open(html_filepath, "w", encoding='utf-8') as f: f.write(html) contracts.append( DataContractView( @@ -42,7 +42,7 @@ class DataContractView: def create_index_html(contracts, path): index_filepath = path / "index.html" - with open(index_filepath, "w") as f: + with open(index_filepath, "w", encoding='utf-8') as f: # Load templates from templates folder package_loader = PackageLoader("datacontract", "templates") env = Environment( diff --git a/datacontract/engines/data_contract_test.py b/datacontract/engines/data_contract_test.py index ae3b84b1f..35ca12920 100644 --- a/datacontract/engines/data_contract_test.py +++ b/datacontract/engines/data_contract_test.py @@ -43,9 +43,9 @@ def execute_data_contract_test( run.checks.extend(create_checks(data_contract_specification, server)) # TODO check server is supported type for nicer error messages - # TODO check server credentials are complete for nicer error messages - if server.format == "json" and server.type != "kafka": - check_jsonschema(run, data_contract_specification, server) + #if server.format == "json" and server.type in ("azure", "s3"): + # check_jsonschema(run, data_contract_specification, server) + # with soda check_soda_execute(run, data_contract_specification, server, spark, duckdb_connection) diff --git a/datacontract/engines/fastjsonschema/az/az_read_files.py b/datacontract/engines/fastjsonschema/az/az_read_files.py new file mode 100644 index 000000000..ff04d53ad --- /dev/null +++ b/datacontract/engines/fastjsonschema/az/az_read_files.py @@ -0,0 +1,68 @@ +import os + +from datacontract.model.exceptions import DataContractException +from datacontract.model.run import Run, ResultEnum + + +def yield_az_files(run: Run, az_storageAccount, az_location): + fs = az_fs(az_storageAccount) + files = fs.glob(az_location) + for file in files: + with fs.open(file) as f: + run.log_info(f"Downloading file {file}") + yield f.read() + + +def az_fs(az_storageAccount): + try: + import adlfs + except ImportError as e: + raise DataContractException( + type="schema", + result=ResultEnum.failed, + name="az extra missing", + reason="Install the extra datacontract-cli\\[azure] to use az", + engine="datacontract", + original_exception=e, + ) + + az_client_id = os.getenv("DATACONTRACT_AZURE_CLIENT_ID") + if az_client_id is None: + raise DataContractException( + type="schema", + result=ResultEnum.failed, + name="az env. variable DATACONTRACT_AZURE_CLIENT_ID missing", + reason="configure export DATACONTRACT_AZURE_CLIENT_ID=*** ", + engine="datacontract", + original_exception=e, + ) + + az_client_secret = os.getenv("DATACONTRACT_AZURE_CLIENT_SECRET") + if az_client_secret is None: + raise DataContractException( + type="schema", + result=ResultEnum.failed, + name="az env. variable DATACONTRACT_AZURE_CLIENT_SECRET missing", + reason="configure export DATACONTRACT_AZURE_CLIENT_SECRET=*** ", + engine="datacontract", + original_exception=e, + ) + + az_tenant_id = os.getenv("DATACONTRACT_AZURE_TENANT_ID") + if az_tenant_id is None: + raise DataContractException( + type="schema", + result=ResultEnum.failed, + name="az env. variable DATACONTRACT_AZURE_TENANT_ID missing", + reason="configure export DATACONTRACT_AZURE_TENANT_ID=*** ", + engine="datacontract", + original_exception=e, + ) + + return adlfs.AzureBlobFileSystem( + account_name=az_storageAccount, + client_id=az_client_id, + client_secret=az_client_secret, + tenant_id=az_tenant_id, + anon=az_client_id is None, + ) diff --git a/datacontract/engines/fastjsonschema/check_jsonschema.py b/datacontract/engines/fastjsonschema/check_jsonschema.py index 5ea79caad..6bdf9e6e8 100644 --- a/datacontract/engines/fastjsonschema/check_jsonschema.py +++ b/datacontract/engines/fastjsonschema/check_jsonschema.py @@ -2,12 +2,14 @@ import logging import os import threading +from datetime import datetime from typing import List, Optional import fastjsonschema from fastjsonschema import JsonSchemaValueException from datacontract.engines.fastjsonschema.s3.s3_read_files import yield_s3_files +from datacontract.engines.fastjsonschema.az.az_read_files import yield_az_files from datacontract.export.jsonschema_converter import to_jsonschema from datacontract.model.data_contract_specification import DataContractSpecification, Server from datacontract.model.exceptions import DataContractException @@ -85,15 +87,15 @@ def process_exceptions(run, exceptions: List[DataContractException]): def validate_json_stream( - schema: dict, model_name: str, validate: callable, json_stream: list[dict] + run: Run, schema: dict, model_name: str, validate: callable, json_stream: list[dict] ) -> List[DataContractException]: - logging.info(f"Validating JSON stream for model: '{model_name}'.") + run.log_info(f"Validating JSON stream for model: '{model_name}'.") exceptions: List[DataContractException] = [] for json_obj in json_stream: try: validate(json_obj) except JsonSchemaValueException as e: - logging.warning(f"Validation failed for JSON object with type: '{model_name}'.") + run.log_warn(f"Validation failed for JSON object with type: '{model_name}'.") primary_key_value = get_primary_key_value(schema, model_name, json_obj) exceptions.append( DataContractException( @@ -107,7 +109,7 @@ def validate_json_stream( ) ) if not exceptions: - logging.info(f"All JSON objects in the stream passed validation for model: '{model_name}'.") + run.log_info(f"All JSON objects in the stream passed validation for model: '{model_name}'.") return exceptions @@ -151,7 +153,7 @@ def process_json_file(run, schema, model_name, validate, file, delimiter): json_stream = read_json_file(file) # Validate the JSON stream and collect exceptions. - exceptions = validate_json_stream(schema, model_name, validate, json_stream) + exceptions = validate_json_stream(run, schema, model_name, validate, json_stream) # Handle all errors from schema validation. process_exceptions(run, exceptions) @@ -165,7 +167,7 @@ def process_local_file(run, server, schema, model_name, validate): if os.path.isdir(path): return process_directory(run, path, server, model_name, validate) else: - logging.info(f"Processing file {path}") + run.log_info(f"Processing file {path}") with open(path, "r") as file: process_json_file(run, schema, model_name, validate, file, server.delimiter) @@ -189,7 +191,7 @@ def process_s3_file(run, server, schema, model_name, validate): s3_location = s3_location.format(model=model_name) json_stream = None - for file_content in yield_s3_files(s3_endpoint_url, s3_location): + for file_content in yield_s3_files(run, s3_endpoint_url, s3_location): if server.delimiter == "new_line": json_stream = read_json_lines_content(file_content) elif server.delimiter == "array": @@ -207,11 +209,62 @@ def process_s3_file(run, server, schema, model_name, validate): ) # Validate the JSON stream and collect exceptions. - exceptions = validate_json_stream(schema, model_name, validate, json_stream) + exceptions = validate_json_stream(run, schema, model_name, validate, json_stream) # Handle all errors from schema validation. process_exceptions(run, exceptions) +def process_azure_file(run, server, schema, model_name, validate): + + if server.storageAccount is None: + raise DataContractException( + type="schema", + name="Check that JSON has valid schema", + result="warning", + reason=f"Cannot retrieve storageAccount in Server config", + engine="datacontract", + ) + + az_storageAccount = server.storageAccount + az_location = server.location + date = datetime.today() + + if "{model}" in az_location: + date = datetime.today() + month_to_quarter = { 1: "Q1", 2: "Q1", 3: "Q1", 4: "Q2", 5: "Q2", 6: "Q2", + 7: "Q3", 8: "Q3", 9: "Q3",10: "Q4", 11: "Q4", 12: "Q4" } + + az_location = az_location.format(model=model_name, + year=date.strftime('%Y'), + month=date.strftime('%m'), + day=date.strftime('%d'), + date=date.strftime('%Y-%m-%d'), + quarter=month_to_quarter.get(date.month)) + + json_stream = None + + for file_content in yield_az_files(run, az_storageAccount, az_location): + if server.delimiter == "new_line": + json_stream = read_json_lines_content(file_content) + elif server.delimiter == "array": + json_stream = read_json_array_content(file_content) + else: + json_stream = read_json_file_content(file_content) + + if json_stream is None: + raise DataContractException( + type="schema", + name="Check that JSON has valid schema", + result="warning", + reason=f"Cannot find any file in {az_location}", + engine="datacontract", + ) + + # Validate the JSON stream and collect exceptions. + exceptions = validate_json_stream(run, schema, model_name, validate, json_stream) + + # Handle all errors from schema validation. + process_exceptions(run, exceptions) def check_jsonschema(run: Run, data_contract: DataContractSpecification, server: Server): run.log_info("Running engine jsonschema") @@ -262,16 +315,7 @@ def check_jsonschema(run: Run, data_contract: DataContractSpecification, server: ) ) elif server.type == "azure": - run.checks.append( - Check( - type="schema", - name="Check that JSON has valid schema", - model=model_name, - result=ResultEnum.info, - reason="JSON Schema check skipped for azure, as azure is currently not supported", - engine="jsonschema", - ) - ) + process_azure_file(run, server, schema, model_name, validate) else: run.checks.append( Check( diff --git a/datacontract/engines/fastjsonschema/s3/s3_read_files.py b/datacontract/engines/fastjsonschema/s3/s3_read_files.py index 87447f2ed..1dc44d333 100644 --- a/datacontract/engines/fastjsonschema/s3/s3_read_files.py +++ b/datacontract/engines/fastjsonschema/s3/s3_read_files.py @@ -1,16 +1,15 @@ -import logging import os from datacontract.model.exceptions import DataContractException -from datacontract.model.run import ResultEnum +from datacontract.model.run import Run, ResultEnum -def yield_s3_files(s3_endpoint_url, s3_location): +def yield_s3_files(run: Run, s3_endpoint_url, s3_location): fs = s3_fs(s3_endpoint_url) files = fs.glob(s3_location) for file in files: with fs.open(file) as f: - logging.info(f"Downloading file {file}") + run.log_info(f"Downloading file {file}") yield f.read() @@ -28,8 +27,38 @@ def s3_fs(s3_endpoint_url): ) aws_access_key_id = os.getenv("DATACONTRACT_S3_ACCESS_KEY_ID") + if aws_access_key_id is None: + raise DataContractException( + type="schema", + result=ResultEnum.failed, + name="s3 env. variable DATACONTRACT_S3_ACCESS_KEY_ID missing", + reason="configure export DATACONTRACT_S3_ACCESS_KEY_ID=*** ", + engine="datacontract", + original_exception=e, + ) + aws_secret_access_key = os.getenv("DATACONTRACT_S3_SECRET_ACCESS_KEY") + if aws_secret_access_key is None: + raise DataContractException( + type="schema", + result=ResultEnum.failed, + name="s3 env. variable DATACONTRACT_S3_SECRET_ACCESS_KEY missing", + reason="configure export DATACONTRACT_S3_SECRET_ACCESS_KEY=*** ", + engine="datacontract", + original_exception=e, + ) + aws_session_token = os.getenv("DATACONTRACT_S3_SESSION_TOKEN") + if aws_session_token is None: + raise DataContractException( + type="schema", + result=ResultEnum.failed, + name="s3 env. variable DATACONTRACT_S3_SESSION_TOKEN missing", + reason="configure export DATACONTRACT_S3_SESSION_TOKEN=*** ", + engine="datacontract", + original_exception=e, + ) + return s3fs.S3FileSystem( key=aws_access_key_id, secret=aws_secret_access_key, diff --git a/datacontract/engines/soda/check_soda_execute.py b/datacontract/engines/soda/check_soda_execute.py index 3f536e37e..56ffd46d5 100644 --- a/datacontract/engines/soda/check_soda_execute.py +++ b/datacontract/engines/soda/check_soda_execute.py @@ -14,6 +14,7 @@ from datacontract.engines.soda.connections.postgres import to_postgres_soda_configuration from datacontract.engines.soda.connections.snowflake import to_snowflake_soda_configuration from datacontract.engines.soda.connections.sqlserver import to_sqlserver_soda_configuration +from datacontract.engines.soda.connections.db2 import to_db2_soda_configuration from datacontract.engines.soda.connections.trino import to_trino_soda_configuration from datacontract.export.sodacl_converter import to_sodacl_yaml from datacontract.model.data_contract_specification import DataContractSpecification, Server @@ -92,6 +93,10 @@ def check_soda_execute( logging.info("Use Spark to connect to data source") scan.add_spark_session(spark, data_source_name="datacontract-cli") scan.set_data_source_name("datacontract-cli") + elif server.type == "db2": + soda_configuration_str = to_db2_soda_configuration(server) + scan.add_configuration_yaml_str(soda_configuration_str) + scan.set_data_source_name(server.type) elif server.type == "kafka": if spark is None: spark = create_spark_session() diff --git a/datacontract/engines/soda/connections/db2.py b/datacontract/engines/soda/connections/db2.py new file mode 100644 index 000000000..f60f55d88 --- /dev/null +++ b/datacontract/engines/soda/connections/db2.py @@ -0,0 +1,35 @@ +import os + +import yaml + +from datacontract.model.data_contract_specification import Server + + +def to_db2_soda_configuration(server: Server) -> str: + """Serialize server config to soda configuration. + + + ### Example: + type: DB2 + host: 127.0.0.1 + port: '50000' + username: simple + password: simple_pass + database: database + schema: public + """ + # with service account key, using an external json file + soda_configuration = { + f"data_source {server.type}": { + "type": "db2", + "host": server.host, + "port": str(server.port), + "username": os.getenv("DATACONTRACT_DB2_USERNAME", ""), + "password": os.getenv("DATACONTRACT_DB2_PASSWORD", ""), + "database": server.database, + "schema": server.schema_, + } + } + + soda_configuration_str = yaml.dump(soda_configuration) + return soda_configuration_str diff --git a/datacontract/engines/soda/connections/duckdb_connection.py b/datacontract/engines/soda/connections/duckdb_connection.py index f05fce2f6..8dfe5ae60 100644 --- a/datacontract/engines/soda/connections/duckdb_connection.py +++ b/datacontract/engines/soda/connections/duckdb_connection.py @@ -1,12 +1,11 @@ import os from typing import Any - import duckdb from datacontract.export.csv_type_converter import convert_to_duckdb_csv_type from datacontract.model.data_contract_specification import DataContractSpecification, Server from datacontract.model.run import Run - +from datetime import datetime def get_duckdb_connection( data_contract: DataContractSpecification, @@ -33,40 +32,85 @@ def get_duckdb_connection( setup_azure_connection(con, server) for model_name, model in data_contract.models.items(): model_path = path - if "{model}" in model_path: - model_path = model_path.format(model=model_name) - run.log_info(f"Creating table {model_name} for {model_path}") - - if server.format == "json": - json_format = "auto" - if server.delimiter == "new_line": - json_format = "newline_delimited" - elif server.delimiter == "array": - json_format = "array" - con.sql(f""" - CREATE VIEW "{model_name}" AS SELECT * FROM read_json_auto('{model_path}', format='{json_format}', hive_partitioning=1); - """) - elif server.format == "parquet": - con.sql(f""" - CREATE VIEW "{model_name}" AS SELECT * FROM read_parquet('{model_path}', hive_partitioning=1); - """) - elif server.format == "csv": - columns = to_csv_types(model) - run.log_info("Using columns: " + str(columns)) - if columns is None: - con.sql( - f"""CREATE VIEW "{model_name}" AS SELECT * FROM read_csv('{model_path}', hive_partitioning=1);""" - ) - else: - con.sql( - f"""CREATE VIEW "{model_name}" AS SELECT * FROM read_csv('{model_path}', hive_partitioning=1, columns={columns});""" - ) - elif server.format == "delta": - con.sql("update extensions;") # Make sure we have the latest delta extension - con.sql(f"""CREATE VIEW "{model_name}" AS SELECT * FROM delta_scan('{model_path}');""") + try: + if "{model}" in model_path: + date = datetime.today() + month_to_quarter = { 1: "Q1", 2: "Q1", 3: "Q1", 4: "Q2", 5: "Q2", 6: "Q2", + 7: "Q3", 8: "Q3", 9: "Q3",10: "Q4", 11: "Q4", 12: "Q4" } + + model_path = model_path.format(model=model_name, + year=date.strftime('%Y'), + month=date.strftime('%m'), + day=date.strftime('%d'), + date=date.strftime('%Y-%m-%d'), + quarter=month_to_quarter.get(date.month)) + run.log_info(f"Creating table {model_name} for {model_path}") + view_ddl= "" + if server.format == "json": + json_format = "auto" + if server.delimiter == "new_line": + json_format = "newline_delimited" + elif server.delimiter == "array": + json_format = "array" + view_ddl=f""" + CREATE VIEW "{model_name}" AS SELECT * FROM read_json_auto('{model_path}', format='{json_format}', hive_partitioning=1); + """ + elif server.format == "parquet": + view_ddl=f""" + CREATE VIEW "{model_name}" AS SELECT * FROM read_parquet('{model_path}', hive_partitioning=1); + """ + elif server.format == "csv": + columns = to_csv_types(model) + run.log_info("Using columns: " + str(columns)) + # Start with the required parameter. + params = ["hive_partitioning=1"] + + # Define a mapping for CSV parameters: server attribute -> read_csv parameter name. + param_mapping = { + "delimiter": "delim", # Map server.delimiter to 'delim' + "header": "header", + "escape": "escape", + "allVarchar": "all_varchar", + "allowQuotedNulls": "allow_quoted_nulls", + "dateformat": "dateformat", + "decimalSeparator": "decimal_separator", + "newLine": "new_line", + "timestampformat": "timestampformat", + "quote": "quote", + + } + for server_attr, read_csv_param in param_mapping.items(): + value = getattr(server, server_attr, None) + if value is not None: + # Wrap string values in quotes. + if isinstance(value, str): + params.append(f"{read_csv_param}='{value}'") + else: + params.append(f"{read_csv_param}={value}") + + # Add columns if they exist. + if columns is not None: + params.append(f"columns={columns}") + + # Build the parameter string. + params_str = ", ".join(params) + + # Create the view with the assembled parameters. + view_ddl = f""" + CREATE VIEW "{model_name}" AS + SELECT * FROM read_csv('{model_path}', {params_str}); + """ + elif server.format == "delta": + con.sql("update extensions;") # Make sure we have the latest delta extension + view_ddl=f"""CREATE VIEW "{model_name}" AS SELECT * FROM delta_scan('{model_path}');""" + + run.log_info("Active view ddl: " +view_ddl) + con.sql(view_ddl) + except Exception as inst: + print(inst) + continue return con - def to_csv_types(model) -> dict[Any, str | None] | None: if model is None: return None @@ -76,7 +120,6 @@ def to_csv_types(model) -> dict[Any, str | None] | None: columns[field_name] = convert_to_duckdb_csv_type(field) return columns - def setup_s3_connection(con, server): s3_region = os.getenv("DATACONTRACT_S3_REGION") s3_access_key_id = os.getenv("DATACONTRACT_S3_ACCESS_KEY_ID") diff --git a/datacontract/export/csv_type_converter.py b/datacontract/export/csv_type_converter.py index 79dfe1668..9c5c3948e 100644 --- a/datacontract/export/csv_type_converter.py +++ b/datacontract/export/csv_type_converter.py @@ -33,4 +33,6 @@ def convert_to_duckdb_csv_type(field) -> None | str: return "VARCHAR" if type.lower() in ["null"]: return "SQLNULL" + if type.lower() in ["json"]: + return "JSON" return "VARCHAR" diff --git a/datacontract/export/odcs_v3_exporter.py b/datacontract/export/odcs_v3_exporter.py index 25f52004e..96d5a0373 100644 --- a/datacontract/export/odcs_v3_exporter.py +++ b/datacontract/export/odcs_v3_exporter.py @@ -217,7 +217,7 @@ def to_property(field_name: str, field: Field) -> dict: if field.description is not None: property["description"] = field.description if field.required is not None: - property["nullable"] = not field.required + property["required"] = field.required if field.unique is not None: property["unique"] = field.unique if field.classification is not None: @@ -259,6 +259,8 @@ def to_property(field_name: str, field: Field) -> dict: property["logicalTypeOptions"]["minimum"] = field.minimum if field.maximum is not None: property["logicalTypeOptions"]["maximum"] = field.maximum + if field.enum is not None: + property["logicalTypeOptions"]["validValues"] = field.enum if field.exclusiveMinimum is not None: property["logicalTypeOptions"]["exclusiveMinimum"] = field.exclusiveMinimum if field.exclusiveMaximum is not None: diff --git a/datacontract/export/sql_converter.py b/datacontract/export/sql_converter.py index 2aabe111d..8b6e02913 100644 --- a/datacontract/export/sql_converter.py +++ b/datacontract/export/sql_converter.py @@ -100,31 +100,59 @@ def to_sql_ddl( def _to_sql_table(model_name, model, server_type="snowflake"): - if server_type == "databricks": + result = "init" + if server_type in ("databricks","snowflake") and model.type.lower() == "table": # Databricks recommends to use the CREATE OR REPLACE statement for unity managed tables # https://docs.databricks.com/en/sql/language-manual/sql-ref-syntax-ddl-create-table-using.html + # the same for Snowflake + # https://docs.snowflake.com/en/sql-reference/sql/create-table result = f"CREATE OR REPLACE TABLE {model_name} (\n" - else: + elif model.type.lower() == "table": result = f"CREATE TABLE {model_name} (\n" + elif server_type == "snowflake" and model.type.lower() == "view": + # https://docs.snowflake.com/en/sql-reference/sql/create-view + result = f"CREATE OR ALTER VIEW {model_name} (\n" + fields = len(model.fields) current_field_index = 1 - for field_name, field in iter(model.fields.items()): - type = convert_to_sql_type(field, server_type) - result += f" {field_name} {type}" - if field.required: - result += " not null" - if field.primaryKey or field.primary: - result += " primary key" - if server_type == "databricks" and field.description is not None: - result += f' COMMENT "{_escape(field.description)}"' - if current_field_index < fields: - result += "," - result += "\n" - current_field_index += 1 - result += ")" - if server_type == "databricks" and model.description is not None: - result += f' COMMENT "{_escape(model.description)}"' - result += ";\n" + if model.type.lower() == "table": + for field_name, field in iter(model.fields.items()): + type = convert_to_sql_type(field, server_type) + result += f" {field_name} {type}" + if field.required: + result += " not null" + if (field.primaryKey or field.primary) and field.c: + result += " primary key" + if server_type in ("snowflake","databricks") and field.description is not None: + result += f' COMMENT "{_escape(field.description)}"' + if current_field_index < fields: + result += "," + result += "\n" + current_field_index += 1 + result += ")" + elif model.type.lower() == "view": + field_list = '' + lineage_list = set() + for field_name, field in iter(model.fields.items()): + type = convert_to_sql_type(field, server_type) + result += f" {field_name}" + field_list += f" \n\t\t{field_name}," + if server_type in ("databricks","snowflake") and field.description is not None: + result += f' COMMENT "{_escape(field.description)}"' + if current_field_index < fields: + result += "," + result += "\n" + current_field_index += 1 + if field.lineages: + lineage_list = lineage_list.union(set(field.lineages)) + + result += f")" + if server_type in ( "snowflake","databricks") and model.description is not None: + result += f'\nCOMMENT "{_escape(model.description)}"\nAS \n\tSELECT {field_list} \n\tFROM {''.join(lineage_list)}' + else: + result += f')\nAS \n\tSELECT {field_list} \n\tFROM {''.join(lineage_list)}' + + result += ";\n\n" return result diff --git a/datacontract/imports/odcs_v3_importer.py b/datacontract/imports/odcs_v3_importer.py index f19073e43..ee68bd13d 100644 --- a/datacontract/imports/odcs_v3_importer.py +++ b/datacontract/imports/odcs_v3_importer.py @@ -17,6 +17,7 @@ Quality, Retention, Server, + ServerRole, ServiceLevel, Terms, ) @@ -98,6 +99,7 @@ def import_servers(odcs_contract: Dict[str, Any]) -> Dict[str, Server] | None: continue server = Server() + server.name = server_name server.type = odcs_server.get("type") server.description = odcs_server.get("description") server.environment = odcs_server.get("environment") @@ -120,9 +122,13 @@ def import_servers(odcs_contract: Dict[str, Any]) -> Dict[str, Server] | None: server.dataProductId = odcs_server.get("dataProductId") server.outputPortId = odcs_server.get("outputPortId") server.driver = odcs_server.get("driver") - server.roles = odcs_server.get("roles") + server.roles = [ServerRole(name = role.get("role"), + description = role.get("description"), + model_config = role + ) for role in odcs_server.get("roles")] if odcs_server.get("roles") is not None else None + server.storageAccount = odcs_server.get("storageAccount") - servers[server_name] = server + servers[server.name] = server return servers @@ -190,7 +196,8 @@ def import_models(odcs_contract: Dict[str, Any]) -> Dict[str, Model]: schema_physical_name = odcs_schema.get("physicalName") schema_description = odcs_schema.get("description") if odcs_schema.get("description") is not None else "" model_name = schema_physical_name if schema_physical_name is not None else schema_name - model = Model(description=" ".join(schema_description.splitlines()), type="table") + type = odcs_schema.get("physicalType") if odcs_schema.get("physicalType") is not None else "table" + model = Model(description=" ".join(schema_description.splitlines()), type=type) model.fields = import_fields( odcs_schema.get("properties"), custom_type_mappings, server_type=get_server_type(odcs_contract) ) @@ -257,14 +264,16 @@ def import_fields( for odcs_property in odcs_properties: mapped_type = map_type(odcs_property.get("logicalType"), custom_type_mappings) + if mapped_type is not None: property_name = odcs_property["name"] description = odcs_property.get("description") if odcs_property.get("description") is not None else None + field = Field( description=" ".join(description.splitlines()) if description is not None else None, type=mapped_type, title=odcs_property.get("businessName"), - required=not odcs_property.get("nullable") if odcs_property.get("nullable") is not None else False, + required=odcs_property.get("required") if odcs_property.get("required") is not None else False, primaryKey=odcs_property.get("primaryKey") if not has_composite_primary_key(odcs_properties) and odcs_property.get("primaryKey") is not None else False, @@ -276,8 +285,32 @@ def import_fields( tags=odcs_property.get("tags") if odcs_property.get("tags") is not None else None, quality=odcs_property.get("quality") if odcs_property.get("quality") is not None else [], config=import_field_config(odcs_property, server_type), + lineage=odcs_property.get("transformSourceObjects") if odcs_property.get("transformSourceObjects") is not None else None, + references=odcs_property.get("references") if odcs_property.get("references") is not None else None, + #nested object + fields= import_fields(odcs_property.get("properties"), custom_type_mappings, server_type) + if odcs_property.get("properties") is not None else {}, + format=odcs_property.get("format") if odcs_property.get("format") is not None else None, ) + + #mapped_type is array + if field.type == "array" and odcs_property.get("items") is not None : + #nested array object + if odcs_property.get("items").get("logicalType") == "object": + field.items= Field(type="object", + fields=import_fields(odcs_property.get("items").get("properties"), custom_type_mappings, server_type)) + #array of simple type + elif odcs_property.get("items").get("logicalType") is not None: + field.items= Field(type = odcs_property.get("items").get("logicalType")) + + # enum from quality validValues as enum + if field.type is "string": + for q in field.quality: + if hasattr(q,"validValues"): + field.enum = q.validValues + result[property_name] = field + else: logger.info( f"Can't map {odcs_property.get('column')} to the Datacontract Mapping types, as there is no equivalent or special mapping. Consider introducing a customProperty 'dc_mapping_{odcs_property.get('logicalName')}' that defines your expected type as the 'value'" diff --git a/datacontract/imports/sql_importer.py b/datacontract/imports/sql_importer.py index c51e4272c..6a4af41cf 100644 --- a/datacontract/imports/sql_importer.py +++ b/datacontract/imports/sql_importer.py @@ -1,8 +1,7 @@ import logging import os -import sqlglot -from sqlglot.dialects.dialect import Dialects +from simple_ddl_parser import parse_from_file from datacontract.imports.importer import Importer from datacontract.model.data_contract_specification import DataContractSpecification, Field, Model, Server @@ -20,12 +19,11 @@ def import_source( def import_sql( data_contract_specification: DataContractSpecification, format: str, source: str, import_args: dict = None ) -> DataContractSpecification: - sql = read_file(source) dialect = to_dialect(import_args) try: - parsed = sqlglot.parse_one(sql=sql, read=dialect) + ddl = parse_from_file(source, group_by_type=True, encoding = "cp1252", output_mode = dialect ) except Exception as e: logging.error(f"Error parsing SQL: {str(e)}") raise DataContractException( @@ -36,104 +34,72 @@ def import_sql( result=ResultEnum.error, ) - server_type: str | None = to_server_type(source, dialect) + server_type: str | None = dialect if server_type is not None: data_contract_specification.servers[server_type] = Server(type=server_type) - tables = parsed.find_all(sqlglot.expressions.Table) + tables = ddl["tables"] for table in tables: if data_contract_specification.models is None: data_contract_specification.models = {} - table_name = table.this.name + table_name = table["table_name"] fields = {} - for column in parsed.find_all(sqlglot.exp.ColumnDef): - if column.parent.this.name != table_name: - continue - + for column in table["columns"]: field = Field() - col_name = column.this.name - col_type = to_col_type(column, dialect) - field.type = map_type_from_sql(col_type) - col_description = get_description(column) - field.description = col_description - field.maxLength = get_max_length(column) - precision, scale = get_precision_scale(column) - field.precision = precision - field.scale = scale - field.primaryKey = get_primary_key(column) - field.required = column.find(sqlglot.exp.NotNullColumnConstraint) is not None or None - physical_type_key = to_physical_type_key(dialect) - field.config = { - physical_type_key: col_type, - } - - fields[col_name] = field - + field.type = map_type_from_sql(map_type_from_sql(column["type"])) + if not column["nullable"]: + field.required = True + if column["unique"]: + field.unique = True + + if column["size"] is not None and column["size"] and not isinstance(column["size"], tuple): + field.maxLength = column["size"] + elif isinstance(column["size"], tuple): + field.precision = column["size"][0] + field.scale = column["size"][1] + + field.description = column["comment"][1:-1] if column.get("comment") else None + field.required = column["nullable"] + if column.get("with_tag"): + field.tags = ", ".join(column["with_tag"]) + if column.get("with_masking_policy"): + field.classification = ", ".join(column["with_masking_policy"]) + if column.get("generated"): + field.examples = str(column["generated"]) + field.unique = column["unique"] + + fields[column["name"]] = field + + if table.get("constraints"): + if table["constraints"].get("primary_key"): + for primary_key in table["constraints"]["primary_key"]["columns"]: + if primary_key in fields: + fields[primary_key].unique = True + fields[primary_key].required = True + fields[primary_key].primaryKey = True + + table_description = table["comment"][1:-1] if table.get("comment") else None + table_tags = table["with_tag"][1:-1] if table.get("with_tag") else None + data_contract_specification.models[table_name] = Model( type="table", + description=table_description, + tags=table_tags, fields=fields, ) return data_contract_specification - -def get_primary_key(column) -> bool | None: - if column.find(sqlglot.exp.PrimaryKeyColumnConstraint) is not None: - return True - if column.find(sqlglot.exp.PrimaryKey) is not None: - return True - return None - - -def to_dialect(import_args: dict) -> Dialects | None: +def to_dialect(import_args: dict) -> str | None: if import_args is None: return None if "dialect" not in import_args: return None dialect = import_args.get("dialect") - if dialect is None: - return None - if dialect == "sqlserver": - return Dialects.TSQL - if dialect.upper() in Dialects.__members__: - return Dialects[dialect.upper()] - if dialect == "sqlserver": - return Dialects.TSQL - return None - - -def to_physical_type_key(dialect: Dialects | None) -> str: - dialect_map = { - Dialects.TSQL: "sqlserverType", - Dialects.POSTGRES: "postgresType", - Dialects.BIGQUERY: "bigqueryType", - Dialects.SNOWFLAKE: "snowflakeType", - Dialects.REDSHIFT: "redshiftType", - Dialects.ORACLE: "oracleType", - Dialects.MYSQL: "mysqlType", - Dialects.DATABRICKS: "databricksType", - } - return dialect_map.get(dialect, "physicalType") - - -def to_server_type(source, dialect: Dialects | None) -> str | None: - if dialect is None: - return None - dialect_map = { - Dialects.TSQL: "sqlserver", - Dialects.POSTGRES: "postgres", - Dialects.BIGQUERY: "bigquery", - Dialects.SNOWFLAKE: "snowflake", - Dialects.REDSHIFT: "redshift", - Dialects.ORACLE: "oracle", - Dialects.MYSQL: "mysql", - Dialects.DATABRICKS: "databricks", - } - return dialect_map.get(dialect, None) - + return dialect def to_col_type(column, dialect): col_type_kind = column.args["kind"] @@ -142,62 +108,12 @@ def to_col_type(column, dialect): return col_type_kind.sql(dialect) - def to_col_type_normalized(column): col_type = column.args["kind"].this.name if col_type is None: return None return col_type.lower() - -def get_description(column: sqlglot.expressions.ColumnDef) -> str | None: - if column.comments is None: - return None - return " ".join(comment.strip() for comment in column.comments) - - -def get_max_length(column: sqlglot.expressions.ColumnDef) -> int | None: - col_type = to_col_type_normalized(column) - if col_type is None: - return None - if col_type not in ["varchar", "char", "nvarchar", "nchar"]: - return None - col_params = list(column.args["kind"].find_all(sqlglot.expressions.DataTypeParam)) - max_length_str = None - if len(col_params) == 0: - return None - if len(col_params) == 1: - max_length_str = col_params[0].name - if len(col_params) == 2: - max_length_str = col_params[1].name - if max_length_str is not None: - return int(max_length_str) if max_length_str.isdigit() else None - - -def get_precision_scale(column): - col_type = to_col_type_normalized(column) - if col_type is None: - return None, None - if col_type not in ["decimal", "numeric", "float", "number"]: - return None, None - col_params = list(column.args["kind"].find_all(sqlglot.expressions.DataTypeParam)) - if len(col_params) == 0: - return None, None - if len(col_params) == 1: - if not col_params[0].name.isdigit(): - return None, None - precision = int(col_params[0].name) - scale = 0 - return precision, scale - if len(col_params) == 2: - if not col_params[0].name.isdigit() or not col_params[1].name.isdigit(): - return None, None - precision = int(col_params[0].name) - scale = int(col_params[1].name) - return precision, scale - return None, None - - def map_type_from_sql(sql_type: str): if sql_type is None: return None @@ -218,14 +134,16 @@ def map_type_from_sql(sql_type: str): return "string" elif sql_type_normed.startswith("ntext"): return "string" + elif sql_type_normed.startswith("number"): + return "decimal" elif sql_type_normed.startswith("int"): - return "int" + return "decimal" elif sql_type_normed.startswith("bigint"): return "long" elif sql_type_normed.startswith("tinyint"): - return "int" + return "decimal" elif sql_type_normed.startswith("smallint"): - return "int" + return "decimal" elif sql_type_normed.startswith("float"): return "float" elif sql_type_normed.startswith("decimal"): diff --git a/datacontract/model/data_contract_specification.py b/datacontract/model/data_contract_specification.py index dcfdd94ec..1b33fa917 100644 --- a/datacontract/model/data_contract_specification.py +++ b/datacontract/model/data_contract_specification.py @@ -28,6 +28,8 @@ "record", "struct", "null", + "geography", + "geometry", ] @@ -50,6 +52,7 @@ class ServerRole(pyd.BaseModel): class Server(pyd.BaseModel): + name: str | None = None type: str | None = None description: str | None = None environment: str | None = None @@ -58,6 +61,15 @@ class Server(pyd.BaseModel): dataset: str | None = None path: str | None = None delimiter: str | None = None + header: bool | None = None + escape: str | None = None + allVarchar: bool | None = None + allowQuotedNulls: bool | None = None + dateformat: str | None = None + decimalSeparator: str | None = None + newLine: str | None = None + timestampformat: str | None = None + quote: str | None = None endpointUrl: str | None = None location: str | None = None account: str | None = None @@ -181,7 +193,7 @@ class Field(pyd.BaseModel): examples: List[Any] | None = None quality: List[Quality] | None = [] config: Dict[str, Any] | None = None - + model_config = pyd.ConfigDict( extra="allow", ) @@ -325,3 +337,29 @@ def to_yaml(self): sort_keys=False, allow_unicode=True, ) + + def to_mermaid(self) -> str | None: + mmd_entity = "erDiagram\n\t" + mmd_references = [] + mmd_relations = [] + try: + for model_name, model in self.models.items(): + entity_block="" + for field_name, field in model.fields.items(): + entity_block += f"\t{ field_name.replace('#','Nb').replace(' ','_').replace('/','by')}{'🔑' if field.primaryKey or (field.unique and field.required) else ''}{'⌘' if field.references else''} {field.type}\n" + if field.references: + mmd_references.append(f'"📑{field.references.split(".")[0] if "." in field.references else ""}"' + "||--o{" +f'"📑{model_name}"') + mmd_entity+= f'\t"📑{model_name}"'+'{\n' + entity_block + '}\n' + mmd_relations.append(model_name) + for i in range(len(mmd_relations)): + for j in range(i + 1, len(mmd_relations)): + mmd_entity+= f'\t"📑{mmd_relations[i]}"'+ " ||--o{ " + f'"📑{mmd_relations[j]}"' + ' : "applies to"' + '\n' + if mmd_entity == "": + return None + else: + return f"{mmd_entity}\n" + + except Exception as e: + print(f"error : {e}, {self}") + return None + diff --git a/datacontract/templates/datacontract.html b/datacontract/templates/datacontract.html index 4bd126906..446a34485 100644 --- a/datacontract/templates/datacontract.html +++ b/datacontract/templates/datacontract.html @@ -4,10 +4,24 @@ Data Contract + {# #} + @@ -73,6 +87,15 @@

+ {% if datacontract.to_mermaid() %} +
+

Diagram

+

Entity relationship diagram

+
+
+ {{ render_partial('partials/erdiagram.html', datacontract = datacontract) }} +
+ {% endif %}
{{ render_partial('partials/datacontract_information.html', datacontract = datacontract) }}
diff --git a/datacontract/templates/index.html b/datacontract/templates/index.html index 8c15c221d..0c22b2563 100644 --- a/datacontract/templates/index.html +++ b/datacontract/templates/index.html @@ -4,6 +4,7 @@ Data Contract + {# #}