Skip to content

Commit 19f4c67

Browse files
ICGogGoogle-ML-Automation
authored andcommitted
Add platform name to xla::ifrt::Device
PiperOrigin-RevId: 844067251
1 parent e914ced commit 19f4c67

File tree

2 files changed

+20
-6
lines changed

2 files changed

+20
-6
lines changed

jaxlib/py_array.cc

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ limitations under the License.
4747
#include "absl/strings/str_cat.h"
4848
#include "absl/strings/str_format.h"
4949
#include "absl/strings/str_join.h"
50+
#include "absl/strings/string_view.h"
5051
#include "absl/types/span.h"
5152
#include "llvm/Support/Casting.h"
5253
#include "nanobind/nanobind.h"
@@ -2304,11 +2305,19 @@ absl::Status PyArray::Register(nb::module_& m) {
23042305
nb::is_method());
23052306
type.attr("platform") = nb::cpp_function(
23062307
[](PyArray self) {
2307-
if (self.ifrt_array()->client()->platform_name() == "cuda" ||
2308-
self.ifrt_array()->client()->platform_name() == "rocm") {
2308+
#if JAX_IFRT_VERSION_NUMBER >= 43
2309+
const xla::ifrt::DeviceListRef& devices =
2310+
self.ifrt_array()->sharding().devices();
2311+
absl::string_view platform_name =
2312+
devices->devices().front()->PlatformName();
2313+
#else
2314+
absl::string_view platform_name =
2315+
self.ifrt_array()->client()->platform_name();
2316+
#endif
2317+
if (platform_name == "cuda" || platform_name == "rocm") {
23092318
return std::string_view("gpu");
23102319
} else {
2311-
return self.ifrt_array()->client()->platform_name();
2320+
return platform_name;
23122321
}
23132322
},
23142323
nb::is_method());

jaxlib/py_device.cc

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ limitations under the License.
2828
#include "absl/status/status.h"
2929
#include "absl/status/statusor.h"
3030
#include "absl/strings/str_join.h"
31+
#include "absl/strings/string_view.h"
3132
#include "llvm/Support/Casting.h"
3233
#include "nanobind/nanobind.h"
3334
#include "nanobind/stl/optional.h" // IWYU pragma: keep
@@ -68,11 +69,15 @@ std::string_view PyDevice::platform() const {
6869
// but we haven't yet updated JAX clients that
6970
// expect "gpu". Migrate users and remove this
7071
// code.
71-
if (client_->platform_name() == "cuda" ||
72-
client_->platform_name() == "rocm") {
72+
#if JAX_IFRT_VERSION_NUMBER >= 43
73+
absl::string_view platform_name = device_->PlatformName();
74+
#else
75+
absl::string_view platform_name = client_->platform_name();
76+
#endif
77+
if (platform_name == "cuda" || platform_name == "rocm") {
7378
return std::string_view("gpu");
7479
} else {
75-
return client_->platform_name();
80+
return platform_name;
7681
}
7782
}
7883

0 commit comments

Comments
 (0)