diff --git a/ci/environment-rapids.yml b/ci/environment-rapids.yml new file mode 100644 index 0000000000..b35eb0fd05 --- /dev/null +++ b/ci/environment-rapids.yml @@ -0,0 +1,11 @@ +# This is an addition to ci/environment.yml. +# Add cudf and downgrade some pinned dependencies. +channels: + - rapidsai-nightly + - conda-forge + - nvidia +dependencies: + - dask-cudf =24.02 + - dask-cuda =24.02 + - pandas ==1.5.3 # pinned by cudf + - pynvml ==11.4.1 # pinned by dask-cuda diff --git a/ci/environment.yml b/ci/environment.yml index 5ca908f06f..4940b57b19 100644 --- a/ci/environment.yml +++ b/ci/environment.yml @@ -8,14 +8,13 @@ dependencies: # - AB_environments/AB_sample.conda.yaml ######################################################## - - python >=3.9 + - python >=3.9,<3.11 - pip - coiled >=0.2.54 - numpy ==1.26.2 - pandas ==2.1.4 - dask ==2023.12.0 - distributed ==2023.12.0 - - dask-expr ==0.2.8 - dask-labextension ==7.0.0 - dask-ml ==2023.3.24 - fsspec ==2023.12.1 diff --git a/tests/tpch/conftest.py b/tests/tpch/conftest.py index a2f3c3bf6d..5487fa367e 100644 --- a/tests/tpch/conftest.py +++ b/tests/tpch/conftest.py @@ -17,6 +17,7 @@ def pytest_addoption(parser): parser.addoption("--local", action="store_true", default=False, help="") + parser.addoption("--rapids", action="store_true", default=False, help="") parser.addoption("--cloud", action="store_false", dest="local", help="") parser.addoption("--restart", action="store_true", default=True, help="") parser.addoption("--no-restart", action="store_false", dest="restart", help="") @@ -48,6 +49,11 @@ def local(request): return request.config.getoption("local") +@pytest.fixture(scope="session") +def rapids(request): + return request.config.getoption("rapids") + + @pytest.fixture(scope="session") def restart(request): return request.config.getoption("restart") @@ -186,6 +192,7 @@ def cluster_spec(scale): @pytest.fixture(scope="module") def cluster( local, + rapids, scale, module, dask_env_variables, @@ -195,19 +202,38 @@ def cluster( make_chart, ): if local: - with LocalCluster() as cluster: - yield cluster - else: - kwargs = dict( - name=f"tpch-{module}-{scale}-{name}", - environ=dask_env_variables, - tags=github_cluster_tags, - region="us-east-2", - **cluster_spec, - ) - with dask.config.set({"distributed.scheduler.worker-saturation": "inf"}): - with coiled.Cluster(**kwargs) as cluster: + if not rapids: + with LocalCluster() as cluster: yield cluster + else: + from dask_cuda import LocalCUDACluster + + with dask.config.set( + {"dataframe.backend": "cudf", "dataframe.shuffle.method": "tasks"} + ): + with LocalCUDACluster(rmm_pool_size="24GB") as cluster: + yield cluster + else: + if not rapids: + kwargs = dict( + name=f"tpch-{module}-{scale}-{name}", + environ=dask_env_variables, + tags=github_cluster_tags, + region="us-east-2", + **cluster_spec, + ) + with dask.config.set({"distributed.scheduler.worker-saturation": "inf"}): + with coiled.Cluster(**kwargs) as cluster: + yield cluster + else: + # should be using Coiled for this + from dask_cuda import LocalCUDACluster + + with dask.config.set( + {"dataframe.backend": "cudf", "dataframe.shuffle.method": "tasks"} + ): + with LocalCUDACluster(rmm_pool_size="24GB") as cluster: + yield cluster @pytest.fixture diff --git a/tests/tpch/test_dask.py b/tests/tpch/test_dask.py index e2b629c77e..8b00d55f9e 100644 --- a/tests/tpch/test_dask.py +++ b/tests/tpch/test_dask.py @@ -6,10 +6,12 @@ dd = pytest.importorskip("dask_expr") +BLOCKSIZE = "default" + def test_query_1(client, dataset_path, fs): VAR1 = datetime(1998, 9, 2) - lineitem_ds = dd.read_parquet(dataset_path + "lineitem", filesystem=fs) + lineitem_ds = dd.read_parquet(dataset_path + "lineitem", filesystem=fs, blocksize=BLOCKSIZE) lineitem_filtered = lineitem_ds[lineitem_ds.l_shipdate <= VAR1] lineitem_filtered["sum_qty"] = lineitem_filtered.l_quantity @@ -50,11 +52,11 @@ def test_query_2(client, dataset_path, fs): var2 = "BRASS" var3 = "EUROPE" - region_ds = dd.read_parquet(dataset_path + "region", filesystem=fs) - nation_filtered = dd.read_parquet(dataset_path + "nation", filesystem=fs) - supplier_filtered = dd.read_parquet(dataset_path + "supplier", filesystem=fs) - part_filtered = dd.read_parquet(dataset_path + "part", filesystem=fs) - partsupp_filtered = dd.read_parquet(dataset_path + "partsupp", filesystem=fs) + region_ds = dd.read_parquet(dataset_path + "region", filesystem=fs, blocksize=BLOCKSIZE) + nation_filtered = dd.read_parquet(dataset_path + "nation", filesystem=fs, blocksize=BLOCKSIZE) + supplier_filtered = dd.read_parquet(dataset_path + "supplier", filesystem=fs, blocksize=BLOCKSIZE) + part_filtered = dd.read_parquet(dataset_path + "part", filesystem=fs, blocksize=BLOCKSIZE) + partsupp_filtered = dd.read_parquet(dataset_path + "partsupp", filesystem=fs, blocksize=BLOCKSIZE) region_filtered = region_ds[(region_ds["r_name"] == var3)] r_n_merged = nation_filtered.merge( @@ -118,9 +120,9 @@ def test_query_3(client, dataset_path, fs): var1 = datetime.strptime("1995-03-15", "%Y-%m-%d") var2 = "BUILDING" - lineitem_ds = dd.read_parquet(dataset_path + "lineitem", filesystem=fs) - orders_ds = dd.read_parquet(dataset_path + "orders", filesystem=fs) - cutomer_ds = dd.read_parquet(dataset_path + "customer", filesystem=fs) + lineitem_ds = dd.read_parquet(dataset_path + "lineitem", filesystem=fs, blocksize=BLOCKSIZE) + orders_ds = dd.read_parquet(dataset_path + "orders", filesystem=fs, blocksize=BLOCKSIZE) + cutomer_ds = dd.read_parquet(dataset_path + "customer", filesystem=fs, blocksize=BLOCKSIZE) lsel = lineitem_ds.l_shipdate > var1 osel = orders_ds.o_orderdate < var1 @@ -144,8 +146,8 @@ def test_query_4(client, dataset_path, fs): date1 = datetime.strptime("1993-10-01", "%Y-%m-%d") date2 = datetime.strptime("1993-07-01", "%Y-%m-%d") - line_item_ds = dd.read_parquet(dataset_path + "lineitem", filesystem=fs) - orders_ds = dd.read_parquet(dataset_path + "orders", filesystem=fs) + line_item_ds = dd.read_parquet(dataset_path + "lineitem", filesystem=fs, blocksize=BLOCKSIZE) + orders_ds = dd.read_parquet(dataset_path + "orders", filesystem=fs, blocksize=BLOCKSIZE) lsel = line_item_ds.l_commitdate < line_item_ds.l_receiptdate osel = (orders_ds.o_orderdate < date1) & (orders_ds.o_orderdate >= date2) @@ -168,12 +170,12 @@ def test_query_5(client, dataset_path, fs): date1 = datetime.strptime("1994-01-01", "%Y-%m-%d") date2 = datetime.strptime("1995-01-01", "%Y-%m-%d") - region_ds = dd.read_parquet(dataset_path + "region", filesystem=fs) - nation_ds = dd.read_parquet(dataset_path + "nation", filesystem=fs) - customer_ds = dd.read_parquet(dataset_path + "customer", filesystem=fs) - line_item_ds = dd.read_parquet(dataset_path + "lineitem", filesystem=fs) - orders_ds = dd.read_parquet(dataset_path + "orders", filesystem=fs) - supplier_ds = dd.read_parquet(dataset_path + "supplier", filesystem=fs) + region_ds = dd.read_parquet(dataset_path + "region", filesystem=fs, blocksize=BLOCKSIZE) + nation_ds = dd.read_parquet(dataset_path + "nation", filesystem=fs, blocksize=BLOCKSIZE) + customer_ds = dd.read_parquet(dataset_path + "customer", filesystem=fs, blocksize=BLOCKSIZE) + line_item_ds = dd.read_parquet(dataset_path + "lineitem", filesystem=fs, blocksize=BLOCKSIZE) + orders_ds = dd.read_parquet(dataset_path + "orders", filesystem=fs, blocksize=BLOCKSIZE) + supplier_ds = dd.read_parquet(dataset_path + "supplier", filesystem=fs, blocksize=BLOCKSIZE) rsel = region_ds.r_name == "ASIA" osel = (orders_ds.o_orderdate >= date1) & (orders_ds.o_orderdate < date2) @@ -198,7 +200,7 @@ def test_query_6(client, dataset_path, fs): date2 = datetime.strptime("1995-01-01", "%Y-%m-%d") var3 = 24 - line_item_ds = dd.read_parquet(dataset_path + "lineitem", filesystem=fs) + line_item_ds = dd.read_parquet(dataset_path + "lineitem", filesystem=fs, blocksize=BLOCKSIZE) sel = ( (line_item_ds.l_shipdate >= date1) @@ -217,11 +219,11 @@ def test_query_7(client, dataset_path, fs): var1 = datetime.strptime("1995-01-01", "%Y-%m-%d") var2 = datetime.strptime("1997-01-01", "%Y-%m-%d") - nation_ds = dd.read_parquet(dataset_path + "nation", filesystem=fs) - customer_ds = dd.read_parquet(dataset_path + "customer", filesystem=fs) - line_item_ds = dd.read_parquet(dataset_path + "lineitem", filesystem=fs) - orders_ds = dd.read_parquet(dataset_path + "orders", filesystem=fs) - supplier_ds = dd.read_parquet(dataset_path + "supplier", filesystem=fs) + nation_ds = dd.read_parquet(dataset_path + "nation", filesystem=fs, blocksize=BLOCKSIZE) + customer_ds = dd.read_parquet(dataset_path + "customer", filesystem=fs, blocksize=BLOCKSIZE) + line_item_ds = dd.read_parquet(dataset_path + "lineitem", filesystem=fs, blocksize=BLOCKSIZE) + orders_ds = dd.read_parquet(dataset_path + "orders", filesystem=fs, blocksize=BLOCKSIZE) + supplier_ds = dd.read_parquet(dataset_path + "supplier", filesystem=fs, blocksize=BLOCKSIZE) lineitem_filtered = line_item_ds[ (line_item_ds["l_shipdate"] >= var1) & (line_item_ds["l_shipdate"] < var2)