ggml : sync ggml-metal.m

This commit is contained in:
Georgi Gerganov 2024-01-18 11:03:13 +02:00
parent 1f50a7d29f
commit fb466b3417
No known key found for this signature in database
GPG Key ID: BF970631944C16B7

View File

@ -170,9 +170,6 @@ struct ggml_metal_context {
id<MTLCommandQueue> queue; id<MTLCommandQueue> queue;
id<MTLLibrary> library; id<MTLLibrary> library;
id<MTLCommandBuffer> command_buffers [GGML_METAL_MAX_COMMAND_BUFFERS];
id<MTLComputeCommandEncoder> command_encoders[GGML_METAL_MAX_COMMAND_BUFFERS];
dispatch_queue_t d_queue; dispatch_queue_t d_queue;
int n_buffers; int n_buffers;
@ -241,21 +238,19 @@ static void * ggml_metal_host_malloc(size_t n) {
static struct ggml_metal_context * ggml_metal_init(int n_cb) { static struct ggml_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_LOG_INFO("%s: allocating\n", __func__); GGML_METAL_LOG_INFO("%s: allocating\n", __func__);
id<MTLDevice> device; #if TARGET_OS_OSX && !GGML_METAL_NDEBUG
NSString * s;
#if TARGET_OS_OSX
// Show all the Metal device instances in the system // Show all the Metal device instances in the system
NSArray * devices = MTLCopyAllDevices(); NSArray * devices = MTLCopyAllDevices();
for (device in devices) { for (id<MTLDevice> device in devices) {
s = [device name]; NSString * s = [device name];
GGML_METAL_LOG_INFO("%s: found device: %s\n", __func__, [s UTF8String]); GGML_METAL_LOG_INFO("%s: found device: %s\n", __func__, [s UTF8String]);
} }
[devices release]; // since it was created by a *Copy* C method
#endif #endif
// Pick and show default Metal device // Pick and show default Metal device
device = MTLCreateSystemDefaultDevice(); id<MTLDevice> device = MTLCreateSystemDefaultDevice();
s = [device name]; NSString * s = [device name];
GGML_METAL_LOG_INFO("%s: picking default device: %s\n", __func__, [s UTF8String]); GGML_METAL_LOG_INFO("%s: picking default device: %s\n", __func__, [s UTF8String]);
// Configure context // Configure context
@ -715,44 +710,41 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
static bool ggml_metal_graph_compute( static bool ggml_metal_graph_compute(
struct ggml_metal_context * ctx, struct ggml_metal_context * ctx,
struct ggml_cgraph * gf) { struct ggml_cgraph * gf) {
@autoreleasepool {
MTLComputePassDescriptor * edesc = MTLComputePassDescriptor.computePassDescriptor; MTLComputePassDescriptor * edesc = MTLComputePassDescriptor.computePassDescriptor;
const int n_nodes = gf->n_nodes;
edesc.dispatchType = MTLDispatchTypeSerial; edesc.dispatchType = MTLDispatchTypeSerial;
// create multiple command buffers and enqueue them // create multiple command buffers and enqueue them
// then, we encode the graph into the command buffers in parallel // then, we encode the graph into the command buffers in parallel
const int n_nodes = gf->n_nodes;
const int n_cb = ctx->n_cb; const int n_cb = ctx->n_cb;
for (int i = 0; i < n_cb; ++i) {
ctx->command_buffers[i] = [ctx->queue commandBuffer];
// enqueue the command buffers in order to specify their execution order
[ctx->command_buffers[i] enqueue];
ctx->command_encoders[i] = [ctx->command_buffers[i] computeCommandEncoderWithDescriptor: edesc];
}
for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
const int n_nodes_per_cb = (n_nodes + n_cb - 1) / n_cb; const int n_nodes_per_cb = (n_nodes + n_cb - 1) / n_cb;
dispatch_async(ctx->d_queue, ^{ id<MTLCommandBuffer> command_buffer_builder[n_cb];
for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
id<MTLCommandBuffer> command_buffer = [ctx->queue commandBufferWithUnretainedReferences];
command_buffer_builder[cb_idx] = command_buffer;
// enqueue the command buffers in order to specify their execution order
[command_buffer enqueue];
}
const id<MTLCommandBuffer> *command_buffers = command_buffer_builder;
dispatch_apply(n_cb, ctx->d_queue, ^(size_t iter) {
const int cb_idx = iter;
size_t offs_src0 = 0; size_t offs_src0 = 0;
size_t offs_src1 = 0; size_t offs_src1 = 0;
size_t offs_dst = 0; size_t offs_dst = 0;
id<MTLCommandBuffer> command_buffer = ctx->command_buffers[cb_idx]; id<MTLCommandBuffer> command_buffer = command_buffers[cb_idx];
id<MTLComputeCommandEncoder> encoder = ctx->command_encoders[cb_idx]; id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
const int node_start = (cb_idx + 0) * n_nodes_per_cb; const int node_start = (cb_idx + 0) * n_nodes_per_cb;
const int node_end = MIN((cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb, n_nodes); const int node_end = MIN((cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb, n_nodes);
for (int ind = node_start; ind < node_end; ++ind) { for (int i = node_start; i < node_end; ++i) {
const int i = ind;
if (i == -1) { if (i == -1) {
[encoder memoryBarrierWithScope:MTLBarrierScopeBuffers]; [encoder memoryBarrierWithScope:MTLBarrierScopeBuffers];
continue; continue;
@ -2240,24 +2232,19 @@ static bool ggml_metal_graph_compute(
#endif #endif
} }
if (encoder != nil) {
[encoder endEncoding]; [encoder endEncoding];
encoder = nil;
}
[command_buffer commit]; [command_buffer commit];
}); });
}
// wait for all threads to finish // Wait for completion and check status of each command buffer
dispatch_barrier_sync(ctx->d_queue, ^{});
// check status of command buffers
// needed to detect if the device ran out-of-memory for example (#1881) // needed to detect if the device ran out-of-memory for example (#1881)
for (int i = 0; i < n_cb; i++) {
[ctx->command_buffers[i] waitUntilCompleted];
MTLCommandBufferStatus status = (MTLCommandBufferStatus) [ctx->command_buffers[i] status]; for (int i = 0; i < n_cb; ++i) {
id<MTLCommandBuffer> command_buffer = command_buffers[i];
[command_buffer waitUntilCompleted];
MTLCommandBufferStatus status = [command_buffer status];
if (status != MTLCommandBufferStatusCompleted) { if (status != MTLCommandBufferStatusCompleted) {
GGML_METAL_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status); GGML_METAL_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status);
return false; return false;
@ -2265,7 +2252,6 @@ static bool ggml_metal_graph_compute(
} }
return true; return true;
}
} }
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////