[PyTorch] [CI] Capture subprocess stderr in distributed tests for better CI error re…#2802
Conversation
…porting Distributed tests launch subprocesses via torch.distributed.launch/torchrun. When these fail, pytest only captures the CalledProcessError from the parent process, not the actual worker traceback. This makes CI JUnit XML reports show "exit code 1" with no useful error detail. Add run_distributed() utility to tests/pytorch/utils.py that captures stderr while letting stdout stream to the terminal. On failure, the worker's stderr (containing the actual Python traceback) is included in the AssertionError, which pytest writes into the JUnit XML report. Behavior: - Interactive use: stdout streams in real time (unchanged), stderr shown on failure - CI/JUnit XML: failure reports now include the actual worker traceback Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Add --output-junit flag so ctest writes JUnit XML to /logs/, matching the pattern used by pytest tests. The XML is written before ctest exits, so it's captured even on test failure. Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Greptile SummaryThis PR improves CI error reporting for distributed PyTorch tests by adding a Key changes:
Confidence Score: 5/5Safe to merge — all findings are P2 style/cleanup issues with no functional impact. All remaining issues are non-blocking: one missed migration that still works correctly via tests/pytorch/distributed/test_cast_master_weights_to_fp8.py — the Important Files Changed
Sequence DiagramsequenceDiagram
participant PT as pytest (parent process)
participant RD as run_distributed()
participant SP as subprocess (torchrun/torch.distributed.run)
participant W as GPU worker(s)
PT->>RD: run_distributed(args, valid_returncodes, **kwargs)
RD->>SP: subprocess.run(args, stderr=PIPE, text=True, **kwargs)
SP->>W: launch worker processes
W-->>SP: stdout → terminal (streamed)
W-->>SP: stderr → PIPE (buffered)
SP-->>RD: CompletedProcess(returncode, stderr)
alt returncode in valid_returncodes
RD-->>PT: return CompletedProcess
else failure
RD->>RD: build AssertionError with stderr[-4000:]
RD-->>PT: raise AssertionError (captured in JUnit XML)
end
Reviews (2): Last reviewed commit: "Merge branch 'main' into sudhakars/impro..." | Re-trigger Greptile |
| Use (0, 5) for inner pytest runs where 5 means all tests skipped. | ||
| **kwargs: Passed through to subprocess.run (e.g. env, timeout). | ||
| """ | ||
| result = subprocess.run(args, stderr=subprocess.PIPE, text=True, **kwargs) |
There was a problem hiding this comment.
**kwargs can silently conflict with stderr and text
If a caller ever passes stderr= or text= through **kwargs, Python will raise TypeError: subprocess.run() got multiple values for keyword argument 'stderr'. Consider explicitly popping or blocking those keys, or documenting the restriction:
kwargs.pop("stderr", None) # always captured internally
kwargs.pop("text", None) # always text mode internally
result = subprocess.run(args, stderr=subprocess.PIPE, text=True, **kwargs)None of the current call sites pass these, so this is not an immediate bug — just a fragile API surface.
…porting
Distributed tests launch subprocesses via torch.distributed.launch/torchrun. When these fail, pytest only captures the CalledProcessError from the parent process, not the actual worker traceback. This makes CI JUnit XML reports show "exit code 1" with no useful error detail.
Add run_distributed() utility to tests/pytorch/utils.py that captures stderr while letting stdout stream to the terminal. On failure, the worker's stderr (containing the actual Python traceback) is included in the AssertionError, which pytest writes into the JUnit XML report.
Behavior:
Description
Please include a brief summary of the changes, relevant motivation and context.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: