11# type: ignore
22from __future__ import annotations
33
4- import os
54import pathlib
65import re
76import sys
1918
2019from sqlmesh import Config , Context
2120from sqlmesh .cli .project_init import init_example_project
22- from sqlmesh .core .config import load_config_from_paths
2321from sqlmesh .core .config .connection import ConnectionConfig
2422import sqlmesh .core .dialect as d
2523from sqlmesh .core .environment import EnvironmentSuffixTarget
@@ -1936,49 +1934,16 @@ def test_transaction(ctx: TestContext):
19361934 ctx .compare_with_current (table , input_data )
19371935
19381936
1939- def test_sushi (ctx : TestContext , tmp_path_factory : pytest . TempPathFactory ):
1937+ def test_sushi (ctx : TestContext , tmp_path : pathlib . Path ):
19401938 if ctx .mark == "athena_hive" :
19411939 pytest .skip (
19421940 "Sushi end-to-end tests only need to run once for Athena because sushi needs a hybrid of both Hive and Iceberg"
19431941 )
19441942
1945- tmp_path = tmp_path_factory .mktemp (f"sushi_{ ctx .test_id } " )
1946-
19471943 sushi_test_schema = ctx .add_test_suffix ("sushi" )
19481944 sushi_state_schema = ctx .add_test_suffix ("sushi_state" )
19491945 raw_test_schema = ctx .add_test_suffix ("raw" )
19501946
1951- config = load_config_from_paths (
1952- Config ,
1953- project_paths = [
1954- pathlib .Path (os .path .join (os .path .dirname (__file__ ), "config.yaml" )),
1955- ],
1956- personal_paths = [pathlib .Path ("~/.sqlmesh/config.yaml" ).expanduser ()],
1957- )
1958- before_all = [
1959- f"CREATE SCHEMA IF NOT EXISTS { raw_test_schema } " ,
1960- f"DROP VIEW IF EXISTS { raw_test_schema } .demographics" ,
1961- f"CREATE VIEW { raw_test_schema } .demographics AS (SELECT 1 AS customer_id, '00000' AS zip)" ,
1962- ]
1963- config .before_all = [
1964- quote_identifiers (
1965- parse_one (e , dialect = config .model_defaults .dialect ),
1966- dialect = config .model_defaults .dialect ,
1967- ).sql (dialect = config .model_defaults .dialect )
1968- for e in before_all
1969- ]
1970-
1971- # To enable parallelism in integration tests
1972- config .gateways = {ctx .gateway : config .gateways [ctx .gateway ]}
1973- current_gateway_config = config .gateways [ctx .gateway ]
1974- current_gateway_config .state_schema = sushi_state_schema
1975-
1976- if ctx .dialect == "athena" :
1977- # Ensure that this test is using the same s3_warehouse_location as TestContext (which includes the testrun_id)
1978- current_gateway_config .connection .s3_warehouse_location = (
1979- ctx .engine_adapter .s3_warehouse_location
1980- )
1981-
19821947 # Copy sushi example to tmpdir
19831948 shutil .copytree (pathlib .Path ("./examples/sushi" ), tmp_path , dirs_exist_ok = True )
19841949
@@ -2000,7 +1965,23 @@ def test_sushi(ctx: TestContext, tmp_path_factory: pytest.TempPathFactory):
20001965 contents = contents .replace (search , replace )
20011966 f .write_text (contents )
20021967
2003- context = Context (paths = tmp_path , config = config , gateway = ctx .gateway )
1968+ before_all = [
1969+ f"CREATE SCHEMA IF NOT EXISTS { raw_test_schema } " ,
1970+ f"DROP VIEW IF EXISTS { raw_test_schema } .demographics" ,
1971+ f"CREATE VIEW { raw_test_schema } .demographics AS (SELECT 1 AS customer_id, '00000' AS zip)" ,
1972+ ]
1973+
1974+ def _mutate_config (gateway : str , config : Config ) -> None :
1975+ config .gateways [gateway ].state_schema = sushi_state_schema
1976+ config .before_all = [
1977+ quote_identifiers (
1978+ parse_one (e , dialect = config .model_defaults .dialect ),
1979+ dialect = config .model_defaults .dialect ,
1980+ ).sql (dialect = config .model_defaults .dialect )
1981+ for e in before_all
1982+ ]
1983+
1984+ context = ctx .create_context (_mutate_config , path = tmp_path , ephemeral_state_connection = False )
20041985
20051986 end = now ()
20061987 start = to_date (end - timedelta (days = 7 ))
@@ -2355,9 +2336,7 @@ def validate_no_comments(
23552336 ctx ._schemas .append (schema )
23562337
23572338
2358- def test_init_project (ctx : TestContext , tmp_path_factory : pytest .TempPathFactory ):
2359- tmp_path = tmp_path_factory .mktemp (f"init_project_{ ctx .test_id } " )
2360-
2339+ def test_init_project (ctx : TestContext , tmp_path : pathlib .Path ):
23612340 schema_name = ctx .add_test_suffix (TEST_SCHEMA )
23622341 state_schema = ctx .add_test_suffix ("sqlmesh_state" )
23632342
@@ -2383,33 +2362,15 @@ def _normalize_snowflake(name: str, prefix_regex: str = "(sqlmesh__)(.*)"):
23832362
23842363 init_example_project (tmp_path , ctx .engine_type , schema_name = schema_name )
23852364
2386- config = load_config_from_paths (
2387- Config ,
2388- project_paths = [
2389- pathlib .Path (os .path .join (os .path .dirname (__file__ ), "config.yaml" )),
2390- ],
2391- personal_paths = [pathlib .Path ("~/.sqlmesh/config.yaml" ).expanduser ()],
2392- )
2393-
2394- # ensure default dialect comes from init_example_project and not ~/.sqlmesh/config.yaml
2395- if config .model_defaults .dialect != ctx .dialect :
2396- config .model_defaults = config .model_defaults .copy (update = {"dialect" : ctx .dialect })
2397-
2398- # To enable parallelism in integration tests
2399- config .gateways = {ctx .gateway : config .gateways [ctx .gateway ]}
2400- current_gateway_config = config .gateways [ctx .gateway ]
2401-
2402- if ctx .dialect == "athena" :
2403- # Ensure that this test is using the same s3_warehouse_location as TestContext (which includes the testrun_id)
2404- current_gateway_config .connection .s3_warehouse_location = (
2405- ctx .engine_adapter .s3_warehouse_location
2406- )
2365+ def _mutate_config (gateway : str , config : Config ):
2366+ # ensure default dialect comes from init_example_project and not ~/.sqlmesh/config.yaml
2367+ if config .model_defaults .dialect != ctx .dialect :
2368+ config .model_defaults = config .model_defaults .copy (update = {"dialect" : ctx .dialect })
24072369
2408- # Ensure the state schema is unique to this test
2409- config .gateways [ctx . gateway ].state_schema = state_schema
2370+ # Ensure the state schema is unique to this test (since we deliberately use the warehouse as the state connection)
2371+ config .gateways [gateway ].state_schema = state_schema
24102372
2411- context = Context (paths = tmp_path , config = config , gateway = ctx .gateway )
2412- ctx .engine_adapter = context .engine_adapter
2373+ context = ctx .create_context (_mutate_config , path = tmp_path , ephemeral_state_connection = False )
24132374
24142375 if ctx .default_table_format :
24152376 # if the default table format is explicitly set, ensure its being used
0 commit comments