From 0b8a195fb2e861a55453ddc14d1bffa09069a29d Mon Sep 17 00:00:00 2001 From: Oliver Holworthy <1216955+oliverholworthy@users.noreply.github.com> Date: Thu, 1 Jun 2023 14:57:10 +0100 Subject: [PATCH 1/4] Add validation in `TabularFeatures.from_schema` that all cols are processed --- transformers4rec/torch/features/tabular.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/transformers4rec/torch/features/tabular.py b/transformers4rec/torch/features/tabular.py index a9a99abaa8..c8bd198416 100644 --- a/transformers4rec/torch/features/tabular.py +++ b/transformers4rec/torch/features/tabular.py @@ -157,6 +157,7 @@ def from_schema( # type: ignore Returns ``TabularFeatures`` from a dataset schema """ maybe_continuous_module, maybe_categorical_module = None, None + processed_features = [] if continuous_tags: if continuous_soft_embeddings: maybe_continuous_module = cls.SOFT_EMBEDDING_MODULE_CLASS.from_schema( @@ -168,10 +169,22 @@ def from_schema( # type: ignore maybe_continuous_module = cls.CONTINUOUS_MODULE_CLASS.from_schema( schema, tags=continuous_tags, **kwargs ) + processed_features.extend(schema.select_by_tag(continuous_tags).column_names) if categorical_tags: maybe_categorical_module = cls.EMBEDDING_MODULE_CLASS.from_schema( schema, tags=categorical_tags, **kwargs ) + processed_features.extend(schema.select_by_tag(categorical_tags).column_names) + + unprocessed_features = set(schema.column_names).difference(set(processed_features)) + if unprocessed_features: + raise ValueError( + "Schema provided to `TabularFeatures` includes features " + "without CONTINUOUS or CATEGORICAL tags. " + "Please ensure all columns have either one of these tags " + "or are excluded from the schema. " + f"\nUnproceesed features: {unprocessed_features} " + ) output = cls( continuous_module=maybe_continuous_module, From f1258ba80998589e68ad906dd7da1b6a294faf8e Mon Sep 17 00:00:00 2001 From: Oliver Holworthy <1216955+oliverholworthy@users.noreply.github.com> Date: Thu, 1 Jun 2023 14:58:13 +0100 Subject: [PATCH 2/4] Add check to `Model.forward` that the inputs match the input_schema --- transformers4rec/torch/model/base.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/transformers4rec/torch/model/base.py b/transformers4rec/torch/model/base.py index f72719cb71..928d3fe3aa 100644 --- a/transformers4rec/torch/model/base.py +++ b/transformers4rec/torch/model/base.py @@ -538,6 +538,21 @@ def __init__( self.top_k = top_k def forward(self, inputs: TabularData, targets=None, training=False, testing=False, **kwargs): + model_expected_features = set(self.input_schema.column_names) + call_input_features = set(inputs.keys()) + if model_expected_features != call_input_features: + raise ValueError( + "Model forward called with different set of features " + "compared with the input schema it was configured with " + "Please check that the inputs passed to the model are only " + "those required by the model." + f"\nModel expected features:\n\t{model_expected_features}" + f"\nCall input features:\n\t{call_input_features}" + f"\nFeatures expected by model input schema only:" + f"\n\t{model_expected_features.difference(call_input_features)}" + f"\nFeatures provided in inputs only:" + f"\n\t{call_input_features.difference(model_expected_features)}" + ) # Convert inputs to float32 which is the default type, expected by PyTorch for name, val in inputs.items(): if torch.is_floating_point(val): From 9c134b6f27477a3325963014d10c8a8c100502e0 Mon Sep 17 00:00:00 2001 From: Oliver Holworthy Date: Thu, 8 Jun 2023 16:07:29 +0100 Subject: [PATCH 3/4] Update tabular.py indentation --- transformers4rec/torch/features/tabular.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformers4rec/torch/features/tabular.py b/transformers4rec/torch/features/tabular.py index 74e15a2d22..aee7832a2b 100644 --- a/transformers4rec/torch/features/tabular.py +++ b/transformers4rec/torch/features/tabular.py @@ -182,7 +182,7 @@ def from_schema( # type: ignore schema, tags=categorical_tags, **kwargs ) processed_features.extend(schema.select_by_tag(categorical_tags).column_names) - if pretrained_embeddings_tags: + if pretrained_embeddings_tags: maybe_pretrained_module = cls.PRETRAINED_EMBEDDING_MODULE_CLASS.from_schema( schema, tags=pretrained_embeddings_tags, **kwargs ) From 7f62635e9fbc036763da54057fffd019eaf3bf23 Mon Sep 17 00:00:00 2001 From: Oliver Holworthy Date: Thu, 8 Jun 2023 16:09:56 +0100 Subject: [PATCH 4/4] Only run input features check if not training or testing --- transformers4rec/torch/model/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformers4rec/torch/model/base.py b/transformers4rec/torch/model/base.py index 928d3fe3aa..9cf817d257 100644 --- a/transformers4rec/torch/model/base.py +++ b/transformers4rec/torch/model/base.py @@ -540,7 +540,7 @@ def __init__( def forward(self, inputs: TabularData, targets=None, training=False, testing=False, **kwargs): model_expected_features = set(self.input_schema.column_names) call_input_features = set(inputs.keys()) - if model_expected_features != call_input_features: + if not (training or testing) and model_expected_features != call_input_features: raise ValueError( "Model forward called with different set of features " "compared with the input schema it was configured with "