Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion modelscan/scanners/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@
SavedModelLambdaDetectScan,
SavedModelTensorflowOpScan,
)
from modelscan.scanners.keras.scan import KerasLambdaDetectScan
from modelscan.scanners.keras.scan import KerasLambdaDetectScan, KerasWeightsPickleScan
51 changes: 50 additions & 1 deletion modelscan/scanners/keras/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@

from modelscan.error import DependencyError, ModelScanScannerError, JsonDecodeError
from modelscan.skip import ModelScanSkipped, SkipCategories
from modelscan.scanners.scan import ScanResults
from modelscan.scanners.scan import ScanResults, ScanBase
from modelscan.scanners.saved_model.scan import SavedModelLambdaDetectScan
from modelscan.model import Model
from modelscan.settings import SupportedModelFormats
from modelscan.tools.picklescanner import scan_numpy


logger = logging.getLogger("modelscan")
Expand Down Expand Up @@ -136,3 +137,51 @@ def name() -> str:
@staticmethod
def full_name() -> str:
return "modelscan.scanners.KerasLambdaDetectScan"


class KerasWeightsPickleScan(ScanBase):
def scan(self, model: Model) -> Optional[ScanResults]:
if SupportedModelFormats.KERAS.value not in [
format_property.value for format_property in model.get_context("formats")
]:
return None

try:
with zipfile.ZipFile(model.get_stream(), "r") as zip:
file_names = zip.namelist()
for file_name in file_names:
if file_name == "model.weights.npz":
with zip.open(file_name, "r") as weights_file:
# Create a new Model instance for the weights file
weights_model = Model(
f"{model.get_source()}:{file_name}", weights_file
)
# Use the existing numpy scanner to check for malicious pickle content
results = scan_numpy(
model=weights_model,
settings=self._settings,
)
return self.label_results(results)
except zipfile.BadZipFile as e:
return ScanResults(
[],
[],
[
ModelScanSkipped(
self.name(),
SkipCategories.BAD_ZIP,
f"Skipping zip file due to error: {e}",
f"{model.get_source()}",
)
],
)

return ScanResults([], [], [])

@staticmethod
def name() -> str:
return "keras_weights"

@staticmethod
def full_name() -> str:
return "modelscan.scanners.KerasWeightsPickleScan"
4 changes: 4 additions & 0 deletions modelscan/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ class SupportedModelFormats:
"enabled": True,
"supported_extensions": [".keras"],
},
"modelscan.scanners.KerasWeightsPickleScan": {
"enabled": True,
"supported_extensions": [".keras"],
},
"modelscan.scanners.SavedModelLambdaDetectScan": {
"enabled": True,
"supported_extensions": [".pb"],
Expand Down
Loading