|
13 | 13 | WorkspaceInlayHintRefreshRequest, |
14 | 14 | ) |
15 | 15 | from pygls.server import LanguageServer |
| 16 | +from sqlglot import exp |
16 | 17 | from sqlmesh._version import __version__ |
17 | 18 | from sqlmesh.core.context import Context |
18 | 19 | from sqlmesh.utils.date import to_timestamp |
|
68 | 69 | from sqlmesh.lsp.uri import URI |
69 | 70 | from web.server.api.endpoints.lineage import column_lineage, model_lineage |
70 | 71 | from web.server.api.endpoints.models import get_models |
71 | | -from web.server.api.endpoints.table_diff import get_table_diff |
| 72 | +from web.server.api.endpoints.table_diff import _process_sample_data |
72 | 73 | from typing import Union |
73 | 74 | from dataclasses import dataclass |
74 | 75 |
|
| 76 | +from web.server.models import RowDiff, SchemaDiff, TableDiff |
| 77 | + |
75 | 78 |
|
76 | 79 | @dataclass |
77 | 80 | class NoContext: |
@@ -275,20 +278,72 @@ def _custom_api( |
275 | 278 | return ApiResponseGetColumnLineage(data=column_lineage_response) |
276 | 279 |
|
277 | 280 | if path_parts[:2] == ["api", "table_diff"]: |
| 281 | + import numpy as np |
| 282 | + |
278 | 283 | # /api/table_diff |
279 | 284 | params = request.params |
280 | | - table_diff_result = get_table_diff( |
281 | | - source=getattr(params, "source", "") if params else "", |
282 | | - target=getattr(params, "target", "") if params else "", |
283 | | - on=getattr(params, "on", None) if params else None, |
284 | | - model_or_snapshot=getattr(params, "model_or_snapshot", None) |
285 | | - if params |
286 | | - else None, |
287 | | - where=getattr(params, "where", None) if params else None, |
288 | | - temp_schema=getattr(params, "temp_schema", None) if params else None, |
289 | | - limit=getattr(params, "limit", 20) if params else 20, |
290 | | - context=context.context, |
291 | | - ) |
| 285 | + table_diff_result: t.Optional[TableDiff] = None |
| 286 | + if params := request.params: |
| 287 | + source = getattr(params, "source", "") if params else "" |
| 288 | + target = getattr(params, "target", "") if params else "" |
| 289 | + on = getattr(params, "on", None) if params else None |
| 290 | + model_or_snapshot = ( |
| 291 | + getattr(params, "model_or_snapshot", None) if params else None |
| 292 | + ) |
| 293 | + where = getattr(params, "where", None) if params else None |
| 294 | + temp_schema = getattr(params, "temp_schema", None) if params else None |
| 295 | + limit = getattr(params, "limit", 20) if params else 20 |
| 296 | + |
| 297 | + table_diffs = context.context.table_diff( |
| 298 | + source=source, |
| 299 | + target=target, |
| 300 | + on=exp.condition(on) if on else None, |
| 301 | + select_models={model_or_snapshot} if model_or_snapshot else None, |
| 302 | + where=where, |
| 303 | + limit=limit, |
| 304 | + show=False, |
| 305 | + ) |
| 306 | + |
| 307 | + if table_diffs: |
| 308 | + diff = table_diffs[0] if isinstance(table_diffs, list) else table_diffs |
| 309 | + |
| 310 | + _schema_diff = diff.schema_diff() |
| 311 | + _row_diff = diff.row_diff(temp_schema=temp_schema) |
| 312 | + schema_diff = SchemaDiff( |
| 313 | + source=_schema_diff.source, |
| 314 | + target=_schema_diff.target, |
| 315 | + source_schema=_schema_diff.source_schema, |
| 316 | + target_schema=_schema_diff.target_schema, |
| 317 | + added=_schema_diff.added, |
| 318 | + removed=_schema_diff.removed, |
| 319 | + modified=_schema_diff.modified, |
| 320 | + ) |
| 321 | + |
| 322 | + # create a readable column-centric sample data structure |
| 323 | + processed_sample_data = _process_sample_data(_row_diff, source, target) |
| 324 | + |
| 325 | + row_diff = RowDiff( |
| 326 | + source=_row_diff.source, |
| 327 | + target=_row_diff.target, |
| 328 | + stats=_row_diff.stats, |
| 329 | + sample=_row_diff.sample.replace({np.nan: None}).to_dict(), |
| 330 | + joined_sample=_row_diff.joined_sample.replace({np.nan: None}).to_dict(), |
| 331 | + s_sample=_row_diff.s_sample.replace({np.nan: None}).to_dict(), |
| 332 | + t_sample=_row_diff.t_sample.replace({np.nan: None}).to_dict(), |
| 333 | + column_stats=_row_diff.column_stats.replace({np.nan: None}).to_dict(), |
| 334 | + source_count=_row_diff.source_count, |
| 335 | + target_count=_row_diff.target_count, |
| 336 | + count_pct_change=_row_diff.count_pct_change, |
| 337 | + decimals=getattr(_row_diff, "decimals", 3), |
| 338 | + processed_sample_data=processed_sample_data, |
| 339 | + ) |
| 340 | + |
| 341 | + s_index, t_index, _ = diff.key_columns |
| 342 | + table_diff_result = TableDiff( |
| 343 | + schema_diff=schema_diff, |
| 344 | + row_diff=row_diff, |
| 345 | + on=[(s.name, t.name) for s, t in zip(s_index, t_index)], |
| 346 | + ) |
292 | 347 | return ApiResponseGetTableDiff(data=table_diff_result) |
293 | 348 |
|
294 | 349 | raise NotImplementedError(f"API request not implemented: {request.endpoint}") |
|
0 commit comments