Skip to content

Commit d604c51

Browse files
move tablediff from webapi to vscode
1 parent 3b7f775 commit d604c51

File tree

2 files changed

+69
-14
lines changed

2 files changed

+69
-14
lines changed

sqlmesh/lsp/main.py

Lines changed: 68 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
WorkspaceInlayHintRefreshRequest,
1414
)
1515
from pygls.server import LanguageServer
16+
from sqlglot import exp
1617
from sqlmesh._version import __version__
1718
from sqlmesh.core.context import Context
1819
from sqlmesh.utils.date import to_timestamp
@@ -68,10 +69,12 @@
6869
from sqlmesh.lsp.uri import URI
6970
from web.server.api.endpoints.lineage import column_lineage, model_lineage
7071
from web.server.api.endpoints.models import get_models
71-
from web.server.api.endpoints.table_diff import get_table_diff
72+
from web.server.api.endpoints.table_diff import _process_sample_data
7273
from typing import Union
7374
from dataclasses import dataclass
7475

76+
from web.server.models import RowDiff, SchemaDiff, TableDiff
77+
7578

7679
@dataclass
7780
class NoContext:
@@ -275,20 +278,72 @@ def _custom_api(
275278
return ApiResponseGetColumnLineage(data=column_lineage_response)
276279

277280
if path_parts[:2] == ["api", "table_diff"]:
281+
import numpy as np
282+
278283
# /api/table_diff
279284
params = request.params
280-
table_diff_result = get_table_diff(
281-
source=getattr(params, "source", "") if params else "",
282-
target=getattr(params, "target", "") if params else "",
283-
on=getattr(params, "on", None) if params else None,
284-
model_or_snapshot=getattr(params, "model_or_snapshot", None)
285-
if params
286-
else None,
287-
where=getattr(params, "where", None) if params else None,
288-
temp_schema=getattr(params, "temp_schema", None) if params else None,
289-
limit=getattr(params, "limit", 20) if params else 20,
290-
context=context.context,
291-
)
285+
table_diff_result: t.Optional[TableDiff] = None
286+
if params := request.params:
287+
source = getattr(params, "source", "") if params else ""
288+
target = getattr(params, "target", "") if params else ""
289+
on = getattr(params, "on", None) if params else None
290+
model_or_snapshot = (
291+
getattr(params, "model_or_snapshot", None) if params else None
292+
)
293+
where = getattr(params, "where", None) if params else None
294+
temp_schema = getattr(params, "temp_schema", None) if params else None
295+
limit = getattr(params, "limit", 20) if params else 20
296+
297+
table_diffs = context.context.table_diff(
298+
source=source,
299+
target=target,
300+
on=exp.condition(on) if on else None,
301+
select_models={model_or_snapshot} if model_or_snapshot else None,
302+
where=where,
303+
limit=limit,
304+
show=False,
305+
)
306+
307+
if table_diffs:
308+
diff = table_diffs[0] if isinstance(table_diffs, list) else table_diffs
309+
310+
_schema_diff = diff.schema_diff()
311+
_row_diff = diff.row_diff(temp_schema=temp_schema)
312+
schema_diff = SchemaDiff(
313+
source=_schema_diff.source,
314+
target=_schema_diff.target,
315+
source_schema=_schema_diff.source_schema,
316+
target_schema=_schema_diff.target_schema,
317+
added=_schema_diff.added,
318+
removed=_schema_diff.removed,
319+
modified=_schema_diff.modified,
320+
)
321+
322+
# create a readable column-centric sample data structure
323+
processed_sample_data = _process_sample_data(_row_diff, source, target)
324+
325+
row_diff = RowDiff(
326+
source=_row_diff.source,
327+
target=_row_diff.target,
328+
stats=_row_diff.stats,
329+
sample=_row_diff.sample.replace({np.nan: None}).to_dict(),
330+
joined_sample=_row_diff.joined_sample.replace({np.nan: None}).to_dict(),
331+
s_sample=_row_diff.s_sample.replace({np.nan: None}).to_dict(),
332+
t_sample=_row_diff.t_sample.replace({np.nan: None}).to_dict(),
333+
column_stats=_row_diff.column_stats.replace({np.nan: None}).to_dict(),
334+
source_count=_row_diff.source_count,
335+
target_count=_row_diff.target_count,
336+
count_pct_change=_row_diff.count_pct_change,
337+
decimals=getattr(_row_diff, "decimals", 3),
338+
processed_sample_data=processed_sample_data,
339+
)
340+
341+
s_index, t_index, _ = diff.key_columns
342+
table_diff_result = TableDiff(
343+
schema_diff=schema_diff,
344+
row_diff=row_diff,
345+
on=[(s.name, t.name) for s, t in zip(s_index, t_index)],
346+
)
292347
return ApiResponseGetTableDiff(data=table_diff_result)
293348

294349
raise NotImplementedError(f"API request not implemented: {request.endpoint}")

web/server/api/endpoints/table_diff.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def _normalize(val: t.Any) -> t.Any:
2727

2828
def _process_sample_data(
2929
row_diff: t.Any, source_name: str, target_name: str
30-
) -> t.Optional[ProcessedSampleData]:
30+
) -> ProcessedSampleData:
3131
import pandas as pd
3232

3333
if row_diff.joined_sample.shape[0] == 0:

0 commit comments

Comments
 (0)