diff --git a/samples/ml/distributed_partition_function/README.md b/samples/ml/distributed_partition_function/README.md new file mode 100644 index 00000000..7ccda16e --- /dev/null +++ b/samples/ml/distributed_partition_function/README.md @@ -0,0 +1,42 @@ +# Distributed Partition Function (DPF) - Example Walkthrough + +## Introduction + +The **Distributed Partition Function (DPF)** lets you process data in parallel across one or more nodes in a compute pool. DPF partitions your data by a specified column (or by staged files) and executes your Python function on each partition concurrently. It handles distributed orchestration, errors, observability, and artifact persistence automatically. + +This example uses a **supply chain allocation** scenario: given factories with limited capacity and warehouses with specific demand, find the optimal shipping plan per region using `scipy.optimize.linprog`. Each region is solved as an independent DPF partition. + +## Execution Modes + +DPF supports two execution modes, both demonstrated in this notebook: + +| Mode | Method | Description | +|------|--------|-------------| +| **DataFrame mode** | `run()` | Partition a Snowpark DataFrame by column values and execute your function on each partition concurrently. | +| **Stage mode** | `run_from_stage()` | Process files from a Snowflake stage where each file becomes a partition. Ideal for large-scale file processing. | + +## What This Notebook Covers + +1. **Setup** - Session, stage, scale compute, and synthetic data generation +2. **DataFrame mode** - Define a processing function, run DPF, monitor progress, retrieve results, inspect logs, restore completed runs +3. **Stage mode** - Copy data to parquet files on stage, run DPF from stage +4. **ML Jobs deployment** - Deploy DPF workloads via the `@remote` decorator + +## Prerequisites + +- A [compute pool](https://docs.snowflake.com/en/sql-reference/sql/create-compute-pool) with at least 3 max nodes (e.g., `CPU_X64_S`), or use the system-provided `SYSTEM_COMPUTE_POOL_CPU` +- A Snowflake Notebook running on the compute pool (Container Runtime) +- Stage access permissions for storing results and artifacts + +## Getting Started + +This notebook is intended to be run in a **Snowflake Notebook** environment on Snowpark Container Services. If running locally, use the **ML Jobs deployment** section at the bottom of the notebook to submit DPF workloads via the `@remote` decorator. + +Open the [DPF Example Notebook](./dpf_example.ipynb) for a full end-to-end walkthrough. + +## References + +- [DPF Documentation](https://docs.snowflake.com/en/developer-guide/snowflake-ml/process-data-across-partitions) +- [DPF API Reference](https://docs.snowflake.com/en/developer-guide/snowpark-ml/reference/latest/container-runtime/distributors.distributed_partition_function) +- [ML Jobs Documentation](https://docs.snowflake.com/developer-guide/snowflake-ml/ml-jobs/overview) +- [Many Model Training (MMT) Example](../many_model_training/mmt_example.ipynb) diff --git a/samples/ml/distributed_partition_function/dpf_example.ipynb b/samples/ml/distributed_partition_function/dpf_example.ipynb new file mode 100644 index 00000000..e15d71e4 --- /dev/null +++ b/samples/ml/distributed_partition_function/dpf_example.ipynb @@ -0,0 +1,812 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "e16f2bf5-88f3-4dfa-8d7a-6be220007ba3", + "metadata": { + "collapsed": false, + "name": "cell0" + }, + "source": [ + "# Distributed Partition Function (DPF) - Example Walkthrough\n", + "\n", + "This notebook demonstrates how to use the **Distributed Partition Function (DPF)** to process data in parallel across multiple nodes in a compute pool. DPF partitions your data and executes your Python function on each partition concurrently, handling distributed orchestration, errors, observability, and artifact persistence automatically.\n", + "\n", + "We'll use a **supply chain allocation** scenario as the example: given factories with limited capacity and warehouses with specific demand, find the optimal shipping plan per region using `scipy.optimize.linprog`.\n", + "\n", + "DPF supports two execution modes:\n", + "\n", + "- **DataFrame mode** (`run()`): Partition a Snowpark DataFrame by column values and execute your function on each partition concurrently.\n", + "- **Stage mode** (`run_from_stage()`): Process files from a Snowflake stage where each file becomes a partition. Ideal for large-scale file processing with predictable memory usage.\n", + "\n", + "**Environment:** This notebook is designed to run in a Snowflake Notebook on Container Runtime. If running locally, see the **ML Jobs deployment** section at the bottom.\n", + "\n", + "**Prerequisites:**\n", + "- A compute pool with max nodes >= 3 (e.g., `CPU_X64_S`), or the system-provided `SYSTEM_COMPUTE_POOL_CPU`" + ] + }, + { + "cell_type": "markdown", + "id": "bebe7269-edd8-4117-957b-d19a5be03ff2", + "metadata": { + "collapsed": false, + "name": "cell1" + }, + "source": [ + "---\n", + "## Setup" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "85a70bc5-f1d0-45e7-99ee-7d9c811df886", + "metadata": { + "language": "python", + "name": "cell2" + }, + "outputs": [], + "source": [ + "from datetime import datetime\n", + "import json\n", + "\n", + "import pandas as pd\n", + "import numpy as np\n", + "\n", + "from snowflake.snowpark import Session\n", + "\n", + "\n", + "session = Session.builder.getOrCreate()\n", + "\n", + "# Configuration\n", + "database = session.get_current_database() or \"MY_DATABASE\" # Change to your database\n", + "schema = session.get_current_schema() or \"MY_SCHEMA\" # Change to your schema\n", + "\n", + "input_stage = \"DPF_INPUT_STAGE\"\n", + "dpf_stage = \"DPF_RESULTS_STAGE\"\n", + "input_table = \"SUPPLY_CHAIN_DATA\"\n", + "output_table = \"OPTIMIZED_SHIPPING_MANIFEST\"\n", + "\n", + "# Create stages\n", + "session.use_schema(f\"{database}.{schema}\")\n", + "session.sql(f\"CREATE STAGE IF NOT EXISTS {dpf_stage}\").collect()\n", + "session.sql(f\"CREATE STAGE IF NOT EXISTS {input_stage}\").collect()\n", + "\n", + "print(f\"Session: {session}\")" + ] + }, + { + "cell_type": "markdown", + "id": "b8a5fdce-1b4d-4041-a1b9-862cefd1eade", + "metadata": { + "collapsed": false, + "name": "cell3" + }, + "source": [ + "### Import DPF modules and Scale Compute Nodes\n", + "Snowflake Notebook on Container Runtime only - skip this cell if running locally." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bbaef190-6856-4616-8a74-30a9b452fbf8", + "metadata": { + "language": "python", + "name": "cell4" + }, + "outputs": [], + "source": [ + "from snowflake.ml.modeling.distributors.distributed_partition_function.dpf import DPF\n", + "from snowflake.ml.modeling.distributors.distributed_partition_function.dpf_run import (\n", + " DPFRun,\n", + ")\n", + "from snowflake.ml.modeling.distributors.distributed_partition_function.entities import (\n", + " RunStatus,\n", + " ExecutionOptions,\n", + ")\n", + "from snowflake.ml.runtime_cluster import scale_cluster\n", + "\n", + "# Scale to 3 nodes for parallel processing\n", + "scale_cluster(expected_cluster_size=3)" + ] + }, + { + "cell_type": "markdown", + "id": "830b0d3a-0f5c-4100-afac-6a5be0e36a17", + "metadata": { + "collapsed": false, + "name": "cell5" + }, + "source": [ + "### Create Synthetic Supply Chain Data\n", + "\n", + "Generate a dataset with 5 regions, each containing 3 factories (supply) and 10 warehouses (demand)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "42b3e109-80a8-4fcd-a1b0-a41842e6cbd5", + "metadata": { + "language": "python", + "name": "cell6" + }, + "outputs": [], + "source": [ + "def create_supply_chain_data(session, table_name):\n", + " \"\"\"Generate synthetic supply chain data with factories and warehouses across regions.\"\"\"\n", + " regions = [\"NA_WEST\", \"NA_EAST\", \"EMEA_CENTRAL\", \"APAC_SOUTH\", \"LATAM\"]\n", + " np.random.seed(42)\n", + " data = []\n", + "\n", + " for reg in regions:\n", + " # 3 Factories per region (supply)\n", + " for i in range(3):\n", + " data.append(\n", + " {\n", + " \"REGION\": reg,\n", + " \"LOCATION_ID\": f\"FACT_{reg}_{i}\",\n", + " \"TYPE\": \"FACTORY\",\n", + " \"LAT\": np.random.uniform(25, 55),\n", + " \"LON\": np.random.uniform(-130, -60),\n", + " \"CAPACITY\": 1000,\n", + " \"DEMAND\": 0,\n", + " }\n", + " )\n", + " # 10 Warehouses per region (demand)\n", + " for j in range(10):\n", + " data.append(\n", + " {\n", + " \"REGION\": reg,\n", + " \"LOCATION_ID\": f\"WH_{reg}_{j}\",\n", + " \"TYPE\": \"WAREHOUSE\",\n", + " \"LAT\": np.random.uniform(25, 55),\n", + " \"LON\": np.random.uniform(-130, -60),\n", + " \"CAPACITY\": 0,\n", + " \"DEMAND\": 250,\n", + " }\n", + " )\n", + "\n", + " df = pd.DataFrame(data)\n", + " sdf = session.create_dataframe(df)\n", + " sdf.write.mode(\"overwrite\").save_as_table(table_name)\n", + " print(f\"Created '{table_name}' with {len(df)} rows across {len(regions)} regions\")\n", + " return session.table(table_name)\n", + "\n", + "\n", + "supply_chain_sdf = create_supply_chain_data(session, input_table)\n", + "supply_chain_sdf.show()" + ] + }, + { + "cell_type": "markdown", + "id": "a0cafdbc-ee9b-4fb4-b228-820dc3dcf5c1", + "metadata": { + "collapsed": false, + "name": "cell7" + }, + "source": [ + "---\n", + "## DataFrame Mode: Process Data by Column Partitions\n", + "\n", + "Partition the `SUPPLY_CHAIN_DATA` table by `REGION` and solve each region's allocation in parallel.\n", + "\n", + "1. **Define the processing function** - optimization logic that runs on each partition.\n", + "2. **Initialize and run DPF** - launch parallel execution across all partitions.\n", + "3. **Monitor progress** - track status and wait for completion.\n", + "4. **Retrieve results** - collect artifacts and output data from each partition.\n", + "5. **Restore a completed run** - access results from a previous run without re-executing.\n", + "\n", + "### Step 1: Define the Processing Function\n", + "\n", + "This function runs on each partition (region). It receives the partition's data via `data_connector` and\n", + "uses `scipy.optimize.linprog` to solve the transportation problem: minimize shipping cost while\n", + "satisfying warehouse demand without exceeding factory capacity." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "785989a9-c244-4761-9462-8cb5a62decd1", + "metadata": { + "language": "python", + "name": "cell8" + }, + "outputs": [], + "source": [ + "def solve_allocation(data_connector, context):\n", + " \"\"\"\n", + " Solve the supply chain allocation problem for a single region.\n", + "\n", + " Uses linear programming to find the optimal shipment plan that minimizes\n", + " total transportation cost (based on distance) subject to:\n", + " - Factory capacity constraints (supply)\n", + " - Warehouse demand constraints (demand)\n", + "\n", + " Args:\n", + " data_connector: Provides access to the partition's data.\n", + " context: PartitionContext with partition_id and artifact methods.\n", + " \"\"\"\n", + " from scipy.optimize import linprog\n", + " from scipy.spatial.distance import cdist\n", + " import pandas as pd\n", + " import numpy as np\n", + " import json\n", + "\n", + " df = data_connector.to_pandas()\n", + " region = context.partition_id\n", + " print(f\"[{region}] Processing {len(df)} locations\")\n", + "\n", + " factories = df[df[\"TYPE\"] == \"FACTORY\"].reset_index(drop=True)\n", + " warehouses = df[df[\"TYPE\"] == \"WAREHOUSE\"].reset_index(drop=True)\n", + " n_fact = len(factories)\n", + " n_wh = len(warehouses)\n", + "\n", + " # Build cost matrix (Euclidean distance as proxy for shipping cost)\n", + " cost_matrix = cdist(\n", + " factories[[\"LAT\", \"LON\"]], warehouses[[\"LAT\", \"LON\"]], metric=\"euclidean\"\n", + " )\n", + " c = cost_matrix.flatten()\n", + "\n", + " # Inequality constraint: total outbound from Factory_i <= Capacity_i\n", + " A_ub = np.zeros((n_fact, n_fact * n_wh))\n", + " for i in range(n_fact):\n", + " A_ub[i, i * n_wh : (i + 1) * n_wh] = 1\n", + " b_ub = factories[\"CAPACITY\"].values.astype(float)\n", + "\n", + " # Equality constraint: total inbound to Warehouse_j == Demand_j\n", + " A_eq = np.zeros((n_wh, n_fact * n_wh))\n", + " for j in range(n_wh):\n", + " A_eq[j, j::n_wh] = 1\n", + " b_eq = warehouses[\"DEMAND\"].values.astype(float)\n", + "\n", + " # Solve\n", + " res = linprog(c, A_ub=A_ub, b_ub=b_ub, A_eq=A_eq, b_eq=b_eq, method=\"highs\")\n", + "\n", + " if res.success:\n", + " allocation = res.x.reshape((n_fact, n_wh))\n", + " manifest = []\n", + " for i in range(n_fact):\n", + " for j in range(n_wh):\n", + " qty = allocation[i, j]\n", + " if qty > 0.1:\n", + " manifest.append(\n", + " {\n", + " \"REGION\": region,\n", + " \"ORIGIN\": factories.loc[i, \"LOCATION_ID\"],\n", + " \"DESTINATION\": warehouses.loc[j, \"LOCATION_ID\"],\n", + " \"SHIPMENT_QTY\": round(float(qty), 2),\n", + " \"UNIT_DISTANCE\": round(float(cost_matrix[i, j]), 4),\n", + " }\n", + " )\n", + "\n", + " manifest_df = pd.DataFrame(manifest)\n", + "\n", + " summary = {\n", + " \"region\": region,\n", + " \"status\": \"OPTIMAL\",\n", + " \"total_cost\": round(float(res.fun), 2),\n", + " \"shipment_count\": len(manifest),\n", + " \"total_units_shipped\": round(sum(m[\"SHIPMENT_QTY\"] for m in manifest), 2),\n", + " }\n", + " print(\n", + " f\"[{region}] Optimal cost: {summary['total_cost']}, shipments: {len(manifest)}\"\n", + " )\n", + "\n", + " # Upload summary as a stage artifact\n", + " context.upload_to_stage(\n", + " summary,\n", + " \"summary.json\",\n", + " write_function=lambda obj, path: json.dump(obj, open(path, \"w\")),\n", + " )\n", + "\n", + " # Write results to a Snowflake table using the bounded session pool\n", + " context.with_session(\n", + " lambda session: session.create_dataframe(manifest_df)\n", + " .write.mode(\"append\")\n", + " .save_as_table(output_table)\n", + " )\n", + " else:\n", + " print(f\"[{region}] Optimization failed: {res.message}\")" + ] + }, + { + "cell_type": "markdown", + "id": "7d6f4691-a3cf-4c46-8785-1aa016476b5d", + "metadata": { + "collapsed": false, + "name": "cell9" + }, + "source": [ + "### Step 2: Initialize and Run DPF" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "84e07682-7bca-4ccb-b0cf-51cbf4bc478c", + "metadata": { + "language": "python", + "name": "cell10" + }, + "outputs": [], + "source": [ + "dpf = DPF(func=solve_allocation, stage_name=dpf_stage)\n", + "\n", + "session.sql(f\"DROP TABLE IF EXISTS {output_table}\").collect()\n", + "\n", + "run = dpf.run(\n", + " partition_by=\"REGION\",\n", + " snowpark_dataframe=session.table(input_table),\n", + " run_id=f\"supply_chain_{datetime.now():%Y%m%d_%H%M%S}\",\n", + " execution_options=ExecutionOptions(use_head_node=True, num_cpus_per_worker=1),\n", + ")\n", + "print(f\"Launched: {run.run_id}\")" + ] + }, + { + "cell_type": "markdown", + "id": "f4b8b230-b983-48ca-a9e7-18ca97279764", + "metadata": { + "collapsed": false, + "name": "cell11" + }, + "source": [ + "### Step 3: Monitor Progress and Wait for Completion" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f5c3a1c7-d809-4a86-b395-16e56da332ec", + "metadata": { + "language": "python", + "name": "cell12" + }, + "outputs": [], + "source": [ + "final_status = run.wait() # Shows progress\n", + "print(f\"Job completed with status: {final_status}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2b824586-57ba-4d1d-bd96-f38f76ba2e73", + "metadata": { + "language": "python", + "name": "cell13" + }, + "outputs": [], + "source": [ + "# Quick summary of all partition statuses\n", + "progress = run.get_progress()\n", + "for status, partitions in progress.items():\n", + " print(f\"{status}: {len(partitions)} partitions\")" + ] + }, + { + "cell_type": "markdown", + "id": "2f92856d-4e0f-45cb-ba82-1b2ca6047ba6", + "metadata": { + "collapsed": false, + "name": "cell14" + }, + "source": [ + "### Step 4: Retrieve Results from Each Partition" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fedc35ff-c572-49e3-b10e-abdb31db8053", + "metadata": { + "language": "python", + "name": "cell15" + }, + "outputs": [], + "source": [ + "def print_results(summaries):\n", + " \"\"\"Format and display the supply chain optimization results.\"\"\"\n", + " for s in summaries:\n", + " print(f\" {s['region']}: cost={s['total_cost']}, shipments={s['shipment_count']}\")\n", + "\n", + " total_cost = sum(s[\"total_cost\"] for s in summaries)\n", + " total_shipments = sum(s[\"shipment_count\"] for s in summaries)\n", + " print(f\"\\n TOTAL: cost={total_cost:.2f}, shipments={total_shipments}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3ab86af1-e6f6-40ea-95f5-a1f5e5fae72a", + "metadata": { + "language": "python", + "name": "cell16" + }, + "outputs": [], + "source": [ + "if final_status == RunStatus.SUCCESS:\n", + " summaries = []\n", + " for partition_id, details in run.partition_details.items():\n", + " files = details.stage_artifacts_manager.list()\n", + " print(f\"Partition '{partition_id}' artifacts: {files}\")\n", + "\n", + " summary = details.stage_artifacts_manager.get(\n", + " \"summary.json\",\n", + " read_function=lambda path: json.load(open(path, \"r\")),\n", + " )\n", + " summaries.append(summary)\n", + "\n", + " print_results(summaries)\n", + "else:\n", + " print(f\"Run did not succeed: {final_status}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ae7b476e-340e-4103-904e-85298e9a3cdd", + "metadata": { + "language": "python", + "name": "cell17" + }, + "outputs": [], + "source": [ + "# View the results written to the Snowflake table\n", + "session.table(output_table).show()" + ] + }, + { + "cell_type": "markdown", + "id": "7fcab272-38a5-47b4-8a10-74f5f87eb299", + "metadata": { + "collapsed": false, + "name": "cell18" + }, + "source": [ + "### Inspect Partition Logs\n", + "\n", + "View stdout/stderr from individual partitions to verify processing or debug failures." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0c4a28a4-469b-46c9-b0cf-79e142bef101", + "metadata": { + "language": "python", + "name": "cell19" + }, + "outputs": [], + "source": [ + "# View logs from each partition\n", + "for partition_id, details in run.partition_details.items():\n", + " print(f\"--- {partition_id} ---\")\n", + " print(details.logs)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f8d1a479-8928-4f78-a61a-e67b82df1a7f", + "metadata": { + "language": "python", + "name": "cell20" + }, + "outputs": [], + "source": [ + "# Debug failed partitions (if any)\n", + "# progress = run.get_progress()\n", + "# for partition in progress.get(\"FAILED\", []):\n", + "# print(f\"--- Failed: {partition.partition_id} ---\")\n", + "# print(partition.logs)" + ] + }, + { + "cell_type": "markdown", + "id": "a9ad12c6-faf5-4fb1-b5de-c2289b9f4670", + "metadata": { + "collapsed": false, + "name": "cell21" + }, + "source": [ + "### Step 5: Restore Results from a Completed Run\n", + "\n", + "Access results from a previous run without re-executing. Useful after restarting a notebook session." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5d71d432-6c2c-4a49-96a8-8ba85397844f", + "metadata": { + "language": "python", + "name": "cell22" + }, + "outputs": [], + "source": [ + "restored_run = DPFRun.restore_from(\n", + " run_id=run.run_id,\n", + " stage_name=dpf_stage,\n", + ")\n", + "\n", + "print(f\"Restored run status: {restored_run.status}\")\n", + "for partition_id, details in restored_run.partition_details.items():\n", + " print(f\" {partition_id}: {details.status}\")\n", + "\n", + "# Note: Restored runs are read-only. You cannot call wait() or cancel() on them." + ] + }, + { + "cell_type": "markdown", + "id": "9e8d3b12-9019-4153-b808-195d16356657", + "metadata": { + "collapsed": false, + "name": "cell23" + }, + "source": [ + "---\n", + "## Stage Mode: Process Files from a Stage\n", + "\n", + "Process pre-staged parquet files where each file becomes a partition.\n", + "First, copy the data from the table to stage as parquet files, one per region." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "efc121f3-43d2-4f7f-88f6-0f7a8fce63d4", + "metadata": { + "language": "python", + "name": "cell24" + }, + "outputs": [], + "source": [ + "# Prepare parquet files on stage - one file per region\n", + "session.sql(f\"REMOVE @{input_stage}/supply_chain/\").collect()\n", + "\n", + "session.sql(\n", + " f\"\"\"\n", + " COPY INTO @{input_stage}/supply_chain/\n", + " FROM {input_table}\n", + " PARTITION BY REGION\n", + " FILE_FORMAT = (TYPE = PARQUET COMPRESSION = SNAPPY)\n", + " HEADER = TRUE\n", + "\"\"\"\n", + ").collect()\n", + "\n", + "# Verify staged files\n", + "session.sql(f\"LIST @{input_stage}/supply_chain/\").show()" + ] + }, + { + "cell_type": "markdown", + "id": "c5e47877-9163-49a6-a6ce-ac8bf82e23b1", + "metadata": { + "collapsed": false, + "name": "cell25" + }, + "source": [ + "### Run DPF from Stage\n", + "\n", + "The processing function signature is the same as DataFrame mode. The `data_connector` provides access\n", + "to each file's data, and `context.partition_id` is the relative file path." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1d320ec5-e978-4392-a1fd-d35d3b772dfc", + "metadata": { + "language": "python", + "name": "cell26" + }, + "outputs": [], + "source": [ + "dpf_from_stage = DPF(func=solve_allocation, stage_name=dpf_stage)\n", + "\n", + "session.sql(f\"DROP TABLE IF EXISTS {output_table}\").collect()\n", + "\n", + "stage_run = dpf_from_stage.run_from_stage(\n", + " stage_location=f\"@{input_stage}/supply_chain/\",\n", + " run_id=f\"supply_chain_stage_{datetime.now():%Y%m%d_%H%M%S}\",\n", + " file_pattern=\"*.parquet\",\n", + " execution_options=ExecutionOptions(\n", + " use_head_node=True,\n", + " num_cpus_per_worker=1,\n", + " ),\n", + ")\n", + "print(f\"Launched: {stage_run.run_id}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c4bb2ffc-6906-4341-a758-511545a2209f", + "metadata": { + "language": "python", + "name": "cell27" + }, + "outputs": [], + "source": [ + "stage_status = stage_run.wait()\n", + "print(f\"Stage mode completed with status: {stage_status}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1ab59f01-345e-4bda-8248-cb74f9941287", + "metadata": { + "language": "python", + "name": "cell28" + }, + "outputs": [], + "source": [ + "# View the results written to the Snowflake table\n", + "session.table(output_table).show()" + ] + }, + { + "cell_type": "markdown", + "id": "f1cad585-2349-4d16-ad60-2de39db7a30b", + "metadata": { + "collapsed": false, + "name": "cell29" + }, + "source": [ + "---\n", + "## Deploy with ML Jobs via `@remote`\n", + "\n", + "Run DPF in an ML Job from any IDE. ML Jobs execute inside Snowpark Container Services\n", + "and can scale across multiple nodes. Logs are available in Snowsight under Monitoring > Services & Jobs." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9f5b59b6-e8d5-47c3-91ed-d3366ea16885", + "metadata": { + "language": "python", + "name": "cell30" + }, + "outputs": [], + "source": [ + "job_stage = \"DPF_JOB_STAGE\"\n", + "compute_pool = \"SYSTEM_COMPUTE_POOL_CPU\" # Update with your compute pool name\n", + "\n", + "session.sql(f\"CREATE STAGE IF NOT EXISTS {job_stage}\").collect()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6c00d321-7185-4d77-9111-98d4b19b5a83", + "metadata": { + "language": "python", + "name": "cell31" + }, + "outputs": [], + "source": [ + "from snowflake.ml.jobs import remote\n", + "\n", + "@remote(\n", + " compute_pool=compute_pool,\n", + " stage_name=job_stage,\n", + " target_instances=3,\n", + ")\n", + "def launch_supply_chain_job():\n", + " \"\"\"\n", + " Launch a DPF supply chain optimization run as an ML Job.\n", + " \"\"\"\n", + " from datetime import datetime\n", + " from snowflake.snowpark import Session\n", + " from snowflake.ml.modeling.distributors.distributed_partition_function.dpf import (\n", + " DPF,\n", + " )\n", + " from snowflake.ml.modeling.distributors.distributed_partition_function.entities import (\n", + " ExecutionOptions,\n", + " )\n", + "\n", + " session = Session.builder.getOrCreate()\n", + " dpf_input = session.table(input_table)\n", + "\n", + " dpf = DPF(func=solve_allocation, stage_name=dpf_stage)\n", + " run = dpf.run(\n", + " partition_by=\"REGION\",\n", + " snowpark_dataframe=dpf_input,\n", + " run_id=f\"job_supply_chain_{datetime.now():%Y%m%d_%H%M%S}\",\n", + " execution_options=ExecutionOptions(use_head_node=True),\n", + " )\n", + " run.wait()\n", + "\n", + " print(f\"DPF run complete: {run.run_id}\")\n", + " return run.run_id\n", + "\n", + "session.sql(f\"DROP TABLE IF EXISTS {output_table}\").collect()\n", + "\n", + "job = launch_supply_chain_job()\n", + "print(f\"Job ID: {job.id}\")\n", + "print(f\"Status: {job.status}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fadd70e4-048c-45cf-b883-dcb71d525985", + "metadata": { + "language": "python", + "name": "cell32" + }, + "outputs": [], + "source": [ + "# Check the status and logs of the ML Job\n", + "print(f\"Status: {job.status}\")\n", + "print(job.get_logs())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9f8aadcc-9e10-4ab3-aa86-96a7483c308f", + "metadata": { + "language": "python", + "name": "cell33" + }, + "outputs": [], + "source": [ + "# View the results written to the Snowflake table\n", + "session.table(output_table).show()" + ] + }, + { + "cell_type": "markdown", + "id": "d3697df9-557d-4603-be01-d4dda3ac9b3b", + "metadata": { + "collapsed": false, + "name": "cell34" + }, + "source": [ + "---\n", + "## Cleanup\n", + "\n", + "Scale the cluster back down to a single node when you're done." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2b2d20d9-6934-42e7-9072-66723ce5884d", + "metadata": { + "language": "python", + "name": "cell35" + }, + "outputs": [], + "source": [ + "scale_cluster(expected_cluster_size=1)\n", + "\n", + "# Uncomment to drop objects created by this notebook\n", + "# session.sql(f\"DROP TABLE IF EXISTS {input_table}\").collect()\n", + "# session.sql(f\"DROP TABLE IF EXISTS {output_table}\").collect()\n", + "# session.sql(f\"DROP STAGE IF EXISTS {dpf_stage}\").collect()\n", + "# session.sql(f\"DROP STAGE IF EXISTS {input_stage}\").collect()\n", + "# session.sql(f\"DROP STAGE IF EXISTS {job_stage}\").collect()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "dpf-test", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.10.18" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}