Skip to content

Commit 0eb27cc

Browse files
authored
Support multiplier memory limits (#395)
1 parent 1645393 commit 0eb27cc

File tree

3 files changed

+120
-29
lines changed

3 files changed

+120
-29
lines changed

coq_tools/diagnose_error.py

Lines changed: 58 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,7 @@ def make_reg_string(output, strict_whitespace=False):
271271

272272

273273
TIMEOUT = {}
274+
MEMORY_USAGE = {}
274275

275276

276277
def get_timeout(coqc=None):
@@ -287,6 +288,21 @@ def reset_timeout():
287288
TIMEOUT = {}
288289

289290

291+
def get_memory_usage(key):
292+
return MEMORY_USAGE.get(key)
293+
294+
295+
def set_memory_usage(key, value):
296+
if value is None:
297+
return
298+
MEMORY_USAGE[key] = value
299+
300+
301+
def reset_memory_usage():
302+
global MEMORY_USAGE
303+
MEMORY_USAGE = {}
304+
305+
290306
def timeout_Popen_communicate(log, *args, **kwargs):
291307
ret = {"value": ("", ""), "returncode": None, "rusage": None}
292308
timeout = kwargs.pop("timeout", None)
@@ -298,9 +314,34 @@ def timeout_Popen_communicate(log, *args, **kwargs):
298314
# Extract memory limit parameters
299315
max_mem_rss = kwargs.pop("max_mem_rss", None)
300316
max_mem_as = kwargs.pop("max_mem_as", None)
317+
max_mem_rss_multiplier = kwargs.pop("max_mem_rss_multiplier", None)
318+
max_mem_as_multiplier = kwargs.pop("max_mem_as_multiplier", None)
319+
memory_usage_key = kwargs.pop("memory_usage_key", None)
301320
cgroup = kwargs.pop("cgroup", None)
302321
mem_limit_method = kwargs.pop("mem_limit_method", "prlimit")
303322

323+
def resolve_dynamic_limit(limit_value, multiplier_value, limit_kind):
324+
if multiplier_value is None:
325+
return limit_value
326+
if multiplier_value <= 0:
327+
return limit_value
328+
if memory_usage_key is None:
329+
return None
330+
usage = get_memory_usage((memory_usage_key, limit_kind))
331+
if usage is None:
332+
return None
333+
resolved = int(math.ceil(usage * multiplier_value))
334+
return resolved if resolved > 0 else None
335+
336+
max_mem_rss = resolve_dynamic_limit(max_mem_rss, max_mem_rss_multiplier, "rss")
337+
max_mem_as = resolve_dynamic_limit(max_mem_as, max_mem_as_multiplier, "as")
338+
339+
def record_memory_usage(usage_bytes):
340+
if memory_usage_key is None or usage_bytes is None:
341+
return
342+
set_memory_usage((memory_usage_key, "rss"), usage_bytes)
343+
set_memory_usage((memory_usage_key, "as"), usage_bytes)
344+
304345
# Get command from args
305346
cmd = list(args[0]) if args else []
306347
cg_path = None
@@ -333,15 +374,19 @@ def get_peak_rss():
333374
thread.join(timeout)
334375
if not thread.is_alive():
335376
cleanup_cgroup(cg_path)
336-
return (ret["value"], ret["returncode"], get_peak_rss())
377+
peak_rss_bytes = get_peak_rss()
378+
record_memory_usage(peak_rss_bytes)
379+
return (ret["value"], ret["returncode"], peak_rss_bytes)
337380

338381
p.terminate()
339382
thread.join()
340383
cleanup_cgroup(cg_path)
384+
peak_rss_bytes = get_peak_rss()
385+
record_memory_usage(peak_rss_bytes)
341386
return (
342387
tuple(map((lambda s: (s if s else "") + TIMEOUT_POSTFIX), ret["value"])),
343388
ret["returncode"],
344-
get_peak_rss(),
389+
peak_rss_bytes,
345390
)
346391

347392

@@ -577,11 +622,17 @@ def get_coq_output(
577622
return COQ_OUTPUT[key][1]
578623

579624
start = time.time()
580-
extra_kwargs = {
581-
k: kwargs[k]
582-
for k in ["max_mem_rss", "max_mem_as", "cgroup", "mem_limit_method"]
583-
if k in kwargs
584-
}
625+
extra_keys = [
626+
"max_mem_rss",
627+
"max_mem_as",
628+
"cgroup",
629+
"mem_limit_method",
630+
"max_mem_rss_multiplier",
631+
"max_mem_as_multiplier",
632+
"memory_usage_key",
633+
]
634+
extra_kwargs = {k: kwargs[k] for k in extra_keys if k in kwargs}
635+
extra_kwargs.setdefault("memory_usage_key", coqc_prog)
585636
((stdout, stderr), returncode, peak_rss_bytes) = (
586637
memory_robust_timeout_Popen_communicate(
587638
kwargs["log"],

coq_tools/find_bug.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -78,12 +78,12 @@
7878
from .strip_comments import strip_comments
7979
from .strip_newlines import strip_newlines
8080
from .util import (
81+
MEMORY_LIMIT_METHODS,
8182
PY3,
8283
BooleanOptionalAction,
8384
list_diff,
85+
parse_memory_limit_with_multiplier,
8486
yes_no_prompt,
85-
parse_memory_bytes,
86-
MEMORY_LIMIT_METHODS,
8787
)
8888

8989
if PY3:
@@ -4132,6 +4132,17 @@ def prepend_coqbin(prog):
41324132
args = adjust_no_error_defaults(args)
41334133
bug_file_name = args.bug_file.name
41344134
output_file_name = args.output_file
4135+
4136+
try:
4137+
max_mem_rss_value, max_mem_rss_multiplier = parse_memory_limit_with_multiplier(
4138+
args.max_mem_rss, "--max-mem-rss"
4139+
)
4140+
max_mem_as_value, max_mem_as_multiplier = parse_memory_limit_with_multiplier(
4141+
args.max_mem_as, "--max-mem-as"
4142+
)
4143+
except ValueError as exc:
4144+
parser.error(str(exc))
4145+
41354146
env = {
41364147
"only_inline": args.only_inline,
41374148
"fast_merge_imports": args.fast_merge_imports,
@@ -4276,14 +4287,10 @@ def prepend_coqbin(prog):
42764287
"add_proof_using_before_admit": args.add_proof_using_before_admit,
42774288
"prefer_final_proof_using": args.prefer_final_proof_using,
42784289
"remove_non_definitions": args.remove_non_definitions,
4279-
"max_mem_rss": (
4280-
parse_memory_bytes(args.max_mem_rss)
4281-
if args.max_mem_rss is not None
4282-
else None
4283-
),
4284-
"max_mem_as": (
4285-
parse_memory_bytes(args.max_mem_as) if args.max_mem_as is not None else None
4286-
),
4290+
"max_mem_rss": max_mem_rss_value,
4291+
"max_mem_as": max_mem_as_value,
4292+
"max_mem_rss_multiplier": max_mem_rss_multiplier,
4293+
"max_mem_as_multiplier": max_mem_as_multiplier,
42874294
"cgroup": args.cgroup,
42884295
"mem_limit_method": args.mem_limit_method,
42894296
}
@@ -4332,7 +4339,11 @@ def prepend_coqbin(prog):
43324339
level=LOG_ALWAYS,
43334340
)
43344341
env["mem_limit_method"] = "prlimit"
4335-
if env["mem_limit_method"] == "cgexec" and env["cgroup"] is None and env["max_mem_rss"] is None:
4342+
if (
4343+
env["mem_limit_method"] == "cgexec"
4344+
and env["cgroup"] is None
4345+
and env["max_mem_rss"] is None
4346+
):
43364347
env["log"](
43374348
"\nError: --mem-limit-method=cgexec requires --cgroup or --max-mem-rss.",
43384349
force_stdout=True,

coq_tools/util.py

Lines changed: 40 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -474,7 +474,9 @@ def parse_memory_bytes(mem_str: str) -> int:
474474
return int(number * multiplier)
475475

476476

477-
def wrap_with_prlimit(cmd: List[str], as_bytes: int = None, rss_bytes: int = None) -> List[str]:
477+
def wrap_with_prlimit(
478+
cmd: List[str], as_bytes: int = None, rss_bytes: int = None
479+
) -> List[str]:
478480
"""
479481
Wrap command with prlimit to set resource limits.
480482
@@ -518,7 +520,9 @@ def wrap_with_ulimit(cmd: List[str], as_bytes: int = None) -> List[str]:
518520
return ["sh", "-c", f"ulimit -v {kb}; exec {escaped_cmd}"]
519521

520522

521-
def wrap_with_cgexec(cmd: List[str], cgroup: str, controllers: str = "memory") -> List[str]:
523+
def wrap_with_cgexec(
524+
cmd: List[str], cgroup: str, controllers: str = "memory"
525+
) -> List[str]:
522526
"""
523527
Wrap command with cgexec to run in an existing cgroup.
524528
@@ -541,6 +545,35 @@ def wrap_with_cgexec(cmd: List[str], cgroup: str, controllers: str = "memory") -
541545
return ["cgexec", "-g", f"{controllers}:{cgroup}"] + cmd
542546

543547

548+
def parse_memory_limit_with_multiplier(value, arg_name):
549+
"""returns (absolute_bytes_or_minus_one_or_none, multiplier_or_none)"""
550+
if value is None:
551+
return None, None
552+
value_str = str(value).strip()
553+
if not value_str:
554+
return None, None
555+
lower_val = value_str.lower()
556+
if lower_val.endswith("x"):
557+
multiplier_str = lower_val[:-1].strip()
558+
if not multiplier_str:
559+
raise ValueError(
560+
f"{arg_name} multiplier must include a numeric value before 'x'"
561+
)
562+
try:
563+
multiplier = float(multiplier_str)
564+
except ValueError as exc:
565+
raise ValueError(
566+
f"{arg_name} multiplier '{value_str}' is not a valid number"
567+
) from exc
568+
if multiplier <= 0:
569+
raise ValueError(f"{arg_name} multiplier '{value_str}' must be positive")
570+
return None, multiplier
571+
try:
572+
return parse_memory_bytes(value_str), None
573+
except ValueError as exc:
574+
raise ValueError(f"{arg_name} has invalid value '{value_str}': {exc}") from exc
575+
576+
544577
def wrap_with_cgexec_and_create(
545578
cmd: List[str],
546579
mem_bytes: int,
@@ -607,10 +640,7 @@ def wrap_with_systemd_run(
607640
Note:
608641
Requires systemd and user session (--user) or root privileges.
609642
"""
610-
has_limits = any(
611-
v is not None and v > 0
612-
for v in [mem_bytes, as_bytes, swap_bytes]
613-
)
643+
has_limits = any(v is not None and v > 0 for v in [mem_bytes, as_bytes, swap_bytes])
614644
if not has_limits:
615645
return cmd
616646

@@ -687,7 +717,9 @@ def apply_memory_limit(
687717

688718
elif method == "ulimit":
689719
if rss_bytes is not None and rss_bytes > 0:
690-
raise ValueError("ulimit method only supports --max-mem-as, not --max-mem-rss")
720+
raise ValueError(
721+
"ulimit method only supports --max-mem-as, not --max-mem-rss"
722+
)
691723
return wrap_with_ulimit(cmd, as_bytes=as_bytes), None
692724

693725
elif method == "cgexec":
@@ -702,10 +734,7 @@ def apply_memory_limit(
702734

703735
elif method == "systemd-run":
704736
return wrap_with_systemd_run(
705-
cmd,
706-
mem_bytes=rss_bytes,
707-
as_bytes=as_bytes,
708-
swap_bytes=swap_bytes
737+
cmd, mem_bytes=rss_bytes, as_bytes=as_bytes, swap_bytes=swap_bytes
709738
), None
710739

711740
else:

0 commit comments

Comments
 (0)