Skip to content
This repository was archived by the owner on Mar 3, 2026. It is now read-only.
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions .github/actions/e2e-setup/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -78,12 +78,18 @@ runs:
- name: tp use
run: >
tp use
--project '${{ inputs.gcp_project }}'
--zone '${{ inputs.gcp_zone }}'
--cluster '${{ inputs.xpk_cluster_name }}'
--project '${INPUTS_GCP_PROJECT}'
--zone '${INPUTS_GCP_ZONE}'
--cluster '${INPUTS_XPK_CLUSTER_NAME}'
--num-slices 1
--artifact-dir '${{ inputs.artifact_dir }}'
--tpu-type '${{ inputs.tpu_type }}'
--artifact-dir '${INPUTS_ARTIFACT_DIR}'
--tpu-type '${INPUTS_TPU_TYPE}'
--bq-table 'torchprime-e2e-tests'
--upload-metrics
shell: bash
env:
INPUTS_GCP_PROJECT: ${{ inputs.gcp_project }}
INPUTS_GCP_ZONE: ${{ inputs.gcp_zone }}
INPUTS_XPK_CLUSTER_NAME: ${{ inputs.xpk_cluster_name }}
INPUTS_ARTIFACT_DIR: ${{ inputs.artifact_dir }}
INPUTS_TPU_TYPE: ${{ inputs.tpu_type }}
39 changes: 26 additions & 13 deletions .github/workflows/e2e_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -58,22 +58,25 @@ jobs:
- name: Setup Docker URL option
id: docker-url-option
run: |
if [ -n "${{ github.event.inputs.docker_url }}" ]; then
echo "value=--base-docker-url ${{ github.event.inputs.docker_url }}" >> "$GITHUB_OUTPUT"
if [ -n "${GITHUB_EVENT_INPUTS_DOCKER_URL}" ]; then
echo "value=--base-docker-url ${GITHUB_EVENT_INPUTS_DOCKER_URL}" >> "$GITHUB_OUTPUT"
else
echo "value=" >> "$GITHUB_OUTPUT"
fi
env:
GITHUB_EVENT_INPUTS_DOCKER_URL: ${{ github.event.inputs.docker_url }}
# Launch training workloads.

