diff --git a/.github/CI_README.md b/.github/CI_README.md index f01ad04a..2a739056 100644 --- a/.github/CI_README.md +++ b/.github/CI_README.md @@ -44,8 +44,8 @@ Fast validation for pull requests and develop branch pushes. ### Biweekly Tests (`biweekly-tests.yml`) Comprehensive testing for large model validation. -**Triggers**: Biweekly schedule (Monday/Thursday 00:00 UTC) -**Runtime**: ~24 hours +**Triggers**: Biweekly schedule (Monday/Thursday 00:00 UTC) +**Runtime**: ~24 hours **Job**: `bert-large-comprehensive-test` (BERT Large Model Comprehensive Test) **Steps**: @@ -95,7 +95,7 @@ Complete test lifecycle with conditional artifact collection. - `docker-cleanup` - Cleans containers AND persistent build directories - `collect-artifacts` - Collects system info, container logs, and test artifacts -#### Docker Actions +#### Docker Actions - `build-docker` - Builds image with verification and timing fixes - `docker-exec` - Executes commands with container lifecycle management diff --git a/.github/actions/build-docker/action.yml b/.github/actions/build-docker/action.yml index 12c0fcab..3e7eba2c 100644 --- a/.github/actions/build-docker/action.yml +++ b/.github/actions/build-docker/action.yml @@ -10,17 +10,17 @@ runs: chmod +x ctl-docker.sh echo "=== Building Docker image ===" ./ctl-docker.sh build - + echo "=== Verifying image was built ===" echo "Expected image tag: $BSMITH_DOCKER_TAG" - + # Wait a moment for Docker to register the image sleep 2 - + # Show all docker images for debugging echo "All Docker images:" docker images - + # Check for the specific tag if docker images --format "{{.Repository}}:{{.Tag}}" | grep -q "microsoft/brainsmith:"; then echo "✓ Docker image built and verified successfully" @@ -29,4 +29,4 @@ runs: echo "Looking for images matching: microsoft/brainsmith:" docker images | grep "microsoft/brainsmith" || echo "No matching images found" exit 1 - fi \ No newline at end of file + fi diff --git a/.github/actions/collect-artifacts/action.yml b/.github/actions/collect-artifacts/action.yml index 63e1e2c2..add068d9 100644 --- a/.github/actions/collect-artifacts/action.yml +++ b/.github/actions/collect-artifacts/action.yml @@ -13,32 +13,32 @@ runs: shell: bash run: | ARTIFACT_DIR="${{ inputs.artifact-directory }}" - + # Validate directory name if [[ ! "$ARTIFACT_DIR" =~ ^[a-zA-Z0-9_\-]+$ ]]; then echo "ERROR: Artifact directory contains invalid characters" echo "Allowed: alphanumeric, underscore, hyphen" exit 1 fi - + # Prevent path traversal if [[ "$ARTIFACT_DIR" =~ \.\./|^/ ]]; then echo "ERROR: Invalid artifact directory path (path traversal detected)" exit 1 fi - + echo "✓ Artifact directory validated: $ARTIFACT_DIR" - + - name: Collect artifacts shell: bash run: | ARTIFACT_DIR="${{ inputs.artifact-directory }}" mkdir -p "$ARTIFACT_DIR" - + echo "=== Collecting system info ===" df -h > "$ARTIFACT_DIR/disk_usage.txt" 2>/dev/null || true free -h > "$ARTIFACT_DIR/memory_usage.txt" 2>/dev/null || true - + echo "=== Collecting container info ===" if [ -x ./ctl-docker.sh ]; then ./ctl-docker.sh status > "$ARTIFACT_DIR/container_status.txt" 2>&1 || echo "Status failed" > "$ARTIFACT_DIR/container_status.txt" @@ -57,4 +57,4 @@ runs: echo "No container diagnostics found (expected if test succeeded or COLLECT_DIAGNOSTICS not enabled)" fi - echo "✓ Artifacts collected in $ARTIFACT_DIR" \ No newline at end of file + echo "✓ Artifacts collected in $ARTIFACT_DIR" diff --git a/.github/actions/docker-cleanup/action.yml b/.github/actions/docker-cleanup/action.yml index c095dc58..e98ff191 100644 --- a/.github/actions/docker-cleanup/action.yml +++ b/.github/actions/docker-cleanup/action.yml @@ -31,4 +31,4 @@ runs: # Note: BuildKit cache layers are preserved automatically echo "✓ BuildKit cache layers preserved for future builds" - echo "Available space: $(df -h / | tail -1 | awk '{print $4}')" \ No newline at end of file + echo "Available space: $(df -h / | tail -1 | awk '{print $4}')" diff --git a/.github/actions/run-test-with-artifacts/action.yml b/.github/actions/run-test-with-artifacts/action.yml index bac0fda7..d2b8f78a 100644 --- a/.github/actions/run-test-with-artifacts/action.yml +++ b/.github/actions/run-test-with-artifacts/action.yml @@ -46,8 +46,8 @@ runs: - name: Collect artifacts id: collect-artifacts if: | - always() && - (inputs.collect-on == 'always' || + always() && + (inputs.collect-on == 'always' || (inputs.collect-on == 'failure' && steps.test-execution.outcome == 'failure')) uses: ./.github/actions/collect-artifacts with: @@ -56,8 +56,8 @@ runs: - name: Upload artifacts id: upload-artifacts if: | - always() && - (inputs.collect-on == 'always' || + always() && + (inputs.collect-on == 'always' || (inputs.collect-on == 'failure' && steps.test-execution.outcome == 'failure')) uses: actions/upload-artifact@v4 with: @@ -67,4 +67,4 @@ runs: - name: Final cleanup if: always() - uses: ./.github/actions/docker-cleanup \ No newline at end of file + uses: ./.github/actions/docker-cleanup diff --git a/.github/workflows/biweekly-tests.yml b/.github/workflows/biweekly-tests.yml index 7f9fe6c5..9cc01ca3 100644 --- a/.github/workflows/biweekly-tests.yml +++ b/.github/workflows/biweekly-tests.yml @@ -44,4 +44,4 @@ jobs: timeout-minutes: 1400 artifact-name: "biweekly-artifacts" collect-on: "always" - retention-days: 14 \ No newline at end of file + retention-days: 14 diff --git a/docker/Dockerfile b/docker/Dockerfile index f116fa62..1f993252 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -69,4 +69,4 @@ RUN chmod 755 /usr/local/bin/entrypoint.sh /usr/local/bin/entrypoint-exec.sh /us # Set default entrypoint ENTRYPOINT ["/usr/local/bin/entrypoint.sh"] -CMD ["bash"] \ No newline at end of file +CMD ["bash"] diff --git a/docker/entrypoint.sh b/docker/entrypoint.sh index f24724ed..515926ee 100755 --- a/docker/entrypoint.sh +++ b/docker/entrypoint.sh @@ -29,7 +29,7 @@ fi # Run one-stop setup (unless explicitly skipped) if [ "$BSMITH_SKIP_SETUP" != "1" ]; then log "Running setup..." - + # Build setup flags SETUP_FLAGS="--docker" # Always use docker mode [ -n "$BSMITH_FORCE_SETUP" ] && SETUP_FLAGS="$SETUP_FLAGS --force" @@ -39,7 +39,7 @@ if [ "$BSMITH_SKIP_SETUP" != "1" ]; then if [ -n "$BSMITH_SKIP_COMPONENTS" ]; then SETUP_FLAGS="$SETUP_FLAGS --skip $BSMITH_SKIP_COMPONENTS" fi - + # Run setup if ./setup-venv.sh $SETUP_FLAGS; then log "✓ Setup completed successfully" @@ -56,4 +56,4 @@ source /usr/local/bin/setup-shell.sh # Keep container alive in daemon mode log "Container ready" -exec tail -f /dev/null \ No newline at end of file +exec tail -f /dev/null diff --git a/examples/bert/bert_demo.py b/examples/bert/bert_demo.py index 2379859a..1a066823 100644 --- a/examples/bert/bert_demo.py +++ b/examples/bert/bert_demo.py @@ -57,14 +57,14 @@ def generate_bert_model(args): """Generate quantized BERT model from HuggingFace with Brevitas quantization. - + This matches the functionality from old end2end_bert.py::gen_initial_bert_model() """ print(f"Generating BERT model with {args.num_hidden_layers} layers...") - + # Global consts used by Brevitas build step dtype = torch.float32 - + # Create BERT configuration config = BertConfig( hidden_size=args.hidden_size, @@ -74,39 +74,39 @@ def generate_bert_model(args): attn_implementation="sdpa", hidden_act="relu", ) - + # Initialize model model = BertModel(config=config) model.to(dtype=dtype) model.eval() - + # Prepare inputs vocab_size = model.config.vocab_size seq_len = args.seqlen batch_size = 1 - + input_ids = torch.randint(vocab_size, (batch_size, seq_len), dtype=torch.int64) inp = {'input_ids': input_ids} - + # Symbolic tracing input_names = inp.keys() model = symbolic_trace(model, input_names) - + # Replace SDPA with quantizable layers print("Replacing SDPA with quantizable variants...") model = replace_sdpa_with_quantizable_layers(model) print("Replacement done.") - + # Configure quantization unsigned_hidden_act = config.hidden_act == 'relu' layerwise_compute_layer_map = {} - + # Linear layer quantization layerwise_compute_layer_map[nn.Linear] = ( qnn.QuantLinear, { - 'input_quant': lambda module: Uint8ActPerTensorFloat - if module.in_features == config.intermediate_size and unsigned_hidden_act + 'input_quant': lambda module: Uint8ActPerTensorFloat + if module.in_features == config.intermediate_size and unsigned_hidden_act else Int8ActPerTensorFloat, 'weight_quant': Int8WeightPerTensorFloat, 'weight_bit_width': args.bitwidth, @@ -115,7 +115,7 @@ def generate_bert_model(args): 'return_quant_tensor': False } ) - + # Attention quantization layerwise_compute_layer_map[qnn.ScaledDotProductAttention] = ( qnn.QuantScaledDotProductAttention, @@ -134,7 +134,7 @@ def generate_bert_model(args): 'return_quant_tensor': False } ) - + # Tanh quantization layerwise_compute_layer_map[nn.Tanh] = ( qnn.QuantTanh, @@ -145,19 +145,19 @@ def generate_bert_model(args): 'return_quant_tensor': False } ) - + # Apply quantization quant_model = layerwise_quantize(model, compute_layer_map=layerwise_compute_layer_map) quant_model.to(dtype=dtype) - + # Calibration with torch.no_grad(), calibration_mode(quant_model): quant_model(**inp) - + # Export to ONNX with tempfile.NamedTemporaryFile(suffix='.onnx', delete=False) as tmp: tmp_path = tmp.name - + with torch.no_grad(): bo.export_qonnx( quant_model, @@ -167,11 +167,11 @@ def generate_bert_model(args): input_names=['input_ids'], opset_version=17, ) - + # Load and return model model = onnx.load(tmp_path) os.unlink(tmp_path) - + # Save initial Brevitas model for debugging debug_path = os.path.join(args.output_dir, "debug_models") os.makedirs(debug_path, exist_ok=True) @@ -180,7 +180,7 @@ def generate_bert_model(args): print(f" - Model inputs: {[i.name for i in model.graph.input]}") print(f" - Model outputs: {[o.name for o in model.graph.output]}") print(f" - Number of nodes: {len(model.graph.node)}") - + return model @@ -190,25 +190,25 @@ def run_brainsmith_dse(model, args): os.makedirs(args.output_dir, exist_ok=True) model_dir = os.path.join(args.output_dir, "intermediate_models") os.makedirs(model_dir, exist_ok=True) - + # Simplify model (matches old hw_compiler.py) model, check = simplify(model) if not check: raise RuntimeError("Unable to simplify the Brevitas BERT model") - + # Save simplified model onnx.save(model, os.path.join(model_dir, "simp.onnx")) # Also save to debug directory for comparison debug_dir = os.path.join(args.output_dir, "debug_models") onnx.save(model, os.path.join(debug_dir, "01_after_simplify.onnx")) print(f"Saved simplified model to debug_models/01_after_simplify.onnx") - + # Run cleanup cleanup( in_file=os.path.join(model_dir, "simp.onnx"), out_file=os.path.join(args.output_dir, "df_input.onnx") ) - + # Save a copy of the cleaned model for visualization import shutil debug_dir = os.path.join(args.output_dir, "debug_models") @@ -217,10 +217,10 @@ def run_brainsmith_dse(model, args): os.path.join(args.output_dir, "df_input.onnx"), os.path.join(debug_dir, "02_after_qonnx_cleanup.onnx") ) - + # Get blueprint path from args blueprint_path = Path(__file__).parent / args.blueprint - + # Create the FPGA accelerator print("Creating FPGA accelerator...") results = explore_design_space( @@ -228,22 +228,22 @@ def run_brainsmith_dse(model, args): blueprint_path=str(blueprint_path), output_dir=args.output_dir ) - + # Results are automatically logged by explore_design_space() # Just check if we succeeded stats = results.compute_stats() if stats['successful'] == 0: raise RuntimeError(f"No successful builds") - + # The new execution tree handles output automatically final_model_dst = os.path.join(args.output_dir, "output.onnx") - + # Find the output from the successful execution for segment_id, result in results.segment_results.items(): if result.status == SegmentStatus.COMPLETED and result.output_model: shutil.copy2(result.output_model, final_model_dst) break - + # Handle shell metadata (matches old hw_compiler.py) handover_file = os.path.join(args.output_dir, "stitched_ip", "shell_handover.json") if os.path.exists(handover_file): @@ -252,7 +252,7 @@ def run_brainsmith_dse(model, args): handover["num_layers"] = args.num_hidden_layers with open(handover_file, "w") as fp: json.dump(handover, fp, indent=4) - + return results @@ -260,42 +260,42 @@ def main(): parser = argparse.ArgumentParser( description='Modern BERT FINN demo - Exact parity with old system using Brainsmith DFC' ) - + # Model configuration parser.add_argument('-o', '--output', help='Output build directory name', required=True) - parser.add_argument('-z', '--hidden_size', type=int, default=384, + parser.add_argument('-z', '--hidden_size', type=int, default=384, help='BERT hidden_size parameter') - parser.add_argument('-n', '--num_attention_heads', type=int, default=12, + parser.add_argument('-n', '--num_attention_heads', type=int, default=12, help='BERT num_attention_heads parameter') - parser.add_argument('-l', '--num_hidden_layers', type=int, default=1, + parser.add_argument('-l', '--num_hidden_layers', type=int, default=1, help='Number of hidden layers') - parser.add_argument('-i', '--intermediate_size', type=int, default=1536, + parser.add_argument('-i', '--intermediate_size', type=int, default=1536, help='BERT intermediate_size parameter') - parser.add_argument('-b', '--bitwidth', type=int, default=8, + parser.add_argument('-b', '--bitwidth', type=int, default=8, help='Quantization bitwidth (4 or 8)') - parser.add_argument('-q', '--seqlen', type=int, default=128, + parser.add_argument('-q', '--seqlen', type=int, default=128, help='Sequence length parameter') - + # Blueprint configuration parser.add_argument('--blueprint', type=str, default='bert_demo.yaml', help='Blueprint YAML file to use (default: bert_demo.yaml)') - + # Force flag parser.add_argument('--force', action='store_true', help='Remove existing output directory before building') - + args = parser.parse_args() - + # Determine output directory build_dir = get_config().build_dir print(build_dir) args.output_dir = os.path.join(str(build_dir), args.output) - + # Clean up existing directory if --force flag is set if args.force and os.path.exists(args.output_dir): print(f"Removing existing output directory: {args.output_dir}") shutil.rmtree(args.output_dir) - + print("=" * 70) print("BERT Demo Using Brainsmith DFC") print("=" * 70) @@ -309,25 +309,25 @@ def main(): print(f" Blueprint: {args.blueprint}") print(f" Output directory: {args.output_dir}") print("=" * 70) - + try: # Step 1: Generate BERT model print("\nStep 1: Generating quantized BERT model...") model = generate_bert_model(args) - + # Step 2: Create dataflow core accelerator print("\nStep 2: Creating dataflow core accelerator...") result = run_brainsmith_dse(model, args) - + print("\n" + "=" * 70) print("BUILD COMPLETED SUCCESSFULLY") print("=" * 70) print(f"Output directory: {args.output_dir}") - + except Exception as e: print(f"\nERROR: Build failed with error: {e}") raise if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/examples/bert/quicktest.sh b/examples/bert/quicktest.sh index 119262f5..1bd26f0c 100755 --- a/examples/bert/quicktest.sh +++ b/examples/bert/quicktest.sh @@ -75,4 +75,4 @@ python bert_demo.py \ --force -echo "Quick test completed!" \ No newline at end of file +echo "Quick test completed!" diff --git a/examples/dlrm/README_onnx_context.md b/examples/dlrm/README_onnx_context.md new file mode 100644 index 00000000..08e74c9f --- /dev/null +++ b/examples/dlrm/README_onnx_context.md @@ -0,0 +1,121 @@ +# ONNX Context Export Utility + +## Overview + +`export_onnx_context.py` is a utility that converts ONNX models into a comprehensive, human-readable format suitable for providing context to AI assistants or for documentation purposes. + +## Features + +The utility generates a structured context file containing: + +1. **Model Metadata** - IR version, opset, producer information +2. **Graph Inputs/Outputs** - Names, types, and shapes +3. **Initializers Summary** - List of all weights and constants with sizes +4. **Operation Statistics** - Count of each operation type in the model +5. **Node Connectivity** - Which nodes consume each input +6. **Complete Graph Structure** - Full ONNX Script IR representation + +## Usage + +### Basic Usage + +```bash +python export_onnx_context.py model.onnx +``` + +This creates `model_context.txt` in the same directory. + +### Specify Output File + +```bash +python export_onnx_context.py model.onnx -o custom_output.txt +``` + +### Verbose Mode + +```bash +python export_onnx_context.py model.onnx -v +``` + +Shows progress and statistics during export. + +### Command Line Options + +``` +positional arguments: + onnx_file Path to the ONNX model file + +optional arguments: + -h, --help show this help message and exit + -o OUTPUT, --output OUTPUT + Output file path (default: {model_name}_context.txt) + --max-array-size MAX_ARRAY_SIZE + Maximum number of array elements to display (default: 20) + -v, --verbose Print verbose output +``` + +## Example Output Structure + +``` +================================================================================ +ONNX Model Context: dlrm_s_pytorch.onnx +================================================================================ + +# Model Metadata +-------------------------------------------------------------------------------- +IR Version: 10 +Opset Imports: {'': 18} +... + +# Graph Inputs +-------------------------------------------------------------------------------- +1. dense_x + Type: FLOAT + Shape: [s32,4] +... + +# Operation Statistics +-------------------------------------------------------------------------------- +Total Nodes: 63 + +Operation Type Distribution: + Concat : 6 + Gather : 6 +... + +# Complete Graph Structure (ONNX Script IR Format) +================================================================================ + +... +``` + +## Use Cases + +1. **Providing Context to AI Assistants** - Share the generated `.txt` file to give comprehensive model information +2. **Documentation** - Create readable documentation of model architecture +3. **Debugging** - Understand model structure and data flow +4. **Code Review** - Share model structure without needing ONNX visualization tools + +## Tips + +- The context file is plain text and can be easily read with any text editor +- For very large models, the file may be large; consider using `head` or `less` to view sections +- The ONNX Script IR format shows exact operation types, attributes, and connections +- Use verbose mode (`-v`) to see file size and node count statistics + +## Integration with AI Assistants + +When working with an AI assistant on ONNX models: + +1. Export the model to context format: + ```bash + python export_onnx_context.py your_model.onnx -v + ``` + +2. The assistant can then read the generated `your_model_context.txt` file to understand: + - Model architecture + - Input/output specifications + - Operation types and counts + - Complete computational graph + +This provides much better context than trying to describe the model verbally! diff --git a/examples/dlrm/dlrm.yaml b/examples/dlrm/dlrm.yaml new file mode 100644 index 00000000..37809f03 --- /dev/null +++ b/examples/dlrm/dlrm.yaml @@ -0,0 +1,54 @@ + +name: "DLRM Demo" +description: "FaceBook PyTorch DLRM model" + +#extends: "${BSMITH_DIR}/examples/blueprints/bert.yaml" + +# Configuration overrides +clock_ns: 5.0 # Target clock period in nanoseconds +output: "bitfile" # estimates | rtl | bitfile +board: "V80" # Target FPGA board +save_intermediate_models: true # Save intermediate ONNX models + +# Direct override FINN configuration options +finn_config: + enable_build_pdb_debug: true +# standalone_thresholds: true +# target_fps: 3000 # Target inference FPS (auto-determines PE/SIMD) +# folding_config_file: null # Path to manual folding config JSON (optional) +# split_large_fifos: true + +design_space: + # Inherit kernels from parent blueprint + kernels: + - LayerNorm + - DuplicateStreams + - ElementwiseBinaryOperation + - Shuffle + - Softmax + - Thresholding + - MVAU + - StreamingConcat + - Lookup + + # Add pre/post-processing steps to standard BERT blueprint + steps: + - bert_cleanup # brainsmith.steps.bert_custom_steps + - split_sparse_processing + - dense_cleanup + - qonnx_to_finn # custom_step_qonnx2finn + # Topology optimization + - streamline + # Core FINN steps + - infer_kernels # Brainsmith dynamic kernel inference + # - create_dataflow_partition + - specialize_layers + - target_fps_parallelization + - apply_folding_config + - minimize_bit_width + - generate_estimate_reports + - hw_codegen + - hw_ipgen + - set_fifo_depths + - create_stitched_ip + - measure_rtlsim_performance diff --git a/examples/dlrm/dlrm_demo.py b/examples/dlrm/dlrm_demo.py new file mode 100644 index 00000000..3558c392 --- /dev/null +++ b/examples/dlrm/dlrm_demo.py @@ -0,0 +1,503 @@ +import torch +import torchrec + +from torchrec.models.dlrm import DLRM, DLRM_DCN +from torchrec import EmbeddingBagCollection, EmbeddingBagConfig +from torchrec.datasets.criteo import DEFAULT_CAT_NAMES, DEFAULT_INT_NAMES + + +### define dlrm in PyTorch ### +class DLRM_Net(nn.Module): + def create_mlp(self, ln, sigmoid_layer): + # build MLP layer by layer + layers = nn.ModuleList() + for i in range(0, ln.size - 1): + n = ln[i] + m = ln[i + 1] + + # construct fully connected operator + LL = nn.Linear(int(n), int(m), bias=True) + + # initialize the weights + # with torch.no_grad(): + # custom Xavier input, output or two-sided fill + mean = 0.0 # std_dev = np.sqrt(variance) + std_dev = np.sqrt(2 / (m + n)) # np.sqrt(1 / m) # np.sqrt(1 / n) + W = np.random.normal(mean, std_dev, size=(m, n)).astype(np.float32) + std_dev = np.sqrt(1 / m) # np.sqrt(2 / (m + 1)) + bt = np.random.normal(mean, std_dev, size=m).astype(np.float32) + # approach 1 + LL.weight.data = torch.tensor(W, requires_grad=True) + LL.bias.data = torch.tensor(bt, requires_grad=True) + # approach 2 + # LL.weight.data.copy_(torch.tensor(W)) + # LL.bias.data.copy_(torch.tensor(bt)) + # approach 3 + # LL.weight = Parameter(torch.tensor(W),requires_grad=True) + # LL.bias = Parameter(torch.tensor(bt),requires_grad=True) + layers.append(LL) + + # construct sigmoid or relu operator + if i == sigmoid_layer: + layers.append(nn.Sigmoid()) + else: + layers.append(nn.ReLU()) + + # approach 1: use ModuleList + # return layers + # approach 2: use Sequential container to wrap all layers + return torch.nn.Sequential(*layers) + + def create_emb(self, m, ln, weighted_pooling=None): + emb_l = nn.ModuleList() + v_W_l = [] + for i in range(0, ln.size): + if ext_dist.my_size > 1: + if i not in self.local_emb_indices: + continue + n = ln[i] + + # construct embedding operator + if self.qr_flag and n > self.qr_threshold: + EE = QREmbeddingBag( + n, + m, + self.qr_collisions, + operation=self.qr_operation, + mode="sum", + sparse=True, + ) + elif self.md_flag and n > self.md_threshold: + base = max(m) + _m = m[i] if n > self.md_threshold else base + EE = PrEmbeddingBag(n, _m, base) + # use np initialization as below for consistency... + W = np.random.uniform( + low=-np.sqrt(1 / n), high=np.sqrt(1 / n), size=(n, _m) + ).astype(np.float32) + EE.embs.weight.data = torch.tensor(W, requires_grad=True) + else: + EE = nn.EmbeddingBag(n, m, mode="sum", sparse=True, padding_idx=m-1) + # initialize embeddings + # nn.init.uniform_(EE.weight, a=-np.sqrt(1 / n), b=np.sqrt(1 / n)) + W = np.random.uniform( + low=-np.sqrt(1 / n), high=np.sqrt(1 / n), size=(n, m) + ).astype(np.float32) + # approach 1 + EE.weight.data = torch.tensor(W, requires_grad=True) + # approach 2 + # EE.weight.data.copy_(torch.tensor(W)) + # approach 3 + # EE.weight = Parameter(torch.tensor(W),requires_grad=True) + if weighted_pooling is None: + v_W_l.append(None) + else: + v_W_l.append(torch.ones(n, dtype=torch.float32)) + emb_l.append(EE) + return emb_l, v_W_l + + def __init__( + self, + m_spa=None, + ln_emb=None, + ln_bot=None, + ln_top=None, + arch_interaction_op=None, + arch_interaction_itself=False, + sigmoid_bot=-1, + sigmoid_top=-1, + sync_dense_params=True, + loss_threshold=0.0, + ndevices=-1, + qr_flag=False, + qr_operation="mult", + qr_collisions=0, + qr_threshold=200, + md_flag=False, + md_threshold=200, + weighted_pooling=None, + loss_function="bce", + ): + super(DLRM_Net, self).__init__() + + if ( + (m_spa is not None) + and (ln_emb is not None) + and (ln_bot is not None) + and (ln_top is not None) + and (arch_interaction_op is not None) + ): + # save arguments + self.ndevices = ndevices + self.output_d = 0 + self.parallel_model_batch_size = -1 + self.parallel_model_is_not_prepared = True + self.arch_interaction_op = arch_interaction_op + self.arch_interaction_itself = arch_interaction_itself + self.sync_dense_params = sync_dense_params + self.loss_threshold = loss_threshold + self.loss_function = loss_function + if weighted_pooling is not None and weighted_pooling != "fixed": + self.weighted_pooling = "learned" + else: + self.weighted_pooling = weighted_pooling + # create variables for QR embedding if applicable + self.qr_flag = qr_flag + if self.qr_flag: + self.qr_collisions = qr_collisions + self.qr_operation = qr_operation + self.qr_threshold = qr_threshold + # create variables for MD embedding if applicable + self.md_flag = md_flag + if self.md_flag: + self.md_threshold = md_threshold + + # If running distributed, get local slice of embedding tables + if ext_dist.my_size > 1: + n_emb = len(ln_emb) + if n_emb < ext_dist.my_size: + sys.exit( + "only (%d) sparse features for (%d) devices, table partitions will fail" + % (n_emb, ext_dist.my_size) + ) + self.n_global_emb = n_emb + self.n_local_emb, self.n_emb_per_rank = ext_dist.get_split_lengths( + n_emb + ) + self.local_emb_slice = ext_dist.get_my_slice(n_emb) + self.local_emb_indices = list(range(n_emb))[self.local_emb_slice] + + # create operators + if ndevices <= 1: + self.emb_l, w_list = self.create_emb(m_spa, ln_emb, weighted_pooling) + if self.weighted_pooling == "learned": + self.v_W_l = nn.ParameterList() + for w in w_list: + self.v_W_l.append(Parameter(w)) + else: + self.v_W_l = w_list + self.bot_l = self.create_mlp(ln_bot, sigmoid_bot) + self.top_l = self.create_mlp(ln_top, sigmoid_top) + + # quantization + self.quantize_emb = False + self.emb_l_q = [] + self.quantize_bits = 32 + + # specify the loss function + if self.loss_function == "mse": + self.loss_fn = torch.nn.MSELoss(reduction="mean") + elif self.loss_function == "bce": + self.loss_fn = torch.nn.BCELoss(reduction="mean") + elif self.loss_function == "wbce": + self.loss_ws = torch.tensor( + np.fromstring(args.loss_weights, dtype=float, sep="-") + ) + self.loss_fn = torch.nn.BCELoss(reduction="none") + else: + sys.exit( + "ERROR: --loss-function=" + self.loss_function + " is not supported" + ) + + def apply_mlp(self, x, layers): + # approach 1: use ModuleList + # for layer in layers: + # x = layer(x) + # return x + # approach 2: use Sequential container to wrap all layers + return layers(x) + + def apply_emb(self, lS_o, lS_i, emb_l, v_W_l): + # WARNING: notice that we are processing the batch at once. We implicitly + # assume that the data is laid out such that: + # 1. each embedding is indexed with a group of sparse indices, + # corresponding to a single lookup + # 2. for each embedding the lookups are further organized into a batch + # 3. for a list of embedding tables there is a list of batched lookups + + ly = [] + for k, sparse_index_group_batch in enumerate(lS_i): + sparse_offset_group_batch = lS_o[k] + + # embedding lookup + # We are using EmbeddingBag, which implicitly uses sum operator. + # The embeddings are represented as tall matrices, with sum + # happening vertically across 0 axis, resulting in a row vector + # E = emb_l[k] + + if v_W_l[k] is not None: + per_sample_weights = v_W_l[k].gather(0, sparse_index_group_batch) + else: + per_sample_weights = None + + if self.quantize_emb: + s1 = self.emb_l_q[k].element_size() * self.emb_l_q[k].nelement() + s2 = self.emb_l_q[k].element_size() * self.emb_l_q[k].nelement() + print("quantized emb sizes:", s1, s2) + + if self.quantize_bits == 4: + QV = ops.quantized.embedding_bag_4bit_rowwise_offsets( + self.emb_l_q[k], + sparse_index_group_batch, + sparse_offset_group_batch, + per_sample_weights=per_sample_weights, + ) + elif self.quantize_bits == 8: + QV = ops.quantized.embedding_bag_byte_rowwise_offsets( + self.emb_l_q[k], + sparse_index_group_batch, + sparse_offset_group_batch, + per_sample_weights=per_sample_weights, + ) + + ly.append(QV) + else: + E = emb_l[k] + V = E( + sparse_index_group_batch, + sparse_offset_group_batch, + per_sample_weights=per_sample_weights, + ) + + ly.append(V) + + # print(ly) + return ly + + + +def main(): + + m_spa = 2 + ln_emb = torch.tensor([4, 3, 2]) + ln_bot = torch.tensor([4, 3, 2]) + ln_top = torch.tensor([8, 4, 1, 2]) + arch_interaction_op = "dot" + arch_interaction_itself = False + sync_dense_params = True + loss_threshold = 0.0 + ndevices = -1 + qr_flag = False + qr_operation = "mult" + qr_collisions = 0 + qr_threshold = 200 + md_flag = False + md_threshold = 200 + weighted_pooling = None + loss_function = "mse" + + dlrm = DLRM_Net( + m_spa, + ln_emb, + ln_bot, + ln_top, + arch_interaction_op=arch_interaction_op, + arch_interaction_itself=arch_interaction_itself, + sigmoid_bot=-1, + sigmoid_top=ln_top.size - 2, + sync_dense_params=sync_dense_params, + loss_threshold=loss_threshold, + ndevices=ndevices, + qr_flag=qr_flag, + qr_operation=qr_operation, + qr_collisions=qr_collisions, + qr_threshold=qr_threshold, + md_flag=md_flag, + md_threshold=md_threshold, + weighted_pooling=weighted_pooling, + loss_function=loss_function, + ) + + """ + # workaround 1: tensor -> list + if torch.is_tensor(lS_i_onnx): + lS_i_onnx = [lS_i_onnx[j] for j in range(len(lS_i_onnx))] + # workaound 2: list -> tensor + lS_i_onnx = torch.stack(lS_i_onnx) + """ + # debug prints + print("inputs", X_onnx, lS_o_onnx, lS_i_onnx) + print("output", dlrm_wrap(X_onnx, lS_o_onnx, lS_i_onnx, use_gpu, device)) + dlrm_pytorch_onnx_file = "dlrm_s_pytorch.onnx" + batch_size = X_onnx.shape[0] + print("X_onnx.shape", X_onnx.shape) + if torch.is_tensor(lS_o_onnx): + print("lS_o_onnx.shape", lS_o_onnx.shape) + else: + for oo in lS_o_onnx: + print("oo.shape", oo.shape) + if torch.is_tensor(lS_i_onnx): + print("lS_i_onnx.shape", lS_i_onnx.shape) + else: + for ii in lS_i_onnx: + print("ii.shape", ii.shape) + + # name inputs and outputs + o_inputs = ( + ["offsets"] + if torch.is_tensor(lS_o_onnx) + else ["offsets_" + str(i) for i in range(len(lS_o_onnx))] + ) + i_inputs = ( + ["indices"] + if torch.is_tensor(lS_i_onnx) + else ["indices_" + str(i) for i in range(len(lS_i_onnx))] + ) + all_inputs = ["dense_x"] + o_inputs + i_inputs + # debug prints + print("inputs", all_inputs) + + # create dynamic_axis dictionaries + do_inputs = ( + [{"offsets": {1: "batch_size"}}] + if torch.is_tensor(lS_o_onnx) + else [ + {"offsets_" + str(i): {0: "batch_size"}} for i in range(len(lS_o_onnx)) + ] + ) + di_inputs = ( + [{"indices": {1: "batch_size"}}] + if torch.is_tensor(lS_i_onnx) + else [ + {"indices_" + str(i): {0: "batch_size"}} for i in range(len(lS_i_onnx)) + ] + ) + dynamic_axes = {"dense_x": {0: "batch_size"}, "pred": {0: "batch_size"}} + for do in do_inputs: + dynamic_axes.update(do) + for di in di_inputs: + dynamic_axes.update(di) + # debug prints + print(dynamic_axes) + # export model + torch.onnx.export( + dlrm, + (X_onnx, lS_o_onnx, lS_i_onnx), + dlrm_pytorch_onnx_file, + verbose=True, + opset_version=11, + input_names=all_inputs, + output_names=["pred"], + dynamic_axes=dynamic_axes, + dynamo=True, + ) + # Define model parameter + + # embedding_dim = 128 + # num_embeddings_per_feature = [2 for _ in range(26)] + # eb_configs = [ + # EmbeddingBagConfig( + # name=f"t_{feature_name}", + # embedding_dim=embedding_dim, + # num_embeddings=num_embeddings_per_feature[feature_idx], + # feature_names=[feature_name], + # ) + # for feature_idx, feature_name in enumerate(DEFAULT_CAT_NAMES) + # ] + # # Initialize the DLRM model + # dlrm_model = DLRM_DCN( + # embedding_bag_collection=EmbeddingBagCollection( + # tables=eb_configs, device=torch.device("cpu") + # ), + # dense_in_features=len(DEFAULT_INT_NAMES), + # dense_arch_layer_sizes=[512, 256, 128], + # over_arch_layer_sizes=[1024, 1024, 512, 256, 1], + # dcn_num_layers=3, + # dcn_low_rank_dim=512, + # dense_device=torch.device("cpu"), + # ) + + + # class DLRM_ONNX_WRAPPER(DLRM): + # def __init__(self, *args, **kwargs): + # super().__init__(*args, **kwargs) + + # def forward(self, dense_features, keys, values, offsets): + + # sparse_features = torchrec.KeyedJaggedTensor.from_offsets_sync( + # keys=keys, + # values=values, + # offsets=offsets, + # ) + # return super().forward(dense_features, sparse_features).squeeze(-1) + + # B = 2 + # D = 8 + + # eb1_config = EmbeddingBagConfig( + # name="t1", embedding_dim=D, num_embeddings=100, feature_names=["f1"] + # ) + # eb2_config = EmbeddingBagConfig( + # name="t2", + # embedding_dim=D, + # num_embeddings=100, + # feature_names=["f2"], + # ) + + # ebc = EmbeddingBagCollection(tables=[eb1_config, eb2_config]) + # model = DLRM_ONNX_WRAPPER( + # embedding_bag_collection=ebc, + # dense_in_features=100, + # dense_arch_layer_sizes=[20, D], + # over_arch_layer_sizes=[5, 1], + # ) + + # features = torch.rand((B, 100)) + + # # 0 1 + # # 0 [1,2] [4,5] + # # 1 [4,3] [2,9] + # # ^ + # # feature + # from torchrec import KeyedJaggedTensor + # sparse_features = KeyedJaggedTensor.from_offsets_sync( + # keys=["f1", "f2"], + # values=torch.tensor([1, 2, 4, 5, 4, 3, 2, 9]), + # offsets=torch.tensor([0, 2, 4, 6, 8]), + # ) + + # import pdb; pdb.set_trace() + # logits = model( + # dense_features=features, + # keys=["f1", "f2"], + # values=torch.tensor([1, 2, 4, 5, 4, 3, 2, 9]), + # offsets=torch.tensor([0, 2, 4, 6, 8]), + # ) + + # print(logits) + # with torch.no_grad(): + # torch.onnx.export( + # model, + # (features, ["f1", "f2"], torch.tensor([1, 2, 4, 5, 4, 3, 2, 9]), torch.tensor([0, 2, 4, 6, 8])), + # "dlrm.onnx", + # input_names=["dense_features", "keys", "values", "offsets"], + # output_names=["logits"], + # opset_version=18, + # dynamo=True, + # report=True + # #optimize=True, + # ) + + # import onnx + + # proto = onnx.load("dlrm.onnx") + # onnx.checker.check_model(proto) + # onnx.shape_inference.infer_shapes(proto) + # onnx.save(proto, "dlrm2.onnx") + + + # Wrap the DLRM model with DLRM_DCN for recommendation tasks + #dlrm_rec_model = DLRM_DCN(dlrm_model) + + # Create dummy input data + #atch_size = 4 + #dense_input = torch.randn(batch_size, num_dense_features) + #sparse_input = torch.randint(0, 1000, (batch_size, num_sparse_features)) + + # Forward pass through the model + #output = dlrm_model(dense_input, sparse_input) + + #print("Model output:", output) + +if __name__ == "__main__": + main() diff --git a/examples/dlrm/dlrm_s_pytorch.py b/examples/dlrm/dlrm_s_pytorch.py new file mode 100644 index 00000000..8069c0dc --- /dev/null +++ b/examples/dlrm/dlrm_s_pytorch.py @@ -0,0 +1,1918 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# Description: an implementation of a deep learning recommendation model (DLRM) +# The model input consists of dense and sparse features. The former is a vector +# of floating point values. The latter is a list of sparse indices into +# embedding tables, which consist of vectors of floating point values. +# The selected vectors are passed to mlp networks denoted by triangles, +# in some cases the vectors are interacted through operators (Ops). +# +# output: +# vector of values +# model: | +# /\ +# /__\ +# | +# _____________________> Op <___________________ +# / | \ +# /\ /\ /\ +# /__\ /__\ ... /__\ +# | | | +# | Op Op +# | ____/__\_____ ____/__\____ +# | |_Emb_|____|__| ... |_Emb_|__|___| +# input: +# [ dense features ] [sparse indices] , ..., [sparse indices] +# +# More precise definition of model layers: +# 1) fully connected layers of an mlp +# z = f(y) +# y = Wx + b +# +# 2) embedding lookup (for a list of sparse indices p=[p1,...,pk]) +# z = Op(e1,...,ek) +# obtain vectors e1=E[:,p1], ..., ek=E[:,pk] +# +# 3) Operator Op can be one of the following +# Sum(e1,...,ek) = e1 + ... + ek +# Dot(e1,...,ek) = [e1'e1, ..., e1'ek, ..., ek'e1, ..., ek'ek] +# Cat(e1,...,ek) = [e1', ..., ek']' +# where ' denotes transpose operation +# +# References: +# [1] Maxim Naumov, Dheevatsa Mudigere, Hao-Jun Michael Shi, Jianyu Huang, +# Narayanan Sundaram, Jongsoo Park, Xiaodong Wang, Udit Gupta, Carole-Jean Wu, +# Alisson G. Azzolini, Dmytro Dzhulgakov, Andrey Mallevich, Ilia Cherniavskii, +# Yinghai Lu, Raghuraman Krishnamoorthi, Ansha Yu, Volodymyr Kondratenko, +# Stephanie Pereira, Xianjie Chen, Wenlin Chen, Vijay Rao, Bill Jia, Liang Xiong, +# Misha Smelyanskiy, "Deep Learning Recommendation Model for Personalization and +# Recommendation Systems", CoRR, arXiv:1906.00091, 2019 + +from __future__ import absolute_import, division, print_function, unicode_literals + +import argparse + +# miscellaneous +import builtins +import datetime +import json +import sys +import time + +# onnx +# The onnx import causes deprecation warnings every time workers +# are spawned during testing. So, we filter out those warnings. +import warnings + +# data generation +import dlrm_data_pytorch as dp + +# For distributed run +import extend_distributed as ext_dist +import mlperf_logger + +# numpy +import numpy as np +import optim.rwsadagrad as RowWiseSparseAdagrad +import sklearn.metrics + +# pytorch +import torch +import torch.nn as nn + +# dataloader +try: + from internals import fbDataLoader, fbInputBatchFormatter + + has_internal_libs = True +except ImportError: + has_internal_libs = False + +from torch._ops import ops +from torch.autograd.profiler import record_function +from torch.nn.parallel.parallel_apply import parallel_apply +from torch.nn.parallel.replicate import replicate +from torch.nn.parallel.scatter_gather import gather, scatter +from torch.nn.parameter import Parameter +from torch.optim.lr_scheduler import _LRScheduler +from torch.utils.tensorboard import SummaryWriter + +import brevitas + +# mixed-dimension trick +from tricks.md_embedding_bag import md_solver, PrEmbeddingBag + +# quotient-remainder trick +from tricks.qr_embedding_bag import QREmbeddingBag + +with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=DeprecationWarning) + try: + import onnx + except ImportError as error: + print("Unable to import onnx. ", error) + +# from torchviz import make_dot +# import torch.nn.functional as Functional +# from torch.nn.parameter import Parameter + +exc = getattr(builtins, "IOError", "FileNotFoundError") + + +def time_wrap(use_gpu): + if use_gpu: + torch.cuda.synchronize() + return time.time() + + +def dlrm_wrap(X, lS_o, lS_i, use_gpu, device, ndevices=1): + with record_function("DLRM forward"): + if use_gpu: # .cuda() + # lS_i can be either a list of tensors or a stacked tensor. + # Handle each case below: + if ndevices == 1: + lS_i = ( + [S_i.to(device) for S_i in lS_i] + if isinstance(lS_i, list) + else lS_i.to(device) + ) + lS_o = ( + [S_o.to(device) for S_o in lS_o] + if isinstance(lS_o, list) + else lS_o.to(device) + ) + return dlrm(X.to(device), lS_o, lS_i) + + +def loss_fn_wrap(Z, T, use_gpu, device): + with record_function("DLRM loss compute"): + if args.loss_function == "mse" or args.loss_function == "bce": + return dlrm.loss_fn(Z, T.to(device)) + elif args.loss_function == "wbce": + loss_ws_ = dlrm.loss_ws[T.data.view(-1).long()].view_as(T).to(device) + loss_fn_ = dlrm.loss_fn(Z, T.to(device)) + loss_sc_ = loss_ws_ * loss_fn_ + return loss_sc_.mean() + + +# The following function is a wrapper to avoid checking this multiple times in th +# loop below. +def unpack_batch(b): + if args.data_generation == "internal": + return fbInputBatchFormatter(b, args.data_size) + else: + # Experiment with unweighted samples + return b[0], b[1], b[2], b[3], torch.ones(b[3].size()), None + + +class LRPolicyScheduler(_LRScheduler): + def __init__(self, optimizer, num_warmup_steps, decay_start_step, num_decay_steps): + self.num_warmup_steps = num_warmup_steps + self.decay_start_step = decay_start_step + self.decay_end_step = decay_start_step + num_decay_steps + self.num_decay_steps = num_decay_steps + + if self.decay_start_step < self.num_warmup_steps: + sys.exit("Learning rate warmup must finish before the decay starts") + + super(LRPolicyScheduler, self).__init__(optimizer) + + def get_lr(self): + step_count = self._step_count + if step_count < self.num_warmup_steps: + # warmup + scale = 1.0 - (self.num_warmup_steps - step_count) / self.num_warmup_steps + lr = [base_lr * scale for base_lr in self.base_lrs] + self.last_lr = lr + elif self.decay_start_step <= step_count and step_count < self.decay_end_step: + # decay + decayed_steps = step_count - self.decay_start_step + scale = ((self.num_decay_steps - decayed_steps) / self.num_decay_steps) ** 2 + min_lr = 0.0000001 + lr = [max(min_lr, base_lr * scale) for base_lr in self.base_lrs] + self.last_lr = lr + else: + if self.num_decay_steps > 0: + # freeze at last, either because we're after decay + # or because we're between warmup and decay + lr = self.last_lr + else: + # do not adjust + lr = self.base_lrs + return lr + +import brevitas.onnx as bo +from brevitas.export import export_qonnx +from brevitas.nn import QuantLinear +from brevitas.quant import Int8WeightPerTensorFloat, Uint8ActPerTensorFloat, Int8Bias + + + +### define dlrm in PyTorch ### +class DLRM_Net(nn.Module): + def create_mlp(self, ln, sigmoid_layer): + # build MLP layer by layer + layers = nn.ModuleList() + for i in range(0, ln.size - 1): + n = ln[i] + m = ln[i + 1] + + # construct fully connected operator + #LL = nn.Linear(int(n), int(m), bias=True) + LL = QuantLinear(int(n), int(m), bias=True, weight_quant=Int8WeightPerTensorFloat, input_quant=Uint8ActPerTensorFloat, bias_quant=Int8Bias) + + # initialize the weights + # with torch.no_grad(): + # custom Xavier input, output or two-sided fill + mean = 0.0 # std_dev = np.sqrt(variance) + std_dev = np.sqrt(2 / (m + n)) # np.sqrt(1 / m) # np.sqrt(1 / n) + W = np.random.normal(mean, std_dev, size=(m, n)).astype(np.float32) + std_dev = np.sqrt(1 / m) # np.sqrt(2 / (m + 1)) + bt = np.random.normal(mean, std_dev, size=m).astype(np.float32) + # approach 1 + LL.weight.data = torch.tensor(W, requires_grad=True) + LL.bias.data = torch.tensor(bt, requires_grad=True) + # approach 2 + # LL.weight.data.copy_(torch.tensor(W)) + # LL.bias.data.copy_(torch.tensor(bt)) + # approach 3 + # LL.weight = Parameter(torch.tensor(W),requires_grad=True) + # LL.bias = Parameter(torch.tensor(bt),requires_grad=True) + layers.append(LL) + + # construct sigmoid or relu operator + if i == sigmoid_layer: + layers.append(nn.Sigmoid()) + else: + layers.append(nn.ReLU()) + + # approach 1: use ModuleList + # return layers + # approach 2: use Sequential container to wrap all layers + return torch.nn.Sequential(*layers) + + def create_emb(self, m, ln, weighted_pooling=None): + emb_l = nn.ModuleList() + v_W_l = [] + for i in range(0, ln.size): + if ext_dist.my_size > 1: + if i not in self.local_emb_indices: + continue + n = ln[i] + + # construct embedding operator + if self.qr_flag and n > self.qr_threshold: + EE = QREmbeddingBag( + n, + m, + self.qr_collisions, + operation=self.qr_operation, + mode="sum", + sparse=True, + ) + elif self.md_flag and n > self.md_threshold: + base = max(m) + _m = m[i] if n > self.md_threshold else base + EE = PrEmbeddingBag(n, _m, base) + # use np initialization as below for consistency... + W = np.random.uniform( + low=-np.sqrt(1 / n), high=np.sqrt(1 / n), size=(n, _m) + ).astype(np.float32) + EE.embs.weight.data = torch.tensor(W, requires_grad=True) + else: + EE = nn.EmbeddingBag(n, m, mode="sum", sparse=True)#, padding_idx=m-1) + # initialize embeddings + # nn.init.uniform_(EE.weight, a=-np.sqrt(1 / n), b=np.sqrt(1 / n)) + W = np.random.uniform( + low=-np.sqrt(1 / n), high=np.sqrt(1 / n), size=(n, m) + ).astype(np.float32) + # approach 1 + EE.weight.data = torch.tensor(W, requires_grad=True) + # approach 2 + # EE.weight.data.copy_(torch.tensor(W)) + # approach 3 + # EE.weight = Parameter(torch.tensor(W),requires_grad=True) + if weighted_pooling is None: + v_W_l.append(None) + else: + v_W_l.append(torch.ones(n, dtype=torch.float32)) + emb_l.append(EE) + return emb_l, v_W_l + + def __init__( + self, + m_spa=None, + ln_emb=None, + ln_bot=None, + ln_top=None, + arch_interaction_op=None, + arch_interaction_itself=False, + sigmoid_bot=-1, + sigmoid_top=-1, + sync_dense_params=True, + loss_threshold=0.0, + ndevices=-1, + qr_flag=False, + qr_operation="mult", + qr_collisions=0, + qr_threshold=200, + md_flag=False, + md_threshold=200, + weighted_pooling=None, + loss_function="bce", + ): + super(DLRM_Net, self).__init__() + if ( + (m_spa is not None) + and (ln_emb is not None) + and (ln_bot is not None) + and (ln_top is not None) + and (arch_interaction_op is not None) + ): + # save arguments + self.ndevices = ndevices + self.output_d = 0 + self.parallel_model_batch_size = -1 + self.parallel_model_is_not_prepared = True + self.arch_interaction_op = arch_interaction_op + self.arch_interaction_itself = arch_interaction_itself + self.sync_dense_params = sync_dense_params + self.loss_threshold = loss_threshold + self.loss_function = loss_function + if weighted_pooling is not None and weighted_pooling != "fixed": + self.weighted_pooling = "learned" + else: + self.weighted_pooling = weighted_pooling + # create variables for QR embedding if applicable + self.qr_flag = qr_flag + if self.qr_flag: + self.qr_collisions = qr_collisions + self.qr_operation = qr_operation + self.qr_threshold = qr_threshold + # create variables for MD embedding if applicable + self.md_flag = md_flag + if self.md_flag: + self.md_threshold = md_threshold + + # If running distributed, get local slice of embedding tables + if ext_dist.my_size > 1: + n_emb = len(ln_emb) + if n_emb < ext_dist.my_size: + sys.exit( + "only (%d) sparse features for (%d) devices, table partitions will fail" + % (n_emb, ext_dist.my_size) + ) + self.n_global_emb = n_emb + self.n_local_emb, self.n_emb_per_rank = ext_dist.get_split_lengths( + n_emb + ) + self.local_emb_slice = ext_dist.get_my_slice(n_emb) + self.local_emb_indices = list(range(n_emb))[self.local_emb_slice] + + # create operators + if ndevices <= 1: + self.emb_l, w_list = self.create_emb(m_spa, ln_emb, weighted_pooling) + if self.weighted_pooling == "learned": + self.v_W_l = nn.ParameterList() + for w in w_list: + self.v_W_l.append(Parameter(w)) + else: + self.v_W_l = w_list + self.bot_l = self.create_mlp(ln_bot, sigmoid_bot) + self.top_l = self.create_mlp(ln_top, sigmoid_top) + + # quantization + self.quantize_emb = False + self.emb_l_q = [] + self.quantize_bits = 32 + + # specify the loss function + if self.loss_function == "mse": + self.loss_fn = torch.nn.MSELoss(reduction="mean") + elif self.loss_function == "bce": + self.loss_fn = torch.nn.BCELoss(reduction="mean") + elif self.loss_function == "wbce": + self.loss_ws = torch.tensor( + np.fromstring(args.loss_weights, dtype=float, sep="-") + ) + self.loss_fn = torch.nn.BCELoss(reduction="none") + else: + sys.exit( + "ERROR: --loss-function=" + self.loss_function + " is not supported" + ) + + def apply_mlp(self, x, layers): + # approach 1: use ModuleList + # for layer in layers: + # x = layer(x) + # return x + # approach 2: use Sequential container to wrap all layers + return layers(x) + + def apply_emb(self, lS_o, lS_i, emb_l, v_W_l): + # WARNING: notice that we are processing the batch at once. We implicitly + # assume that the data is laid out such that: + # 1. each embedding is indexed with a group of sparse indices, + # corresponding to a single lookup + # 2. for each embedding the lookups are further organized into a batch + # 3. for a list of embedding tables there is a list of batched lookups + + ly = [] + for k, sparse_index_group_batch in enumerate(lS_i): + sparse_offset_group_batch = lS_o[k] + + # embedding lookup + # We are using EmbeddingBag, which implicitly uses sum operator. + # The embeddings are represented as tall matrices, with sum + # happening vertically across 0 axis, resulting in a row vector + # E = emb_l[k] + + if v_W_l[k] is not None: + per_sample_weights = v_W_l[k].gather(0, sparse_index_group_batch) + else: + per_sample_weights = None + + if self.quantize_emb: + s1 = self.emb_l_q[k].element_size() * self.emb_l_q[k].nelement() + s2 = self.emb_l_q[k].element_size() * self.emb_l_q[k].nelement() + print("quantized emb sizes:", s1, s2) + + if self.quantize_bits == 4: + QV = ops.quantized.embedding_bag_4bit_rowwise_offsets( + self.emb_l_q[k], + sparse_index_group_batch, + sparse_offset_group_batch, + per_sample_weights=per_sample_weights, + ) + elif self.quantize_bits == 8: + QV = ops.quantized.embedding_bag_byte_rowwise_offsets( + self.emb_l_q[k], + sparse_index_group_batch, + sparse_offset_group_batch, + per_sample_weights=per_sample_weights, + ) + + ly.append(QV) + else: + E = emb_l[k] + V = E( + sparse_index_group_batch, + sparse_offset_group_batch, + per_sample_weights=per_sample_weights, + ) + + ly.append(V) + + # print(ly) + return ly + + # using quantizing functions from caffe2/aten/src/ATen/native/quantized/cpu + def quantize_embedding(self, bits): + n = len(self.emb_l) + self.emb_l_q = [None] * n + for k in range(n): + if bits == 4: + self.emb_l_q[k] = ops.quantized.embedding_bag_4bit_prepack( + self.emb_l[k].weight + ) + elif bits == 8: + self.emb_l_q[k] = ops.quantized.embedding_bag_byte_prepack( + self.emb_l[k].weight + ) + else: + return + self.emb_l = None + self.quantize_emb = True + self.quantize_bits = bits + + def interact_features(self, x, ly): + if self.arch_interaction_op == "dot": + # concatenate dense and sparse features + (batch_size, d) = x.shape + T = torch.cat([x] + ly, dim=1).view((batch_size, -1, d)) + # perform a dot product + Z = torch.bmm(T, torch.transpose(T, 1, 2)) + # append dense feature with the interactions (into a row vector) + # approach 1: all + # Zflat = Z.view((batch_size, -1)) + # approach 2: unique + _, ni, nj = Z.shape + # approach 1: tril_indices + # offset = 0 if self.arch_interaction_itself else -1 + # li, lj = torch.tril_indices(ni, nj, offset=offset) + # approach 2: custom + offset = 1 if self.arch_interaction_itself else 0 + li = torch.tensor([i for i in range(ni) for j in range(i + offset)]) + lj = torch.tensor([j for i in range(nj) for j in range(i + offset)]) + Zflat = Z[:, li, lj] + # concatenate dense features and interactions + R = torch.cat([x] + [Zflat], dim=1) + elif self.arch_interaction_op == "cat": + # concatenation features (into a row vector) + R = torch.cat([x] + ly, dim=1) + else: + sys.exit( + "ERROR: --arch-interaction-op=" + + self.arch_interaction_op + + " is not supported" + ) + + return R + + def forward(self, dense_x, lS_o, lS_i): + if ext_dist.my_size > 1: + # multi-node multi-device run + return self.distributed_forward(dense_x, lS_o, lS_i) + elif self.ndevices <= 1: + # single device run + return self.sequential_forward(dense_x, lS_o, lS_i) + else: + # single-node multi-device run + return self.parallel_forward(dense_x, lS_o, lS_i) + + def distributed_forward(self, dense_x, lS_o, lS_i): + batch_size = dense_x.size()[0] + # WARNING: # of ranks must be <= batch size in distributed_forward call + if batch_size < ext_dist.my_size: + sys.exit( + "ERROR: batch_size (%d) must be larger than number of ranks (%d)" + % (batch_size, ext_dist.my_size) + ) + if batch_size % ext_dist.my_size != 0: + sys.exit( + "ERROR: batch_size %d can not split across %d ranks evenly" + % (batch_size, ext_dist.my_size) + ) + + dense_x = dense_x[ext_dist.get_my_slice(batch_size)] + lS_o = lS_o[self.local_emb_slice] + lS_i = lS_i[self.local_emb_slice] + + if (len(self.emb_l) != len(lS_o)) or (len(self.emb_l) != len(lS_i)): + sys.exit( + "ERROR: corrupted model input detected in distributed_forward call" + ) + + # embeddings + with record_function("DLRM embedding forward"): + ly = self.apply_emb(lS_o, lS_i, self.emb_l, self.v_W_l) + + # WARNING: Note that at this point we have the result of the embedding lookup + # for the entire batch on each rank. We would like to obtain partial results + # corresponding to all embedding lookups, but part of the batch on each rank. + # Therefore, matching the distribution of output of bottom mlp, so that both + # could be used for subsequent interactions on each device. + if len(self.emb_l) != len(ly): + sys.exit("ERROR: corrupted intermediate result in distributed_forward call") + + a2a_req = ext_dist.alltoall(ly, self.n_emb_per_rank) + + with record_function("DLRM bottom nlp forward"): + x = self.apply_mlp(dense_x, self.bot_l) + + ly = a2a_req.wait() + ly = list(ly) + + # interactions + with record_function("DLRM interaction forward"): + z = self.interact_features(x, ly) + + # top mlp + with record_function("DLRM top nlp forward"): + p = self.apply_mlp(z, self.top_l) + + # clamp output if needed + if 0.0 < self.loss_threshold and self.loss_threshold < 1.0: + z = torch.clamp(p, min=self.loss_threshold, max=(1.0 - self.loss_threshold)) + else: + z = p + + return z + + def sequential_forward(self, dense_x, lS_o, lS_i): + # process dense features (using bottom mlp), resulting in a row vector + x = self.apply_mlp(dense_x, self.bot_l) + # debug prints + # print("intermediate") + # print(x.detach().cpu().numpy()) + + # process sparse features(using embeddings), resulting in a list of row vectors + ly = self.apply_emb(lS_o, lS_i, self.emb_l, self.v_W_l) + # for y in ly: + # print(y.detach().cpu().numpy()) + + # interact features (dense and sparse) + z = self.interact_features(x, ly) + # print(z.detach().cpu().numpy()) + + # obtain probability of a click (using top mlp) + p = self.apply_mlp(z, self.top_l) + + # clamp output if needed + if 0.0 < self.loss_threshold and self.loss_threshold < 1.0: + z = torch.clamp(p, min=self.loss_threshold, max=(1.0 - self.loss_threshold)) + else: + z = p + + return z + + def parallel_forward(self, dense_x, lS_o, lS_i): + ### prepare model (overwrite) ### + # WARNING: # of devices must be >= batch size in parallel_forward call + batch_size = dense_x.size()[0] + ndevices = min(self.ndevices, batch_size, len(self.emb_l)) + device_ids = range(ndevices) + # WARNING: must redistribute the model if mini-batch size changes(this is common + # for last mini-batch, when # of elements in the dataset/batch size is not even + if self.parallel_model_batch_size != batch_size: + self.parallel_model_is_not_prepared = True + + if self.parallel_model_is_not_prepared or self.sync_dense_params: + # replicate mlp (data parallelism) + self.bot_l_replicas = replicate(self.bot_l, device_ids) + self.top_l_replicas = replicate(self.top_l, device_ids) + self.parallel_model_batch_size = batch_size + + if self.parallel_model_is_not_prepared: + # distribute embeddings (model parallelism) + t_list = [] + w_list = [] + for k, emb in enumerate(self.emb_l): + d = torch.device("cuda:" + str(k % ndevices)) + t_list.append(emb.to(d)) + if self.weighted_pooling == "learned": + w_list.append(Parameter(self.v_W_l[k].to(d))) + elif self.weighted_pooling == "fixed": + w_list.append(self.v_W_l[k].to(d)) + else: + w_list.append(None) + self.emb_l = nn.ModuleList(t_list) + if self.weighted_pooling == "learned": + self.v_W_l = nn.ParameterList(w_list) + else: + self.v_W_l = w_list + self.parallel_model_is_not_prepared = False + + ### prepare input (overwrite) ### + # scatter dense features (data parallelism) + # print(dense_x.device) + dense_x = scatter(dense_x, device_ids, dim=0) + # distribute sparse features (model parallelism) + if (len(self.emb_l) != len(lS_o)) or (len(self.emb_l) != len(lS_i)): + sys.exit("ERROR: corrupted model input detected in parallel_forward call") + + t_list = [] + i_list = [] + for k, _ in enumerate(self.emb_l): + d = torch.device("cuda:" + str(k % ndevices)) + t_list.append(lS_o[k].to(d)) + i_list.append(lS_i[k].to(d)) + lS_o = t_list + lS_i = i_list + + ### compute results in parallel ### + # bottom mlp + # WARNING: Note that the self.bot_l is a list of bottom mlp modules + # that have been replicated across devices, while dense_x is a tuple of dense + # inputs that has been scattered across devices on the first (batch) dimension. + # The output is a list of tensors scattered across devices according to the + # distribution of dense_x. + x = parallel_apply(self.bot_l_replicas, dense_x, None, device_ids) + # debug prints + # print(x) + + # embeddings + ly = self.apply_emb(lS_o, lS_i, self.emb_l, self.v_W_l) + # debug prints + # print(ly) + + # butterfly shuffle (implemented inefficiently for now) + # WARNING: Note that at this point we have the result of the embedding lookup + # for the entire batch on each device. We would like to obtain partial results + # corresponding to all embedding lookups, but part of the batch on each device. + # Therefore, matching the distribution of output of bottom mlp, so that both + # could be used for subsequent interactions on each device. + if len(self.emb_l) != len(ly): + sys.exit("ERROR: corrupted intermediate result in parallel_forward call") + + t_list = [] + for k, _ in enumerate(self.emb_l): + d = torch.device("cuda:" + str(k % ndevices)) + y = scatter(ly[k], device_ids, dim=0) + t_list.append(y) + # adjust the list to be ordered per device + ly = list(map(lambda y: list(y), zip(*t_list))) + # debug prints + # print(ly) + + # interactions + z = [] + for k in range(ndevices): + zk = self.interact_features(x[k], ly[k]) + z.append(zk) + # debug prints + # print(z) + + # top mlp + # WARNING: Note that the self.top_l is a list of top mlp modules that + # have been replicated across devices, while z is a list of interaction results + # that by construction are scattered across devices on the first (batch) dim. + # The output is a list of tensors scattered across devices according to the + # distribution of z. + p = parallel_apply(self.top_l_replicas, z, None, device_ids) + + ### gather the distributed results ### + p0 = gather(p, self.output_d, dim=0) + + # clamp output if needed + if 0.0 < self.loss_threshold and self.loss_threshold < 1.0: + z0 = torch.clamp( + p0, min=self.loss_threshold, max=(1.0 - self.loss_threshold) + ) + else: + z0 = p0 + + return z0 + + +def dash_separated_ints(value): + vals = value.split("-") + for val in vals: + try: + int(val) + except ValueError: + raise argparse.ArgumentTypeError( + "%s is not a valid dash separated list of ints" % value + ) + + return value + + +def dash_separated_floats(value): + vals = value.split("-") + for val in vals: + try: + float(val) + except ValueError: + raise argparse.ArgumentTypeError( + "%s is not a valid dash separated list of floats" % value + ) + + return value + + +def inference( + args, + dlrm, + best_acc_test, + best_auc_test, + test_ld, + device, + use_gpu, + log_iter=-1, +): + test_accu = 0 + test_samp = 0 + + if args.mlperf_logging: + scores = [] + targets = [] + + for i, testBatch in enumerate(test_ld): + # early exit if nbatches was set by the user and was exceeded + if nbatches > 0 and i >= nbatches: + break + + X_test, lS_o_test, lS_i_test, T_test, W_test, CBPP_test = unpack_batch( + testBatch + ) + + # Skip the batch if batch size not multiple of total ranks + if ext_dist.my_size > 1 and X_test.size(0) % ext_dist.my_size != 0: + print("Warning: Skiping the batch %d with size %d" % (i, X_test.size(0))) + continue + + # forward pass + Z_test = dlrm_wrap( + X_test, + lS_o_test, + lS_i_test, + use_gpu, + device, + ndevices=ndevices, + ) + ### gather the distributed results on each rank ### + # For some reason it requires explicit sync before all_gather call if + # tensor is on GPU memory + if Z_test.is_cuda: + torch.cuda.synchronize() + (_, batch_split_lengths) = ext_dist.get_split_lengths(X_test.size(0)) + if ext_dist.my_size > 1: + Z_test = ext_dist.all_gather(Z_test, batch_split_lengths) + + if args.mlperf_logging: + S_test = Z_test.detach().cpu().numpy() # numpy array + T_test = T_test.detach().cpu().numpy() # numpy array + scores.append(S_test) + targets.append(T_test) + else: + with record_function("DLRM accuracy compute"): + # compute loss and accuracy + S_test = Z_test.detach().cpu().numpy() # numpy array + T_test = T_test.detach().cpu().numpy() # numpy array + + mbs_test = T_test.shape[0] # = mini_batch_size except last + A_test = np.sum((np.round(S_test, 0) == T_test).astype(np.uint8)) + + test_accu += A_test + test_samp += mbs_test + + if args.mlperf_logging: + with record_function("DLRM mlperf sklearn metrics compute"): + scores = np.concatenate(scores, axis=0) + targets = np.concatenate(targets, axis=0) + + metrics = { + "recall": lambda y_true, y_score: sklearn.metrics.recall_score( + y_true=y_true, y_pred=np.round(y_score) + ), + "precision": lambda y_true, y_score: sklearn.metrics.precision_score( + y_true=y_true, y_pred=np.round(y_score) + ), + "f1": lambda y_true, y_score: sklearn.metrics.f1_score( + y_true=y_true, y_pred=np.round(y_score) + ), + "ap": sklearn.metrics.average_precision_score, + "roc_auc": sklearn.metrics.roc_auc_score, + "accuracy": lambda y_true, y_score: sklearn.metrics.accuracy_score( + y_true=y_true, y_pred=np.round(y_score) + ), + } + + validation_results = {} + for metric_name, metric_function in metrics.items(): + validation_results[metric_name] = metric_function(targets, scores) + writer.add_scalar( + "mlperf-metrics-test/" + metric_name, + validation_results[metric_name], + log_iter, + ) + acc_test = validation_results["accuracy"] + else: + acc_test = test_accu / test_samp + writer.add_scalar("Test/Acc", acc_test, log_iter) + + model_metrics_dict = { + "nepochs": args.nepochs, + "nbatches": nbatches, + "nbatches_test": nbatches_test, + "state_dict": dlrm.state_dict(), + "test_acc": acc_test, + } + + if args.mlperf_logging: + is_best = validation_results["roc_auc"] > best_auc_test + if is_best: + best_auc_test = validation_results["roc_auc"] + model_metrics_dict["test_auc"] = best_auc_test + print( + "recall {:.4f}, precision {:.4f},".format( + validation_results["recall"], + validation_results["precision"], + ) + + " f1 {:.4f}, ap {:.4f},".format( + validation_results["f1"], validation_results["ap"] + ) + + " auc {:.4f}, best auc {:.4f},".format( + validation_results["roc_auc"], best_auc_test + ) + + " accuracy {:3.3f} %, best accuracy {:3.3f} %".format( + validation_results["accuracy"] * 100, best_acc_test * 100 + ), + flush=True, + ) + else: + is_best = acc_test > best_acc_test + if is_best: + best_acc_test = acc_test + print( + " accuracy {:3.3f} %, best {:3.3f} %".format( + acc_test * 100, best_acc_test * 100 + ), + flush=True, + ) + return model_metrics_dict, is_best + + +def run(): + ### parse arguments ### + parser = argparse.ArgumentParser( + description="Train Deep Learning Recommendation Model (DLRM)" + ) + # model related parameters + parser.add_argument("--arch-sparse-feature-size", type=int, default=2) + parser.add_argument( + "--arch-embedding-size", type=dash_separated_ints, default="4-3-2" + ) + # j will be replaced with the table number + parser.add_argument("--arch-mlp-bot", type=dash_separated_ints, default="4-3-2") + parser.add_argument("--arch-mlp-top", type=dash_separated_ints, default="4-2-1") + parser.add_argument( + "--arch-interaction-op", type=str, choices=["dot", "cat"], default="dot" + ) + parser.add_argument("--arch-interaction-itself", action="store_true", default=False) + parser.add_argument("--weighted-pooling", type=str, default=None) + # embedding table options + parser.add_argument("--md-flag", action="store_true", default=False) + parser.add_argument("--md-threshold", type=int, default=200) + parser.add_argument("--md-temperature", type=float, default=0.3) + parser.add_argument("--md-round-dims", action="store_true", default=False) + parser.add_argument("--qr-flag", action="store_true", default=False) + parser.add_argument("--qr-threshold", type=int, default=200) + parser.add_argument("--qr-operation", type=str, default="mult") + parser.add_argument("--qr-collisions", type=int, default=4) + # activations and loss + parser.add_argument("--activation-function", type=str, default="relu") + parser.add_argument("--loss-function", type=str, default="mse") # or bce or wbce + parser.add_argument( + "--loss-weights", type=dash_separated_floats, default="1.0-1.0" + ) # for wbce + parser.add_argument("--loss-threshold", type=float, default=0.0) # 1.0e-7 + parser.add_argument("--round-targets", type=bool, default=False) + # data + parser.add_argument("--data-size", type=int, default=1) + parser.add_argument("--num-batches", type=int, default=0) + parser.add_argument( + "--data-generation", + type=str, + choices=["random", "dataset", "internal"], + default="random", + ) # synthetic, dataset or internal + parser.add_argument( + "--rand-data-dist", type=str, default="uniform" + ) # uniform or gaussian + parser.add_argument("--rand-data-min", type=float, default=0) + parser.add_argument("--rand-data-max", type=float, default=1) + parser.add_argument("--rand-data-mu", type=float, default=-1) + parser.add_argument("--rand-data-sigma", type=float, default=1) + parser.add_argument("--data-trace-file", type=str, default="./input/dist_emb_j.log") + parser.add_argument("--data-set", type=str, default="kaggle") # or terabyte + parser.add_argument("--raw-data-file", type=str, default="") + parser.add_argument("--processed-data-file", type=str, default="") + parser.add_argument("--data-randomize", type=str, default="total") # or day or none + parser.add_argument("--data-trace-enable-padding", type=bool, default=False) + parser.add_argument("--max-ind-range", type=int, default=-1) + parser.add_argument("--data-sub-sample-rate", type=float, default=0.0) # in [0, 1] + parser.add_argument("--num-indices-per-lookup", type=int, default=10) + parser.add_argument("--num-indices-per-lookup-fixed", type=bool, default=False) + parser.add_argument("--num-workers", type=int, default=0) + parser.add_argument("--memory-map", action="store_true", default=False) + # training + parser.add_argument("--mini-batch-size", type=int, default=1) + parser.add_argument("--nepochs", type=int, default=1) + parser.add_argument("--learning-rate", type=float, default=0.01) + parser.add_argument("--print-precision", type=int, default=5) + parser.add_argument("--numpy-rand-seed", type=int, default=123) + parser.add_argument("--sync-dense-params", type=bool, default=True) + parser.add_argument("--optimizer", type=str, default="sgd") + parser.add_argument( + "--dataset-multiprocessing", + action="store_true", + default=False, + help="The Kaggle dataset can be multiprocessed in an environment \ + with more than 7 CPU cores and more than 20 GB of memory. \n \ + The Terabyte dataset can be multiprocessed in an environment \ + with more than 24 CPU cores and at least 1 TB of memory.", + ) + # inference + parser.add_argument("--inference-only", action="store_true", default=False) + # quantize + parser.add_argument("--quantize-mlp-with-bit", type=int, default=32) + parser.add_argument("--quantize-emb-with-bit", type=int, default=32) + # onnx + parser.add_argument("--save-onnx", action="store_true", default=False) + # gpu + parser.add_argument("--use-gpu", action="store_true", default=False) + # distributed + parser.add_argument("--local_rank", type=int, default=-1) + parser.add_argument("--dist-backend", type=str, default="") + # debugging and profiling + parser.add_argument("--print-freq", type=int, default=1) + parser.add_argument("--test-freq", type=int, default=-1) + parser.add_argument("--test-mini-batch-size", type=int, default=-1) + parser.add_argument("--test-num-workers", type=int, default=-1) + parser.add_argument("--print-time", action="store_true", default=False) + parser.add_argument("--print-wall-time", action="store_true", default=False) + parser.add_argument("--debug-mode", action="store_true", default=False) + parser.add_argument("--enable-profiling", action="store_true", default=False) + parser.add_argument("--plot-compute-graph", action="store_true", default=False) + parser.add_argument("--tensor-board-filename", type=str, default="run_kaggle_pt") + # store/load model + parser.add_argument("--save-model", type=str, default="") + parser.add_argument("--load-model", type=str, default="") + # mlperf logging (disables other output and stops early) + parser.add_argument("--mlperf-logging", action="store_true", default=False) + # stop at target accuracy Kaggle 0.789, Terabyte (sub-sampled=0.875) 0.8107 + parser.add_argument("--mlperf-acc-threshold", type=float, default=0.0) + # stop at target AUC Terabyte (no subsampling) 0.8025 + parser.add_argument("--mlperf-auc-threshold", type=float, default=0.0) + parser.add_argument("--mlperf-bin-loader", action="store_true", default=False) + parser.add_argument("--mlperf-bin-shuffle", action="store_true", default=False) + # mlperf gradient accumulation iterations + parser.add_argument("--mlperf-grad-accum-iter", type=int, default=1) + # LR policy + parser.add_argument("--lr-num-warmup-steps", type=int, default=0) + parser.add_argument("--lr-decay-start-step", type=int, default=0) + parser.add_argument("--lr-num-decay-steps", type=int, default=0) + + global args + global nbatches + global nbatches_test + global writer + args = parser.parse_args() + + if args.dataset_multiprocessing: + assert sys.version_info[0] >= 3 and sys.version_info[1] > 7, ( + "The dataset_multiprocessing " + + "flag is susceptible to a bug in Python 3.7 and under. " + + "https://github.com/facebookresearch/dlrm/issues/172" + ) + + if args.mlperf_logging: + mlperf_logger.log_event(key=mlperf_logger.constants.CACHE_CLEAR, value=True) + mlperf_logger.log_start( + key=mlperf_logger.constants.INIT_START, log_all_ranks=True + ) + + if args.weighted_pooling is not None: + if args.qr_flag: + sys.exit("ERROR: quotient remainder with weighted pooling is not supported") + if args.md_flag: + sys.exit("ERROR: mixed dimensions with weighted pooling is not supported") + if args.quantize_emb_with_bit in [4, 8]: + if args.qr_flag: + sys.exit( + "ERROR: 4 and 8-bit quantization with quotient remainder is not supported" + ) + if args.md_flag: + sys.exit( + "ERROR: 4 and 8-bit quantization with mixed dimensions is not supported" + ) + if args.use_gpu: + sys.exit("ERROR: 4 and 8-bit quantization on GPU is not supported") + + ### some basic setup ### + np.random.seed(args.numpy_rand_seed) + np.set_printoptions(precision=args.print_precision) + torch.set_printoptions(precision=args.print_precision) + torch.manual_seed(args.numpy_rand_seed) + + if args.test_mini_batch_size < 0: + # if the parameter is not set, use the training batch size + args.test_mini_batch_size = args.mini_batch_size + if args.test_num_workers < 0: + # if the parameter is not set, use the same parameter for training + args.test_num_workers = args.num_workers + + use_gpu = args.use_gpu and torch.cuda.is_available() + + if not args.debug_mode: + ext_dist.init_distributed( + local_rank=args.local_rank, use_gpu=use_gpu, backend=args.dist_backend + ) + + if use_gpu: + torch.cuda.manual_seed_all(args.numpy_rand_seed) + torch.backends.cudnn.deterministic = True + if ext_dist.my_size > 1: + ngpus = 1 + device = torch.device("cuda", ext_dist.my_local_rank) + else: + ngpus = torch.cuda.device_count() + device = torch.device("cuda", 0) + print("Using {} GPU(s)...".format(ngpus)) + else: + device = torch.device("cpu") + print("Using CPU...") + + ### prepare training data ### + ln_bot = np.fromstring(args.arch_mlp_bot, dtype=int, sep="-") + # input data + + if args.mlperf_logging: + mlperf_logger.barrier() + mlperf_logger.log_end(key=mlperf_logger.constants.INIT_STOP) + mlperf_logger.barrier() + mlperf_logger.log_start(key=mlperf_logger.constants.RUN_START) + mlperf_logger.barrier() + + if args.data_generation == "dataset": + train_data, train_ld, test_data, test_ld = dp.make_criteo_data_and_loaders(args) + table_feature_map = {idx: idx for idx in range(len(train_data.counts))} + nbatches = args.num_batches if args.num_batches > 0 else len(train_ld) + nbatches_test = len(test_ld) + + ln_emb = train_data.counts + # enforce maximum limit on number of vectors per embedding + if args.max_ind_range > 0: + ln_emb = np.array( + list( + map( + lambda x: x if x < args.max_ind_range else args.max_ind_range, + ln_emb, + ) + ) + ) + else: + ln_emb = np.array(ln_emb) + m_den = train_data.m_den + ln_bot[0] = m_den + elif args.data_generation == "internal": + if not has_internal_libs: + raise Exception("Internal libraries are not available.") + NUM_BATCHES = 5000 + nbatches = args.num_batches if args.num_batches > 0 else NUM_BATCHES + train_ld, feature_to_num_embeddings = fbDataLoader(args.data_size, nbatches) + ln_emb = np.array(list(feature_to_num_embeddings.values())) + m_den = ln_bot[0] + else: + # input and target at random + ln_emb = np.fromstring(args.arch_embedding_size, dtype=int, sep="-") + m_den = ln_bot[0] + train_data, train_ld, test_data, test_ld = dp.make_random_data_and_loader( + args, ln_emb, m_den + ) + nbatches = args.num_batches if args.num_batches > 0 else len(train_ld) + nbatches_test = len(test_ld) + + args.ln_emb = ln_emb.tolist() + if args.mlperf_logging: + print("command line args: ", json.dumps(vars(args))) + + ### parse command line arguments ### + m_spa = args.arch_sparse_feature_size + ln_emb = np.asarray(ln_emb) + num_fea = ln_emb.size + 1 # num sparse + num dense features + + m_den_out = ln_bot[ln_bot.size - 1] + if args.arch_interaction_op == "dot": + # approach 1: all + # num_int = num_fea * num_fea + m_den_out + # approach 2: unique + if args.arch_interaction_itself: + num_int = (num_fea * (num_fea + 1)) // 2 + m_den_out + else: + num_int = (num_fea * (num_fea - 1)) // 2 + m_den_out + elif args.arch_interaction_op == "cat": + num_int = num_fea * m_den_out + else: + sys.exit( + "ERROR: --arch-interaction-op=" + + args.arch_interaction_op + + " is not supported" + ) + arch_mlp_top_adjusted = str(num_int) + "-" + args.arch_mlp_top + ln_top = np.fromstring(arch_mlp_top_adjusted, dtype=int, sep="-") + + # sanity check: feature sizes and mlp dimensions must match + if m_den != ln_bot[0]: + sys.exit( + "ERROR: arch-dense-feature-size " + + str(m_den) + + " does not match first dim of bottom mlp " + + str(ln_bot[0]) + ) + if args.qr_flag: + if args.qr_operation == "concat" and 2 * m_spa != m_den_out: + sys.exit( + "ERROR: 2 arch-sparse-feature-size " + + str(2 * m_spa) + + " does not match last dim of bottom mlp " + + str(m_den_out) + + " (note that the last dim of bottom mlp must be 2x the embedding dim)" + ) + if args.qr_operation != "concat" and m_spa != m_den_out: + sys.exit( + "ERROR: arch-sparse-feature-size " + + str(m_spa) + + " does not match last dim of bottom mlp " + + str(m_den_out) + ) + else: + if m_spa != m_den_out: + sys.exit( + "ERROR: arch-sparse-feature-size " + + str(m_spa) + + " does not match last dim of bottom mlp " + + str(m_den_out) + ) + if num_int != ln_top[0]: + sys.exit( + "ERROR: # of feature interactions " + + str(num_int) + + " does not match first dimension of top mlp " + + str(ln_top[0]) + ) + + # assign mixed dimensions if applicable + if args.md_flag: + m_spa = md_solver( + torch.tensor(ln_emb), + args.md_temperature, # alpha + d0=m_spa, + round_dim=args.md_round_dims, + ).tolist() + + # test prints (model arch) + if args.debug_mode: + print("model arch:") + print( + "mlp top arch " + + str(ln_top.size - 1) + + " layers, with input to output dimensions:" + ) + print(ln_top) + print("# of interactions") + print(num_int) + print( + "mlp bot arch " + + str(ln_bot.size - 1) + + " layers, with input to output dimensions:" + ) + print(ln_bot) + print("# of features (sparse and dense)") + print(num_fea) + print("dense feature size") + print(m_den) + print("sparse feature size") + print(m_spa) + print( + "# of embeddings (= # of sparse features) " + + str(ln_emb.size) + + ", with dimensions " + + str(m_spa) + + "x:" + ) + print(ln_emb) + + print("data (inputs and targets):") + for j, inputBatch in enumerate(train_ld): + X, lS_o, lS_i, T, W, CBPP = unpack_batch(inputBatch) + + torch.set_printoptions(precision=4) + # early exit if nbatches was set by the user and has been exceeded + if nbatches > 0 and j >= nbatches: + break + print("mini-batch: %d" % j) + print(X.detach().cpu()) + # transform offsets to lengths when printing + print( + torch.IntTensor( + [ + np.diff( + S_o.detach().cpu().tolist() + list(lS_i[i].shape) + ).tolist() + for i, S_o in enumerate(lS_o) + ] + ) + ) + print([S_i.detach().cpu() for S_i in lS_i]) + print(T.detach().cpu()) + + global ndevices + ndevices = min(ngpus, args.mini_batch_size, num_fea - 1) if use_gpu else -1 + + ### construct the neural network specified above ### + # WARNING: to obtain exactly the same initialization for + # the weights we need to start from the same random seed. + # np.random.seed(args.numpy_rand_seed) + global dlrm + dlrm = DLRM_Net( + m_spa, + ln_emb, + ln_bot, + ln_top, + arch_interaction_op=args.arch_interaction_op, + arch_interaction_itself=args.arch_interaction_itself, + sigmoid_bot=-1, + sigmoid_top=ln_top.size - 2, + sync_dense_params=args.sync_dense_params, + loss_threshold=args.loss_threshold, + ndevices=ndevices, + qr_flag=args.qr_flag, + qr_operation=args.qr_operation, + qr_collisions=args.qr_collisions, + qr_threshold=args.qr_threshold, + md_flag=args.md_flag, + md_threshold=args.md_threshold, + weighted_pooling=args.weighted_pooling, + loss_function=args.loss_function, + ) + + # test prints + if args.debug_mode: + print("initial parameters (weights and bias):") + for param in dlrm.parameters(): + print(param.detach().cpu().numpy()) + # print(dlrm) + + if use_gpu: + # Custom Model-Data Parallel + # the mlps are replicated and use data parallelism, while + # the embeddings are distributed and use model parallelism + dlrm = dlrm.to(device) # .cuda() + if dlrm.ndevices > 1: + dlrm.emb_l, dlrm.v_W_l = dlrm.create_emb( + m_spa, ln_emb, args.weighted_pooling + ) + else: + if dlrm.weighted_pooling == "fixed": + for k, w in enumerate(dlrm.v_W_l): + dlrm.v_W_l[k] = w.cuda() + + # distribute data parallel mlps + if ext_dist.my_size > 1: + if use_gpu: + device_ids = [ext_dist.my_local_rank] + dlrm.bot_l = ext_dist.DDP(dlrm.bot_l, device_ids=device_ids) + dlrm.top_l = ext_dist.DDP(dlrm.top_l, device_ids=device_ids) + else: + dlrm.bot_l = ext_dist.DDP(dlrm.bot_l) + dlrm.top_l = ext_dist.DDP(dlrm.top_l) + + if not args.inference_only: + if use_gpu and args.optimizer in ["rwsadagrad", "adagrad"]: + sys.exit("GPU version of Adagrad is not supported by PyTorch.") + # specify the optimizer algorithm + opts = { + "sgd": torch.optim.SGD, + "rwsadagrad": RowWiseSparseAdagrad.RWSAdagrad, + "adagrad": torch.optim.Adagrad, + } + + parameters = ( + dlrm.parameters() + if ext_dist.my_size == 1 + else [ + { + "params": [p for emb in dlrm.emb_l for p in emb.parameters()], + "lr": args.learning_rate, + }, + # TODO check this lr setup + # bottom mlp has no data parallelism + # need to check how do we deal with top mlp + { + "params": dlrm.bot_l.parameters(), + "lr": args.learning_rate, + }, + { + "params": dlrm.top_l.parameters(), + "lr": args.learning_rate, + }, + ] + ) + optimizer = opts[args.optimizer](parameters, lr=args.learning_rate) + lr_scheduler = LRPolicyScheduler( + optimizer, + args.lr_num_warmup_steps, + args.lr_decay_start_step, + args.lr_num_decay_steps, + ) + + ### main loop ### + + # training or inference + best_acc_test = 0 + best_auc_test = 0 + skip_upto_epoch = 0 + skip_upto_batch = 0 + total_time = 0 + total_loss = 0 + total_iter = 0 + total_samp = 0 + + if args.mlperf_logging: + mlperf_logger.mlperf_submission_log("dlrm") + mlperf_logger.log_event( + key=mlperf_logger.constants.SEED, value=args.numpy_rand_seed + ) + mlperf_logger.log_event( + key=mlperf_logger.constants.GLOBAL_BATCH_SIZE, value=args.mini_batch_size + ) + + # Load model is specified + if not (args.load_model == ""): + print("Loading saved model {}".format(args.load_model)) + if use_gpu: + if dlrm.ndevices > 1: + # NOTE: when targeting inference on multiple GPUs, + # load the model as is on CPU or GPU, with the move + # to multiple GPUs to be done in parallel_forward + ld_model = torch.load(args.load_model) + else: + # NOTE: when targeting inference on single GPU, + # note that the call to .to(device) has already happened + ld_model = torch.load( + args.load_model, + map_location=torch.device("cuda"), + # map_location=lambda storage, loc: storage.cuda(0) + ) + else: + # when targeting inference on CPU + ld_model = torch.load(args.load_model, map_location=torch.device("cpu")) + dlrm.load_state_dict(ld_model["state_dict"]) + ld_j = ld_model["iter"] + ld_k = ld_model["epoch"] + ld_nepochs = ld_model["nepochs"] + ld_nbatches = ld_model["nbatches"] + ld_nbatches_test = ld_model["nbatches_test"] + ld_train_loss = ld_model["train_loss"] + ld_total_loss = ld_model["total_loss"] + if args.mlperf_logging: + ld_gAUC_test = ld_model["test_auc"] + ld_acc_test = ld_model["test_acc"] + if not args.inference_only: + optimizer.load_state_dict(ld_model["opt_state_dict"]) + best_acc_test = ld_acc_test + total_loss = ld_total_loss + skip_upto_epoch = ld_k # epochs + skip_upto_batch = ld_j # batches + else: + args.print_freq = ld_nbatches + args.test_freq = 0 + + print( + "Saved at: epoch = {:d}/{:d}, batch = {:d}/{:d}, ntbatch = {:d}".format( + ld_k, ld_nepochs, ld_j, ld_nbatches, ld_nbatches_test + ) + ) + print( + "Training state: loss = {:.6f}".format( + ld_train_loss, + ) + ) + if args.mlperf_logging: + print( + "Testing state: accuracy = {:3.3f} %, auc = {:.3f}".format( + ld_acc_test * 100, ld_gAUC_test + ) + ) + else: + print("Testing state: accuracy = {:3.3f} %".format(ld_acc_test * 100)) + + if args.inference_only: + # Currently only dynamic quantization with INT8 and FP16 weights are + # supported for MLPs and INT4 and INT8 weights for EmbeddingBag + # post-training quantization during the inference. + # By default we don't do the quantization: quantize_{mlp,emb}_with_bit == 32 (FP32) + assert args.quantize_mlp_with_bit in [ + 8, + 16, + 32, + ], "only support 8/16/32-bit but got {}".format(args.quantize_mlp_with_bit) + assert args.quantize_emb_with_bit in [ + 4, + 8, + 32, + ], "only support 4/8/32-bit but got {}".format(args.quantize_emb_with_bit) + if args.quantize_mlp_with_bit != 32: + if args.quantize_mlp_with_bit in [8]: + quantize_dtype = torch.qint8 + else: + quantize_dtype = torch.float16 + dlrm = torch.quantization.quantize_dynamic( + dlrm, {torch.nn.Linear}, quantize_dtype + ) + if args.quantize_emb_with_bit != 32: + dlrm.quantize_embedding(args.quantize_emb_with_bit) + # print(dlrm) + + print("time/loss/accuracy (if enabled):") + + if args.mlperf_logging: + # LR is logged twice for now because of a compliance checker bug + mlperf_logger.log_event( + key=mlperf_logger.constants.OPT_BASE_LR, value=args.learning_rate + ) + mlperf_logger.log_event( + key=mlperf_logger.constants.OPT_LR_WARMUP_STEPS, + value=args.lr_num_warmup_steps, + ) + + # use logging keys from the official HP table and not from the logging library + mlperf_logger.log_event( + key="sgd_opt_base_learning_rate", value=args.learning_rate + ) + mlperf_logger.log_event( + key="lr_decay_start_steps", value=args.lr_decay_start_step + ) + mlperf_logger.log_event( + key="sgd_opt_learning_rate_decay_steps", value=args.lr_num_decay_steps + ) + mlperf_logger.log_event(key="sgd_opt_learning_rate_decay_poly_power", value=2) + + tb_file = "./" + args.tensor_board_filename + writer = SummaryWriter(tb_file) + + ext_dist.barrier() + with torch.autograd.profiler.profile( + args.enable_profiling, use_cuda=use_gpu, record_shapes=True + ) as prof: + if not args.inference_only: + k = 0 + total_time_begin = 0 + while k < args.nepochs: + if args.mlperf_logging: + mlperf_logger.barrier() + mlperf_logger.log_start( + key=mlperf_logger.constants.BLOCK_START, + metadata={ + mlperf_logger.constants.FIRST_EPOCH_NUM: (k + 1), + mlperf_logger.constants.EPOCH_COUNT: 1, + }, + ) + mlperf_logger.barrier() + mlperf_logger.log_start( + key=mlperf_logger.constants.EPOCH_START, + metadata={mlperf_logger.constants.EPOCH_NUM: (k + 1)}, + ) + + if k < skip_upto_epoch: + continue + + if args.mlperf_logging: + previous_iteration_time = None + + for j, inputBatch in enumerate(train_ld): + if j == 0 and args.save_onnx: + X_onnx, lS_o_onnx, lS_i_onnx, _, _, _ = unpack_batch(inputBatch) + + if j < skip_upto_batch: + continue + + X, lS_o, lS_i, T, W, CBPP = unpack_batch(inputBatch) + + if args.mlperf_logging: + current_time = time_wrap(use_gpu) + if previous_iteration_time: + iteration_time = current_time - previous_iteration_time + else: + iteration_time = 0 + previous_iteration_time = current_time + else: + t1 = time_wrap(use_gpu) + + # early exit if nbatches was set by the user and has been exceeded + if nbatches > 0 and j >= nbatches: + break + + # Skip the batch if batch size not multiple of total ranks + if ext_dist.my_size > 1 and X.size(0) % ext_dist.my_size != 0: + print( + "Warning: Skiping the batch %d with size %d" + % (j, X.size(0)) + ) + continue + + mbs = T.shape[0] # = args.mini_batch_size except maybe for last + + # forward pass + Z = dlrm_wrap( + X, + lS_o, + lS_i, + use_gpu, + device, + ndevices=ndevices, + ) + + if ext_dist.my_size > 1: + T = T[ext_dist.get_my_slice(mbs)] + W = W[ext_dist.get_my_slice(mbs)] + + # loss + E = loss_fn_wrap(Z, T, use_gpu, device) + + # compute loss and accuracy + L = E.detach().cpu().numpy() # numpy array + # training accuracy is not disabled + # S = Z.detach().cpu().numpy() # numpy array + # T = T.detach().cpu().numpy() # numpy array + + # # print("res: ", S) + + # # print("j, train: BCE ", j, L) + + # mbs = T.shape[0] # = args.mini_batch_size except maybe for last + # A = np.sum((np.round(S, 0) == T).astype(np.uint8)) + + with record_function("DLRM backward"): + # scaled error gradient propagation + # (where we do not accumulate gradients across mini-batches) + if ( + args.mlperf_logging + and (j + 1) % args.mlperf_grad_accum_iter == 0 + ) or not args.mlperf_logging: + optimizer.zero_grad() + # backward pass + E.backward() + + # optimizer + if ( + args.mlperf_logging + and (j + 1) % args.mlperf_grad_accum_iter == 0 + ) or not args.mlperf_logging: + optimizer.step() + lr_scheduler.step() + + if args.mlperf_logging: + total_time += iteration_time + else: + t2 = time_wrap(use_gpu) + total_time += t2 - t1 + + total_loss += L * mbs + total_iter += 1 + total_samp += mbs + + should_print = ((j + 1) % args.print_freq == 0) or ( + j + 1 == nbatches + ) + should_test = ( + (args.test_freq > 0) + and (args.data_generation in ["dataset", "random"]) + and (((j + 1) % args.test_freq == 0) or (j + 1 == nbatches)) + ) + + # print time, loss and accuracy + if should_print or should_test: + gT = 1000.0 * total_time / total_iter if args.print_time else -1 + total_time = 0 + + train_loss = total_loss / total_samp + total_loss = 0 + + str_run_type = ( + "inference" if args.inference_only else "training" + ) + + wall_time = "" + if args.print_wall_time: + wall_time = " ({})".format(time.strftime("%H:%M")) + + print( + "Finished {} it {}/{} of epoch {}, {:.2f} ms/it,".format( + str_run_type, j + 1, nbatches, k, gT + ) + + " loss {:.6f}".format(train_loss) + + wall_time, + flush=True, + ) + + log_iter = nbatches * k + j + 1 + writer.add_scalar("Train/Loss", train_loss, log_iter) + + total_iter = 0 + total_samp = 0 + + # testing + if should_test: + epoch_num_float = (j + 1) / len(train_ld) + k + 1 + if args.mlperf_logging: + mlperf_logger.barrier() + mlperf_logger.log_start( + key=mlperf_logger.constants.EVAL_START, + metadata={ + mlperf_logger.constants.EPOCH_NUM: epoch_num_float + }, + ) + + # don't measure training iter time in a test iteration + if args.mlperf_logging: + previous_iteration_time = None + print( + "Testing at - {}/{} of epoch {},".format(j + 1, nbatches, k) + ) + model_metrics_dict, is_best = inference( + args, + dlrm, + best_acc_test, + best_auc_test, + test_ld, + device, + use_gpu, + log_iter, + ) + + if ( + is_best + and not (args.save_model == "") + and not args.inference_only + ): + model_metrics_dict["epoch"] = k + model_metrics_dict["iter"] = j + 1 + model_metrics_dict["train_loss"] = train_loss + model_metrics_dict["total_loss"] = total_loss + model_metrics_dict["opt_state_dict"] = ( + optimizer.state_dict() + ) + print("Saving model to {}".format(args.save_model)) + torch.save(model_metrics_dict, args.save_model) + + if args.mlperf_logging: + mlperf_logger.barrier() + mlperf_logger.log_end( + key=mlperf_logger.constants.EVAL_STOP, + metadata={ + mlperf_logger.constants.EPOCH_NUM: epoch_num_float + }, + ) + + # Uncomment the line below to print out the total time with overhead + # print("Total test time for this group: {}" \ + # .format(time_wrap(use_gpu) - accum_test_time_begin)) + + if ( + args.mlperf_logging + and (args.mlperf_acc_threshold > 0) + and (best_acc_test > args.mlperf_acc_threshold) + ): + print( + "MLPerf testing accuracy threshold " + + str(args.mlperf_acc_threshold) + + " reached, stop training" + ) + break + + if ( + args.mlperf_logging + and (args.mlperf_auc_threshold > 0) + and (best_auc_test > args.mlperf_auc_threshold) + ): + print( + "MLPerf testing auc threshold " + + str(args.mlperf_auc_threshold) + + " reached, stop training" + ) + if args.mlperf_logging: + mlperf_logger.barrier() + mlperf_logger.log_end( + key=mlperf_logger.constants.RUN_STOP, + metadata={ + mlperf_logger.constants.STATUS: mlperf_logger.constants.SUCCESS + }, + ) + break + + if args.mlperf_logging: + mlperf_logger.barrier() + mlperf_logger.log_end( + key=mlperf_logger.constants.EPOCH_STOP, + metadata={mlperf_logger.constants.EPOCH_NUM: (k + 1)}, + ) + mlperf_logger.barrier() + mlperf_logger.log_end( + key=mlperf_logger.constants.BLOCK_STOP, + metadata={mlperf_logger.constants.FIRST_EPOCH_NUM: (k + 1)}, + ) + k += 1 # nepochs + if args.mlperf_logging and best_auc_test <= args.mlperf_auc_threshold: + mlperf_logger.barrier() + mlperf_logger.log_end( + key=mlperf_logger.constants.RUN_STOP, + metadata={ + mlperf_logger.constants.STATUS: mlperf_logger.constants.ABORTED + }, + ) + else: + print("Testing for inference only") + inference( + args, + dlrm, + best_acc_test, + best_auc_test, + test_ld, + device, + use_gpu, + ) + + # profiling + if args.enable_profiling: + time_stamp = str(datetime.datetime.now()).replace(" ", "_") + with open("dlrm_s_pytorch" + time_stamp + "_shape.prof", "w") as prof_f: + prof_f.write( + prof.key_averages(group_by_input_shape=True).table( + sort_by="self_cpu_time_total" + ) + ) + with open("dlrm_s_pytorch" + time_stamp + "_total.prof", "w") as prof_f: + prof_f.write(prof.key_averages().table(sort_by="self_cpu_time_total")) + prof.export_chrome_trace("dlrm_s_pytorch" + time_stamp + ".json") + # print(prof.key_averages().table(sort_by="cpu_time_total")) + + # plot compute graph + if args.plot_compute_graph: + sys.exit( + "ERROR: Please install pytorchviz package in order to use the" + + " visualization. Then, uncomment its import above as well as" + + " three lines below and run the code again." + ) + # V = Z.mean() if args.inference_only else E + # dot = make_dot(V, params=dict(dlrm.named_parameters())) + # dot.render('dlrm_s_pytorch_graph') # write .pdf file + + # test prints + if not args.inference_only and args.debug_mode: + print("updated parameters (weights and bias):") + for param in dlrm.parameters(): + print(param.detach().cpu().numpy()) + + # export the model in onnx + if args.save_onnx: + """ + # workaround 1: tensor -> list + if torch.is_tensor(lS_i_onnx): + lS_i_onnx = [lS_i_onnx[j] for j in range(len(lS_i_onnx))] + # workaound 2: list -> tensor + lS_i_onnx = torch.stack(lS_i_onnx) + """ + # debug prints + print("inputs", X_onnx, lS_o_onnx, lS_i_onnx) + print("output", dlrm_wrap(X_onnx, lS_o_onnx, lS_i_onnx, use_gpu, device)) + dlrm_pytorch_onnx_file = "dlrm_s_pytorch.onnx" + batch_size = X_onnx.shape[0] + print("X_onnx.shape", X_onnx.shape) + if torch.is_tensor(lS_o_onnx): + print("lS_o_onnx.shape", lS_o_onnx.shape) + else: + for oo in lS_o_onnx: + print("oo.shape", oo.shape) + if torch.is_tensor(lS_i_onnx): + print("lS_i_onnx.shape", lS_i_onnx.shape) + else: + for ii in lS_i_onnx: + print("ii.shape", ii.shape) + + # name inputs and outputs + o_inputs = ( + ["offsets"] + if torch.is_tensor(lS_o_onnx) + else ["offsets_" + str(i) for i in range(len(lS_o_onnx))] + ) + i_inputs = ( + ["indices"] + if torch.is_tensor(lS_i_onnx) + else ["indices_" + str(i) for i in range(len(lS_i_onnx))] + ) + all_inputs = ["dense_x"] + o_inputs + i_inputs + # debug prints + print("inputs", all_inputs) + + # create dynamic_axis dictionaries + do_inputs = ( + [{"offsets": {1: "batch_size"}}] + if torch.is_tensor(lS_o_onnx) + else [ + {"offsets_" + str(i): {0: "batch_size"}} for i in range(len(lS_o_onnx)) + ] + ) + di_inputs = ( + [{"indices": {1: "batch_size"}}] + if torch.is_tensor(lS_i_onnx) + else [ + {"indices_" + str(i): {0: "batch_size"}} for i in range(len(lS_i_onnx)) + ] + ) + dynamic_axes = {"dense_x": {0: "batch_size"}, "pred": {0: "batch_size"}} + for do in do_inputs: + dynamic_axes.update(do) + for di in di_inputs: + dynamic_axes.update(di) + # debug prints + print(dynamic_axes) + + # export model + with torch.no_grad(): + bo.export_qonnx( + dlrm, + (X_onnx, lS_o_onnx, lS_i_onnx), + dlrm_pytorch_onnx_file, + verbose=True, + opset_version=11, + input_names=all_inputs, + output_names=["pred"], + dynamic_axes=dynamic_axes, + dynamo=False, + ) + # recover the model back + dlrm_pytorch_onnx = onnx.load("dlrm_s_pytorch.onnx") + # check the onnx model + onnx.checker.check_model(dlrm_pytorch_onnx) + total_time_end = time_wrap(use_gpu) + + +if __name__ == "__main__": + run() diff --git a/examples/dlrm/export_onnx_context.py b/examples/dlrm/export_onnx_context.py new file mode 100755 index 00000000..fb31a8e2 --- /dev/null +++ b/examples/dlrm/export_onnx_context.py @@ -0,0 +1,236 @@ +#!/usr/bin/env python3 +""" +Export ONNX models to a human-readable context format. + +This script generates a comprehensive but readable representation of an ONNX model +that can be easily consumed by AI assistants or used for documentation. +""" + +import argparse +import sys +from pathlib import Path +from io import StringIO +import onnx +from onnxscript import ir + + +def truncate_array_repr(text, max_elements=10): + """Truncate large array representations in the text.""" + # This is a simple heuristic - could be improved + return text + + +def export_onnx_context(onnx_path, output_path=None, max_array_size=20, verbose=False): + """ + Export ONNX model to a readable context file. + + Args: + onnx_path: Path to the ONNX model file + output_path: Path to save the output (default: {model_name}_context.txt) + max_array_size: Maximum number of elements to show in arrays + verbose: Print progress messages + """ + onnx_path = Path(onnx_path) + + if not onnx_path.exists(): + raise FileNotFoundError(f"ONNX file not found: {onnx_path}") + + # Determine output path + if output_path is None: + output_path = onnx_path.parent / f"{onnx_path.stem}_context.txt" + else: + output_path = Path(output_path) + + if verbose: + print(f"Loading ONNX model from: {onnx_path}") + + # Load and deserialize the model + proto = onnx.load(str(onnx_path)) + ir_model = ir.serde.deserialize_model(proto) + + if verbose: + print(f"Writing context to: {output_path}") + + with open(output_path, 'w', encoding='utf-8') as f: + # Header + f.write("=" * 80 + "\n") + f.write(f"ONNX Model Context: {onnx_path.name}\n") + f.write("=" * 80 + "\n\n") + + # 1. High-level summary + f.write("# Model Metadata\n") + f.write("-" * 80 + "\n") + f.write(f"IR Version: {ir_model.ir_version}\n") + f.write(f"Opset Imports: {ir_model.opset_imports}\n") + f.write(f"Producer Name: {ir_model.producer_name}\n") + f.write(f"Producer Version: {ir_model.producer_version}\n") + f.write(f"Domain: {ir_model.domain}\n") + f.write(f"Model Version: {ir_model.model_version}\n") + f.write(f"Graph Name: {ir_model.graph.name}\n") + f.write("\n") + + # 2. Inputs + f.write("# Graph Inputs\n") + f.write("-" * 80 + "\n") + if ir_model.graph.inputs: + for i, inp in enumerate(ir_model.graph.inputs): + f.write(f"{i+1}. {inp.name}\n") + f.write(f" Type: {inp.type}\n") + f.write(f" Shape: {inp.shape if hasattr(inp, 'shape') else 'N/A'}\n") + else: + f.write("No inputs\n") + f.write("\n") + + # 3. Outputs + f.write("# Graph Outputs\n") + f.write("-" * 80 + "\n") + if ir_model.graph.outputs: + for i, out in enumerate(ir_model.graph.outputs): + f.write(f"{i+1}. {out.name}\n") + f.write(f" Type: {out.type}\n") + f.write(f" Shape: {out.shape if hasattr(out, 'shape') else 'N/A'}\n") + else: + f.write("No outputs\n") + f.write("\n") + + # 4. Initializers summary + f.write("# Initializers (Weights & Constants)\n") + f.write("-" * 80 + "\n") + initializers = list(ir_model.graph.initializers.values()) + if initializers: + f.write(f"Total: {len(initializers)}\n\n") + for init in initializers: + f.write(f" • {init.name}: {init.type}") + if hasattr(init, 'shape') and init.shape: + total_elements = 1 + for dim in init.shape.dims: + if dim: + total_elements *= dim + f.write(f" [{total_elements} elements]") + f.write("\n") + else: + f.write("No initializers\n") + f.write("\n") + + # 5. Operation statistics + f.write("# Operation Statistics\n") + f.write("-" * 80 + "\n") + op_counts = {} + total_nodes = 0 + for node in ir_model.graph: + op_type = node.op_type + op_counts[op_type] = op_counts.get(op_type, 0) + 1 + total_nodes += 1 + + f.write(f"Total Nodes: {total_nodes}\n\n") + f.write("Operation Type Distribution:\n") + for op, count in sorted(op_counts.items(), key=lambda x: (-x[1], x[0])): + f.write(f" {op:30s} : {count:4d}\n") + f.write("\n") + + # 6. Node connectivity summary + f.write("# Node Connectivity\n") + f.write("-" * 80 + "\n") + f.write("Input -> Consumers:\n") + for inp in ir_model.graph.inputs: + consumers = list(inp.consumers()) + f.write(f" {inp.name} -> {len(consumers)} consumer(s)") + if consumers: + consumer_names = [c.name for c in consumers[:5]] + if len(consumers) > 5: + consumer_names.append(f"... and {len(consumers)-5} more") + f.write(f": {', '.join(consumer_names)}") + f.write("\n") + f.write("\n") + + # 7. Full IR display + f.write("# Complete Graph Structure (ONNX Script IR Format)\n") + f.write("=" * 80 + "\n") + f.write("This section shows the complete computational graph in ONNX Script format.\n") + f.write("Each node shows: operation type, inputs, outputs, and attributes.\n") + f.write("=" * 80 + "\n\n") + + # Capture the display output + old_stdout = sys.stdout + sys.stdout = StringIO() + ir_model.display() + full_display = sys.stdout.getvalue() + sys.stdout = old_stdout + + f.write(full_display) + f.write("\n\n") + + # 8. Footer + f.write("=" * 80 + "\n") + f.write("End of ONNX Model Context\n") + f.write("=" * 80 + "\n") + + if verbose: + print(f"✓ Context file created successfully: {output_path}") + print(f" Total nodes: {total_nodes}") + print(f" File size: {output_path.stat().st_size / 1024:.1f} KB") + + return output_path + + +def main(): + parser = argparse.ArgumentParser( + description="Export ONNX model to readable context format", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Export with default output name + python export_onnx_context.py model.onnx + + # Specify output file + python export_onnx_context.py model.onnx -o context.txt + + # Verbose output + python export_onnx_context.py model.onnx -v + """ + ) + + parser.add_argument( + "onnx_file", + help="Path to the ONNX model file" + ) + + parser.add_argument( + "-o", "--output", + help="Output file path (default: {model_name}_context.txt)", + default=None + ) + + parser.add_argument( + "--max-array-size", + type=int, + default=20, + help="Maximum number of array elements to display (default: 20)" + ) + + parser.add_argument( + "-v", "--verbose", + action="store_true", + help="Print verbose output" + ) + + args = parser.parse_args() + + try: + output_path = export_onnx_context( + args.onnx_file, + args.output, + args.max_array_size, + args.verbose + ) + return 0 + except Exception as e: + print(f"Error: {e}", file=sys.stderr) + if args.verbose: + import traceback + traceback.print_exc() + return 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/examples/dlrm/plugins/__init__.py b/examples/dlrm/plugins/__init__.py new file mode 100644 index 00000000..748fdf5c --- /dev/null +++ b/examples/dlrm/plugins/__init__.py @@ -0,0 +1,2 @@ +from . import custom_steps + diff --git a/examples/dlrm/plugins/custom_steps.py b/examples/dlrm/plugins/custom_steps.py new file mode 100644 index 00000000..5547859d --- /dev/null +++ b/examples/dlrm/plugins/custom_steps.py @@ -0,0 +1,144 @@ +############################################################################ +# @author Joshua Monson +# @author Thomas Keller +############################################################################ + +""" +BERT-Specific Custom Build Steps + +Custom steps specifically for BERT model processing, including: +- Head and tail removal for model decomposition +- Metadata extraction for shell integration +- Reference I/O generation for validation + +These steps are highly specific to BERT model architecture and +are not general-purpose FINN dataflow compilation steps. +""" + +from brainsmith.registry import step +from finn.util import onnxscript_helpers as oxh +from qonnx.core.modelwrapper import ModelWrapper +from onnxscript import ir +from .import reachablefrominputx as reachable +import onnx + + +@step(name="split_sparse_processing") +def split_sparse_processing(model, cfg): + """Separate the sparse processing parts of the model into their own graph + because they will not go on the FPGA. + """ + transform = reachable.ReachableFromInputTransform() + transform.apply(model) + + # Get the IR model from the transform (same instance used for analysis) + model_ir = transform.ir_model + + # Get reachable nodes for each input (these are IR node objects from model_ir) + dense_x_nodes = transform.get_reachable_nodes("dense_x") + indices_0_nodes = transform.get_reachable_nodes("indices_0") + indices_1_nodes = transform.get_reachable_nodes("indices_1") + indices_2_nodes = transform.get_reachable_nodes("indices_2") + offsets_nodes = transform.get_reachable_nodes("offsets") + + # All nodes reachable from sparse inputs + all_sparse = indices_0_nodes | indices_1_nodes | indices_2_nodes | offsets_nodes + + # Dense nodes: reachable from dense_x (includes overlapping nodes) + dense_nodes = dense_x_nodes + + # Sparse nodes: ONLY reachable from sparse inputs, NOT from dense_x + sparse_nodes = all_sparse - dense_x_nodes + + # Check for missing nodes (not reachable from any input) + all_graph_nodes = set(model_ir.graph) + all_partitioned_nodes = sparse_nodes | dense_nodes + missing_nodes = all_graph_nodes - all_partitioned_nodes + + # Assign missing nodes to the partition where their consumers are + # (these are typically constant/empty initializer nodes with no inputs) + for missing_node in missing_nodes: + # Check which partition(s) consume this node's outputs + sparse_consumers = 0 + dense_consumers = 0 + for output in missing_node.outputs: + for consumer in output.consumers(): + if consumer in sparse_nodes: + sparse_consumers += 1 + elif consumer in dense_nodes: + dense_consumers += 1 + + # Assign to the partition with more consumers (or sparse if tied/no consumers) + if dense_consumers > sparse_consumers: + dense_nodes.add(missing_node) + else: + sparse_nodes.add(missing_node) + + # Recompute coverage after assigning missing nodes + all_partitioned_nodes = sparse_nodes | dense_nodes + remaining_missing = all_graph_nodes - all_partitioned_nodes + + print(f"\n=== Graph Partitioning Summary ===") + print(f"Total nodes in graph: {len(all_graph_nodes)}") + print(f"Sparse nodes: {len(sparse_nodes)} (includes {len(missing_nodes)} non-input nodes)") + print(f"Dense nodes: {len(dense_nodes)}") + print(f"Overlap nodes (going to FPGA): {len(dense_x_nodes & all_sparse)}") + print(f"Accounted for: {len(all_partitioned_nodes)}") + print(f"Remaining missing: {len(remaining_missing)}") + + if missing_nodes: + print(f"\n=== Originally Missing Nodes (now assigned) ===") + for node in missing_nodes: + partition = "SPARSE" if node in sparse_nodes else "DENSE" + print(f" - {node.name} ({node.op_type}) -> {partition}") + + if remaining_missing: + print(f"\n=== Still Missing Nodes (could not assign) ===") + for node in remaining_missing: + print(f" - {node.name} ({node.op_type})") + print(f" Inputs: {[inp.name if inp else 'None' for inp in node.inputs]}") + print(f" Outputs: {[out.name for out in node.outputs]}") + # Check if outputs are consumed and which partition consumers are in + consumers = [] + consumer_partitions = [] + for out in node.outputs: + for consumer in out.consumers(): + consumers.append(consumer.name) + if consumer in sparse_nodes: + consumer_partitions.append(f"{consumer.name} (SPARSE)") + elif consumer in dense_nodes: + consumer_partitions.append(f"{consumer.name} (DENSE)") + else: + consumer_partitions.append(f"{consumer.name} (UNKNOWN)") + print(f" Consumers: {consumers if consumers else 'None (unused)'}") + print(f" Consumer partitions: {consumer_partitions}") + + print(f"\nSample sparse node names: {[n.name for n in list(sparse_nodes)[:5]]}") + print(f"Sample dense node names: {[n.name for n in list(dense_nodes)[:5]]}") + + model_ir.graph.sort() + + dense_subgraph = oxh.SubGraphView(model_ir.graph, 'dense', dense_nodes, include_initializers=True) + sparse_subgraph = oxh.SubGraphView(model_ir.graph, 'sparse', sparse_nodes, include_initializers=True) + + dense_model = ir.Model(dense_subgraph, ir_version=model_ir.ir_version) + sparse_model = ir.Model(sparse_subgraph, ir_version=model_ir.ir_version) + + + + sparse_proto = ir.serde.serialize_model(sparse_model) + dense_proto = ir.serde.serialize_model(dense_model) + onnx.save(sparse_proto, "sparse.onnx") + onnx.save(dense_proto, "dense.onnx") + return ModelWrapper(dense_proto) + +@step(name="dense_cleanup") +def dense_cleanup(model: ModelWrapper, cfg): + + model = model.cleanup() + # Optimize the dense model using ONNX Graph Optimization Toolkit + import onnxsim + model, check = onnxsim.simplify(model.model) + if not check: + raise RuntimeError("Unable to simplify the DLRM datapath") + return ModelWrapper(model) diff --git a/examples/dlrm/plugins/reachablefrominputx.py b/examples/dlrm/plugins/reachablefrominputx.py new file mode 100644 index 00000000..206f0bd8 --- /dev/null +++ b/examples/dlrm/plugins/reachablefrominputx.py @@ -0,0 +1,117 @@ +import onnx +from onnx import numpy_helper +from typing import Dict, Set, List +from collections import deque +from qonnx.transformation.base import Transformation +from qonnx.core.modelwrapper import ModelWrapper +from onnxscript import ir + +class ReachableFromInputTransform(Transformation): + """ + Analyzes the ONNX model to determine which nodes are reachable from each input. + This transform does not modify the model, only analyzes reachability. + Assumes the graph is directed and acyclic (DAG). + """ + + def __init__(self): + super().__init__() + self.input_reachability: Dict[str, Set[str]] = {} + self.ir_model = None + + def _build_subgraph_to_node_map(self, graph): + """ + Build a mapping from subgraph to the node that contains it. + Recursively walks through all nodes and their subgraphs. + """ + subgraph_to_node = {} + + def visit_node(node): + # Check all attributes for subgraphs + for attr in node.attributes.values(): + if hasattr(attr, 'value') and isinstance(attr.value, ir.Graph): + subgraph = attr.value + subgraph_to_node[id(subgraph)] = node + # Recursively visit nodes in the subgraph + for sub_node in subgraph: + visit_node(sub_node) + + # Visit all nodes in the main graph + for node in graph: + visit_node(node) + + return subgraph_to_node + + def _get_top_level_node(self, node, main_graph, subgraph_map): + """ + Get the top-level node in the main graph. + If node is in a subgraph, return its enclosing parent node. + Otherwise, return the node itself. + """ + current = node + # Walk up the graph hierarchy until we reach the main graph + while current.graph != main_graph: + graph_id = id(current.graph) + if graph_id not in subgraph_map: + # Couldn't find parent, return current node + break + current = subgraph_map[graph_id] + return current + + def apply(self, model: ModelWrapper) -> ModelWrapper: + """ + Analyze the model and compute reachability for each input. + Returns the original model unchanged. + """ + # Convert ONNX model to IR and store it + self.ir_model = ir.serde.deserialize_model(model.model) + graph = self.ir_model.graph + + # Build mapping from subgraphs to their containing nodes + subgraph_map = self._build_subgraph_to_node_map(graph) + + # Get all model inputs + model_inputs = list(graph.inputs) + + # Compute reachability from each input using BFS + self.input_reachability = {} + for input_value in model_inputs: + reachable_nodes = set() + queue = deque() + + # Start by adding all direct consumers of the input + for consumer in input_value.consumers(): + top_level_node = self._get_top_level_node(consumer, graph, subgraph_map) + queue.append(consumer) + reachable_nodes.add(top_level_node) + + # BFS traversal through nodes (no cycle checking needed for DAG) + while queue: + current_node = queue.popleft() + + # Add all consumers of this node's outputs to the queue + for output_value in current_node.outputs: + for consumer in output_value.consumers(): + top_level_node = self._get_top_level_node(consumer, graph, subgraph_map) + reachable_nodes.add(top_level_node) + queue.append(consumer) + + self.input_reachability[input_value.name] = reachable_nodes + + # Log the results + for input_name, reachable_nodes in self.input_reachability.items(): + print(f"Input '{input_name}' reaches {len(reachable_nodes)} nodes") + + # Return the original model unchanged + return model + + def get_reachable_nodes(self, input_name: str) -> Set[str]: + """ + Get the set of node names reachable from a specific input. + """ + return self.input_reachability.get(input_name, set()) + + def get_all_reachability(self) -> Dict[str, Set[str]]: + """ + Get the complete reachability mapping for all inputs. + """ + return self.input_reachability diff --git a/examples/dlrm/print_model.py b/examples/dlrm/print_model.py new file mode 100644 index 00000000..2dc50adf --- /dev/null +++ b/examples/dlrm/print_model.py @@ -0,0 +1,9 @@ +import onnx +from onnxscript import ir + +proto = onnx.load("dlrm_s_pytorch.onnx") +ir_model = ir.serde.deserialize_model(proto) + +# Export to ONNX Script Python format +python_code = ir_model.display() +print(python_code)