diff --git a/tzrec/acc/aot_utils.py b/tzrec/acc/aot_utils.py index 68686863..ea2f7d0d 100644 --- a/tzrec/acc/aot_utils.py +++ b/tzrec/acc/aot_utils.py @@ -9,62 +9,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -import functools -import os + + from typing import Dict import torch -import torch._prims_common as prims_utils -import torch.nn.functional as F -from fbgemm_gpu.split_table_batched_embeddings_ops_common import ( - BoundsCheckMode, -) -from fbgemm_gpu.split_table_batched_embeddings_ops_inference import ( - IntNBitTableBatchedEmbeddingBagsCodegen, -) from torch import nn -from torch._decomp import decomposition_table, register_decomposition -from torch._prims_common.wrappers import out_wrapper -from torch.export import Dim - -from tzrec.utils.fx_util import symbolic_trace -from tzrec.utils.logging_util import logger - -# skip default bound check which is not allow by aot -if "ENABLE_AOT" in os.environ: - # pyre-ignore [8] - IntNBitTableBatchedEmbeddingBagsCodegen.__init__ = functools.partialmethod( - IntNBitTableBatchedEmbeddingBagsCodegen.__init__, - bounds_check_mode=BoundsCheckMode.NONE, - ) - -# add new aten._softmax decomposition which is supported by dynamo -aten = torch._ops.ops.aten -if aten._softmax.default in decomposition_table: - del decomposition_table[aten._softmax.default] - del decomposition_table[aten._softmax.out] - - -# pyre-ignore [56] -@register_decomposition(aten._softmax) -@out_wrapper() -def _softmax(x: torch.Tensor, dim: int, half_to_float: bool) -> torch.Tensor: - # eager softmax returns a contiguous tensor. Ensure that decomp also returns - # a contiguous tensor. - x = x.contiguous() - if half_to_float: - assert x.dtype == torch.half - computation_dtype, result_dtype = prims_utils.elementwise_dtypes( - x, type_promotion_kind=prims_utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT - ) - x = x.to(computation_dtype) - x_max = torch.max(x, dim, keepdim=True).values - unnormalized = torch.exp(x - x_max) - result = unnormalized / torch.sum(unnormalized, dim, keepdim=True) - if not half_to_float: - result = result.to(result_dtype) - return result +from tzrec.acc.export_utils import export_pm def export_model_aot( @@ -77,58 +29,7 @@ def export_model_aot( data (Dict[str, torch.Tensor]): the test data save_dir (str): model save dir """ - gm = symbolic_trace(model) - with open(os.path.join(save_dir, "gm.code"), "w") as f: - f.write(gm.code) - - gm = gm.cuda() - - def _is_dense(key: str, data: Dict[str, torch.Tensor]) -> bool: - return data[key].dtype in (torch.float32, torch.bfloat16, torch.float16) - - def _is_dense_seq(key: str, data: Dict[str, torch.Tensor]) -> bool: - return (key.split(".")[0] + ".lengths") in data and _is_dense(key, data) - - batch = Dim("batch") - dynamic_shapes = {} - for key in data: - if key.endswith(".lengths"): - if data[key].shape[0] == 1: - logger.info("uniq user sparse fea %s length=1" % key) - dynamic_shapes[key] = {} - else: - dynamic_shapes[key] = {0: batch} - elif key == "batch_size": - dynamic_shapes[key] = {} - elif _is_dense_seq(key, data) and data[key].shape[0] == 1: - logger.info("uniq seq_dense_fea=%s shape=%s" % (key, data[key].shape)) - dynamic_shapes[key] = {} - elif _is_dense(key, data) and not _is_dense_seq(key, data): - if data[key].shape[0] == 1: - logger.info("uniq user dense_fea=%s shape=%s" % (key, data[key].shape)) - dynamic_shapes[key] = {} - else: - logger.info("batch dense_fea=%s shape=%s" % (key, data[key].shape)) - dynamic_shapes[key] = {0: batch} - else: - tmp_val_dim = Dim(key.replace(".", "__") + "__batch", min=0) - # to handle torch.export 0/1 specialization problem - if data[key].shape[0] < 2: - data[key] = F.pad( - data[key], - [0, 2] + [0, 0] * (len(data[key].shape) - 1), - mode="constant", - ) - dynamic_shapes[key] = {0: tmp_val_dim} - - exported_pg = torch.export.export( - gm, args=(data,), dynamic_shapes=(dynamic_shapes,) - ) - - export_path = os.path.join(save_dir, "exported_pg.py") - with open(export_path, "w") as fout: - fout.write(str(exported_pg)) - - exported_pg.module()(data) - + exported_pg,data = export_pm(model, data, save_dir) + + # TODO(aot cmpile) return exported_pg diff --git a/tzrec/acc/export_utils.py b/tzrec/acc/export_utils.py new file mode 100644 index 00000000..660743d8 --- /dev/null +++ b/tzrec/acc/export_utils.py @@ -0,0 +1,218 @@ +# Copyright (c) 2024, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +import os +from typing import Dict,Tuple,List + +import torch +import torch._prims_common as prims_utils +import torch.nn.functional as F +from fbgemm_gpu.split_table_batched_embeddings_ops_common import ( + BoundsCheckMode, +) +from fbgemm_gpu.split_table_batched_embeddings_ops_inference import ( + IntNBitTableBatchedEmbeddingBagsCodegen, +) +from torch import nn +from torch._decomp import decomposition_table, register_decomposition +from torch._prims_common.wrappers import out_wrapper +from torch.export import Dim + +from tzrec.utils.fx_util import symbolic_trace +from tzrec.utils.logging_util import logger + +# skip default bound check which is not allow by aot +if "ENABLE_AOT" in os.environ or "ENABLE_TRT" in os.environ: + # pyre-ignore [8] + IntNBitTableBatchedEmbeddingBagsCodegen.__init__ = functools.partialmethod( + IntNBitTableBatchedEmbeddingBagsCodegen.__init__, + bounds_check_mode=BoundsCheckMode.NONE, + ) + logger.info("update IntNBitTableBatchedEmbeddingBagsCodegen for export") + + +# add new aten._softmax decomposition which is supported by dynamo +aten = torch._ops.ops.aten +if aten._softmax.default in decomposition_table: + del decomposition_table[aten._softmax.default] + del decomposition_table[aten._softmax.out] + + +# pyre-ignore [56] +@register_decomposition(aten._softmax) +@out_wrapper() +def _softmax(x: torch.Tensor, dim: int, half_to_float: bool) -> torch.Tensor: + # eager softmax returns a contiguous tensor. Ensure that decomp also returns + # a contiguous tensor. + x = x.contiguous() + if half_to_float: + assert x.dtype == torch.half + computation_dtype, result_dtype = prims_utils.elementwise_dtypes( + x, type_promotion_kind=prims_utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ) + x = x.to(computation_dtype) + x_max = torch.max(x, dim, keepdim=True).values + unnormalized = torch.exp(x - x_max) + result = unnormalized / torch.sum(unnormalized, dim, keepdim=True) + if not half_to_float: + result = result.to(result_dtype) + return result + + +def export_pm( + model: nn.Module, data: Dict[str, torch.Tensor], save_dir: str +) -> Tuple[torch.export.ExportedProgram, Dict[str, torch.Tensor]]: + """Export a PyTorch model and its parameters. + + Args: + model (nn.Module): The PyTorch model to export. + data (Dict[str, torch.Tensor]): A dictionary containing the model's input tensors. + save_dir (str): The directory where the model should be saved. + + Returns: + Tuple[torch.export.ExportedProgram, Dict[str, torch.Tensor]]: The exported program and its input data. + """ + gm = symbolic_trace(model) + with open(os.path.join(save_dir, "gm.code"), "w") as f: + f.write(gm.code) + + gm = gm.cuda() + + batch = Dim("batch",min=1, max=8196) + dynamic_shapes = {} + for key in data: + # .lengths + if key.endswith(".lengths"): + # user feats + if key.split(".")[0] in model._data_parser.user_feats: + assert(data[key].shape[0] == 1) + logger.info("uniq user sparse fea %s length=%s" % (key, data[key].shape)) + dynamic_shapes[key] = {} + else: + dynamic_shapes[key] = {0: batch} + elif key == "batch_size": + dynamic_shapes[key] = {} + # dense values + elif key in model._data_parser.dense_keys_list: + # user feats + if key.split(".")[0] in model._data_parser.user_feats: + assert(data[key].shape[0] == 1) + logger.info("uniq user dense_fea=%s shape=%s" % (key, data[key].shape)) + dynamic_shapes[key] = {} + else: + logger.info("batch dense_fea=%s shape=%s" % (key, data[key].shape)) + dynamic_shapes[key] = {0: batch} + # sparse or seq_dense values + else: + # sparse or seq_dense(seq_dense values is also sparse) + logger.info("sparse or seq_dense_fea=%s shape=%s" % (key, data[key].shape)) + tmp_val_dim = Dim(key.replace(".", "__") + "__batch", min=0, max=1000000) + # to handle torch.export 0/1 specialization problem + if data[key].shape[0] < 2: + data[key] = F.pad( + data[key], + [0, 2] + [0, 0] * (len(data[key].shape) - 1), + mode="constant", + ) + dynamic_shapes[key] = {0: tmp_val_dim} + data[key] = data[key].contiguous() + + + logger.info("dynamic shapes=%s" %dynamic_shapes) + exported_pg = torch.export.export( + gm, args=(data,), dynamic_shapes=(dynamic_shapes,)) + + + export_path = os.path.join(save_dir, "exported_pg.py") + with open(export_path, "w") as fout: + fout.write(str(exported_pg)) + + exported_pg.module()(data) + + return (exported_pg, data) + + +def export_pm_list( + model: nn.Module, data_list: List[torch.Tensor], save_dir: str +) -> Tuple[torch.export.ExportedProgram, Dict[str, torch.Tensor]]: + """Export a PyTorch model and its parameters. + + Args: + model (nn.Module): The PyTorch model to export. + data (Dict[str, torch.Tensor]): A dictionary containing the model's input tensors. + save_dir (str): The directory where the model should be saved. + + Returns: + Tuple[torch.export.ExportedProgram, Dict[str, torch.Tensor]]: The exported program and its input data. + """ + gm = symbolic_trace(model) + with open(os.path.join(save_dir, "gm.code"), "w") as f: + f.write(gm.code) + + gm = gm.cuda() + + batch = Dim("batch",min=1, max=8196) + + dynamic_shapes_list = [] + for idx, key in enumerate(model._data_parser.data_list_keys): + + # .lengths + if key.endswith(".lengths"): + # user feats + if key.split(".")[0] in model._data_parser.user_feats: + assert(data_list[idx].shape[0] == 1) + logger.info("uniq user sparse fea %s length=%s" % (key, data_list[idx].shape)) + dynamic_shapes_list.append({}) + else: + dynamic_shapes_list.append({0: batch}) + elif key == "batch_size": + dynamic_shapes_list.append({}) + # dense values + elif key in model._data_parser.dense_keys_list: + # user feats + if key.split(".")[0] in model._data_parser.user_feats: + assert(data_list[idx].shape[0] == 1) + logger.info("uniq user dense_fea=%s shape=%s" % (key, data_list[idx].shape)) + dynamic_shapes_list.append({}) + else: + logger.info("batch dense_fea=%s shape=%s" % (key, data_list[idx].shape)) + dynamic_shapes_list.append({0: batch}) + # sparse or seq_dense values + else: + # sparse or seq_dense(seq_dense values is also sparse) + logger.info("sparse or seq_dense_fea=%s shape=%s" % (key,data_list[idx].shape)) + tmp_val_dim = Dim(key.replace(".", "__") + "__batch", min=0, max=1000000) + # to handle torch.export 0/1 specialization problem + if data_list[idx].shape[0] < 2: + data_list[idx] = F.pad( + data_list[idx], + [0, 2] + [0, 0] * (len(data_list[idx].shape) - 1), + mode="constant", + ) + dynamic_shapes_list.append({0: tmp_val_dim}) + + # trt need input contiguous + data_list[idx] = data_list[idx].contiguous() + + logger.info("dynamic shapes=%s" %dynamic_shapes_list) + dynamic_shapes = {"data": dynamic_shapes_list} + exported_pg = torch.export.export( + gm, args=(data_list,), dynamic_shapes=dynamic_shapes) + + + export_path = os.path.join(save_dir, "exported_pg.py") + with open(export_path, "w") as fout: + fout.write(str(exported_pg)) + + exported_pg.module()(data_list) + + return (exported_pg, data_list) \ No newline at end of file diff --git a/tzrec/acc/trt_utils.py b/tzrec/acc/trt_utils.py index 89d32efc..0aa63ec3 100644 --- a/tzrec/acc/trt_utils.py +++ b/tzrec/acc/trt_utils.py @@ -26,21 +26,18 @@ from tzrec.models.model import ScriptWrapper from tzrec.utils.fx_util import symbolic_trace from tzrec.utils.logging_util import logger - +from tzrec.acc.export_utils import export_pm_list,export_pm def trt_convert( - module: nn.Module, + exp_program: torch.export.ExportedProgram, # pyre-ignore [2] inputs: Optional[Sequence[Sequence[Any]]], - # pyre-ignore [2] - dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]], ) -> torch.fx.GraphModule: """Convert model use trt. Args: - module (nn.Module): Source module + exp_program (torch.export.ExportedProgram): Source exported program inputs (List[Union(torch_tensorrt.Input, torch.Tensor)]): inputs - dynamic_shapes: dynamic shapes Returns: torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT @@ -56,7 +53,6 @@ def trt_convert( # (Lower value allows more graph segmentation) min_block_size = 2 - exp_program = torch.export.export(module, (inputs,), dynamic_shapes=dynamic_shapes) # use script model , unsupported the inputs : dict if is_debug_trt(): with torch_tensorrt.logging.graphs(): @@ -167,7 +163,7 @@ def get_trt_max_seq_len() -> int: return int(os.environ.get("TRT_MAX_SEQ_LEN", 100)) -def export_model_trt( +def export_model_trt1( model: nn.Module, data: Dict[str, torch.Tensor], save_dir: str ) -> None: """Export trt model. @@ -201,6 +197,8 @@ def export_model_trt( if v.size(0) < 2: v = torch.zeros((2,) + v.size()[1:], device="cuda:0", dtype=v.dtype) values_list_cuda.append(v) + print(v.is_contiguous(memory_format=torch.contiguous_format)) + print(v.is_contiguous(memory_format=torch.channels_last)) dynamic_shapes_list.append(dict_dy) # convert dense @@ -208,7 +206,8 @@ def export_model_trt( logger.info("dense res: %s", dense(values_list_cuda)) dense_layer = symbolic_trace(dense) dynamic_shapes = {"args": dynamic_shapes_list} - dense_layer_trt = trt_convert(dense_layer, values_list_cuda, dynamic_shapes) + exp_program = torch.export.export(dense_layer, (values_list_cuda,), dynamic_shapes=dynamic_shapes) + dense_layer_trt = trt_convert(exp_program, values_list_cuda) dict_res = dense_layer_trt(values_list_cuda) logger.info("dense trt res: %s", dict_res) @@ -255,3 +254,119 @@ def export_model_trt( logger.info(prof.key_averages().table(sort_by="cuda_time_total", row_limit=100)) logger.info("trt convert success") + + +def export_model_trt( + model: nn.Module, data: List[torch.Tensor], save_dir: str +) -> None: + """Export trt model. + + Args: + model (nn.Module): the model + data (List[torch.Tensor]): the test data + save_dir (str): model save dir + """ + + result = model(data) + logger.info("orign model result: %s", result) + + exported_pg,data = export_pm_list(model, data, save_dir) + model_trt = trt_convert(exported_pg, data) + gm = symbolic_trace(model_trt) + with open(os.path.join(save_dir, "gm.code"), "w") as f: + f.write(gm.code) + + scripted_model = torch.jit.script(gm) + scripted_model.save(os.path.join(save_dir, "scripted_model.pt")) + + + if is_debug_trt(): + with profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + record_shapes=True, + ) as prof: + with record_function("model_inference"): + dict_res = model(data) + logger.info(prof.key_averages().table(sort_by="cuda_time_total", row_limit=100)) + + with profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + record_shapes=True, + ) as prof: + with record_function("model_inference_trt"): + dict_res = model_trt(data) + logger.info(prof.key_averages().table(sort_by="cuda_time_total", row_limit=100)) + + model_gpu = torch.jit.load( + os.path.join(save_dir, "scripted_model.pt"), map_location="cuda:0" + ) + res = model_gpu(data) + logger.info("final res: %s", res) + with profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + record_shapes=True, + ) as prof: + with record_function("model_inference_combined_trt"): + dict_res = model_gpu(data) + logger.info(prof.key_averages().table(sort_by="cuda_time_total", row_limit=100)) + + logger.info("trt convert success") + + +def export_model_trt2( + model: nn.Module, data: Dict[str, torch.Tensor], save_dir: str +) -> None: + """Export trt model. + + Args: + model (nn.Module): the model + data (List[torch.Tensor]): the test data + save_dir (str): model save dir + """ + + result = model(data) + logger.info("orign model result: %s", result) + + exported_pg,data = export_pm(model, data, save_dir) + model_trt = trt_convert(exported_pg, data) + gm = symbolic_trace(model_trt) + with open(os.path.join(save_dir, "gm.code"), "w") as f: + f.write(gm.code) + + scripted_model = torch.jit.script(gm) + scripted_model.save(os.path.join(save_dir, "scripted_model.pt")) + + + if is_debug_trt(): + with profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + record_shapes=True, + ) as prof: + with record_function("model_inference"): + dict_res = model(data) + logger.info(prof.key_averages().table(sort_by="cuda_time_total", row_limit=100)) + + with profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + record_shapes=True, + ) as prof: + with record_function("model_inference_trt"): + dict_res = model_trt(data) + logger.info(prof.key_averages().table(sort_by="cuda_time_total", row_limit=100)) + + model_gpu = torch.jit.load( + os.path.join(save_dir, "scripted_model.pt"), map_location="cuda:0" + ) + res = model_gpu(data) + logger.info("final res: %s", res) + with profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + record_shapes=True, + ) as prof: + with record_function("model_inference_combined_trt"): + dict_res = model_gpu(data) + logger.info(prof.key_averages().table(sort_by="cuda_time_total", row_limit=100)) + + logger.info("trt convert success") + + \ No newline at end of file diff --git a/tzrec/acc/utils.py b/tzrec/acc/utils.py index acfdb7ba..935bf8ff 100644 --- a/tzrec/acc/utils.py +++ b/tzrec/acc/utils.py @@ -12,7 +12,7 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import json import os -from typing import Dict +from typing import Dict, List import torch @@ -52,6 +52,9 @@ def is_trt() -> bool: return True return False +def is_cuda_export()->bool: + """Judge is trt/aot or not.""" + return is_trt() or is_aot() def is_trt_predict(model_path: str) -> bool: """Judge is trt or not in predict.""" @@ -124,4 +127,8 @@ def export_acc_config() -> Dict[str, str]: acc_config["QUANT_EMB"] = os.environ["QUANT_EMB"] if "ENABLE_TRT" in os.environ: acc_config["ENABLE_TRT"] = os.environ["ENABLE_TRT"] + if "ENABLE_AOT" in os.environ: + acc_config["ENABLE_AOT"] = os.environ["ENABLE_AOT"] return acc_config + + diff --git a/tzrec/datasets/data_parser.py b/tzrec/datasets/data_parser.py index 445aeed5..df519600 100644 --- a/tzrec/datasets/data_parser.py +++ b/tzrec/datasets/data_parser.py @@ -116,6 +116,35 @@ def __init__( logger.info(f"self.user_feats: {self.user_feats}") logger.info(f"self.user_inputs: {self.user_inputs}") + logger.info(f"self.sparse_keys: {self.sparse_keys}") + logger.info(f"self.dense_keys: {self.dense_keys}") + logger.info(f"self.sequence_dense_keys: {self.sequence_dense_keys}") + + # get all key.values as the real keys list + self.data_list_keys = [] + self.dense_keys_list = [] + self.sparse_keys_list = [] + self.sequence_dense_keys_list = [] + self.get_data_list_keys() + print("data list keys:", self.data_list_keys) + + def get_data_list_keys(self) -> None: + for _, keys in self.sparse_keys.items(): + for key in keys: + self.sparse_keys_list.append(f"{key}.values") + self.sparse_keys_list.append(f"{key}.lengths") + for _, keys in self.dense_keys.items(): + for key in keys: + self.dense_keys_list.append(f"{key}.values") + for key in self.sequence_dense_keys: + self.sequence_dense_keys_list.append(f"{key}.values") + self.sequence_dense_keys_list.append(f"{key}.lengths") + + self.data_list_keys = list(set(self.sparse_keys_list) | set(self.dense_keys_list) | set(self.sequence_dense_keys_list)) + if is_input_tile(): + self.data_list_keys.append("batch_size") + self.data_list_keys = sorted(self.data_list_keys) + def _init_fg_hander(self) -> None: """Init pyfg dag handler.""" if not self._fg_handler: @@ -416,6 +445,7 @@ def _to_sparse_features( input_data[f"{key}.values"], dtype=torch.float32 ) ) + sparse_feature = KeyedJaggedTensor( keys=keys, values=torch.cat(values, dim=-1), diff --git a/tzrec/datasets/utils.py b/tzrec/datasets/utils.py index ac0fc1bf..2948169b 100644 --- a/tzrec/datasets/utils.py +++ b/tzrec/datasets/utils.py @@ -10,7 +10,7 @@ # limitations under the License. from dataclasses import dataclass, field -from typing import Dict, Optional +from typing import Dict, Optional, List import numpy.typing as npt import pyarrow as pa @@ -227,3 +227,36 @@ def to_dict( if self.tile_size > 0: tensor_dict["batch_size"] = torch.tensor(self.tile_size, dtype=torch.int64) return tensor_dict + + def to_list( + self, + sparse_dtype: Optional[torch.dtype] = None + ) -> Dict[str, torch.Tensor]: + """Convert to feature tensor list. + used in export,we will skip the labels. + """ + tensor_dict = {} + for x in self.dense_features.values(): + for k, v in x.to_dict().items(): + tensor_dict[f"{k}.values"] = v + for x in self.sparse_features.values(): + if sparse_dtype: + x = KeyedJaggedTensor( + keys=x.keys(), + values=x.values().to(sparse_dtype), + lengths=x.lengths().to(sparse_dtype), + weights=x.weights_or_none(), + ) + for k, v in x.to_dict().items(): + tensor_dict[f"{k}.values"] = v.values() + tensor_dict[f"{k}.lengths"] = v.lengths() + if v.weights_or_none() is not None: + tensor_dict[f"{k}.weights"] = v.weights() + for k, v in self.sequence_dense_features.items(): + tensor_dict[f"{k}.values"] = v.values() + tensor_dict[f"{k}.lengths"] = v.lengths() + if self.tile_size > 0: + tensor_dict["batch_size"] = torch.tensor(self.tile_size, dtype=torch.int64) + sorted_dict = {k: tensor_dict[k] for k in sorted(tensor_dict)} + values_list = list(sorted_dict.values()) + return values_list \ No newline at end of file diff --git a/tzrec/main.py b/tzrec/main.py index 4105be3c..94e9c204 100644 --- a/tzrec/main.py +++ b/tzrec/main.py @@ -50,6 +50,7 @@ is_input_tile_emb, is_quant, is_trt, + is_cuda_export, is_trt_predict, write_mapping_file_for_input_tile, ) @@ -69,7 +70,7 @@ TowerWoEGWrapper, TowerWrapper, ) -from tzrec.models.model import BaseModel, ExportWrapperAOT, ScriptWrapper, TrainWrapper +from tzrec.models.model import BaseModel, CudaScriptWrapper, ScriptWrapper, TrainWrapper from tzrec.models.tdm import TDM, TDMEmbedding from tzrec.modules.embedding import EmbeddingGroup from tzrec.optim import optimizer_builder @@ -727,12 +728,8 @@ def _script_model( if is_rank_zero: if not os.path.exists(save_dir): os.makedirs(save_dir) - if is_trt_convert: - model = model.to_empty(device="cuda:0") - logger.info("gather states to cuda model...") - else: - model = model.to_empty(device="cpu") - logger.info("gather states to cpu model...") + model = model.to_empty(device="cpu") + logger.info("gather states to cpu model...") state_dict_gather(state_dict, model.state_dict()) @@ -741,22 +738,45 @@ def _script_model( if is_rank_zero: batch = next(iter(dataloader)) - if is_aot(): + if is_cuda_export(): model = model.cuda() if is_quant(): logger.info("quantize embeddings...") - quantize_embeddings(model, dtype=torch.qint8, inplace=True) + import torchrec + additional_qconfig_spec_keys = [torchrec.EmbeddingCollection] + from torchrec.quant.embedding_modules import EmbeddingCollection as quant_emb + additional_mapping = {torchrec.EmbeddingCollection: quant_emb} + quantize_embeddings(model, dtype=torch.qint8, inplace=True, + additional_qconfig_spec_keys=additional_qconfig_spec_keys, + additional_mapping=additional_mapping) + #quantize_embeddings(model, dtype=torch.qint8, inplace=True) model.eval() if is_trt_convert: data_cuda = batch.to_dict(sparse_dtype=torch.int64) - result = model(data_cuda, "cuda:0") + #print(data_cuda) + data_cuda_list = batch.to_list(sparse_dtype=torch.int64) + print(data_cuda_list) + result = model(data_cuda_list) result_info = {k: (v.size(), v.dtype) for k, v in result.items()} logger.info(f"Model Outputs: {result_info}") - export_model_trt(model, data_cuda, save_dir) + export_model_trt(model, data_cuda_list, save_dir) + + # result = model(data_cuda) + # result_info = {k: (v.size(), v.dtype) for k, v in result.items()} + # logger.info(f"Model Outputs: {result_info}") + + # export_model_trt(model, data_cuda, save_dir) + # data_cuda = batch.to_dict(sparse_dtype=torch.int64) + # result = model(data_cuda, "cuda:0") + # result_info = {k: (v.size(), v.dtype) for k, v in result.items()} + # logger.info(f"Model Outputs: {result_info}") + + # export_model_trt(model, data_cuda, save_dir) + elif is_aot(): data_cuda = batch.to_dict(sparse_dtype=torch.int64) result = model(data_cuda) @@ -876,20 +896,12 @@ def export( else: raise ValueError("checkpoint path should be specified.") - if is_trt_convert: - checkpoint_pg = dist.new_group(backend="nccl") - if is_rank_zero: - logger.info("copy sharded state_dict to cuda...") - device_state_dict = state_dict_to_device( - model.state_dict(), pg=checkpoint_pg, device=torch.device(device) - ) - else: - checkpoint_pg = dist.new_group(backend="gloo") - if is_rank_zero: - logger.info("copy sharded state_dict to cpu...") - device_state_dict = state_dict_to_device( - model.state_dict(), pg=checkpoint_pg, device=torch.device("cpu") - ) + checkpoint_pg = dist.new_group(backend="gloo") + if is_rank_zero: + logger.info("copy sharded state_dict to cpu...") + device_state_dict = state_dict_to_device( + model.state_dict(), pg=checkpoint_pg, device=torch.device("cpu") + ) device_model = _create_model( pipeline_config.model_config, @@ -897,7 +909,11 @@ def export( list(data_config.label_fields), ) - InferWrapper = ExportWrapperAOT if is_aot() else ScriptWrapper + InferWrapper = CudaScriptWrapper if is_aot() else ScriptWrapper + if is_trt_convert: + from tzrec.models.model import CudaListScriptWrapper + InferWrapper = CudaListScriptWrapper + #InferWrapper = CudaScriptWrapper if isinstance(device_model, MatchModel): for name, module in device_model.named_children(): if isinstance(module, MatchTower) or isinstance(module, MatchTowerWoEG): diff --git a/tzrec/models/model.py b/tzrec/models/model.py index 6dbcd7d2..60d8a289 100644 --- a/tzrec/models/model.py +++ b/tzrec/models/model.py @@ -32,6 +32,16 @@ _MODEL_CLASS_MAP = {} _meta_cls = get_register_class_meta(_MODEL_CLASS_MAP) +@torch.fx.wrap +def _get_dict( + grouped_features_keys: List[str], args: List[torch.Tensor] +) -> Dict[str, torch.Tensor]: + if len(grouped_features_keys) != len(args): + raise ValueError( + "The number of grouped_features_keys must match " "the number of arguments." + ) + grouped_features = {key: value for key, value in zip(grouped_features_keys, args)} + return grouped_features class BaseModel(nn.Module, metaclass=_meta_cls): """TorchEasyRec base model. @@ -238,8 +248,9 @@ def forward( return self.model.predict(batch) -class ExportWrapperAOT(ScriptWrapper): - """Model inference wrapper for aot export.""" + +class CudaScriptWrapper(ScriptWrapper): + """Model inference wrapper for cuda export(aot/trt).""" # pyre-ignore [14] def forward( @@ -257,3 +268,68 @@ def forward( batch = self._data_parser.to_batch(data) batch = batch.to(torch.device("cuda"), non_blocking=True) return self.model.predict(batch) + +class CudaListScriptWrapper(ScriptWrapper): + """Model inference wrapper for cuda export(aot/trt).""" + + # pyre-ignore [14] + def forward( + self, + data: List[torch.Tensor], + ) -> Dict[str, torch.Tensor]: + """Predict the model. + + Args: + data (dict): a dict of input data for Batch. + + Return: + predictions (dict): a dict of predicted result. + """ + data_dict = _get_dict(self._data_parser.data_list_keys, data) + batch = self._data_parser.to_batch(data_dict) + batch = batch.to(torch.device("cuda"), non_blocking=True) + return self.model.predict(batch) + +# only when use TRT will use the wrapper +class ScriptWrapper2(nn.Module): + """Model inference wrapper for jit.script.""" + def __init__(self, module: nn.Module) -> None: + super().__init__() + self.model = module + self._features = self.model._features + self._data_parser = DataParser(self._features) + + def forward( + self, + click_50_seq__item_id_lengths, + click_50_seq__item_id_values, + item_id_lengths, + item_id_values, + user_id_lengths, + user_id_values, + # pyre-ignore [9] + # data: Dict[str, torch.Tensor], + # device: torch.device = "cpu", + ) -> Dict[str, torch.Tensor]: + """Predict the model. + Args: + data (dict): a dict of input data for Batch. + device (torch.device): inference device. + Return: + predictions (dict): a dict of predicted result. + """ + # data = dict(zip(self._data_parser.output_data_keys,args)) + data = dict() + # for i, key in enumerate(self._data_parser.output_data_keys): + # data[key]=args[i] + data["click_50_seq__item_id.lengths"] = click_50_seq__item_id_lengths + data["click_50_seq__item_id.values"] = click_50_seq__item_id_values + data["item_id.lengths"] = item_id_lengths + data["item_id.values"] = item_id_values + data["user_id.lengths"] = user_id_lengths + data["user_id.values"] = user_id_values + batch = self._data_parser.to_batch(data) # , long_lengths=True) + + batch = batch.to("cuda", non_blocking=True) + return self.model.predict(batch) + \ No newline at end of file diff --git a/tzrec/modules/embedding.py b/tzrec/modules/embedding.py index bf6c6fd7..d5fa49a4 100644 --- a/tzrec/modules/embedding.py +++ b/tzrec/modules/embedding.py @@ -120,6 +120,7 @@ def _int_item(x: torch.Tensor) -> int: return int(x.item()) + class EmbeddingGroup(nn.Module): """Applies embedding lookup transformation for feature group. @@ -963,6 +964,8 @@ def __init__( self.ec_list.append( EmbeddingCollection(list(emb_configs.values()), device=device) ) + print("***emb configs") + print(emb_configs.values()) self.mc_ec_list = nn.ModuleList() for k, emb_configs in dim_to_mc_emb_configs.items(): self.mc_ec_list.append( @@ -1076,6 +1079,7 @@ def forward( query_t_list.append(query_t) if len(query_t_list) > 0: results[f"{group_name}.query"] = torch.cat(query_t_list, dim=1) + print(f"{group_name}.query",results[f"{group_name}.query"].shape ) for group_name, v in self._group_to_shared_sequence.items(): seq_t_list = [] @@ -1097,19 +1101,24 @@ def forward( if i == 0: sequence_length = jt.lengths() group_sequence_length = _int_item(torch.max(sequence_length)) + if need_tile: results[f"{group_name}.sequence_length"] = sequence_length.tile( tile_size ) else: results[f"{group_name}.sequence_length"] = sequence_length + jt = jt.to_padded_dense(group_sequence_length) - + if need_tile: jt = jt.tile(tile_size, 1, 1) seq_t_list.append(jt) if seq_t_list: results[f"{group_name}.sequence"] = torch.cat(seq_t_list, dim=2) - + print("group_sequence_length:",group_sequence_length) + print(f"{group_name}.sequence",results[f"{group_name}.sequence"].shape ) + print(f"{group_name}.sequence_length",results[f"{group_name}.sequence_length"].shape ) + return results diff --git a/tzrec/tests/rank_integration_test.py b/tzrec/tests/rank_integration_test.py index 80c4f5b1..0ae31807 100644 --- a/tzrec/tests/rank_integration_test.py +++ b/tzrec/tests/rank_integration_test.py @@ -34,9 +34,9 @@ def setUp(self): os.chmod(self.test_dir, 0o755) def tearDown(self): - if self.success: - if os.path.exists(self.test_dir): - shutil.rmtree(self.test_dir) + # if self.success: + # if os.path.exists(self.test_dir): + # shutil.rmtree(self.test_dir) os.environ.pop("QUANT_EMB", None) os.environ.pop("INPUT_TILE", None) os.environ.pop("ENABLE_TRT", None) @@ -70,20 +70,81 @@ def _test_rank_nofg(self, pipeline_config_path, reserved_columns, output_columns @unittest.skipIf(not torch.cuda.is_available(), "cuda not found") def test_aot_export(self): - pipeline_config_path = "tzrec/tests/configs/multi_tower_din_mock.config" - self.success = utils.test_train_eval(pipeline_config_path, self.test_dir) + pipeline_config_path = "tzrec/tests/configs/multi_tower_din_fg_mock.config" + # self.success = utils.test_train_eval(pipeline_config_path, self.test_dir,user_id="user_id", + # item_id="item_id") + + # if self.success: + # self.success = utils.test_eval( + # os.path.join(self.test_dir, "pipeline.config"), self.test_dir + # ) + self.test_dir = "./tmp/tzrec_q01o6j5c/" + os.system( + "rm -rf ./tmp/tzrec_q01o6j5c/export ./tmp/tzrec_q01o6j5c/input_tile ./tmp/tzrec_q01o6j5c/input_tile_emb") + self.success = True if self.success: - self.success = utils.test_eval( - os.path.join(self.test_dir, "pipeline.config"), self.test_dir + self.success = utils.test_export( + os.path.join(self.test_dir, "pipeline.config"), + self.test_dir, + enable_aot=True, ) + input_tile_dir = os.path.join(self.test_dir, "input_tile") + input_tile_dir_emb = os.path.join(self.test_dir, "input_tile_emb") if self.success: + os.environ["INPUT_TILE"] = "2" self.success = utils.test_export( os.path.join(self.test_dir, "pipeline.config"), - self.test_dir, + input_tile_dir, enable_aot=True, ) + if self.success: + os.environ["INPUT_TILE"] = "3" + self.success = utils.test_export( + os.path.join(self.test_dir, "pipeline.config"), + input_tile_dir_emb, + enable_aot=True, + ) + self.assertTrue(self.success) + + + @unittest.skipIf(not torch.cuda.is_available(), "cuda not found") + def test_trt_export(self): + pipeline_config_path = "tzrec/tests/configs/multi_tower_din_fg_mock.config" + self.test_dir = "./tmp/tzrec_q01o6j5c/" + os.system( + "rm -rf ./tmp/tzrec_q01o6j5c/export ./tmp/tzrec_q01o6j5c/input_tile ./tmp/tzrec_q01o6j5c/input_tile_emb") + self.success = True + # self.success = utils.test_train_eval(pipeline_config_path, self.test_dir,user_id="user_id", + # item_id="item_id") + # if self.success: + # self.success = utils.test_eval( + # os.path.join(self.test_dir, "pipeline.config"), self.test_dir + # ) + if self.success: + self.success = utils.test_export( + os.path.join(self.test_dir, "pipeline.config"), + self.test_dir, + enable_trt=True, + ) + input_tile_dir = os.path.join(self.test_dir, "input_tile") + input_tile_dir_emb = os.path.join(self.test_dir, "input_tile_emb") + if self.success: + os.environ["INPUT_TILE"] = "2" + self.success = utils.test_export( + os.path.join(self.test_dir, "pipeline.config"), + input_tile_dir, + enable_trt=True, + ) + if self.success: + os.environ["INPUT_TILE"] = "3" + self.success = utils.test_export( + os.path.join(self.test_dir, "pipeline.config"), + input_tile_dir_emb, + enable_trt=True, + ) self.assertTrue(self.success) + def test_multi_tower_din_fg_encoded_train_eval_export(self): self._test_rank_nofg( "tzrec/tests/configs/multi_tower_din_mock.config", diff --git a/tzrec/tests/utils.py b/tzrec/tests/utils.py index 8eb25c52..c8bb0845 100644 --- a/tzrec/tests/utils.py +++ b/tzrec/tests/utils.py @@ -875,7 +875,8 @@ def test_eval(pipeline_config_path: str, test_dir: str) -> bool: def test_export( - pipeline_config_path: str, test_dir: str, asset_files: str = "", enable_aot=False + pipeline_config_path: str, test_dir: str, asset_files: str = "", enable_aot=False, + enable_trt=False ) -> bool: """Run export integration test.""" log_dir = os.path.join(test_dir, "log_export") @@ -888,6 +889,8 @@ def test_export( ) if enable_aot: cmd_str = "ENABLE_AOT=1 " + cmd_str + if enable_trt: + cmd_str = "ENABLE_TRT=1 " + cmd_str if asset_files: cmd_str += f"--asset_files {asset_files}"