- name: Run Llama 3.0 8B
id: run-llama-3-8b
env:
XLA_IR_DEBUG: 1
XLA_HLO_DEBUG: 1
STEPS_DOCKER_URL_OPTION_OUTPUTS_VALUE: ${{ steps.docker-url-option.outputs.value }}
run: |
name=$(e2e_testing/gen_name.py llama-3-8b)
echo "name=$name" >> "$GITHUB_OUTPUT"
tp run ${{ steps.docker-url-option.outputs.value }} \
tp run ${STEPS_DOCKER_URL_OPTION_OUTPUTS_VALUE} \
--name $name \
torchprime/torch_xla_models/train.py \
model=llama-3-8b \
Expand All @@ -91,10 +94,11 @@ jobs:
env:
XLA_IR_DEBUG: 1
XLA_HLO_DEBUG: 1
STEPS_DOCKER_URL_OPTION_OUTPUTS_VALUE: ${{ steps.docker-url-option.outputs.value }}
run: |
name=$(e2e_testing/gen_name.py llama-3-8b-pure-mlp)
echo "name=$name" >> "$GITHUB_OUTPUT"
tp run ${{ steps.docker-url-option.outputs.value }} \
tp run ${STEPS_DOCKER_URL_OPTION_OUTPUTS_VALUE} \
--name $name \
torchprime/torch_xla_models/train.py \
model=llama-3-8b \
Expand All @@ -113,10 +117,11 @@ jobs:
env:
XLA_IR_DEBUG: 1
XLA_HLO_DEBUG: 1
STEPS_DOCKER_URL_OPTION_OUTPUTS_VALUE: ${{ steps.docker-url-option.outputs.value }}
run: |
name=$(e2e_testing/gen_name.py llama-3dot1-8b-sa)
echo "name=$name" >> "$GITHUB_OUTPUT"
tp run ${{ steps.docker-url-option.outputs.value }} \
tp run ${STEPS_DOCKER_URL_OPTION_OUTPUTS_VALUE} \
--name $name \
torchprime/torch_xla_models/train.py \
model=llama-3.1-8b \
Expand All @@ -135,10 +140,11 @@ jobs:
env:
XLA_IR_DEBUG: 1
XLA_HLO_DEBUG: 1
STEPS_DOCKER_URL_OPTION_OUTPUTS_VALUE: ${{ steps.docker-url-option.outputs.value }}
run: |
name=$(e2e_testing/gen_name.py llama-3dot1-8b-scan-offload)
echo "name=$name" >> "$GITHUB_OUTPUT"
tp run ${{ steps.docker-url-option.outputs.value }} \
tp run ${STEPS_DOCKER_URL_OPTION_OUTPUTS_VALUE} \
--name $name \
torchprime/torch_xla_models/train.py \
model=llama-3.1-8b \
Expand All @@ -157,10 +163,11 @@ jobs:
env:
XLA_IR_DEBUG: 1
XLA_HLO_DEBUG: 1
STEPS_DOCKER_URL_OPTION_OUTPUTS_VALUE: ${{ steps.docker-url-option.outputs.value }}
run: |
name=$(e2e_testing/gen_name.py llama-3-8b-2d)
echo "name=$name" >> "$GITHUB_OUTPUT"
tp run ${{ steps.docker-url-option.outputs.value }} \
tp run ${STEPS_DOCKER_URL_OPTION_OUTPUTS_VALUE} \
--name $name \
torchprime/torch_xla_models/train.py \
model=llama-3-8b \
Expand All @@ -180,10 +187,11 @@ jobs:
env:
XLA_IR_DEBUG: 1
XLA_HLO_DEBUG: 1
STEPS_DOCKER_URL_OPTION_OUTPUTS_VALUE: ${{ steps.docker-url-option.outputs.value }}
run: |
name=$(e2e_testing/gen_name.py llama-3-8b-fsdp-cp)
echo "name=$name" >> "$GITHUB_OUTPUT"
tp run ${{ steps.docker-url-option.outputs.value }} \
tp run ${STEPS_DOCKER_URL_OPTION_OUTPUTS_VALUE} \
--name $name \
torchprime/torch_xla_models/train.py \
model=llama-3-8b-cp \
Expand All @@ -202,10 +210,11 @@ jobs:
env:
XLA_IR_DEBUG: 1
XLA_HLO_DEBUG: 1
STEPS_DOCKER_URL_OPTION_OUTPUTS_VALUE: ${{ steps.docker-url-option.outputs.value }}
run: |
name=$(e2e_testing/gen_name.py mixtral-8x7b)
echo "name=$name" >> "$GITHUB_OUTPUT"
tp run ${{ steps.docker-url-option.outputs.value }} \
tp run ${STEPS_DOCKER_URL_OPTION_OUTPUTS_VALUE} \
--name $name \
torchprime/torch_xla_models/train.py \
model=mixtral-8x7b \
Expand All @@ -224,10 +233,11 @@ jobs:
env:
XLA_IR_DEBUG: 1
XLA_HLO_DEBUG: 1
STEPS_DOCKER_URL_OPTION_OUTPUTS_VALUE: ${{ steps.docker-url-option.outputs.value }}
run: |
name=$(e2e_testing/gen_name.py llama-3-8b-2-slice)
echo "name=$name" >> "$GITHUB_OUTPUT"
tp run ${{ steps.docker-url-option.outputs.value }} \
tp run ${STEPS_DOCKER_URL_OPTION_OUTPUTS_VALUE} \
--name $name \
--num-slices 2 \
torchprime/torch_xla_models/train.py \
Expand All @@ -248,10 +258,11 @@ jobs:
env:
XLA_IR_DEBUG: 1
XLA_HLO_DEBUG: 1
STEPS_DOCKER_URL_OPTION_OUTPUTS_VALUE: ${{ steps.docker-url-option.outputs.value }}
run: |
name=$(e2e_testing/gen_name.py llama-3-8b-sft)
echo "name=$name" >> "$GITHUB_OUTPUT"
tp run ${{ steps.docker-url-option.outputs.value }} \
tp run ${STEPS_DOCKER_URL_OPTION_OUTPUTS_VALUE} \
--name $name \
torchprime/torch_xla_models/train.py \
--config-name llama-3-8b-sft-w-gsm8k \
Expand All @@ -268,10 +279,11 @@ jobs:
env:
XLA_IR_DEBUG: 1
XLA_HLO_DEBUG: 1
STEPS_DOCKER_URL_OPTION_OUTPUTS_VALUE: ${{ steps.docker-url-option.outputs.value }}
run: |
name=$(e2e_testing/gen_name.py llama-3-8b-ddp-fsdp)
echo "name=$name" >> "$GITHUB_OUTPUT"
tp run ${{ steps.docker-url-option.outputs.value }} \
tp run ${STEPS_DOCKER_URL_OPTION_OUTPUTS_VALUE} \
--name $name \
--num-slices 2 \
torchprime/torch_xla_models/train.py \
Expand All @@ -292,10 +304,11 @@ jobs:
env:
XLA_IR_DEBUG: 1
XLA_HLO_DEBUG: 1
STEPS_DOCKER_URL_OPTION_OUTPUTS_VALUE: ${{ steps.docker-url-option.outputs.value }}
run: |
name=$(e2e_testing/gen_name.py ds-v3-shallow)
echo "name=$name" >> "$GITHUB_OUTPUT"
tp run ${{ steps.docker-url-option.outputs.value }} \
tp run ${STEPS_DOCKER_URL_OPTION_OUTPUTS_VALUE} \
--name $name \
torchprime/torch_xla_models/train.py \
model=deepseek-v3 \
Expand Down
44 changes: 35 additions & 9 deletions .github/workflows/reusable_e2e_check.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,19 @@ jobs:
gcp_sa_key: ${{ secrets.GCP_SA_KEY }}
- name: Get GKE credentials
run: |
gcloud container clusters get-credentials ${{ vars.XPK_CLUSTER_NAME }} --region=${{ vars.GCP_ZONE }} --project=${{ vars.GCP_PROJECT }}
gcloud container clusters get-credentials ${VARS_XPK_CLUSTER_NAME} --region=${VARS_GCP_ZONE} --project=${VARS_GCP_PROJECT}
kubectl config view
kubectl config set-context --current --namespace=default
env:
VARS_XPK_CLUSTER_NAME: ${{ vars.XPK_CLUSTER_NAME }}
VARS_GCP_ZONE: ${{ vars.GCP_ZONE }}
VARS_GCP_PROJECT: ${{ vars.GCP_PROJECT }}
- name: Get pod name
id: get_pod_name
run: |
# Wait for pod to exist
for i in {1..60}; do
pod_name=$(kubectl get pods -l jobset.sigs.k8s.io/jobset-name=${{ inputs.jobset_name }} -o jsonpath='{.items[0].metadata.name}' 2>/dev/null || echo "")
pod_name=$(kubectl get pods -l jobset.sigs.k8s.io/jobset-name=${INPUTS_JOBSET_NAME} -o jsonpath='{.items[0].metadata.name}' 2>/dev/null || echo "")
if [[ -n "$pod_name" && "$pod_name" != "null" ]]; then
echo "pod_name=$pod_name" >> $GITHUB_OUTPUT
exit 0
Expand All @@ -64,9 +68,11 @@ jobs:
done
echo "❌ ERROR: Pod not found after 60 minutes"
exit 1
env:
INPUTS_JOBSET_NAME: ${{ inputs.jobset_name }}
- name: Wait for workload to start
run: |
pod_name="${{ steps.get_pod_name.outputs.pod_name }}"
pod_name="${STEPS_GET_POD_NAME_OUTPUTS_POD_NAME}"
# Check if pod is already done or running
for i in {1..60}; do
phase=$(kubectl get pod "$pod_name" -o jsonpath='{.status.phase}' 2>/dev/null || echo "Unknown")
Expand All @@ -88,27 +94,47 @@ jobs:
done
echo "❌ ERROR: Timeout waiting for pod to start"
exit 1
env:
STEPS_GET_POD_NAME_OUTPUTS_POD_NAME: ${{ steps.get_pod_name.outputs.pod_name }}
- name: Stream logs
run: |
# Save logs to a file for later checks
kubectl logs -c jax-tpu -f ${{ steps.get_pod_name.outputs.pod_name }} | tee /tmp/pod-${{ steps.get_pod_name.outputs.pod_name }}.log
kubectl logs -c jax-tpu -f ${STEPS_GET_POD_NAME_OUTPUTS_POD_NAME} | tee /tmp/pod-${STEPS_GET_POD_NAME_OUTPUTS_POD_NAME}.log
env:
STEPS_GET_POD_NAME_OUTPUTS_POD_NAME: ${{ steps.get_pod_name.outputs.pod_name }}
- name: Wait for workload to complete
run: |
xpk workload list --cluster ${{ vars.XPK_CLUSTER_NAME }} --wait-for-job-completion=${{ inputs.jobset_name }} --project ${{ vars.GCP_PROJECT }} --zone ${{ vars.GCP_ZONE }}
xpk workload list --cluster ${VARS_XPK_CLUSTER_NAME} --wait-for-job-completion=${INPUTS_JOBSET_NAME} --project ${VARS_GCP_PROJECT} --zone ${VARS_GCP_ZONE}
env:
VARS_XPK_CLUSTER_NAME: ${{ vars.XPK_CLUSTER_NAME }}
INPUTS_JOBSET_NAME: ${{ inputs.jobset_name }}
VARS_GCP_PROJECT: ${{ vars.GCP_PROJECT }}
VARS_GCP_ZONE: ${{ vars.GCP_ZONE }}
- name: Validate logs
run: |
e2e_testing/check_logs.py /tmp/pod-${{ steps.get_pod_name.outputs.pod_name }}.log
e2e_testing/check_logs.py /tmp/pod-${STEPS_GET_POD_NAME_OUTPUTS_POD_NAME}.log
env:
STEPS_GET_POD_NAME_OUTPUTS_POD_NAME: ${{ steps.get_pod_name.outputs.pod_name }}
- name: Validate profile
run: |
profile_dir="${{ inputs.artifact_dir }}/${{ inputs.jobset_name }}/profile/0-0"
profile_dir="${INPUTS_ARTIFACT_DIR}/${INPUTS_JOBSET_NAME}/profile/0-0"
e2e_testing/check_profile.py "$profile_dir"
env:
INPUTS_ARTIFACT_DIR: ${{ inputs.artifact_dir }}
INPUTS_JOBSET_NAME: ${{ inputs.jobset_name }}
- name: Validate metrics
run: |
output_dir="${{ inputs.artifact_dir }}/${{ inputs.jobset_name }}/outputs/0-0"
output_dir="${INPUTS_ARTIFACT_DIR}/${INPUTS_JOBSET_NAME}/outputs/0-0"
e2e_testing/check_step_time.py "$output_dir" "${{ inputs.step_time_lower_bound }}" "${{ inputs.step_time_upper_bound }}"
env:
INPUTS_ARTIFACT_DIR: ${{ inputs.artifact_dir }}
INPUTS_JOBSET_NAME: ${{ inputs.jobset_name }}
- name: Validate loss
if: ${{ inputs.target_loss }}
run: |
e2e_testing/check_loss.py \
/tmp/pod-${{ steps.get_pod_name.outputs.pod_name }}.log \
/tmp/pod-${STEPS_GET_POD_NAME_OUTPUTS_POD_NAME}.log \
"${{ inputs.target_loss }}" "${{ inputs.loss_tolerance }}"

env:
STEPS_GET_POD_NAME_OUTPUTS_POD_NAME: ${{ steps.get_pod_name.outputs.pod_name }}
Loading