From 5a03adbefb8963c46b13c846cd56c6fbce4fae26 Mon Sep 17 00:00:00 2001 From: ankit-amazon <125257518+ankit-amazon@users.noreply.github.com> Date: Tue, 9 May 2023 18:32:41 +0530 Subject: [PATCH 1/3] Improve compliant and noncompliant examples for python/pytorch-miss-call-to-eval@v1.0 Addressed review comments --- .../pytorch_miss_call_to_eval.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/src/python/detectors/pytorch_miss_call_to_eval/pytorch_miss_call_to_eval.py b/src/python/detectors/pytorch_miss_call_to_eval/pytorch_miss_call_to_eval.py index 8b42d83..41c7878 100644 --- a/src/python/detectors/pytorch_miss_call_to_eval/pytorch_miss_call_to_eval.py +++ b/src/python/detectors/pytorch_miss_call_to_eval/pytorch_miss_call_to_eval.py @@ -4,14 +4,29 @@ # {fact rule=pytorch-miss-call-to-eval@v1.0 defects=1} def pytorch_miss_call_to_eval_noncompliant(model): import torch - # Noncompliant: miss call to `eval()` after load. model.load_state_dict(torch.load("model.pth")) + # Noncompliant: miss call to `eval()` after load. + classes = ["T-shirt/top", "Trouser", "Pullover", "Dress", "Coat", + "Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"] + x, y = test_data[0][0], test_data[0][1] + with torch.no_grad(): + pred = model(x) + predicted, actual = classes[pred[0].argmax(0)], classes[y] + print(f'Predicted: "{predicted}", Actual: "{actual}"') # {/fact} # {fact rule=pytorch-miss-call-to-eval@v1.0 defects=0} def pytorch_miss_call_to_eval_compliant(model): + import torch model.load_state_dict(torch.load("model.pth")) + classes = ["T-shirt/top", "Trouser", "Pullover", "Dress", "Coat", + "Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"] # Compliant: `eval()` is called after load. model.eval() + x, y = test_data[0][0], test_data[0][1] + with torch.no_grad(): + pred = model(x) + predicted, actual = classes[pred[0].argmax(0)], classes[y] + print(f'Predicted: "{predicted}", Actual: "{actual}"') # {/fact} From bfd94cac1a0f0322bdfb00e0d0500c3f6bba1cb8 Mon Sep 17 00:00:00 2001 From: ankit-amazon <125257518+ankit-amazon@users.noreply.github.com> Date: Tue, 9 May 2023 21:16:43 +0530 Subject: [PATCH 2/3] Update pytorch_miss_call_to_eval.py --- .../pytorch_miss_call_to_eval/pytorch_miss_call_to_eval.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/python/detectors/pytorch_miss_call_to_eval/pytorch_miss_call_to_eval.py b/src/python/detectors/pytorch_miss_call_to_eval/pytorch_miss_call_to_eval.py index 41c7878..1c18e83 100644 --- a/src/python/detectors/pytorch_miss_call_to_eval/pytorch_miss_call_to_eval.py +++ b/src/python/detectors/pytorch_miss_call_to_eval/pytorch_miss_call_to_eval.py @@ -5,7 +5,7 @@ def pytorch_miss_call_to_eval_noncompliant(model): import torch model.load_state_dict(torch.load("model.pth")) - # Noncompliant: miss call to `eval()` after load. + # Noncompliant: miss call to `eval()` before evaluating the model. classes = ["T-shirt/top", "Trouser", "Pullover", "Dress", "Coat", "Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"] x, y = test_data[0][0], test_data[0][1] @@ -22,7 +22,7 @@ def pytorch_miss_call_to_eval_compliant(model): model.load_state_dict(torch.load("model.pth")) classes = ["T-shirt/top", "Trouser", "Pullover", "Dress", "Coat", "Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"] - # Compliant: `eval()` is called after load. + # Compliant: `eval()` is called before evaluating the model. model.eval() x, y = test_data[0][0], test_data[0][1] with torch.no_grad(): From 77b491df2539fe3d1eaae1b213068b7046f66b41 Mon Sep 17 00:00:00 2001 From: ankit-amazon <125257518+ankit-amazon@users.noreply.github.com> Date: Tue, 9 May 2023 21:21:38 +0530 Subject: [PATCH 3/3] Update pytorch_miss_call_to_eval.py --- .../pytorch_miss_call_to_eval/pytorch_miss_call_to_eval.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/python/detectors/pytorch_miss_call_to_eval/pytorch_miss_call_to_eval.py b/src/python/detectors/pytorch_miss_call_to_eval/pytorch_miss_call_to_eval.py index 1c18e83..b1ae941 100644 --- a/src/python/detectors/pytorch_miss_call_to_eval/pytorch_miss_call_to_eval.py +++ b/src/python/detectors/pytorch_miss_call_to_eval/pytorch_miss_call_to_eval.py @@ -22,7 +22,7 @@ def pytorch_miss_call_to_eval_compliant(model): model.load_state_dict(torch.load("model.pth")) classes = ["T-shirt/top", "Trouser", "Pullover", "Dress", "Coat", "Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"] - # Compliant: `eval()` is called before evaluating the model. + # Compliant: `eval()` is called before evaluating the model. model.eval() x, y = test_data[0][0], test_data[0][1] with torch.no_grad():