Skip to content

Commit e914ced

Browse files
hyeontaekGoogle-ML-Automation
authored andcommitted
[PjRt-IFRT] Create ifrt::PjRtExecutable only from ifrt::PjRtCompiler and CompileOnlyIfrtCompiler
This change migrates direct calls to `ifrt::PjRtExecutable::Create()` outside to use a public IFRT API `ifrt::PjRtCompiler::Compile()` instead. This change should be no-op in practice. For PjRt-IFRT, it now performs IFRT device ID -> PjRt device ID conversion in `xla::CompileOptions::executable_build_options` (which was missing before) and thus can handle a client using a different device ID mapping. PiperOrigin-RevId: 844980890
1 parent 52ee821 commit e914ced

File tree

2 files changed

+20
-7
lines changed

2 files changed

+20
-7
lines changed

jaxlib/py_client.cc

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -499,27 +499,33 @@ PyClient::CompileAndLoadIfrtProgram(
499499
ifrt::DeviceListRef executable_devices, xla::CompileOptions options) {
500500
mlir::OwningOpRef<mlir::ModuleOp> clone(module.clone());
501501
module = *clone;
502-
ifrt::ExecutableRef executable_ref;
502+
ifrt::ExecutableRef ifrt_executable;
503503
{
504504
TF_ASSIGN_OR_RETURN(
505505
auto topology,
506506
client->ifrt_client()->GetTopologyForDevices(executable_devices));
507507
auto xla_options = std::make_unique<ifrt::XlaCompileOptions>(
508508
options, std::move(executable_devices));
509-
#if JAX_IFRT_VERSION_NUMBER >= 38
509+
#if JAX_IFRT_VERSION_NUMBER >= 42
510510
TF_ASSIGN_OR_RETURN(
511-
executable_ref,
511+
ifrt_executable,
512+
client->ifrt_client()->GetDefaultCompiler()->Compile(
513+
std::make_unique<xla::ifrt::HloProgram>(std::move(module)),
514+
*topology, std::move(xla_options)));
515+
#elif JAX_IFRT_VERSION_NUMBER >= 38
516+
TF_ASSIGN_OR_RETURN(
517+
ifrt_executable,
512518
ifrt::PjRtExecutable::Create(std::move(module), std::move(options),
513519
*topology->description()));
514520
#else
515521
TF_ASSIGN_OR_RETURN(
516522
auto pjrt_executable,
517523
PjRtCompile(std::move(options), module, *topology->description()));
518-
TF_ASSIGN_OR_RETURN(executable_ref, ifrt::PjRtExecutable::Create(
519-
std::move(pjrt_executable)));
524+
TF_ASSIGN_OR_RETURN(ifrt_executable, ifrt::PjRtExecutable::Create(
525+
std::move(pjrt_executable)));
520526
#endif
521527
}
522-
return make_nb_class<PyExecutable>(executable_ref);
528+
return make_nb_class<PyExecutable>(ifrt_executable);
523529
}
524530

525531
/* static */ absl::StatusOr<nb_class_ptr<PyLoadedExecutable>>

jaxlib/py_compile_only_client.cc

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ limitations under the License.
4040
#include "xla/python/compile_only_ifrt/client.h"
4141
#include "xla/python/ifrt/device_list.h"
4242
#include "xla/python/ifrt/executable.h"
43+
#include "xla/python/ifrt/hlo/hlo_program.h"
4344
#include "xla/python/pjrt_ifrt/pjrt_executable.h"
4445
#include "xla/python/pjrt_ifrt/pjrt_topology.h"
4546
#include "xla/python/pjrt_ifrt/xla_compiler.h"
@@ -79,7 +80,13 @@ absl::StatusOr<nb_class_ptr<PyExecutable>> CompileOnlyPyClient::CompileUnloaded(
7980

8081
auto xla_options = std::make_unique<ifrt::XlaCompileOptions>(
8182
options, std::move(executable_devices));
82-
#if JAX_IFRT_VERSION_NUMBER >= 38
83+
#if JAX_IFRT_VERSION_NUMBER >= 42
84+
TF_ASSIGN_OR_RETURN(
85+
ifrt_executable,
86+
ifrt_client->GetDefaultCompiler()->Compile(
87+
std::make_unique<xla::ifrt::HloProgram>(std::move(module)),
88+
ifrt_client->topology(), std::move(xla_options)));
89+
#elif JAX_IFRT_VERSION_NUMBER >= 38
8390
TF_ASSIGN_OR_RETURN(
8491
ifrt_executable,
8592
ifrt::PjRtExecutable::Create(std::move(module), std::move(options),

0 commit comments

Comments
 (0)