Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 122 additions & 47 deletions datacontract/imports/sql_importer.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
import logging
import os
import re

import sqlglot
from sqlglot.dialects.dialect import Dialects

from datacontract.imports.importer import Importer
from datacontract.model.data_contract_specification import DataContractSpecification, Field, Model, Server
from datacontract.model.data_contract_specification import (
DataContractSpecification,
Field,
Model,
Server,
)
from datacontract.model.exceptions import DataContractException
from datacontract.model.run import ResultEnum

Expand All @@ -18,16 +24,28 @@ def import_source(


def import_sql(
data_contract_specification: DataContractSpecification, format: str, source: str, import_args: dict = None
data_contract_specification: DataContractSpecification,
format: str,
source: str,
import_args: dict = None,
) -> DataContractSpecification:
dialect = to_dialect(import_args)

server_type: str | None = to_server_type(source, dialect)
if server_type is not None:
data_contract_specification.servers[server_type] = Server(type=server_type)

sql = read_file(source)

dialect = to_dialect(import_args)
parsed = None

try:
parsed = sqlglot.parse_one(sql=sql, read=dialect)
parsed = sqlglot.parse_one(sql=sql, read=dialect.lower())

tables = parsed.find_all(sqlglot.expressions.Table)

except Exception as e:
logging.error(f"Error parsing SQL: {str(e)}")
logging.error(f"Error simple-dd-parser SQL: {str(e)}")
raise DataContractException(
type="import",
name=f"Reading source from {source}",
Expand All @@ -36,50 +54,92 @@ def import_sql(
result=ResultEnum.error,
)

server_type: str | None = to_server_type(source, dialect)
if server_type is not None:
data_contract_specification.servers[server_type] = Server(type=server_type)

tables = parsed.find_all(sqlglot.expressions.Table)

for table in tables:
if data_contract_specification.models is None:
data_contract_specification.models = {}

table_name = table.this.name

fields = {}
for column in parsed.find_all(sqlglot.exp.ColumnDef):
if column.parent.this.name != table_name:
continue

field = Field()
col_name = column.this.name
col_type = to_col_type(column, dialect)
field.type = map_type_from_sql(col_type)
col_description = get_description(column)
field.description = col_description
field.maxLength = get_max_length(column)
precision, scale = get_precision_scale(column)
field.precision = precision
field.scale = scale
field.primaryKey = get_primary_key(column)
field.required = column.find(sqlglot.exp.NotNullColumnConstraint) is not None or None
physical_type_key = to_physical_type_key(dialect)
field.config = {
physical_type_key: col_type,
}

fields[col_name] = field
table_name, fields, table_description, table_tags = sqlglot_model_wrapper(table, parsed, dialect)

data_contract_specification.models[table_name] = Model(
type="table",
description=table_description,
tags=table_tags,
fields=fields,
)

return data_contract_specification


def sqlglot_model_wrapper(table, parsed, dialect):
table_description = None
table_tag = None

table_name = table.this.name

table_comment_property = parsed.find(sqlglot.expressions.SchemaCommentProperty)
if table_comment_property:
table_description = table_comment_property.this.this

prop = parsed.find(sqlglot.expressions.Properties)
if prop:
tags = prop.find(sqlglot.expressions.Tags)
if tags:
tag_enum = tags.find(sqlglot.expressions.Property)
table_tag = [str(t) for t in tag_enum]

fields = {}
for column in parsed.find_all(sqlglot.exp.ColumnDef):
if column.parent.this.name != table_name:
continue

field = Field()
col_name = column.this.name
col_type = to_col_type(column, dialect)
field.type = map_type_from_sql(col_type)
col_description = get_description(column)
field.description = col_description
field.maxLength = get_max_length(column)
precision, scale = get_precision_scale(column)
field.precision = precision
field.scale = scale
field.primaryKey = get_primary_key(column)
field.required = column.find(sqlglot.exp.NotNullColumnConstraint) is not None or None
physical_type_key = to_physical_type_key(dialect)
field.tags = get_tags(column)
field.config = {
physical_type_key: col_type,
}

fields[col_name] = field

return table_name, fields, table_description, table_tag


def map_physical_type(column, dialect) -> str | None:
autoincrement = ""
if column.get("autoincrement") and dialect == Dialects.SNOWFLAKE:
autoincrement = " AUTOINCREMENT" + " START " + str(column.get("start")) if column.get("start") else ""
autoincrement += " INCREMENT " + str(column.get("increment")) if column.get("increment") else ""
autoincrement += " NOORDER" if not column.get("increment_order") else ""
elif column.get("autoincrement"):
autoincrement = " IDENTITY"

if column.get("size") and isinstance(column.get("size"), tuple):
return (
column.get("type")
+ "("
+ str(column.get("size")[0])
+ ","
+ str(column.get("size")[1])
+ ")"
+ autoincrement
)
elif column.get("size"):
return column.get("type") + "(" + str(column.get("size")) + ")" + autoincrement
else:
return column.get("type") + autoincrement


def get_primary_key(column) -> bool | None:
if column.find(sqlglot.exp.PrimaryKeyColumnConstraint) is not None:
return True
Expand All @@ -100,9 +160,7 @@ def to_dialect(import_args: dict) -> Dialects | None:
return Dialects.TSQL
if dialect.upper() in Dialects.__members__:
return Dialects[dialect.upper()]
if dialect == "sqlserver":
return Dialects.TSQL
return None
return "None"


def to_physical_type_key(dialect: Dialects | str | None) -> str:
Expand Down Expand Up @@ -154,9 +212,22 @@ def to_col_type_normalized(column):

def get_description(column: sqlglot.expressions.ColumnDef) -> str | None:
if column.comments is None:
description = column.find(sqlglot.expressions.CommentColumnConstraint)
if description:
return description.this.this
else:
return None
return " ".join(comment.strip() for comment in column.comments)


def get_tags(column: sqlglot.expressions.ColumnDef) -> str | None:
tags = column.find(sqlglot.expressions.Tags)
if tags:
tag_enum = tags.find(sqlglot.expressions.Property)
return [str(t) for t in tag_enum]
else:
return None
return " ".join(comment.strip() for comment in column.comments)



def get_max_length(column: sqlglot.expressions.ColumnDef) -> int | None:
col_type = to_col_type_normalized(column)
Expand Down Expand Up @@ -222,18 +293,20 @@ def map_type_from_sql(sql_type: str) -> str | None:
return "string"
elif sql_type_normed.startswith("int"):
return "int"
elif sql_type_normed.startswith("bigint"):
return "long"
elif sql_type_normed.startswith("tinyint"):
return "int"
elif sql_type_normed.startswith("smallint"):
return "int"
elif sql_type_normed.startswith("float"):
elif sql_type_normed.startswith("bigint"):
return "long"
elif sql_type_normed.startswith("float") or sql_type_normed.startswith("double") or sql_type_normed == "real":
return "float"
elif sql_type_normed.startswith("decimal"):
elif sql_type_normed.startswith("number"):
return "decimal"
elif sql_type_normed.startswith("numeric"):
return "decimal"
elif sql_type_normed.startswith("decimal"):
return "decimal"
elif sql_type_normed.startswith("bool"):
return "boolean"
elif sql_type_normed.startswith("bit"):
Expand All @@ -252,6 +325,7 @@ def map_type_from_sql(sql_type: str) -> str | None:
sql_type_normed == "timestamptz"
or sql_type_normed == "timestamp_tz"
or sql_type_normed == "timestamp with time zone"
or sql_type_normed == "timestamp_ltz"
):
return "timestamp_tz"
elif sql_type_normed == "timestampntz" or sql_type_normed == "timestamp_ntz":
Expand All @@ -271,7 +345,7 @@ def map_type_from_sql(sql_type: str) -> str | None:
elif sql_type_normed == "xml": # tsql
return "string"
else:
return "variant"
return "object"


def read_file(path):
Expand All @@ -285,4 +359,5 @@ def read_file(path):
)
with open(path, "r") as file:
file_content = file.read()
return file_content

