diff --git a/code-exercises/01 - JAX AI Stack.ipynb b/code-exercises/01 - JAX AI Stack.ipynb index 1cbc054..02d46a6 100644 --- a/code-exercises/01 - JAX AI Stack.ipynb +++ b/code-exercises/01 - JAX AI Stack.ipynb @@ -1 +1,1372 @@ -{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[{"file_id":"183UawZ8L3Tbm1ueDynDqO_TyGysgZ8rt","timestamp":1755114181793}]},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"markdown","source":["[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/rcrowe-google/Learning-JAX/blob/main/code-exercises/01%20-%20JAX%20AI%20Stack.ipynb)\n","\n","# Introduction\n","\n","**Welcome to the JAX AI Stack Exercises!**\n","\n","This notebook is designed to accompany the \"Leveraging the JAX AI Stack\" lecture. You'll get hands-on experience with core JAX concepts, Flax NNX for model building, Optax for optimization, and Orbax for checkpointing.\n","\n","The exercises will guide you through implementing key components, drawing parallels to PyTorch where appropriate, to solidify your understanding.\n","\n","Let's get started!"],"metadata":{"id":"AEYnLrsY27El"}},{"cell_type":"code","execution_count":null,"metadata":{"id":"OPA5MMD621LQ"},"outputs":[],"source":["# @title Setup: Install and Import Libraries\n","# Install necessary libraries\n","!pip install -q jax-ai-stack==2025.9.3\n","\n","import jax\n","import jax.numpy as jnp\n","import flax\n","from flax import nnx\n","import optax\n","import orbax.checkpoint as ocp # For Orbax\n","from typing import Any, Dict, Tuple # For type hints\n","\n","# Helper to print PyTrees more nicely for demonstration\n","import pprint\n","import os # For Orbax directory management\n","import shutil # For cleaning up Orbax directory\n","\n","print(f\"JAX version: {jax.__version__}\")\n","print(f\"Flax version: {flax.__version__}\")\n","print(f\"Optax version: {optax.__version__}\")\n","print(f\"Orbax version: {ocp.__version__}\")\n","\n","# Global JAX PRNG key for reproducibility in exercises\n","# Students can learn to split this key for different operations.\n","main_key = jax.random.key(0)"]},{"cell_type":"markdown","source":["## Exercise 1: JAX Core & NumPy API\n","\n","**Goal**: Get familiar with jax.numpy and JAX's functional programming style.\n","\n","### Instructions:\n","\n","1. Create two JAX arrays, a (a 2x2 matrix of random numbers) and b (a 2x2 matrix of ones) using jax.numpy (jnp). You'll need a jax.random.key for creating random numbers.\n","2. Perform element-wise addition of a and b.\n","3. Perform matrix multiplication of a and b.\n","4. Demonstrate JAX's immutability:\n"," - Store the Python id() of array a.\n"," - Perform an operation like a = a + 1.\n"," - Print the new id() of a and observe that it has changed, indicating a new array was created."],"metadata":{"id":"3gC7luR35tJd"}},{"cell_type":"code","source":["# Instructions for Exercise 1\n","key_ex1, main_key = jax.random.split(main_key) # Split the main key\n","\n","# 1. Create JAX arrays a and b\n","# TODO: Create array 'a' (2x2 random normal) and 'b' (2x2 ones)\n","a = None # Placeholder\n","b = None # Placeholder\n","\n","print(\"Array a:\\n\", a)\n","print(\"Array b:\\n\", b)\n","\n","# 2. Perform element-wise addition\n","# TODO: Add a and b\n","c = None # Placeholder\n","print(\"Element-wise sum c = a + b:\\n\", c)\n","\n","# 3. Perform matrix multiplication\n","# TODO: Matrix multiply a and b\n","d = None # Placeholder\n","print(\"Matrix product d = a @ b:\\n\", d)\n","\n","# 4. Demonstrate immutability\n","# original_a_id = id(a)\n","# print(f\"Original id(a): {original_a_id}\")\n","\n","# TODO: Perform an operation that reassigns 'a', e.g., a = a + 1\n","# a_new_ref = None # Placeholder\n","# new_a_id = id(a_new_ref)\n","# print(f\"New id(a) after 'a = a + 1': {new_a_id}\")\n","\n","# TODO: Check if original_a_id is different from new_a_id\n","# print(f\"IDs are different: {None}\") # Placeholder"],"metadata":{"id":"8Tq_WFzc5Ycl"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# @title Solution 1: JAX Core & NumPy API\n","key_ex1_sol, main_key = jax.random.split(main_key)\n","\n","# 1. Create JAX arrays a and b\n","a_sol = jax.random.normal(key_ex1_sol, (2, 2))\n","b_sol = jnp.ones((2, 2))\n","\n","print(\"Array a:\\n\", a_sol)\n","print(\"Array b:\\n\", b_sol)\n","\n","# 2. Perform element-wise addition\n","c_sol = a_sol + b_sol\n","print(\"Element-wise sum c = a + b:\\n\", c_sol)\n","\n","# 3. Perform matrix multiplication\n","d_sol = jnp.dot(a_sol, b_sol) # or d = a @ b\n","print(\"Matrix product d = a @ b:\\n\", d_sol)\n","\n","# 4. Demonstrate immutability\n","original_a_id_sol = id(a_sol)\n","print(f\"Original id(a_sol): {original_a_id_sol}\")\n","\n","a_sol_new_ref = a_sol + 1 # This creates a new array and rebinds the Python variable.\n","new_a_id_sol = id(a_sol_new_ref)\n","print(f\"New id(a_sol_new_ref) after 'a_sol = a_sol + 1': {new_a_id_sol}\")\n","print(f\"IDs are different: {original_a_id_sol != new_a_id_sol}\")\n","print(\"This shows that the original array was not modified in-place; a new array was created.\")"],"metadata":{"id":"0p2HrUzH6NYQ"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Exercise 2: jax.jit (Just-In-Time Compilation)\n","\n","**Goal**: Understand how to use jax.jit to compile JAX functions for performance.\n","\n","### Instructions:\n","\n","1. Define a Python function compute_heavy_stuff(x, w, b) that performs a sequence of jnp operations:\n"," - y = jnp.dot(x, w)\n"," - y = y + b\n"," - y = jnp.tanh(y)\n"," - result = jnp.sum(y)\n"," - Return result.\n","2. Create a JIT-compiled version of this function, fast_compute_heavy_stuff, using jax.jit.\n","3. Create some large dummy JAX arrays for x, w, and b.\n","4. Call both the original and JIT-compiled functions with the dummy data.\n","5. (Optional) Use the `%timeit` magic command in Colab (in separate cells) to compare their execution speeds. Remember that the first call to a JIT-compiled function includes compilation time."],"metadata":{"id":"MK4rErEp6WPx"}},{"cell_type":"code","source":["# Instructions for Exercise 2\n","key_ex2_main, main_key = jax.random.split(main_key)\n","key_ex2_x, key_ex2_w, key_ex2_b = jax.random.split(key_ex2_main, 3)\n","\n","# 1. Define the Python function\n","def compute_heavy_stuff(x, w, b):\n"," # TODO: Implement the operations\n"," y1 = None # Placeholder\n"," y2 = None # Placeholder\n"," y3 = None # Placeholder\n"," result = None # Placeholder\n"," return result\n","\n","# 2. Create a JIT-compiled version\n","# TODO: Use jax.jit to compile compute_heavy_stuff\n","fast_compute_heavy_stuff = None # Placeholder\n","\n","# 3. Create dummy data\n","dim1, dim2, dim3 = 500, 1000, 500\n","x_data = jax.random.normal(key_ex2_x, (dim1, dim2))\n","w_data = jax.random.normal(key_ex2_w, (dim2, dim3))\n","b_data = jax.random.normal(key_ex2_b, (dim3,))\n","\n","# 4. Call both functions\n","result_original = None # Placeholder compute_heavy_stuff(x_data, w_data, b_data)\n","result_fast_first_call = None # Placeholder fast_compute_heavy_stuff(x_data, w_data, b_data) # First call (compiles)\n","result_fast_second_call = None # Placeholder fast_compute_heavy_stuff(x_data, w_data, b_data) # Second call (uses compiled)\n","\n","print(f\"Result (original): {result_original}\")\n","print(f\"Result (fast, 1st call): {result_fast_first_call}\")\n","print(f\"Result (fast, 2nd call): {result_fast_second_call}\")\n","\n","# if result_original is not None and result_fast_first_call is not None:\n","# assert jnp.allclose(result_original, result_fast_first_call), \"Results should match!\"\n","# print(\"\\nResults from original and JIT-compiled functions match.\")\n","\n","# 5. Optional: Timing (use %timeit in separate cells for accuracy)\n","# print(\"\\nTo see the speed difference, run these in separate cells:\")\n","# print(\"%timeit compute_heavy_stuff(x_data, w_data, b_data).block_until_ready()\")\n","# print(\"%timeit fast_compute_heavy_stuff(x_data, w_data, b_data).block_until_ready()\")"],"metadata":{"id":"SNwAyNyO6SM3"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# @title Solution 2: `jax.jit` (Just-In-Time Compilation)\n","key_ex2_sol_main, main_key = jax.random.split(main_key)\n","key_ex2_sol_x, key_ex2_sol_w, key_ex2_sol_b = jax.random.split(key_ex2_sol_main, 3)\n","\n","# 1. Define the Python function\n","def compute_heavy_stuff_sol(x, w, b):\n"," y = jnp.dot(x, w)\n"," y = y + b\n"," y = jnp.tanh(y)\n"," result = jnp.sum(y)\n"," return result\n","\n","# 2. Create a JIT-compiled version\n","fast_compute_heavy_stuff_sol = jax.jit(compute_heavy_stuff_sol)\n","\n","# 3. Create dummy data\n","dim1_sol, dim2_sol, dim3_sol = 500, 1000, 500\n","x_data_sol = jax.random.normal(key_ex2_sol_x, (dim1_sol, dim2_sol))\n","w_data_sol = jax.random.normal(key_ex2_sol_w, (dim2_sol, dim3_sol))\n","b_data_sol = jax.random.normal(key_ex2_sol_b, (dim3_sol,))\n","\n","# 4. Call both functions\n","# Call original once to ensure it's not timed with any JAX overhead if it were the first JAX op\n","result_original_sol = compute_heavy_stuff_sol(x_data_sol, w_data_sol, b_data_sol).block_until_ready()\n","\n","# First call to JITed function includes compilation time\n","result_fast_sol_first_call = fast_compute_heavy_stuff_sol(x_data_sol, w_data_sol, b_data_sol).block_until_ready()\n","\n","# Subsequent calls use the cached compiled code\n","result_fast_sol_second_call = fast_compute_heavy_stuff_sol(x_data_sol, w_data_sol, b_data_sol).block_until_ready()\n","\n","print(f\"Result (original): {result_original_sol}\")\n","print(f\"Result (fast, 1st call): {result_fast_sol_first_call}\")\n","print(f\"Result (fast, 2nd call): {result_fast_sol_second_call}\")\n","\n","assert jnp.allclose(result_original_sol, result_fast_sol_first_call), \"Results should match!\"\n","print(\"\\nResults from original and JIT-compiled functions match.\")\n","\n","# 5. Optional: Timing\n","# To accurately measure, run these in separate Colab cells:\n","# Cell 1:\n","# %timeit compute_heavy_stuff_sol(x_data_sol, w_data_sol, b_data_sol).block_until_ready()\n","# Cell 2:\n","# %timeit fast_compute_heavy_stuff_sol(x_data_sol, w_data_sol, b_data_sol).block_until_ready()\n","# You should observe that the JIT-compiled version is significantly faster after the initial compilation.\n","print(\"\\nTo see the speed difference, run the %timeit commands (provided in comments above) in separate cells.\")"],"metadata":{"id":"xOLQxFay61ls"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Exercise 3: jax.grad (Automatic Differentiation)\n","\n","**Goal**: Learn to use jax.grad to compute gradients of functions.\n","\n","### Instructions:\n","\n","1. Define a Python function scalar_loss(params, x, y_true) that:\n"," - Takes a dictionary params with keys 'w' and 'b'.\n"," - Computes y_pred = params['w'] * x + params['b'].\n"," - Returns a scalar loss, e.g., jnp.mean((y_pred - y_true)**2).\n","2. Use jax.grad to create a new function, compute_gradients, that computes the gradient of scalar_loss with respect to its first argument (params).\n","3. Initialize some dummy params, x_input, and y_target values.\n","4. Call compute_gradients to get the gradients. Print the gradients."],"metadata":{"id":"MNZqLNB57CpS"}},{"cell_type":"code","source":["# Instructions for Exercise 3\n","\n","# 1. Define the scalar_loss function\n","def scalar_loss(params: Dict[str, jnp.ndarray], x: jnp.ndarray, y_true: jnp.ndarray) -> jnp.ndarray:\n"," # TODO: Implement the prediction and loss calculation\n"," y_pred = None # Placeholder\n"," loss = None # Placeholder\n"," return loss\n","\n","# 2. Create the gradient function using jax.grad\n","# TODO: Gradient of scalar_loss w.r.t. 'params' (argnums=0)\n","compute_gradients = None # Placeholder\n","\n","# 3. Initialize dummy data\n","params_init = {'w': jnp.array(2.0), 'b': jnp.array(1.0)}\n","x_input_data = jnp.array([1.0, 2.0, 3.0])\n","y_target_data = jnp.array([7.0, 9.0, 11.0]) # Targets for y = 3x + 4 (to make non-zero loss with init_params)\n","\n","# 4. Call the gradient function\n","gradients = None # Placeholder compute_gradients(params_init, x_input_data, y_target_data)\n","print(\"Initial params:\", params_init)\n","print(\"Gradients w.r.t params:\\n\", gradients)\n","\n","# Expected gradients (manual calculation for y_pred = wx+b, loss = mean((y_pred - y_true)^2)):\n","# dL/dw = mean(2 * (wx+b - y_true) * x)\n","# dL/db = mean(2 * (wx+b - y_true) * 1)\n","# For params_init={'w': 2.0, 'b': 1.0}, x=[1,2,3], y_true=[7,9,11]\n","# x=1: y_pred = 2*1+1 = 3. Error = 3-7 = -4. dL/dw_i_term = 2*(-4)*1 = -8. dL/db_i_term = 2*(-4)*1 = -8\n","# x=2: y_pred = 2*2+1 = 5. Error = 5-9 = -4. dL/dw_i_term = 2*(-4)*2 = -16. dL/db_i_term = 2*(-4)*1 = -8\n","# x=3: y_pred = 2*3+1 = 7. Error = 7-11 = -4. dL/dw_i_term = 2*(-4)*3 = -24. dL/db_i_term = 2*(-4)*1 = -8\n","# Mean gradients: dL/dw = (-8-16-24)/3 = -48/3 = -16. dL/db = (-8-8-8)/3 = -24/3 = -8.\n","# if gradients is not None:\n","# assert jnp.isclose(gradients['w'], -16.0)\n","# assert jnp.isclose(gradients['b'], -8.0)\n","# print(\"\\nGradients match expected values.\")"],"metadata":{"id":"g8S-6snP69KI"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# @title Solution 3: `jax.grad` (Automatic Differentiation)\n","\n","# 1. Define the scalar_loss function\n","def scalar_loss_sol(params: Dict[str, jnp.ndarray], x: jnp.ndarray, y_true: jnp.ndarray) -> jnp.ndarray:\n"," y_pred = params['w'] * x + params['b']\n"," loss = jnp.mean((y_pred - y_true)**2)\n"," return loss\n","\n","# 2. Create the gradient function using jax.grad\n","# Gradient of scalar_loss w.r.t. 'params' (which is the 0-th argument)\n","compute_gradients_sol = jax.grad(scalar_loss_sol, argnums=0)\n","\n","# 3. Initialize dummy data\n","params_init_sol = {'w': jnp.array(2.0), 'b': jnp.array(1.0)}\n","x_input_data_sol = jnp.array([1.0, 2.0, 3.0])\n","y_target_data_sol = jnp.array([7.0, 9.0, 11.0])\n","\n","# 4. Call the gradient function\n","gradients_sol = compute_gradients_sol(params_init_sol, x_input_data_sol, y_target_data_sol)\n","print(\"Initial params:\", params_init_sol)\n","print(\"Gradients w.r.t params:\\n\", pprint.pformat(gradients_sol))\n","\n","# Verify with expected values (calculated in instructions)\n","expected_dL_dw = -16.0\n","expected_dL_db = -8.0\n","assert jnp.isclose(gradients_sol['w'], expected_dL_dw), f\"Grad w.r.t 'w' is {gradients_sol['w']}, expected {expected_dL_dw}\"\n","assert jnp.isclose(gradients_sol['b'], expected_dL_db), f\"Grad w.r.t 'b' is {gradients_sol['b']}, expected {expected_dL_db}\"\n","print(\"\\nGradients match expected values.\")"],"metadata":{"id":"jcjiql4O7ZQy"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Exercise 4: jax.vmap (Automatic Vectorization)\n","\n","**Goal**: Use jax.vmap to automatically batch operations.\n","\n","### Instructions:\n","\n","1. Define a function apply_affine(vector, matrix, bias) that takes a single 1D vector, a 2D matrix, and a 1D bias. It should compute jnp.dot(matrix, vector) + bias.\n","2. You have a batch of vectors (a 2D array where each row is a vector), but a single matrix and a single bias that should be applied to each vector in the batch.\n","3. Use jax.vmap to create batched_apply_affine that efficiently applies apply_affine to each vector in the batch.\n"," - Hint: in_axes for jax.vmap should specify 0 for the batched vector argument, and None for matrix and bias as they are not batched (broadcasted). The out_axes should be 0 to indicate the output is batched along the first axis.\n","4. Test batched_apply_affine with sample data."],"metadata":{"id":"XWoB6bD-7g2M"}},{"cell_type":"code","source":["# Instructions for Exercise 4\n","key_ex4_main, main_key = jax.random.split(main_key)\n","key_ex4_vec, key_ex4_mat, key_ex4_bias = jax.random.split(key_ex4_main, 3)\n","\n","# 1. Define apply_affine for a single vector\n","def apply_affine(vector: jnp.ndarray, matrix: jnp.ndarray, bias: jnp.ndarray) -> jnp.ndarray:\n"," # TODO: Compute jnp.dot(matrix, vector) + bias\n"," result = None # Placeholder\n"," return result\n","\n","# 2. Prepare data\n","batch_size = 4\n","input_features = 3\n","output_features = 2\n","\n","# batch_of_vectors: (batch_size, input_features)\n","# single_matrix: (output_features, input_features)\n","# single_bias: (output_features,)\n","batch_of_vectors = jax.random.normal(key_ex4_vec, (batch_size, input_features))\n","single_matrix = jax.random.normal(key_ex4_mat, (output_features, input_features))\n","single_bias = jax.random.normal(key_ex4_bias, (output_features,))\n","\n","\n","# 3. Use jax.vmap to create batched_apply_affine\n","# TODO: Specify in_axes correctly: vector is batched, matrix and bias are not. out_axes should be 0.\n","batched_apply_affine = None # Placeholder jax.vmap(apply_affine, in_axes=(..., ... , ...), out_axes=...)\n","\n","\n","# 4. Test batched_apply_affine\n","result_vmap = None # Placeholder batched_apply_affine(batch_of_vectors, single_matrix, single_bias)\n","print(\"Batch of vectors shape:\", batch_of_vectors.shape)\n","print(\"Single matrix shape:\", single_matrix.shape)\n","print(\"Single bias shape:\", single_bias.shape)\n","if result_vmap is not None:\n"," print(\"Result using vmap shape:\", result_vmap.shape) # Expected: (batch_size, output_features)\n","\n"," # For comparison, a manual loop (less efficient):\n"," # manual_results = []\n"," # for i in range(batch_size):\n"," # manual_results.append(apply_affine(batch_of_vectors[i], single_matrix, single_bias))\n"," # result_manual_loop = jnp.stack(manual_results)\n"," # assert jnp.allclose(result_vmap, result_manual_loop)\n"," # print(\"vmap result matches manual loop result.\")\n","else:\n"," print(\"result_vmap is None.\")"],"metadata":{"id":"vA9mu1si7dii"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# @title Solution 4: `jax.vmap` (Automatic Vectorization)\n","key_ex4_sol_main, main_key = jax.random.split(main_key)\n","key_ex4_sol_vec, key_ex4_sol_mat, key_ex4_sol_bias = jax.random.split(key_ex4_sol_main, 3)\n","\n","# 1. Define apply_affine for a single vector\n","def apply_affine_sol(vector: jnp.ndarray, matrix: jnp.ndarray, bias: jnp.ndarray) -> jnp.ndarray:\n"," return jnp.dot(matrix, vector) + bias\n","\n","# 2. Prepare data\n","batch_size_sol = 4\n","input_features_sol = 3\n","output_features_sol = 2\n","\n","batch_of_vectors_sol = jax.random.normal(key_ex4_sol_vec, (batch_size_sol, input_features_sol))\n","single_matrix_sol = jax.random.normal(key_ex4_sol_mat, (output_features_sol, input_features_sol))\n","single_bias_sol = jax.random.normal(key_ex4_sol_bias, (output_features_sol,))\n","\n","# 3. Use jax.vmap to create batched_apply_affine\n","# Vector is batched along axis 0, matrix and bias are not batched (broadcasted).\n","# out_axes=0 means the output will also be batched along its first axis.\n","batched_apply_affine_sol = jax.vmap(apply_affine_sol, in_axes=(0, None, None), out_axes=0)\n","\n","# 4. Test batched_apply_affine\n","result_vmap_sol = batched_apply_affine_sol(batch_of_vectors_sol, single_matrix_sol, single_bias_sol)\n","print(\"Batch of vectors shape:\", batch_of_vectors_sol.shape)\n","print(\"Single matrix shape:\", single_matrix_sol.shape)\n","print(\"Single bias shape:\", single_bias_sol.shape)\n","print(\"Result using vmap shape:\", result_vmap_sol.shape) # Expected: (batch_size, output_features)\n","assert result_vmap_sol.shape == (batch_size_sol, output_features_sol)\n","\n","# For comparison, a manual loop (less efficient):\n","manual_results_sol = []\n","for i in range(batch_size_sol):\n"," manual_results_sol.append(apply_affine_sol(batch_of_vectors_sol[i], single_matrix_sol, single_bias_sol))\n","result_manual_loop_sol = jnp.stack(manual_results_sol)\n","\n","assert jnp.allclose(result_vmap_sol, result_manual_loop_sol)\n","print(\"\\nvmap result matches manual loop result, demonstrating correct vectorization.\")"],"metadata":{"id":"q1QkKEtF76yo"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Exercise 5: Flax NNX - Defining a Model\n","\n","**Goal**: Learn to define a simple neural network model using Flax NNX.\n","\n","### Instructions:\n","\n","1. Define a Flax NNX model class SimpleNNXModel that inherits from nnx.Module.\n","2. In its __init__, define one nnx.Linear layer. The layer should take din (input features) and dout (output features) as arguments. Remember to pass the rngs argument to nnx.Linear for parameter initialization (e.g., rngs=rngs).\n","3. Implement the __call__ method (the forward pass) which takes an input x and passes it through the linear layer.\n","4. Instantiate your SimpleNNXModel. You'll need to create an nnx.Rngs object using a JAX PRNG key (e.g., nnx.Rngs(params=jax.random.key(seed))). The key name params is conventional for nnx.Linear.\n","5. Test your model instance with a dummy input batch. Print the output and the model's state (parameters) using nnx.display()."],"metadata":{"id":"3LAlhdzq8D_S"}},{"cell_type":"code","source":["# Instructions for Exercise 5\n","key_ex5_model_init, main_key = jax.random.split(main_key)\n","\n","# 1. & 2. & 3. Define the SimpleNNXModel\n","class SimpleNNXModel(nnx.Module):\n"," def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):\n"," # TODO: Define an nnx.Linear layer named 'dense_layer'\n"," # self.dense_layer = nnx.Linear(...)\n"," self.some_attribute = None # Placeholder, remove later\n"," pass # Remove this placeholder if class is not empty\n","\n"," def __call__(self, x: jnp.ndarray) -> jnp.ndarray:\n"," # TODO: Pass input x through the dense_layer\n"," # return self.dense_layer(x)\n"," return x # Placeholder\n","\n","# 4. Instantiate the model\n","model_din = 3\n","model_dout = 2\n","# TODO: Create nnx.Rngs for parameter initialization. Use 'params' as the key name.\n","model_rngs = None # Placeholder nnx.Rngs(params=key_ex5_model_init)\n","my_model = None # Placeholder SimpleNNXModel(din=model_din, dout=model_dout, rngs=model_rngs)\n","\n","# 5. Test with dummy data\n","dummy_batch_size = 4\n","dummy_input_ex5 = jnp.ones((dummy_batch_size, model_din))\n","\n","model_output = None # Placeholder\n","if my_model is not None:\n"," model_output = my_model(dummy_input_ex5)\n"," print(f\"Model output shape: {model_output.shape}\")\n"," print(f\"Model output:\\n{model_output}\")\n","\n"," model_state = my_model.get_state()\n"," print(f\"\\nModel state (parameters, etc.):\")\n"," pprint.pprint(model_state)\n","else:\n"," print(\"my_model is None.\")"],"metadata":{"id":"BzUjMHll7--R"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# @title Solution 5: Flax NNX - Defining a Model\n","key_ex5_sol_model_init, main_key = jax.random.split(main_key)\n","\n","# 1. & 2. & 3. Define the SimpleNNXModel\n","class SimpleNNXModel_Sol(nnx.Module): # Renamed for solution cell\n"," def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):\n"," # nnx.Linear will use the 'params' key from rngs by default for its parameters\n"," self.dense_layer = nnx.Linear(din, dout, rngs=rngs)\n","\n"," def __call__(self, x: jnp.ndarray) -> jnp.ndarray:\n"," return self.dense_layer(x)\n","\n","# 4. Instantiate the model\n","model_din_sol = 3\n","model_dout_sol = 2\n","# Create nnx.Rngs for parameter initialization.\n","# 'params' is the default key nnx.Linear looks for in the rngs object.\n","model_rngs_sol = nnx.Rngs(params=key_ex5_sol_model_init)\n","my_model_sol = SimpleNNXModel_Sol(din=model_din_sol, dout=model_dout_sol, rngs=model_rngs_sol)\n","\n","# 5. Test with dummy data\n","dummy_batch_size_sol = 4\n","dummy_input_ex5_sol = jnp.ones((dummy_batch_size_sol, model_din_sol))\n","\n","model_output_sol = my_model_sol(dummy_input_ex5_sol)\n","print(f\"Model output shape: {model_output_sol.shape}\")\n","print(f\"Model output:\\n{model_output_sol}\")\n","\n","# model_state_sol = my_model_sol.get_state()\n","_, model_state_sol = nnx.split(my_model_sol)\n","print(f\"\\nModel state (parameters, etc.):\")\n","nnx.display(model_state_sol)\n","\n","# Check that parameters are present\n","assert 'dense_layer' in model_state_sol, \"Key 'dense_layer' not in model_state\"\n","assert 'kernel' in model_state_sol['dense_layer'], \"Key 'kernel' not in model_state['dense_layer']\"\n","assert 'bias' in model_state_sol['dense_layer'], \"Key 'bias' not in model_state['dense_layer']\"\n","print(\"\\nModel parameters (kernel and bias for dense_layer) are present in the state.\")"],"metadata":{"id":"QbBqSZse8V9y"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Exercise 6: Optax & Flax NNX - Creating an Optimizer\n","\n","**Goal**: Set up an Optax optimizer and wrap it with nnx.Optimizer for use with a Flax NNX model.\n","\n","### Instructions:\n","1. Use the SimpleNNXModel_Sol class and an instance my_model_sol from the previous exercise's solution. (If running standalone, re-instantiate it).\n","2. Create an Optax optimizer, for example, optax.adam with a learning rate of 0.001.\n","3. Create an nnx.Optimizer instance. This wrapper links the Optax optimizer with your Flax NNX model (my_model_sol).\n","4. Print the nnx.Optimizer instance and its state attribute to see the initialized optimizer state (e.g., Adam's momentum terms)."],"metadata":{"id":"i4kuv2IH-FbA"}},{"cell_type":"code","source":["# Instructions for Exercise 6\n","\n","# 1. Assume my_model_sol is available from Exercise 5 solution\n","# (If running standalone, re-instantiate it)\n","if 'my_model_sol' not in globals():\n"," print(\"Re-initializing model from Ex5 solution for Ex6.\")\n"," key_ex6_model_init, main_key = jax.random.split(main_key)\n"," _model_din_ex6 = 3\n"," _model_dout_ex6 = 2\n"," _model_rngs_ex6 = nnx.Rngs(params=key_ex6_model_init)\n"," # Use solution class name if defined, otherwise student's class name\n"," _ModelClass = SimpleNNXModel_Sol if 'SimpleNNXModel_Sol' in globals() else SimpleNNXModel\n"," model_for_opt = _ModelClass(din=_model_din_ex6, dout=_model_dout_ex6, rngs=_model_rngs_ex6)\n"," print(\"Model for optimizer created.\")\n","else:\n"," model_for_opt = my_model_sol # Use the one from previous solution\n"," print(\"Using model 'my_model_sol' from previous exercise for 'model_for_opt'.\")\n","\n","\n","# 2. Create an Optax optimizer\n","learning_rate = 0.001\n","# TODO: Create an optax.adam optimizer transform\n","optax_tx = None # Placeholder optax.adam(...)\n","\n","# 3. Create an nnx.Optimizer wrapper\n","# TODO: Wrap the model (model_for_opt) and the optax transform (optax_tx)\n","# The `wrt` argument is now required to specify what to differentiate with respect to.\n","nnx_optimizer = None # Placeholder nnx.Optimizer(...)\n","\n","# 4. Print the optimizer and its state\n","print(\"\\nFlax NNX Optimizer wrapper:\")\n","nnx.display(nnx_optimizer)\n","\n","print(\"\\nInitial Optimizer State (Optax state, e.g., Adam's momentum):\")\n","if nnx_optimizer is not None and hasattr(nnx_optimizer, 'opt_state'):\n"," pprint.pprint(nnx_optimizer.state)\n"," # if hasattr(nnx_optimizer, 'opt_state'):\n"," # adam_state = nnx_optimizer.opt_state\n"," # assert len(adam_state) > 0 and hasattr(adam_state[0], 'count')\n"," # print(\"\\nOptimizer state structure looks plausible for Adam.\")\n","else:\n"," print(\"nnx_optimizer or its state is None or not structured as expected.\")"],"metadata":{"id":"ytaIj3xK8ZMI"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# @title Solution 6: Optax & Flax NNX - Creating an Optimizer\n","\n","# 1. Use my_model_sol from Exercise 5 solution\n","# If not run sequentially, ensure my_model_sol is defined:\n","if 'my_model_sol' not in globals():\n"," print(\"Re-initializing model from Ex5 solution for Ex6.\")\n"," key_ex6_sol_model_init, main_key = jax.random.split(main_key)\n"," _model_din_sol_ex6 = 3\n"," _model_dout_sol_ex6 = 2\n"," _model_rngs_sol_ex6 = nnx.Rngs(params=key_ex6_sol_model_init)\n"," # Ensure SimpleNNXModel_Sol is used\n"," my_model_sol = SimpleNNXModel_Sol(din=_model_din_sol_ex6, dout=_model_dout_sol_ex6, rngs=_model_rngs_sol_ex6)\n"," print(\"Model for optimizer re-created as 'my_model_sol'.\")\n","else:\n"," print(\"Using model 'my_model_sol' from previous exercise.\")\n","\n","\n","# 2. Create an Optax optimizer\n","learning_rate_sol = 0.001\n","# Create an optax.adam optimizer transform\n","optax_tx_sol = optax.adam(learning_rate=learning_rate_sol)\n","\n","# 3. Create an nnx.Optimizer wrapper\n","# This links the model and the Optax optimizer.\n","# The optimizer state will be initialized based on the model's parameters.\n","nnx_optimizer_sol = nnx.Optimizer(my_model_sol, optax_tx_sol, wrt=nnx.Param)\n","\n","# 4. Print the optimizer and its state\n","print(\"\\nFlax NNX Optimizer wrapper:\")\n","nnx.display(nnx_optimizer_sol) # Shows the model it's associated with and the Optax transform\n","\n","print(\"\\nInitial Optimizer State (Optax state, e.g., Adam's momentum):\")\n","# nnx.Optimizer stores the actual Optax state in its .opt_state attribute.\n","# This state is a PyTree that matches the structure of the model's parameters.\n","pprint.pprint(nnx_optimizer_sol.opt_state)\n","\n","# Verify the structure of the optimizer state for Adam (count, mu, nu for each param)\n","assert hasattr(nnx_optimizer_sol, 'opt_state'), \"Optax opt_state not found in nnx.Optimizer\"\n","# The opt_state is a tuple, typically (CountState(), ScaleByAdamState()) for adam\n","adam_optax_internal_state = nnx_optimizer_sol.opt_state\n","assert len(adam_optax_internal_state) > 0 and hasattr(adam_optax_internal_state[0], 'count'), \"Adam 'count' state not found.\"\n","# The second element of the tuple is often where parameter-specific states like mu and nu reside\n","if len(adam_optax_internal_state) > 1 and hasattr(adam_optax_internal_state[1], 'mu'):\n"," param_specific_state = adam_optax_internal_state[1]\n"," assert 'dense_layer' in param_specific_state.mu and 'kernel' in param_specific_state.mu['dense_layer'], \"Adam 'mu' state for kernel not found.\"\n"," print(\"\\nOptimizer state structure looks correct for Adam.\")\n","else:\n"," print(\"\\nWarning: Optimizer state structure for Adam might be different or not fully verified.\")"],"metadata":{"id":"f1ccATgB-Zed"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Exercise 7: Training Step with Flax NNX and Optax\n","\n","**Goal**: Implement a complete JIT-compiled training step for a Flax NNX model using Optax.\n","\n","### Instructions:\n","\n","1. You'll need:\n"," - An instance of your model class (e.g., my_model_sol from Ex 5/6 solution).\n"," - An instance of nnx.Optimizer (e.g., nnx_optimizer_sol from Ex 6 solution).\n","2. Define a train_step function that is decorated with @nnx.jit. This function should take the model, optimizer, input x_batch, and target y_batch as arguments.\n","3. Inside train_step:\n"," - Define an inner loss_fn_for_grad. This function must take the model as its first argument. Inside, it computes the model's predictions for x_batch and then calculates the mean squared error (MSE) against y_batch.\n"," - Use nnx.value_and_grad(loss_fn_for_grad)(model_arg) to compute both the loss value and the gradients with respect to the model passed to loss_fn_for_grad. (model_arg is the model instance passed into train_step).\n"," - Update the model's parameters (and the optimizer's state) using optimizer_arg.update(model_arg, grads). The update method takes the model and gradients, and updates the model's state in-place.\n"," - Return the computed loss_value.\n","4. Create dummy x_batch and y_batch data.\n","5. Call your train_step function. Print the returned loss.\n","6. (Optional) Verify that the model's parameters have changed after the train_step by comparing a parameter value before and after the call."],"metadata":{"id":"i7jXowc9ACNB"}},{"cell_type":"code","source":["# Instructions for Exercise 7\n","key_ex7_main, main_key = jax.random.split(main_key)\n","key_ex7_x, key_ex7_y = jax.random.split(key_ex7_main, 2)\n","\n","# 1. Use model and optimizer from previous exercises' solutions\n","# Ensure my_model_sol and nnx_optimizer_sol are available\n","if 'my_model_sol' not in globals() or 'nnx_optimizer_sol' not in globals():\n"," print(\"Re-initializing model and optimizer from Ex5/Ex6 solutions for Ex7.\")\n"," key_ex7_model_fallback, main_key = jax.random.split(main_key)\n"," _model_din_ex7 = 3\n"," _model_dout_ex7 = 2\n"," _model_rngs_ex7 = nnx.Rngs(params=key_ex7_model_fallback)\n"," # Ensure SimpleNNXModel_Sol is used\n"," my_model_ex7 = SimpleNNXModel_Sol(din=_model_din_ex7, dout=_model_dout_ex7, rngs=_model_rngs_ex7)\n"," _optax_tx_ex7 = optax.adam(learning_rate=0.001)\n"," nnx_optimizer_ex7 = nnx.Optimizer(my_model_ex7, _optax_tx_ex7)\n"," print(\"Model and optimizer re-created for Ex7.\")\n","else:\n"," my_model_ex7 = my_model_sol\n"," nnx_optimizer_ex7 = nnx_optimizer_sol\n"," print(\"Using 'my_model_sol' and 'nnx_optimizer_sol' for 'my_model_ex7' and 'nnx_optimizer_ex7'.\")\n","\n","\n","# 2. & 3. Define the train_step function\n","# TODO: Decorate with @nnx.jit\n","# def train_step(model_arg: nnx.Module, optimizer_arg: nnx.Optimizer, # Type hint with base nnx.Module\n","# x_batch: jnp.ndarray, y_batch: jnp.ndarray) -> jnp.ndarray:\n","\n"," # TODO: Define inner loss_fn_for_grad(current_model_state_for_grad_fn)\n"," # def loss_fn_for_grad(model_in_grad_fn: nnx.Module): # Type hint with base nnx.Module\n"," # y_pred = model_in_grad_fn(x_batch)\n"," # loss = jnp.mean((y_pred - y_batch)**2)\n"," # return loss\n"," # return jnp.array(0.0) # Placeholder\n","\n"," # TODO: Compute loss value and gradients using nnx.value_and_grad\n"," # loss_value, grads = nnx.value_and_grad(loss_fn_for_grad)(model_arg) # Pass model_arg\n","\n"," # TODO: Update the optimizer (which updates the model_arg in-place)\n"," # optimizer_arg.update(model_arg, grads)\n","\n"," # return loss_value\n","# return jnp.array(0.0) # Placeholder defined train_step function\n","\n","# For the student to define:\n","# Make sure the function signature is correct for nnx.jit\n","@nnx.jit\n","def train_step(model_arg: nnx.Module, optimizer_arg: nnx.Optimizer,\n"," x_batch: jnp.ndarray, y_batch: jnp.ndarray) -> jnp.ndarray:\n"," # Placeholder implementation for student\n"," def loss_fn_for_grad(model_in_grad_fn: nnx.Module):\n"," # y_pred = model_in_grad_fn(x_batch)\n"," # loss = jnp.mean((y_pred - y_batch)**2)\n"," # return loss\n"," return jnp.array(0.0) # Student TODO: replace this\n","\n"," # loss_value, grads = nnx.value_and_grad(loss_fn_for_grad)(model_arg)\n"," # optimizer_arg.update(grads)\n"," # return loss_value\n"," return jnp.array(-1.0) # Student TODO: replace this\n","\n","\n","# 4. Create dummy data\n","batch_s = 8\n","# Access features_in and features_out carefully\n","_din_from_model_ex7 = my_model_ex7.dense_layer.in_features if hasattr(my_model_ex7, 'dense_layer') else 3\n","_dout_from_model_ex7 = my_model_ex7.dense_layer.out_features if hasattr(my_model_ex7, 'dense_layer') else 2\n","\n","x_batch_data = jax.random.normal(key_ex7_x, (batch_s, _din_from_model_ex7))\n","y_batch_data = jax.random.normal(key_ex7_y, (batch_s, _dout_from_model_ex7))\n","\n","# Optional: Store initial param value for comparison\n","initial_kernel_val = None\n","if hasattr(my_model_ex7, 'get_state'):\n"," _current_model_state_ex7 = my_model_ex7.get_state()\n"," if 'dense_layer' in _current_model_state_ex7:\n"," initial_kernel_val = _current_model_state_ex7['dense_layer']['kernel'].value[0,0].copy()\n","print(f\"Initial kernel value (sample): {initial_kernel_val}\")\n","\n","# 5. Call the train_step\n","# loss_after_step = train_step(my_model_ex7, nnx_optimizer_ex7, x_batch_data, y_batch_data) # Student will uncomment\n","loss_after_step = jnp.array(-1.0) # Placeholder until student implements train_step\n","if train_step(my_model_ex7, nnx_optimizer_ex7, x_batch_data, y_batch_data).item() != -1.0: # Check if student implemented\n"," loss_after_step = train_step(my_model_ex7, nnx_optimizer_ex7, x_batch_data, y_batch_data)\n"," print(f\"Loss after one training step: {loss_after_step}\")\n","else:\n"," print(\"Student needs to implement `train_step` function.\")\n","\n","\n","# # 6. Optional: Verify parameter change\n","# updated_kernel_val_sol = None\n","# _, updated_model_state_sol = nnx.split(my_model_sol_ex7) # Get state again after update\n","# if 'dense_layer' in updated_model_state_sol:\n","# updated_kernel_val_sol = updated_model_state_sol['dense_layer']['kernel'].value[0,0]\n","# print(f\"Updated kernel value (sample): {updated_kernel_val_sol}\")\n","\n","# if initial_kernel_val_sol is not None and updated_kernel_val_sol is not None:\n","# assert not jnp.allclose(initial_kernel_val_sol, updated_kernel_val_sol), \"Kernel parameter did not change!\"\n","# print(\"Kernel parameter changed as expected after the training step.\")\n","# else:\n","# print(\"Could not verify kernel change (initial or updated value was None).\")"],"metadata":{"id":"KEQCcmBI-ce2"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# @title Solution 7: Training Step with Flax NNX and Optax\n","key_ex7_sol_main, main_key = jax.random.split(main_key)\n","key_ex7_sol_x, key_ex7_sol_y = jax.random.split(key_ex7_sol_main, 2)\n","\n","# 1. Use model and optimizer from previous exercises' solutions\n","if 'my_model_sol' not in globals() or 'nnx_optimizer_sol' not in globals():\n"," print(\"Re-initializing model and optimizer from Ex5/Ex6 solutions for Ex7 solution.\")\n"," key_ex7_sol_model_fallback, main_key = jax.random.split(main_key)\n"," _model_din_sol_ex7 = 3\n"," _model_dout_sol_ex7 = 2\n"," _model_rngs_sol_ex7 = nnx.Rngs(params=key_ex7_sol_model_fallback)\n"," # Ensure SimpleNNXModel_Sol is used for the solution\n"," my_model_sol_ex7 = SimpleNNXModel_Sol(din=_model_din_sol_ex7, dout=_model_dout_sol_ex7, rngs=_model_rngs_sol_ex7)\n"," _optax_tx_sol_ex7 = optax.adam(learning_rate=0.001)\n"," nnx_optimizer_sol_ex7 = nnx.Optimizer(my_model_sol_ex7, _optax_tx_sol_ex7)\n"," print(\"Model and optimizer re-created for Ex7 solution.\")\n","else:\n"," # If solutions are run sequentially, these will be the correct instances\n"," my_model_sol_ex7 = my_model_sol\n"," nnx_optimizer_sol_ex7 = nnx_optimizer_sol\n"," print(\"Using 'my_model_sol' and 'nnx_optimizer_sol' for Ex7 solution.\")\n","\n","\n","# 2. & 3. Define the train_step function\n","@nnx.jit # Decorate with @nnx.jit for JIT compilation\n","def train_step_sol(model_arg: nnx.Module, optimizer_arg: nnx.Optimizer, # Use base nnx.Module for generality\n"," x_batch: jnp.ndarray, y_batch: jnp.ndarray) -> jnp.ndarray:\n","\n"," # Define inner loss_fn_for_grad. It takes the model as its first argument.\n"," # It captures x_batch and y_batch from the outer scope.\n"," def loss_fn_for_grad(model_in_grad_fn: nnx.Module): # Use base nnx.Module\n"," y_pred = model_in_grad_fn(x_batch) # Use the model passed to this inner function\n"," loss = jnp.mean((y_pred - y_batch)**2)\n"," return loss\n","\n"," # Compute loss value and gradients using nnx.value_and_grad.\n"," # This will differentiate loss_fn_for_grad with respect to its first argument (model_in_grad_fn).\n"," # We pass the current state of our model (model_arg) to it.\n"," loss_value, grads = nnx.value_and_grad(loss_fn_for_grad)(model_arg)\n","\n"," # Update the optimizer. This updates the model_arg (which nnx_optimizer_sol_ex7 references) in-place.\n"," optimizer_arg.update(model_arg, grads)\n","\n"," return loss_value\n","\n","\n","# 4. Create dummy data\n","batch_s_sol = 8\n","# Ensure din and dout match the model instantiation from Ex5/Ex6\n","# my_model_sol_ex7.dense_layer is an nnx.Linear object\n","din_from_model_sol = my_model_sol_ex7.dense_layer.in_features\n","dout_from_model_sol = my_model_sol_ex7.dense_layer.out_features\n","\n","x_batch_data_sol = jax.random.normal(key_ex7_sol_x, (batch_s_sol, din_from_model_sol))\n","y_batch_data_sol = jax.random.normal(key_ex7_sol_y, (batch_s_sol, dout_from_model_sol))\n","\n","# Optional: Store initial param value for comparison\n","initial_kernel_val_sol = None\n","_, current_model_state_sol = nnx.split(my_model_sol_ex7)\n","if 'dense_layer' in current_model_state_sol:\n"," initial_kernel_val_sol = current_model_state_sol['dense_layer']['kernel'].value[0,0].copy()\n","print(f\"Initial kernel value (sample): {initial_kernel_val_sol}\")\n","\n","\n","# 5. Call the train_step\n","# First call will JIT compile the train_step_sol function.\n","loss_after_step_sol = train_step_sol(my_model_sol_ex7, nnx_optimizer_sol_ex7, x_batch_data_sol, y_batch_data_sol)\n","print(f\"Loss after one training step (1st call, JIT): {loss_after_step_sol}\")\n","# Second call to show it's faster (though %timeit is better for measurement)\n","loss_after_step_sol_2 = train_step_sol(my_model_sol_ex7, nnx_optimizer_sol_ex7, x_batch_data_sol, y_batch_data_sol)\n","print(f\"Loss after one training step (2nd call, cached): {loss_after_step_sol_2}\")\n","\n","\n","# 6. Optional: Verify parameter change\n","updated_kernel_val_sol = None\n","_, updated_model_state_sol = nnx.split(my_model_sol_ex7) # Get state again after update\n","if 'dense_layer' in updated_model_state_sol:\n"," updated_kernel_val_sol = updated_model_state_sol['dense_layer']['kernel'].value[0,0]\n"," print(f\"Updated kernel value (sample): {updated_kernel_val_sol}\")\n","\n","if initial_kernel_val_sol is not None and updated_kernel_val_sol is not None:\n"," assert not jnp.allclose(initial_kernel_val_sol, updated_kernel_val_sol), \"Kernel parameter did not change!\"\n"," print(\"Kernel parameter changed as expected after the training step.\")\n","else:\n"," print(\"Could not verify kernel change (initial or updated value was None).\")"],"metadata":{"id":"7bVlg9_-Ae6Z"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Exercise 8: Orbax - Saving and Restoring Checkpoints\n","\n","**Goal**: Learn to use Orbax to save and restore JAX PyTrees, specifically Flax NNX model states and Optax optimizer states.\n","\n","### Instructions:\n","1. You'll need your model (e.g., my_model_sol_ex7) and optimizer (e.g., nnx_optimizer_sol_ex7) from the previous exercise's solution.\n","2. Define a checkpoint directory (e.g., /tmp/my_nnx_checkpoint/).\n","3. Create an Orbax CheckpointManagerOptions and then a CheckpointManager.\n","4. Bundle the states you want to save into a dictionary. For NNX, this is my_model_sol_ex7.get_state() for the model, and nnx_optimizer_sol_ex7.state for the optimizer's internal state. Also include a training step counter.\n","5. Use checkpoint_manager.save() with ocp.args.StandardSave() to save the bundled state. Call checkpoint_manager.wait_until_finished() to ensure saving completes.\n","6. To restore:\n"," - Create new instances of your model (restored_model) and Optax transform (restored_optax_tx). The new model should have a different PRNG key for its initial parameters to demonstrate that restoration works.\n"," - Use checkpoint_manager.restore() with ocp.args.StandardRestore() to load the bundled state.\n"," - Apply the loaded model state to restored_model using restored_model.update_state(loaded_bundle['model']).\n"," - Create a new nnx.Optimizer (restored_optimizer) associating restored_model and restored_optax_tx.\n"," - Assign the loaded optimizer state to the new optimizer: restored_optimizer.state = loaded_bundle['optimizer'].\n","7. Verify that a parameter from restored_model matches the corresponding parameter from the original my_model_sol_ex7 (before saving, or from the saved state). Also, compare optimizer states if possible.\n","8. Clean up the checkpoint directory."],"metadata":{"id":"_t8KGFhqDoSu"}},{"cell_type":"code","source":["# Instructions for Exercise 8\n","# import orbax.checkpoint as ocp # Already imported\n","# import os, shutil # Already imported\n","\n","# 1. Use model and optimizer from previous exercise solution\n","if 'my_model_sol_ex7' not in globals() or 'nnx_optimizer_sol_ex7' not in globals():\n"," print(\"Re-initializing model and optimizer from Ex7 solution for Ex8.\")\n"," key_ex8_model_fallback, main_key = jax.random.split(main_key)\n"," _model_din_ex8 = 3\n"," _model_dout_ex8 = 2\n"," _model_rngs_ex8 = nnx.Rngs(params=key_ex8_model_fallback)\n"," _ModelClassEx8 = SimpleNNXModel_Sol if 'SimpleNNXModel_Sol' in globals() else SimpleNNXModel\n"," model_to_save = _ModelClassEx8(din=_model_din_ex8, dout=_model_dout_ex8, rngs=_model_rngs_ex8)\n"," _optax_tx_ex8 = optax.adam(learning_rate=0.001)\n"," optimizer_to_save = nnx.Optimizer(model_to_save, _optax_tx_ex8)\n"," print(\"Model and optimizer re-created for Ex8.\")\n","else:\n"," model_to_save = my_model_sol_ex7\n"," optimizer_to_save = nnx_optimizer_sol_ex7\n"," print(\"Using model and optimizer from Ex7 solution for Ex8.\")\n","\n","# 2. Define checkpoint directory\n","# TODO: Define checkpoint_dir\n","checkpoint_dir = None # Placeholder e.g., \"/tmp/my_nnx_checkpoint_exercise/\"\n","# if checkpoint_dir and os.path.exists(checkpoint_dir):\n","# shutil.rmtree(checkpoint_dir) # Clean up previous runs for safety\n","# if checkpoint_dir:\n","# os.makedirs(checkpoint_dir, exist_ok=True)\n","\n","\n","# 3. Create Orbax CheckpointManager\n","# TODO: Create options and manager\n","# options = ocp.CheckpointManagerOptions(...)\n","# mngr = ocp.CheckpointManager(...)\n","options = None\n","mngr = None\n","\n","# 4. Bundle states\n","# current_step = 100 # Example step\n","# TODO: Get model_state and optimizer_state\n","# model_state_to_save = nnx.split(model_to_save)\n","# The optimizer state is now accessed via the .state attribute.\n","# opt_state_to_save = optimizer_to_save.state\n","# save_bundle = {\n","# 'model': model_state_to_save,\n","# 'optimizer': opt_state_to_save,\n","# 'step': current_step\n","# }\n","save_bundle = None\n","\n","# 5. Save the checkpoint\n","# if mngr and save_bundle:\n","# TODO: Save checkpoint\n","# mngr.save(...)\n","# mngr.wait_until_finished()\n","# print(f\"Checkpoint saved at step {current_step} to {checkpoint_dir}\")\n","# else:\n","# print(\"Checkpoint manager or save_bundle not initialized.\")\n","\n","# --- Restoration ---\n","# 6.a Create new model and Optax transform (for restoration)\n","# key_ex8_restore_model, main_key = jax.random.split(main_key)\n","# din_restore = model_to_save.dense_layer.in_features if hasattr(model_to_save, 'dense_layer') else 3\n","# dout_restore = model_to_save.dense_layer.out_features if hasattr(model_to_save, 'dense_layer') else 2\n","# _ModelClassRestore = SimpleNNXModel_Sol if 'SimpleNNXModel_Sol' in globals() else SimpleNNXModel\n","# restored_model = _ModelClassRestore(\n","# din=din_restore, dout=dout_restore,\n","# rngs=nnx.Rngs(params=key_ex8_restore_model) # New key for different initial params\n","# )\n","# restored_optax_tx = optax.adam(learning_rate=0.001) # Same Optax config\n","restored_model = None\n","restored_optax_tx = None\n","\n","# 6.b Restore the checkpoint\n","# loaded_bundle = None\n","# if mngr:\n","# TODO: Restore checkpoint\n","# latest_step = mngr.latest_step()\n","# if latest_step is not None:\n","# loaded_bundle = mngr.restore(...)\n","# print(f\"Checkpoint restored from step {latest_step}\")\n","# else:\n","# print(\"No checkpoint found to restore.\")\n","# else:\n","# print(\"Checkpoint manager not initialized for restore.\")\n","\n","# 6.c Apply loaded states\n","# if loaded_bundle and restored_model:\n","# TODO: Update restored_model state\n","# nnx.update(restored_model, ...)\n","# print(\"Restored model state applied.\")\n","\n"," # TODO: Create new nnx.Optimizer and assign its state\n","# restored_optimizer = nnx.Optimizer(...)\n","# restored_optimizer.state = ...\n","# print(\"Restored optimizer state applied.\")\n","# else:\n","# print(\"Loaded_bundle or restored_model is None, cannot apply states.\")\n","restored_optimizer = None\n","\n","# 7. Verify restoration\n","# original_kernel_sol = save_bundle_sol['model']['dense_layer']['kernel']\n","# _, restored_model_state = nnx.split(restored_model_sol)\n","# kernel_after_restore_sol = restored_model_state['dense_layer']['kernel']\n","# assert jnp.array_equal(original_kernel_sol.value, kernel_after_restore_sol.value), \\\n","# \"Model kernel parameters differ after restoration!\"\n","# print(\"\\nModel parameters successfully restored and verified (kernel match).\")\n","\n","# # Verify optimizer state (e.g., Adam's 'mu' for a specific parameter)\n","# original_opt_state_adam_mu_kernel = save_bundle_sol['optimizer'][0].mu['dense_layer']['kernel'].value\n","# restored_opt_state_adam_mu_kernel = restored_optimizer_sol.state[0].mu['dense_layer']['kernel'].value\n","# assert jnp.array_equal(original_opt_state_adam_mu_kernel, restored_opt_state_adam_mu_kernel), \\\n","# \"Optimizer Adam mu for kernel differs!\"\n","# print(\"Optimizer state (sample mu) successfully restored and verified.\")\n","\n","\n","# 8. Clean up\n","# if mngr:\n","# mngr.close()\n","# if checkpoint_dir and os.path.exists(checkpoint_dir):\n","# shutil.rmtree(checkpoint_dir)\n","# print(f\"Cleaned up checkpoint directory: {checkpoint_dir}\")"],"metadata":{"id":"V7XdNy-vAjpG"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# @title Solution 8: Orbax - Saving and Restoring Checkpoints\n","\n","# 1. Use model and optimizer from previous exercise solution\n","if 'my_model_sol_ex7' not in globals() or 'nnx_optimizer_sol_ex7' not in globals():\n"," print(\"Re-initializing model and optimizer from Ex7 solution for Ex8 solution.\")\n"," key_ex8_sol_model_fallback, main_key = jax.random.split(main_key)\n"," _model_din_sol_ex8 = 3\n"," _model_dout_sol_ex8 = 2\n"," _model_rngs_sol_ex8 = nnx.Rngs(params=key_ex8_sol_model_fallback)\n"," # Ensure SimpleNNXModel_Sol is used for the solution\n"," model_to_save_sol = SimpleNNXModel_Sol(din=_model_din_sol_ex8,\n"," dout=_model_dout_sol_ex8,\n"," rngs=_model_rngs_sol_ex8)\n"," _optax_tx_sol_ex8 = optax.adam(learning_rate=0.001) # Store the transform for later\n"," optimizer_to_save_sol = nnx.Optimizer(model_to_save_sol, _optax_tx_sol_ex8)\n"," print(\"Model and optimizer re-created for Ex8 solution.\")\n","else:\n"," model_to_save_sol = my_model_sol_ex7\n"," optimizer_to_save_sol = nnx_optimizer_sol_ex7\n"," # We need the optax transform used to create the optimizer for restoration\n"," _optax_tx_sol_ex8 = optimizer_to_save_sol.tx # Access the original Optax transform\n"," print(\"Using model and optimizer from Ex7 solution for Ex8 solution.\")\n","\n","# 2. Define checkpoint directory\n","checkpoint_dir_sol = \"/tmp/my_nnx_checkpoint_exercise_solution/\"\n","if os.path.exists(checkpoint_dir_sol):\n"," shutil.rmtree(checkpoint_dir_sol) # Clean up previous runs\n","os.makedirs(checkpoint_dir_sol, exist_ok=True)\n","print(f\"Orbax checkpoint directory: {checkpoint_dir_sol}\")\n","\n","# 3. Create Orbax CheckpointManager\n","options_sol = ocp.CheckpointManagerOptions(save_interval_steps=1, max_to_keep=1)\n","mngr_sol = ocp.CheckpointManager(checkpoint_dir_sol, options=options_sol)\n","\n","# 4. Bundle states\n","current_step_sol = 100 # Example step\n","_, model_state_to_save_sol = nnx.split(model_to_save_sol)\n","# The optimizer state is now a PyTree directly available in the .state attribute.\n","opt_state_to_save_sol = optimizer_to_save_sol.opt_state\n","save_bundle_sol = {\n"," 'model': model_state_to_save_sol,\n"," 'optimizer': opt_state_to_save_sol,\n"," 'step': current_step_sol\n","}\n","print(\"\\nState bundle to be saved:\")\n","pprint.pprint(f\"Model state keys: {model_state_to_save_sol.keys()}\")\n","pprint.pprint(f\"Optimizer state type: {type(opt_state_to_save_sol)}\")\n","\n","\n","# 5. Save the checkpoint\n","mngr_sol.save(current_step_sol, args=ocp.args.StandardSave(save_bundle_sol))\n","mngr_sol.wait_until_finished()\n","print(f\"\\nCheckpoint saved at step {current_step_sol} to {checkpoint_dir_sol}\")\n","\n","# --- Restoration ---\n","# 6.a Create new model and Optax transform (for restoration)\n","key_ex8_sol_restore_model, main_key = jax.random.split(main_key)\n","# Ensure din/dout are correctly obtained from the saved model's structure if possible\n","# Assuming model_to_save_sol is SimpleNNXModel_Sol which has a dense_layer\n","din_restore_sol = model_to_save_sol.dense_layer.in_features\n","dout_restore_sol = model_to_save_sol.dense_layer.out_features\n","\n","restored_model_sol = SimpleNNXModel_Sol( # Use the solution's model class\n"," din=din_restore_sol, dout=dout_restore_sol,\n"," rngs=nnx.Rngs(params=key_ex8_sol_restore_model) # New key for different initial params\n",")\n","# We need the original Optax transform definition for the new nnx.Optimizer\n","# _optax_tx_sol_ex8 was stored earlier, or can be re-created if config is known\n","restored_optax_tx_sol = _optax_tx_sol_ex8\n","\n","# Print a param from new model BEFORE restoration to show it's different\n","_, kernel_before_restore_sol = nnx.split(restored_model_sol)\n","print(f\"\\nSample kernel from 'restored_model_sol' BEFORE restoration:\")\n","nnx.display(kernel_before_restore_sol['dense_layer']['kernel'])\n","\n","# 6.b Restore the checkpoint\n","loaded_bundle_sol = None\n","latest_step_sol = mngr_sol.latest_step()\n","if latest_step_sol is not None:\n"," # For NNX, we are restoring raw PyTrees, StandardRestore is suitable.\n"," loaded_bundle_sol = mngr_sol.restore(latest_step_sol,\n"," args=ocp.args.StandardRestore(save_bundle_sol))\n"," print(f\"\\nCheckpoint restored from step {latest_step_sol}\")\n"," print(f\"Loaded bundle contains keys: {loaded_bundle_sol.keys()}\")\n","else:\n"," raise ValueError(\"No checkpoint found to restore.\")\n","\n","# 6.c Apply loaded states\n","assert loaded_bundle_sol is not None, \"Loaded bundle is None\"\n","nnx.update(restored_model_sol, loaded_bundle_sol['model'])\n","print(\"Restored model state applied to 'restored_model_sol'.\")\n","\n","# Create new nnx.Optimizer with the restored_model and original optax_tx\n","restored_optimizer_sol = nnx.Optimizer(restored_model_sol, restored_optax_tx_sol,\n"," wrt=nnx.Param)\n","# Now assign the loaded Optax state PyTree\n","restored_optimizer_sol.state = loaded_bundle_sol['optimizer']\n","print(\"Restored optimizer state applied to 'restored_optimizer_sol'.\")\n","\n","\n","# 7. Verify restoration\n","original_kernel_sol = save_bundle_sol['model']['dense_layer']['kernel']\n","_, restored_model_state = nnx.split(restored_model_sol)\n","kernel_after_restore_sol = restored_model_state['dense_layer']['kernel']\n","assert jnp.array_equal(original_kernel_sol.value, kernel_after_restore_sol.value), \\\n"," \"Model kernel parameters differ after restoration!\"\n","print(\"\\nModel parameters successfully restored and verified (kernel match).\")\n","\n","# Verify optimizer state (e.g., Adam's 'mu' for a specific parameter)\n","original_opt_state_adam_mu_kernel = save_bundle_sol['optimizer'][0].mu['dense_layer']['kernel'].value\n","restored_opt_state_adam_mu_kernel = restored_optimizer_sol.state[0].mu['dense_layer']['kernel'].value\n","assert jnp.array_equal(original_opt_state_adam_mu_kernel, restored_opt_state_adam_mu_kernel), \\\n"," \"Optimizer Adam mu for kernel differs!\"\n","print(\"Optimizer state (sample mu) successfully restored and verified.\")\n","\n","\n","# 8. Clean up\n","mngr_sol.close()\n","if os.path.exists(checkpoint_dir_sol):\n"," shutil.rmtree(checkpoint_dir_sol)\n"," print(f\"Cleaned up checkpoint directory: {checkpoint_dir_sol}\")"],"metadata":{"id":"2-Fk8aukEGVL"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Conclusion\n","\n","### Congratulations on completing the JAX AI Stack exercises!\n","\n","You've now had a hands-on introduction to:\n","\n","- Core JAX: jax.numpy, functional programming, jax.jit, jax.grad, jax.vmap.\n","- Flax NNX: Defining and instantiating Pythonic neural network models.\n","- Optax: Creating and using composable optimizers with Flax NNX.\n","- Training Loop: Implementing an end-to-end training step in Flax NNX.\n","- Orbax: Saving and restoring model and optimizer states.\n","\n","This forms a strong foundation for developing high-performance machine learning models with the JAX ecosystem.\n","\n","For further learning, refer to the official documentation:\n","- JAX AI Stack: https://jaxstack.ai\n","- JAX: https://jax.dev\n","- Flax NNX: https://flax.readthedocs.io\n","- Optax: https://optax.readthedocs.io\n","- Orbax: https://orbax.readthedocs.io\n","\n","Don't forget to provide feedback on the training session:\n","https://goo.gle/jax-training-feedback"],"metadata":{"id":"9kotBqE7Qhiv"}},{"cell_type":"code","source":[],"metadata":{"id":"TdQIp5G9QqwR"},"execution_count":null,"outputs":[]}]} \ No newline at end of file +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "AEYnLrsY27El" + }, + "source": [ + "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/rcrowe-google/Learning-JAX/blob/main/code-exercises/01%20-%20JAX%20AI%20Stack.ipynb)\n", + "\n", + "# Introduction\n", + "\n", + "**Welcome to the JAX AI Stack Exercises!**\n", + "\n", + "This notebook is designed to accompany the \"Leveraging the JAX AI Stack\" lecture. You'll get hands-on experience with core JAX concepts, Flax NNX for model building, Optax for optimization, and Orbax for checkpointing.\n", + "\n", + "The exercises will guide you through implementing key components, drawing parallels to PyTorch where appropriate, to solidify your understanding.\n", + "\n", + "Let's get started!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "OPA5MMD621LQ" + }, + "outputs": [], + "source": [ + "# @title Setup: Install and Import Libraries\n", + "# Install necessary libraries\n", + "!pip install -q jax-ai-stack==2025.9.3\n", + "\n", + "import jax\n", + "import jax.numpy as jnp\n", + "import flax\n", + "from flax import nnx\n", + "import optax\n", + "import orbax.checkpoint as ocp # For Orbax\n", + "from typing import Any, Dict, Tuple # For type hints\n", + "\n", + "# Helper to print PyTrees more nicely for demonstration\n", + "import pprint\n", + "import os # For Orbax directory management\n", + "import shutil # For cleaning up Orbax directory\n", + "\n", + "print(f\"JAX version: {jax.__version__}\")\n", + "print(f\"Flax version: {flax.__version__}\")\n", + "print(f\"Optax version: {optax.__version__}\")\n", + "print(f\"Orbax version: {ocp.__version__}\")\n", + "\n", + "# Global JAX PRNG key for reproducibility in exercises\n", + "# Students can learn to split this key for different operations.\n", + "main_key = jax.random.key(0)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "3gC7luR35tJd" + }, + "source": [ + "## Exercise 1: JAX Core & NumPy API\n", + "\n", + "**Goal**: Get familiar with jax.numpy and JAX's functional programming style.\n", + "\n", + "### Instructions:\n", + "\n", + "1. Create two JAX arrays, a (a 2x2 matrix of random numbers) and b (a 2x2 matrix of ones) using jax.numpy (jnp). You'll need a jax.random.key for creating random numbers.\n", + "2. Perform element-wise addition of a and b.\n", + "3. Perform matrix multiplication of a and b.\n", + "4. Demonstrate JAX's immutability:\n", + " - Store the Python id() of array a.\n", + " - Perform an operation like a = a + 1.\n", + " - Print the new id() of a and observe that it has changed, indicating a new array was created." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "8Tq_WFzc5Ycl" + }, + "outputs": [], + "source": [ + "# Instructions for Exercise 1\n", + "key_ex1, main_key = jax.random.split(main_key) # Split the main key\n", + "\n", + "# 1. Create JAX arrays a and b\n", + "# TODO: Create array 'a' (2x2 random normal) and 'b' (2x2 ones)\n", + "a = None # Placeholder\n", + "b = None # Placeholder\n", + "\n", + "print(\"Array a:\\n\", a)\n", + "print(\"Array b:\\n\", b)\n", + "\n", + "# 2. Perform element-wise addition\n", + "# TODO: Add a and b\n", + "c = None # Placeholder\n", + "print(\"Element-wise sum c = a + b:\\n\", c)\n", + "\n", + "# 3. Perform matrix multiplication\n", + "# TODO: Matrix multiply a and b\n", + "d = None # Placeholder\n", + "print(\"Matrix product d = a @ b:\\n\", d)\n", + "\n", + "# 4. Demonstrate immutability\n", + "# original_a_id = id(a)\n", + "# print(f\"Original id(a): {original_a_id}\")\n", + "\n", + "# TODO: Perform an operation that reassigns 'a', e.g., a = a + 1\n", + "# a = None # Placeholder\n", + "# new_a_id = id(a)\n", + "# print(f\"New id(a) after 'a = a + 1': {new_a_id}\")\n", + "\n", + "# TODO: Check if original_a_id is different from new_a_id\n", + "# print(f\"IDs are different: {None}\") # Placeholder\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "0p2HrUzH6NYQ" + }, + "outputs": [], + "source": [ + "# @title Solution 1: JAX Core & NumPy API\n", + "key_ex1_sol, main_key = jax.random.split(main_key)\n", + "\n", + "# 1. Create JAX arrays a and b\n", + "a_sol = jax.random.normal(key_ex1_sol, (2, 2))\n", + "b_sol = jnp.ones((2, 2))\n", + "\n", + "print(\"Array a:\\n\", a_sol)\n", + "print(\"Array b:\\n\", b_sol)\n", + "\n", + "# 2. Perform element-wise addition\n", + "c_sol = a_sol + b_sol\n", + "print(\"Element-wise sum c = a + b:\\n\", c_sol)\n", + "\n", + "# 3. Perform matrix multiplication\n", + "d_sol = jnp.dot(a_sol, b_sol) # or d = a @ b\n", + "print(\"Matrix product d = a @ b:\\n\", d_sol)\n", + "\n", + "# 4. Demonstrate immutability\n", + "original_a_id_sol = id(a_sol)\n", + "print(f\"Original id(a_sol): {original_a_id_sol}\")\n", + "\n", + "a_sol += 1 # This creates a new array and rebinds the variable `a_sol`.\n", + "# a_sol = a_sol + 1 # same as above line, an out-of-place update.\n", + "new_a_id_sol = id(a_sol)\n", + "print(f\"New id(a_sol_new_ref) after 'a_sol = a_sol + 1': {new_a_id_sol}\")\n", + "print(f\"IDs are different: {original_a_id_sol != new_a_id_sol}\")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This shows that the original array was **not** modified in-place; a new array was created. This is unlike `NumPy` where arrays are mutable, and so one can do in-place modification, leaving the id unchanged:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np \n", + "a = np.random.random((2,2))\n", + "original_id = id(a)\n", + "a += 1 \n", + "new_id = id(a)\n", + "print(f\"(NumPy, in-place update) IDs are different? {original_id != new_id}\")\n", + "\n", + "# One can also do an out-of-place update:\n", + "a = a + 1\n", + "new_id_out = id(a)\n", + "print(f\"(NumPy, out-of-place update) IDs are different? {original_id != new_id_out}\")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "MK4rErEp6WPx" + }, + "source": [ + "## Exercise 2: jax.jit (Just-In-Time Compilation)\n", + "\n", + "**Goal**: Understand how to use jax.jit to compile JAX functions for performance.\n", + "\n", + "### Instructions:\n", + "\n", + "1. Define a Python function compute_heavy_stuff(x, w, b) that performs a sequence of jnp operations:\n", + " - y = jnp.dot(x, w)\n", + " - y = y + b\n", + " - y = jnp.tanh(y)\n", + " - result = jnp.sum(y)\n", + " - Return result.\n", + "2. Create a JIT-compiled version of this function, fast_compute_heavy_stuff, using jax.jit.\n", + "3. Create some large dummy JAX arrays for x, w, and b.\n", + "4. Call both the original and JIT-compiled functions with the dummy data.\n", + "5. (Optional) Use the `%timeit` magic command in Colab (in separate cells) to compare their execution speeds. Remember that the first call to a JIT-compiled function includes compilation time." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "SNwAyNyO6SM3" + }, + "outputs": [], + "source": [ + "# Instructions for Exercise 2\n", + "key_ex2_main, main_key = jax.random.split(main_key)\n", + "key_ex2_x, key_ex2_w, key_ex2_b = jax.random.split(key_ex2_main, 3)\n", + "\n", + "# 1. Define the Python function\n", + "def compute_heavy_stuff(x, w, b):\n", + " # TODO: Implement the operations\n", + " y1 = None # Placeholder\n", + " y2 = None # Placeholder\n", + " y3 = None # Placeholder\n", + " result = None # Placeholder\n", + " return result\n", + "\n", + "# 2. Create a JIT-compiled version\n", + "# TODO: Use jax.jit to compile compute_heavy_stuff\n", + "fast_compute_heavy_stuff = None # Placeholder\n", + "\n", + "# 3. Create dummy data\n", + "dim1, dim2, dim3 = 500, 1000, 500\n", + "x_data = jax.random.normal(key_ex2_x, (dim1, dim2))\n", + "w_data = jax.random.normal(key_ex2_w, (dim2, dim3))\n", + "b_data = jax.random.normal(key_ex2_b, (dim3,))\n", + "\n", + "# 4. Call both functions\n", + "result_original = None # Placeholder compute_heavy_stuff(x_data, w_data, b_data)\n", + "result_fast_first_call = None # Placeholder fast_compute_heavy_stuff(x_data, w_data, b_data) # First call (compiles)\n", + "result_fast_second_call = None # Placeholder fast_compute_heavy_stuff(x_data, w_data, b_data) # Second call (uses compiled)\n", + "\n", + "print(f\"Result (original): {result_original}\")\n", + "print(f\"Result (fast, 1st call): {result_fast_first_call}\")\n", + "print(f\"Result (fast, 2nd call): {result_fast_second_call}\")\n", + "\n", + "# if result_original is not None and result_fast_first_call is not None:\n", + "# assert jnp.allclose(result_original, result_fast_first_call), \"Results should match!\"\n", + "# print(\"\\nResults from original and JIT-compiled functions match.\")\n", + "\n", + "# 5. Optional: Timing (use %timeit in separate cells for accuracy)\n", + "# print(\"\\nTo see the speed difference, run these in separate cells:\")\n", + "# print(\"%timeit compute_heavy_stuff(x_data, w_data, b_data).block_until_ready()\")\n", + "# print(\"%timeit fast_compute_heavy_stuff(x_data, w_data, b_data).block_until_ready()\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "xOLQxFay61ls" + }, + "outputs": [], + "source": [ + "# @title Solution 2: `jax.jit` (Just-In-Time Compilation)\n", + "key_ex2_sol_main, main_key = jax.random.split(main_key)\n", + "key_ex2_sol_x, key_ex2_sol_w, key_ex2_sol_b = jax.random.split(key_ex2_sol_main, 3)\n", + "\n", + "# 1. Define the Python function\n", + "def compute_heavy_stuff_sol(x, w, b):\n", + " y = jnp.dot(x, w)\n", + " y = y + b\n", + " y = jnp.tanh(y)\n", + " result = jnp.sum(y)\n", + " return result\n", + "\n", + "# 2. Create a JIT-compiled version\n", + "fast_compute_heavy_stuff_sol = jax.jit(compute_heavy_stuff_sol)\n", + "\n", + "# 3. Create dummy data\n", + "dim1_sol, dim2_sol, dim3_sol = 500, 1000, 500\n", + "x_data_sol = jax.random.normal(key_ex2_sol_x, (dim1_sol, dim2_sol))\n", + "w_data_sol = jax.random.normal(key_ex2_sol_w, (dim2_sol, dim3_sol))\n", + "b_data_sol = jax.random.normal(key_ex2_sol_b, (dim3_sol,))\n", + "\n", + "# 4. Call both functions\n", + "# Call original once to ensure it's not timed with any JAX overhead if it were the first JAX op\n", + "result_original_sol = compute_heavy_stuff_sol(x_data_sol, w_data_sol, b_data_sol).block_until_ready()\n", + "\n", + "# First call to JITed function includes compilation time\n", + "result_fast_sol_first_call = fast_compute_heavy_stuff_sol(x_data_sol, w_data_sol, b_data_sol).block_until_ready()\n", + "\n", + "# Subsequent calls use the cached compiled code\n", + "result_fast_sol_second_call = fast_compute_heavy_stuff_sol(x_data_sol, w_data_sol, b_data_sol).block_until_ready()\n", + "\n", + "print(f\"Result (original): {result_original_sol}\")\n", + "print(f\"Result (fast, 1st call): {result_fast_sol_first_call}\")\n", + "print(f\"Result (fast, 2nd call): {result_fast_sol_second_call}\")\n", + "\n", + "assert jnp.allclose(result_original_sol, result_fast_sol_first_call), \"Results should match!\"\n", + "print(\"\\nResults from original and JIT-compiled functions match.\")\n", + "\n", + "# 5. Optional: Timing\n", + "# To accurately measure, run these in separate Colab cells:\n", + "# Cell 1:\n", + "# %timeit compute_heavy_stuff_sol(x_data_sol, w_data_sol, b_data_sol).block_until_ready()\n", + "# Cell 2:\n", + "# %timeit fast_compute_heavy_stuff_sol(x_data_sol, w_data_sol, b_data_sol).block_until_ready()\n", + "# You should observe that the JIT-compiled version is significantly faster after the initial compilation.\n", + "print(\"\\nTo see the speed difference, run the %timeit commands (provided in comments above) in separate cells.\")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "MNZqLNB57CpS" + }, + "source": [ + "## Exercise 3: jax.grad (Automatic Differentiation)\n", + "\n", + "**Goal**: Learn to use jax.grad to compute gradients of functions.\n", + "\n", + "### Instructions:\n", + "\n", + "1. Define a Python function scalar_loss(params, x, y_true) that:\n", + " - Takes a dictionary params with keys 'w' and 'b'.\n", + " - Computes y_pred = params['w'] * x + params['b'].\n", + " - Returns a scalar loss, e.g., jnp.mean((y_pred - y_true)**2).\n", + "2. Use jax.grad to create a new function, compute_gradients, that computes the gradient of scalar_loss with respect to its first argument (params).\n", + "3. Initialize some dummy params, x_input, and y_target values.\n", + "4. Call compute_gradients to get the gradients. Print the gradients." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "g8S-6snP69KI" + }, + "outputs": [], + "source": [ + "# Instructions for Exercise 3\n", + "\n", + "# 1. Define the scalar_loss function\n", + "def scalar_loss(params: Dict[str, jnp.ndarray], x: jnp.ndarray, y_true: jnp.ndarray) -> jnp.ndarray:\n", + " # TODO: Implement the prediction and loss calculation\n", + " y_pred = None # Placeholder\n", + " loss = None # Placeholder\n", + " return loss\n", + "\n", + "# 2. Create the gradient function using jax.grad\n", + "# TODO: Gradient of scalar_loss w.r.t. 'params' (argnums=0)\n", + "compute_gradients = None # Placeholder\n", + "\n", + "# 3. Initialize dummy data\n", + "params_init = {'w': jnp.array(2.0), 'b': jnp.array(1.0)}\n", + "x_input_data = jnp.array([1.0, 2.0, 3.0])\n", + "y_target_data = jnp.array([7.0, 9.0, 11.0]) # Targets for y = 3x + 4 (to make non-zero loss with init_params)\n", + "\n", + "# 4. Call the gradient function\n", + "gradients = None # Placeholder compute_gradients(params_init, x_input_data, y_target_data)\n", + "print(\"Initial params:\", params_init)\n", + "print(\"Gradients w.r.t params:\\n\", gradients)\n", + "\n", + "# Expected gradients (manual calculation for y_pred = wx+b, loss = mean((y_pred - y_true)^2)):\n", + "# dL/dw = mean(2 * (wx+b - y_true) * x)\n", + "# dL/db = mean(2 * (wx+b - y_true) * 1)\n", + "# For params_init={'w': 2.0, 'b': 1.0}, x=[1,2,3], y_true=[7,9,11]\n", + "# x=1: y_pred = 2*1+1 = 3. Error = 3-7 = -4. dL/dw_i_term = 2*(-4)*1 = -8. dL/db_i_term = 2*(-4)*1 = -8\n", + "# x=2: y_pred = 2*2+1 = 5. Error = 5-9 = -4. dL/dw_i_term = 2*(-4)*2 = -16. dL/db_i_term = 2*(-4)*1 = -8\n", + "# x=3: y_pred = 2*3+1 = 7. Error = 7-11 = -4. dL/dw_i_term = 2*(-4)*3 = -24. dL/db_i_term = 2*(-4)*1 = -8\n", + "# Mean gradients: dL/dw = (-8-16-24)/3 = -48/3 = -16. dL/db = (-8-8-8)/3 = -24/3 = -8.\n", + "# if gradients is not None:\n", + "# assert jnp.isclose(gradients['w'], -16.0)\n", + "# assert jnp.isclose(gradients['b'], -8.0)\n", + "# print(\"\\nGradients match expected values.\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "jcjiql4O7ZQy" + }, + "outputs": [], + "source": [ + "# @title Solution 3: `jax.grad` (Automatic Differentiation)\n", + "\n", + "# 1. Define the scalar_loss function\n", + "def scalar_loss_sol(params: Dict[str, jnp.ndarray], x: jnp.ndarray, y_true: jnp.ndarray) -> jnp.ndarray:\n", + " y_pred = params['w'] * x + params['b']\n", + " loss = jnp.mean((y_pred - y_true)**2)\n", + " return loss\n", + "\n", + "# 2. Create the gradient function using jax.grad\n", + "# Gradient of scalar_loss w.r.t. 'params' (which is the 0-th argument)\n", + "compute_gradients_sol = jax.grad(scalar_loss_sol, argnums=0)\n", + "\n", + "# 3. Initialize dummy data\n", + "params_init_sol = {'w': jnp.array(2.0), 'b': jnp.array(1.0)}\n", + "x_input_data_sol = jnp.array([1.0, 2.0, 3.0])\n", + "y_target_data_sol = jnp.array([7.0, 9.0, 11.0])\n", + "\n", + "# 4. Call the gradient function\n", + "gradients_sol = compute_gradients_sol(params_init_sol, x_input_data_sol, y_target_data_sol)\n", + "print(\"Initial params:\", params_init_sol)\n", + "print(\"Gradients w.r.t params:\\n\", pprint.pformat(gradients_sol))\n", + "\n", + "# Verify with expected values (calculated in instructions)\n", + "expected_dL_dw = -16.0\n", + "expected_dL_db = -8.0\n", + "assert jnp.isclose(gradients_sol['w'], expected_dL_dw), f\"Grad w.r.t 'w' is {gradients_sol['w']}, expected {expected_dL_dw}\"\n", + "assert jnp.isclose(gradients_sol['b'], expected_dL_db), f\"Grad w.r.t 'b' is {gradients_sol['b']}, expected {expected_dL_db}\"\n", + "print(\"\\nGradients match expected values.\")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "XWoB6bD-7g2M" + }, + "source": [ + "## Exercise 4: jax.vmap (Automatic Vectorization)\n", + "\n", + "**Goal**: Use jax.vmap to automatically batch operations.\n", + "\n", + "### Instructions:\n", + "\n", + "1. Define a function apply_affine(vector, matrix, bias) that takes a single 1D vector, a 2D matrix, and a 1D bias. It should compute jnp.dot(matrix, vector) + bias.\n", + "2. You have a batch of vectors (a 2D array where each row is a vector), but a single matrix and a single bias that should be applied to each vector in the batch.\n", + "3. Use jax.vmap to create batched_apply_affine that efficiently applies apply_affine to each vector in the batch.\n", + " - Hint: in_axes for jax.vmap should specify 0 for the batched vector argument, and None for matrix and bias as they are not batched (broadcasted). The out_axes should be 0 to indicate the output is batched along the first axis.\n", + "4. Test batched_apply_affine with sample data." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "vA9mu1si7dii" + }, + "outputs": [], + "source": [ + "# Instructions for Exercise 4\n", + "key_ex4_main, main_key = jax.random.split(main_key)\n", + "key_ex4_vec, key_ex4_mat, key_ex4_bias = jax.random.split(key_ex4_main, 3)\n", + "\n", + "# 1. Define apply_affine for a single vector\n", + "def apply_affine(vector: jnp.ndarray, matrix: jnp.ndarray, bias: jnp.ndarray) -> jnp.ndarray:\n", + " # TODO: Compute jnp.dot(matrix, vector) + bias\n", + " result = None # Placeholder\n", + " return result\n", + "\n", + "# 2. Prepare data\n", + "batch_size = 4\n", + "input_features = 3\n", + "output_features = 2\n", + "\n", + "# batch_of_vectors: (batch_size, input_features)\n", + "# single_matrix: (output_features, input_features)\n", + "# single_bias: (output_features,)\n", + "batch_of_vectors = jax.random.normal(key_ex4_vec, (batch_size, input_features))\n", + "single_matrix = jax.random.normal(key_ex4_mat, (output_features, input_features))\n", + "single_bias = jax.random.normal(key_ex4_bias, (output_features,))\n", + "\n", + "\n", + "# 3. Use jax.vmap to create batched_apply_affine\n", + "# TODO: Specify in_axes correctly: vector is batched, matrix and bias are not. out_axes should be 0.\n", + "batched_apply_affine = None # Placeholder jax.vmap(apply_affine, in_axes=(..., ... , ...), out_axes=...)\n", + "\n", + "\n", + "# 4. Test batched_apply_affine\n", + "result_vmap = None # Placeholder batched_apply_affine(batch_of_vectors, single_matrix, single_bias)\n", + "print(\"Batch of vectors shape:\", batch_of_vectors.shape)\n", + "print(\"Single matrix shape:\", single_matrix.shape)\n", + "print(\"Single bias shape:\", single_bias.shape)\n", + "if result_vmap is not None:\n", + " print(\"Result using vmap shape:\", result_vmap.shape) # Expected: (batch_size, output_features)\n", + "\n", + " # For comparison, a manual loop (less efficient):\n", + " # manual_results = []\n", + " # for i in range(batch_size):\n", + " # manual_results.append(apply_affine(batch_of_vectors[i], single_matrix, single_bias))\n", + " # result_manual_loop = jnp.stack(manual_results)\n", + " # assert jnp.allclose(result_vmap, result_manual_loop)\n", + " # print(\"vmap result matches manual loop result.\")\n", + "else:\n", + " print(\"result_vmap is None.\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "q1QkKEtF76yo" + }, + "outputs": [], + "source": [ + "# @title Solution 4: `jax.vmap` (Automatic Vectorization)\n", + "key_ex4_sol_main, main_key = jax.random.split(main_key)\n", + "key_ex4_sol_vec, key_ex4_sol_mat, key_ex4_sol_bias = jax.random.split(key_ex4_sol_main, 3)\n", + "\n", + "# 1. Define apply_affine for a single vector\n", + "def apply_affine_sol(vector: jnp.ndarray, matrix: jnp.ndarray, bias: jnp.ndarray) -> jnp.ndarray:\n", + " return jnp.dot(matrix, vector) + bias\n", + "\n", + "# 2. Prepare data\n", + "batch_size_sol = 4\n", + "input_features_sol = 3\n", + "output_features_sol = 2\n", + "\n", + "batch_of_vectors_sol = jax.random.normal(key_ex4_sol_vec, (batch_size_sol, input_features_sol))\n", + "single_matrix_sol = jax.random.normal(key_ex4_sol_mat, (output_features_sol, input_features_sol))\n", + "single_bias_sol = jax.random.normal(key_ex4_sol_bias, (output_features_sol,))\n", + "\n", + "# 3. Use jax.vmap to create batched_apply_affine\n", + "# Vector is batched along axis 0, matrix and bias are not batched (broadcasted).\n", + "# out_axes=0 means the output will also be batched along its first axis.\n", + "batched_apply_affine_sol = jax.vmap(apply_affine_sol, in_axes=(0, None, None), out_axes=0)\n", + "\n", + "# 4. Test batched_apply_affine\n", + "result_vmap_sol = batched_apply_affine_sol(batch_of_vectors_sol, single_matrix_sol, single_bias_sol)\n", + "print(\"Batch of vectors shape:\", batch_of_vectors_sol.shape)\n", + "print(\"Single matrix shape:\", single_matrix_sol.shape)\n", + "print(\"Single bias shape:\", single_bias_sol.shape)\n", + "print(\"Result using vmap shape:\", result_vmap_sol.shape) # Expected: (batch_size, output_features)\n", + "assert result_vmap_sol.shape == (batch_size_sol, output_features_sol)\n", + "\n", + "# For comparison, a manual loop (less efficient):\n", + "manual_results_sol = []\n", + "for i in range(batch_size_sol):\n", + " manual_results_sol.append(apply_affine_sol(batch_of_vectors_sol[i], single_matrix_sol, single_bias_sol))\n", + "result_manual_loop_sol = jnp.stack(manual_results_sol)\n", + "\n", + "assert jnp.allclose(result_vmap_sol, result_manual_loop_sol)\n", + "print(\"\\nvmap result matches manual loop result, demonstrating correct vectorization.\")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "3LAlhdzq8D_S" + }, + "source": [ + "## Exercise 5: Flax NNX - Defining a Model\n", + "\n", + "**Goal**: Learn to define a simple neural network model using Flax NNX.\n", + "\n", + "### Instructions:\n", + "\n", + "1. Define a Flax NNX model class SimpleNNXModel that inherits from nnx.Module.\n", + "2. In its __init__, define one nnx.Linear layer. The layer should take din (input features) and dout (output features) as arguments. Remember to pass the rngs argument to nnx.Linear for parameter initialization (e.g., rngs=rngs).\n", + "3. Implement the __call__ method (the forward pass) which takes an input x and passes it through the linear layer.\n", + "4. Instantiate your SimpleNNXModel. You'll need to create an nnx.Rngs object using a JAX PRNG key (e.g., nnx.Rngs(params=jax.random.key(seed))). The key name params is conventional for nnx.Linear.\n", + "5. Test your model instance with a dummy input batch. Print the output and the model's state (parameters) using nnx.display()." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "BzUjMHll7--R" + }, + "outputs": [], + "source": [ + "# Instructions for Exercise 5\n", + "key_ex5_model_init, main_key = jax.random.split(main_key)\n", + "\n", + "# 1. & 2. & 3. Define the SimpleNNXModel\n", + "class SimpleNNXModel(nnx.Module):\n", + " def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):\n", + " # TODO: Define an nnx.Linear layer named 'dense_layer'\n", + " # self.dense_layer = nnx.Linear(...)\n", + " self.some_attribute = None # Placeholder, remove later\n", + " pass # Remove this placeholder if class is not empty\n", + "\n", + " def __call__(self, x: jnp.ndarray) -> jnp.ndarray:\n", + " # TODO: Pass input x through the dense_layer\n", + " # return self.dense_layer(x)\n", + " return x # Placeholder\n", + "\n", + "# 4. Instantiate the model\n", + "model_din = 3\n", + "model_dout = 2\n", + "# TODO: Create nnx.Rngs for parameter initialization. Use 'params' as the key name.\n", + "model_rngs = None # Placeholder nnx.Rngs(params=key_ex5_model_init)\n", + "my_model = None # Placeholder SimpleNNXModel(din=model_din, dout=model_dout, rngs=model_rngs)\n", + "\n", + "# 5. Test with dummy data\n", + "dummy_batch_size = 4\n", + "dummy_input_ex5 = jnp.ones((dummy_batch_size, model_din))\n", + "\n", + "model_output = None # Placeholder\n", + "if my_model is not None:\n", + " model_output = my_model(dummy_input_ex5)\n", + " print(f\"Model output shape: {model_output.shape}\")\n", + " print(f\"Model output:\\n{model_output}\")\n", + "\n", + " model_state = my_model.get_state()\n", + " print(f\"\\nModel state (parameters, etc.):\")\n", + " pprint.pprint(model_state)\n", + "else:\n", + " print(\"my_model is None.\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "QbBqSZse8V9y" + }, + "outputs": [], + "source": [ + "# @title Solution 5: Flax NNX - Defining a Model\n", + "key_ex5_sol_model_init, main_key = jax.random.split(main_key)\n", + "\n", + "# 1. & 2. & 3. Define the SimpleNNXModel\n", + "class SimpleNNXModel_Sol(nnx.Module): # Renamed for solution cell\n", + " def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):\n", + " # nnx.Linear will use the 'params' key from rngs by default for its parameters\n", + " self.dense_layer = nnx.Linear(din, dout, rngs=rngs)\n", + "\n", + " def __call__(self, x: jnp.ndarray) -> jnp.ndarray:\n", + " return self.dense_layer(x)\n", + "\n", + "# 4. Instantiate the model\n", + "model_din_sol = 3\n", + "model_dout_sol = 2\n", + "# Create nnx.Rngs for parameter initialization.\n", + "# 'params' is the default key nnx.Linear looks for in the rngs object.\n", + "model_rngs_sol = nnx.Rngs(params=key_ex5_sol_model_init)\n", + "my_model_sol = SimpleNNXModel_Sol(din=model_din_sol, dout=model_dout_sol, rngs=model_rngs_sol)\n", + "\n", + "# 5. Test with dummy data\n", + "dummy_batch_size_sol = 4\n", + "dummy_input_ex5_sol = jnp.ones((dummy_batch_size_sol, model_din_sol))\n", + "\n", + "model_output_sol = my_model_sol(dummy_input_ex5_sol)\n", + "print(f\"Model output shape: {model_output_sol.shape}\")\n", + "print(f\"Model output:\\n{model_output_sol}\")\n", + "\n", + "# model_state_sol = my_model_sol.get_state()\n", + "_, model_state_sol = nnx.split(my_model_sol)\n", + "print(f\"\\nModel state (parameters, etc.):\")\n", + "nnx.display(model_state_sol)\n", + "\n", + "# Check that parameters are present\n", + "assert 'dense_layer' in model_state_sol, \"Key 'dense_layer' not in model_state\"\n", + "assert 'kernel' in model_state_sol['dense_layer'], \"Key 'kernel' not in model_state['dense_layer']\"\n", + "assert 'bias' in model_state_sol['dense_layer'], \"Key 'bias' not in model_state['dense_layer']\"\n", + "print(\"\\nModel parameters (kernel and bias for dense_layer) are present in the state.\")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "i4kuv2IH-FbA" + }, + "source": [ + "## Exercise 6: Optax & Flax NNX - Creating an Optimizer\n", + "\n", + "**Goal**: Set up an Optax optimizer and wrap it with nnx.Optimizer for use with a Flax NNX model.\n", + "\n", + "### Instructions:\n", + "1. Use the SimpleNNXModel_Sol class and an instance my_model_sol from the previous exercise's solution. (If running standalone, re-instantiate it).\n", + "2. Create an Optax optimizer, for example, optax.adam with a learning rate of 0.001.\n", + "3. Create an nnx.Optimizer instance. This wrapper links the Optax optimizer with your Flax NNX model (my_model_sol).\n", + "4. Print the nnx.Optimizer instance and its state attribute to see the initialized optimizer state (e.g., Adam's momentum terms)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "ytaIj3xK8ZMI" + }, + "outputs": [], + "source": [ + "# Instructions for Exercise 6\n", + "\n", + "# 1. Assume my_model_sol is available from Exercise 5 solution\n", + "# (If running standalone, re-instantiate it)\n", + "if 'my_model_sol' not in globals():\n", + " print(\"Re-initializing model from Ex5 solution for Ex6.\")\n", + " key_ex6_model_init, main_key = jax.random.split(main_key)\n", + " _model_din_ex6 = 3\n", + " _model_dout_ex6 = 2\n", + " _model_rngs_ex6 = nnx.Rngs(params=key_ex6_model_init)\n", + " # Use solution class name if defined, otherwise student's class name\n", + " _ModelClass = SimpleNNXModel_Sol if 'SimpleNNXModel_Sol' in globals() else SimpleNNXModel\n", + " model_for_opt = _ModelClass(din=_model_din_ex6, dout=_model_dout_ex6, rngs=_model_rngs_ex6)\n", + " print(\"Model for optimizer created.\")\n", + "else:\n", + " model_for_opt = my_model_sol # Use the one from previous solution\n", + " print(\"Using model 'my_model_sol' from previous exercise for 'model_for_opt'.\")\n", + "\n", + "\n", + "# 2. Create an Optax optimizer\n", + "learning_rate = 0.001\n", + "# TODO: Create an optax.adam optimizer transform\n", + "optax_tx = None # Placeholder optax.adam(...)\n", + "\n", + "# 3. Create an nnx.Optimizer wrapper\n", + "# TODO: Wrap the model (model_for_opt) and the optax transform (optax_tx)\n", + "# The `wrt` argument is now required to specify what to differentiate with respect to.\n", + "nnx_optimizer = None # Placeholder nnx.Optimizer(...)\n", + "\n", + "# 4. Print the optimizer and its state\n", + "print(\"\\nFlax NNX Optimizer wrapper:\")\n", + "nnx.display(nnx_optimizer)\n", + "\n", + "print(\"\\nInitial Optimizer State (Optax state, e.g., Adam's momentum):\")\n", + "if nnx_optimizer is not None and hasattr(nnx_optimizer, 'opt_state'):\n", + " pprint.pprint(nnx_optimizer.state)\n", + " # if hasattr(nnx_optimizer, 'opt_state'):\n", + " # adam_state = nnx_optimizer.opt_state\n", + " # assert len(adam_state) > 0 and hasattr(adam_state[0], 'count')\n", + " # print(\"\\nOptimizer state structure looks plausible for Adam.\")\n", + "else:\n", + " print(\"nnx_optimizer or its state is None or not structured as expected.\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "f1ccATgB-Zed" + }, + "outputs": [], + "source": [ + "# @title Solution 6: Optax & Flax NNX - Creating an Optimizer\n", + "\n", + "# 1. Use my_model_sol from Exercise 5 solution\n", + "# If not run sequentially, ensure my_model_sol is defined:\n", + "if 'my_model_sol' not in globals():\n", + " print(\"Re-initializing model from Ex5 solution for Ex6.\")\n", + " key_ex6_sol_model_init, main_key = jax.random.split(main_key)\n", + " _model_din_sol_ex6 = 3\n", + " _model_dout_sol_ex6 = 2\n", + " _model_rngs_sol_ex6 = nnx.Rngs(params=key_ex6_sol_model_init)\n", + " # Ensure SimpleNNXModel_Sol is used\n", + " my_model_sol = SimpleNNXModel_Sol(din=_model_din_sol_ex6, dout=_model_dout_sol_ex6, rngs=_model_rngs_sol_ex6)\n", + " print(\"Model for optimizer re-created as 'my_model_sol'.\")\n", + "else:\n", + " print(\"Using model 'my_model_sol' from previous exercise.\")\n", + "\n", + "\n", + "# 2. Create an Optax optimizer\n", + "learning_rate_sol = 0.001\n", + "# Create an optax.adam optimizer transform\n", + "optax_tx_sol = optax.adam(learning_rate=learning_rate_sol)\n", + "\n", + "# 3. Create an nnx.Optimizer wrapper\n", + "# This links the model and the Optax optimizer.\n", + "# The optimizer state will be initialized based on the model's parameters.\n", + "nnx_optimizer_sol = nnx.Optimizer(my_model_sol, optax_tx_sol, wrt=nnx.Param)\n", + "\n", + "# 4. Print the optimizer and its state\n", + "print(\"\\nFlax NNX Optimizer wrapper:\")\n", + "nnx.display(nnx_optimizer_sol) # Shows the model it's associated with and the Optax transform\n", + "\n", + "print(\"\\nInitial Optimizer State (Optax state, e.g., Adam's momentum):\")\n", + "# nnx.Optimizer stores the actual Optax state in its .opt_state attribute.\n", + "# This state is a PyTree that matches the structure of the model's parameters.\n", + "pprint.pprint(nnx_optimizer_sol.opt_state)\n", + "\n", + "# Verify the structure of the optimizer state for Adam (count, mu, nu for each param)\n", + "assert hasattr(nnx_optimizer_sol, 'opt_state'), \"Optax opt_state not found in nnx.Optimizer\"\n", + "# The opt_state is a tuple, typically (CountState(), ScaleByAdamState()) for adam\n", + "adam_optax_internal_state = nnx_optimizer_sol.opt_state\n", + "assert len(adam_optax_internal_state) > 0 and hasattr(adam_optax_internal_state[0], 'count'), \"Adam 'count' state not found.\"\n", + "# The second element of the tuple is often where parameter-specific states like mu and nu reside\n", + "if len(adam_optax_internal_state) > 1 and hasattr(adam_optax_internal_state[1], 'mu'):\n", + " param_specific_state = adam_optax_internal_state[1]\n", + " assert 'dense_layer' in param_specific_state.mu and 'kernel' in param_specific_state.mu['dense_layer'], \"Adam 'mu' state for kernel not found.\"\n", + " print(\"\\nOptimizer state structure looks correct for Adam.\")\n", + "else:\n", + " print(\"\\nWarning: Optimizer state structure for Adam might be different or not fully verified.\")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "i7jXowc9ACNB" + }, + "source": [ + "## Exercise 7: Training Step with Flax NNX and Optax\n", + "\n", + "**Goal**: Implement a complete JIT-compiled training step for a Flax NNX model using Optax.\n", + "\n", + "### Instructions:\n", + "\n", + "1. You'll need:\n", + " - An instance of your model class (e.g., my_model_sol from Ex 5/6 solution).\n", + " - An instance of nnx.Optimizer (e.g., nnx_optimizer_sol from Ex 6 solution).\n", + "2. Define a train_step function that is decorated with @nnx.jit. This function should take the model, optimizer, input x_batch, and target y_batch as arguments.\n", + "3. Inside train_step:\n", + " - Define an inner loss_fn_for_grad. This function must take the model as its first argument. Inside, it computes the model's predictions for x_batch and then calculates the mean squared error (MSE) against y_batch.\n", + " - Use nnx.value_and_grad(loss_fn_for_grad)(model_arg) to compute both the loss value and the gradients with respect to the model passed to loss_fn_for_grad. (model_arg is the model instance passed into train_step).\n", + " - Update the model's parameters (and the optimizer's state) using optimizer_arg.update(model_arg, grads). The update method takes the model and gradients, and updates the model's state in-place.\n", + " - Return the computed loss_value.\n", + "4. Create dummy x_batch and y_batch data.\n", + "5. Call your train_step function. Print the returned loss.\n", + "6. (Optional) Verify that the model's parameters have changed after the train_step by comparing a parameter value before and after the call." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "KEQCcmBI-ce2" + }, + "outputs": [], + "source": [ + "# Instructions for Exercise 7\n", + "key_ex7_main, main_key = jax.random.split(main_key)\n", + "key_ex7_x, key_ex7_y = jax.random.split(key_ex7_main, 2)\n", + "\n", + "# 1. Use model and optimizer from previous exercises' solutions\n", + "# Ensure my_model_sol and nnx_optimizer_sol are available\n", + "if 'my_model_sol' not in globals() or 'nnx_optimizer_sol' not in globals():\n", + " print(\"Re-initializing model and optimizer from Ex5/Ex6 solutions for Ex7.\")\n", + " key_ex7_model_fallback, main_key = jax.random.split(main_key)\n", + " _model_din_ex7 = 3\n", + " _model_dout_ex7 = 2\n", + " _model_rngs_ex7 = nnx.Rngs(params=key_ex7_model_fallback)\n", + " # Ensure SimpleNNXModel_Sol is used\n", + " my_model_ex7 = SimpleNNXModel_Sol(din=_model_din_ex7, dout=_model_dout_ex7, rngs=_model_rngs_ex7)\n", + " _optax_tx_ex7 = optax.adam(learning_rate=0.001)\n", + " nnx_optimizer_ex7 = nnx.Optimizer(my_model_ex7, _optax_tx_ex7)\n", + " print(\"Model and optimizer re-created for Ex7.\")\n", + "else:\n", + " my_model_ex7 = my_model_sol\n", + " nnx_optimizer_ex7 = nnx_optimizer_sol\n", + " print(\"Using 'my_model_sol' and 'nnx_optimizer_sol' for 'my_model_ex7' and 'nnx_optimizer_ex7'.\")\n", + "\n", + "\n", + "# 2. & 3. Define the train_step function\n", + "# TODO: Decorate with @nnx.jit\n", + "# def train_step(model_arg: nnx.Module, optimizer_arg: nnx.Optimizer, # Type hint with base nnx.Module\n", + "# x_batch: jnp.ndarray, y_batch: jnp.ndarray) -> jnp.ndarray:\n", + "\n", + " # TODO: Define inner loss_fn_for_grad(current_model_state_for_grad_fn)\n", + " # def loss_fn_for_grad(model_in_grad_fn: nnx.Module): # Type hint with base nnx.Module\n", + " # y_pred = model_in_grad_fn(x_batch)\n", + " # loss = jnp.mean((y_pred - y_batch)**2)\n", + " # return loss\n", + " # return jnp.array(0.0) # Placeholder\n", + "\n", + " # TODO: Compute loss value and gradients using nnx.value_and_grad\n", + " # loss_value, grads = nnx.value_and_grad(loss_fn_for_grad)(model_arg) # Pass model_arg\n", + "\n", + " # TODO: Update the optimizer (which updates the model_arg in-place)\n", + " # optimizer_arg.update(model_arg, grads)\n", + "\n", + " # return loss_value\n", + "# return jnp.array(0.0) # Placeholder defined train_step function\n", + "\n", + "# For the student to define:\n", + "# Make sure the function signature is correct for nnx.jit\n", + "@nnx.jit\n", + "def train_step(model_arg: nnx.Module, optimizer_arg: nnx.Optimizer,\n", + " x_batch: jnp.ndarray, y_batch: jnp.ndarray) -> jnp.ndarray:\n", + " # Placeholder implementation for student\n", + " def loss_fn_for_grad(model_in_grad_fn: nnx.Module):\n", + " # y_pred = model_in_grad_fn(x_batch)\n", + " # loss = jnp.mean((y_pred - y_batch)**2)\n", + " # return loss\n", + " return jnp.array(0.0) # Student TODO: replace this\n", + "\n", + " # loss_value, grads = nnx.value_and_grad(loss_fn_for_grad)(model_arg)\n", + " # optimizer_arg.update(grads)\n", + " # return loss_value\n", + " return jnp.array(-1.0) # Student TODO: replace this\n", + "\n", + "\n", + "# 4. Create dummy data\n", + "batch_s = 8\n", + "# Access features_in and features_out carefully\n", + "_din_from_model_ex7 = my_model_ex7.dense_layer.in_features if hasattr(my_model_ex7, 'dense_layer') else 3\n", + "_dout_from_model_ex7 = my_model_ex7.dense_layer.out_features if hasattr(my_model_ex7, 'dense_layer') else 2\n", + "\n", + "x_batch_data = jax.random.normal(key_ex7_x, (batch_s, _din_from_model_ex7))\n", + "y_batch_data = jax.random.normal(key_ex7_y, (batch_s, _dout_from_model_ex7))\n", + "\n", + "# Optional: Store initial param value for comparison\n", + "initial_kernel_val = None\n", + "if hasattr(my_model_ex7, 'get_state'):\n", + " _current_model_state_ex7 = my_model_ex7.get_state()\n", + " if 'dense_layer' in _current_model_state_ex7:\n", + " initial_kernel_val = _current_model_state_ex7['dense_layer']['kernel'].value[0,0].copy()\n", + "print(f\"Initial kernel value (sample): {initial_kernel_val}\")\n", + "\n", + "# 5. Call the train_step\n", + "# loss_after_step = train_step(my_model_ex7, nnx_optimizer_ex7, x_batch_data, y_batch_data) # Student will uncomment\n", + "loss_after_step = jnp.array(-1.0) # Placeholder until student implements train_step\n", + "if train_step(my_model_ex7, nnx_optimizer_ex7, x_batch_data, y_batch_data).item() != -1.0: # Check if student implemented\n", + " loss_after_step = train_step(my_model_ex7, nnx_optimizer_ex7, x_batch_data, y_batch_data)\n", + " print(f\"Loss after one training step: {loss_after_step}\")\n", + "else:\n", + " print(\"Student needs to implement `train_step` function.\")\n", + "\n", + "\n", + "# # 6. Optional: Verify parameter change\n", + "# updated_kernel_val_sol = None\n", + "# _, updated_model_state_sol = nnx.split(my_model_sol_ex7) # Get state again after update\n", + "# if 'dense_layer' in updated_model_state_sol:\n", + "# updated_kernel_val_sol = updated_model_state_sol['dense_layer']['kernel'].value[0,0]\n", + "# print(f\"Updated kernel value (sample): {updated_kernel_val_sol}\")\n", + "\n", + "# if initial_kernel_val_sol is not None and updated_kernel_val_sol is not None:\n", + "# assert not jnp.allclose(initial_kernel_val_sol, updated_kernel_val_sol), \"Kernel parameter did not change!\"\n", + "# print(\"Kernel parameter changed as expected after the training step.\")\n", + "# else:\n", + "# print(\"Could not verify kernel change (initial or updated value was None).\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "7bVlg9_-Ae6Z" + }, + "outputs": [], + "source": [ + "# @title Solution 7: Training Step with Flax NNX and Optax\n", + "key_ex7_sol_main, main_key = jax.random.split(main_key)\n", + "key_ex7_sol_x, key_ex7_sol_y = jax.random.split(key_ex7_sol_main, 2)\n", + "\n", + "# 1. Use model and optimizer from previous exercises' solutions\n", + "if 'my_model_sol' not in globals() or 'nnx_optimizer_sol' not in globals():\n", + " print(\"Re-initializing model and optimizer from Ex5/Ex6 solutions for Ex7 solution.\")\n", + " key_ex7_sol_model_fallback, main_key = jax.random.split(main_key)\n", + " _model_din_sol_ex7 = 3\n", + " _model_dout_sol_ex7 = 2\n", + " _model_rngs_sol_ex7 = nnx.Rngs(params=key_ex7_sol_model_fallback)\n", + " # Ensure SimpleNNXModel_Sol is used for the solution\n", + " my_model_sol_ex7 = SimpleNNXModel_Sol(din=_model_din_sol_ex7, dout=_model_dout_sol_ex7, rngs=_model_rngs_sol_ex7)\n", + " _optax_tx_sol_ex7 = optax.adam(learning_rate=0.001)\n", + " nnx_optimizer_sol_ex7 = nnx.Optimizer(my_model_sol_ex7, _optax_tx_sol_ex7)\n", + " print(\"Model and optimizer re-created for Ex7 solution.\")\n", + "else:\n", + " # If solutions are run sequentially, these will be the correct instances\n", + " my_model_sol_ex7 = my_model_sol\n", + " nnx_optimizer_sol_ex7 = nnx_optimizer_sol\n", + " print(\"Using 'my_model_sol' and 'nnx_optimizer_sol' for Ex7 solution.\")\n", + "\n", + "\n", + "# 2. & 3. Define the train_step function\n", + "@nnx.jit # Decorate with @nnx.jit for JIT compilation\n", + "def train_step_sol(model_arg: nnx.Module, optimizer_arg: nnx.Optimizer, # Use base nnx.Module for generality\n", + " x_batch: jnp.ndarray, y_batch: jnp.ndarray) -> jnp.ndarray:\n", + "\n", + " # Define inner loss_fn_for_grad. It takes the model as its first argument.\n", + " # It captures x_batch and y_batch from the outer scope.\n", + " def loss_fn_for_grad(model_in_grad_fn: nnx.Module): # Use base nnx.Module\n", + " y_pred = model_in_grad_fn(x_batch) # Use the model passed to this inner function\n", + " loss = jnp.mean((y_pred - y_batch)**2)\n", + " return loss\n", + "\n", + " # Compute loss value and gradients using nnx.value_and_grad.\n", + " # This will differentiate loss_fn_for_grad with respect to its first argument (model_in_grad_fn).\n", + " # We pass the current state of our model (model_arg) to it.\n", + " loss_value, grads = nnx.value_and_grad(loss_fn_for_grad)(model_arg)\n", + "\n", + " # Update the optimizer. This updates the model_arg (which nnx_optimizer_sol_ex7 references) in-place.\n", + " optimizer_arg.update(model_arg, grads)\n", + "\n", + " return loss_value\n", + "\n", + "\n", + "# 4. Create dummy data\n", + "batch_s_sol = 8\n", + "# Ensure din and dout match the model instantiation from Ex5/Ex6\n", + "# my_model_sol_ex7.dense_layer is an nnx.Linear object\n", + "din_from_model_sol = my_model_sol_ex7.dense_layer.in_features\n", + "dout_from_model_sol = my_model_sol_ex7.dense_layer.out_features\n", + "\n", + "x_batch_data_sol = jax.random.normal(key_ex7_sol_x, (batch_s_sol, din_from_model_sol))\n", + "y_batch_data_sol = jax.random.normal(key_ex7_sol_y, (batch_s_sol, dout_from_model_sol))\n", + "\n", + "# Optional: Store initial param value for comparison\n", + "initial_kernel_val_sol = None\n", + "_, current_model_state_sol = nnx.split(my_model_sol_ex7)\n", + "if 'dense_layer' in current_model_state_sol:\n", + " initial_kernel_val_sol = current_model_state_sol['dense_layer']['kernel'].value[0,0].copy()\n", + "print(f\"Initial kernel value (sample): {initial_kernel_val_sol}\")\n", + "\n", + "\n", + "# 5. Call the train_step\n", + "# First call will JIT compile the train_step_sol function.\n", + "loss_after_step_sol = train_step_sol(my_model_sol_ex7, nnx_optimizer_sol_ex7, x_batch_data_sol, y_batch_data_sol)\n", + "print(f\"Loss after one training step (1st call, JIT): {loss_after_step_sol}\")\n", + "# Second call to show it's faster (though %timeit is better for measurement)\n", + "loss_after_step_sol_2 = train_step_sol(my_model_sol_ex7, nnx_optimizer_sol_ex7, x_batch_data_sol, y_batch_data_sol)\n", + "print(f\"Loss after one training step (2nd call, cached): {loss_after_step_sol_2}\")\n", + "\n", + "\n", + "# 6. Optional: Verify parameter change\n", + "updated_kernel_val_sol = None\n", + "_, updated_model_state_sol = nnx.split(my_model_sol_ex7) # Get state again after update\n", + "if 'dense_layer' in updated_model_state_sol:\n", + " updated_kernel_val_sol = updated_model_state_sol['dense_layer']['kernel'].value[0,0]\n", + " print(f\"Updated kernel value (sample): {updated_kernel_val_sol}\")\n", + "\n", + "if initial_kernel_val_sol is not None and updated_kernel_val_sol is not None:\n", + " assert not jnp.allclose(initial_kernel_val_sol, updated_kernel_val_sol), \"Kernel parameter did not change!\"\n", + " print(\"Kernel parameter changed as expected after the training step.\")\n", + "else:\n", + " print(\"Could not verify kernel change (initial or updated value was None).\")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "_t8KGFhqDoSu" + }, + "source": [ + "## Exercise 8: Orbax - Saving and Restoring Checkpoints\n", + "\n", + "**Goal**: Learn to use Orbax to save and restore JAX PyTrees, specifically Flax NNX model states and Optax optimizer states.\n", + "\n", + "### Instructions:\n", + "1. You'll need your model (e.g., my_model_sol_ex7) and optimizer (e.g., nnx_optimizer_sol_ex7) from the previous exercise's solution.\n", + "2. Define a checkpoint directory (e.g., /tmp/my_nnx_checkpoint/).\n", + "3. Create an Orbax CheckpointManagerOptions and then a CheckpointManager.\n", + "4. Bundle the states you want to save into a dictionary. For NNX, this is my_model_sol_ex7.get_state() for the model, and nnx_optimizer_sol_ex7.state for the optimizer's internal state. Also include a training step counter.\n", + "5. Use checkpoint_manager.save() with ocp.args.StandardSave() to save the bundled state. Call checkpoint_manager.wait_until_finished() to ensure saving completes.\n", + "6. To restore:\n", + " - Create new instances of your model (restored_model) and Optax transform (restored_optax_tx). The new model should have a different PRNG key for its initial parameters to demonstrate that restoration works.\n", + " - Use checkpoint_manager.restore() with ocp.args.StandardRestore() to load the bundled state.\n", + " - Apply the loaded model state to restored_model using restored_model.update_state(loaded_bundle['model']).\n", + " - Create a new nnx.Optimizer (restored_optimizer) associating restored_model and restored_optax_tx.\n", + " - Assign the loaded optimizer state to the new optimizer: restored_optimizer.state = loaded_bundle['optimizer'].\n", + "7. Verify that a parameter from restored_model matches the corresponding parameter from the original my_model_sol_ex7 (before saving, or from the saved state). Also, compare optimizer states if possible.\n", + "8. Clean up the checkpoint directory." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "V7XdNy-vAjpG" + }, + "outputs": [], + "source": [ + "# Instructions for Exercise 8\n", + "# import orbax.checkpoint as ocp # Already imported\n", + "# import os, shutil # Already imported\n", + "\n", + "# 1. Use model and optimizer from previous exercise solution\n", + "if 'my_model_sol_ex7' not in globals() or 'nnx_optimizer_sol_ex7' not in globals():\n", + " print(\"Re-initializing model and optimizer from Ex7 solution for Ex8.\")\n", + " key_ex8_model_fallback, main_key = jax.random.split(main_key)\n", + " _model_din_ex8 = 3\n", + " _model_dout_ex8 = 2\n", + " _model_rngs_ex8 = nnx.Rngs(params=key_ex8_model_fallback)\n", + " _ModelClassEx8 = SimpleNNXModel_Sol if 'SimpleNNXModel_Sol' in globals() else SimpleNNXModel\n", + " model_to_save = _ModelClassEx8(din=_model_din_ex8, dout=_model_dout_ex8, rngs=_model_rngs_ex8)\n", + " _optax_tx_ex8 = optax.adam(learning_rate=0.001)\n", + " optimizer_to_save = nnx.Optimizer(model_to_save, _optax_tx_ex8)\n", + " print(\"Model and optimizer re-created for Ex8.\")\n", + "else:\n", + " model_to_save = my_model_sol_ex7\n", + " optimizer_to_save = nnx_optimizer_sol_ex7\n", + " print(\"Using model and optimizer from Ex7 solution for Ex8.\")\n", + "\n", + "# 2. Define checkpoint directory\n", + "# TODO: Define checkpoint_dir\n", + "checkpoint_dir = None # Placeholder e.g., \"/tmp/my_nnx_checkpoint_exercise/\"\n", + "# if checkpoint_dir and os.path.exists(checkpoint_dir):\n", + "# shutil.rmtree(checkpoint_dir) # Clean up previous runs for safety\n", + "# if checkpoint_dir:\n", + "# os.makedirs(checkpoint_dir, exist_ok=True)\n", + "\n", + "\n", + "# 3. Create Orbax CheckpointManager\n", + "# TODO: Create options and manager\n", + "# options = ocp.CheckpointManagerOptions(...)\n", + "# mngr = ocp.CheckpointManager(...)\n", + "options = None\n", + "mngr = None\n", + "\n", + "# 4. Bundle states\n", + "# current_step = 100 # Example step\n", + "# TODO: Get model_state and optimizer_state\n", + "# model_state_to_save = nnx.split(model_to_save)\n", + "# The optimizer state is now accessed via the .state attribute.\n", + "# opt_state_to_save = optimizer_to_save.state\n", + "# save_bundle = {\n", + "# 'model': model_state_to_save,\n", + "# 'optimizer': opt_state_to_save,\n", + "# 'step': current_step\n", + "# }\n", + "save_bundle = None\n", + "\n", + "# 5. Save the checkpoint\n", + "# if mngr and save_bundle:\n", + "# TODO: Save checkpoint\n", + "# mngr.save(...)\n", + "# mngr.wait_until_finished()\n", + "# print(f\"Checkpoint saved at step {current_step} to {checkpoint_dir}\")\n", + "# else:\n", + "# print(\"Checkpoint manager or save_bundle not initialized.\")\n", + "\n", + "# --- Restoration ---\n", + "# 6.a Create new model and Optax transform (for restoration)\n", + "# key_ex8_restore_model, main_key = jax.random.split(main_key)\n", + "# din_restore = model_to_save.dense_layer.in_features if hasattr(model_to_save, 'dense_layer') else 3\n", + "# dout_restore = model_to_save.dense_layer.out_features if hasattr(model_to_save, 'dense_layer') else 2\n", + "# _ModelClassRestore = SimpleNNXModel_Sol if 'SimpleNNXModel_Sol' in globals() else SimpleNNXModel\n", + "# restored_model = _ModelClassRestore(\n", + "# din=din_restore, dout=dout_restore,\n", + "# rngs=nnx.Rngs(params=key_ex8_restore_model) # New key for different initial params\n", + "# )\n", + "# restored_optax_tx = optax.adam(learning_rate=0.001) # Same Optax config\n", + "restored_model = None\n", + "restored_optax_tx = None\n", + "\n", + "# 6.b Restore the checkpoint\n", + "# loaded_bundle = None\n", + "# if mngr:\n", + "# TODO: Restore checkpoint\n", + "# latest_step = mngr.latest_step()\n", + "# if latest_step is not None:\n", + "# loaded_bundle = mngr.restore(...)\n", + "# print(f\"Checkpoint restored from step {latest_step}\")\n", + "# else:\n", + "# print(\"No checkpoint found to restore.\")\n", + "# else:\n", + "# print(\"Checkpoint manager not initialized for restore.\")\n", + "\n", + "# 6.c Apply loaded states\n", + "# if loaded_bundle and restored_model:\n", + "# TODO: Update restored_model state\n", + "# nnx.update(restored_model, ...)\n", + "# print(\"Restored model state applied.\")\n", + "\n", + " # TODO: Create new nnx.Optimizer and assign its state\n", + "# restored_optimizer = nnx.Optimizer(...)\n", + "# restored_optimizer.state = ...\n", + "# print(\"Restored optimizer state applied.\")\n", + "# else:\n", + "# print(\"Loaded_bundle or restored_model is None, cannot apply states.\")\n", + "restored_optimizer = None\n", + "\n", + "# 7. Verify restoration\n", + "# original_kernel_sol = save_bundle_sol['model']['dense_layer']['kernel']\n", + "# _, restored_model_state = nnx.split(restored_model_sol)\n", + "# kernel_after_restore_sol = restored_model_state['dense_layer']['kernel']\n", + "# assert jnp.array_equal(original_kernel_sol.value, kernel_after_restore_sol.value), \\\n", + "# \"Model kernel parameters differ after restoration!\"\n", + "# print(\"\\nModel parameters successfully restored and verified (kernel match).\")\n", + "\n", + "# # Verify optimizer state (e.g., Adam's 'mu' for a specific parameter)\n", + "# original_opt_state_adam_mu_kernel = save_bundle_sol['optimizer'][0].mu['dense_layer']['kernel'].value\n", + "# restored_opt_state_adam_mu_kernel = restored_optimizer_sol.state[0].mu['dense_layer']['kernel'].value\n", + "# assert jnp.array_equal(original_opt_state_adam_mu_kernel, restored_opt_state_adam_mu_kernel), \\\n", + "# \"Optimizer Adam mu for kernel differs!\"\n", + "# print(\"Optimizer state (sample mu) successfully restored and verified.\")\n", + "\n", + "\n", + "# 8. Clean up\n", + "# if mngr:\n", + "# mngr.close()\n", + "# if checkpoint_dir and os.path.exists(checkpoint_dir):\n", + "# shutil.rmtree(checkpoint_dir)\n", + "# print(f\"Cleaned up checkpoint directory: {checkpoint_dir}\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "2-Fk8aukEGVL" + }, + "outputs": [], + "source": [ + "# @title Solution 8: Orbax - Saving and Restoring Checkpoints\n", + "\n", + "# 1. Use model and optimizer from previous exercise solution\n", + "if 'my_model_sol_ex7' not in globals() or 'nnx_optimizer_sol_ex7' not in globals():\n", + " print(\"Re-initializing model and optimizer from Ex7 solution for Ex8 solution.\")\n", + " key_ex8_sol_model_fallback, main_key = jax.random.split(main_key)\n", + " _model_din_sol_ex8 = 3\n", + " _model_dout_sol_ex8 = 2\n", + " _model_rngs_sol_ex8 = nnx.Rngs(params=key_ex8_sol_model_fallback)\n", + " # Ensure SimpleNNXModel_Sol is used for the solution\n", + " model_to_save_sol = SimpleNNXModel_Sol(din=_model_din_sol_ex8,\n", + " dout=_model_dout_sol_ex8,\n", + " rngs=_model_rngs_sol_ex8)\n", + " _optax_tx_sol_ex8 = optax.adam(learning_rate=0.001) # Store the transform for later\n", + " optimizer_to_save_sol = nnx.Optimizer(model_to_save_sol, _optax_tx_sol_ex8)\n", + " print(\"Model and optimizer re-created for Ex8 solution.\")\n", + "else:\n", + " model_to_save_sol = my_model_sol_ex7\n", + " optimizer_to_save_sol = nnx_optimizer_sol_ex7\n", + " # We need the optax transform used to create the optimizer for restoration\n", + " _optax_tx_sol_ex8 = optimizer_to_save_sol.tx # Access the original Optax transform\n", + " print(\"Using model and optimizer from Ex7 solution for Ex8 solution.\")\n", + "\n", + "# 2. Define checkpoint directory\n", + "checkpoint_dir_sol = \"/tmp/my_nnx_checkpoint_exercise_solution/\"\n", + "if os.path.exists(checkpoint_dir_sol):\n", + " shutil.rmtree(checkpoint_dir_sol) # Clean up previous runs\n", + "os.makedirs(checkpoint_dir_sol, exist_ok=True)\n", + "print(f\"Orbax checkpoint directory: {checkpoint_dir_sol}\")\n", + "\n", + "# 3. Create Orbax CheckpointManager\n", + "options_sol = ocp.CheckpointManagerOptions(save_interval_steps=1, max_to_keep=1)\n", + "mngr_sol = ocp.CheckpointManager(checkpoint_dir_sol, options=options_sol)\n", + "\n", + "# 4. Bundle states\n", + "current_step_sol = 100 # Example step\n", + "_, model_state_to_save_sol = nnx.split(model_to_save_sol)\n", + "# The optimizer state is now a PyTree directly available in the .state attribute.\n", + "opt_state_to_save_sol = optimizer_to_save_sol.opt_state\n", + "save_bundle_sol = {\n", + " 'model': model_state_to_save_sol,\n", + " 'optimizer': opt_state_to_save_sol,\n", + " 'step': current_step_sol\n", + "}\n", + "print(\"\\nState bundle to be saved:\")\n", + "pprint.pprint(f\"Model state keys: {model_state_to_save_sol.keys()}\")\n", + "pprint.pprint(f\"Optimizer state type: {type(opt_state_to_save_sol)}\")\n", + "\n", + "\n", + "# 5. Save the checkpoint\n", + "mngr_sol.save(current_step_sol, args=ocp.args.StandardSave(save_bundle_sol))\n", + "mngr_sol.wait_until_finished()\n", + "print(f\"\\nCheckpoint saved at step {current_step_sol} to {checkpoint_dir_sol}\")\n", + "\n", + "# --- Restoration ---\n", + "# 6.a Create new model and Optax transform (for restoration)\n", + "key_ex8_sol_restore_model, main_key = jax.random.split(main_key)\n", + "# Ensure din/dout are correctly obtained from the saved model's structure if possible\n", + "# Assuming model_to_save_sol is SimpleNNXModel_Sol which has a dense_layer\n", + "din_restore_sol = model_to_save_sol.dense_layer.in_features\n", + "dout_restore_sol = model_to_save_sol.dense_layer.out_features\n", + "\n", + "restored_model_sol = SimpleNNXModel_Sol( # Use the solution's model class\n", + " din=din_restore_sol, dout=dout_restore_sol,\n", + " rngs=nnx.Rngs(params=key_ex8_sol_restore_model) # New key for different initial params\n", + ")\n", + "# We need the original Optax transform definition for the new nnx.Optimizer\n", + "# _optax_tx_sol_ex8 was stored earlier, or can be re-created if config is known\n", + "restored_optax_tx_sol = _optax_tx_sol_ex8\n", + "\n", + "# Print a param from new model BEFORE restoration to show it's different\n", + "_, kernel_before_restore_sol = nnx.split(restored_model_sol)\n", + "print(f\"\\nSample kernel from 'restored_model_sol' BEFORE restoration:\")\n", + "nnx.display(kernel_before_restore_sol['dense_layer']['kernel'])\n", + "\n", + "# 6.b Restore the checkpoint\n", + "loaded_bundle_sol = None\n", + "latest_step_sol = mngr_sol.latest_step()\n", + "if latest_step_sol is not None:\n", + " # For NNX, we are restoring raw PyTrees, StandardRestore is suitable.\n", + " loaded_bundle_sol = mngr_sol.restore(latest_step_sol,\n", + " args=ocp.args.StandardRestore(save_bundle_sol))\n", + " print(f\"\\nCheckpoint restored from step {latest_step_sol}\")\n", + " print(f\"Loaded bundle contains keys: {loaded_bundle_sol.keys()}\")\n", + "else:\n", + " raise ValueError(\"No checkpoint found to restore.\")\n", + "\n", + "# 6.c Apply loaded states\n", + "assert loaded_bundle_sol is not None, \"Loaded bundle is None\"\n", + "nnx.update(restored_model_sol, loaded_bundle_sol['model'])\n", + "print(\"Restored model state applied to 'restored_model_sol'.\")\n", + "\n", + "# Create new nnx.Optimizer with the restored_model and original optax_tx\n", + "restored_optimizer_sol = nnx.Optimizer(restored_model_sol, restored_optax_tx_sol,\n", + " wrt=nnx.Param)\n", + "# Now assign the loaded Optax state PyTree\n", + "restored_optimizer_sol.state = loaded_bundle_sol['optimizer']\n", + "print(\"Restored optimizer state applied to 'restored_optimizer_sol'.\")\n", + "\n", + "\n", + "# 7. Verify restoration\n", + "original_kernel_sol = save_bundle_sol['model']['dense_layer']['kernel']\n", + "_, restored_model_state = nnx.split(restored_model_sol)\n", + "kernel_after_restore_sol = restored_model_state['dense_layer']['kernel']\n", + "assert jnp.array_equal(original_kernel_sol.value, kernel_after_restore_sol.value), \\\n", + " \"Model kernel parameters differ after restoration!\"\n", + "print(\"\\nModel parameters successfully restored and verified (kernel match).\")\n", + "\n", + "# Verify optimizer state (e.g., Adam's 'mu' for a specific parameter)\n", + "original_opt_state_adam_mu_kernel = save_bundle_sol['optimizer'][0].mu['dense_layer']['kernel'].value\n", + "restored_opt_state_adam_mu_kernel = restored_optimizer_sol.state[0].mu['dense_layer']['kernel'].value\n", + "assert jnp.array_equal(original_opt_state_adam_mu_kernel, restored_opt_state_adam_mu_kernel), \\\n", + " \"Optimizer Adam mu for kernel differs!\"\n", + "print(\"Optimizer state (sample mu) successfully restored and verified.\")\n", + "\n", + "\n", + "# 8. Clean up\n", + "mngr_sol.close()\n", + "if os.path.exists(checkpoint_dir_sol):\n", + " shutil.rmtree(checkpoint_dir_sol)\n", + " print(f\"Cleaned up checkpoint directory: {checkpoint_dir_sol}\")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "9kotBqE7Qhiv" + }, + "source": [ + "## Conclusion\n", + "\n", + "### Congratulations on completing the JAX AI Stack exercises!\n", + "\n", + "You've now had a hands-on introduction to:\n", + "\n", + "- Core JAX: jax.numpy, functional programming, jax.jit, jax.grad, jax.vmap.\n", + "- Flax NNX: Defining and instantiating Pythonic neural network models.\n", + "- Optax: Creating and using composable optimizers with Flax NNX.\n", + "- Training Loop: Implementing an end-to-end training step in Flax NNX.\n", + "- Orbax: Saving and restoring model and optimizer states.\n", + "\n", + "This forms a strong foundation for developing high-performance machine learning models with the JAX ecosystem.\n", + "\n", + "For further learning, refer to the official documentation:\n", + "- JAX AI Stack: https://jaxstack.ai\n", + "- JAX: https://jax.dev\n", + "- Flax NNX: https://flax.readthedocs.io\n", + "- Optax: https://optax.readthedocs.io\n", + "- Orbax: https://orbax.readthedocs.io\n", + "\n", + "Don't forget to provide feedback on the training session:\n", + "https://goo.gle/jax-training-feedback" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "TdQIp5G9QqwR" + }, + "outputs": [], + "source": [] + } + ], + "metadata": { + "colab": { + "provenance": [ + { + "file_id": "183UawZ8L3Tbm1ueDynDqO_TyGysgZ8rt", + "timestamp": 1755114181793 + } + ] + }, + "kernelspec": { + "display_name": "learning-jax", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.13.5" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +}