diff --git a/modelscan/scanners/__init__.py b/modelscan/scanners/__init__.py index fc09d38b..616acf51 100644 --- a/modelscan/scanners/__init__.py +++ b/modelscan/scanners/__init__.py @@ -9,4 +9,4 @@ SavedModelLambdaDetectScan, SavedModelTensorflowOpScan, ) -from modelscan.scanners.keras.scan import KerasLambdaDetectScan +from modelscan.scanners.keras.scan import KerasLambdaDetectScan, KerasWeightsPickleScan diff --git a/modelscan/scanners/keras/scan.py b/modelscan/scanners/keras/scan.py index 1e88c389..a6ba049e 100644 --- a/modelscan/scanners/keras/scan.py +++ b/modelscan/scanners/keras/scan.py @@ -6,10 +6,11 @@ from modelscan.error import DependencyError, ModelScanScannerError, JsonDecodeError from modelscan.skip import ModelScanSkipped, SkipCategories -from modelscan.scanners.scan import ScanResults +from modelscan.scanners.scan import ScanResults, ScanBase from modelscan.scanners.saved_model.scan import SavedModelLambdaDetectScan from modelscan.model import Model from modelscan.settings import SupportedModelFormats +from modelscan.tools.picklescanner import scan_numpy logger = logging.getLogger("modelscan") @@ -136,3 +137,51 @@ def name() -> str: @staticmethod def full_name() -> str: return "modelscan.scanners.KerasLambdaDetectScan" + + +class KerasWeightsPickleScan(ScanBase): + def scan(self, model: Model) -> Optional[ScanResults]: + if SupportedModelFormats.KERAS.value not in [ + format_property.value for format_property in model.get_context("formats") + ]: + return None + + try: + with zipfile.ZipFile(model.get_stream(), "r") as zip: + file_names = zip.namelist() + for file_name in file_names: + if file_name == "model.weights.npz": + with zip.open(file_name, "r") as weights_file: + # Create a new Model instance for the weights file + weights_model = Model( + f"{model.get_source()}:{file_name}", weights_file + ) + # Use the existing numpy scanner to check for malicious pickle content + results = scan_numpy( + model=weights_model, + settings=self._settings, + ) + return self.label_results(results) + except zipfile.BadZipFile as e: + return ScanResults( + [], + [], + [ + ModelScanSkipped( + self.name(), + SkipCategories.BAD_ZIP, + f"Skipping zip file due to error: {e}", + f"{model.get_source()}", + ) + ], + ) + + return ScanResults([], [], []) + + @staticmethod + def name() -> str: + return "keras_weights" + + @staticmethod + def full_name() -> str: + return "modelscan.scanners.KerasWeightsPickleScan" diff --git a/modelscan/settings.py b/modelscan/settings.py index 2b99b4a9..3faa1221 100644 --- a/modelscan/settings.py +++ b/modelscan/settings.py @@ -37,6 +37,10 @@ class SupportedModelFormats: "enabled": True, "supported_extensions": [".keras"], }, + "modelscan.scanners.KerasWeightsPickleScan": { + "enabled": True, + "supported_extensions": [".keras"], + }, "modelscan.scanners.SavedModelLambdaDetectScan": { "enabled": True, "supported_extensions": [".pb"],