diff --git a/datacontract/imports/odcs_helper.py b/datacontract/imports/odcs_helper.py index 5c01daf0a..dbb9caed5 100644 --- a/datacontract/imports/odcs_helper.py +++ b/datacontract/imports/odcs_helper.py @@ -34,6 +34,7 @@ def create_schema_object( description: str = None, business_name: str = None, properties: List[SchemaProperty] = None, + tags: List[str] = None, ) -> SchemaObject: """Create a SchemaObject (equivalent to DCS Model).""" schema = SchemaObject( @@ -48,6 +49,8 @@ def create_schema_object( schema.businessName = business_name if properties: schema.properties = properties + if tags: + schema.tags = tags return schema @@ -130,9 +133,7 @@ def create_property( # Custom properties if custom_properties: - prop.customProperties = [ - CustomProperty(property=k, value=v) for k, v in custom_properties.items() - ] + prop.customProperties = [CustomProperty(property=k, value=v) for k, v in custom_properties.items()] return prop diff --git a/datacontract/imports/sql_importer.py b/datacontract/imports/sql_importer.py index 5d3f4ac0c..2e84327c3 100644 --- a/datacontract/imports/sql_importer.py +++ b/datacontract/imports/sql_importer.py @@ -1,5 +1,6 @@ import logging import os +import re import sqlglot from open_data_contract_standard.model import OpenDataContractStandard @@ -17,22 +18,23 @@ class SqlImporter(Importer): - def import_source( - self, source: str, import_args: dict - ) -> OpenDataContractStandard: + def import_source(self, source: str, import_args: dict) -> OpenDataContractStandard: return import_sql(self.import_format, source, import_args) -def import_sql( - format: str, source: str, import_args: dict = None -) -> OpenDataContractStandard: +def import_sql(format: str, source: str, import_args: dict = None) -> OpenDataContractStandard: sql = read_file(source) - dialect = to_dialect(import_args) + dialect = to_dialect(import_args) or None + + parsed = None try: parsed = sqlglot.parse_one(sql=sql, read=dialect) + + tables = parsed.find_all(sqlglot.expressions.Table) + except Exception as e: - logging.error(f"Error parsing SQL: {str(e)}") + logging.error(f"Error sqlglot SQL: {str(e)}") raise DataContractException( type="import", name=f"Reading source from {source}", @@ -67,6 +69,7 @@ def import_sql( precision, scale = get_precision_scale(column) is_primary_key = get_primary_key(column) is_required = column.find(sqlglot.exp.NotNullColumnConstraint) is not None or None + tags = get_tags(column) prop = create_property( name=col_name, @@ -79,6 +82,7 @@ def import_sql( primary_key=is_primary_key, primary_key_position=primary_key_position if is_primary_key else None, required=is_required if is_required else None, + tags=tags, ) if is_primary_key: @@ -86,9 +90,23 @@ def import_sql( properties.append(prop) + table_comment_property = parsed.find(sqlglot.expressions.SchemaCommentProperty) + + if table_comment_property: + table_description = table_comment_property.this.this + + prop = parsed.find(sqlglot.expressions.Properties) + if prop: + tags = prop.find(sqlglot.expressions.Tags) + if tags: + tag_enum = tags.find(sqlglot.expressions.Property) + table_tags = [str(t) for t in tag_enum] + schema_obj = create_schema_object( name=table_name, physical_type="table", + description=table_description if table_comment_property else None, + tags=table_tags if tags else None, properties=properties, ) odcs.schema_.append(schema_obj) @@ -96,6 +114,31 @@ def import_sql( return odcs +def map_physical_type(column, dialect) -> str | None: + autoincrement = "" + if column.get("autoincrement") and dialect == Dialects.SNOWFLAKE: + autoincrement = " AUTOINCREMENT" + " START " + str(column.get("start")) if column.get("start") else "" + autoincrement += " INCREMENT " + str(column.get("increment")) if column.get("increment") else "" + autoincrement += " NOORDER" if not column.get("increment_order") else "" + elif column.get("autoincrement"): + autoincrement = " IDENTITY" + + if column.get("size") and isinstance(column.get("size"), tuple): + return ( + column.get("type") + + "(" + + str(column.get("size")[0]) + + "," + + str(column.get("size")[1]) + + ")" + + autoincrement + ) + elif column.get("size"): + return column.get("type") + "(" + str(column.get("size")) + ")" + autoincrement + else: + return column.get("type") + autoincrement + + def get_primary_key(column) -> bool | None: if column.find(sqlglot.exp.PrimaryKeyColumnConstraint) is not None: return True @@ -116,25 +159,7 @@ def to_dialect(import_args: dict) -> Dialects | None: return Dialects.TSQL if dialect.upper() in Dialects.__members__: return Dialects[dialect.upper()] - if dialect == "sqlserver": - return Dialects.TSQL - return None - - -def to_physical_type_key(dialect: Dialects | str | None) -> str: - dialect_map = { - Dialects.TSQL: "sqlserverType", - Dialects.POSTGRES: "postgresType", - Dialects.BIGQUERY: "bigqueryType", - Dialects.SNOWFLAKE: "snowflakeType", - Dialects.REDSHIFT: "redshiftType", - Dialects.ORACLE: "oracleType", - Dialects.MYSQL: "mysqlType", - Dialects.DATABRICKS: "databricksType", - } - if isinstance(dialect, str): - dialect = Dialects[dialect.upper()] if dialect.upper() in Dialects.__members__ else None - return dialect_map.get(dialect, "physicalType") + return "None" def to_server_type(source, dialect: Dialects | None) -> str | None: @@ -170,10 +195,23 @@ def to_col_type_normalized(column): def get_description(column: sqlglot.expressions.ColumnDef) -> str | None: if column.comments is None: - return None + description = column.find(sqlglot.expressions.CommentColumnConstraint) + if description: + return description.this.this + else: + return None return " ".join(comment.strip() for comment in column.comments) +def get_tags(column: sqlglot.expressions.ColumnDef) -> str | None: + tags = column.find(sqlglot.expressions.Tags) + if tags: + tag_enum = tags.find(sqlglot.expressions.Property) + return [str(t) for t in tag_enum] + else: + return None + + def get_max_length(column: sqlglot.expressions.ColumnDef) -> int | None: col_type = to_col_type_normalized(column) if col_type is None: @@ -237,30 +275,28 @@ def map_type_from_sql(sql_type: str) -> str | None: return "string" elif sql_type_normed.startswith("ntext"): return "string" - elif sql_type_normed.startswith("int") and not sql_type_normed.startswith("interval"): - return "integer" - elif sql_type_normed.startswith("bigint"): - return "integer" - elif sql_type_normed.startswith("tinyint"): + elif sql_type_normed.endswith("integer"): return "integer" - elif sql_type_normed.startswith("smallint"): + elif sql_type_normed.endswith("int"): # covers int, bigint, smallint, tinyint return "integer" - elif sql_type_normed.startswith("float"): + elif sql_type_normed.startswith("float") or sql_type_normed.startswith("double") or sql_type_normed == "real": return "number" - elif sql_type_normed.startswith("double"): + elif sql_type_normed.startswith("number"): + return "number" + elif sql_type_normed.startswith("numeric"): return "number" elif sql_type_normed.startswith("decimal"): return "number" - elif sql_type_normed.startswith("numeric"): + elif sql_type_normed.startswith("money"): return "number" elif sql_type_normed.startswith("bool"): return "boolean" elif sql_type_normed.startswith("bit"): return "boolean" elif sql_type_normed.startswith("binary"): - return "array" + return "object" elif sql_type_normed.startswith("varbinary"): - return "array" + return "object" elif sql_type_normed.startswith("raw"): return "array" elif sql_type_normed == "blob" or sql_type_normed == "bfile": @@ -270,12 +306,10 @@ def map_type_from_sql(sql_type: str) -> str | None: elif sql_type_normed == "time": return "string" elif sql_type_normed.startswith("timestamp"): - return "date" - elif sql_type_normed == "datetime" or sql_type_normed == "datetime2": - return "date" + return "timestamp" elif sql_type_normed == "smalldatetime": return "date" - elif sql_type_normed == "datetimeoffset": + elif sql_type_normed.startswith("datetime"): # tsql datatime2 return "date" elif sql_type_normed == "uniqueidentifier": # tsql return "string" @@ -291,6 +325,14 @@ def map_type_from_sql(sql_type: str) -> str | None: return "object" +def remove_variable_tokens(sql_script: str) -> str: + ## to cleanse sql statement's script token like $(...) in sqlcmd for T-SQL langage, ${...} for liquibase, {{}} as Jinja + ## https://learn.microsoft.com/en-us/sql/tools/sqlcmd/sqlcmd-use-scripting-variables?view=sql-server-ver17#b-use-the-setvar-command-interactively + ## https://docs.liquibase.com/concepts/changelogs/property-substitution.html + ## https://docs.getdbt.com/guides/using-jinja?step=1 + return re.sub(r"\$\((\w+)\)|\$\{(\w+)\}|\{\{(\w+)\}\}", r"\1", sql_script) + + def read_file(path): if not os.path.exists(path): raise DataContractException( @@ -302,4 +344,5 @@ def read_file(path): ) with open(path, "r") as file: file_content = file.read() - return file_content + + return remove_variable_tokens(file_content) diff --git a/tests/fixtures/databricks-unity/import/datacontract.yaml b/tests/fixtures/databricks-unity/import/datacontract.yaml index 3dff216cd..ae894b326 100644 --- a/tests/fixtures/databricks-unity/import/datacontract.yaml +++ b/tests/fixtures/databricks-unity/import/datacontract.yaml @@ -53,7 +53,7 @@ schema: customProperties: - property: databricksType value: timestamp - logicalType: date + logicalType: timestamp - name: is_active physicalType: boolean customProperties: diff --git a/tests/fixtures/dbml/import/datacontract.yaml b/tests/fixtures/dbml/import/datacontract.yaml index 2aa614e83..d14583e85 100644 --- a/tests/fixtures/dbml/import/datacontract.yaml +++ b/tests/fixtures/dbml/import/datacontract.yaml @@ -26,7 +26,7 @@ schema: physicalType: timestamp description: The business timestamp in UTC when the order was successfully registered in the source system and the payment was successful. - logicalType: date + logicalType: timestamp required: true - name: order_total physicalType: record @@ -46,7 +46,7 @@ schema: - name: processed_timestamp physicalType: timestamp description: The timestamp when the record was processed by the data platform. - logicalType: date + logicalType: timestamp required: true - name: line_items physicalType: table diff --git a/tests/fixtures/dbml/import/datacontract_table_filtered.yaml b/tests/fixtures/dbml/import/datacontract_table_filtered.yaml index b8d5bb2d9..2247ebc1b 100644 --- a/tests/fixtures/dbml/import/datacontract_table_filtered.yaml +++ b/tests/fixtures/dbml/import/datacontract_table_filtered.yaml @@ -26,7 +26,7 @@ schema: physicalType: timestamp description: The business timestamp in UTC when the order was successfully registered in the source system and the payment was successful. - logicalType: date + logicalType: timestamp required: true - name: order_total physicalType: record @@ -46,5 +46,5 @@ schema: - name: processed_timestamp physicalType: timestamp description: The timestamp when the record was processed by the data platform. - logicalType: date + logicalType: timestamp required: true diff --git a/tests/fixtures/snowflake/import/ddl.sql b/tests/fixtures/snowflake/import/ddl.sql new file mode 100644 index 000000000..c458db7e5 --- /dev/null +++ b/tests/fixtures/snowflake/import/ddl.sql @@ -0,0 +1,42 @@ +CREATE TABLE IF NOT EXISTS ${database_name}.PUBLIC.my_table ( + -- https://docs.snowflake.com/en/sql-reference/intro-summary-data-types + field_primary_key NUMBER(38,0) NOT NULL autoincrement start 1 increment 1 COMMENT 'Primary key', + field_not_null INT NOT NULL COMMENT 'Not null', + field_char CHAR(10) COMMENT 'Fixed-length string', + field_character CHARACTER(10) COMMENT 'Fixed-length string', + field_varchar VARCHAR(100) WITH TAG (SNOWFLAKE.CORE.PRIVACY_CATEGORY='IDENTIFIER', SNOWFLAKE.CORE.SEMANTIC_CATEGORY='NAME') COMMENT 'Variable-length string', + + field_text TEXT COMMENT 'Large variable-length string', + field_string STRING COMMENT 'Large variable-length Unicode string', + + field_tinyint TINYINT COMMENT 'Integer (0-255)', + field_smallint SMALLINT COMMENT 'Integer (-32,768 to 32,767)', + field_int INT COMMENT 'Integer (-2.1B to 2.1B)', + field_integer INTEGER COMMENT 'Integer full name(-2.1B to 2.1B)', + field_bigint BIGINT COMMENT 'Large integer (-9 quintillion to 9 quintillion)', + + field_decimal DECIMAL(10, 2) COMMENT 'Fixed precision decimal', + field_numeric NUMERIC(10, 2) COMMENT 'Same as DECIMAL', + + field_float FLOAT COMMENT 'Approximate floating-point', + field_float4 FLOAT4 COMMENT 'Approximate floating-point 4', + field_float8 FLOAT8 COMMENT 'Approximate floating-point 8', + field_real REAL COMMENT 'Smaller floating-point', + + field_boulean BOOLEAN COMMENT 'Boolean-like (0 or 1)', + + field_date DATE COMMENT 'Date only (YYYY-MM-DD)', + field_time TIME COMMENT 'Time only (HH:MM:SS)', + field_timestamp TIMESTAMP COMMENT 'More precise datetime', + field_timestamp_ltz TIMESTAMP_LTZ COMMENT 'More precise datetime with local time zone; time zone, if provided, isn`t stored.', + field_timestamp_ntz TIMESTAMP_NTZ DEFAULT CURRENT_TIMESTAMP() COMMENT 'More precise datetime with no time zone; time zone, if provided, isn`t stored.', + field_timestamp_tz TIMESTAMP_TZ COMMENT 'More precise datetime with time zone.', + + field_binary BINARY(16) COMMENT 'Fixed-length binary', + field_varbinary VARBINARY(100) COMMENT 'Variable-length binary', + + field_variant VARIANT COMMENT 'VARIANT data', + field_json OBJECT COMMENT 'JSON (Stored as text)', + UNIQUE(field_not_null), + PRIMARY KEY (field_primary_key) +) COMMENT = 'My Comment' diff --git a/tests/test_import_sql_oracle.py b/tests/test_import_sql_oracle.py index 37ffc5604..74c3408e8 100644 --- a/tests/test_import_sql_oracle.py +++ b/tests/test_import_sql_oracle.py @@ -83,15 +83,15 @@ def test_import_sql_oracle(): physicalType: DOUBLE PRECISION description: 64-bit floating point number - name: field_timestamp - logicalType: date + logicalType: timestamp physicalType: TIMESTAMP description: Timestamp with fractional second precision of 6, no timezones - name: field_timestamp_tz - logicalType: date + logicalType: timestamp physicalType: TIMESTAMP WITH TIME ZONE description: Timestamp with fractional second precision of 6, with timezones (TZ) - name: field_timestamp_ltz - logicalType: date + logicalType: timestamp physicalType: TIMESTAMPLTZ description: Timestamp with fractional second precision of 6, with local timezone (LTZ) - name: field_interval_year @@ -176,7 +176,7 @@ def test_import_sql_constraints(): physicalType: VARCHAR(30) required: true - name: create_date - logicalType: date + logicalType: timestamp physicalType: TIMESTAMP required: true - name: changed_by @@ -185,7 +185,7 @@ def test_import_sql_constraints(): maxLength: 30 physicalType: VARCHAR(30) - name: change_date - logicalType: date + logicalType: timestamp physicalType: TIMESTAMP - name: name logicalType: string diff --git a/tests/test_import_sql_postgres.py b/tests/test_import_sql_postgres.py index 908a9810a..fdf3b181e 100644 --- a/tests/test_import_sql_postgres.py +++ b/tests/test_import_sql_postgres.py @@ -20,6 +20,8 @@ def test_cli(): "sql", "--source", sql_file_path, + "--dialect", + "postgres" ], ) assert result.exit_code == 0 @@ -56,7 +58,7 @@ def test_import_sql_postgres(): physicalType: INT required: true - name: field_three - logicalType: date + logicalType: timestamp physicalType: TIMESTAMPTZ """ print("Result", result.to_yaml()) @@ -93,7 +95,7 @@ def test_import_sql_constraints(): physicalType: VARCHAR(30) required: true - name: create_date - logicalType: date + logicalType: timestamp physicalType: TIMESTAMP required: true - name: changed_by @@ -102,7 +104,7 @@ def test_import_sql_constraints(): maxLength: 30 physicalType: VARCHAR(30) - name: change_date - logicalType: date + logicalType: timestamp physicalType: TIMESTAMP - name: name logicalType: string diff --git a/tests/test_import_sql_snowflake.py b/tests/test_import_sql_snowflake.py new file mode 100644 index 000000000..e378ccd89 --- /dev/null +++ b/tests/test_import_sql_snowflake.py @@ -0,0 +1,169 @@ +import yaml + +from datacontract.data_contract import DataContract + +sql_file_path = "fixtures/snowflake/import/ddl.sql" + + +def test_import_sql_snowflake(): + result = DataContract().import_from_source("sql", sql_file_path, dialect="snowflake") + + expected = """version: 1.0.0 +kind: DataContract +apiVersion: v3.1.0 +id: my-data-contract +name: My Data Contract +status: draft +servers: +- server: snowflake + type: snowflake +schema: +- name: my_table + physicalType: table + description: My Comment + logicalType: object + physicalName: my_table + properties: + - name: field_primary_key + physicalType: DECIMAL(38, 0) + description: Primary key + logicalType: number + logicalTypeOptions: + precision: 38 + scale: 0 + required: true + - name: field_not_null + physicalType: INT + description: Not null + logicalType: integer + required: true + - name: field_char + physicalType: CHAR(10) + description: Fixed-length string + logicalType: string + logicalTypeOptions: + maxLength: 10 + - name: field_character + physicalType: CHAR(10) + description: Fixed-length string + logicalType: string + logicalTypeOptions: + maxLength: 10 + - name: field_varchar + physicalType: VARCHAR(100) + description: Variable-length string + tags: + - SNOWFLAKE.CORE.PRIVACY_CATEGORY='IDENTIFIER' + - SNOWFLAKE.CORE.SEMANTIC_CATEGORY='NAME' + logicalType: string + logicalTypeOptions: + maxLength: 100 + - name: field_text + physicalType: VARCHAR + description: Large variable-length string + logicalType: string + - name: field_string + physicalType: VARCHAR + description: Large variable-length Unicode string + logicalType: string + - name: field_tinyint + physicalType: TINYINT + description: Integer (0-255) + logicalType: integer + - name: field_smallint + physicalType: SMALLINT + description: Integer (-32,768 to 32,767) + logicalType: integer + - name: field_int + physicalType: INT + description: Integer (-2.1B to 2.1B) + logicalType: integer + - name: field_integer + physicalType: INT + description: Integer full name(-2.1B to 2.1B) + logicalType: integer + - name: field_bigint + physicalType: BIGINT + description: Large integer (-9 quintillion to 9 quintillion) + logicalType: integer + - name: field_decimal + physicalType: DECIMAL(10, 2) + description: Fixed precision decimal + logicalType: number + logicalTypeOptions: + precision: 10 + scale: 2 + - name: field_numeric + physicalType: DECIMAL(10, 2) + description: Same as DECIMAL + logicalType: number + logicalTypeOptions: + precision: 10 + scale: 2 + - name: field_float + physicalType: DOUBLE + description: Approximate floating-point + logicalType: number + - name: field_float4 + physicalType: FLOAT + description: Approximate floating-point 4 + logicalType: number + - name: field_float8 + physicalType: DOUBLE + description: Approximate floating-point 8 + logicalType: number + - name: field_real + physicalType: FLOAT + description: Smaller floating-point + logicalType: number + - name: field_boulean + physicalType: BOOLEAN + description: Boolean-like (0 or 1) + logicalType: boolean + - name: field_date + physicalType: DATE + description: Date only (YYYY-MM-DD) + logicalType: date + - name: field_time + physicalType: TIME + description: Time only (HH:MM:SS) + logicalType: string + - name: field_timestamp + physicalType: TIMESTAMP + description: More precise datetime + logicalType: timestamp + - name: field_timestamp_ltz + physicalType: TIMESTAMPLTZ + description: More precise datetime with local time zone; time zone, if provided, + isn`t stored. + logicalType: timestamp + - name: field_timestamp_ntz + physicalType: TIMESTAMPNTZ + description: More precise datetime with no time zone; time zone, if provided, + isn`t stored. + logicalType: timestamp + - name: field_timestamp_tz + description: More precise datetime with time zone. + logicalType: timestamp + physicalType: 'TIMESTAMPTZ' + - name: field_binary + physicalType: BINARY(16) + description: Fixed-length binary + logicalType: object + - name: field_varbinary + physicalType: VARBINARY(100) + description: Variable-length binary + logicalType: object + - name: field_variant + physicalType: VARIANT + description: VARIANT data + logicalType: object + - name: field_json + physicalType: OBJECT + description: JSON (Stored as text) + logicalType: object""" + + print("Result", result.to_yaml()) + assert yaml.safe_load(result.to_yaml()) == yaml.safe_load(expected) + # Disable linters so we don't get "missing description" warnings account, db, schema name are required + #assert DataContract(data_contract_str=expected).lint().has_passed() diff --git a/tests/test_import_sql_sqlserver.py b/tests/test_import_sql_sqlserver.py index 479421856..e573ede04 100644 --- a/tests/test_import_sql_sqlserver.py +++ b/tests/test_import_sql_sqlserver.py @@ -130,11 +130,11 @@ def test_import_sql_sqlserver(): physicalType: DATETIMEOFFSET description: Datetime with time zone - name: field_binary - logicalType: array + logicalType: object physicalType: BINARY(16) description: Fixed-length binary - name: field_varbinary - logicalType: array + logicalType: object physicalType: VARBINARY(100) description: Variable-length binary - name: field_uniqueidentifier