diff --git a/datacontract/imports/sql_importer.py b/datacontract/imports/sql_importer.py index c08efaee6..fde47a619 100644 --- a/datacontract/imports/sql_importer.py +++ b/datacontract/imports/sql_importer.py @@ -1,11 +1,17 @@ import logging import os +import re import sqlglot from sqlglot.dialects.dialect import Dialects from datacontract.imports.importer import Importer -from datacontract.model.data_contract_specification import DataContractSpecification, Field, Model, Server +from datacontract.model.data_contract_specification import ( + DataContractSpecification, + Field, + Model, + Server, +) from datacontract.model.exceptions import DataContractException from datacontract.model.run import ResultEnum @@ -18,16 +24,28 @@ def import_source( def import_sql( - data_contract_specification: DataContractSpecification, format: str, source: str, import_args: dict = None + data_contract_specification: DataContractSpecification, + format: str, + source: str, + import_args: dict = None, ) -> DataContractSpecification: + dialect = to_dialect(import_args) + + server_type: str | None = to_server_type(source, dialect) + if server_type is not None: + data_contract_specification.servers[server_type] = Server(type=server_type) + sql = read_file(source) - dialect = to_dialect(import_args) + parsed = None try: - parsed = sqlglot.parse_one(sql=sql, read=dialect) + parsed = sqlglot.parse_one(sql=sql, read=dialect.lower()) + + tables = parsed.find_all(sqlglot.expressions.Table) + except Exception as e: - logging.error(f"Error parsing SQL: {str(e)}") + logging.error(f"Error simple-dd-parser SQL: {str(e)}") raise DataContractException( type="import", name=f"Reading source from {source}", @@ -36,50 +54,92 @@ def import_sql( result=ResultEnum.error, ) - server_type: str | None = to_server_type(source, dialect) - if server_type is not None: - data_contract_specification.servers[server_type] = Server(type=server_type) - - tables = parsed.find_all(sqlglot.expressions.Table) - for table in tables: if data_contract_specification.models is None: data_contract_specification.models = {} - table_name = table.this.name - - fields = {} - for column in parsed.find_all(sqlglot.exp.ColumnDef): - if column.parent.this.name != table_name: - continue - - field = Field() - col_name = column.this.name - col_type = to_col_type(column, dialect) - field.type = map_type_from_sql(col_type) - col_description = get_description(column) - field.description = col_description - field.maxLength = get_max_length(column) - precision, scale = get_precision_scale(column) - field.precision = precision - field.scale = scale - field.primaryKey = get_primary_key(column) - field.required = column.find(sqlglot.exp.NotNullColumnConstraint) is not None or None - physical_type_key = to_physical_type_key(dialect) - field.config = { - physical_type_key: col_type, - } - - fields[col_name] = field + table_name, fields, table_description, table_tags = sqlglot_model_wrapper(table, parsed, dialect) data_contract_specification.models[table_name] = Model( type="table", + description=table_description, + tags=table_tags, fields=fields, ) return data_contract_specification +def sqlglot_model_wrapper(table, parsed, dialect): + table_description = None + table_tag = None + + table_name = table.this.name + + table_comment_property = parsed.find(sqlglot.expressions.SchemaCommentProperty) + if table_comment_property: + table_description = table_comment_property.this.this + + prop = parsed.find(sqlglot.expressions.Properties) + if prop: + tags = prop.find(sqlglot.expressions.Tags) + if tags: + tag_enum = tags.find(sqlglot.expressions.Property) + table_tag = [str(t) for t in tag_enum] + + fields = {} + for column in parsed.find_all(sqlglot.exp.ColumnDef): + if column.parent.this.name != table_name: + continue + + field = Field() + col_name = column.this.name + col_type = to_col_type(column, dialect) + field.type = map_type_from_sql(col_type) + col_description = get_description(column) + field.description = col_description + field.maxLength = get_max_length(column) + precision, scale = get_precision_scale(column) + field.precision = precision + field.scale = scale + field.primaryKey = get_primary_key(column) + field.required = column.find(sqlglot.exp.NotNullColumnConstraint) is not None or None + physical_type_key = to_physical_type_key(dialect) + field.tags = get_tags(column) + field.config = { + physical_type_key: col_type, + } + + fields[col_name] = field + + return table_name, fields, table_description, table_tag + + +def map_physical_type(column, dialect) -> str | None: + autoincrement = "" + if column.get("autoincrement") and dialect == Dialects.SNOWFLAKE: + autoincrement = " AUTOINCREMENT" + " START " + str(column.get("start")) if column.get("start") else "" + autoincrement += " INCREMENT " + str(column.get("increment")) if column.get("increment") else "" + autoincrement += " NOORDER" if not column.get("increment_order") else "" + elif column.get("autoincrement"): + autoincrement = " IDENTITY" + + if column.get("size") and isinstance(column.get("size"), tuple): + return ( + column.get("type") + + "(" + + str(column.get("size")[0]) + + "," + + str(column.get("size")[1]) + + ")" + + autoincrement + ) + elif column.get("size"): + return column.get("type") + "(" + str(column.get("size")) + ")" + autoincrement + else: + return column.get("type") + autoincrement + + def get_primary_key(column) -> bool | None: if column.find(sqlglot.exp.PrimaryKeyColumnConstraint) is not None: return True @@ -100,9 +160,7 @@ def to_dialect(import_args: dict) -> Dialects | None: return Dialects.TSQL if dialect.upper() in Dialects.__members__: return Dialects[dialect.upper()] - if dialect == "sqlserver": - return Dialects.TSQL - return None + return "None" def to_physical_type_key(dialect: Dialects | str | None) -> str: @@ -154,9 +212,22 @@ def to_col_type_normalized(column): def get_description(column: sqlglot.expressions.ColumnDef) -> str | None: if column.comments is None: + description = column.find(sqlglot.expressions.CommentColumnConstraint) + if description: + return description.this.this + else: + return None + return " ".join(comment.strip() for comment in column.comments) + + +def get_tags(column: sqlglot.expressions.ColumnDef) -> str | None: + tags = column.find(sqlglot.expressions.Tags) + if tags: + tag_enum = tags.find(sqlglot.expressions.Property) + return [str(t) for t in tag_enum] + else: return None - return " ".join(comment.strip() for comment in column.comments) - + def get_max_length(column: sqlglot.expressions.ColumnDef) -> int | None: col_type = to_col_type_normalized(column) @@ -222,18 +293,20 @@ def map_type_from_sql(sql_type: str) -> str | None: return "string" elif sql_type_normed.startswith("int"): return "int" - elif sql_type_normed.startswith("bigint"): - return "long" elif sql_type_normed.startswith("tinyint"): return "int" elif sql_type_normed.startswith("smallint"): return "int" - elif sql_type_normed.startswith("float"): + elif sql_type_normed.startswith("bigint"): + return "long" + elif sql_type_normed.startswith("float") or sql_type_normed.startswith("double") or sql_type_normed == "real": return "float" - elif sql_type_normed.startswith("decimal"): + elif sql_type_normed.startswith("number"): return "decimal" elif sql_type_normed.startswith("numeric"): return "decimal" + elif sql_type_normed.startswith("decimal"): + return "decimal" elif sql_type_normed.startswith("bool"): return "boolean" elif sql_type_normed.startswith("bit"): @@ -252,6 +325,7 @@ def map_type_from_sql(sql_type: str) -> str | None: sql_type_normed == "timestamptz" or sql_type_normed == "timestamp_tz" or sql_type_normed == "timestamp with time zone" + or sql_type_normed == "timestamp_ltz" ): return "timestamp_tz" elif sql_type_normed == "timestampntz" or sql_type_normed == "timestamp_ntz": @@ -271,7 +345,7 @@ def map_type_from_sql(sql_type: str) -> str | None: elif sql_type_normed == "xml": # tsql return "string" else: - return "variant" + return "object" def read_file(path): @@ -285,4 +359,5 @@ def read_file(path): ) with open(path, "r") as file: file_content = file.read() - return file_content + + return re.sub(r'\$\{(\w+)\}', r'\1', file_content) diff --git a/tests/fixtures/dbml/import/datacontract.yaml b/tests/fixtures/dbml/import/datacontract.yaml index 00d2a4440..d6b40c980 100644 --- a/tests/fixtures/dbml/import/datacontract.yaml +++ b/tests/fixtures/dbml/import/datacontract.yaml @@ -22,7 +22,7 @@ models: description: The business timestamp in UTC when the order was successfully registered in the source system and the payment was successful. order_total: - type: variant + type: object required: true primaryKey: false unique: false diff --git a/tests/fixtures/dbml/import/datacontract_table_filtered.yaml b/tests/fixtures/dbml/import/datacontract_table_filtered.yaml index fe011d539..2f9855d62 100644 --- a/tests/fixtures/dbml/import/datacontract_table_filtered.yaml +++ b/tests/fixtures/dbml/import/datacontract_table_filtered.yaml @@ -22,7 +22,7 @@ models: description: The business timestamp in UTC when the order was successfully registered in the source system and the payment was successful. order_total: - type: variant + type: object required: true primaryKey: false unique: false diff --git a/tests/fixtures/snowflake/import/ddl.sql b/tests/fixtures/snowflake/import/ddl.sql new file mode 100644 index 000000000..c458db7e5 --- /dev/null +++ b/tests/fixtures/snowflake/import/ddl.sql @@ -0,0 +1,42 @@ +CREATE TABLE IF NOT EXISTS ${database_name}.PUBLIC.my_table ( + -- https://docs.snowflake.com/en/sql-reference/intro-summary-data-types + field_primary_key NUMBER(38,0) NOT NULL autoincrement start 1 increment 1 COMMENT 'Primary key', + field_not_null INT NOT NULL COMMENT 'Not null', + field_char CHAR(10) COMMENT 'Fixed-length string', + field_character CHARACTER(10) COMMENT 'Fixed-length string', + field_varchar VARCHAR(100) WITH TAG (SNOWFLAKE.CORE.PRIVACY_CATEGORY='IDENTIFIER', SNOWFLAKE.CORE.SEMANTIC_CATEGORY='NAME') COMMENT 'Variable-length string', + + field_text TEXT COMMENT 'Large variable-length string', + field_string STRING COMMENT 'Large variable-length Unicode string', + + field_tinyint TINYINT COMMENT 'Integer (0-255)', + field_smallint SMALLINT COMMENT 'Integer (-32,768 to 32,767)', + field_int INT COMMENT 'Integer (-2.1B to 2.1B)', + field_integer INTEGER COMMENT 'Integer full name(-2.1B to 2.1B)', + field_bigint BIGINT COMMENT 'Large integer (-9 quintillion to 9 quintillion)', + + field_decimal DECIMAL(10, 2) COMMENT 'Fixed precision decimal', + field_numeric NUMERIC(10, 2) COMMENT 'Same as DECIMAL', + + field_float FLOAT COMMENT 'Approximate floating-point', + field_float4 FLOAT4 COMMENT 'Approximate floating-point 4', + field_float8 FLOAT8 COMMENT 'Approximate floating-point 8', + field_real REAL COMMENT 'Smaller floating-point', + + field_boulean BOOLEAN COMMENT 'Boolean-like (0 or 1)', + + field_date DATE COMMENT 'Date only (YYYY-MM-DD)', + field_time TIME COMMENT 'Time only (HH:MM:SS)', + field_timestamp TIMESTAMP COMMENT 'More precise datetime', + field_timestamp_ltz TIMESTAMP_LTZ COMMENT 'More precise datetime with local time zone; time zone, if provided, isn`t stored.', + field_timestamp_ntz TIMESTAMP_NTZ DEFAULT CURRENT_TIMESTAMP() COMMENT 'More precise datetime with no time zone; time zone, if provided, isn`t stored.', + field_timestamp_tz TIMESTAMP_TZ COMMENT 'More precise datetime with time zone.', + + field_binary BINARY(16) COMMENT 'Fixed-length binary', + field_varbinary VARBINARY(100) COMMENT 'Variable-length binary', + + field_variant VARIANT COMMENT 'VARIANT data', + field_json OBJECT COMMENT 'JSON (Stored as text)', + UNIQUE(field_not_null), + PRIMARY KEY (field_primary_key) +) COMMENT = 'My Comment' diff --git a/tests/test_import_sql_postgres.py b/tests/test_import_sql_postgres.py index edc18b729..8efa07e0d 100644 --- a/tests/test_import_sql_postgres.py +++ b/tests/test_import_sql_postgres.py @@ -20,6 +20,8 @@ def test_cli(): "sql", "--source", sql_file_path, + "--dialect", + "postgres" ], ) assert result.exit_code == 0 diff --git a/tests/test_import_sql_snowflake.py b/tests/test_import_sql_snowflake.py new file mode 100644 index 000000000..b598b251e --- /dev/null +++ b/tests/test_import_sql_snowflake.py @@ -0,0 +1,186 @@ +import yaml + +from datacontract.data_contract import DataContract + +sql_file_path = "fixtures/snowflake/import/ddl.sql" + + +def test_import_sql_snowflake(): + result = DataContract().import_from_source("sql", sql_file_path, dialect="snowflake") + + expected = """ +dataContractSpecification: 1.1.0 +id: my-data-contract-id +info: + title: My Data Contract + version: 0.0.1 +servers: + snowflake: + type: snowflake +models: + my_table: + description: My Comment + type: table + fields: + field_primary_key: + type: decimal + required: true + description: Primary key + precision: 38 + scale: 0 + config: + snowflakeType: DECIMAL(38, 0) + field_not_null: + type: int + required: true + description: Not null + config: + snowflakeType: INT + field_char: + type: string + description: Fixed-length string + maxLength: 10 + config: + snowflakeType: CHAR(10) + field_character: + type: string + description: Fixed-length string + maxLength: 10 + config: + snowflakeType: CHAR(10) + field_varchar: + type: string + description: Variable-length string + maxLength: 100 + tags: ["SNOWFLAKE.CORE.PRIVACY_CATEGORY='IDENTIFIER'", "SNOWFLAKE.CORE.SEMANTIC_CATEGORY='NAME'"] + config: + snowflakeType: VARCHAR(100) + field_text: + type: string + description: Large variable-length string + config: + snowflakeType: TEXT + field_string: + type: string + description: Large variable-length Unicode string + config: + snowflakeType: TEXT + field_tinyint: + type: int + description: Integer (0-255) + config: + snowflakeType: TINYINT + field_smallint: + type: int + description: Integer (-32,768 to 32,767) + config: + snowflakeType: SMALLINT + field_int: + type: int + description: Integer (-2.1B to 2.1B) + config: + snowflakeType: INT + field_integer: + type: int + description: Integer full name(-2.1B to 2.1B) + config: + snowflakeType: INT + field_bigint: + type: long + description: Large integer (-9 quintillion to 9 quintillion) + config: + snowflakeType: BIGINT + field_decimal: + type: decimal + description: Fixed precision decimal + precision: 10 + scale: 2 + config: + snowflakeType: DECIMAL(10, 2) + field_numeric: + type: decimal + description: Same as DECIMAL + precision: 10 + scale: 2 + config: + snowflakeType: DECIMAL(10, 2) + field_float: + type: float + description: Approximate floating-point + config: + snowflakeType: FLOAT + field_float4: + type: float + description: Approximate floating-point 4 + config: + snowflakeType: FLOAT + field_float8: + type: float + description: Approximate floating-point 8 + config: + snowflakeType: DOUBLE + field_real: + type: float + description: Smaller floating-point + config: + snowflakeType: FLOAT + field_boulean: + type: boolean + description: Boolean-like (0 or 1) + config: + snowflakeType: BOOLEAN + field_date: + type: date + description: Date only (YYYY-MM-DD) + config: + snowflakeType: DATE + field_time: + type: string + description: Time only (HH:MM:SS) + config: + snowflakeType: TIME + field_timestamp: + type: timestamp_ntz + description: More precise datetime + config: + snowflakeType: TIMESTAMP + field_timestamp_ltz: + type: object + description: More precise datetime with local time zone; time zone, if provided, isn`t stored. + config: + snowflakeType: TIMESTAMPLTZ + field_timestamp_ntz: + type: timestamp_ntz + description: More precise datetime with no time zone; time zone, if provided, isn`t stored. + config: + snowflakeType: TIMESTAMPNTZ + field_timestamp_tz: + type: timestamp_tz + description: More precise datetime with time zone. + config: + snowflakeType: TIMESTAMPTZ + field_binary: + type: bytes + description: Fixed-length binary + config: + snowflakeType: BINARY(16) + field_varbinary: + type: bytes + description: Variable-length binary + config: + snowflakeType: VARBINARY(100) + field_variant: + type: object + description: VARIANT data + config: + snowflakeType: VARIANT + field_json: + type: object + description: JSON (Stored as text) + config: + snowflakeType: OBJECT""" + + print("Result", result.to_yaml()) + assert yaml.safe_load(result.to_yaml()) == yaml.safe_load(expected) + # Disable linters so we don't get "missing description" warnings + assert DataContract(data_contract_str=expected).lint(enabled_linters=set()).has_passed()