diff --git a/examples/modular/match/dssm_taobao_local_backbone.config b/examples/modular/match/dssm_taobao_local_backbone.config new file mode 100644 index 00000000..b715fa34 --- /dev/null +++ b/examples/modular/match/dssm_taobao_local_backbone.config @@ -0,0 +1,319 @@ +train_input_path: "data/taobao_data_recall_train/*.parquet" +eval_input_path: "data/taobao_data_recall_eval/*.parquet" +model_dir: "experiments/dssm_taobao_backbone" + +train_config { + sparse_optimizer { + adagrad_optimizer { + lr: 0.001 + } + constant_learning_rate { + } + } + dense_optimizer { + adam_optimizer { + lr: 0.001 + } + constant_learning_rate { + } + } + num_epochs: 8 +} + +eval_config { +} + +data_config { + batch_size: 2048 + dataset_type: ParquetDataset + fg_mode: FG_DAG + label_fields: "clk" + num_workers: 8 + negative_sampler { + input_path: "data/taobao_ad_feature_gl" + num_sample: 4096 + attr_fields: "adgroup_id" + attr_fields: "cate_id" + attr_fields: "campaign_id" + attr_fields: "customer" + attr_fields: "brand" + attr_fields: "price" + item_id_field: "adgroup_id" + attr_delimiter: "\x02" + } +} + +feature_configs { + id_feature { + feature_name: "user_id" + expression: "user:user_id" + num_buckets: 1141730 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cms_segid" + expression: "user:cms_segid" + num_buckets: 98 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cms_group_id" + expression: "user:cms_group_id" + num_buckets: 14 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "final_gender_code" + expression: "user:final_gender_code" + num_buckets: 3 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "age_level" + expression: "user:age_level" + num_buckets: 8 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "pvalue_level" + expression: "user:pvalue_level" + num_buckets: 5 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "shopping_level" + expression: "user:shopping_level" + num_buckets: 5 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "occupation" + expression: "user:occupation" + num_buckets: 3 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "new_user_class_level" + expression: "user:new_user_class_level" + num_buckets: 6 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "adgroup_id" + expression: "item:adgroup_id" + num_buckets: 846812 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cate_id" + expression: "item:cate_id" + num_buckets: 12961 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "campaign_id" + expression: "item:campaign_id" + num_buckets: 423438 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "customer" + expression: "item:customer" + num_buckets: 255877 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "brand" + expression: "item:brand" + num_buckets: 461498 + embedding_dim: 16 + } +} +feature_configs { + raw_feature { + feature_name: "price" + expression: "item:price" + boundaries: [1.1, 2.2, 3.6, 5.2, 7.39, 9.5, 10.5, 12.9, 15, 17.37, 19, 20, 23.8, 25.8, 28, 29.8, 31.5, 34, 36, 38, 39, 40, 45, 48, 49, 51.6, 55.2, 58, 59, 63.8, 68, 69, 72, 78, 79, 85, 88, 90, 97.5, 98, 99, 100, 108, 115, 118, 124, 128, 129, 138, 139, 148, 155, 158, 164, 168, 171.8, 179, 188, 195, 198, 199, 216, 228, 238, 248, 258, 268, 278, 288, 298, 299, 316, 330, 352, 368, 388, 398, 399, 439, 478, 499, 536, 580, 599, 660, 699, 780, 859, 970, 1080, 1280, 1480, 1776, 2188, 2798, 3680, 5160, 8720] + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "pid" + expression: "context:pid" + hash_bucket_size: 20 + embedding_dim: 16 + } +} +feature_configs { + sequence_feature { + sequence_name: "click_50_seq" + sequence_length: 100 + sequence_delim: "|" + features { + id_feature { + feature_name: "adgroup_id" + expression: "item:adgroup_id" + num_buckets: 846812 + embedding_dim: 16 + } + } + features { + id_feature { + feature_name: "cate_id" + expression: "item:cate_id" + num_buckets: 12961 + embedding_dim: 16 + } + } + features { + id_feature { + feature_name: "brand" + expression: "item:brand" + num_buckets: 461498 + embedding_dim: 16 + } + } + } +} + +model_config { + feature_groups { + group_name: "user" + feature_names: "user_id" + feature_names: "cms_segid" + feature_names: "cms_group_id" + feature_names: "final_gender_code" + feature_names: "age_level" + feature_names: "pvalue_level" + feature_names: "shopping_level" + feature_names: "occupation" + feature_names: "new_user_class_level" + feature_names: "pid" + sequence_groups { + group_name: "click_50_seq" + feature_names: "click_50_seq__adgroup_id" + feature_names: "click_50_seq__cate_id" + feature_names: "click_50_seq__brand" + } + sequence_encoders { + pooling_encoder: { + input: "click_50_seq" + pooling_type: "mean" + } + } + group_type: DEEP + } + feature_groups { + group_name: "item" + feature_names: "adgroup_id" + feature_names: "cate_id" + feature_names: "campaign_id" + feature_names: "customer" + feature_names: "brand" + feature_names: "price" + group_type: DEEP + } + + # 使用组件化的match_backbone配置 + match_backbone { + backbone { + # 主要的backbone block定义 + blocks { + name: "user" + inputs { + feature_group_name: "user" + } + input_layer { + } + } + blocks { + name: "item" + inputs { + feature_group_name: "item" + } + input_layer { + } + } + # 用户塔MLP + blocks { + name: "user_tower" + inputs { + block_name: "user" + } + module { + class_name: "MLP" + mlp { + hidden_units: [512, 256, 128] + activation: "nn.ReLU" + use_bn: false + dropout_ratio: [0.0, 0.0, 0.0] + } + } + } + # 物品塔MLP + blocks { + name: "item_tower" + inputs { + block_name: "item" + } + module { + class_name: "MLP" + mlp { + hidden_units: [512, 256, 128] + activation: "nn.ReLU" + use_bn: false + dropout_ratio: [0.0, 0.0, 0.0] + } + } + } + # 输出blocks配置 - 指定用户塔和物品塔的输出 + output_blocks: "user_tower" + output_blocks: "item_tower" + } + model_params { + # 可以在这里配置一些通用参数 + # 具体的output_dim、similarity等参数会通过代码默认值处理 + } + } + + metrics { + recall_at_k { + top_k: 1 + } + } + metrics { + recall_at_k { + top_k: 5 + } + } + losses { + softmax_cross_entropy {} + } +} diff --git a/examples/modular/multi_task_rank/mmoe_taobao_backbone.config b/examples/modular/multi_task_rank/mmoe_taobao_backbone.config new file mode 100644 index 00000000..d8f9ed91 --- /dev/null +++ b/examples/modular/multi_task_rank/mmoe_taobao_backbone.config @@ -0,0 +1,258 @@ +train_input_path: "odps://pai_rec_test_dev/tables/taobao_multitask_sample_bucketized_v1" +eval_input_path: "odps://pai_rec_test_dev/tables/taobao_multitask_sample_bucketized_v1/ds=20170513" +model_dir: "experiments/mmoe_taobao_backbone" + +train_config { + sparse_optimizer { + adagrad_optimizer { + lr: 0.001 + } + constant_learning_rate { + } + } + dense_optimizer { + adam_optimizer { + lr: 0.001 + } + constant_learning_rate { + } + } + num_epochs: 1 +} + +eval_config { +} + +data_config { + batch_size: 8192 + dataset_type: OdpsDataset + fg_encoded: false + label_fields: "clk" + label_fields: "buy" + num_workers: 8 +} + +feature_configs { + id_feature { + feature_name: "user_id" + expression: "user:user_id" + num_buckets: 1141730 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cms_segid" + expression: "user:cms_segid" + num_buckets: 98 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cms_group_id" + expression: "user:cms_group_id" + num_buckets: 14 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "final_gender_code" + expression: "user:final_gender_code" + num_buckets: 3 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "age_level" + expression: "user:age_level" + num_buckets: 8 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "pvalue_level" + expression: "user:pvalue_level" + num_buckets: 5 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "shopping_level" + expression: "user:shopping_level" + num_buckets: 5 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "occupation" + expression: "user:occupation" + num_buckets: 3 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "new_user_class_level" + expression: "user:new_user_class_level" + num_buckets: 6 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "adgroup_id" + expression: "item:adgroup_id" + num_buckets: 846812 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cate_id" + expression: "item:cate_id" + num_buckets: 12961 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "campaign_id" + expression: "item:campaign_id" + num_buckets: 423438 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "customer" + expression: "item:customer" + num_buckets: 255877 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "brand" + expression: "item:brand" + num_buckets: 461498 + embedding_dim: 16 + } +} +feature_configs { + raw_feature { + feature_name: "price" + expression: "item:price" + boundaries: [1.1, 2.2, 3.6, 5.2, 7.39, 9.5, 10.5, 12.9, 15, 17.37, 19, 20, 23.8, 25.8, 28, 29.8, 31.5, 34, 36, 38, 39, 40, 45, 48, 49, 51.6, 55.2, 58, 59, 63.8, 68, 69, 72, 78, 79, 85, 88, 90, 97.5, 98, 99, 100, 108, 115, 118, 124, 128, 129, 138, 139, 148, 155, 158, 164, 168, 171.8, 179, 188, 195, 198, 199, 216, 228, 238, 248, 258, 268, 278, 288, 298, 299, 316, 330, 352, 368, 388, 398, 399, 439, 478, 499, 536, 580, 599, 660, 699, 780, 859, 970, 1080, 1280, 1480, 1776, 2188, 2798, 3680, 5160, 8720] + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "pid" + expression: "context:pid" + hash_bucket_size: 20 + embedding_dim: 16 + } +} + +model_config { + feature_groups { + group_name: "all" + feature_names: "user_id" + feature_names: "cms_segid" + feature_names: "cms_group_id" + feature_names: "final_gender_code" + feature_names: "age_level" + feature_names: "pvalue_level" + feature_names: "shopping_level" + feature_names: "occupation" + feature_names: "new_user_class_level" + feature_names: "pid" + feature_names: "adgroup_id" + feature_names: "cate_id" + feature_names: "campaign_id" + feature_names: "customer" + feature_names: "brand" + feature_names: "price" + group_type: DEEP + } + + multi_task_backbone { + backbone { + # 输入层:处理特征组 + blocks { + name: 'all' + inputs { + feature_group_name: 'all' + } + input_layer { + only_output_feature_list: false + } + } + + # MMoE模块 + blocks { + name: 'mmoe_module' + inputs { + block_name: 'all' + } + module { + class_name: 'MMoE' + mmoe { + expert_mlp { + hidden_units: [512, 256, 128] + } + num_expert: 3 + num_task: 2 + gate_mlp { + hidden_units: [256, 128] + } + } + } + } + } + model_params{ + # 任务塔配置 + task_towers { + tower_name: "ctr" + label_name: "clk" + num_class: 1 + mlp { + hidden_units: [256, 128, 64] + activation: "nn.ReLU" + dropout_ratio: [0.0, 0.0, 0.0] + } + metrics { + auc {} + } + losses { + binary_cross_entropy {} + } + } + task_towers { + tower_name: "cvr" + label_name: "buy" + num_class: 1 + mlp { + hidden_units: [256, 128, 64] + activation: "nn.ReLU" + dropout_ratio: [0.0, 0.0, 0.0] + } + metrics { + auc { + thresholds: 1000 + } + } + losses { + binary_cross_entropy {} + } + } + } + } +} diff --git a/examples/modular/rank/dcn_local_backbone.config b/examples/modular/rank/dcn_local_backbone.config new file mode 100644 index 00000000..641a5e5c --- /dev/null +++ b/examples/modular/rank/dcn_local_backbone.config @@ -0,0 +1,230 @@ +train_input_path: "data/taobao_data_train/*.parquet" +eval_input_path: "data/taobao_data_eval/*.parquet" +model_dir: "experiments/dcn_local_backbone" +train_config { + sparse_optimizer { + adagrad_optimizer { + lr: 0.001 + } + constant_learning_rate { + } + } + dense_optimizer { + adam_optimizer { + lr: 0.001 + } + constant_learning_rate { + } + } + num_epochs: 1 +} +eval_config { +} +data_config { + batch_size: 8192 + dataset_type: ParquetDataset + fg_mode: FG_DAG + label_fields: "clk" + num_workers: 8 +} +feature_configs { + id_feature { + feature_name: "user_id" + expression: "user:user_id" + num_buckets: 1141730 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cms_segid" + expression: "user:cms_segid" + num_buckets: 98 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cms_group_id" + expression: "user:cms_group_id" + num_buckets: 14 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "final_gender_code" + expression: "user:final_gender_code" + num_buckets: 3 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "age_level" + expression: "user:age_level" + num_buckets: 8 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "pvalue_level" + expression: "user:pvalue_level" + num_buckets: 5 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "shopping_level" + expression: "user:shopping_level" + num_buckets: 5 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "occupation" + expression: "user:occupation" + num_buckets: 3 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "new_user_class_level" + expression: "user:new_user_class_level" + num_buckets: 6 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "adgroup_id" + expression: "item:adgroup_id" + num_buckets: 846812 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cate_id" + expression: "item:cate_id" + num_buckets: 12961 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "campaign_id" + expression: "item:campaign_id" + num_buckets: 423438 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "customer" + expression: "item:customer" + num_buckets: 255877 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "brand" + expression: "item:brand" + num_buckets: 461498 + embedding_dim: 16 + } +} +feature_configs { + raw_feature { + feature_name: "price" + expression: "item:price" + boundaries: [1.1, 2.2, 3.6, 5.2, 7.39, 9.5, 10.5, 12.9, 15, 17.37, 19, 20, 23.8, 25.8, 28, 29.8, 31.5, 34, 36, 38, 39, 40, 45, 48, 49, 51.6, 55.2, 58, 59, 63.8, 68, 69, 72, 78, 79, 85, 88, 90, 97.5, 98, 99, 100, 108, 115, 118, 124, 128, 129, 138, 139, 148, 155, 158, 164, 168, 171.8, 179, 188, 195, 198, 199, 216, 228, 238, 248, 258, 268, 278, 288, 298, 299, 316, 330, 352, 368, 388, 398, 399, 439, 478, 499, 536, 580, 599, 660, 699, 780, 859, 970, 1080, 1280, 1480, 1776, 2188, 2798, 3680, 5160, 8720] + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "pid" + expression: "context:pid" + hash_bucket_size: 20 + embedding_dim: 16 + } +} +model_config { + feature_groups { + group_name: "user" + feature_names: "user_id" + feature_names: "cms_segid" + feature_names: "cms_group_id" + feature_names: "final_gender_code" + feature_names: "age_level" + feature_names: "pvalue_level" + feature_names: "shopping_level" + feature_names: "occupation" + feature_names: "new_user_class_level" + feature_names: "pid" + group_type: DEEP + } + feature_groups { + group_name: "item" + feature_names: "adgroup_id" + feature_names: "cate_id" + feature_names: "campaign_id" + feature_names: "customer" + feature_names: "brand" + feature_names: "price" + group_type: DEEP + } + rank_backbone{ + backbone { + blocks { + name: "cross_net" + inputs { feature_group_name: "user" } + module { + class_name: "CrossNet" + cross_net { + num_layers: 3 + } + } + } + blocks { + name: "deep_net" + inputs { feature_group_name: "item" } + module { + class_name: "MLP" + mlp { + hidden_units: 512 + hidden_units: 256 + hidden_units: 128 + activation: "nn.ReLU" + } + } + } + blocks { + name: "dcn_output" + inputs { block_name: "cross_net" } + inputs { block_name: "deep_net" } + merge_inputs_into_list: false + module { + class_name: "MLP" + mlp { + hidden_units: 64 + activation: "nn.ReLU" + } + } + } + concat_blocks: "dcn_output" + } + } + metrics { + auc {} + } + losses { + binary_cross_entropy {} + } +} diff --git a/examples/modular/rank/dcn_local_backbone_recurrent.config b/examples/modular/rank/dcn_local_backbone_recurrent.config new file mode 100644 index 00000000..4411aa58 --- /dev/null +++ b/examples/modular/rank/dcn_local_backbone_recurrent.config @@ -0,0 +1,240 @@ +train_input_path: "data/taobao_data_train/*.parquet" +eval_input_path: "data/taobao_data_eval/*.parquet" +model_dir: "experiments/dcn_local_backbone_recurrent" +train_config { + sparse_optimizer { + adagrad_optimizer { + lr: 0.001 + } + constant_learning_rate { + } + } + dense_optimizer { + adam_optimizer { + lr: 0.001 + } + constant_learning_rate { + } + } + num_epochs: 1 +} +eval_config { +} +data_config { + batch_size: 8192 + dataset_type: ParquetDataset + fg_mode: FG_DAG + label_fields: "clk" + num_workers: 8 +} +feature_configs { + id_feature { + feature_name: "user_id" + expression: "user:user_id" + num_buckets: 1141730 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cms_segid" + expression: "user:cms_segid" + num_buckets: 98 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cms_group_id" + expression: "user:cms_group_id" + num_buckets: 14 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "final_gender_code" + expression: "user:final_gender_code" + num_buckets: 3 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "age_level" + expression: "user:age_level" + num_buckets: 8 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "pvalue_level" + expression: "user:pvalue_level" + num_buckets: 5 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "shopping_level" + expression: "user:shopping_level" + num_buckets: 5 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "occupation" + expression: "user:occupation" + num_buckets: 3 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "new_user_class_level" + expression: "user:new_user_class_level" + num_buckets: 6 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "adgroup_id" + expression: "item:adgroup_id" + num_buckets: 846812 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cate_id" + expression: "item:cate_id" + num_buckets: 12961 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "campaign_id" + expression: "item:campaign_id" + num_buckets: 423438 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "customer" + expression: "item:customer" + num_buckets: 255877 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "brand" + expression: "item:brand" + num_buckets: 461498 + embedding_dim: 16 + } +} +feature_configs { + raw_feature { + feature_name: "price" + expression: "item:price" + boundaries: [1.1, 2.2, 3.6, 5.2, 7.39, 9.5, 10.5, 12.9, 15, 17.37, 19, 20, 23.8, 25.8, 28, 29.8, 31.5, 34, 36, 38, 39, 40, 45, 48, 49, 51.6, 55.2, 58, 59, 63.8, 68, 69, 72, 78, 79, 85, 88, 90, 97.5, 98, 99, 100, 108, 115, 118, 124, 128, 129, 138, 139, 148, 155, 158, 164, 168, 171.8, 179, 188, 195, 198, 199, 216, 228, 238, 248, 258, 268, 278, 288, 298, 299, 316, 330, 352, 368, 388, 398, 399, 439, 478, 499, 536, 580, 599, 660, 699, 780, 859, 970, 1080, 1280, 1480, 1776, 2188, 2798, 3680, 5160, 8720] + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "pid" + expression: "context:pid" + hash_bucket_size: 20 + embedding_dim: 16 + } +} +model_config { + feature_groups { + group_name: "all" + feature_names: "user_id" + feature_names: "cms_segid" + feature_names: "cms_group_id" + feature_names: "final_gender_code" + feature_names: "age_level" + feature_names: "pvalue_level" + feature_names: "shopping_level" + feature_names: "occupation" + feature_names: "new_user_class_level" + feature_names: "adgroup_id" + feature_names: "cate_id" + feature_names: "campaign_id" + feature_names: "customer" + feature_names: "brand" + feature_names: "price" + feature_names: "pid" + group_type: DEEP + } + feature_groups { + group_name: "deep" + feature_names: "adgroup_id" + feature_names: "cate_id" + feature_names: "campaign_id" + feature_names: "customer" + feature_names: "brand" + feature_names: "price" + group_type: DEEP + } + rank_backbone{ + backbone { + blocks { + name: "dcn" + inputs { + feature_group_name: "all" + input_fn: "lambda x: [x, x]" + } + recurrent { + num_steps: 3 + fixed_input_index: 0 + module { + class_name: "Cross" + } + } + } + blocks { + name: "deep_net" + inputs { feature_group_name: "deep" } + module { + class_name: "MLP" + mlp { + hidden_units: 512 + hidden_units: 256 + hidden_units: 128 + activation: "nn.ReLU" + } + } + } + blocks { + name: "dcn_output" + inputs { block_name: "dcn" } + inputs { block_name: "deep_net" } + merge_inputs_into_list: false + module { + class_name: "MLP" + mlp { + hidden_units: 64 + activation: "nn.ReLU" + } + } + } + concat_blocks: "dcn_output" + } + } + metrics { + auc {} + } + losses { + binary_cross_entropy {} + } +} diff --git a/examples/modular/rank/deepfm_criteo_rankbackbone.config b/examples/modular/rank/deepfm_criteo_rankbackbone.config new file mode 100644 index 00000000..a6fd44d3 --- /dev/null +++ b/examples/modular/rank/deepfm_criteo_rankbackbone.config @@ -0,0 +1,460 @@ +train_input_path: "odps://pai_rec_test_dev/tables/criteo_terabyte_train_hashed_v1" +eval_input_path: "odps://pai_rec_test_dev/tables/criteo_terabyte_val_test_hashed_v1" +model_dir: "experiments/deepfm_criteo" +train_config { + sparse_optimizer { + adagrad_optimizer { + lr: 0.001 + } + constant_learning_rate { + } + } + dense_optimizer { + adam_optimizer { + lr: 0.001 + } + constant_learning_rate { + } + } + num_epochs: 1 +} +eval_config { + num_steps: 100 +} +data_config { + batch_size: 8192 + dataset_type: OdpsDataset + fg_encoded: true + label_fields: "label" + num_workers: 8 +} +feature_configs { + raw_feature { + feature_name: "int_0" + } +} +feature_configs { + raw_feature { + feature_name: "int_1" + } +} +feature_configs { + raw_feature { + feature_name: "int_2" + } +} +feature_configs { + raw_feature { + feature_name: "int_3" + } +} +feature_configs { + raw_feature { + feature_name: "int_4" + } +} +feature_configs { + raw_feature { + feature_name: "int_5" + } +} +feature_configs { + raw_feature { + feature_name: "int_6" + } +} +feature_configs { + raw_feature { + feature_name: "int_7" + } +} +feature_configs { + raw_feature { + feature_name: "int_8" + } +} +feature_configs { + raw_feature { + feature_name: "int_9" + } +} +feature_configs { + raw_feature { + feature_name: "int_10" + } +} +feature_configs { + raw_feature { + feature_name: "int_11" + } +} +feature_configs { + raw_feature { + feature_name: "int_12" + } +} +feature_configs { + id_feature { + feature_name: "cat_0" + num_buckets: 40000000 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_1" + num_buckets: 39060 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_2" + num_buckets: 17295 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_3" + num_buckets: 7424 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_4" + num_buckets: 20265 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_5" + num_buckets: 3 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_6" + num_buckets: 7122 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_7" + num_buckets: 1543 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_8" + num_buckets: 63 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_9" + num_buckets: 40000000 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_10" + num_buckets: 3067956 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_11" + num_buckets: 405282 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_12" + num_buckets: 10 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_13" + num_buckets: 2209 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_14" + num_buckets: 11938 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_15" + num_buckets: 155 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_16" + num_buckets: 4 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_17" + num_buckets: 976 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_18" + num_buckets: 14 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_19" + num_buckets: 40000000 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_20" + num_buckets: 40000000 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_21" + num_buckets: 40000000 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_22" + num_buckets: 590152 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_23" + num_buckets: 12973 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_24" + num_buckets: 108 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_25" + num_buckets: 36 + embedding_dim: 16 + } +} +model_config { + feature_groups { + group_name: "wide_features" + feature_names: "cat_0" + feature_names: "cat_1" + feature_names: "cat_2" + feature_names: "cat_3" + feature_names: "cat_4" + feature_names: "cat_5" + feature_names: "cat_6" + feature_names: "cat_7" + feature_names: "cat_8" + feature_names: "cat_9" + feature_names: "cat_10" + feature_names: "cat_11" + feature_names: "cat_12" + feature_names: "cat_13" + feature_names: "cat_14" + feature_names: "cat_15" + feature_names: "cat_16" + feature_names: "cat_17" + feature_names: "cat_18" + feature_names: "cat_19" + feature_names: "cat_20" + feature_names: "cat_21" + feature_names: "cat_22" + feature_names: "cat_23" + feature_names: "cat_24" + feature_names: "cat_25" + group_type: WIDE + } + feature_groups { + group_name: "fm_features" + feature_names: "cat_0" + feature_names: "cat_1" + feature_names: "cat_2" + feature_names: "cat_3" + feature_names: "cat_4" + feature_names: "cat_5" + feature_names: "cat_6" + feature_names: "cat_7" + feature_names: "cat_8" + feature_names: "cat_9" + feature_names: "cat_10" + feature_names: "cat_11" + feature_names: "cat_12" + feature_names: "cat_13" + feature_names: "cat_14" + feature_names: "cat_15" + feature_names: "cat_16" + feature_names: "cat_17" + feature_names: "cat_18" + feature_names: "cat_19" + feature_names: "cat_20" + feature_names: "cat_21" + feature_names: "cat_22" + feature_names: "cat_23" + feature_names: "cat_24" + feature_names: "cat_25" + group_type: DEEP + } + feature_groups { + group_name: "deep_features" + feature_names: "int_0" + feature_names: "int_1" + feature_names: "int_2" + feature_names: "int_3" + feature_names: "int_4" + feature_names: "int_5" + feature_names: "int_6" + feature_names: "int_7" + feature_names: "int_8" + feature_names: "int_9" + feature_names: "int_10" + feature_names: "int_11" + feature_names: "int_12" + feature_names: "cat_0" + feature_names: "cat_1" + feature_names: "cat_2" + feature_names: "cat_3" + feature_names: "cat_4" + feature_names: "cat_5" + feature_names: "cat_6" + feature_names: "cat_7" + feature_names: "cat_8" + feature_names: "cat_9" + feature_names: "cat_10" + feature_names: "cat_11" + feature_names: "cat_12" + feature_names: "cat_13" + feature_names: "cat_14" + feature_names: "cat_15" + feature_names: "cat_16" + feature_names: "cat_17" + feature_names: "cat_18" + feature_names: "cat_19" + feature_names: "cat_20" + feature_names: "cat_21" + feature_names: "cat_22" + feature_names: "cat_23" + feature_names: "cat_24" + feature_names: "cat_25" + group_type: DEEP + } + rank_backbone { + backbone { + blocks { + name: 'wide_features' + inputs { + feature_group_name: 'wide_features' + } + input_layer { + wide_output_dim: 1 + } + } + blocks { + name: 'wide_logit' + inputs { + block_name: 'wide_features' + } + lambda { + expression: 'lambda x: torch.sum(x, dim=-1, keepdim=True)' + } + } + blocks { + name: 'fm_features' + inputs { + feature_group_name: 'fm_features' + } + input_layer { + only_output_3d_tensor: false + } + } + blocks{ + name:'fm_reshape' + inputs{ + block_name: 'fm_features' + input_fn: 'lambda x: x.reshape(x.shape[0],26,16)' + } + } + blocks { + name: 'deep_features' + inputs { + feature_group_name: 'deep_features' + } + input_layer { + output_2d_tensor_and_feature_list: true + } + } + blocks { + name: 'fm' + inputs { + block_name: 'fm_reshape' + } + module { + class_name: 'FM' + fm { + } + } + } + blocks { + name: 'deep' + inputs { + block_name: 'deep_features' + } + module { + class_name: 'MLP' + mlp { + hidden_units: [256, 128, 64, 2] + activation: '' + } + } + } + concat_blocks: ['wide_logit', 'fm', 'deep'] + } + } + metrics { + auc {} + } + losses { + binary_cross_entropy {} + } +} diff --git a/examples/modular/rank/masknet_criteo_backbone.config b/examples/modular/rank/masknet_criteo_backbone.config new file mode 100644 index 00000000..b24fb910 --- /dev/null +++ b/examples/modular/rank/masknet_criteo_backbone.config @@ -0,0 +1,395 @@ +train_input_path: "odps://pai_rec_test_dev/tables/criteo_terabyte_train_hashed_v1" +eval_input_path: "odps://pai_rec_test_dev/tables/criteo_terabyte_val_test_hashed_v1" +model_dir: "experiments/masknet_criteo_backbone" +train_config { + sparse_optimizer { + adagrad_optimizer { + lr: 0.0001 + } + constant_learning_rate { + } + } + dense_optimizer { + adam_optimizer { + lr: 0.0001 + } + constant_learning_rate { + } + } + num_epochs: 1 + save_checkpoints_epochs: 1 +} +eval_config { +} +data_config { + batch_size: 8192 + dataset_type: OdpsDataset + fg_encoded: true + label_fields: "label" + num_workers: 8 +} + +# 数值特征配置 +feature_configs { + raw_feature { + feature_name: "int_0" + embedding_dim: 16 + normalizer: "method=expression,expr=log(x+3)" + } +} +feature_configs { + raw_feature { + feature_name: "int_1" + embedding_dim: 16 + normalizer: "method=expression,expr=log(x+3)" + } +} +feature_configs { + raw_feature { + feature_name: "int_2" + embedding_dim: 16 + normalizer: "method=expression,expr=log(x+3)" + } +} +feature_configs { + raw_feature { + feature_name: "int_3" + embedding_dim: 16 + normalizer: "method=expression,expr=log(x+3)" + } +} +feature_configs { + raw_feature { + feature_name: "int_4" + embedding_dim: 16 + normalizer: "method=expression,expr=log(x+3)" + } +} +feature_configs { + raw_feature { + feature_name: "int_5" + embedding_dim: 16 + normalizer: "method=expression,expr=log(x+3)" + } +} +feature_configs { + raw_feature { + feature_name: "int_6" + embedding_dim: 16 + normalizer: "method=expression,expr=log(x+3)" + } +} +feature_configs { + raw_feature { + feature_name: "int_7" + embedding_dim: 16 + normalizer: "method=expression,expr=log(x+3)" + } +} +feature_configs { + raw_feature { + feature_name: "int_8" + embedding_dim: 16 + normalizer: "method=expression,expr=log(x+3)" + } +} +feature_configs { + raw_feature { + feature_name: "int_9" + embedding_dim: 16 + normalizer: "method=expression,expr=log(x+3)" + } +} +feature_configs { + raw_feature { + feature_name: "int_10" + embedding_dim: 16 + normalizer: "method=expression,expr=log(x+3)" + } +} +feature_configs { + raw_feature { + feature_name: "int_11" + embedding_dim: 16 + normalizer: "method=expression,expr=log(x+3)" + } +} +feature_configs { + raw_feature { + feature_name: "int_12" + embedding_dim: 16 + normalizer: "method=expression,expr=log(x+3)" + } +} + +# 类别特征配置 +feature_configs { + id_feature { + feature_name: "cat_0" + num_buckets: 40000000 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_1" + num_buckets: 39060 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_2" + num_buckets: 17295 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_3" + num_buckets: 7424 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_4" + num_buckets: 20265 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_5" + num_buckets: 3 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_6" + num_buckets: 7122 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_7" + num_buckets: 1543 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_8" + num_buckets: 63 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_9" + num_buckets: 40000000 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_10" + num_buckets: 3067956 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_11" + num_buckets: 405282 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_12" + num_buckets: 10 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_13" + num_buckets: 2209 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_14" + num_buckets: 11938 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_15" + num_buckets: 155 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_16" + num_buckets: 4 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_17" + num_buckets: 976 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_18" + num_buckets: 14 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_19" + num_buckets: 40000000 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_20" + num_buckets: 40000000 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_21" + num_buckets: 40000000 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_22" + num_buckets: 590152 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_23" + num_buckets: 12973 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_24" + num_buckets: 108 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_25" + num_buckets: 36 + embedding_dim: 16 + } +} + +model_config { + feature_groups { + group_name: "all_features" + feature_names: "int_0" + feature_names: "int_1" + feature_names: "int_2" + feature_names: "int_3" + feature_names: "int_4" + feature_names: "int_5" + feature_names: "int_6" + feature_names: "int_7" + feature_names: "int_8" + feature_names: "int_9" + feature_names: "int_10" + feature_names: "int_11" + feature_names: "int_12" + feature_names: "cat_0" + feature_names: "cat_1" + feature_names: "cat_2" + feature_names: "cat_3" + feature_names: "cat_4" + feature_names: "cat_5" + feature_names: "cat_6" + feature_names: "cat_7" + feature_names: "cat_8" + feature_names: "cat_9" + feature_names: "cat_10" + feature_names: "cat_11" + feature_names: "cat_12" + feature_names: "cat_13" + feature_names: "cat_14" + feature_names: "cat_15" + feature_names: "cat_16" + feature_names: "cat_17" + feature_names: "cat_18" + feature_names: "cat_19" + feature_names: "cat_20" + feature_names: "cat_21" + feature_names: "cat_22" + feature_names: "cat_23" + feature_names: "cat_24" + feature_names: "cat_25" + group_type: DEEP + } + rank_backbone { + backbone { + blocks { + name: 'all_features' + inputs { + feature_group_name: 'all_features' + } + input_layer { + only_output_3d_tensor: false + } + } + blocks { + name: 'masknet' + inputs { + block_name: 'all_features' + } + module { + class_name: 'MaskNetModule' + mask_net_module { + n_mask_blocks: 3 + mask_block { + reduction_ratio: 3.0 + hidden_dim: 512 + } + use_parallel: true + top_mlp { + hidden_units: [256, 128, 64, 1] + activation: 'nn.ReLU' + dropout_ratio: [0.0, 0.0, 0.0, 0.0] + } + } + } + } + concat_blocks: ['masknet'] + } + } + metrics { + auc {} + } + losses { + binary_cross_entropy {} + } +} diff --git a/examples/modular/rank/masknet_criteo_repeat_backbone.config b/examples/modular/rank/masknet_criteo_repeat_backbone.config new file mode 100644 index 00000000..34f8a1e8 --- /dev/null +++ b/examples/modular/rank/masknet_criteo_repeat_backbone.config @@ -0,0 +1,411 @@ +train_input_path: "odps://pai_rec_test_dev/tables/criteo_terabyte_train_hashed_v1" +eval_input_path: "odps://pai_rec_test_dev/tables/criteo_terabyte_val_test_hashed_v1" +model_dir: "experiments/masknet_criteo_repeat_backbone" +train_config { + sparse_optimizer { + adagrad_optimizer { + lr: 0.0001 + } + constant_learning_rate { + } + } + dense_optimizer { + adam_optimizer { + lr: 0.0001 + } + constant_learning_rate { + } + } + num_epochs: 1 + save_checkpoints_epochs: 1 +} +eval_config { +} +data_config { + batch_size: 8192 + dataset_type: OdpsDataset + fg_encoded: true + label_fields: "label" + num_workers: 8 +} + +feature_configs { + raw_feature { + feature_name: "int_0" + embedding_dim: 16 + normalizer: "method=expression,expr=log(x+3)" + } +} +feature_configs { + raw_feature { + feature_name: "int_1" + embedding_dim: 16 + normalizer: "method=expression,expr=log(x+3)" + } +} +feature_configs { + raw_feature { + feature_name: "int_2" + embedding_dim: 16 + normalizer: "method=expression,expr=log(x+3)" + } +} +feature_configs { + raw_feature { + feature_name: "int_3" + embedding_dim: 16 + normalizer: "method=expression,expr=log(x+3)" + } +} +feature_configs { + raw_feature { + feature_name: "int_4" + embedding_dim: 16 + normalizer: "method=expression,expr=log(x+3)" + } +} +feature_configs { + raw_feature { + feature_name: "int_5" + embedding_dim: 16 + normalizer: "method=expression,expr=log(x+3)" + } +} +feature_configs { + raw_feature { + feature_name: "int_6" + embedding_dim: 16 + normalizer: "method=expression,expr=log(x+3)" + } +} +feature_configs { + raw_feature { + feature_name: "int_7" + embedding_dim: 16 + normalizer: "method=expression,expr=log(x+3)" + } +} +feature_configs { + raw_feature { + feature_name: "int_8" + embedding_dim: 16 + normalizer: "method=expression,expr=log(x+3)" + } +} +feature_configs { + raw_feature { + feature_name: "int_9" + embedding_dim: 16 + normalizer: "method=expression,expr=log(x+3)" + } +} +feature_configs { + raw_feature { + feature_name: "int_10" + embedding_dim: 16 + normalizer: "method=expression,expr=log(x+3)" + } +} +feature_configs { + raw_feature { + feature_name: "int_11" + embedding_dim: 16 + normalizer: "method=expression,expr=log(x+3)" + } +} +feature_configs { + raw_feature { + feature_name: "int_12" + embedding_dim: 16 + normalizer: "method=expression,expr=log(x+3)" + } +} + +# 类别特征配置 +feature_configs { + id_feature { + feature_name: "cat_0" + num_buckets: 40000000 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_1" + num_buckets: 39060 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_2" + num_buckets: 17295 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_3" + num_buckets: 7424 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_4" + num_buckets: 20265 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_5" + num_buckets: 3 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_6" + num_buckets: 7122 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_7" + num_buckets: 1543 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_8" + num_buckets: 63 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_9" + num_buckets: 40000000 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_10" + num_buckets: 3067956 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_11" + num_buckets: 405282 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_12" + num_buckets: 10 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_13" + num_buckets: 2209 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_14" + num_buckets: 11938 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_15" + num_buckets: 155 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_16" + num_buckets: 4 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_17" + num_buckets: 976 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_18" + num_buckets: 14 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_19" + num_buckets: 40000000 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_20" + num_buckets: 40000000 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_21" + num_buckets: 40000000 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_22" + num_buckets: 590152 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_23" + num_buckets: 12973 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_24" + num_buckets: 108 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_25" + num_buckets: 36 + embedding_dim: 16 + } +} + +model_config { + # 特征组配置 - 包含所有特征 + feature_groups { + group_name: "all_features" + feature_names: "int_0" + feature_names: "int_1" + feature_names: "int_2" + feature_names: "int_3" + feature_names: "int_4" + feature_names: "int_5" + feature_names: "int_6" + feature_names: "int_7" + feature_names: "int_8" + feature_names: "int_9" + feature_names: "int_10" + feature_names: "int_11" + feature_names: "int_12" + feature_names: "cat_0" + feature_names: "cat_1" + feature_names: "cat_2" + feature_names: "cat_3" + feature_names: "cat_4" + feature_names: "cat_5" + feature_names: "cat_6" + feature_names: "cat_7" + feature_names: "cat_8" + feature_names: "cat_9" + feature_names: "cat_10" + feature_names: "cat_11" + feature_names: "cat_12" + feature_names: "cat_13" + feature_names: "cat_14" + feature_names: "cat_15" + feature_names: "cat_16" + feature_names: "cat_17" + feature_names: "cat_18" + feature_names: "cat_19" + feature_names: "cat_20" + feature_names: "cat_21" + feature_names: "cat_22" + feature_names: "cat_23" + feature_names: "cat_24" + feature_names: "cat_25" + group_type: DEEP + } + + rank_backbone { + backbone { + blocks { + name: 'all_features' + inputs { + feature_group_name: 'all_features' + } + input_layer { + only_output_3d_tensor: false + } + } + blocks { + name: 'repeated_mask_blocks' + inputs { + block_name: 'all_features' + input_fn: "lambda x: [x, x]" + } + repeat { + # 重复3次MaskBlock操作,相当于3层MaskBlock + num_repeat: 3 + # 输出时在最后一个维度进行拼接 + output_concat_axis: -1 + # 定义要重复的MaskBlock模块 + module { + class_name: 'MaskBlock' + mask_block { + reduction_ratio: 3.0 + hidden_dim: 512 + } + } + } + } + blocks { + name: 'top_mlp' + inputs { + block_name: 'repeated_mask_blocks' + } + module { + class_name: 'MLP' + mlp { + hidden_units: [256, 128, 64, 1] + activation: 'nn.ReLU' + dropout_ratio: [0.0, 0.0, 0.0, 0.0] + use_bn: false + bias: true + } + } + } + concat_blocks: ['top_mlp'] + } + } + metrics { + auc {} + } + losses { + binary_cross_entropy {} + } +} diff --git a/examples/modular/rank/multi_tower_din_taobao_rankbackbone.config b/examples/modular/rank/multi_tower_din_taobao_rankbackbone.config new file mode 100644 index 00000000..407ca1b4 --- /dev/null +++ b/examples/modular/rank/multi_tower_din_taobao_rankbackbone.config @@ -0,0 +1,274 @@ +train_input_path: "odps://pai_rec_test_dev/tables/taobao_multitask_sample_bucketized_v1" +eval_input_path: "odps://pai_rec_test_dev/tables/taobao_multitask_sample_bucketized_v1/ds=20170513" +model_dir: "experiments/multi_tower_din_taobao_rankbackbone" +train_config { + sparse_optimizer { + adagrad_optimizer { + lr: 0.001 + } + constant_learning_rate { + } + } + dense_optimizer { + adam_optimizer { + lr: 0.001 + } + constant_learning_rate { + } + } + num_epochs: 1 +} +eval_config { +} +data_config { + batch_size: 8192 + dataset_type: OdpsDataset + fg_encoded: false + label_fields: "clk" + num_workers: 8 +} +feature_configs { + id_feature { + feature_name: "user_id" + expression: "user:user_id" + num_buckets: 1141730 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cms_segid" + expression: "user:cms_segid" + num_buckets: 98 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cms_group_id" + expression: "user:cms_group_id" + num_buckets: 14 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "final_gender_code" + expression: "user:final_gender_code" + num_buckets: 3 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "age_level" + expression: "user:age_level" + num_buckets: 8 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "pvalue_level" + expression: "user:pvalue_level" + num_buckets: 5 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "shopping_level" + expression: "user:shopping_level" + num_buckets: 5 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "occupation" + expression: "user:occupation" + num_buckets: 3 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "new_user_class_level" + expression: "user:new_user_class_level" + num_buckets: 6 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "adgroup_id" + expression: "item:adgroup_id" + num_buckets: 846812 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cate_id" + expression: "item:cate_id" + num_buckets: 12961 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "campaign_id" + expression: "item:campaign_id" + num_buckets: 423438 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "customer" + expression: "item:customer" + num_buckets: 255877 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "brand" + expression: "item:brand" + num_buckets: 461498 + embedding_dim: 16 + } +} +feature_configs { + raw_feature { + feature_name: "price" + expression: "item:price" + boundaries: [1.1, 2.2, 3.6, 5.2, 7.39, 9.5, 10.5, 12.9, 15, 17.37, 19, 20, 23.8, 25.8, 28, 29.8, 31.5, 34, 36, 38, 39, 40, 45, 48, 49, 51.6, 55.2, 58, 59, 63.8, 68, 69, 72, 78, 79, 85, 88, 90, 97.5, 98, 99, 100, 108, 115, 118, 124, 128, 129, 138, 139, 148, 155, 158, 164, 168, 171.8, 179, 188, 195, 198, 199, 216, 228, 238, 248, 258, 268, 278, 288, 298, 299, 316, 330, 352, 368, 388, 398, 399, 439, 478, 499, 536, 580, 599, 660, 699, 780, 859, 970, 1080, 1280, 1480, 1776, 2188, 2798, 3680, 5160, 8720] + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "pid" + expression: "context:pid" + hash_bucket_size: 20 + embedding_dim: 16 + } +} +feature_configs { + sequence_feature { + sequence_name: "click_50_seq" + sequence_length: 100 + sequence_delim: "|" + features { + id_feature { + feature_name: "adgroup_id" + expression: "item:adgroup_id" + num_buckets: 846812 + embedding_dim: 16 + } + } + features { + id_feature { + feature_name: "cate_id" + expression: "item:cate_id" + num_buckets: 12961 + embedding_dim: 16 + } + } + features { + id_feature { + feature_name: "brand" + expression: "item:brand" + num_buckets: 461498 + embedding_dim: 16 + } + } + } +} +model_config { + feature_groups { + group_name: "deep" + feature_names: "user_id" + feature_names: "cms_segid" + feature_names: "cms_group_id" + feature_names: "final_gender_code" + feature_names: "age_level" + feature_names: "pvalue_level" + feature_names: "shopping_level" + feature_names: "occupation" + feature_names: "new_user_class_level" + feature_names: "adgroup_id" + feature_names: "cate_id" + feature_names: "campaign_id" + feature_names: "customer" + feature_names: "brand" + feature_names: "price" + feature_names: "pid" + group_type: DEEP + } + feature_groups { + group_name: "seq" + feature_names: "adgroup_id" + feature_names: "cate_id" + feature_names: "brand" + feature_names: "click_50_seq__adgroup_id" + feature_names: "click_50_seq__cate_id" + feature_names: "click_50_seq__brand" + group_type: SEQUENCE + } + rank_backbone{ + backbone{ + blocks { + name: 'tower' + inputs { + feature_group_name: 'deep' + } + module { + class_name: 'MLP' + mlp { + hidden_units: [512, 256, 128] + } + } + } + blocks { + name: 'din_attention' + inputs { + feature_group_name: 'seq' + } + module { + class_name: 'DIN' + din { + input: "seq" + attn_mlp { + hidden_units: [256, 64] + } + max_seq_length: 100 + } + } + } + blocks { + name: 'final_mlp' + inputs { + block_name: 'tower' + } + inputs { + block_name: 'din_attention' + } + module { + class_name: 'MLP' + mlp { + hidden_units: [64] + } + } + } + } + } + + metrics { + auc {} + } + losses { + binary_cross_entropy {} + } +} diff --git a/examples/modular/rank/multi_tower_taobao_local_rankbackbone.config b/examples/modular/rank/multi_tower_taobao_local_rankbackbone.config new file mode 100644 index 00000000..33445568 --- /dev/null +++ b/examples/modular/rank/multi_tower_taobao_local_rankbackbone.config @@ -0,0 +1,231 @@ +model_dir: "experiments/multi_tower_taobao_component" +train_config { + sparse_optimizer { + adagrad_optimizer { + lr: 0.001 + } + constant_learning_rate { + } + } + dense_optimizer { + adam_optimizer { + lr: 0.001 + } + constant_learning_rate { + } + } + num_epochs: 1 +} +eval_config { +} +data_config { + batch_size: 8192 + dataset_type: ParquetDataset + fg_mode: FG_DAG + label_fields: "clk" + num_workers: 8 +} +feature_configs { + id_feature { + feature_name: "user_id" + expression: "user:user_id" + num_buckets: 1141730 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cms_segid" + expression: "user:cms_segid" + num_buckets: 98 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cms_group_id" + expression: "user:cms_group_id" + num_buckets: 14 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "final_gender_code" + expression: "user:final_gender_code" + num_buckets: 3 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "age_level" + expression: "user:age_level" + num_buckets: 8 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "pvalue_level" + expression: "user:pvalue_level" + num_buckets: 5 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "shopping_level" + expression: "user:shopping_level" + num_buckets: 5 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "occupation" + expression: "user:occupation" + num_buckets: 3 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "new_user_class_level" + expression: "user:new_user_class_level" + num_buckets: 6 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "adgroup_id" + expression: "item:adgroup_id" + num_buckets: 846812 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cate_id" + expression: "item:cate_id" + num_buckets: 12961 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "campaign_id" + expression: "item:campaign_id" + num_buckets: 423438 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "customer" + expression: "item:customer" + num_buckets: 255877 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "brand" + expression: "item:brand" + num_buckets: 461498 + embedding_dim: 16 + } +} +feature_configs { + raw_feature { + feature_name: "price" + expression: "item:price" + boundaries: [1.1, 2.2, 3.6, 5.2, 7.39, 9.5, 10.5, 12.9, 15, 17.37, 19, 20, 23.8, 25.8, 28, 29.8, 31.5, 34, 36, 38, 39, 40, 45, 48, 49, 51.6, 55.2, 58, 59, 63.8, 68, 69, 72, 78, 79, 85, 88, 90, 97.5, 98, 99, 100, 108, 115, 118, 124, 128, 129, 138, 139, 148, 155, 158, 164, 168, 171.8, 179, 188, 195, 198, 199, 216, 228, 238, 248, 258, 268, 278, 288, 298, 299, 316, 330, 352, 368, 388, 398, 399, 439, 478, 499, 536, 580, 599, 660, 699, 780, 859, 970, 1080, 1280, 1480, 1776, 2188, 2798, 3680, 5160, 8720] + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "pid" + expression: "context:pid" + hash_bucket_size: 20 + embedding_dim: 16 + } +} +model_config { + feature_groups { + group_name: "user" + feature_names: "user_id" + feature_names: "cms_segid" + feature_names: "cms_group_id" + feature_names: "final_gender_code" + feature_names: "age_level" + feature_names: "pvalue_level" + feature_names: "shopping_level" + feature_names: "occupation" + feature_names: "new_user_class_level" + feature_names: "pid" + group_type: DEEP + } + feature_groups { + group_name: "item" + feature_names: "adgroup_id" + feature_names: "cate_id" + feature_names: "campaign_id" + feature_names: "customer" + feature_names: "brand" + feature_names: "price" + group_type: DEEP + } + rank_backbone{ + backbone { + blocks { + name: "user_mlp" + inputs { feature_group_name: "user" } + module { + class_name: "MLP" + mlp { + hidden_units: 512 + hidden_units: 256 + hidden_units: 128 + activation: "nn.ReLU" + } + } + } + blocks { + name: "item_mlp" + inputs { feature_group_name: "item" } + module { + class_name: "MLP" + mlp { + hidden_units: 512 + hidden_units: 256 + hidden_units: 128 + activation: "nn.ReLU" + } + } + } + blocks { + name: "final_mlp" + inputs { block_name: "user_mlp" } + inputs { block_name: "item_mlp" } + merge_inputs_into_list: false + module { + class_name: "MLP" + mlp { + hidden_units: 64 + activation: "nn.ReLU" + } + } + } + concat_blocks: "final_mlp" + } + } + metrics { + auc {} + } + losses { + binary_cross_entropy {} + } +} diff --git a/examples/modular/rank/sequential_mlp_backbone.config b/examples/modular/rank/sequential_mlp_backbone.config new file mode 100644 index 00000000..8aae32ee --- /dev/null +++ b/examples/modular/rank/sequential_mlp_backbone.config @@ -0,0 +1,385 @@ +train_input_path: "odps://pai_rec_test_dev/tables/criteo_terabyte_train_hashed_v1" +eval_input_path: "odps://pai_rec_test_dev/tables/criteo_terabyte_val_test_hashed_v1" +model_dir: "experiments/sequential_mlp_backbone1" +train_config { + sparse_optimizer { + adagrad_optimizer { + lr: 0.001 + } + constant_learning_rate { + } + } + dense_optimizer { + adam_optimizer { + lr: 0.001 + } + constant_learning_rate { + } + } + num_epochs: 1 +} +eval_config { + num_steps: 100 +} +data_config { + batch_size: 8192 + dataset_type: OdpsDataset + fg_encoded: true + label_fields: "label" + num_workers: 8 +} +feature_configs { + raw_feature { + feature_name: "int_0" + } +} +feature_configs { + raw_feature { + feature_name: "int_1" + } +} +feature_configs { + raw_feature { + feature_name: "int_2" + } +} +feature_configs { + raw_feature { + feature_name: "int_3" + } +} +feature_configs { + raw_feature { + feature_name: "int_4" + } +} +feature_configs { + raw_feature { + feature_name: "int_5" + } +} +feature_configs { + raw_feature { + feature_name: "int_6" + } +} +feature_configs { + raw_feature { + feature_name: "int_7" + } +} +feature_configs { + raw_feature { + feature_name: "int_8" + } +} +feature_configs { + raw_feature { + feature_name: "int_9" + } +} +feature_configs { + raw_feature { + feature_name: "int_10" + } +} +feature_configs { + raw_feature { + feature_name: "int_11" + } +} +feature_configs { + raw_feature { + feature_name: "int_12" + } +} +feature_configs { + id_feature { + feature_name: "cat_0" + num_buckets: 40000000 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_1" + num_buckets: 39060 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_2" + num_buckets: 17295 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_3" + num_buckets: 7424 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_4" + num_buckets: 20265 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_5" + num_buckets: 3 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_6" + num_buckets: 7122 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_7" + num_buckets: 1543 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_8" + num_buckets: 63 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_9" + num_buckets: 40000000 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_10" + num_buckets: 3067956 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_11" + num_buckets: 405282 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_12" + num_buckets: 10 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_13" + num_buckets: 2209 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_14" + num_buckets: 11938 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_15" + num_buckets: 155 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_16" + num_buckets: 4 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_17" + num_buckets: 976 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_18" + num_buckets: 14 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_19" + num_buckets: 40000000 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_20" + num_buckets: 40000000 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_21" + num_buckets: 40000000 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_22" + num_buckets: 590152 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_23" + num_buckets: 12973 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_24" + num_buckets: 108 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_25" + num_buckets: 36 + embedding_dim: 16 + } +} +model_config { + feature_groups { + group_name: "features" + feature_names: "int_0" + feature_names: "int_1" + feature_names: "int_2" + feature_names: "int_3" + feature_names: "int_4" + feature_names: "int_5" + feature_names: "int_6" + feature_names: "int_7" + feature_names: "int_8" + feature_names: "int_9" + feature_names: "int_10" + feature_names: "int_11" + feature_names: "int_12" + feature_names: "cat_0" + feature_names: "cat_1" + feature_names: "cat_2" + feature_names: "cat_3" + feature_names: "cat_4" + feature_names: "cat_5" + feature_names: "cat_6" + feature_names: "cat_7" + feature_names: "cat_8" + feature_names: "cat_9" + feature_names: "cat_10" + feature_names: "cat_11" + feature_names: "cat_12" + feature_names: "cat_13" + feature_names: "cat_14" + feature_names: "cat_15" + feature_names: "cat_16" + feature_names: "cat_17" + feature_names: "cat_18" + feature_names: "cat_19" + feature_names: "cat_20" + feature_names: "cat_21" + feature_names: "cat_22" + feature_names: "cat_23" + feature_names: "cat_24" + feature_names: "cat_25" + group_type: DEEP + } + rank_backbone{ + backbone { + blocks { + name: 'mlp' + inputs { + feature_group_name: 'features' + } + layers { + module { + class_name: 'Linear' + st_params { + fields { + key:'in_features' + value:{number_value:429} + } + fields { + key: 'out_features' + value: { number_value: 256 } + } + } + } + } + layers { + module { + class_name: 'ReLU' + } + } + layers { + module { + class_name: 'Dropout' + st_params { + fields { + key: 'p' + value: { number_value: 0.5 } + } + } + } + } + layers{ + module { + class_name: 'Linear' + st_params { + fields { + key: 'in_features' + value: { number_value: 256 } + } + fields { + key: 'out_features' + value: { number_value: 1 } + } + } + } + } + } + concat_blocks: 'mlp' + } + } + metrics { + auc {} + } + losses { + binary_cross_entropy {} + } +} diff --git a/examples/modular/rank/wide_and_deep_criteo_modular.config b/examples/modular/rank/wide_and_deep_criteo_modular.config new file mode 100644 index 00000000..11d4eb24 --- /dev/null +++ b/examples/modular/rank/wide_and_deep_criteo_modular.config @@ -0,0 +1,401 @@ +train_input_path: "odps://pai_rec_test_dev/tables/criteo_terabyte_train_hashed_v1" +eval_input_path: "odps://pai_rec_test_dev/tables/criteo_terabyte_val_test_hashed_v1" +model_dir: "experiments/wide_and_deep_criteo_modular" +train_config { + sparse_optimizer { + adagrad_optimizer { + lr: 0.001 + } + constant_learning_rate { + } + } + dense_optimizer { + adam_optimizer { + lr: 0.001 + } + constant_learning_rate { + } + } + num_epochs: 1 +} +eval_config { + num_steps: 100 +} +data_config { + batch_size: 8192 + dataset_type: OdpsDataset + fg_encoded: true + label_fields: "label" + num_workers: 8 +} +feature_configs { + raw_feature { + feature_name: "int_0" + } +} +feature_configs { + raw_feature { + feature_name: "int_1" + } +} +feature_configs { + raw_feature { + feature_name: "int_2" + } +} +feature_configs { + raw_feature { + feature_name: "int_3" + } +} +feature_configs { + raw_feature { + feature_name: "int_4" + } +} +feature_configs { + raw_feature { + feature_name: "int_5" + } +} +feature_configs { + raw_feature { + feature_name: "int_6" + } +} +feature_configs { + raw_feature { + feature_name: "int_7" + } +} +feature_configs { + raw_feature { + feature_name: "int_8" + } +} +feature_configs { + raw_feature { + feature_name: "int_9" + } +} +feature_configs { + raw_feature { + feature_name: "int_10" + } +} +feature_configs { + raw_feature { + feature_name: "int_11" + } +} +feature_configs { + raw_feature { + feature_name: "int_12" + } +} +feature_configs { + id_feature { + feature_name: "cat_0" + num_buckets: 40000000 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_1" + num_buckets: 39060 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_2" + num_buckets: 17295 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_3" + num_buckets: 7424 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_4" + num_buckets: 20265 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_5" + num_buckets: 3 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_6" + num_buckets: 7122 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_7" + num_buckets: 1543 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_8" + num_buckets: 63 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_9" + num_buckets: 40000000 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_10" + num_buckets: 3067956 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_11" + num_buckets: 405282 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_12" + num_buckets: 10 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_13" + num_buckets: 2209 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_14" + num_buckets: 11938 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_15" + num_buckets: 155 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_16" + num_buckets: 4 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_17" + num_buckets: 976 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_18" + num_buckets: 14 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_19" + num_buckets: 40000000 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_20" + num_buckets: 40000000 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_21" + num_buckets: 40000000 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_22" + num_buckets: 590152 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_23" + num_buckets: 12973 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_24" + num_buckets: 108 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "cat_25" + num_buckets: 36 + embedding_dim: 16 + } +} +model_config { + feature_groups { + group_name: "wide" + feature_names: "cat_0" + feature_names: "cat_1" + feature_names: "cat_2" + feature_names: "cat_3" + feature_names: "cat_4" + feature_names: "cat_5" + feature_names: "cat_6" + feature_names: "cat_7" + feature_names: "cat_8" + feature_names: "cat_9" + feature_names: "cat_10" + feature_names: "cat_11" + feature_names: "cat_12" + feature_names: "cat_13" + feature_names: "cat_14" + feature_names: "cat_15" + feature_names: "cat_16" + feature_names: "cat_17" + feature_names: "cat_18" + feature_names: "cat_19" + feature_names: "cat_20" + feature_names: "cat_21" + feature_names: "cat_22" + feature_names: "cat_23" + feature_names: "cat_24" + feature_names: "cat_25" + group_type: WIDE + } + feature_groups { + group_name: "deep" + feature_names: "int_0" + feature_names: "int_1" + feature_names: "int_2" + feature_names: "int_3" + feature_names: "int_4" + feature_names: "int_5" + feature_names: "int_6" + feature_names: "int_7" + feature_names: "int_8" + feature_names: "int_9" + feature_names: "int_10" + feature_names: "int_11" + feature_names: "int_12" + feature_names: "cat_0" + feature_names: "cat_1" + feature_names: "cat_2" + feature_names: "cat_3" + feature_names: "cat_4" + feature_names: "cat_5" + feature_names: "cat_6" + feature_names: "cat_7" + feature_names: "cat_8" + feature_names: "cat_9" + feature_names: "cat_10" + feature_names: "cat_11" + feature_names: "cat_12" + feature_names: "cat_13" + feature_names: "cat_14" + feature_names: "cat_15" + feature_names: "cat_16" + feature_names: "cat_17" + feature_names: "cat_18" + feature_names: "cat_19" + feature_names: "cat_20" + feature_names: "cat_21" + feature_names: "cat_22" + feature_names: "cat_23" + feature_names: "cat_24" + feature_names: "cat_25" + group_type: DEEP + } + rank_backbone { + backbone { + blocks { + name: 'wide' + inputs { + feature_group_name: 'wide' + } + input_layer { + wide_output_dim: 1 + only_output_feature_list: true + } + } + blocks { + name: 'deep_logit' + inputs { + feature_group_name: 'deep' + } + module { + class_name: 'MLP' + mlp { + hidden_units: [256, 256, 256, 1] + activation: 'nn.ReLU' + } + } + } + blocks { + name: 'final_logit' + inputs { + block_name: 'wide' + input_fn: 'lambda x: x.sum(dim=-1, keepdim=True)' + } + inputs { + block_name: 'deep_logit' + } + # 合并成list + merge_inputs_into_list: true + lambda { + expression: 'lambda xs: torch.cat(xs, dim=1)' + } + } + concat_blocks: 'final_logit' + } + } + metrics { + auc {} + } + losses { + binary_cross_entropy {} + } +} diff --git a/launch.json b/launch.json new file mode 100644 index 00000000..56737ac2 --- /dev/null +++ b/launch.json @@ -0,0 +1,35 @@ +{ + "version": "0.2.0", + "configurations": [ + + + + { + "name": "tzrec with torchrun", + "type": "python", + "request": "launch", + "module": "torch.distributed.run", + "console": "integratedTerminal", + "cwd": "/nas/fengzuocheng/TorchEasyRec", + "args": [ + "--nproc_per_node=1", // 每个节点使用的GPU数量 + "--nnodes=1", // 节点总数 + "--node_rank=0", // 当前节点rank + "--master_addr=127.0.0.1", // 主节点地址 + "--master_port=209", // 主节点端口 + "tzrec/train_eval.py", // 训练脚本 + "--pipeline_config_path=examples/modular/rank/multi_tower_din_taobao_rankbackbone.config" ,// 配置文件路径 + // "--train_input_path=data/taobao_data_train/*.parquet", + // "--eval_input_path=data/taobao_data_eval/*.parquet", + // "--continue_train" + ], + "env": { + "PYTHONPATH": "/nas/fengzuocheng/TorchEasyRec", + "CUDA_VISIBLE_DEVICES": "1", // 指定可见GPU + "ODPS_CONFIG_FILE_PATH": "./odps_conf" + }, + "python": "/root/miniconda3/envs/tzrec/bin/python", + "stopOnEntry": false + } + ] +} diff --git a/requirements/runtime.txt b/requirements/runtime.txt index e8bf1062..5b152232 100644 --- a/requirements/runtime.txt +++ b/requirements/runtime.txt @@ -5,6 +5,7 @@ fbgemm-gpu==1.2.0 graphlearn @ https://tzrec.oss-accelerate.aliyuncs.com/third_party/graphlearn-1.3.6-cp311-cp311-linux_x86_64.whl ; python_version=="3.11" graphlearn @ https://tzrec.oss-accelerate.aliyuncs.com/third_party/graphlearn-1.3.6-cp310-cp310-linux_x86_64.whl ; python_version=="3.10" grpcio-tools<1.63.0 +networkx numpy<2 pandas psutil diff --git a/tzrec/layers/backbone.py b/tzrec/layers/backbone.py new file mode 100644 index 00000000..e69de29b diff --git a/tzrec/models/modular_match.py b/tzrec/models/modular_match.py new file mode 100644 index 00000000..315df990 --- /dev/null +++ b/tzrec/models/modular_match.py @@ -0,0 +1,377 @@ +# Copyright (c) 2025, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, List, Optional, Union + +import torch +from torch import nn + +from tzrec.datasets.utils import Batch +from tzrec.features.feature import BaseFeature +from tzrec.models.match_model import MatchModel +from tzrec.modules.backbone import Backbone +from tzrec.protos import simi_pb2 +from tzrec.protos.model_pb2 import ModelConfig + + +class ModularMatch(MatchModel): + """Match backbone model for flexible dual-tower matching with configurable backbone. + + This implementation supports various matching models (DSSM, DAT, etc.) by using + a flexible backbone network that can output features for different towers. + + Args: + model_config (ModelConfig): an instance of ModelConfig. + features (list): list of features. + labels (list): list of label names. + sample_weights (list): sample weight names. + """ + + def __init__( + self, + model_config: ModelConfig, + features: List[BaseFeature], + labels: List[str], + sample_weights: Optional[List[str]] = None, + **kwargs: Any, + ) -> None: + super().__init__(model_config, features, labels, sample_weights, **kwargs) + + # get backbone config + self._match_backbone_config = self._base_model_config.match_backbone + + # get model params + model_params = getattr(self._match_backbone_config, "model_params", None) + self._output_dim = 64 # default + self._similarity_type = simi_pb2.INNER_PRODUCT # default + self._temperature = 1.0 # default temperature + + # Try getting parameters from different sources + if model_params: + # Get parameters from model_paramsGet parameters from model_params (if any) + self._output_dim = getattr(model_params, "output_dim", self._output_dim) + if hasattr(model_params, "similarity"): + self._similarity_type = model_params.similarity + if hasattr(model_params, "temperature"): + self._temperature = model_params.temperature + + # Get parameters from kwargs (passed in at runtime)im", self._output_dim) + self._similarity_type = kwargs.get("similarity", self._similarity_type) + self._temperature = kwargs.get("temperature", self._temperature) + + # build backbone network + self._backbone_net = self.build_backbone_network() + + # get backbone output blocks configuration + self._output_blocks = self._get_output_blocks() + + self._user_tower_input = self._output_blocks.get("user", None) + self._item_tower_input = self._output_blocks.get("item", None) + + # if user/item tower input not explicitly set, setup default + if not self._user_tower_input and not self._item_tower_input: + self._setup_default_tower_inputs() + + def build_backbone_network(self) -> Backbone: + """Build backbone network.""" + wide_embedding_dim = ( + int(self.wide_embedding_dim) + if hasattr(self, "wide_embedding_dim") + else None + ) + wide_init_fn = self.wide_init_fn if hasattr(self, "wide_init_fn") else None + feature_groups = list(self._base_model_config.feature_groups) + + return Backbone( + config=self._match_backbone_config.backbone, + features=self._features, + embedding_group=None, + feature_groups=feature_groups, + wide_embedding_dim=wide_embedding_dim, + wide_init_fn=wide_init_fn, + ) + + def _get_output_blocks(self) -> Dict[str, str]: + """Get output blocks configuration for different towers. + + Returns: + Dict[str, str]: mapping from tower name to block name. + """ + output_blocks = {} + backbone_config = self._match_backbone_config.backbone + + # Check if there is output_blocks configuration + if hasattr(backbone_config, "output_blocks") and backbone_config.output_blocks: + output_block_list = list(backbone_config.output_blocks) + + # Try to infer user towers and item towers based on block names + for block_name in output_block_list: + if "user" in block_name.lower(): + output_blocks["user"] = block_name + elif "item" in block_name.lower() or "product" in block_name.lower(): + output_blocks["item"] = block_name + + # if not found, use first two blocks as user/item towers + if len(output_block_list) == 2 and len(output_blocks) == 0: + output_blocks["user"] = output_block_list[0] + output_blocks["item"] = output_block_list[1] + + return output_blocks + + def _setup_default_tower_inputs(self): + """Setup default tower inputs when not explicitly configured.""" + # default: use first two output blocks if available + backbone_output_names = self._backbone_net.get_output_block_names() + + if len(backbone_output_names) >= 2: + self._user_tower_input = backbone_output_names[0] + self._item_tower_input = backbone_output_names[1] + else: + # single output block, use it for both towers + self._user_tower_input = ( + backbone_output_names[0] if backbone_output_names else "shared" + ) + self._item_tower_input = self._user_tower_input + + def backbone( + self, batch: Batch + ) -> Union[torch.Tensor, List[torch.Tensor], Dict[str, torch.Tensor]]: + """Get backbone output.""" + if self._backbone_net: + kwargs = { + "loss_modules": self._loss_modules, + "metric_modules": self._metric_modules, + "labels": self._labels, + } + return self._backbone_net( + batch=batch, + **kwargs, + ) + return None + + def _extract_tower_feature( + self, + backbone_output: Union[ + torch.Tensor, List[torch.Tensor], Dict[str, torch.Tensor] + ], + tower_input: str, + ) -> torch.Tensor: + """Extract tower-specific feature from backbone output. + + Args: + backbone_output: Output from backbone network. + tower_input: Name of the input for this tower. + + Returns: + torch.Tensor: Tower-specific feature tensor. + """ + if isinstance(backbone_output, dict): + # If backbone returns a dictionary, get it directly by name + if tower_input in backbone_output: + return backbone_output[tower_input] + else: + # If the specified tower_input is not found + for key in backbone_output.keys(): + if tower_input.lower() in key.lower(): + return backbone_output[key] + # If none are found, return the first value. + return list(backbone_output.values())[0] + elif isinstance(backbone_output, (list, tuple)): + if tower_input == self._user_tower_input and len(backbone_output) > 0: + return backbone_output[0] + elif tower_input == self._item_tower_input and len(backbone_output) > 1: + return backbone_output[1] + else: + return backbone_output[0] + else: + # If it is a single tensor, return directly + return backbone_output + + def user_tower(self, batch: Batch) -> torch.Tensor: + """Extract user embedding from backbone output. + + Args: + batch (Batch): input batch data. + + Returns: + torch.Tensor: user embedding tensor. + """ + backbone_output = self.backbone(batch) + user_feature = self._extract_tower_feature( + backbone_output, self._user_tower_input + ) + + if user_feature.size(-1) != self._output_dim: + if not hasattr(self, "_user_projection_layer"): + self._user_projection_layer = nn.Linear( + user_feature.size(-1), self._output_dim + ) + if torch.cuda.is_available() and user_feature.is_cuda: + self._user_projection_layer = self._user_projection_layer.cuda() + user_emb = self._user_projection_layer(user_feature) + else: + user_emb = user_feature + + if self._similarity_type == simi_pb2.COSINE: + user_emb = nn.functional.normalize(user_emb, p=2, dim=-1) + + return user_emb + + def item_tower(self, batch: Batch) -> torch.Tensor: + """Extract item embedding from backbone output. + + Args: + batch (Batch): input batch data. + + Returns: + torch.Tensor: item embedding tensor. + """ + backbone_output = self.backbone(batch) + item_feature = self._extract_tower_feature( + backbone_output, self._item_tower_input + ) + + if item_feature.size(-1) != self._output_dim: + if not hasattr(self, "_item_projection_layer"): + self._item_projection_layer = nn.Linear( + item_feature.size(-1), self._output_dim + ) + if torch.cuda.is_available() and item_feature.is_cuda: + self._item_projection_layer = self._item_projection_layer.cuda() + item_emb = self._item_projection_layer(item_feature) + else: + item_emb = item_feature + + if self._similarity_type == simi_pb2.COSINE: + item_emb = nn.functional.normalize(item_emb, p=2, dim=-1) + + return item_emb + + def predict(self, batch: Batch) -> Dict[str, torch.Tensor]: + """Predict the model. + + Args: + batch (Batch): input batch data. + + Return: + predictions (dict): a dict of predicted result. + """ + user_emb = self.user_tower(batch) + item_emb = self.item_tower(batch) + + # compute similarity + hard_neg_indices = getattr(batch, "hard_neg_indices", None) + similarity = self.sim(user_emb, item_emb, hard_neg_indices) + + if self._temperature != 1.0: + similarity = similarity / self._temperature + + return {"similarity": similarity} + + def get_user_tower(self) -> nn.Module: + """Get user tower for inference. + + Returns: + nn.Module: user tower module for jit scripting. + """ + + class UserTowerInference(nn.Module): + def __init__(self, match_backbone_model): + super().__init__() + self.backbone_net = match_backbone_model._backbone_net + self._user_tower_input = match_backbone_model._user_tower_input + self._output_dim = match_backbone_model._output_dim + self._similarity_type = match_backbone_model._similarity_type + + if hasattr(match_backbone_model, "_user_projection_layer"): + self.user_projection_layer = ( + match_backbone_model._user_projection_layer + ) + else: + self.user_projection_layer = None + + def forward(self, batch: Batch) -> torch.Tensor: + backbone_output = self.backbone_net(batch=batch) + + if isinstance(backbone_output, dict): + if self._user_tower_input in backbone_output: + user_feature = backbone_output[self._user_tower_input] + else: + user_feature = list(backbone_output.values())[0] + elif isinstance(backbone_output, (list, tuple)): + user_feature = backbone_output[0] + else: + user_feature = backbone_output + + if self.user_projection_layer is not None: + user_emb = self.user_projection_layer(user_feature) + else: + user_emb = user_feature + + # normalize if using cosine similarity + if self._similarity_type == simi_pb2.COSINE: + user_emb = nn.functional.normalize(user_emb, p=2, dim=-1) + + return user_emb + + return UserTowerInference(self) + + def get_item_tower(self) -> nn.Module: + """Get item tower for inference. + + Returns: + nn.Module: item tower module for jit scripting. + """ + + class ItemTowerInference(nn.Module): + def __init__(self, match_backbone_model): + super().__init__() + self.backbone_net = match_backbone_model._backbone_net + self._item_tower_input = match_backbone_model._item_tower_input + self._output_dim = match_backbone_model._output_dim + self._similarity_type = match_backbone_model._similarity_type + + if hasattr(match_backbone_model, "_item_projection_layer"): + self.item_projection_layer = ( + match_backbone_model._item_projection_layer + ) + else: + self.item_projection_layer = None + + def forward(self, batch: Batch) -> torch.Tensor: + backbone_output = self.backbone_net(batch=batch) + + if isinstance(backbone_output, dict): + if self._item_tower_input in backbone_output: + item_feature = backbone_output[self._item_tower_input] + else: + item_feature = list(backbone_output.values())[0] + elif isinstance(backbone_output, (list, tuple)): + item_feature = ( + backbone_output[1] + if len(backbone_output) > 1 + else backbone_output[0] + ) + else: + item_feature = backbone_output + + if self.item_projection_layer is not None: + item_emb = self.item_projection_layer(item_feature) + else: + item_emb = item_feature + + # normalize if using cosine similarity + if self._similarity_type == simi_pb2.COSINE: + item_emb = nn.functional.normalize(item_emb, p=2, dim=-1) + + return item_emb + + return ItemTowerInference(self) diff --git a/tzrec/models/modular_multi_task.py b/tzrec/models/modular_multi_task.py new file mode 100644 index 00000000..d45bcf9e --- /dev/null +++ b/tzrec/models/modular_multi_task.py @@ -0,0 +1,148 @@ +# Copyright (c) 2025, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, List, Optional + +import torch +from torch import nn + +from tzrec.datasets.utils import Batch +from tzrec.features.feature import BaseFeature +from tzrec.models.multi_task_rank import MultiTaskRank +from tzrec.modules.backbone import Backbone +from tzrec.protos.model_pb2 import ModelConfig +from tzrec.utils.config_util import config_to_kwargs + + +class ModularMultiTask(MultiTaskRank): + """Multi-task backbone model. + + Args: + model_config (ModelConfig): an instance of ModelConfig. + features (list): list of features. + labels (list): list of label names. + sample_weights (list): sample weight names. + """ + + def __init__( + self, + model_config: ModelConfig, + features: List[BaseFeature], + labels: List[str], + sample_weights: Optional[List[str]] = None, + **kwargs: Any, + ) -> None: + super().__init__(model_config, features, labels, sample_weights, **kwargs) + + # build backbone network + self._backbone_net = self.build_backbone_network() + + # build task towers + self._task_towers = self.build_task_towers() + + def build_backbone_network(self): + """Build backbone network.""" + wide_embedding_dim = ( + int(self.wide_embedding_dim) + if hasattr(self, "wide_embedding_dim") + else None + ) + wide_init_fn = self.wide_init_fn if hasattr(self, "wide_init_fn") else None + feature_groups = list(self._base_model_config.feature_groups) + + return Backbone( + config=self._base_model_config.multi_task_backbone.backbone, + features=self._features, + embedding_group=None, + feature_groups=feature_groups, + wide_embedding_dim=wide_embedding_dim, + wide_init_fn=wide_init_fn, + ) + + def build_task_towers(self): + """Build task towers based on backbone output dimension.""" + # get backbone output dimension + backbone_output_dim = self._backbone_net.output_dim() + + task_towers = nn.ModuleDict() + for task_tower_cfg in self._task_tower_cfgs: + tower_name = task_tower_cfg.tower_name + num_class = task_tower_cfg.num_class + + # Check whether there is a custom MLP configuration + if task_tower_cfg.HasField("mlp"): + from tzrec.modules.mlp import MLP + + mlp_config = config_to_kwargs(task_tower_cfg.mlp) + task_tower = nn.Sequential( + MLP(in_features=backbone_output_dim, **mlp_config), + nn.Linear(mlp_config["hidden_units"][-1], num_class), + ) + else: + # Connect directly to the output layer + task_tower = nn.Linear(backbone_output_dim, num_class) + + task_towers[tower_name] = task_tower + + return task_towers + + def backbone(self, batch: Batch) -> torch.Tensor: + """Get backbone output.""" + if self._backbone_net: + kwargs = { + "loss_modules": self._loss_modules, + "metric_modules": self._metric_modules, + "labels": self._labels, + } + return self._backbone_net( + batch=batch, + **kwargs, + ) + return None + + def predict(self, batch: Batch) -> Dict[str, torch.Tensor]: + """Predict the model. + + Args: + batch (Batch): input batch data. + + Return: + predictions (dict): a dict of predicted result. + """ + # get backbone output + backbone_output = self.backbone(batch) + + # Process backbone output: it may be + # a single tensor or a list of tensors + if isinstance(backbone_output, (list, tuple)): + # The backbone returns a list (such as the MMoE module), + # which needs to correspond one-to-one with the task tower. + if len(backbone_output) != len(self._task_tower_cfgs): + raise ValueError( + f"The number of backbone outputs ({len(backbone_output)}) and " + f"task towers ({len(self._task_tower_cfgs)}) must be equal" + ) + task_input_list = backbone_output + else: + # Backbone returns a single tensor, + # which is copied to all task towers + task_input_list = [backbone_output] * len(self._task_tower_cfgs) + + # Generate predictions through each mission tower + tower_outputs = {} + for i, task_tower_cfg in enumerate(self._task_tower_cfgs): + tower_name = task_tower_cfg.tower_name + task_input = task_input_list[i] + tower_output = self._task_towers[tower_name](task_input) + tower_outputs[tower_name] = tower_output + + # Convert to final prediction format + return self._multi_task_output_to_prediction(tower_outputs) diff --git a/tzrec/models/modular_rank.py b/tzrec/models/modular_rank.py new file mode 100644 index 00000000..d2103a87 --- /dev/null +++ b/tzrec/models/modular_rank.py @@ -0,0 +1,93 @@ +# Copyright (c) 2025, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, List, Optional + +import torch +from torch import nn + +from tzrec.datasets.utils import Batch +from tzrec.features.feature import BaseFeature +from tzrec.models.rank_model import RankModel +from tzrec.modules.backbone import Backbone +from tzrec.protos.model_pb2 import ModelConfig + + +class ModularRank(RankModel): + """Ranking backbone model.""" + + def __init__( + self, + model_config: ModelConfig, + features: List[BaseFeature], + labels: List[str], + sample_weights: Optional[List[str]] = None, + **kwargs: Any, + ) -> None: + super().__init__(model_config, features, labels, sample_weights, **kwargs) + self._feature_dict = features + self._backbone_output = None + self._backbone_net = self.build_backbone_network() + + # Use the final output dimension of backbone and consider the impact of top_mlp + output_dims = self._backbone_net.output_dim() + + self.output_mlp = nn.Linear(output_dims, self._num_class) + + def build_backbone_network(self) -> Backbone: + """Build backbone.""" + wide_embedding_dim = ( + int(self.wide_embedding_dim) + if hasattr(self, "wide_embedding_dim") + else None + ) + wide_init_fn = self.wide_init_fn if hasattr(self, "wide_init_fn") else None + feature_groups = list(self._base_model_config.feature_groups) + return Backbone( + config=self._base_model_config.rank_backbone.backbone, + features=self._feature_dict, + embedding_group=None, # Backbone create the EmbeddingGroup itself + feature_groups=feature_groups, + wide_embedding_dim=wide_embedding_dim, + wide_init_fn=wide_init_fn, + ) + + def backbone( + self, + batch: Batch, + ) -> Optional[nn.Module]: + """Get backbone.""" + if self._backbone_output: + return self._backbone_output + if self._backbone_net: + kwargs = { + "loss_modules": self._loss_modules, + "metric_modules": self._metric_modules, + "labels": self._labels, + } + return self._backbone_net( + batch=batch, + **kwargs, + ) + return None + + def predict(self, batch: Batch) -> Dict[str, torch.Tensor]: + """Predict the model. + + Args: + batch (Batch): input batch data. + + Return: + predictions (dict): a dict of predicted result. + """ + output = self.backbone(batch=batch) + y = self.output_mlp(output) + return self._output_to_prediction(y) diff --git a/tzrec/modules/__init__.py b/tzrec/modules/__init__.py index f971bcbd..98c1f956 100644 --- a/tzrec/modules/__init__.py +++ b/tzrec/modules/__init__.py @@ -8,3 +8,25 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +# from .backbone_module import FM, Add +# from .cross import Cross, CrossNet +from .fm import FactorizationMachine as FM +from .interaction import Cross, CrossV2 +from .masknet import MaskBlock, MaskNetModule +from .mlp import MLP +from .mmoe import MMoE +from .sequence import DINEncoder as DIN + +# from .fm import FactorizationMachine as FM +__all__ = [ + "MLP", + "Add", + "FM", + "DIN", + "MMoE", + "Cross", + "CrossV2", + "MaskNetModule", + "MaskBlock", +] diff --git a/tzrec/modules/backbone.py b/tzrec/modules/backbone.py new file mode 100644 index 00000000..de269d60 --- /dev/null +++ b/tzrec/modules/backbone.py @@ -0,0 +1,1944 @@ +# Copyright (c) 2025, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Dict, List, Optional, Tuple, Union + +import networkx as nx +import torch +from networkx.drawing.nx_agraph import to_agraph +from torch import nn + +from tzrec.datasets.utils import Batch +from tzrec.features.feature import BaseFeature +from tzrec.modules.embedding import EmbeddingGroup +from tzrec.modules.mlp import MLP +from tzrec.protos import backbone_pb2, torch_layer_pb2 +from tzrec.protos.model_pb2 import FeatureGroupConfig +from tzrec.utils.backbone_utils import Parameter +from tzrec.utils.config_util import config_to_kwargs +from tzrec.utils.dimension_inference import ( + DimensionInferenceEngine, + DimensionInfo, + create_dimension_info_from_embedding, +) +from tzrec.utils.lambda_inference import LambdaOutputDimInferrer +from tzrec.utils.load_class import load_torch_layer # pyre ignore[21] +from tzrec.utils.logging_util import logger + +# Constants for auto-inferred parameters +# Input dimension related parameters +INPUT_DIM_PARAMS = ["in_features", "input_dim", "feature_dim", "mask_input_dim"] + +# Sequence dimension related parameters +SEQUENCE_QUERY_PARAMS = ["sequence_dim", "query_dim"] + +# All parameters that support automatic inference +AUTO_INFER_PARAMS = INPUT_DIM_PARAMS + SEQUENCE_QUERY_PARAMS + + +class LambdaWrapper(nn.Module): + """Lambda expression wrapper for dimension inference and execution.""" + + def __init__(self, expression: str, name: str = "lambda_wrapper") -> None: + super().__init__() + self.expression = expression + self.name = name + self._lambda_fn = None + self._compile_function() + + def _compile_function(self) -> None: + """Compiling Lambda Functions.""" + try: + self._lambda_fn = eval(self.expression) + if not callable(self._lambda_fn): + raise ValueError( + f"Expression does not evaluate to callable: {self.expression}" + ) + except Exception as e: + logger.error(f"Failed to compile lambda function '{self.expression}': {e}") + raise + + def forward( + self, x: Union[torch.Tensor, List[torch.Tensor], Dict[str, torch.Tensor]] + ) -> Union[torch.Tensor, List[torch.Tensor], Dict[str, torch.Tensor]]: + """Executing lambda expressions.""" + if self._lambda_fn is None: + raise ValueError("Lambda function not compiled") + return self._lambda_fn(x) + + def infer_output_dim(self, input_dim_info: DimensionInfo) -> DimensionInfo: + """Inferring output dims using LambdaOutputDimInferrer.""" + try: + inferrer = LambdaOutputDimInferrer() + output_dim_info = inferrer.infer_output_dim(input_dim_info, self.expression) + logger.debug( + f"Lambda wrapper {self.name} inferred output dim: {output_dim_info}" + ) + return output_dim_info + except Exception as e: + logger.warning( + f"Failed to infer output dim for lambda {self.name}: {e}, using input dim" # NOQA + ) + return input_dim_info + + def __repr__(self) -> str: + return f"LambdaWrapper(name={self.name}, expression='{self.expression}')" + + +class Package(nn.Module): + """A sub DAG for reuse.""" + + __packages = {} + + @staticmethod + def has_backbone_block(name: str) -> bool: + """Return True if the backbone block with the given name exists.""" + if "backbone" not in Package.__packages: + return False + backbone = Package.__packages["backbone"] + return backbone.has_block(name) + + @staticmethod + def backbone_block_outputs( + name: str, + ) -> Optional[Union[torch.Tensor, List[torch.Tensor], Dict[str, torch.Tensor]]]: + """Get the outputs of a backbone block by name. + + Args: + name (str): The name of the backbone block to retrieve outputs for. + + Returns: + Any: The output of the specified backbone block, or None if the backbone + package doesn't exist or the block is not found. + """ + if "backbone" not in Package.__packages: + return None + backbone = Package.__packages["backbone"] + return backbone.block_outputs(name) + + def __init__( + self, + config: backbone_pb2.BlockPackage, + features: List[BaseFeature], + embedding_group: EmbeddingGroup, + feature_groups: List[FeatureGroupConfig], + wide_embedding_dim: Optional[int] = None, + wide_init_fn: Optional[str] = None, + ) -> None: + super().__init__() + self._config = config + self._features = features + self._embedding_group = embedding_group + self._feature_groups = feature_groups + self._wide_embedding_dim = wide_embedding_dim + self._wide_init_fn = wide_init_fn + # build DAG using networkx DiGraph + self.G = nx.DiGraph() + self._name_to_blocks = {} + + self._name_to_layer = nn.ModuleDict() # Layer corresponding to each Block name + self._name_to_customize = {} # Whether each Block is a custom implementation + + # Dimension inference engine + self.dim_engine = DimensionInferenceEngine() + + self._name_to_output_dim = {} + self._name_to_input_dim = {} + + self.reset_input_config(None) + self._block_outputs = {} + self._package_input = None + self._feature_group_inputs = {} + input_feature_groups = self._feature_group_inputs + + # ======= step 1: Register all nodes ======= + for block in config.blocks: + if len(block.inputs) == 0: + raise ValueError("block takes at least one input: %s" % block.name) + self._name_to_blocks[block.name] = block + self.G.add_node(block.name) + + # ======= step 2: Complete all DAG edges ======== + for block in config.blocks: + name = block.name + for input_node in block.inputs: + input_type = input_node.WhichOneof( + "name" + ) # feature_group_name / block_name + input_name = getattr(input_node, input_type) + if input_type == "feature_group_name": + # If not registered, register it as an input node. + # "feature_group_name" requires adding a new DAG node. + if input_name not in self._name_to_blocks: + new_block = backbone_pb2.Block() + new_block.name = input_name + input_cfg = backbone_pb2.Input() + input_cfg.feature_group_name = input_name + new_block.inputs.append(input_cfg) + new_block.input_layer.CopyFrom(backbone_pb2.InputLayer()) + self._name_to_blocks[input_name] = new_block + self.G.add_node(input_name) + self.G.add_edge(input_name, name) + elif input_type == "package_name": + # The package is the sub-DAG as the input of the Block + raise NotImplementedError + else: + # block-to-block + if input_name in self._name_to_blocks: + self.G.add_edge(input_name, name) + else: + raise KeyError( + f"input name `{input_name}` not found in blocks/feature_groups" # NOQA + ) + # ========== step 3: After topological sorting, define_layer in order ========== + self.topo_order = nx.topological_sort(self.G) + self.topo_order_list = list(self.topo_order) + A = to_agraph(self.G) + A.layout("dot") + import hashlib + import time + + config_info = f"{config.name}_{len(config.blocks)}_{len(self._name_to_layer)}" + config_hash = hashlib.md5(config_info.encode()).hexdigest()[:8] + timestamp = int(time.time()) + + dag_filename = f"dag_{config.name}_{config_hash}_{timestamp}.png" + A.draw(dag_filename) + for block_name in self.topo_order_list: + block = self._name_to_blocks[block_name] + layer = block.WhichOneof("layer") + if layer in {"input_layer", "raw_input", "embedding_layer"}: + # Register input-related layer, needs 1 input + if len(block.inputs) != 1: + raise ValueError( + "input layer `%s` takes only one input" % block.name + ) + one_input = block.inputs[0] + name = one_input.WhichOneof("name") + if name != "feature_group_name": + raise KeyError( + "`feature_group_name` should be set for input layer: " + + block.name + ) + group = one_input.feature_group_name + + if group in input_feature_groups: + # Already exists, do not register again + if layer == "input_layer": + logger.warning( + "input `%s` already exists in other block" % group + ) + elif layer == "raw_input": + raise NotImplementedError + elif layer == "embedding_layer": + raise NotImplementedError + else: + input_fn = EmbeddingGroup( + features=self._features, + feature_groups=self._feature_groups, + wide_embedding_dim=self._wide_embedding_dim, + wide_init_fn=self._wide_init_fn, + ) + if layer == "input_layer": + # Use dimension inference engine + dim_info = create_dimension_info_from_embedding( + input_fn, + group, + batch_size=None, + ) + self.dim_engine.register_output_dim(block.name, dim_info) + self._name_to_output_dim[block.name] = ( + dim_info.get_feature_dim() + ) + + input_feature_groups[group] = ( + embedding_group # not a layer is a dim + ) + elif layer == "raw_input": + raise NotImplementedError + else: # embedding_layer + raise NotImplementedError + self._name_to_layer[block.name] = input_fn + # If module is None, it may be a sequential module + elif layer is not None: + # Use the dimension inference engine to handle multiple input dimensions + input_dim_infos = [] + + for input_node in block.inputs: + if (len(block.inputs)) > 1: + logger.debug( + f"Processing multiple inputs for block {block.name}: {[getattr(n, n.WhichOneof('name')) for n in block.inputs]}" # NOQA + ) + input_type = input_node.WhichOneof("name") + input_name = getattr(input_node, input_type) + # Parse input_fn & input_slice + input_fn = getattr(input_node, "input_fn", None) + input_slice = getattr(input_node, "input_slice", None) + + if input_type == "package_name": + # package is a sub-DAG as input to Block + raise NotImplementedError + else: # block_name or feature_group_name + # Get input dimension info from dimension inference engine + input_dim_info = self.dim_engine.get_output_dim(input_name) + + # If it is a recurrent or repeat layer + # To ensure the latest output dimensions, + # need to do some processing first. + if input_name in self._name_to_blocks: + input_block = self._name_to_blocks[input_name] + input_layer_type = input_block.WhichOneof("layer") + if input_layer_type in ["recurrent", "repeat"]: + # Get the latest output dimension + if input_name in self._name_to_output_dim: + latest_output_dim = self._name_to_output_dim[ + input_name + ] + latest_dim_info = DimensionInfo(latest_output_dim) + logger.info( + f"Overriding dim_engine cache for {input_layer_type} layer {input_name}: {latest_output_dim}" # NOQA + ) + # Updated dimension inference engine + self.dim_engine.register_output_dim( + input_name, latest_dim_info + ) + input_dim_info = latest_dim_info + else: + logger.warning( + f"{input_layer_type} layer {input_name} not found in _name_to_output_dim" # NOQA + ) + # Apply input_fn and input_slice transformations + if input_fn or input_slice: + input_dim_info = self.dim_engine.apply_input_transforms( + input_dim_info, input_fn, input_slice + ) + + input_dim_infos.append(input_dim_info) + + # Merge dimension info of multiple inputs + if len(input_dim_infos) == 1: + merged_input_dim = input_dim_infos[0] + else: + # Determine the merging method based on block configuration + merge_mode = ( + "list" + if getattr(block, "merge_inputs_into_list", False) + else "concat" + ) + merged_input_dim = self.dim_engine.merge_input_dims( + input_dim_infos, merge_mode + ) + + # Register input dimension + self.dim_engine.register_input_dim(block.name, merged_input_dim) + self._name_to_input_dim[block.name] = merged_input_dim.get_total_dim() + + # Add debug info + logger.info( + f"Block {block.name} input dimensions: merged_input_dim={merged_input_dim}, total_dim={merged_input_dim.get_total_dim()}" # NOQA + ) + if merged_input_dim.is_list: + logger.info( + f" - is_list=True, dims_list={merged_input_dim.to_list()}" + ) + else: + logger.info( + f" - is_list=False, feature_dim={merged_input_dim.get_feature_dim()}" # NOQA + ) + + # define layer + self.define_layers(layer, block, block.name) + + # Register the layer to the dimension inference engine + if block.name in self._name_to_layer: + layer_obj = self._name_to_layer[block.name] + self.dim_engine.register_layer(block.name, layer_obj) + + # Lambda module require dimension inference + if isinstance(layer_obj, LambdaWrapper): + output_dim_info = layer_obj.infer_output_dim(merged_input_dim) + logger.info( + f"Lambda layer {block.name} inferred output dim: {output_dim_info}" # NOQA + ) + else: + # Check if it is already a recurrent or repeat layer + # if so skip output dimension inference + if layer in {"recurrent", "repeat"}: + # Output dimension is already set in define_layers, + # no need to infer again + output_dim_info = self.dim_engine.get_output_dim(block.name) + if output_dim_info is None: + # If not in dimension inference engine, + # get from self._name_to_output_dim + if block.name in self._name_to_output_dim: + output_dim = self._name_to_output_dim[block.name] + output_dim_info = DimensionInfo(output_dim) + self.dim_engine.register_output_dim( + block.name, output_dim_info + ) + logger.info( + f"{layer.capitalize()} layer {block.name} output dim restored from compatibility field: {output_dim}" # NOQA + ) + else: + raise ValueError( + f"{layer.capitalize()} layer {block.name} missing output dimension" # NOQA + ) + else: + logger.info( + f"{layer.capitalize()} layer {block.name} output dim already set: {output_dim_info}" # NOQA + ) + else: + # Inferred output dimensions + output_dim_info = self.dim_engine.infer_layer_output_dim( + layer_obj, merged_input_dim + ) + + self.dim_engine.register_output_dim(block.name, output_dim_info) + self._name_to_output_dim[block.name] = ( + output_dim_info.get_feature_dim() + ) + + logger.info( + f"Block {block.name} output dimensions: output_dim_info={output_dim_info}, feature_dim={output_dim_info.get_feature_dim()}" # NOQA + ) + else: + # Check if it is a recurrent or repeat layer, and if so, + # do not overwrite the set output dimension. + layer_type = layer + if layer_type in ["recurrent", "repeat"]: + # The output dimensions of the recurrent layer have been set + # in define_layers and are no need to overwrite. + existing_output_dim_info = self.dim_engine.get_output_dim( + block.name + ) + existing_output_dim = self._name_to_output_dim.get(block.name) + logger.info( + f"[SKIP OVERRIDE] {layer_type.capitalize()} layer {block.name} - keeping existing output dim: engine={existing_output_dim_info}, compat={existing_output_dim}" # NOQA + ) + logger.info( + f"Skipping override for {layer_type} layer {block.name} - keeping existing output dimensions" # NOQA + ) + else: + # Use input dimensions as output dimensions + self.dim_engine.register_output_dim( + block.name, merged_input_dim + ) + self._name_to_output_dim[block.name] = ( + merged_input_dim.get_feature_dim() + ) + + logger.info( + f"Block {block.name} (no layer) output dimensions: output_dim_info={merged_input_dim}, feature_dim={merged_input_dim.get_feature_dim()}" # NOQA + ) + else: # layer is None, e.g. sequential + if len(block.inputs) == 0: + # sequential block without inputs, use input_dim_info + raise ValueError( + f"Sequential block {block.name} has no input dimensions registered" # NOQA + ) + else: + # sequential block with inputs, use merged input dimensions + for input_node in block.inputs: + input_type = input_node.WhichOneof("name") + input_name = getattr(input_node, input_type) + # Parsing input_fn & input_slice does + # not support input_fn & input_slice in sequential + input_fn = getattr(input_node, "input_fn", None) + input_slice = getattr(input_node, "input_slice", None) + + if input_type == "package_name": + # The package is the sub-DAG as the input of the Block + # Nested packages in sequential modules + input_dim_info = self.dim_engine.get_output_dim(input_name) + raise NotImplementedError + else: # block_name or feature_group_name + # Get input dimension info from dimension inference engine + input_dim_info = self.dim_engine.get_output_dim(input_name) + # Dimension inference for sequential layers + prev_output_dim_info = input_dim_info + prev_output_dim = input_dim_info.get_feature_dim() + last_output_dim_info = None + last_output_dim = None + for i, layer_cnf in enumerate(block.layers): + layer = layer_cnf.WhichOneof("layer") + name_i = "%s_l%d" % (block.name, i) # e.g. block1_l0 + # Register input dimension + self.dim_engine.register_input_dim(name_i, prev_output_dim_info) + self._name_to_input_dim[name_i] = prev_output_dim + # Define layer + self.define_layers(layer, layer_cnf, name_i) + # Register layer to dimension inference engine + if name_i in self._name_to_layer: + layer_obj = self._name_to_layer[name_i] + self.dim_engine.register_layer(name_i, layer_obj) + # Infer output dimension + if isinstance(layer_obj, LambdaWrapper): + output_dim_info = layer_obj.infer_output_dim( + prev_output_dim_info + ) + else: + output_dim_info = self.dim_engine.infer_layer_output_dim( + layer_obj, prev_output_dim_info + ) + self.dim_engine.register_output_dim(name_i, output_dim_info) + self._name_to_output_dim[name_i] = ( + output_dim_info.get_feature_dim() + ) + # Update prev to current output + prev_output_dim_info = output_dim_info + prev_output_dim = output_dim_info.get_feature_dim() + last_output_dim_info = output_dim_info + last_output_dim = output_dim_info.get_feature_dim() + else: + raise ValueError( + f"Sequential layer {name_i} not found in _name_to_layer" + ) + # The block output dimension is the last layer output + if last_output_dim_info is not None: + self.dim_engine.register_output_dim( + block.name, last_output_dim_info + ) + self._name_to_output_dim[block.name] = last_output_dim + logger.info( + f"Sequential block {block.name} output dim set to {last_output_dim}" # NOQA + ) + else: + raise ValueError( + f"Cannot determine output dimension for sequential block {block.name}" # NOQA + ) + + # ======= Post-processing, output node inference ======= + input_feature_groups = self._feature_group_inputs + num_groups = len(input_feature_groups) # Number of input_feature_groups + # Subtract the number of input feature groups, + # blocks contain feature_groups e.g. feature group user + num_blocks = len(self._name_to_blocks) - num_groups + assert num_blocks > 0, "there must be at least one block in backbone" + # num_pkg_input = 0 + # Processing multiple pkgs is not yet supported + # Optional: Check package inputs + + # If concat_blocks is not configured, + # automatically concatenate all leaf nodes of the DAG and output + if len(config.concat_blocks) == 0 and len(config.output_blocks) == 0: + # Get all leaf nodes + leaf = [node for node in self.G.nodes() if self.G.out_degree(node) == 0] + logger.warning( + ( + f"{config.name} has no `concat_blocks` or `output_blocks`, " + f"try to concat all leaf blocks: {','.join(leaf)}" + ) + ) + self._config.concat_blocks.extend(leaf) + + Package.__packages[self._config.name] = self + + # Output dimension inference summary + dim_summary = self.dim_engine.get_summary() + logger.info(f"{config.name} dimension inference summary: {dim_summary}") + + # Output detailed dimension info for all blocks + logger.info("=== Final dimension summary ===") + for block_name in self.topo_order_list: + if block_name in self._name_to_input_dim: + input_dim = self._name_to_input_dim[block_name] + output_dim = self._name_to_output_dim.get(block_name, "N/A") + dim_engine_output = self.dim_engine.get_output_dim(block_name) + logger.info( + f"Block {block_name}: input_dim={input_dim}, output_dim={output_dim}, dim_engine={dim_engine_output}" # NOQA + ) + + logger.info( + "%s layers: %s" % (config.name, ",".join(self._name_to_layer.keys())) + ) + + def get_output_block_names(self) -> List[str]: + """Returns the final output block name list (prefer concat_blocks, otherwise output_blocks).""" # NOQA + blocks = list(getattr(self._config, "concat_blocks", [])) + if not blocks: + blocks = list(getattr(self._config, "output_blocks", [])) + return blocks + + def get_dimension_summary(self) -> Dict[str, Any]: + """Get detailed summary information of dimension inference.""" + summary = self.dim_engine.get_summary() + summary.update( + { + "config_name": self._config.name, + "total_layers": len(self._name_to_layer), + "output_blocks": list(getattr(self._config, "output_blocks", [])), + "concat_blocks": list(getattr(self._config, "concat_blocks", [])), + "final_output_dims": self.output_block_dims(), + "total_output_dim": self.total_output_dim(), + } + ) + return summary + + def output_block_dims(self) -> List[int]: + """Return a list of dimensions of the final output blocks, e.g. [160, 96].""" + blocks = self.get_output_block_names() + dims = [] + for block in blocks: + dim_info = self.dim_engine.get_output_dim(block) + logger.info(f"Output block `{block}` dimension info: {dim_info}") + if dim_info is not None: + dims.append(dim_info.get_feature_dim()) + elif block in self._name_to_output_dim: + dims.append(self._name_to_output_dim[block]) + else: + raise ValueError(f"block `{block}` not in output dims") + return dims + + def total_output_dim(self) -> int: + """Return the total dimension of the final output after concatenation.""" + return sum(self.output_block_dims()) + + def define_layers( + self, layer: str, layer_cnf: backbone_pb2.Block, name: str + ) -> None: + """Define layers. + + Args: + layer (str): the type of layer, e.g., 'module', 'recurrent', 'repeat'. + layer_cnf (backbone_pb2.Block): the configuration of the layer. + class_name: "MLP" mlp { + hidden_units: 512 + hidden_units: 256 + hidden_units: 128 + activation: "nn.ReLU" + } + name (str): the name of the layer. e.g., 'user_mlp'. + """ + if layer == "module": + layer_cls, customize = self.load_torch_layer( + layer_cnf.module, name, self._name_to_input_dim.get(name, None) + ) + self._name_to_layer[name] = layer_cls + self._name_to_customize[name] = customize + elif layer == "recurrent": + torch_layer = layer_cnf.recurrent.module + # Get the input dimension info of the parent layer, + # used for child layer dimension inference + parent_input_dim_info = self.dim_engine.block_input_dims.get(name) + parent_input_dim = self._name_to_input_dim.get(name, None) + + # Check if there is a fixed_input_index configuration + fixed_input_index = getattr(layer_cnf.recurrent, "fixed_input_index", None) + + # If fixed_input_index exists and parent_input_dim_info is a list, + # special handling is needed + child_input_dim_info = parent_input_dim_info + child_input_dim = parent_input_dim + + if fixed_input_index is not None and parent_input_dim_info is not None: + if parent_input_dim_info.is_list: + # Take the dimension specified by fixed_input_index from the list + dims_list = parent_input_dim_info.to_list() + if fixed_input_index < len(dims_list): + fixed_dim = dims_list[fixed_input_index] + child_input_dim_info = DimensionInfo(fixed_dim) + child_input_dim = fixed_dim + logger.info( + f"Recurrent layer {name} using fixed_input_index={fixed_input_index}, child input_dim={fixed_dim}" # NOQA + ) + else: + logger.warning( + f"fixed_input_index={fixed_input_index} out of range for input dims: {dims_list}" # NOQA + ) + + # record the output dimension of the last child layer + last_output_dim_info = None + last_output_dim = None + + for i in range(layer_cnf.recurrent.num_steps): + name_i = "%s_%d" % (name, i) + + # Register input dimension info for each child layer + if child_input_dim_info is not None: + self.dim_engine.register_input_dim(name_i, child_input_dim_info) + if child_input_dim is not None: + self._name_to_input_dim[name_i] = child_input_dim + + # Load the child layer, passing the correct input_dim parameter + layer_obj, customize = self.load_torch_layer( + torch_layer, name_i, child_input_dim + ) + self._name_to_layer[name_i] = layer_obj + self._name_to_customize[name_i] = customize + + # Register the child layer with the dimension inference engine + self.dim_engine.register_layer(name_i, layer_obj) + + # Infer the output dimension of the child layer + if child_input_dim_info is not None: + if isinstance(layer_obj, LambdaWrapper): + output_dim_info = layer_obj.infer_output_dim( + child_input_dim_info + ) + else: + output_dim_info = self.dim_engine.infer_layer_output_dim( + layer_obj, child_input_dim_info + ) + + self.dim_engine.register_output_dim(name_i, output_dim_info) + self._name_to_output_dim[name_i] = output_dim_info.get_feature_dim() + + # Record the output dimension of the last child layer + last_output_dim_info = output_dim_info + last_output_dim = output_dim_info.get_feature_dim() + else: + raise ValueError( + f"Cannot determine output dimension for layer {name_i}" + ) + + # Set the output dimension of the parent layer (recurrent layer) to + # the output dimension of the last child layer + if last_output_dim_info is not None: + # Updates the dimension inference engine and self._name_to_output_dim + self.dim_engine.register_output_dim(name, last_output_dim_info) + self._name_to_output_dim[name] = last_output_dim + logger.info( + f"Recurrent layer {name} output dim set to {last_output_dim} (from last child layer)" # NOQA + ) + logger.info(f" - last_output_dim_info: {last_output_dim_info}") + logger.info( + f" - Updated _name_to_output_dim[{name}]: {self._name_to_output_dim[name]}" # NOQA + ) + + # Verify that the update was successful + updated_dim_info = self.dim_engine.get_output_dim(name) + logger.info( + f"[VERIFY] Updated dim_engine output for {name}: {updated_dim_info}" + ) + else: + raise ValueError(f"Cannot determine input dimension for layer {name}") + elif layer == "repeat": + torch_layer = layer_cnf.repeat.module + # Get the input dimension information of the parent layer + # for dimension inference of the child layer + parent_input_dim_info = self.dim_engine.block_input_dims.get(name) + parent_input_dim = self._name_to_input_dim.get(name, None) + + # Used to record the output dimension of the last child layer + last_output_dim_info = None + last_output_dim = None + + for i in range(layer_cnf.repeat.num_repeat): + name_i = "%s_%d" % (name, i) + + # Register input dimension info for each child layer + if parent_input_dim_info is not None: + self.dim_engine.register_input_dim(name_i, parent_input_dim_info) + if parent_input_dim is not None: + self._name_to_input_dim[name_i] = parent_input_dim + + # Load the child layer, + # passing the correct input_dim parameter + layer_obj, customize = self.load_torch_layer( + torch_layer, name_i, parent_input_dim + ) + self._name_to_layer[name_i] = layer_obj + self._name_to_customize[name_i] = customize + + # Register child layer to dimension inference engine + self.dim_engine.register_layer(name_i, layer_obj) + + # Infer the output dimension of the child layer + if parent_input_dim_info is not None: + if isinstance(layer_obj, LambdaWrapper): + output_dim_info = layer_obj.infer_output_dim( + parent_input_dim_info + ) + else: + output_dim_info = self.dim_engine.infer_layer_output_dim( + layer_obj, parent_input_dim_info + ) + + self.dim_engine.register_output_dim(name_i, output_dim_info) + self._name_to_output_dim[name_i] = output_dim_info.get_feature_dim() + + # Record the output dimension of the last child layer + last_output_dim_info = output_dim_info + last_output_dim = output_dim_info.get_feature_dim() + else: + raise ValueError( + f"Cannot determine output dimension for layer {name_i}" + ) + + # Calculate the output dimension of the parent layer (repeat layer), + # taking into account the output_concat_axis configuration + if last_output_dim_info is not None: + final_output_dim_info = last_output_dim_info + final_output_dim = last_output_dim + + # Check if output_concat_axis is configured + # + # e.g., repeat maskblock 2 times and concatenate in + # the last dimension (output_concat_axis: -1). + # Equivalent to: [maskblock1, maskblock2] in the last dimension cat + if ( + hasattr(layer_cnf.repeat, "output_concat_axis") + and layer_cnf.repeat.output_concat_axis is not None + ): + axis = layer_cnf.repeat.output_concat_axis + num_repeat = layer_cnf.repeat.num_repeat + + # IF in the last dimension splicing (axis=-1), + # you need to multiply the dimension by the number of repeats + if axis == -1: + # The output dimension of a single child layer + # multiplied by repeat times + if last_output_dim is None: + raise ValueError( + f"Repeat layer {name}: last_output_dim is None, cannot infer final_output_dim" # NOQA + ) + if isinstance(last_output_dim, int): + final_output_dim = last_output_dim * num_repeat + final_output_dim_info = DimensionInfo(final_output_dim) + logger.info( + f"Repeat layer {name} with output_concat_axis={axis}: " + f"single_output_dim={last_output_dim} * num_repeat={num_repeat} = {final_output_dim}" # NOQA + ) + else: + # For the splicing of other axes, remain unchanged for now + # and require more complex dimension inference logic. + logger.warning( + f"Repeat layer {name} with output_concat_axis={axis}: " + f"non-last axis concatenation not fully supported, using single layer output dim={last_output_dim}" # NOQA + ) + else: + # If output_concat_axis is not configured, return as list format + num_repeat = layer_cnf.repeat.num_repeat + # Create dimension information in list format, + # containing num_repeat identical sub-layer output dimensions + list_dims = [last_output_dim] * num_repeat + final_output_dim_info = DimensionInfo(list_dims, is_list=True) + + # final_output_dim, by default uses the total dimension of the list + # In actual use, the correct dimension information should + # be obtained through the dimension inference engine + final_output_dim = sum(list_dims) # pyre-ignore[6] + + logger.info( + f"Repeat layer {name} without output_concat_axis: returns list of {num_repeat} outputs, " # NOQA + f"each with dim={last_output_dim}, list_dims={list_dims}" + ) + + self.dim_engine.register_output_dim(name, final_output_dim_info) + self._name_to_output_dim[name] = final_output_dim + logger.info( + f"Repeat layer {name} final output dim set to {final_output_dim}" + ) + else: + raise ValueError(f"Cannot determine output dimension for layer {name}") + elif layer == "lambda": + expression = getattr(layer_cnf, "lambda").expression + lambda_layer = LambdaWrapper(expression, name=name) + self._name_to_layer[name] = lambda_layer + self._name_to_customize[name] = True + + def load_torch_layer( + self, + layer_conf: torch_layer_pb2.TorchLayer, + name: str, + input_dim: Optional[int] = None, + ) -> Tuple[Optional[nn.Module], bool]: + """Dynamically load and initialize a torch layer based on configuration. + + Args: + layer_conf: Layer configuration containing class name and parameters. + name (str): Name of the layer to be created. + input_dim (int, optional): Input dimension for the layer. + + Returns: + tuple: A tuple containing (layer_instance, customize_flag) where + layer_instance is the initialized layer object and customize_flag + indicates if it's a custom implementation. + + Raises: + ValueError: If the layer class name is invalid or layer creation fails. + """ + # customize indicates whether it is a custom implementation + layer_cls, customize = load_torch_layer(layer_conf.class_name) + if layer_cls is None: + raise ValueError("Invalid torch layer class name: " + layer_conf.class_name) + param_type = layer_conf.WhichOneof("params") + # st_params is a parameter configured + # in the google.protobuf.Struct object format; + # can also pass parameters to the loaded Layer object + # in a custom protobuf message format. + if customize: + if param_type is None: # No additional parameters + # Get the constructor signature + sig = inspect.signature(layer_cls.__init__) + kwargs = {} + elif param_type == "st_params": + params = Parameter(layer_conf.st_params, True) + kwargs = config_to_kwargs(params) # pyre-ignore[6] + sig = inspect.signature(layer_cls.__init__) + # If param_type points to some other field in oneof, + # the code dynamically gets the value of that field via getattr, + # assuming it is a Protocol Buffer message (is_struct=False). + else: + pb_params = getattr(layer_conf, param_type) + params = Parameter(pb_params, False) + sig = inspect.signature(layer_cls.__init__) + kwargs = config_to_kwargs(params) # pyre-ignore[6] + + # Check if you need to automatically infer the input dimension parameters + input_dim_params_in_sig = [ + param for param in INPUT_DIM_PARAMS if param in sig.parameters + ] + if input_dim_params_in_sig: + input_dim_params_missing = [ + param for param in INPUT_DIM_PARAMS if param not in kwargs + ] + if input_dim_params_missing: + # Get input dimensions from the dimension inference engine + input_dim_info = self.dim_engine.block_input_dims.get(name) + if input_dim_info is not None: + # For modules that receive multiple independent tensors, + # check whether sum operation should be avoided + should_use_single_dim = False + + # Check method: whether the forward method + # accepts multiple tensor parameters + if hasattr(layer_cls, "forward"): + try: + forward_sig = inspect.signature(layer_cls.forward) + forward_params = [ + p + for p in forward_sig.parameters.keys() + if p != "self" + ] + # If forward method has 2 or more non-self parameters, + # it may be multiple tensor inputs + if len(forward_params) >= 2: + should_use_single_dim = True + logger.info( + f"Detected multi-tensor input module {layer_cls.__name__} with {len(forward_params)} forward parameters" # NOQA + ) + except Exception as err: + raise ValueError( + f"Failed to inspect forward method of {layer_cls.__name__} for dimension inference" # NOQA + ) from err + if ( + should_use_single_dim + and input_dim_info.is_list + and isinstance(input_dim_info.dim, (list, tuple)) + ): + # For forward modules that require multiple tensor inputs, + # use the dimensions in list format. + for idx, param_name in enumerate(input_dim_params_in_sig): + kwargs[param_name] = input_dim_info.dim[idx] + logger.info( + f"Layer {name} ({layer_cls.__name__}) auto-inferred {param_name}={input_dim_info.dim[idx]} from input dim list" # NOQA + ) + else: + # For other modules, use the total dimension + feature_dim = input_dim_info.get_feature_dim() + for param_name in input_dim_params_in_sig: + kwargs[param_name] = feature_dim + logger.info( + f"Layer {name} ({layer_cls.__name__}) auto-inferred {param_name}={feature_dim} from dim_engine" # NOQA + ) + else: + logger.error( + f"Layer {name} ({layer_cls.__name__}) dimension inference failed - no input_dim available" # NOQA + ) + logger.error( + f" - input_dim_info from dim_engine: {input_dim_info}" + ) + logger.error(f" - input_dim: {input_dim}") + logger.error( + f" - block_input_dims keys: {list(self.dim_engine.block_input_dims.keys())}" # NOQA + ) + if name in self._name_to_input_dim: + logger.error( + f" - _name_to_input_dim[{name}]: {self._name_to_input_dim[name]}" # NOQA + ) + raise ValueError( + f"Cannot automatically infer {', '.join(missing_params)} for {layer_cls.__name__} {name}. " # NOQA + "Please ensure correct input feature groups are configured or manually specify these parameters." # NOQA + ) + + # sequence_dim and query_dim are automatically inferred + sequence_dim_missing = ( + SEQUENCE_QUERY_PARAMS[0] in sig.parameters + and SEQUENCE_QUERY_PARAMS[0] not in kwargs + ) + query_dim_missing = ( + SEQUENCE_QUERY_PARAMS[1] in sig.parameters + and SEQUENCE_QUERY_PARAMS[1] not in kwargs + ) + + if sequence_dim_missing or query_dim_missing: + # Get the input information of the current block + block_config = self._name_to_blocks[name] + input_dims = self._infer_sequence_query_dimensions(block_config, name) + + if input_dims: + sequence_dim, query_dim = input_dims + if sequence_dim_missing: + kwargs[SEQUENCE_QUERY_PARAMS[0]] = sequence_dim + if query_dim_missing: + kwargs[SEQUENCE_QUERY_PARAMS[1]] = query_dim + logger.info( + f"Auto-inferred dimensions for {layer_cls.__name__} {name}: " # NOQA + f"{SEQUENCE_QUERY_PARAMS[0]}={sequence_dim if sequence_dim_missing else 'provided'}, " # NOQA + f"{SEQUENCE_QUERY_PARAMS[1]}={query_dim if query_dim_missing else 'provided'}" # NOQA + ) + else: + missing_params = [] + if sequence_dim_missing: + missing_params.append(SEQUENCE_QUERY_PARAMS[0]) + if query_dim_missing: + missing_params.append(SEQUENCE_QUERY_PARAMS[1]) + raise ValueError( + f"Cannot automatically infer {', '.join(missing_params)} for {layer_cls.__name__} {name}. " # NOQA + "Please ensure correct input feature groups are configured or manually specify these parameters." # NOQA + ) + layer = layer_cls(**kwargs) + return layer, customize + elif param_type is None: # internal torch layer + layer = layer_cls() + return layer, customize + else: # st_params parameter + assert param_type == "st_params", ( + "internal torch layer only support st_params as parameters" + ) + try: + kwargs = convert_to_dict(layer_conf.st_params) + logger.info( + "call %s layer with params %r" % (layer_conf.class_name, kwargs) + ) + layer = layer_cls(**kwargs) + except TypeError as e: + logger.warning(e) + args = map(format_value, layer_conf.st_params.values()) + logger.info( + "try to call %s layer with params %r" + % (layer_conf.class_name, args) + ) + layer = layer_cls(*args, name=name) + return layer, customize + + def reset_input_config(self, config: backbone_pb2.BlockPackage) -> None: + """Reset the input configuration for this package. + + Args: + config: The new input configuration to set. + """ + self.input_config = config + + def _infer_sequence_query_dimensions(self, block_config, block_name): + """Inference module sequence_dim and query_dim. + + e.g. infer DINEncoder's sequence_dim and query_dim + Args: + block_config: Block configuration + block_name: Block name + + Returns: + tuple: (sequence_dim, query_dim) or None if inference fails + """ + sequence_dim = None + query_dim = None + + # Analyze the input and infer the dimension based on feature_group_name + for input_node in block_config.inputs: + input_type = input_node.WhichOneof("name") + input_name = getattr(input_node, input_type) + + if input_type == "feature_group_name": + # get the sequence and query dimensions from the embedding group + dims = self._try_get_sequence_query_dims_from_group(input_name) + if dims: + sequence_dim, query_dim = dims + logger.info( + f"Auto-inferred dimensions from {input_name}: " + f"sequence_dim={sequence_dim}, query_dim={query_dim}" + ) + return sequence_dim, query_dim + else: + raise NotImplementedError + + # Check the inference results + if sequence_dim is not None and query_dim is not None: + return sequence_dim, query_dim + else: + logger.warning( + f"Could not infer sequence/query dimensions for {block_name}: " + f"sequence_dim={sequence_dim}, query_dim={query_dim}" + ) + return None + + def _try_get_sequence_query_dims_from_group( + self, group_name: str + ) -> Optional[Tuple[int, int]]: + """Get the sequence and query dimensions from the embedding group. + + Args: + group_name: embedding group name + + Returns: + tuple: (sequence_dim, query_dim) or None if failed + """ + # Check if group exists + if group_name not in self._name_to_layer: + logger.debug(f"Group {group_name} not found in _name_to_layer") + return None + + layer = self._name_to_layer[group_name] + + # Check if there is a group_total_dim method + if not hasattr(layer, "group_total_dim"): + logger.debug(f"Group {group_name} does not have group_total_dim method") + return None + + # Trying to get the dimensions of .sequence and .query subgroups + sequence_group_name = f"{group_name}.sequence" + query_group_name = f"{group_name}.query" + + try: + sequence_dim = layer.group_total_dim(sequence_group_name) + query_dim = layer.group_total_dim(query_group_name) + return sequence_dim, query_dim + except (KeyError, AttributeError, ValueError) as e: + logger.debug( + f"Could not get .sequence/.query dimensions for {group_name}: {type(e).__name__}: {e}" # NOQA + ) + return None + except Exception as e: + logger.warning( + f"Unexpected error getting dimensions for {group_name}: {type(e).__name__}: {e}" # NOQA + ) + return None + + def set_package_input(self, pkg_input:torch.Tensor ) -> None: + """Set the package input for this package. + + Args: + pkg_input: The input data to be used by this package. + """ + self._package_input = pkg_input + + def has_block(self, name: str) -> bool: + """Check if a block with the given name exists in this package. + + Args: + name (str): The name of the block to check for. + + Returns: + bool: True if the block exists, False otherwise. + """ + return name in self._name_to_blocks + + def block_outputs(self, name): + """Get the output of a specific block by name. + + Args: + name (str): The name of the block to retrieve outputs for. + + Returns: + Any: The output of the specified block, or None if not found. + """ + return self._block_outputs.get(name, None) + + def block_input( + self, config: backbone_pb2.Block, block_outputs: dict, **kwargs: dict + ) -> list: + """Process and merge inputs for a block based on its configuration. + + Args: + config: Block configuration containing input specifications. + block_outputs (dict): Dictionary of outputs from previously executed blocks. + **kwargs: Additional keyword arguments passed to downstream components. + + Returns: + torch.Tensor or list: Processed and merged input data ready for the block. + """ + inputs = [] + # Traverse each input node configured by config.inputs + for input_node in config.inputs: + input_type = input_node.WhichOneof("name") + input_name = getattr(input_node, input_type) + + if input_type == "use_package_input": + input_feature = self._package_input + input_name = "package_input" + + elif input_type == "package_name": + if input_name not in Package.__packages: + raise KeyError(f"package name `{input_name}` does not exist") + package = Package.__packages[input_name] + if input_node.HasField("reset_input"): + package.reset_input_config(input_node.reset_input) + if input_node.HasField("package_input"): + pkg_input_name = input_node.package_input + if pkg_input_name in block_outputs: + pkg_input = block_outputs[pkg_input_name] + else: + if pkg_input_name not in Package.__packages: + raise KeyError( + f"package name `{pkg_input_name}` does not exist" + ) + inner_package = Package.__packages[pkg_input_name] + pkg_input = inner_package() + if input_node.HasField("package_input_fn"): + fn = eval(input_node.package_input_fn) + pkg_input = fn(pkg_input) + package.set_package_input(pkg_input) + input_feature = package(**kwargs) + + elif input_name in block_outputs: + input_feature = block_outputs[input_name] + + else: + input_feature = Package.backbone_block_outputs(input_name) + + if input_feature is None: + raise KeyError(f"input name `{input_name}` does not exist") + + if getattr(input_node, "ignore_input", False): + continue + + # Get an element of the input tuple/list as input through slice syntax + if input_node.HasField("input_slice"): + fn = eval("lambda x: x" + input_node.input_slice.strip()) + input_feature = fn(input_feature) + + if input_node.HasField("input_fn"): + # Specify a lambda function to perform transformation on the input. + # e.g.,input_fn: 'lambda x: [x]' + fn = eval(input_node.input_fn) + input_feature = fn(input_feature) + # Need to recalculate input_dim + inputs.append(input_feature) + + # merge inputs + if getattr(config, "merge_inputs_into_list", False): + output = inputs + else: + try: + # merge_inputs need self define,e.g. torch.cat + # Assuming config.input_concat_axis is defined, usually 1 + output = merge_inputs( + inputs, + axis=getattr(config, "input_concat_axis", 1), + msg=config.name, + ) + except ValueError as e: + msg = getattr(e, "message", str(e)) + logger.error(f"merge inputs of block {config.name} failed: {msg}") + raise e + # To perform additional transformations on the merged multi-channel + # input results, you need to configure it in the format of a lambda function. + if config.HasField("extra_input_fn"): + fn = eval(config.extra_input_fn) + output = fn(output) + + return output + + def forward( + self, batch: Batch, **kwargs: dict + ) -> Union[torch.Tensor, List[torch.Tensor]]: + """Execute forward pass through the package DAG. + + Args: + batch (Any, optional): Input batch data. Defaults to None. + **kwargs: Additional keyword arguments passed to layers. + + Returns: + torch.Tensor or List[torch.Tensor]: Output tensor(s) from the package. + + Raises: + ValueError: If required output blocks are not found. + KeyError: If input names are invalid or not found. + """ + block_outputs = {} + self._block_outputs = block_outputs # reset + blocks = self.topo_order_list + logger.info(self._config.name + " topological order: " + ",".join(blocks)) + + for block in blocks: # Traverse blocks + if block not in self._name_to_blocks: + # package block + assert block in Package.__packages, "invalid block: " + block + continue + config = self._name_to_blocks[block] + # Case 1: sequential layers + if hasattr(config, "layers") and config.layers: + logger.info("call sequential %d layers" % len(config.layers)) + output = self.block_input(config, block_outputs, **kwargs) + for i, layer in enumerate(config.layers): + name_i = "%s_l%d" % (block, i) + output = self.call_layer(output, layer, name_i, **kwargs) + block_outputs[block] = output + continue + + # Case 2: single layer just one of layer + layer_type = config.WhichOneof("layer") + if layer_type is None: # identity layer + output = self.block_input(config, block_outputs, **kwargs) + block_outputs[block] = output + elif layer_type == "raw_input": + block_outputs[block] = self._name_to_layer[block] + elif layer_type == "input_layer": + if ( + block in self._name_to_layer + and self._name_to_layer[block] is not None + ): + input_fn = self._name_to_layer[block] # embedding group + else: + input_fn = self._embedding_group + # no block input itself + input_config = config.input_layer + if self.input_config is not None: + input_config = self.input_config + if hasattr(input_fn, "reset"): + input_fn.reset(input_config) + if batch is not None: + embedding_outputs = input_fn(batch) + if ( + isinstance(embedding_outputs, dict) + and block in embedding_outputs + ): + block_outputs[block] = embedding_outputs[block] + else: + # If the returned value is not a dictionary or does not + # have a corresponding key, use the entire output. + block_outputs[block] = embedding_outputs + if isinstance(block_outputs[block], torch.Tensor): + logger.info( + f"block_outputs[{block}]shape: {block_outputs[block].shape}" # NOQA + ) + else: + logger.info( + f"block_outputs[{block}] type: {type(block_outputs[block])}" + ) + else: + embedding_outputs = input_fn(input_config) + if ( + isinstance(embedding_outputs, dict) + and block in embedding_outputs + ): + block_outputs[block] = embedding_outputs[block] + else: + block_outputs[block] = embedding_outputs + elif layer_type == "embedding_layer": + input_fn = self._name_to_layer[block] + feature_group = config.inputs[0].feature_group_name + inputs, _, weights = self._feature_group_inputs[feature_group] + block_outputs[block] = input_fn([inputs, weights]) + else: + # Custom module, e.g. mlp + inputs = self.block_input(config, block_outputs, **kwargs) + output = self.call_layer(inputs, config, block, **kwargs) + block_outputs[block] = output + + # Collect outputs + outputs = [] + for output in getattr(self._config, "output_blocks", []): + if output in block_outputs: + outputs.append(block_outputs[output]) + else: + raise ValueError("No output `%s` of backbone to be concat" % output) + if outputs: + return outputs + + for output in getattr(self._config, "concat_blocks", []): + if output in block_outputs: + outputs.append(block_outputs[output]) + else: + raise ValueError("No output `%s` of backbone to be concat" % output) + + try: + logger.info(f"Number of outputs to merge: {len(outputs)}") + # Log each output's shape + for i, out in enumerate(outputs): + if isinstance(out, torch.Tensor): + logger.info(f"Output {i} shape: {out.shape}") + elif isinstance(out, (list, tuple)): + logger.info(f"Output {i} is a list/tuple with {len(out)} elements.") + else: + logger.info(f"Output {i} is of type {type(out)}") + # merge_inputs + output = merge_inputs(outputs, msg="backbone") + except Exception as e: + logger.error("merge backbone's output failed: %s", str(e)) + raise e + return output + + def _determine_input_format(self, layer_obj, inputs:Union[torch.Tensor, dict])-> Union[torch.Tensor, dict]: + """Determine the input format required by the module. + + Args: + layer_obj: The layer object to call + inputs: Input data (may be a tensor dict or a single tensor) + + Returns: + Input suitable for this layer + """ + try: + # Check the module's forward method signature + if hasattr(layer_obj, "forward"): + sig = inspect.signature(layer_obj.forward) + params = list(sig.parameters.keys()) + if "self" in params: + params.remove("self") + + # If the forward method has multiple parameters, + # it may require a dictionary input + if len(params) > 1: + logger.debug( + f"Layer {layer_obj.__class__.__name__} has multiple forward parameters: {params}" # NOQA + ) + # Check if a specific parameter name implies + # that a dictionary input is required + dict_indicators = [ + "grouped_features", + "feature_dict", + "inputs_dict", + "batch", + ] + if any(indicator in params for indicator in dict_indicators): + logger.info( + f"Layer {layer_obj.__class__.__name__} likely needs dict input" # NOQA + ) + return inputs # Return to original dictionary format + + # Check whether it is a sequence-related module + class_name = layer_obj.__class__.__name__ + sequence_modules = [ + "DINEncoder", + "SimpleAttention", + "PoolingEncoder", + "DIN", + ] + if any(seq_name in class_name for seq_name in sequence_modules): + logger.info( + f"Layer {class_name} is a sequence module, using dict input" + ) + return inputs # Sequence modules usually require a dictionary input + + # check if need dict format input + dict_attributes = SEQUENCE_QUERY_PARAMS + ["attention"] + if any(hasattr(layer_obj, attr) for attr in dict_attributes): + logger.info( + f"Layer {class_name} has sequence attributes, using dict input" + ) + return inputs + + # Default: If inputs is a dictionary and has only one value, + # extract that value + if isinstance(inputs, dict): + if len(inputs) == 1: + single_key = list(inputs.keys())[0] + single_value = inputs[single_key] + logger.debug( + f"Extracting single tensor from dict for {layer_obj.__class__.__name__}" # NOQA + ) + return single_value + else: + # In the case of multiple values, try concatenation + logger.debug( + f"Multiple values in dict, trying to concatenate for {layer_obj.__class__.__name__}" # NOQA + ) + tensor_list = list(inputs.values()) + if all(isinstance(t, torch.Tensor) for t in tensor_list): + try: + # Check if all tensors have + # the same number of dimensions + # except the last dimension + first_shape = tensor_list[0].shape + batch_size = first_shape[0] + + # If the number of dimensions is different, + # try flattening and then concatenating + flattened_tensors = [] + for t in tensor_list: + if len(t.shape) != len(first_shape): + # Flatten all dimensions except + # the batch dimension + flattened = t.view(batch_size, -1) + flattened_tensors.append(flattened) + else: + # If the number of dimensions is the same + # but the shape is different, flatten it + if t.shape[:-1] != first_shape[:-1]: + flattened = t.view(batch_size, -1) + flattened_tensors.append(flattened) + else: + flattened_tensors.append(t) + + result = torch.cat(flattened_tensors, dim=-1) + logger.debug( + f"Successfully concatenated tensors, final shape: {result.shape}" # NOQA + ) + return result + except Exception as e: + logger.debug( + f"Failed to concatenate tensors: {e}, " + f"using first tensor" + ) + return tensor_list[0] + else: + # If the concatenation cannot be done, + # return the original dictionary. + # If it is not a dictionary, return it directly. + return inputs + return inputs + + except Exception as e: + logger.warning( + f"Error determining input format for " + f"{layer_obj.__class__.__name__}: {e}" + ) + return inputs # Returns the original input on error + + def call_torch_layer(self, inputs, name: str, **kwargs): # pyre-ignore[2] + """Call predefined torch Layer.""" + layer = self._name_to_layer[name] + cls = layer.__class__.__name__ + + # Determine input format + processed_inputs = self._determine_input_format(layer, inputs) + + # First try the processed input format + if self._try_call_layer(layer, processed_inputs, name, cls): + return self._last_output + + # If that fails and the input format has been modified, + # try the original input format + if processed_inputs is not inputs: + logger.info(f"Retrying {name} with original input format") + if self._try_call_layer(layer, inputs, name, cls): + logger.info(f"Successfully called {name} with original input format") + return self._last_output + else: + logger.error(f"Both input formats failed for {name}") + raise RuntimeError( + f"Layer {name} failed with both processed and original input formats" # NOQA + ) + else: + # If the input format has not changed, + # throw an exception directly + raise RuntimeError(f"Layer {name} ({cls}) failed to execute") + + def _try_call_layer( + self, layer, inputs, name: str, cls: str + ) -> bool: # pyre-ignore[2] + """Attempt to call the layer. + + Args: + layer: the layer object to call + inputs: input tensor data + name: layer name + cls: layer class name + + Returns: + bool: Returns True on success, False on failure + """ + try: + # Check the module's forward method signature + # to determine how to pass parameters + if hasattr(layer, "forward"): + sig = inspect.signature(layer.forward) + params = list(sig.parameters.keys()) + # parameters without default values + required_params = [ + p + for p in sig.parameters.values() + if p.default == inspect.Parameter.empty and p.name != "self" + ] + if "self" in params: + params.remove("self") + + # If inputs is a list/tuple and the layer expects + # multiple arguments, try spreading it out. + if ( + isinstance(inputs, (list, tuple)) + and len(params) > 1 + and ( + len(inputs) == len(params) + or len(required_params) >= len(inputs) + ) + ): + self._last_output = layer(*inputs) + logger.debug( + f"Layer {name} ({cls}) called successfully with {len(inputs)} separate arguments" # NOQA + ) + else: + # Default: single parameter passing + self._last_output = layer(inputs) + logger.debug( + f"Layer {name} ({cls}) called successfully with input type: {type(inputs)}" # NOQA + ) + else: + # no forward method, directly use + self._last_output = layer(inputs) + logger.debug( + f"Layer {name} ({cls}) called successfully with input type: {type(inputs)}" # NOQA + ) + return True + except Exception as e: + msg = getattr(e, "message", str(e)) + logger.error(f"Call layer {name} ({cls}) failed: {msg}") + return False + + def call_layer( + self, + inputs: torch.Tensor, + config: backbone_pb2.Block, + name: str, + **kwargs: dict, + ) -> torch.Tensor: + """Call a layer based on its configuration type. + + Args: + inputs: Input data to be processed by the layer. + config: Layer configuration containing layer type and parameters. + name (str): Name of the layer to be called. + **kwargs: Additional keyword arguments passed to the layer. + + Returns: + Output from the called layer. + + Raises: + NotImplementedError: If the layer type is not supported. + """ + layer_name = config.WhichOneof("layer") + if layer_name == "module": + return self.call_torch_layer(inputs, name, **kwargs) + elif layer_name == "recurrent": + return self._call_recurrent_layer(inputs, config, name, **kwargs) + elif layer_name == "repeat": + return self._call_repeat_layer(inputs, config, name, **kwargs) + elif layer_name == "lambda": + if name in self._name_to_layer and isinstance( + self._name_to_layer[name], LambdaWrapper + ): + lambda_wrapper = self._name_to_layer[name] + return lambda_wrapper(inputs) + else: + # execution lambda expression + conf = getattr(config, "lambda") + fn = eval(conf.expression) + return fn(inputs) + raise NotImplementedError("Unsupported backbone layer:" + layer_name) + + def _call_recurrent_layer( + self, + inputs: torch.Tensor, + config: backbone_pb2.Block, + name: str, + **kwargs: dict, + ) -> torch.Tensor: + """Call recurrent layer by iterating through all steps. + + Args: + inputs: Input data to be processed by the recurrent layer. + config: Recurrent layer configuration. + name (str): Name of the recurrent layer. + **kwargs: Additional keyword arguments passed to sub-layers. + + Returns: + Output from the last step of the recurrent layer. + """ + recurrent_config = config.recurrent + + # Fixed import index, default -1, display missing fixed import + fixed_input_index = -1 + if hasattr(recurrent_config, "fixed_input_index"): + fixed_input_index = recurrent_config.fixed_input_index + + # If there is a fixed input index, the input must be a list or tuple. + if fixed_input_index >= 0: + assert isinstance(inputs, (tuple, list)), ( + f"{name} inputs must be a list when using fixed_input_index" + ) + # Initialize output to input + output = inputs + for i in range(recurrent_config.num_steps): + name_i = f"{name}_{i}" + if name_i in self._name_to_layer: + # Calling child layer + output_i = self.call_torch_layer(output, name_i, **kwargs) + + if fixed_input_index >= 0: + # In case of fixed input index: + # update all inputs except the fixed index + j = 0 + for idx in range(len(output)): + if idx == fixed_input_index: + continue # Skip fixed input index + + if isinstance(output_i, (tuple, list)): + output[idx] = output_i[j] + else: + output[idx] = output_i + j += 1 + else: + # without fixed input index: directly replace the entire output + output = output_i + else: + logger.warning(f"Recurrent sub-layer {name_i} not found, skipping") + + if fixed_input_index >= 0: + # Delete the element corresponding to the fixed input index + output = list(output) + del output[fixed_input_index] + + if len(output) == 1: + return output[0] + return output + + return output + + def _call_repeat_layer( + self, + inputs: torch.Tensor, + config: backbone_pb2.Block, + name: str, + **kwargs: dict, + ) -> torch.Tensor: + """Call repeat layer by iterating through all repetitions. + + Args: + inputs: Input data to be processed by the repeat layer. + config: Repeat layer configuration. + name (str): Name of the repeat layer. + **kwargs: Additional keyword arguments passed to sub-layers. + + Returns: + Output based on configuration: single tensor, concatenated tensor, or + list of tensors. + """ + repeat_config = config.repeat + n_loop = repeat_config.num_repeat + outputs = [] + + # execute repeat + for i in range(n_loop): + name_i = f"{name}_{i}" + ly_inputs = inputs + + # Processing input_slice configuration + if hasattr(repeat_config, "input_slice") and repeat_config.input_slice: + fn = eval("lambda x, i: x" + repeat_config.input_slice.strip()) + ly_inputs = fn(ly_inputs, i) + + # Processing input_fn configuration + if hasattr(repeat_config, "input_fn") and repeat_config.input_fn: + fn = eval(repeat_config.input_fn) + ly_inputs = fn(ly_inputs, i) + + # Calling child layer + if name_i in self._name_to_layer: + output = self.call_torch_layer(ly_inputs, name_i, **kwargs) + outputs.append(output) + else: + logger.warning(f"Repeat sub-layer {name_i} not found, skipping") + + # Output format determined by configuration + if len(outputs) == 1: + return outputs[0] + + if ( + hasattr(repeat_config, "output_concat_axis") + and repeat_config.output_concat_axis is not None + ): + axis = repeat_config.output_concat_axis + return torch.cat(outputs, dim=axis) + + return outputs + + +class Backbone(nn.Module): + """Configurable Backbone Network.""" + + def __init__( + self, + config: backbone_pb2.BackboneTower, + features: List[BaseFeature], + embedding_group: EmbeddingGroup, + feature_groups: List[FeatureGroupConfig], + wide_embedding_dim: Optional[int] = None, + wide_init_fn: Optional[str] = None, + ) -> None: + super().__init__() + self._config = config + main_pkg = backbone_pb2.BlockPackage() + main_pkg.name = "backbone" + main_pkg.blocks.MergeFrom(config.blocks) + # If concat_blocks is not configured, + # concatenate all leaf nodes of the DAG and output them. + if config.concat_blocks: + main_pkg.concat_blocks.extend(config.concat_blocks) + if config.output_blocks: + # If the output of multiple blocks does not need + # to be concat together, but as a list type + # Use output_blocks instead of concat_blocks + main_pkg.output_blocks.extend(config.output_blocks) + + self._main_pkg = Package( + main_pkg, + features, + embedding_group, + feature_groups, + wide_embedding_dim, + wide_init_fn, + ) + for pkg in config.packages: + Package(pkg, features, embedding_group) # Package is a sub-DAG + + # initial top_mlp + self._top_mlp = None + if self._config.HasField("top_mlp"): + params = Parameter.make_from_pb(self._config.top_mlp) + + # Get total output dimensions from main_pkg + total_output_dim = self._main_pkg.total_output_dim() + + kwargs = config_to_kwargs(params) + self._top_mlp = MLP(in_features=total_output_dim, **kwargs) + + def forward(self, batch: Batch, **kwargs: dict) -> torch.Tensor: # pyre-ignore[2] + """Forward pass through the backbone network. + + Args: + batch (Any, optional): Input batch data. Defaults to None. + **kwargs: Additional keyword arguments. + + Returns: + torch.Tensor: Output tensor from the backbone network. + """ + output = self._main_pkg(batch, **kwargs) + + if hasattr(self, "_top_mlp") and self._top_mlp is not None: + if isinstance(output, (list, tuple)): + output = torch.cat(output, dim=-1) + output = self._top_mlp(output) + return output + + def output_dim(self) -> int: + """Get the final output dimension, taking into account of top_mlp.""" + if hasattr(self, "_top_mlp") and self._top_mlp is not None: + if hasattr(self._top_mlp, "output_dim"): + return self._top_mlp.output_dim() + elif hasattr(self._top_mlp, "hidden_units") and self._top_mlp.hidden_units: + # Returns the hidden_units of the last layer + return self._top_mlp.hidden_units[-1] + else: + # Trying to get the output dimension of the last layer from mlp + if hasattr(self._top_mlp, "mlp") and len(self._top_mlp.mlp) > 0: + last_layer = self._top_mlp.mlp[-1] + if hasattr(last_layer, "perceptron"): + # Get the output dimension of the last Perceptron linear layer + linear_layers = [ + module + for module in last_layer.perceptron + if isinstance(module, nn.Linear) + ] + if linear_layers: + return linear_layers[-1].out_features + elif isinstance(last_layer, nn.Linear): + return last_layer.out_features + + # If there is no top_mlp, return the output dimensions of main_pkg + return self._main_pkg.total_output_dim() + + +def merge_inputs( + inputs: List, axis: int = -1, msg: str = "" +) -> Union[List, torch.Tensor]: + """Merge multiple inputs and apply different logic based on input types and count. + + Args: + inputs (list): Inputs to merge; can be a list of lists or a list of tensors. + - If all elements are lists, merged into a single list. + - If elements are a mix of lists and non-list items, + non-list items are wrapped into single-element lists before merging. + - If all tensors, they are concatenated along the specified axis. + axis (int): Axis along which to concatenate tensors, + effective only when inputs are tensors. Default is -1. + - If axis = -1, concatenation is along the last dimension. + - If inputs are lists, this parameter is ignored. + msg (str): Additional log message to identify the context of the operation. + Default is an empty string. + + Returns: + list or torch.Tensor: + - lists, returns the merged list. + - tensors, returns the tensor concatenated along the specified axis. + - If inputs contain only one element, returns that element (no merge). + + Raises: + ValueError: If inputs is an empty list (length 0), + indicating there are no inputs to merge. + """ + if len(inputs) == 0: + raise ValueError("no inputs to be concat:" + msg) + if len(inputs) == 1: + return inputs[0] + from functools import reduce + + if all(isinstance(x, list) for x in inputs): + # merge multiple lists into a list + return reduce(lambda x, y: x + y, inputs) + + if any(isinstance(x, list) for x in inputs): + logger.warning("%s: try to merge inputs into list" % msg) + return reduce( + lambda x, y: x + y, [e if isinstance(e, list) else [e] for e in inputs] + ) + + if axis != -1: + logger.info("concat inputs %s axis=%d" % (msg, axis)) + return torch.cat(inputs, dim=axis) + + +def format_value(value:Union[str,int,list,dict]) -> Union[str,int,list,dict]: + """Format the input value based on its type. + + Args: + value: The value to format. + + Returns: + The formatted value. + """ + if isinstance(value, str): + return value + if isinstance(value, float): + int_v = int(value) + return int_v if int_v == value else value + if isinstance(value, list): + return [format_value(v) for v in value] + if isinstance(value, dict): + return convert_to_dict(value) + return value + + +def convert_to_dict(struct) -> dict: + """Convert a struct_pb2.Struct object to a Python dictionary. + + Args: + struct: A struct_pb2.Struct object. + + Returns: + dict: The converted Python dictionary. + """ + kwargs = {} + for key, value in struct.items(): + kwargs[str(key)] = format_value(value) + return kwargs diff --git a/tzrec/protos/backbone.proto b/tzrec/protos/backbone.proto new file mode 100644 index 00000000..f7289de5 --- /dev/null +++ b/tzrec/protos/backbone.proto @@ -0,0 +1,117 @@ +syntax = "proto2"; +package tzrec.protos; + +import "tzrec/protos/torch_layer.proto"; +import "tzrec/protos/module.proto"; + +message InputLayer { + optional bool do_batch_norm = 1; + optional bool do_layer_norm = 2; + optional float dropout_rate = 3; + optional float feature_dropout_rate = 4; + optional bool only_output_feature_list = 5; + optional bool only_output_3d_tensor = 6; + optional bool output_2d_tensor_and_feature_list = 7; + optional bool output_seq_and_normal_feature = 8; + optional uint32 wide_output_dim = 9; + optional bool concat_seq_feature = 10 [default = true]; +} + +message RawInputLayer { +} + +message EmbeddingLayer { + required uint32 embedding_dim = 1; + optional uint32 vocab_size = 2; + optional string combiner = 3 [default = 'weight']; + optional bool concat = 4 [default = true]; +} + +message Lambda { + required string expression = 1; +} + +message Input { + oneof name { + string feature_group_name = 1; + string block_name = 2; + string package_name = 3; + bool use_package_input = 4; + } + optional string input_fn = 11; + optional string input_slice = 12; + optional bool ignore_input = 13 [default = false]; + optional InputLayer reset_input = 14; + optional string package_input = 15; + optional string package_input_fn = 16; +} + +message RecurrentLayer { + required uint32 num_steps = 1 [default = 1]; + optional uint32 fixed_input_index = 2; + required TorchLayer module = 3; +} + +message RepeatLayer { + required uint32 num_repeat = 1 [default = 1]; + // default output the list of multiple outputs + optional int32 output_concat_axis = 2; + required TorchLayer module = 3; + optional string input_slice = 4; + optional string input_fn = 5; +} + +message Layer { + oneof layer { + Lambda lambda = 1; + TorchLayer module = 2; + RecurrentLayer recurrent = 3; + RepeatLayer repeat = 4; + } +} + +message Block { + required string name = 1; + // the input names of feature groups or other blocks + repeated Input inputs = 2; + optional int32 input_concat_axis = 3 [default = -1]; + optional bool merge_inputs_into_list = 4; + optional string extra_input_fn = 5; + + // sequential layers + repeated Layer layers = 100; + + // only take effect when there are no layers + oneof layer { + InputLayer input_layer = 101; + Lambda lambda = 102; + TorchLayer module = 103; + RecurrentLayer recurrent = 104; + RepeatLayer repeat = 105; + } +} + +// a package of blocks for reuse; e.g. call in a contrastive learning manner +message BlockPackage { + // package name + required string name = 1; + // a few blocks generating a DAG + repeated Block blocks = 2; + // the names of output blocks, will be merge into a tensor + repeated string concat_blocks = 3; + // the names of output blocks, return as a list or single tensor + repeated string output_blocks = 4; +} + +message BackboneTower { + // a few sub DAGs + repeated BlockPackage packages = 1; + // a few blocks generating a DAG + repeated Block blocks = 2; + // the names of output blocks, will be merge into a tensor + repeated string concat_blocks = 3; + // the names of output blocks, return as a list or single tensor + repeated string output_blocks = 4; + // optional top mlp layer + optional MLP top_mlp = 5; +} diff --git a/tzrec/protos/model.proto b/tzrec/protos/model.proto index b3e9c993..003d9181 100644 --- a/tzrec/protos/model.proto +++ b/tzrec/protos/model.proto @@ -9,6 +9,8 @@ import "tzrec/protos/loss.proto"; import "tzrec/protos/metric.proto"; import "tzrec/protos/seq_encoder.proto"; import "tzrec/protos/module.proto"; +import "tzrec/protos/backbone.proto"; +import "tzrec/protos/tower.proto"; enum FeatureGroupType { DEEP = 0; @@ -36,11 +38,34 @@ enum Kernel { CUDA = 2; } +// configure backbone network common parameters +message ModelParams { + optional float l2_regularization = 1; + repeated string outputs = 2; +} + +message ModularRank { + required BackboneTower backbone = 1; + optional ModelParams model_params = 2; +} +message ModularMatch { + required BackboneTower backbone = 1; + optional ModelParams model_params = 2; +} +message ModularMultiTask { + required BackboneTower backbone = 1; + optional ModelParams model_params = 2; + repeated TaskTower task_towers = 3; +} message ModelConfig { repeated FeatureGroupConfig feature_groups = 1; oneof model { + ModularRank rank_backbone = 1001; + ModularMatch match_backbone = 1002; + ModularMultiTask multi_task_backbone = 1003; + DLRM dlrm = 100; DeepFM deepfm = 101; MultiTower multi_tower = 102; diff --git a/tzrec/protos/module.proto b/tzrec/protos/module.proto index a83a7af4..dbe0c732 100644 --- a/tzrec/protos/module.proto +++ b/tzrec/protos/module.proto @@ -251,3 +251,19 @@ message HSTU { // output postprocessor required GROutputPostprocessor output_postprocessor = 6; } + +message FM { + // optional bool use_variant = 1; + // optional float l2_regularization = 5 [default = 1e-4]; +} + +message MMoEModule { + // mmoe expert module definition + required MLP expert_mlp = 1; + // number of mmoe experts + required uint32 num_expert = 3 [default=3]; + // task tower + required uint32 num_task = 4; + // mmoe gate module definition + optional MLP gate_mlp = 2; +} diff --git a/tzrec/protos/torch_layer.proto b/tzrec/protos/torch_layer.proto new file mode 100644 index 00000000..31f0ded1 --- /dev/null +++ b/tzrec/protos/torch_layer.proto @@ -0,0 +1,22 @@ +syntax = "proto2"; +package tzrec.protos; + +import "google/protobuf/struct.proto"; +import "tzrec/protos/module.proto"; +import "tzrec/protos/seq_encoder.proto"; + +message TorchLayer { + required string class_name = 1; + oneof params { + google.protobuf.Struct st_params = 2; + FM fm = 10; + MLP mlp = 11; + DINEncoder din = 12; + MMoEModule mmoe = 14; + Cross cross = 15; + CrossV2 cross_v2 = 16; + // DCNv2Net dcnv2_net = 17; + MaskNetModule mask_net_module = 18; + MaskBlock mask_block = 19; + } +} diff --git a/tzrec/utils/backbone_utils.py b/tzrec/utils/backbone_utils.py new file mode 100644 index 00000000..5a202d07 --- /dev/null +++ b/tzrec/utils/backbone_utils.py @@ -0,0 +1,185 @@ +# Copyright (c) 2025, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Common util functions used by layers.""" + +from google.protobuf import struct_pb2 +from google.protobuf.descriptor import FieldDescriptor + + +def is_proto_message(pb_obj, field) -> bool: + """Check if a given field in a Protocol Buffer object is a message type field. + + This utility function is designed to handle Protocol Buffer object dynamic + attributes and type checking, ensuring that fields conform to specific + message types. + + Args: + pb_obj: The Protocol Buffer object to inspect. + field: The field name to check for message type. + + Returns: + bool: True if the field is a Protocol Buffer message type, False otherwise. + """ + if not hasattr(pb_obj, "DESCRIPTOR"): + return False + if field not in pb_obj.DESCRIPTOR.fields_by_name: + return False + field_type = pb_obj.DESCRIPTOR.fields_by_name[field].type + return field_type == FieldDescriptor.TYPE_MESSAGE + + +class Parameter(object): + """A utility class for encapsulating and managing parameters. + + This class supports handling both structured parameters and Protocol Buffer (PB) + message type parameters. It provides convenient methods and properties for + accessing, modifying, and validating parameters, while supporting nested + structures and default value handling. + + Attributes: + params: The parameter data (dict for struct or PB message object). + is_struct: Boolean indicating if this is a struct-type parameter. + """ + + def __init__(self, params, is_struct) -> None: + self.params = params + self.is_struct = is_struct + + @staticmethod + def make_from_pb(config): + """Create a Parameter instance from a Protocol Buffer configuration. + + Args: + config: The Protocol Buffer configuration object. + + Returns: + Parameter: A new Parameter instance with is_struct=False. + """ + return Parameter(config, False) + + def get_pb_config(self): + """Get the Protocol Buffer configuration object. + + Returns: + The Protocol Buffer configuration object. + + Raises: + AssertionError: If this Parameter instance is a struct type. + """ + assert not self.is_struct, "Struct parameter can not convert to pb config" + return self.params + + def __getattr__(self, key): + if self.is_struct: + if key not in self.params: + return None + value = self.params[key] + if isinstance(value, struct_pb2.Struct): + return Parameter(value, True) + else: + return value + value = getattr(self.params, key) + if is_proto_message(self.params, key): + return Parameter(value, False) + return value + + def __getitem__(self, key): + return self.__getattr__(key) + + def get_or_default(self, key, def_val): + """Get parameter value or return default if not present or empty. + + Args: + key: The parameter key to retrieve. + def_val: The default value to return if key is not found or empty. + + Returns: + The parameter value if present and non-empty, otherwise def_val. + """ + if self.is_struct: + if key in self.params: + if def_val is None: + return self.params[key] + value = self.params[key] + if isinstance(value, float): + return type(def_val)(value) + return value + return def_val + else: # pb message + value = getattr(self.params, key, def_val) + if hasattr(value, "__len__"): # repeated + return value if len(value) > 0 else def_val + try: + if self.params.HasField(key): + return value + except ValueError: + pass + return def_val # maybe not equal to the default value of msg field + + def check_required(self, keys) -> None: + """Check that required keys are present in the struct parameters. + + Args: + keys: A key name or list/tuple of key names to check for presence. + + Raises: + KeyError: If any required key is missing from the struct parameters. + """ + if not self.is_struct: + return + if not isinstance(keys, (list, tuple)): + keys = [keys] + for key in keys: + if key not in self.params: + raise KeyError("%s must be set in params" % key) + + def has_field(self, key) -> bool: + """Check if the parameter has the specified field. + + Args: + key: The field name to check. + + Returns: + bool: True if the field exists, False otherwise. + """ + if self.is_struct: + return key in self.params + else: + return self.params.HasField(key) + + +def params_to_dict(parameter) -> dict: + """Convert Parameter object to a dictionary.""" + + def convert(param) -> dict: + if isinstance(param, Parameter): + if param.is_struct: + return {key: convert(value) for key, value in param.params.items()} + else: # PB message + result = {} + for field in param.params.DESCRIPTOR.fields: + key = field.name + value = getattr(param.params, key, None) + if value is not None: + if is_proto_message(param.params, key): + result[key] = convert(Parameter(value, False)) + elif isinstance(value, struct_pb2.Struct): + result[key] = convert(Parameter(value, True)) + else: + result[key] = value + return result + elif isinstance(param, struct_pb2.Struct): + return {key: convert(value) for key, value in param.fields.items()} + else: + return param + + return convert(parameter) diff --git a/tzrec/utils/dimension_inference.py b/tzrec/utils/dimension_inference.py new file mode 100644 index 00000000..7ef869eb --- /dev/null +++ b/tzrec/utils/dimension_inference.py @@ -0,0 +1,434 @@ +# Copyright (c) 2025, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Enhanced dimension inference utilities for backbone blocks.""" + +import logging +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch.nn as nn + +from tzrec.modules.embedding import EmbeddingGroup + + +class DimensionInfo: + """Class representing dimension information.""" + + def __init__( + self, + dim: Union[int, List[int], Tuple[int, ...]], + shape: Optional[Tuple[int, ...]] = None, + is_list: bool = False, + feature_dim: Optional[int] = None, + ) -> None: + """Initialize DimensionInfo. + + Args: + dim: Dimension information, int (single dim) or a list/tuple (multiple dim). + shape: The complete tensor shape information (if available). + is_list: Indicates whether the output is of a list type. + feature_dim: Explicitly specified feature dime to override inference. + """ + self.dim = dim + self.shape = shape + self.is_list = is_list + self._feature_dim = feature_dim + + def __repr__(self) -> str: + return ( + f"DimensionInfo(dim={self.dim}, shape={self.shape}, " + f"is_list={self.is_list}, feature_dim={self._feature_dim})" + ) + + def get_feature_dim(self) -> Union[int, List[int], Tuple[int, ...]]: + """Get feature dimension (last dimension).""" + # Prefer explicitly specified feature dimensions + if self._feature_dim is not None: + return self._feature_dim + + if isinstance(self.dim, (list, tuple)): + if self.is_list: + # If list type, return the sum of all dimensions + return sum(self.dim) + else: + # If tensor, return the last dimension + return self.dim[-1] if self.dim else 0 + return self.dim + + def get_total_dim(self) -> Union[int, List[int], Tuple[int, ...]]: + """Get the total dimension (for operations such as concat).""" + if isinstance(self.dim, (list, tuple)): + return sum(self.dim) + return self.dim + + def to_list(self) -> List[int]: + """Convert to list format.""" + if isinstance(self.dim, (list, tuple)): + return list(self.dim) + return [self.dim] + + def with_shape(self, shape: Tuple[int, ...]) -> "DimensionInfo": + """Returns a new DimensionInfo with the specified shape information.""" + feature_dim = shape[-1] if shape else self.get_feature_dim() + return DimensionInfo( + dim=self.dim, shape=shape, is_list=self.is_list, feature_dim=feature_dim + ) + + def estimate_shape( + self, batch_size: Optional[int] = None, seq_len: Optional[int] = None + ) -> Tuple[int, ...]: + """Estimate shape based on known information. + + Args: + batch_size: The batch size. + seq_len: The sequence length (if applicable). + + Returns: + The estimated shape as a tuple. + """ + if self.shape is not None: + return self.shape + + feature_dim = self.get_feature_dim() + + # 2D (batch_size, feature_dim) + if batch_size is not None: + if seq_len is not None: + # 3D (batch_size, seq_len, feature_dim) + return (batch_size, seq_len, feature_dim) # pyre-ignore [7] + else: + # 2D (batch_size, feature_dim) + return (batch_size, feature_dim) # pyre-ignore [7] + else: + # Only feature dimensions are returned + return (feature_dim,) + + +class DimensionInferenceEngine: + """Dimension inference engine, manages and infers dim information between blocks.""" + + def __init__(self) -> None: + self.block_input_dims: Dict[str, DimensionInfo] = {} + self.block_output_dims: Dict[str, DimensionInfo] = {} + self.block_layers: Dict[str, nn.Module] = {} + self.logger = logging.getLogger(__name__) + + def register_input_dim(self, block_name: str, dim_info: DimensionInfo) -> None: + """Register the input dimension of the block.""" + self.block_input_dims[block_name] = dim_info + logging.debug(f"Registered input dim for {block_name}: {dim_info}") + + def register_output_dim(self, block_name: str, dim_info: DimensionInfo) -> None: + """Register the output dimension of the block.""" + self.block_output_dims[block_name] = dim_info + logging.debug(f"Registered output dim for {block_name}: {dim_info}") + + def register_layer(self, block_name: str, layer: nn.Module) -> None: + """Register the layer corresponding to the block.""" + self.block_layers[block_name] = layer + + def get_output_dim(self, block_name: str) -> DimensionInfo: + """Get the output dimension of the block.""" + return self.block_output_dims.get(block_name) + + def infer_layer_output_dim( + self, layer: nn.Module, input_dim: DimensionInfo + ) -> DimensionInfo: + """Infer the output dimensions of a layer.""" + if hasattr(layer, "output_dim") and callable(layer.output_dim): + # If the layer has an output_dim method, call it directly + try: + output_dim = layer.output_dim() + # Estimating output shape + input_shape = input_dim.shape + if input_shape is not None: + output_shape = input_shape[:-1] + (output_dim,) + else: + output_shape = input_dim.estimate_shape() + if output_shape: + output_shape = output_shape[:-1] + (output_dim,) + else: + output_shape = None + + return DimensionInfo( + dim=output_dim, shape=output_shape, feature_dim=output_dim + ) + except Exception as e: + logging.warning( + f"Failed to call output_dim on {type(layer).__name__}: {e}" + ) + + # try: + # return create_dimension_info_from_layer_output(layer, input_dim) + # except Exception: + # # failed + # pass + + # Inferring output dimensions based on layer type + layer_type = type(layer).__name__ + + # if layer_type == "MLP": + # if hasattr(layer, "hidden_units") and layer.hidden_units: + # output_dim = layer.hidden_units[-1] + # return DimensionInfo(output_dim, feature_dim=output_dim) + # elif hasattr(layer, "out_features"): + # output_dim = layer.out_features + # return DimensionInfo(output_dim, feature_dim=output_dim) + + # elif layer_type in ["Linear", "LazyLinear"]: + # if hasattr(layer, "out_features"): + # output_dim = layer.out_features + # return DimensionInfo(output_dim, feature_dim=output_dim) + + # elif layer_type == "DIN": + # # DIN + # if hasattr(layer, "_sequence_dim") and layer._sequence_dim is not None: + # # If it has been initialized, return sequence_dim directly + # output_dim = layer._sequence_dim + # return DimensionInfo(output_dim, feature_dim=output_dim) + # else: + # # not initialized yet, infer from input + # if isinstance(input_dim, DimensionInfo): + # # input is [sequence_features, query_features]concat + # # The output dimension is equal to sequence_dim + # total_dim = input_dim.get_feature_dim() + # if total_dim > 0: + # sequence_dim = total_dim // 2 + # logging.info( + # f"DIN output dimension inferred as {sequence_dim} " + # f"(half of input {total_dim})" + # ) + # return DimensionInfo(sequence_dim, feature_dim=sequence_dim) + + # # If inference cannot be made, return the input dimensions + # logging.warning( + # "Cannot infer DIN output dimension, using input dimension" + # ) + # return input_dim + + # elif layer_type == "DINEncoder": + # # DINEncoder + # if hasattr(layer, "_sequence_dim") and layer._sequence_dim is not None: + # output_dim = layer._sequence_dim + # return DimensionInfo(output_dim, feature_dim=output_dim) + # elif hasattr(layer, "output_dim") and callable(layer.output_dim): + # # use output_dim method + # try: + # output_dim = layer.output_dim() + # return DimensionInfo(output_dim, feature_dim=output_dim) + # except Exception: + # pass + + # # If it cannot be obtained from the layer, infer it from the input + # if isinstance(input_dim, DimensionInfo): + # total_dim = input_dim.get_feature_dim() + # if total_dim > 0: + # sequence_dim = total_dim // 2 + # logging.info( + # f"DINEncoder output dimension inferred as {sequence_dim}" + # ) + # return DimensionInfo(sequence_dim, feature_dim=sequence_dim) + + # # If inference cannot be made, return the input dimensions + # logging.warning( + # "Cannot infer DINEncoder output dimension, using input dimension" + # ) + # return input_dim + + # elif layer_type in [ + # "BatchNorm1d", + # "LayerNorm", + # "Dropout", + # "ReLU", + # "GELU", + # "Tanh", + # ]: + # # These layers do not change the dimensions + # return input_dim + + # elif layer_type == "Sequential": + # current_dim = input_dim + # for sublayer in layer: + # current_dim = self.infer_layer_output_dim(sublayer, current_dim) + # return current_dim + + # Default: output dimension is the same as input dimension + logging.warning( + f"Unknown layer type {layer_type}, assuming output dim == input dim" + ) + return input_dim + + def apply_input_transforms( + self, + input_dim: DimensionInfo, + input_fn: Optional[str] = None, + input_slice: Optional[str] = None, + ) -> DimensionInfo: + """input_fn and input_slice transforms.""" + current_dim = input_dim + + # use input_slice + if input_slice is not None: + current_dim = self._apply_input_slice(current_dim, input_slice) + + # use input_fn + if input_fn is not None: + current_dim = self._apply_input_fn(current_dim, input_fn) + + return current_dim + + def _apply_input_slice( + self, dim_info: DimensionInfo, input_slice: str + ) -> DimensionInfo: + """Use input_slice.""" + try: + # Parsing slice expressions + slice_expr = eval( + f"slice{input_slice}" + if input_slice.startswith("[") and input_slice.endswith("]") + else input_slice + ) + + if isinstance(slice_expr, int): + # Single index + if isinstance(dim_info.dim, (list, tuple)): + new_dim = dim_info.dim[slice_expr] + return DimensionInfo(new_dim) + else: + raise ValueError( + f"Cannot apply index {slice_expr} to scalar dimension " + f"{dim_info.dim}" + ) + + elif isinstance(slice_expr, slice): + # slice + if isinstance(dim_info.dim, (list, tuple)): + new_dim = dim_info.dim[slice_expr] + return DimensionInfo(new_dim, is_list=True) + else: + raise ValueError( + f"Cannot apply slice {slice_expr} to scalar dimension " + f"{dim_info.dim}" + ) + + else: + logging.warning(f"Unsupported slice expression: {input_slice}") + return dim_info + + except Exception as e: + logging.error(f"Failed to apply input_slice {input_slice}: {e}") + return dim_info + + def _apply_input_fn(self, dim_info: DimensionInfo, input_fn: str) -> DimensionInfo: + """Use input_fn transform - Prioritize using dummy tensor inference.""" + try: + # First try to use dummy tensor for inference + try: + from tzrec.utils.lambda_inference import infer_lambda_output_dim + + result = infer_lambda_output_dim(dim_info, input_fn) + self.logger.info( + f"Successfully inferred output dim using dummy tensor for " + f"'{input_fn}': {result}" + ) + return result + except Exception as e: + self.logger.debug( + f"Dummy tensor inference failed for '{input_fn}': {e}, " + f"falling back to pattern matching" + ) + return dim_info + + except Exception as e: + logging.error(f"Failed to apply input_fn {input_fn}: {e}") + return dim_info + + def merge_input_dims( + self, input_dims: List[DimensionInfo], merge_mode: str = "concat" + ) -> DimensionInfo: + """Merge multiple input dimensions.""" + if not input_dims: + raise ValueError("No input dimensions to merge") + + if len(input_dims) == 1: + return input_dims[0] + + if merge_mode == "concat": + # Splicing mode: Dimension addition + total_dim = sum(dim_info.get_total_dim() for dim_info in input_dims) + return DimensionInfo(total_dim) + + elif merge_mode == "list": + # List mode: Keep as list + dims = [] + for dim_info in input_dims: + if dim_info is not None: + dims.extend(dim_info.to_list()) + return DimensionInfo(dims, is_list=True) + + elif merge_mode == "stack": + # Stacked Mode: Adding a Dimension + if not all( + dim_info.get_feature_dim() == input_dims[0].get_feature_dim() + for dim_info in input_dims + ): + raise ValueError( + "All inputs must have same feature dimension for stacking" + ) + feature_dim = input_dims[0].get_feature_dim() + return DimensionInfo(feature_dim) + + else: + raise ValueError(f"Unsupported merge mode: {merge_mode}") + + def get_summary(self) -> Dict[str, Any]: + """Get summary information about dimension inference.""" + return { + "total_blocks": len(self.block_output_dims), + "input_dims": { + name: str(dim) for name, dim in self.block_input_dims.items() + }, + "output_dims": { + name: str(dim) for name, dim in self.block_output_dims.items() + }, + } + + +def create_dimension_info_from_embedding( + embedding_group: EmbeddingGroup, group_name: str, batch_size: Optional[int] = None +) -> DimensionInfo: + """Create dimension information from an embedding group. + + Args: + embedding_group: The embedding group object. + group_name: The name of the group. + batch_size: The batch size (optional, used for estimating the full shape). + + Returns: + A DimensionInfo object containing feature dimension information. + """ + try: + total_dim = embedding_group.group_total_dim(group_name) + + # Estimate shape information + if batch_size is not None: + estimated_shape = (batch_size, total_dim) + else: + estimated_shape = None + + return DimensionInfo( + dim=total_dim, + shape=estimated_shape, + feature_dim=total_dim, # Explicitly specify the feature dimension + ) + except Exception as e: + logging.error(f"Failed to get dimension from embedding group {group_name}: {e}") + return DimensionInfo(0, feature_dim=0) diff --git a/tzrec/utils/dimension_inference_test.py b/tzrec/utils/dimension_inference_test.py new file mode 100644 index 00000000..344bf38b --- /dev/null +++ b/tzrec/utils/dimension_inference_test.py @@ -0,0 +1,156 @@ +# Copyright (c) 2025, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import traceback +import unittest + + +class DimensionInferenceTest(unittest.TestCase): + """Test class for DIN automatic dimension inference functionality.""" + + def test_din_module_import(self): + """Test DIN module import functionality.""" + try: + from tzrec.utils.load_class import load_torch_layer + + # Test loading DINEncoder + din_cls, is_customize = load_torch_layer("DIN") + self.assertEqual( + is_customize, True, "DINEncoder should be a customized class" + ) + + self.assertIsNotNone(din_cls, "DINEncoder should not be None") + + # Check parameters of DINEncoder + import inspect + + sig = inspect.signature(din_cls.__init__) + self.assertEqual( + len(sig.parameters), 7, "DINEncoder should have 7 parameters" + ) + self.assertEqual( + list(sig.parameters.keys()), + [ + "self", + "sequence_dim", + "query_dim", + "input", + "attn_mlp", + "max_seq_length", + "kwargs", + ], + ) + + except Exception as e: + self.fail(f"Error importing DINEncoder: {e}") + traceback.print_exc() + + def test_dimension_inference(self): + """Test dimension inference functionality.""" + try: + from tzrec.modules.sequence import DINEncoder + from tzrec.utils.dimension_inference import ( + DimensionInferenceEngine, + DimensionInfo, + ) + + # Create a dimension inference engine + engine = DimensionInferenceEngine() + + # Create a DINEncoder (provide necessary parameters) + din = DINEncoder( + sequence_dim=128, + query_dim=96, + input="seq", + attn_mlp={"hidden_units": [256, 64]}, + max_seq_length=100, + ) + + self.assertEqual(din.output_dim(), 128) + + # Test input dimension info + input_total_dim = 224 + input_dim_info = DimensionInfo( + dim=input_total_dim, + shape=(32, input_total_dim), + feature_dim=input_total_dim, + ) + + # Infer output dimension + output_dim_info = engine.infer_layer_output_dim(din, input_dim_info) + + # Validate inference result + expected_output_dim = 128 + actual_output_dim = output_dim_info.get_feature_dim() + self.assertEqual( + actual_output_dim, + expected_output_dim, + f"Expected output dim {expected_output_dim}, got {actual_output_dim}", + ) + + except Exception as e: + self.fail(f"Dimension inference failed: {e}") + traceback.print_exc() + + def test_automatic_dimension_inference(self): + """Test automatic dimension inference (simulate backbone scenario).""" + try: + import inspect + + from tzrec.modules.sequence import DINEncoder + + # Simulate the process of automatic dimension inference + din_cls = DINEncoder + sig = inspect.signature(din_cls.__init__) + + self.assertEqual( + [p for p in sig.parameters.keys() if p != "self"], + [ + "sequence_dim", + "query_dim", + "input", + "attn_mlp", + "max_seq_length", + "kwargs", + ], + ) + + # Simulate kwargs dictionary (result of proto configuration parsing) + kwargs = { + "input": "seq", + "attn_mlp": {"hidden_units": [256, 64]}, + "max_seq_length": 100, + } + + # Simulate logic for automatic dimension inference + if "sequence_dim" not in kwargs: + kwargs["sequence_dim"] = 128 + + if "query_dim" not in kwargs: + kwargs["query_dim"] = 96 + + self.assertEqual(kwargs["sequence_dim"], 128) + self.assertEqual(kwargs["query_dim"], 96) + self.assertEqual(kwargs["input"], "seq") + self.assertEqual(kwargs["attn_mlp"], {"hidden_units": [256, 64]}) + self.assertEqual(kwargs["max_seq_length"], 100) + + # Create DINEncoder instance + din = din_cls(**kwargs) + self.assertEqual(din.output_dim(), 128) + + except Exception as e: + self.fail(f"Automatic dimension inference failed: {e}") + traceback.print_exc() + + +if __name__ == "__main__": + unittest.main() diff --git a/tzrec/utils/lambda_inference.py b/tzrec/utils/lambda_inference.py new file mode 100644 index 00000000..31c0eb65 --- /dev/null +++ b/tzrec/utils/lambda_inference.py @@ -0,0 +1,291 @@ +# Copyright (c) 2025, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Lambda expression dimension inference module.""" + +import logging +from typing import Any, Callable, Iterable, List, Optional, Tuple, Union + +import torch +import torch.nn as nn + +from tzrec.utils.dimension_inference import DimensionInfo + + +class LambdaOutputDimInferrer: + """Lambda expression output dimension inferer. + + Infer the output dimensions by creating a dummy tensor and + executing the lambda expression. + """ + + def __init__(self) -> None: + """Initialize the Lambda output dimension inferrer.""" + self.logger = logging.getLogger(__name__) + + def infer_output_dim( + self, + input_dim_info: DimensionInfo, + lambda_fn_str: str, + dummy_batch_size: int = 2, + dummy_seq_len: Optional[int] = None, + ) -> DimensionInfo: + """Infer the output dimensions of a lambda expression. + + Args: + input_dim_info: The input dimension information. + lambda_fn_str: The lambda expression string, such as "lambda x: x.sum". + dummy_batch_size: The batch size used to create a dummy tensor. + dummy_seq_len: The sequence length used to create a dummy tensor (optional). + + Returns: + The inferred output dimension information. + """ + # If the first dimension of input_dim_info.shape + # is not None, use it as batch_size + shape = input_dim_info.shape + if shape is not None and len(shape) > 0 and shape[0] is not None: + dummy_batch_size = shape[0] + try: + # 1. Create a dummy tensor + dummy_tensor = self._create_dummy_tensor( + input_dim_info, dummy_batch_size, dummy_seq_len + ) + + # 2. Compile the Lambda function + lambda_fn = self._compile_lambda_function(lambda_fn_str) + + # 3. Execute the Lambda function + with torch.no_grad(): # No gradient computation needed + output_tensor = lambda_fn(dummy_tensor) + + # 4. Parse the output and create a DimensionInfo + return self._analyze_output(output_tensor, input_dim_info) + + except Exception as e: + self.logger.error( + f"Failed to infer output dim for lambda '{lambda_fn_str}': {e}" + ) + # Return the input dimension as fallback on error + self.logger.warning("Falling back to input dimension") + return input_dim_info + + def _create_dummy_tensor( + self, + input_dim_info: DimensionInfo, + batch_size: int, + seq_len: Optional[int] = None, + ) -> torch.Tensor: + """Create a dummy tensor for testing.""" + + def flatten_shape(s: Any) -> Tuple[int, ...]: # pyre-ignore[2] + # Expand the nested shape and keep only int + result = [] + for item in s: + if isinstance(item, (list, tuple)): + result.extend(flatten_shape(item)) + else: + result.append(item) + return tuple(result) + + if input_dim_info.shape is not None: + # if there is full shape info, use it + shape = input_dim_info.shape + # replace the first dimension with dummy_batch_size + if len(shape) > 0: + shape = (batch_size,) + shape[1:] + shape = flatten_shape(shape) + else: + # compute shape based on feature dimension + feature_dim = input_dim_info.get_feature_dim() + + if seq_len is not None: + # 3D: (batch_size, seq_len, feature_dim) + shape = (batch_size, seq_len, feature_dim) + else: + # 2D: (batch_size, feature_dim) + shape = (batch_size, feature_dim) + shape = flatten_shape(shape) + + dummy_tensor = torch.randn(shape, dtype=torch.float32) + self.logger.debug(f"Created dummy tensor with shape: {shape}") + return dummy_tensor + + def _compile_lambda_function( + self, lambda_fn_str: str + ) -> Union[ + Callable[[torch.Tensor], torch.Tensor], + Callable[[Iterable[torch.Tensor]], torch.Tensor], + ]: + """Compile lambda function string.""" + try: + lambda_fn_str = lambda_fn_str.strip() + + lambda_fn = eval(lambda_fn_str) + + if not callable(lambda_fn): + raise ValueError( + f"Lambda expression does not evaluate to a callable: " + f"{lambda_fn_str}" + ) + + return lambda_fn # pyre-ignore[7] + + except Exception as e: + self.logger.error( + f"Failed to compile lambda function '{lambda_fn_str}': {e}" + ) + raise ValueError(f"Invalid lambda expression: {lambda_fn_str}") from e + + def _analyze_output( + self, output_tensor: torch.Tensor, input_dim_info: DimensionInfo + ) -> DimensionInfo: + """Analyze the output tensor and create DimensionInfo.""" + if isinstance(output_tensor, (list, tuple)): + # if the output is list/tuple + if len(output_tensor) == 0: + return DimensionInfo(0, is_list=True) + + # analyze the dimension of each element in the list + dims = [] + shapes = [] + for item in output_tensor: + if isinstance(item, torch.Tensor): + dims.append(item.shape[-1] if len(item.shape) > 0 else 1) + shapes.append(item.shape) + else: + # not a tensor + dims.append(1) + shapes.append((1,)) + + return DimensionInfo( + dim=dims, + shape=shapes[0] if len(set(shapes)) == 1 else None, + is_list=True, + feature_dim=sum(dims), + ) + + elif isinstance(output_tensor, torch.Tensor): + # Standard tensor output + output_shape = tuple(output_tensor.shape) + feature_dim = output_shape[-1] if len(output_shape) > 0 else 1 + + return DimensionInfo( + dim=feature_dim, shape=output_shape, feature_dim=feature_dim + ) + + else: + # other types of output + self.logger.warning(f"Unexpected output type: {type(output_tensor)}") + return DimensionInfo(1, feature_dim=1) + + +class LambdaLayer(nn.Module): + """Lambda expression layer, providing output_dim method.""" + + def __init__( + self, + lambda_fn_str: str, + input_dim_info: DimensionInfo, + name: str = "lambda_layer", + ) -> None: + """Initialize the Lambda layer. + + Args: + lambda_fn_str: lambda expression string + input_dim_info: Input dimension information (used to infer output dimension) + name: Layer name + """ + super().__init__() + self.lambda_fn_str = lambda_fn_str + self.name = name + self._input_dim_info = input_dim_info + self._output_dim_info = None + self._lambda_fn = None + + # compile the lambda function + self._compile_function() + + # if there is input dimension info, infer output dimension immediately + if input_dim_info is not None: + self._infer_output_dim() + + def _compile_function(self) -> None: + """Compile lambda function.""" + inferrer = LambdaOutputDimInferrer() + self._lambda_fn = inferrer._compile_lambda_function(self.lambda_fn_str) + + def _infer_output_dim(self) -> None: + """Infer output dimension.""" + if self._input_dim_info is None: + raise ValueError( + "Cannot infer output dimension without input dimension info" + ) + + inferrer = LambdaOutputDimInferrer() + self._output_dim_info = inferrer.infer_output_dim( + self._input_dim_info, self.lambda_fn_str + ) + + def set_input_dim_info(self, input_dim_info: DimensionInfo) -> None: + """Set input dimension info and re-infer output dimension.""" + self._input_dim_info = input_dim_info + self._infer_output_dim() + + def output_dim(self) -> int: + """Get the output feature dimension.""" + if self._output_dim_info is None: + raise ValueError( + f"Output dimension not available for {self.name}. " + "Make sure to set input_dim_info first." + ) + return self._output_dim_info.get_feature_dim() + + def get_output_dim_info(self) -> DimensionInfo: + """Get the output dimension info.""" + if self._output_dim_info is None: + raise ValueError( + f"Output dimension not available for {self.name}. " + "Make sure to set input_dim_info first." + ) + return self._output_dim_info + + def forward( + self, x: torch.Tensor + ) -> Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]]: + """Forward.""" + if self._lambda_fn is None: + raise ValueError("Lambda function not compiled") + return self._lambda_fn(x) + + def __repr__(self) -> str: + return f"LambdaLayer(name={self.name}, lambda_fn='{self.lambda_fn_str}')" + + +def create_lambda_layer_from_input_fn( + input_fn_str: str, input_dim_info: DimensionInfo, name: str = "input_fn_layer" +) -> LambdaLayer: + """Create a Lambda layer from an input_fn string. + + Convert the input_fn in the backbone configuration + into a layer with an output_dim method. + """ + return LambdaLayer( + lambda_fn_str=input_fn_str, input_dim_info=input_dim_info, name=name + ) + + +def infer_lambda_output_dim( + input_dim_info: DimensionInfo, lambda_fn_str: str +) -> DimensionInfo: + """Infer the output dimensions of a lambda expression.""" + inferrer = LambdaOutputDimInferrer() + return inferrer.infer_output_dim(input_dim_info, lambda_fn_str) diff --git a/tzrec/utils/lambda_inference_test.py b/tzrec/utils/lambda_inference_test.py new file mode 100644 index 00000000..3f67c78e --- /dev/null +++ b/tzrec/utils/lambda_inference_test.py @@ -0,0 +1,116 @@ +# Copyright (c) 2025, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test lambda layer dimension inference in backbone.""" + +import logging +import unittest + +import torch + +from tzrec.modules.backbone import LambdaWrapper +from tzrec.utils.dimension_inference import DimensionInfo +from tzrec.utils.lambda_inference import LambdaOutputDimInferrer + +logging.basicConfig(level=logging.DEBUG) + + +class TestLambdaDimensionInference(unittest.TestCase): + """Test the dimension inference function of the lambda module.""" + + def test_lambda_wrapper_simple(self): + """Testing simple lambda expressions.""" + # create input dimension info + input_dim = DimensionInfo(16, shape=(32, 16)) + + # create lambda wrapper + lambda_wrapper = LambdaWrapper("lambda x: x", "identity") + + # infer output dimension + output_dim = lambda_wrapper.infer_output_dim(input_dim) + + self.assertEqual(output_dim.shape, (32, 16)) + self.assertEqual(output_dim.get_total_dim(), 16) + self.assertEqual(output_dim.get_feature_dim(), 16) + + def test_lambda_wrapper_sum(self): + """Testing the lambda expression for the sum operation.""" + # 3D tensor + input_dim = DimensionInfo( + 16, shape=(32, 10, 16) + ) # batch_size=32, seq_len=10, feature_dim=16 + + # create lambda wrapper - Summing over the sequence dimension + lambda_wrapper = LambdaWrapper("lambda x: x.sum(dim=1)", "sum_seq") + + # infer output dimension + output_dim = lambda_wrapper.infer_output_dim(input_dim) + + # sum over the sequence dimension, should get (32, 16) + self.assertEqual(output_dim.get_feature_dim(), 16) + self.assertEqual(output_dim.shape, (32, 16)) + + def test_lambda_wrapper_list_conversion(self): + """Testing lambda expressions converted to lists.""" + # create input dimension info + input_dim = DimensionInfo(16, shape=(32, 16)) + + # create lambda wrapper - convert to list + lambda_wrapper = LambdaWrapper("lambda x: [x]", "to_list") + + # infer output dimension + output_dim = lambda_wrapper.infer_output_dim(input_dim) + + # After conversion to list, the dimensions + # should be maintained but marked as list type + self.assertEqual(output_dim.get_feature_dim(), 16) + self.assertTrue(output_dim.is_list) + + def test_lambda_wrapper_execution(self): + """Test the execution function of the lambda wrapper.""" + # create lambda wrapper + lambda_wrapper = LambdaWrapper("lambda x: x * 2", "multiply") + + # create test input + test_input = torch.randn(4, 8) + + # execute + output = lambda_wrapper(test_input) + + # expected output + expected = test_input * 2 + torch.testing.assert_close(output, expected) + + def test_direct_inferrer(self): + """Testing LambdaOutputDimInferrer.""" + # create inferrer + inferrer = LambdaOutputDimInferrer() + + # create input dimension info + input_dim = DimensionInfo(16, shape=(32, 16)) + + test_cases = [ + ("lambda x: x", 16), + ("lambda x: x.sum(dim=-1)", 32), + ("lambda x: x.sum(dim=-1, keepdim=True)", 1), + ("lambda x: [x]", 16), + ] + + for lambda_expr, expected_feature_dim in test_cases: + with self.subTest(lambda_expr=lambda_expr): + output_dim = inferrer.infer_output_dim(input_dim, lambda_expr) + + if expected_feature_dim is not None: + self.assertEqual(output_dim.get_feature_dim(), expected_feature_dim) + + +if __name__ == "__main__": + unittest.main() diff --git a/tzrec/utils/load_class.py b/tzrec/utils/load_class.py index fe488e6d..e66eef0c 100644 --- a/tzrec/utils/load_class.py +++ b/tzrec/utils/load_class.py @@ -169,3 +169,33 @@ def load_by_path(path): except pydoc.ErrorDuringImport: print("load %s failed: %s" % (path, traceback.format_exc())) return None + + +def load_torch_layer(name): + """Load torch layer class. + + Args: + name (str): Module class name, e.g. 'Linear' or 'YourCustomLayer' + + Return: + (layer_class, is_customize) + module_class: The class object (e.g., torch.nn.Linear) + is_customize: True if loaded from custom namespace, False if from torch.nn + """ + name = name.strip() + if name == "" or name is None: + return None + + path = "tzrec.modules." + name + try: + cls = pydoc.locate(path) + if cls is not None: + return cls, True + path = "torch.nn." + name + return pydoc.locate(path), False + except pydoc.ErrorDuringImport: + print("load torch layer %s failed" % name) + import logging + + logging.error("load torch layer %s failed: %s" % (name, traceback.format_exc())) + return None, False diff --git a/tzrec/utils/load_class_test.py b/tzrec/utils/load_class_test.py index 617a0d7b..9fc034c8 100644 --- a/tzrec/utils/load_class_test.py +++ b/tzrec/utils/load_class_test.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024, Alibaba Group; +# Copyright (c) 2025, Alibaba Group; # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at