Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
275 changes: 238 additions & 37 deletions train_iter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import os
import sys
import time
import threading
from src.arrival_scale import validate_job_arrival_scale
from src.workloadgen_cli import add_workloadgen_args, build_workloadgen_cli_args

Expand Down Expand Up @@ -153,57 +154,255 @@ def build_command(
return command


def run_all_parallel(combinations, max_parallel, iter_limit_per_step, session, prices,
job_durations, jobs, hourly_jobs, job_arrival_scale, jobs_exact_replay, jobs_exact_replay_aggregate, plot_dashboard, dashboard_hours,
seeds, seed_sweep, evaluate_savings, eval_months, workloadgen_args):
active = [] # list of (proc, label)
current_env = os.environ.copy()
def make_log_dir(session):
ts = str(int(time.time()))
if session:
log_dir = os.path.join("sessions", session, "proc_logs", ts)
else:
log_dir = os.path.join("proc_logs", ts)
os.makedirs(log_dir, exist_ok=True)
return log_dir


def label_to_filename(label):
return label.replace(", ", "_").replace("=", "") + ".log"


def _elapsed_str(seconds):
m, s = divmod(int(seconds), 60)
h, m = divmod(m, 60)
return f"{h}h{m:02d}m{s:02d}s" if h else f"{m}m{s:02d}s"


def _run_plain(tasks, max_parallel, log_dir, launch):
pending = list(tasks)
active = [] # (proc, label, log_fh, start_time)
done_log = []
failure_count = 0
multi_seed = len(seeds) > 1
total = len(pending)

for combo, seed in itertools.product(combinations, seeds):
efficiency_weight, price_weight, idle_weight, job_age_weight, drop_weight = combo
label = f"efficiency={efficiency_weight}, price={price_weight}, idle={idle_weight}, job_age={job_age_weight}, drop={drop_weight}"
if multi_seed:
label += f", seed={seed}"
print(f"[run] logs -> {log_dir}/")

try:
while pending or active:
while pending and len(active) < max_parallel:
combo, seed = pending.pop(0)
proc, label, fh, t0 = launch(combo, seed)
print(f"[run] starting ({len(done_log) + len(active) + 1}/{total}): {label}")
active.append((proc, label, fh, t0))

# Wait until a slot is free
while len(active) >= max_parallel:
still_running = []
for proc, lbl in active:
if proc.poll() is None:
still_running.append((proc, lbl))
else:
for proc, label, fh, t0 in active:
if proc.poll() is not None:
fh.close()
rc = proc.returncode
if rc != 0:
failure_count += 1
status = "done" if rc == 0 else f"error (rc={rc})"
print(f"[run] {status}: {lbl}")
elapsed = time.time() - t0
done_log.append((label, rc, elapsed))
status = "done" if rc == 0 else f"FAILED (rc={rc})"
print(f"[run] [{len(done_log)}/{total}] {status}: {label} ({_elapsed_str(elapsed)})")
else:
still_running.append((proc, label, fh, t0))
active = still_running
if len(active) >= max_parallel:

if active:
time.sleep(1)
finally:
for proc, label, fh, t0 in active:
try:
proc.terminate()
try:
proc.wait(timeout=5)
except subprocess.TimeoutExpired:
proc.kill()
proc.wait()
except OSError:
pass
try:
fh.close()
except OSError:
pass

return failure_count


def _draw_tui(stdscr, active, done_log, n_pending, total, log_dir, input_buf=""):
import curses as _curses
try:
stdscr.erase()
h, w = stdscr.getmaxyx()
row = 0

hdr = (f"train_iter [{len(done_log)}/{total} done | {len(active)} running | "
f"{n_pending} queued] logs: {log_dir}/")
stdscr.addstr(row, 0, hdr[:w - 1], _curses.A_BOLD)
row += 1
stdscr.addstr(row, 0, "-" * min(w - 1, 80))
row += 1

# Reserve last line for the terminate prompt
body_end = h - 2

if active and row < body_end:
stdscr.addstr(row, 0, "Running:")
row += 1
for i, (_, label, _, t0) in enumerate(active):
if row >= body_end:
break
line = f" [{i + 1}] {_elapsed_str(time.time() - t0)} {label}"
stdscr.addstr(row, 0, line[:w - 1])
row += 1
row += 1

max_show = body_end - row
if done_log and max_show > 1 and row < body_end:
stdscr.addstr(row, 0, "Completed:")
row += 1
max_show -= 1
for label, rc, elapsed in done_log[-max_show:]:
if row >= body_end:
break
if rc == 0:
status = "done"
elif rc == -1:
status = "terminated"
else:
status = f"FAILED(rc={rc})"
stdscr.addstr(row, 0, f" {status}: {label} ({_elapsed_str(elapsed)})"[:w - 1])
row += 1

# Terminate prompt at the last line
prompt = f"Terminate #: {input_buf}_"
stdscr.addstr(h - 1, 0, prompt[:w - 1])

stdscr.refresh()
except Exception:
pass


def _run_tui(stdscr, tasks, max_parallel, log_dir, launch):
import curses as _curses
_curses.curs_set(0)
stdscr.nodelay(True)

pending = list(tasks)
active = [] # (proc, label, log_fh, start_time)
done_log = [] # (label, rc, elapsed)
failure_count = 0
total = len(pending)
input_buf = ""

while pending or active:
while pending and len(active) < max_parallel:
combo, seed = pending.pop(0)
proc, label, fh, t0 = launch(combo, seed)
active.append((proc, label, fh, t0))

still_running = []
for proc, label, fh, t0 in active:
if proc.poll() is not None:
fh.close()
rc = proc.returncode
if rc != 0:
failure_count += 1
done_log.append((label, rc, time.time() - t0))
else:
still_running.append((proc, label, fh, t0))
active = still_running

_draw_tui(stdscr, active, done_log, len(pending), total, log_dir, input_buf)

try:
key = stdscr.getkey()
if key in ("\n", "\r", "KEY_ENTER"):
try:
idx = int(input_buf) - 1
if 0 <= idx < len(active):
proc, label, fh, t0 = active.pop(idx)
elapsed = time.time() - t0
proc.terminate()
def _reap(proc, label, fh, elapsed):
try:
proc.wait()
except OSError:
pass
try:
fh.close()
except OSError:
pass
done_log.append((label, -1, elapsed))
threading.Thread(target=_reap, args=(proc, label, fh, elapsed), daemon=True).start()
failure_count += 1
except ValueError:
pass
input_buf = ""
elif key in ("KEY_BACKSPACE", "\x7f", "\b"):
input_buf = input_buf[:-1]
elif key == "\x1b": # ESC
input_buf = ""
elif key.isdigit():
input_buf += key
except _curses.error:
pass

time.sleep(0.25)

_draw_tui(stdscr, [], done_log, 0, total, log_dir)
try:
h, w = stdscr.getmaxyx()
summary = f"All {total} runs done. {failure_count} failure(s). Press any key to exit."
stdscr.addstr(h - 1, 0, summary[:w - 1], _curses.A_BOLD)
stdscr.refresh()
except Exception:
pass
if sys.stdin.isatty():
stdscr.nodelay(False)
stdscr.getch()

return failure_count


def run_all_parallel(combinations, max_parallel, iter_limit_per_step, session, prices,
job_durations, jobs, hourly_jobs, job_arrival_scale, jobs_exact_replay,
jobs_exact_replay_aggregate, plot_dashboard, dashboard_hours,
seeds, seed_sweep, evaluate_savings, eval_months, workloadgen_args,
no_tui=False):
multi_seed = len(seeds) > 1
current_env = os.environ.copy()
log_dir = make_log_dir(session)
tasks = list(itertools.product(combinations, seeds))

def launch(combo, seed):
efficiency_weight, price_weight, idle_weight, job_age_weight, drop_weight = combo
label = f"efficiency={efficiency_weight}, price={price_weight}, idle={idle_weight}, job_age={job_age_weight}, drop={drop_weight}"
if multi_seed:
label += f", seed={seed}"
command = build_command(
efficiency_weight, price_weight, idle_weight, job_age_weight, drop_weight,
iter_limit_per_step, session, prices, job_durations, jobs, hourly_jobs, job_arrival_scale, jobs_exact_replay, jobs_exact_replay_aggregate,
iter_limit_per_step, session, prices, job_durations, jobs, hourly_jobs,
job_arrival_scale, jobs_exact_replay, jobs_exact_replay_aggregate,
plot_dashboard, dashboard_hours, seed, seed_sweep,
evaluate_savings, eval_months,
workloadgen_args,
evaluate_savings, eval_months, workloadgen_args,
)
print(f"[run] starting: {label}")
proc = subprocess.Popen(command, env=current_env)
active.append((proc, label))

# Wait for all remaining processes
for proc, label in active:
proc.wait()
rc = proc.returncode
if rc != 0:
failure_count += 1
status = "done" if rc == 0 else f"error (rc={rc})"
print(f"[run] {status}: {label}")

return failure_count
log_path = os.path.join(log_dir, label_to_filename(label))
log_fh = open(log_path, "w")
try:
proc = subprocess.Popen(command, env=current_env, stdout=log_fh, stderr=subprocess.STDOUT)
except OSError:
log_fh.close()
raise
return proc, label, log_fh, time.time()

if not no_tui and sys.stdout.isatty():
import curses
failure_count = [0]
def _run(stdscr):
failure_count[0] = _run_tui(stdscr, tasks, max_parallel, log_dir, launch)
curses.wrapper(_run)
return failure_count[0]
else:
return _run_plain(tasks, max_parallel, log_dir, launch)

def parse_fixed_weights(fix_weights_str, fix_values_str):
if not fix_weights_str or not fix_values_str:
Expand Down Expand Up @@ -244,6 +443,7 @@ def main():
parser.add_argument("--parallel", type=int, default=1, metavar="N", help="Number of training runs to execute in parallel (default: 1, sequential)")
parser.add_argument("--evaluate-savings", action="store_true", help="Forward to train.py to evaluate savings compared to baseline.")
parser.add_argument("--eval-months", type=int, default=6, help="Number of months to evaluate savings over (forwarded to train.py)")
parser.add_argument("--no-tui", action="store_true", help="Disable interactive TUI; print plain progress lines instead (auto-disabled when not a TTY)")
add_workloadgen_args(parser)

parser.add_argument("--session", help="Session ID")
Expand Down Expand Up @@ -310,6 +510,7 @@ def main():
evaluate_savings=args.evaluate_savings,
eval_months=args.eval_months,
workloadgen_args=workloadgen_args,
no_tui=args.no_tui,
)
if failures:
print(f"{failures} run(s) failed")
Expand Down
Loading