Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 33 additions & 4 deletions mlx/backend/metal/allocator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,6 @@ MetalAllocator::MetalAllocator()
block_limit_ = std::min(1.5 * max_rec_size, 0.95 * memsize);
gc_limit_ = std::min(static_cast<size_t>(0.95 * max_rec_size), block_limit_);
max_pool_size_ = block_limit_;
device(mlx::core::Device::gpu)
.set_residency_set(residency_set_.mtl_residency_set());
bool is_vm = std::get<std::string>(info.at("device_name")) ==
"Apple Paravirtual device";
if (is_vm) {
Expand Down Expand Up @@ -96,11 +94,42 @@ size_t MetalAllocator::get_memory_limit() {

size_t MetalAllocator::set_wired_limit(size_t limit) {
std::unique_lock lk(mutex_);
std::swap(limit, wired_limit_);
size_t previous = wired_limit_;
wired_limit_ = limit;
// During active Metal capture, avoid issuing residency-set mutations so the
// capture remains replayable for derived counters.
if (!metal::residency_sets_enabled() || metal::is_capture_active()) {
wired_limit_pending_apply_ = true;
return previous;
}
// Attach the residency set lazily so normal execution (wired limit == 0)
// does not add queue residency-set state that is never used.
if (wired_limit_ > 0 && !residency_set_attached_ &&
residency_set_.mtl_residency_set() != nullptr) {
device(mlx::core::Device::gpu)
.set_residency_set(residency_set_.mtl_residency_set());
residency_set_attached_ = true;
}
residency_set_.resize(wired_limit_);
return limit;
wired_limit_pending_apply_ = false;
return previous;
};

void MetalAllocator::on_capture_stop() {
std::unique_lock lk(mutex_);
if (!wired_limit_pending_apply_ || !metal::residency_sets_enabled()) {
return;
}
if (wired_limit_ > 0 && !residency_set_attached_ &&
residency_set_.mtl_residency_set() != nullptr) {
device(mlx::core::Device::gpu)
.set_residency_set(residency_set_.mtl_residency_set());
residency_set_attached_ = true;
}
residency_set_.resize(wired_limit_);
wired_limit_pending_apply_ = false;
}

Buffer MetalAllocator::malloc(size_t size) {
// Metal doesn't like empty buffers
if (size == 0) {
Expand Down
3 changes: 3 additions & 0 deletions mlx/backend/metal/allocator.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class MetalAllocator : public allocator::Allocator {
size_t set_memory_limit(size_t limit);
size_t get_memory_limit();
size_t set_wired_limit(size_t limit);
void on_capture_stop();
void clear_cache();

private:
Expand Down Expand Up @@ -68,6 +69,8 @@ class MetalAllocator : public allocator::Allocator {
size_t peak_memory_{0};
size_t max_pool_size_;
size_t wired_limit_{0};
bool wired_limit_pending_apply_{false};
bool residency_set_attached_{false};
size_t num_resources_{0};
size_t resource_limit_{0};

Expand Down
56 changes: 53 additions & 3 deletions mlx/backend/metal/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -370,13 +370,15 @@ void Device::new_queue(int index) {
throw std::runtime_error(
"[metal::Device] Failed to make new command queue.");
}
attach_residency_set_to_existing_queues_if_needed_();
stream_map_.emplace(index, q);
if (residency_set_ != nullptr) {
q->addResidencySet(residency_set_);
if (residency_set_ != nullptr && !metal::is_capture_active()) {
attach_residency_set_to_stream_(stream_map_.find(index)->second);
}
}

MTL::CommandQueue* Device::get_queue(Stream stream) {
attach_residency_set_to_existing_queues_if_needed_();
return get_stream_(stream.index).queue;
}

Expand All @@ -387,6 +389,7 @@ bool Device::command_buffer_needs_commit(int index) {
}

MTL::CommandBuffer* Device::get_command_buffer(int index) {
attach_residency_set_to_existing_queues_if_needed_();
auto& stream = get_stream_(index);
if (stream.buffer == nullptr) {
stream.buffer = stream.queue->commandBufferWithUnretainedReferences();
Expand Down Expand Up @@ -785,6 +788,28 @@ MTL::ComputePipelineState* Device::get_kernel(
base_name, default_library_, hash_name, func_consts, linked_functions);
}

void Device::attach_residency_set_to_stream_(DeviceStream& stream) {
if (residency_set_ == nullptr || stream.residency_set_attached ||
!metal::residency_sets_enabled() || metal::is_capture_active()) {
return;
}
stream.queue->addResidencySet(residency_set_);
stream.residency_set_attached = true;
}

void Device::attach_residency_set_to_existing_queues_if_needed_() {
if (!residency_set_pending_attach_ || residency_set_ == nullptr) {
return;
}
if (!metal::residency_sets_enabled() || metal::is_capture_active()) {
return;
}
for (auto& [_, stream] : stream_map_) {
attach_residency_set_to_stream_(stream);
}
residency_set_pending_attach_ = false;
}

void Device::set_residency_set(const MTL::ResidencySet* residency_set) {
if (residency_set_ != nullptr) {
throw std::runtime_error(
Expand All @@ -794,10 +819,35 @@ void Device::set_residency_set(const MTL::ResidencySet* residency_set) {
return;
}
residency_set_ = residency_set;
if (!metal::residency_sets_enabled() || metal::is_capture_active()) {
residency_set_pending_attach_ = true;
return;
}
// Attach residency set to existing command queues
for (auto& [_, stream] : stream_map_) {
stream.queue->addResidencySet(residency_set_);
attach_residency_set_to_stream_(stream);
}
residency_set_pending_attach_ = false;
}

void Device::on_capture_start() {
if (residency_set_ == nullptr) {
return;
}
// Avoid queue-level residency state in captures; this can break
// derived-counter replay in Xcode.
for (auto& [_, stream] : stream_map_) {
if (!stream.residency_set_attached) {
continue;
}
stream.queue->removeResidencySet(residency_set_);
stream.residency_set_attached = false;
}
residency_set_pending_attach_ = true;
}

void Device::on_capture_stop() {
attach_residency_set_to_existing_queues_if_needed_();
}

Device& device(mlx::core::Device) {
Expand Down
7 changes: 7 additions & 0 deletions mlx/backend/metal/device.h
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ struct DeviceStream {
std::unique_ptr<CommandEncoder> encoder{nullptr};
std::shared_ptr<Fence> fence;
std::vector<array> temporaries;
bool residency_set_attached{false};
};

class MLX_API Device {
Expand Down Expand Up @@ -206,8 +207,13 @@ class MLX_API Device {
void add_temporaries(std::vector<array> arrays, int index);

void set_residency_set(const MTL::ResidencySet* residency_set);
void on_capture_start();
void on_capture_stop();

private:
void attach_residency_set_to_stream_(DeviceStream& stream);
void attach_residency_set_to_existing_queues_if_needed_();

DeviceStream& get_stream_(int index) {
return stream_map_.find(index)->second;
}
Expand Down Expand Up @@ -255,6 +261,7 @@ class MLX_API Device {
std::unordered_map<std::string, MTL::ComputePipelineState*>>
library_kernels_;
const MTL::ResidencySet* residency_set_{nullptr};
bool residency_set_pending_attach_{false};
std::string arch_;
int arch_gen_;
int max_ops_per_buffer_;
Expand Down
29 changes: 29 additions & 0 deletions mlx/backend/metal/metal.cpp
Original file line number Diff line number Diff line change
@@ -1,18 +1,35 @@
// Copyright © 2023-2024 Apple Inc.
#include <atomic>
#include <memory>

#include "mlx/backend/metal/allocator.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/metal.h"
#include "mlx/backend/metal/utils.h"

namespace mlx::core::metal {
namespace {
std::atomic<bool> g_residency_sets_enabled{true};
} // namespace

bool is_available() {
return true;
}

bool residency_sets_enabled() {
return g_residency_sets_enabled.load(std::memory_order_relaxed);
}

void set_residency_sets_enabled(bool enabled) {
g_residency_sets_enabled.store(enabled, std::memory_order_relaxed);
}

void start_capture(std::string path, NS::Object* object) {
auto pool = new_scoped_memory_pool();
set_residency_sets_enabled(false);
// Detach queue residency sets before capture starts to keep traces replayable
// when collecting derived counters.
metal::device(mlx::core::Device::gpu).on_capture_start();

auto descriptor = MTL::CaptureDescriptor::alloc()->init();
descriptor->setCaptureObject(object);
Expand All @@ -29,6 +46,9 @@ void start_capture(std::string path, NS::Object* object) {
bool started = manager->startCapture(descriptor, &error);
descriptor->release();
if (!started) {
set_residency_sets_enabled(true);
metal::device(mlx::core::Device::gpu).on_capture_stop();
metal::allocator().on_capture_stop();
std::ostringstream msg;
msg << "[metal::start_capture] Failed to start: "
<< error->localizedDescription()->utf8String();
Expand All @@ -45,6 +65,15 @@ void stop_capture() {
auto pool = new_scoped_memory_pool();
auto manager = MTL::CaptureManager::sharedCaptureManager();
manager->stopCapture();
set_residency_sets_enabled(true);
metal::device(mlx::core::Device::gpu).on_capture_stop();
metal::allocator().on_capture_stop();
}

bool is_capture_active() {
auto pool = new_scoped_memory_pool();
auto manager = MTL::CaptureManager::sharedCaptureManager();
return manager->isCapturing();
}

} // namespace mlx::core::metal
3 changes: 3 additions & 0 deletions mlx/backend/metal/metal.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ MLX_API bool is_available();
/** Capture a GPU trace, saving it to an absolute file `path` */
MLX_API void start_capture(std::string path = "");
MLX_API void stop_capture();
MLX_API bool is_capture_active();
MLX_API bool residency_sets_enabled();
MLX_API void set_residency_sets_enabled(bool enabled);

/** Get information about the GPU and system settings. */
MLX_API const
Expand Down
7 changes: 7 additions & 0 deletions mlx/backend/metal/no_metal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,13 @@ bool is_available() {

void start_capture(std::string) {}
void stop_capture() {}
bool is_capture_active() {
return false;
}
bool residency_sets_enabled() {
return false;
}
void set_residency_sets_enabled(bool) {}

const std::unordered_map<std::string, std::variant<std::string, size_t>>&
device_info() {
Expand Down
38 changes: 27 additions & 11 deletions mlx/backend/metal/resident.cpp
Original file line number Diff line number Diff line change
@@ -1,17 +1,32 @@
// Copyright © 2024 Apple Inc.

#include "mlx/backend/metal/resident.h"
#include "mlx/backend/metal/metal.h"

namespace mlx::core::metal {

ResidencySet::ResidencySet(MTL::Device* d) {
if (!d->supportsFamily(MTL::GPUFamilyMetal3)) {
ResidencySet::ResidencySet(MTL::Device* d) : device_(d) {}

void ResidencySet::ensure_wired_set_() {
if (wired_set_ != nullptr || device_ == nullptr) {
return;
}
if (!metal::residency_sets_enabled()) {
return;
}
if (!device_->supportsFamily(MTL::GPUFamilyMetal3)) {
return;
} else if (__builtin_available(macOS 15, iOS 18, *)) {
}
if (__builtin_available(macOS 15, iOS 18, *)) {
// Avoid creating residency sets while a Metal capture is active since this
// can make derived-counter replay unstable in Xcode.
if (metal::is_capture_active()) {
return;
}
auto pool = new_scoped_memory_pool();
auto desc = MTL::ResidencySetDescriptor::alloc()->init();
NS::Error* error;
wired_set_ = d->newResidencySet(desc, &error);
wired_set_ = device_->newResidencySet(desc, &error);
desc->release();
if (!wired_set_) {
std::ostringstream msg;
Expand All @@ -27,6 +42,7 @@ ResidencySet::ResidencySet(MTL::Device* d) {

void ResidencySet::insert(MTL::Allocation* buf) {
if (!wired_set_) {
unwired_set_.insert(buf);
return;
}
if (wired_set_->allocatedSize() + buf->allocatedSize() <= capacity_) {
Expand All @@ -38,26 +54,26 @@ void ResidencySet::insert(MTL::Allocation* buf) {
}

void ResidencySet::erase(MTL::Allocation* buf) {
if (!wired_set_) {
return;
}
if (auto it = unwired_set_.find(buf); it != unwired_set_.end()) {
unwired_set_.erase(it);
} else {
if (!wired_set_) {
return;
}
wired_set_->removeAllocation(buf);
wired_set_->commit();
}
}

void ResidencySet::resize(size_t size) {
if (!wired_set_) {
return;
}

if (capacity_ == size) {
return;
}
capacity_ = size;
ensure_wired_set_();
if (!wired_set_) {
return;
}

size_t current_size = wired_set_->allocatedSize();

Expand Down
3 changes: 3 additions & 0 deletions mlx/backend/metal/resident.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ class ResidencySet {
void resize(size_t size);

private:
void ensure_wired_set_();

MTL::Device* device_{nullptr};
MTL::ResidencySet* wired_set_{nullptr};
std::unordered_set<const MTL::Allocation*> unwired_set_;
size_t capacity_{0};
Expand Down