From 657b68db7d8d6d1904b1c054ba4d4bae73eeedff Mon Sep 17 00:00:00 2001 From: Damien Maresma Date: Wed, 11 Jun 2025 15:03:31 -0400 Subject: [PATCH 1/8] init. snowflake sql ddl import to datacontract --- datacontract/imports/sql_importer.py | 178 ++++++++++++++++------ pyproject.toml | 1 + tests/fixtures/snowflake/import/ddl.sql | 42 ++++++ tests/test_import_sql_snowflake.py | 192 ++++++++++++++++++++++++ 4 files changed, 370 insertions(+), 43 deletions(-) create mode 100644 tests/fixtures/snowflake/import/ddl.sql create mode 100644 tests/test_import_sql_snowflake.py diff --git a/datacontract/imports/sql_importer.py b/datacontract/imports/sql_importer.py index c08efaee6..2b7718771 100644 --- a/datacontract/imports/sql_importer.py +++ b/datacontract/imports/sql_importer.py @@ -3,6 +3,7 @@ import sqlglot from sqlglot.dialects.dialect import Dialects +from simple_ddl_parser import parse_from_file from datacontract.imports.importer import Importer from datacontract.model.data_contract_specification import DataContractSpecification, Field, Model, Server @@ -20,12 +21,28 @@ def import_source( def import_sql( data_contract_specification: DataContractSpecification, format: str, source: str, import_args: dict = None ) -> DataContractSpecification: - sql = read_file(source) - + dialect = to_dialect(import_args) + server_type: str | None = to_server_type(source, dialect) + if server_type is not None: + data_contract_specification.servers[server_type] = Server(type=server_type) + + sql = read_file(source) + + parsed = None + try: - parsed = sqlglot.parse_one(sql=sql, read=dialect) + parsed = sqlglot.parse_one(sql=sql, read=dialect.lower()) + + tables = parsed.find_all(sqlglot.expressions.Table) + + except Exception as e: + # Second try with simple-ddl-parser + ddl = parse_from_file(source, group_by_type=True, encoding = "cp1252", output_mode = dialect.lower() ) + + tables = ddl["tables"] + except Exception as e: logging.error(f"Error parsing SQL: {str(e)}") raise DataContractException( @@ -36,49 +53,121 @@ def import_sql( result=ResultEnum.error, ) - server_type: str | None = to_server_type(source, dialect) - if server_type is not None: - data_contract_specification.servers[server_type] = Server(type=server_type) - - tables = parsed.find_all(sqlglot.expressions.Table) - for table in tables: if data_contract_specification.models is None: data_contract_specification.models = {} - - table_name = table.this.name - - fields = {} - for column in parsed.find_all(sqlglot.exp.ColumnDef): - if column.parent.this.name != table_name: - continue - - field = Field() - col_name = column.this.name - col_type = to_col_type(column, dialect) - field.type = map_type_from_sql(col_type) - col_description = get_description(column) - field.description = col_description - field.maxLength = get_max_length(column) - precision, scale = get_precision_scale(column) - field.precision = precision - field.scale = scale - field.primaryKey = get_primary_key(column) - field.required = column.find(sqlglot.exp.NotNullColumnConstraint) is not None or None - physical_type_key = to_physical_type_key(dialect) - field.config = { - physical_type_key: col_type, - } - - fields[col_name] = field + + if hasattr(table, 'this'): # sqlglot + table_name, fields, table_description, table_tags = sqlglot_model_wrapper(table, parsed, dialect) + else: # simple-ddl-parser + table_name, fields, table_description, table_tags = simple_ddl_model_wrapper(table, dialect) data_contract_specification.models[table_name] = Model( type="table", + description=table_description, + tags=table_tags, fields=fields, ) return data_contract_specification +def sqlglot_model_wrapper(table, parsed, dialect): + table_name = table.this.name + + fields = {} + for column in parsed.find_all(sqlglot.exp.ColumnDef): + if column.parent.this.name != table_name: + continue + + field = Field() + col_name = column.this.name + col_type = to_col_type(column, dialect) + field.type = map_type_from_sql(col_type) + col_description = get_description(column) + field.description = col_description + field.maxLength = get_max_length(column) + precision, scale = get_precision_scale(column) + field.precision = precision + field.scale = scale + field.primaryKey = get_primary_key(column) + field.required = column.find(sqlglot.exp.NotNullColumnConstraint) is not None or None + physical_type_key = to_physical_type_key(dialect) + field.config = { + physical_type_key: col_type, + } + + fields[col_name] = field + + return table_name, fields, None, None + +def simple_ddl_model_wrapper(table, dialect): + table_name = table["table_name"] + + fields = {} + + for column in table["columns"]: + field = Field() + field.type = map_type_from_sql(column["type"]) + physical_type_key = to_physical_type_key(dialect) + datatype = map_physical_type(column, dialect) + field.config = { + physical_type_key: datatype, + } + + if not column["nullable"]: + field.required = True + if column["unique"]: + field.unique = True + + if column["size"] is not None and column["size"] and not isinstance(column["size"], tuple): + field.maxLength = column["size"] + elif isinstance(column["size"], tuple): + field.precision = column["size"][0] + field.scale = column["size"][1] + + field.description = column["comment"][1:-1].strip() if column.get("comment") else None + + if column.get("with_tag"): + field.tags = column["with_tag"] + if column.get("with_masking_policy"): + field.classification = ", ".join(column["with_masking_policy"]) + if column.get("generated"): + field.examples = str(column["generated"]) + + fields[column["name"]] = field + + if table.get("constraints"): + if table["constraints"].get("primary_key"): + for primary_key in table["constraints"]["primary_key"]["columns"]: + if primary_key in fields: + fields[primary_key].unique = True + fields[primary_key].required = True + fields[primary_key].primaryKey = True + + table_description = table["comment"][1:-1] if table.get("comment") else None + table_tags = table["with_tag"][1:-1] if table.get("with_tag") else None + + return table_name, fields, table_description, table_tags + +def map_physical_type(column, dialect) -> str | None: + autoincrement = "" + if column.get("autoincrement") == True and dialect == Dialects.SNOWFLAKE: + autoincrement = " AUTOINCREMENT" \ + + " START " + str(column.get("start")) if column.get("start") else "" + autoincrement += " INCREMENT " + str(column.get("increment")) if column.get("increment") else "" + autoincrement += " NOORDER" if column.get("increment_order") == False else "" + elif column.get("autoincrement") == True: + autoincrement = " IDENTITY" + + if column.get("size") and isinstance(column.get("size"), tuple): + return column.get("type") + "(" + str(column.get("size")[0]) + "," + str(column.get("size")[1]) + ")" \ + + autoincrement + elif column.get("size"): + return column.get("type") + "(" + str(column.get("size")) + ")" \ + + autoincrement + else: + return column.get("type") + autoincrement + def get_primary_key(column) -> bool | None: if column.find(sqlglot.exp.PrimaryKeyColumnConstraint) is not None: @@ -100,8 +189,6 @@ def to_dialect(import_args: dict) -> Dialects | None: return Dialects.TSQL if dialect.upper() in Dialects.__members__: return Dialects[dialect.upper()] - if dialect == "sqlserver": - return Dialects.TSQL return None @@ -221,19 +308,23 @@ def map_type_from_sql(sql_type: str) -> str | None: elif sql_type_normed.startswith("ntext"): return "string" elif sql_type_normed.startswith("int"): - return "int" - elif sql_type_normed.startswith("bigint"): - return "long" + return "int" elif sql_type_normed.startswith("tinyint"): return "int" elif sql_type_normed.startswith("smallint"): return "int" - elif sql_type_normed.startswith("float"): + elif sql_type_normed.startswith("bigint"): + return "long" + elif (sql_type_normed.startswith("float") + or sql_type_normed.startswith("double") + or sql_type_normed == "real"): return "float" - elif sql_type_normed.startswith("decimal"): + elif sql_type_normed.startswith("number"): return "decimal" elif sql_type_normed.startswith("numeric"): return "decimal" + elif sql_type_normed.startswith("decimal"): + return "decimal" elif sql_type_normed.startswith("bool"): return "boolean" elif sql_type_normed.startswith("bit"): @@ -252,6 +343,7 @@ def map_type_from_sql(sql_type: str) -> str | None: sql_type_normed == "timestamptz" or sql_type_normed == "timestamp_tz" or sql_type_normed == "timestamp with time zone" + or sql_type_normed == "timestamp_ltz" ): return "timestamp_tz" elif sql_type_normed == "timestampntz" or sql_type_normed == "timestamp_ntz": @@ -271,7 +363,7 @@ def map_type_from_sql(sql_type: str) -> str | None: elif sql_type_normed == "xml": # tsql return "string" else: - return "variant" + return "object" def read_file(path): diff --git a/pyproject.toml b/pyproject.toml index 32e2e909b..d1045d87d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,7 @@ dependencies = [ "python-multipart>=0.0.20,<1.0.0", "rich>=13.7,<15.0", "sqlglot>=26.6.0,<27.0.0", + "simple-ddl-parser>=1.7.1,<2.0.0", "duckdb>=1.0.0,<2.0.0", "soda-core-duckdb>=3.3.20,<3.6.0", # remove setuptools when https://github.com/sodadata/soda-core/issues/2091 is resolved diff --git a/tests/fixtures/snowflake/import/ddl.sql b/tests/fixtures/snowflake/import/ddl.sql new file mode 100644 index 000000000..f76058829 --- /dev/null +++ b/tests/fixtures/snowflake/import/ddl.sql @@ -0,0 +1,42 @@ +CREATE TABLE IF NOT EXISTS ${database_name}.PUBLIC.my_table ( + -- https://docs.snowflake.com/en/sql-reference/intro-summary-data-types + field_primary_key NUMBER(38,0) NOT NULL autoincrement start 1 increment 1 noorder COMMENT 'Primary key', + field_not_null INT NOT NULL COMMENT 'Not null', + field_char CHAR(10) COMMENT 'Fixed-length string', + field_character CHARACTER(10) COMMENT 'Fixed-length string', + field_varchar VARCHAR(100) WITH TAG (SNOWFLAKE.CORE.PRIVACY_CATEGORY='IDENTIFIER', SNOWFLAKE.CORE.SEMANTIC_CATEGORY='NAME') COMMENT 'Variable-length string', + + field_text TEXT COMMENT 'Large variable-length string', + field_string STRING COMMENT 'Large variable-length Unicode string', + + field_tinyint TINYINT COMMENT 'Integer (0-255)', + field_smallint SMALLINT COMMENT 'Integer (-32,768 to 32,767)', + field_int INT COMMENT 'Integer (-2.1B to 2.1B)', + field_integer INTEGER COMMENT 'Integer full name(-2.1B to 2.1B)', + field_bigint BIGINT COMMENT 'Large integer (-9 quintillion to 9 quintillion)', + + field_decimal DECIMAL(10, 2) COMMENT 'Fixed precision decimal', + field_numeric NUMERIC(10, 2) COMMENT 'Same as DECIMAL', + + field_float FLOAT COMMENT 'Approximate floating-point', + field_float4 FLOAT4 COMMENT 'Approximate floating-point 4', + field_float8 FLOAT8 COMMENT 'Approximate floating-point 8', + field_real REAL COMMENT 'Smaller floating-point', + + field_boulean BOOLEAN COMMENT 'Boolean-like (0 or 1)', + + field_date DATE COMMENT 'Date only (YYYY-MM-DD)', + field_time TIME COMMENT 'Time only (HH:MM:SS)', + field_timestamp TIMESTAMP COMMENT 'More precise datetime', + field_timestamp_ltz TIMESTAMP_LTZ COMMENT 'More precise datetime with local time zone; time zone, if provided, isn`t stored.', + field_timestamp_ntz TIMESTAMP_NTZ DEFAULT CURRENT_TIMESTAMP() COMMENT 'More precise datetime with no time zone; time zone, if provided, isn`t stored.', + field_timestamp_tz TIMESTAMP_TZ COMMENT 'More precise datetime with time zone.', + + field_binary BINARY(16) COMMENT 'Fixed-length binary', + field_varbinary VARBINARY(100) COMMENT 'Variable-length binary', + + field_variant VARIANT COMMENT 'VARIANT data', + field_json OBJECT COMMENT 'JSON (Stored as text)', + UNIQUE(field_not_null), + PRIMARY KEY (field_primary_key) +) COMMENT = 'My Comment' diff --git a/tests/test_import_sql_snowflake.py b/tests/test_import_sql_snowflake.py new file mode 100644 index 000000000..e83942823 --- /dev/null +++ b/tests/test_import_sql_snowflake.py @@ -0,0 +1,192 @@ +import yaml + +from datacontract.data_contract import DataContract + +sql_file_path = "fixtures/snowflake/import/ddl.sql" + + +def test_import_sql_snowflake(): + + result = DataContract().import_from_source("sql", sql_file_path, dialect="snowflake") + + expected = """ +dataContractSpecification: 1.1.0 +id: my-data-contract-id +info: + title: My Data Contract + version: 0.0.1 +servers: + snowflake: + type: snowflake +models: + my_table: + description: My Comment + type: table + fields: + field_primary_key: + type: decimal + required: true + description: Primary key + precision: 38 + scale: 0 + config: + snowflakeType: NUMBER(38,0) AUTOINCREMENT START 1 INCREMENT 1 NOORDER + field_not_null: + type: int + required: true + unique: true + description: Not null + config: + snowflakeType: INT + field_char: + type: string + description: Fixed-length string + maxLength: 10 + config: + snowflakeType: CHAR(10) + field_character: + type: string + description: Fixed-length string + maxLength: 10 + config: + snowflakeType: CHARACTER(10) + field_varchar: + type: string + description: Variable-length string + maxLength: 100 + tags: ["SNOWFLAKE.CORE.PRIVACY_CATEGORY='IDENTIFIER'", "SNOWFLAKE.CORE.SEMANTIC_CATEGORY='NAME'"] + config: + snowflakeType: VARCHAR(100) + field_text: + type: string + description: Large variable-length string + config: + snowflakeType: TEXT + field_string: + type: string + description: Large variable-length Unicode string + config: + snowflakeType: STRING + field_tinyint: + type: int + description: Integer ( 0-255) + config: + snowflakeType: TINYINT + field_smallint: + type: int + description: Integer ( -32 , 768 to 32 , 767) + config: + snowflakeType: SMALLINT + field_int: + type: int + description: Integer ( -2.1B to 2.1B) + config: + snowflakeType: INT + field_integer: + type: int + description: Integer full name ( -2.1B to 2.1B) + config: + snowflakeType: INTEGER + field_bigint: + type: long + description: Large integer ( -9 quintillion to 9 quintillion) + config: + snowflakeType: BIGINT + field_decimal: + type: decimal + description: Fixed precision decimal + precision: 10 + scale: 2 + config: + snowflakeType: DECIMAL(10,2) + field_numeric: + type: decimal + description: Same as DECIMAL + precision: 10 + scale: 2 + config: + snowflakeType: NUMERIC(10,2) + field_float: + type: float + description: Approximate floating-point + config: + snowflakeType: FLOAT + field_float4: + type: float + description: Approximate floating-point 4 + config: + snowflakeType: FLOAT4 + field_float8: + type: float + description: Approximate floating-point 8 + config: + snowflakeType: FLOAT8 + field_real: + type: float + description: Smaller floating-point + config: + snowflakeType: REAL + field_boulean: + type: boolean + description: Boolean-like ( 0 or 1) + config: + snowflakeType: BOOLEAN + field_date: + type: date + description: Date only ( YYYY-MM-DD) + config: + snowflakeType: DATE + field_time: + type: string + description: Time only ( HH:MM:SS) + config: + snowflakeType: TIME + field_timestamp: + type: timestamp_ntz + description: More precise datetime + config: + snowflakeType: TIMESTAMP + field_timestamp_ltz: + type: timestamp_tz + description: More precise datetime with local time zone; time zone , if provided + , isn`t stored. + config: + snowflakeType: TIMESTAMP_LTZ + field_timestamp_ntz: + type: timestamp_ntz + description: More precise datetime with no time zone; time zone , if provided + , isn`t stored. + config: + snowflakeType: TIMESTAMP_NTZ + field_timestamp_tz: + type: timestamp_tz + description: More precise datetime with time zone. + config: + snowflakeType: TIMESTAMP_TZ + field_binary: + type: bytes + description: Fixed-length binary + maxLength: 16 + config: + snowflakeType: BINARY(16) + field_varbinary: + type: bytes + description: Variable-length binary + maxLength: 100 + config: + snowflakeType: VARBINARY(100) + field_variant: + type: object + description: VARIANT data + config: + snowflakeType: VARIANT + field_json: + type: object + description: JSON ( Stored as text) + config: + snowflakeType: OBJECT""" + + print("Result", result.to_yaml()) + assert yaml.safe_load(result.to_yaml()) == yaml.safe_load(expected) + # Disable linters so we don't get "missing description" warnings + assert DataContract(data_contract_str=expected).lint(enabled_linters=set()).has_passed() From a224abac2aaa771357dfa0915c0efa67215bd228 Mon Sep 17 00:00:00 2001 From: Damien Maresma Date: Wed, 11 Jun 2025 15:23:45 -0400 Subject: [PATCH 2/8] apply ruff check and format --- datacontract/imports/sql_importer.py | 72 +++++++++++++++------------- tests/test_import_sql_snowflake.py | 15 +++--- 2 files changed, 46 insertions(+), 41 deletions(-) diff --git a/datacontract/imports/sql_importer.py b/datacontract/imports/sql_importer.py index 2b7718771..b890095bf 100644 --- a/datacontract/imports/sql_importer.py +++ b/datacontract/imports/sql_importer.py @@ -21,30 +21,30 @@ def import_source( def import_sql( data_contract_specification: DataContractSpecification, format: str, source: str, import_args: dict = None ) -> DataContractSpecification: - dialect = to_dialect(import_args) server_type: str | None = to_server_type(source, dialect) if server_type is not None: data_contract_specification.servers[server_type] = Server(type=server_type) - sql = read_file(source) + sql = read_file(source) parsed = None try: - parsed = sqlglot.parse_one(sql=sql, read=dialect.lower()) - + parsed = sqlglot.parse_one(sql=sql, read=dialect.lower()) + tables = parsed.find_all(sqlglot.expressions.Table) except Exception as e: + logging.error(f"Error parsing sqlglot: {str(e)}") # Second try with simple-ddl-parser - ddl = parse_from_file(source, group_by_type=True, encoding = "cp1252", output_mode = dialect.lower() ) + ddl = parse_from_file(source, group_by_type=True, encoding="cp1252", output_mode=dialect.lower()) tables = ddl["tables"] except Exception as e: - logging.error(f"Error parsing SQL: {str(e)}") + logging.error(f"Error simple-dd-parser SQL: {str(e)}") raise DataContractException( type="import", name=f"Reading source from {source}", @@ -56,10 +56,10 @@ def import_sql( for table in tables: if data_contract_specification.models is None: data_contract_specification.models = {} - - if hasattr(table, 'this'): # sqlglot + + if hasattr(table, "this"): # sqlglot table_name, fields, table_description, table_tags = sqlglot_model_wrapper(table, parsed, dialect) - else: # simple-ddl-parser + else: # simple-ddl-parser table_name, fields, table_description, table_tags = simple_ddl_model_wrapper(table, dialect) data_contract_specification.models[table_name] = Model( @@ -71,6 +71,7 @@ def import_sql( return data_contract_specification + def sqlglot_model_wrapper(table, parsed, dialect): table_name = table.this.name @@ -100,6 +101,7 @@ def sqlglot_model_wrapper(table, parsed, dialect): return table_name, fields, None, None + def simple_ddl_model_wrapper(table, dialect): table_name = table["table_name"] @@ -115,10 +117,10 @@ def simple_ddl_model_wrapper(table, dialect): } if not column["nullable"]: - field.required = True + field.required = True if column["unique"]: field.unique = True - + if column["size"] is not None and column["size"] and not isinstance(column["size"], tuple): field.maxLength = column["size"] elif isinstance(column["size"], tuple): @@ -126,45 +128,51 @@ def simple_ddl_model_wrapper(table, dialect): field.scale = column["size"][1] field.description = column["comment"][1:-1].strip() if column.get("comment") else None - + if column.get("with_tag"): field.tags = column["with_tag"] if column.get("with_masking_policy"): - field.classification = ", ".join(column["with_masking_policy"]) + field.classification = ", ".join(column["with_masking_policy"]) if column.get("generated"): field.examples = str(column["generated"]) - + fields[column["name"]] = field - + if table.get("constraints"): - if table["constraints"].get("primary_key"): + if table["constraints"].get("primary_key"): for primary_key in table["constraints"]["primary_key"]["columns"]: if primary_key in fields: fields[primary_key].unique = True fields[primary_key].required = True fields[primary_key].primaryKey = True - table_description = table["comment"][1:-1] if table.get("comment") else None - table_tags = table["with_tag"][1:-1] if table.get("with_tag") else None + table_description = table["comment"][1:-1] if table.get("comment") else None + table_tags = table["with_tag"][1:-1] if table.get("with_tag") else None return table_name, fields, table_description, table_tags - + + def map_physical_type(column, dialect) -> str | None: autoincrement = "" - if column.get("autoincrement") == True and dialect == Dialects.SNOWFLAKE: - autoincrement = " AUTOINCREMENT" \ - + " START " + str(column.get("start")) if column.get("start") else "" + if column.get("autoincrement") and dialect == Dialects.SNOWFLAKE: + autoincrement = " AUTOINCREMENT" + " START " + str(column.get("start")) if column.get("start") else "" autoincrement += " INCREMENT " + str(column.get("increment")) if column.get("increment") else "" - autoincrement += " NOORDER" if column.get("increment_order") == False else "" - elif column.get("autoincrement") == True: + autoincrement += " NOORDER" if not column.get("increment_order") else "" + elif column.get("autoincrement"): autoincrement = " IDENTITY" - + if column.get("size") and isinstance(column.get("size"), tuple): - return column.get("type") + "(" + str(column.get("size")[0]) + "," + str(column.get("size")[1]) + ")" \ - + autoincrement + return ( + column.get("type") + + "(" + + str(column.get("size")[0]) + + "," + + str(column.get("size")[1]) + + ")" + + autoincrement + ) elif column.get("size"): - return column.get("type") + "(" + str(column.get("size")) + ")" \ - + autoincrement + return column.get("type") + "(" + str(column.get("size")) + ")" + autoincrement else: return column.get("type") + autoincrement @@ -308,16 +316,14 @@ def map_type_from_sql(sql_type: str) -> str | None: elif sql_type_normed.startswith("ntext"): return "string" elif sql_type_normed.startswith("int"): - return "int" + return "int" elif sql_type_normed.startswith("tinyint"): return "int" elif sql_type_normed.startswith("smallint"): return "int" elif sql_type_normed.startswith("bigint"): return "long" - elif (sql_type_normed.startswith("float") - or sql_type_normed.startswith("double") - or sql_type_normed == "real"): + elif sql_type_normed.startswith("float") or sql_type_normed.startswith("double") or sql_type_normed == "real": return "float" elif sql_type_normed.startswith("number"): return "decimal" diff --git a/tests/test_import_sql_snowflake.py b/tests/test_import_sql_snowflake.py index e83942823..59a33c6c0 100644 --- a/tests/test_import_sql_snowflake.py +++ b/tests/test_import_sql_snowflake.py @@ -6,10 +6,9 @@ def test_import_sql_snowflake(): - - result = DataContract().import_from_source("sql", sql_file_path, dialect="snowflake") + result = DataContract().import_from_source("sql", sql_file_path, dialect="snowflake") - expected = """ + expected = """ dataContractSpecification: 1.1.0 id: my-data-contract-id info: @@ -185,8 +184,8 @@ def test_import_sql_snowflake(): description: JSON ( Stored as text) config: snowflakeType: OBJECT""" - - print("Result", result.to_yaml()) - assert yaml.safe_load(result.to_yaml()) == yaml.safe_load(expected) - # Disable linters so we don't get "missing description" warnings - assert DataContract(data_contract_str=expected).lint(enabled_linters=set()).has_passed() + + print("Result", result.to_yaml()) + assert yaml.safe_load(result.to_yaml()) == yaml.safe_load(expected) + # Disable linters so we don't get "missing description" warnings + assert DataContract(data_contract_str=expected).lint(enabled_linters=set()).has_passed() From 327c21a32c274572b48697949fbfea93e2ca7393 Mon Sep 17 00:00:00 2001 From: Damien Maresma Date: Wed, 11 Jun 2025 15:29:51 -0400 Subject: [PATCH 3/8] align import --- datacontract/imports/sql_importer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datacontract/imports/sql_importer.py b/datacontract/imports/sql_importer.py index b890095bf..ccd6d044b 100644 --- a/datacontract/imports/sql_importer.py +++ b/datacontract/imports/sql_importer.py @@ -8,7 +8,7 @@ from datacontract.imports.importer import Importer from datacontract.model.data_contract_specification import DataContractSpecification, Field, Model, Server from datacontract.model.exceptions import DataContractException -from datacontract.model.run import ResultEnum +from datacontract.model.run import ResultEnum class SqlImporter(Importer): From 234c2fb42f2f975c095340ef24ab497663c1a418 Mon Sep 17 00:00:00 2001 From: Damien Maresma Date: Wed, 11 Jun 2025 15:38:33 -0400 Subject: [PATCH 4/8] add dialect --- datacontract/imports/sql_importer.py | 2 +- tests/test_import_sql_postgres.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/datacontract/imports/sql_importer.py b/datacontract/imports/sql_importer.py index ccd6d044b..72ff26d2c 100644 --- a/datacontract/imports/sql_importer.py +++ b/datacontract/imports/sql_importer.py @@ -197,7 +197,7 @@ def to_dialect(import_args: dict) -> Dialects | None: return Dialects.TSQL if dialect.upper() in Dialects.__members__: return Dialects[dialect.upper()] - return None + return "None" def to_physical_type_key(dialect: Dialects | str | None) -> str: diff --git a/tests/test_import_sql_postgres.py b/tests/test_import_sql_postgres.py index edc18b729..8efa07e0d 100644 --- a/tests/test_import_sql_postgres.py +++ b/tests/test_import_sql_postgres.py @@ -20,6 +20,8 @@ def test_cli(): "sql", "--source", sql_file_path, + "--dialect", + "postgres" ], ) assert result.exit_code == 0 From 5d412fd291a76faf1892402e19c84f4ccfb026eb Mon Sep 17 00:00:00 2001 From: Damien Maresma Date: Fri, 13 Jun 2025 18:40:27 -0400 Subject: [PATCH 5/8] sqlglot ${} token bypass and waiting for NOORDER ORDER AUTOINCREMENT waiting PR --- datacontract/imports/sql_importer.py | 113 +++++++++--------------- tests/fixtures/snowflake/import/ddl.sql | 2 +- tests/test_import_sql_snowflake.py | 53 +++++------ 3 files changed, 68 insertions(+), 100 deletions(-) diff --git a/datacontract/imports/sql_importer.py b/datacontract/imports/sql_importer.py index 72ff26d2c..35ef1556f 100644 --- a/datacontract/imports/sql_importer.py +++ b/datacontract/imports/sql_importer.py @@ -1,14 +1,19 @@ import logging import os +import re import sqlglot from sqlglot.dialects.dialect import Dialects -from simple_ddl_parser import parse_from_file from datacontract.imports.importer import Importer -from datacontract.model.data_contract_specification import DataContractSpecification, Field, Model, Server +from datacontract.model.data_contract_specification import ( + DataContractSpecification, + Field, + Model, + Server, +) from datacontract.model.exceptions import DataContractException -from datacontract.model.run import ResultEnum +from datacontract.model.run import ResultEnum class SqlImporter(Importer): @@ -19,7 +24,10 @@ def import_source( def import_sql( - data_contract_specification: DataContractSpecification, format: str, source: str, import_args: dict = None + data_contract_specification: DataContractSpecification, + format: str, + source: str, + import_args: dict = None, ) -> DataContractSpecification: dialect = to_dialect(import_args) @@ -36,13 +44,6 @@ def import_sql( tables = parsed.find_all(sqlglot.expressions.Table) - except Exception as e: - logging.error(f"Error parsing sqlglot: {str(e)}") - # Second try with simple-ddl-parser - ddl = parse_from_file(source, group_by_type=True, encoding="cp1252", output_mode=dialect.lower()) - - tables = ddl["tables"] - except Exception as e: logging.error(f"Error simple-dd-parser SQL: {str(e)}") raise DataContractException( @@ -57,10 +58,7 @@ def import_sql( if data_contract_specification.models is None: data_contract_specification.models = {} - if hasattr(table, "this"): # sqlglot - table_name, fields, table_description, table_tags = sqlglot_model_wrapper(table, parsed, dialect) - else: # simple-ddl-parser - table_name, fields, table_description, table_tags = simple_ddl_model_wrapper(table, dialect) + table_name, fields, table_description, table_tags = sqlglot_model_wrapper(table, parsed, dialect) data_contract_specification.models[table_name] = Model( type="table", @@ -73,8 +71,22 @@ def import_sql( def sqlglot_model_wrapper(table, parsed, dialect): + table_description = None + table_tag = None + table_name = table.this.name + table_comment_property = parsed.find(sqlglot.expressions.SchemaCommentProperty) + if table_comment_property: + table_description = table_comment_property.this.this + + prop = parsed.find(sqlglot.expressions.Properties) + if prop: + tags = prop.find(sqlglot.expressions.Tags) + if tags: + tag_enum = tags.find(sqlglot.expressions.Property) + table_tag = [str(t) for t in tag_enum] + fields = {} for column in parsed.find_all(sqlglot.exp.ColumnDef): if column.parent.this.name != table_name: @@ -93,63 +105,14 @@ def sqlglot_model_wrapper(table, parsed, dialect): field.primaryKey = get_primary_key(column) field.required = column.find(sqlglot.exp.NotNullColumnConstraint) is not None or None physical_type_key = to_physical_type_key(dialect) + field.tags = get_tags(column) field.config = { physical_type_key: col_type, } fields[col_name] = field - return table_name, fields, None, None - - -def simple_ddl_model_wrapper(table, dialect): - table_name = table["table_name"] - - fields = {} - - for column in table["columns"]: - field = Field() - field.type = map_type_from_sql(column["type"]) - physical_type_key = to_physical_type_key(dialect) - datatype = map_physical_type(column, dialect) - field.config = { - physical_type_key: datatype, - } - - if not column["nullable"]: - field.required = True - if column["unique"]: - field.unique = True - - if column["size"] is not None and column["size"] and not isinstance(column["size"], tuple): - field.maxLength = column["size"] - elif isinstance(column["size"], tuple): - field.precision = column["size"][0] - field.scale = column["size"][1] - - field.description = column["comment"][1:-1].strip() if column.get("comment") else None - - if column.get("with_tag"): - field.tags = column["with_tag"] - if column.get("with_masking_policy"): - field.classification = ", ".join(column["with_masking_policy"]) - if column.get("generated"): - field.examples = str(column["generated"]) - - fields[column["name"]] = field - - if table.get("constraints"): - if table["constraints"].get("primary_key"): - for primary_key in table["constraints"]["primary_key"]["columns"]: - if primary_key in fields: - fields[primary_key].unique = True - fields[primary_key].required = True - fields[primary_key].primaryKey = True - - table_description = table["comment"][1:-1] if table.get("comment") else None - table_tags = table["with_tag"][1:-1] if table.get("with_tag") else None - - return table_name, fields, table_description, table_tags + return table_name, fields, table_description, table_tag def map_physical_type(column, dialect) -> str | None: @@ -248,10 +211,19 @@ def to_col_type_normalized(column): def get_description(column: sqlglot.expressions.ColumnDef) -> str | None: - if column.comments is None: + description = column.find(sqlglot.expressions.CommentColumnConstraint) + if not description: + return + return description.this.this + +def get_tags(column: sqlglot.expressions.ColumnDef) -> str | None: + tags = column.find(sqlglot.expressions.Tags) + if tags: + tag_enum = tags.find(sqlglot.expressions.Property) + return [str(t) for t in tag_enum] + else: return None - return " ".join(comment.strip() for comment in column.comments) - + def get_max_length(column: sqlglot.expressions.ColumnDef) -> int | None: col_type = to_col_type_normalized(column) @@ -383,4 +355,5 @@ def read_file(path): ) with open(path, "r") as file: file_content = file.read() - return file_content + + return re.sub(r'\$\{(\w+)\}', r'\1', file_content) diff --git a/tests/fixtures/snowflake/import/ddl.sql b/tests/fixtures/snowflake/import/ddl.sql index f76058829..c458db7e5 100644 --- a/tests/fixtures/snowflake/import/ddl.sql +++ b/tests/fixtures/snowflake/import/ddl.sql @@ -1,6 +1,6 @@ CREATE TABLE IF NOT EXISTS ${database_name}.PUBLIC.my_table ( -- https://docs.snowflake.com/en/sql-reference/intro-summary-data-types - field_primary_key NUMBER(38,0) NOT NULL autoincrement start 1 increment 1 noorder COMMENT 'Primary key', + field_primary_key NUMBER(38,0) NOT NULL autoincrement start 1 increment 1 COMMENT 'Primary key', field_not_null INT NOT NULL COMMENT 'Not null', field_char CHAR(10) COMMENT 'Fixed-length string', field_character CHARACTER(10) COMMENT 'Fixed-length string', diff --git a/tests/test_import_sql_snowflake.py b/tests/test_import_sql_snowflake.py index 59a33c6c0..b598b251e 100644 --- a/tests/test_import_sql_snowflake.py +++ b/tests/test_import_sql_snowflake.py @@ -29,11 +29,10 @@ def test_import_sql_snowflake(): precision: 38 scale: 0 config: - snowflakeType: NUMBER(38,0) AUTOINCREMENT START 1 INCREMENT 1 NOORDER + snowflakeType: DECIMAL(38, 0) field_not_null: type: int required: true - unique: true description: Not null config: snowflakeType: INT @@ -48,7 +47,7 @@ def test_import_sql_snowflake(): description: Fixed-length string maxLength: 10 config: - snowflakeType: CHARACTER(10) + snowflakeType: CHAR(10) field_varchar: type: string description: Variable-length string @@ -65,30 +64,30 @@ def test_import_sql_snowflake(): type: string description: Large variable-length Unicode string config: - snowflakeType: STRING + snowflakeType: TEXT field_tinyint: type: int - description: Integer ( 0-255) + description: Integer (0-255) config: snowflakeType: TINYINT field_smallint: type: int - description: Integer ( -32 , 768 to 32 , 767) + description: Integer (-32,768 to 32,767) config: snowflakeType: SMALLINT field_int: type: int - description: Integer ( -2.1B to 2.1B) + description: Integer (-2.1B to 2.1B) config: snowflakeType: INT field_integer: type: int - description: Integer full name ( -2.1B to 2.1B) + description: Integer full name(-2.1B to 2.1B) config: - snowflakeType: INTEGER + snowflakeType: INT field_bigint: type: long - description: Large integer ( -9 quintillion to 9 quintillion) + description: Large integer (-9 quintillion to 9 quintillion) config: snowflakeType: BIGINT field_decimal: @@ -97,14 +96,14 @@ def test_import_sql_snowflake(): precision: 10 scale: 2 config: - snowflakeType: DECIMAL(10,2) + snowflakeType: DECIMAL(10, 2) field_numeric: type: decimal description: Same as DECIMAL precision: 10 scale: 2 config: - snowflakeType: NUMERIC(10,2) + snowflakeType: DECIMAL(10, 2) field_float: type: float description: Approximate floating-point @@ -114,30 +113,30 @@ def test_import_sql_snowflake(): type: float description: Approximate floating-point 4 config: - snowflakeType: FLOAT4 + snowflakeType: FLOAT field_float8: type: float description: Approximate floating-point 8 config: - snowflakeType: FLOAT8 + snowflakeType: DOUBLE field_real: type: float description: Smaller floating-point config: - snowflakeType: REAL + snowflakeType: FLOAT field_boulean: type: boolean - description: Boolean-like ( 0 or 1) + description: Boolean-like (0 or 1) config: snowflakeType: BOOLEAN field_date: type: date - description: Date only ( YYYY-MM-DD) + description: Date only (YYYY-MM-DD) config: snowflakeType: DATE field_time: type: string - description: Time only ( HH:MM:SS) + description: Time only (HH:MM:SS) config: snowflakeType: TIME field_timestamp: @@ -146,32 +145,28 @@ def test_import_sql_snowflake(): config: snowflakeType: TIMESTAMP field_timestamp_ltz: - type: timestamp_tz - description: More precise datetime with local time zone; time zone , if provided - , isn`t stored. + type: object + description: More precise datetime with local time zone; time zone, if provided, isn`t stored. config: - snowflakeType: TIMESTAMP_LTZ + snowflakeType: TIMESTAMPLTZ field_timestamp_ntz: type: timestamp_ntz - description: More precise datetime with no time zone; time zone , if provided - , isn`t stored. + description: More precise datetime with no time zone; time zone, if provided, isn`t stored. config: - snowflakeType: TIMESTAMP_NTZ + snowflakeType: TIMESTAMPNTZ field_timestamp_tz: type: timestamp_tz description: More precise datetime with time zone. config: - snowflakeType: TIMESTAMP_TZ + snowflakeType: TIMESTAMPTZ field_binary: type: bytes description: Fixed-length binary - maxLength: 16 config: snowflakeType: BINARY(16) field_varbinary: type: bytes description: Variable-length binary - maxLength: 100 config: snowflakeType: VARBINARY(100) field_variant: @@ -181,7 +176,7 @@ def test_import_sql_snowflake(): snowflakeType: VARIANT field_json: type: object - description: JSON ( Stored as text) + description: JSON (Stored as text) config: snowflakeType: OBJECT""" From 76d53b83c07f3082012bc7a41e15604ea5815cf6 Mon Sep 17 00:00:00 2001 From: Damien Maresma Date: Fri, 13 Jun 2025 18:50:23 -0400 Subject: [PATCH 6/8] fix regression on sql server side (no formal or declarative comments) --- datacontract/imports/sql_importer.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/datacontract/imports/sql_importer.py b/datacontract/imports/sql_importer.py index 35ef1556f..fde47a619 100644 --- a/datacontract/imports/sql_importer.py +++ b/datacontract/imports/sql_importer.py @@ -211,10 +211,14 @@ def to_col_type_normalized(column): def get_description(column: sqlglot.expressions.ColumnDef) -> str | None: - description = column.find(sqlglot.expressions.CommentColumnConstraint) - if not description: - return - return description.this.this + if column.comments is None: + description = column.find(sqlglot.expressions.CommentColumnConstraint) + if description: + return description.this.this + else: + return None + return " ".join(comment.strip() for comment in column.comments) + def get_tags(column: sqlglot.expressions.ColumnDef) -> str | None: tags = column.find(sqlglot.expressions.Tags) From 020d879674e67a023df1de169b0dd66a6ae99c49 Mon Sep 17 00:00:00 2001 From: Damien Maresma Date: Fri, 13 Jun 2025 20:02:25 -0400 Subject: [PATCH 7/8] type variant not allow in lint DataContract(data_contract_str=expected).lint(enabled_linters=set()).has_passed() --- tests/fixtures/dbml/import/datacontract.yaml | 2 +- tests/fixtures/dbml/import/datacontract_table_filtered.yaml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/fixtures/dbml/import/datacontract.yaml b/tests/fixtures/dbml/import/datacontract.yaml index 00d2a4440..d6b40c980 100644 --- a/tests/fixtures/dbml/import/datacontract.yaml +++ b/tests/fixtures/dbml/import/datacontract.yaml @@ -22,7 +22,7 @@ models: description: The business timestamp in UTC when the order was successfully registered in the source system and the payment was successful. order_total: - type: variant + type: object required: true primaryKey: false unique: false diff --git a/tests/fixtures/dbml/import/datacontract_table_filtered.yaml b/tests/fixtures/dbml/import/datacontract_table_filtered.yaml index fe011d539..2f9855d62 100644 --- a/tests/fixtures/dbml/import/datacontract_table_filtered.yaml +++ b/tests/fixtures/dbml/import/datacontract_table_filtered.yaml @@ -22,7 +22,7 @@ models: description: The business timestamp in UTC when the order was successfully registered in the source system and the payment was successful. order_total: - type: variant + type: object required: true primaryKey: false unique: false From e2ee1e8b15e1b3fdf79c97135078cfb4809664ab Mon Sep 17 00:00:00 2001 From: Damien Maresma Date: Fri, 13 Jun 2025 20:13:00 -0400 Subject: [PATCH 8/8] remove simple-ddl-parser dependency --- pyproject.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index d1045d87d..32e2e909b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,6 @@ dependencies = [ "python-multipart>=0.0.20,<1.0.0", "rich>=13.7,<15.0", "sqlglot>=26.6.0,<27.0.0", - "simple-ddl-parser>=1.7.1,<2.0.0", "duckdb>=1.0.0,<2.0.0", "soda-core-duckdb>=3.3.20,<3.6.0", # remove setuptools when https://github.com/sodadata/soda-core/issues/2091 is resolved