File tree Expand file tree Collapse file tree 2 files changed +20
-6
lines changed
Expand file tree Collapse file tree 2 files changed +20
-6
lines changed Original file line number Diff line number Diff line change @@ -47,6 +47,7 @@ limitations under the License.
4747#include " absl/strings/str_cat.h"
4848#include " absl/strings/str_format.h"
4949#include " absl/strings/str_join.h"
50+ #include " absl/strings/string_view.h"
5051#include " absl/types/span.h"
5152#include " llvm/Support/Casting.h"
5253#include " nanobind/nanobind.h"
@@ -2304,11 +2305,19 @@ absl::Status PyArray::Register(nb::module_& m) {
23042305 nb::is_method ());
23052306 type.attr (" platform" ) = nb::cpp_function (
23062307 [](PyArray self) {
2307- if (self.ifrt_array ()->client ()->platform_name () == " cuda" ||
2308- self.ifrt_array ()->client ()->platform_name () == " rocm" ) {
2308+ #if JAX_IFRT_VERSION_NUMBER >= 43
2309+ const xla::ifrt::DeviceListRef& devices =
2310+ self.ifrt_array ()->sharding ().devices ();
2311+ absl::string_view platform_name =
2312+ devices->devices ().front ()->PlatformName ();
2313+ #else
2314+ absl::string_view platform_name =
2315+ self.ifrt_array ()->client ()->platform_name ();
2316+ #endif
2317+ if (platform_name == " cuda" || platform_name == " rocm" ) {
23092318 return std::string_view (" gpu" );
23102319 } else {
2311- return self. ifrt_array ()-> client ()-> platform_name () ;
2320+ return platform_name;
23122321 }
23132322 },
23142323 nb::is_method ());
Original file line number Diff line number Diff line change @@ -28,6 +28,7 @@ limitations under the License.
2828#include " absl/status/status.h"
2929#include " absl/status/statusor.h"
3030#include " absl/strings/str_join.h"
31+ #include " absl/strings/string_view.h"
3132#include " llvm/Support/Casting.h"
3233#include " nanobind/nanobind.h"
3334#include " nanobind/stl/optional.h" // IWYU pragma: keep
@@ -68,11 +69,15 @@ std::string_view PyDevice::platform() const {
6869 // but we haven't yet updated JAX clients that
6970 // expect "gpu". Migrate users and remove this
7071 // code.
71- if (client_->platform_name () == " cuda" ||
72- client_->platform_name () == " rocm" ) {
72+ #if JAX_IFRT_VERSION_NUMBER >= 43
73+ absl::string_view platform_name = device_->PlatformName ();
74+ #else
75+ absl::string_view platform_name = client_->platform_name ();
76+ #endif
77+ if (platform_name == " cuda" || platform_name == " rocm" ) {
7378 return std::string_view (" gpu" );
7479 } else {
75- return client_-> platform_name () ;
80+ return platform_name;
7681 }
7782}
7883
You can’t perform that action at this time.
0 commit comments