return re.sub(r'\$\{(\w+)\}', r'\1', file_content)
2 changes: 1 addition & 1 deletion tests/fixtures/dbml/import/datacontract.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ models:
description: The business timestamp in UTC when the order was successfully
registered in the source system and the payment was successful.
order_total:
type: variant
type: object
required: true
primaryKey: false
unique: false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ models:
description: The business timestamp in UTC when the order was successfully
registered in the source system and the payment was successful.
order_total:
type: variant
type: object
required: true
primaryKey: false
unique: false
Expand Down
42 changes: 42 additions & 0 deletions tests/fixtures/snowflake/import/ddl.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
CREATE TABLE IF NOT EXISTS ${database_name}.PUBLIC.my_table (
-- https://docs.snowflake.com/en/sql-reference/intro-summary-data-types
field_primary_key NUMBER(38,0) NOT NULL autoincrement start 1 increment 1 COMMENT 'Primary key',
field_not_null INT NOT NULL COMMENT 'Not null',
field_char CHAR(10) COMMENT 'Fixed-length string',
field_character CHARACTER(10) COMMENT 'Fixed-length string',
field_varchar VARCHAR(100) WITH TAG (SNOWFLAKE.CORE.PRIVACY_CATEGORY='IDENTIFIER', SNOWFLAKE.CORE.SEMANTIC_CATEGORY='NAME') COMMENT 'Variable-length string',

field_text TEXT COMMENT 'Large variable-length string',
field_string STRING COMMENT 'Large variable-length Unicode string',

field_tinyint TINYINT COMMENT 'Integer (0-255)',
field_smallint SMALLINT COMMENT 'Integer (-32,768 to 32,767)',
field_int INT COMMENT 'Integer (-2.1B to 2.1B)',
field_integer INTEGER COMMENT 'Integer full name(-2.1B to 2.1B)',
field_bigint BIGINT COMMENT 'Large integer (-9 quintillion to 9 quintillion)',

field_decimal DECIMAL(10, 2) COMMENT 'Fixed precision decimal',
field_numeric NUMERIC(10, 2) COMMENT 'Same as DECIMAL',

field_float FLOAT COMMENT 'Approximate floating-point',
field_float4 FLOAT4 COMMENT 'Approximate floating-point 4',
field_float8 FLOAT8 COMMENT 'Approximate floating-point 8',
field_real REAL COMMENT 'Smaller floating-point',

field_boulean BOOLEAN COMMENT 'Boolean-like (0 or 1)',

field_date DATE COMMENT 'Date only (YYYY-MM-DD)',
field_time TIME COMMENT 'Time only (HH:MM:SS)',
field_timestamp TIMESTAMP COMMENT 'More precise datetime',
field_timestamp_ltz TIMESTAMP_LTZ COMMENT 'More precise datetime with local time zone; time zone, if provided, isn`t stored.',
field_timestamp_ntz TIMESTAMP_NTZ DEFAULT CURRENT_TIMESTAMP() COMMENT 'More precise datetime with no time zone; time zone, if provided, isn`t stored.',
field_timestamp_tz TIMESTAMP_TZ COMMENT 'More precise datetime with time zone.',

field_binary BINARY(16) COMMENT 'Fixed-length binary',
field_varbinary VARBINARY(100) COMMENT 'Variable-length binary',

field_variant VARIANT COMMENT 'VARIANT data',
field_json OBJECT COMMENT 'JSON (Stored as text)',
UNIQUE(field_not_null),
PRIMARY KEY (field_primary_key)
) COMMENT = 'My Comment'
2 changes: 2 additions & 0 deletions tests/test_import_sql_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ def test_cli():
"sql",
"--source",
sql_file_path,
"--dialect",
"postgres"
],
)
assert result.exit_code == 0
Expand Down
Loading
Loading