diff --git a/mojo/mojo_host_platform.bzl b/mojo/mojo_host_platform.bzl index a4dcd15..ce9fe43 100644 --- a/mojo/mojo_host_platform.bzl +++ b/mojo/mojo_host_platform.bzl @@ -12,7 +12,14 @@ def _log_result(rctx, binary, result): .format(binary, result.return_code, result.stdout, result.stderr), ) -def _get_amdgpu_constraint(series, gpu_mapping): +def _fail(rctx, msg): + if rctx.getenv("MOJO_IGNORE_UNKNOWN_GPUS") == "1": + # buildifier: disable=print + print("WARNING: ignoring unknown GPU, to support it, add it to the gpu_mapping in the MODULE.bazel: {}".format(msg)) + else: + fail(msg) + +def _get_amdgpu_constraint(rctx, series, gpu_mapping): for gpu_name, constraint in gpu_mapping.items(): if gpu_name in series: if constraint: @@ -20,21 +27,22 @@ def _get_amdgpu_constraint(series, gpu_mapping): else: return None - fail("Unrecognized amd-smi/rocm-smi output, please add it to your gpu_mapping in the MODULE.bazel file: {}".format(series)) + _fail(rctx, "Unrecognized amd-smi/rocm-smi output, please add it to your gpu_mapping in the MODULE.bazel file: {}".format(series)) + return None -def _get_rocm_constraint(blob, gpu_mapping): +def _get_rocm_constraint(rctx, blob, gpu_mapping): for value in blob.values(): series = value["Card Series"] - return _get_amdgpu_constraint(series, gpu_mapping) + return _get_amdgpu_constraint(rctx, series, gpu_mapping) fail("Unrecognized rocm-smi output, please report: {}".format(blob)) -def _get_amd_constraint(blob, gpu_mapping): +def _get_amd_constraint(rctx, blob, gpu_mapping): for value in blob: series = value["asic"]["market_name"] - return _get_amdgpu_constraint(series, gpu_mapping) + return _get_amdgpu_constraint(rctx, series, gpu_mapping) fail("Unrecognized amd-smi output, please report: {}".format(blob)) -def _get_nvidia_constraint(lines, gpu_mapping): +def _get_nvidia_constraint(rctx, lines, gpu_mapping): line = lines[0] for gpu_name, constraint in gpu_mapping.items(): if gpu_name in line: @@ -43,7 +51,8 @@ def _get_nvidia_constraint(lines, gpu_mapping): else: return None - fail("Unrecognized nvidia-smi output, please add it to your gpu_mapping in the MODULE.bazel file: {}".format(lines)) + _fail(rctx, "Unrecognized nvidia-smi output, please add it to your gpu_mapping in the MODULE.bazel file: {}".format(lines)) + return None def _impl(rctx): constraints = [] @@ -67,7 +76,7 @@ def _impl(rctx): if len(lines) == 0: fail("nvidia-smi succeeded but had no GPUs, please report this issue") - constraint = _get_nvidia_constraint(lines, rctx.attr.gpu_mapping) + constraint = _get_nvidia_constraint(rctx, lines, rctx.attr.gpu_mapping) if constraint: constraints.extend([ "@mojo_gpu_toolchains//:nvidia_gpu", @@ -75,10 +84,10 @@ def _impl(rctx): constraint, ]) - if len(lines) > 1: - constraints.append("@mojo_gpu_toolchains//:has_multi_gpu") - if len(lines) >= 4: - constraints.append("@mojo_gpu_toolchains//:has_4_gpus") + if len(lines) > 1: + constraints.append("@mojo_gpu_toolchains//:has_multi_gpu") + if len(lines) >= 4: + constraints.append("@mojo_gpu_toolchains//:has_4_gpus") # AMD if amd_smi: @@ -86,40 +95,44 @@ def _impl(rctx): _log_result(rctx, amd_smi, result) if result.return_code == 0: - constraints.extend([ - "@mojo_gpu_toolchains//:amd_gpu", - "@mojo_gpu_toolchains//:has_gpu", - ]) - blob = json.decode(result.stdout) if len(blob) == 0: fail("amd-smi succeeded but didn't actually have any GPUs, please report this issue") - constraints.append(_get_amd_constraint(blob, rctx.attr.gpu_mapping)) - if len(blob) > 1: - constraints.append("@mojo_gpu_toolchains//:has_multi_gpu") - if len(blob) >= 4: - constraints.append("@mojo_gpu_toolchains//:has_4_gpus") + amd_constraint = _get_amd_constraint(rctx, blob, rctx.attr.gpu_mapping) + if amd_constraint: + constraints.extend([ + amd_constraint, + "@mojo_gpu_toolchains//:amd_gpu", + "@mojo_gpu_toolchains//:has_gpu", + ]) + + if len(blob) > 1: + constraints.append("@mojo_gpu_toolchains//:has_multi_gpu") + if len(blob) >= 4: + constraints.append("@mojo_gpu_toolchains//:has_4_gpus") elif rocm_smi: result = rctx.execute([rocm_smi, "--json", "--showproductname"]) _log_result(rctx, rocm_smi, result) if result.return_code == 0: - constraints.extend([ - "@mojo_gpu_toolchains//:amd_gpu", - "@mojo_gpu_toolchains//:has_gpu", - ]) - blob = json.decode(result.stdout) if len(blob.keys()) == 0: fail("rocm-smi succeeded but didn't actually have any GPUs, please report this issue") - constraints.append(_get_rocm_constraint(blob, rctx.attr.gpu_mapping)) - if len(blob.keys()) > 1: - constraints.append("@mojo_gpu_toolchains//:has_multi_gpu") - if len(blob.keys()) >= 4: - constraints.append("@mojo_gpu_toolchains//:has_4_gpus") + rocm_constraint = _get_rocm_constraint(rctx, blob, rctx.attr.gpu_mapping) + if rocm_constraint: + constraints.extend([ + rocm_constraint, + "@mojo_gpu_toolchains//:amd_gpu", + "@mojo_gpu_toolchains//:has_gpu", + ]) + + if len(blob.keys()) > 1: + constraints.append("@mojo_gpu_toolchains//:has_multi_gpu") + if len(blob.keys()) >= 4: + constraints.append("@mojo_gpu_toolchains//:has_4_gpus") rctx.file("WORKSPACE.bazel", "workspace(name = {})".format(rctx.attr.name)) rctx.file("BUILD.bazel", """ @@ -138,6 +151,7 @@ mojo_host_platform = repository_rule( implementation = _impl, configure = True, environ = [ + "MOJO_IGNORE_UNKNOWN_GPUS", "MOJO_VERBOSE_GPU_DETECT", ], attrs = {