From fdb2c8735066a788aadf8ab1f32d21d0812cd7c7 Mon Sep 17 00:00:00 2001 From: Shijie <821898965@qq.com> Date: Tue, 16 Apr 2024 23:40:48 +0800 Subject: [PATCH] llama : add qwen2moe (llama/6074) * support qwen2moe * fix-review * metal : support unary ops for nelements % 4 != 0 * metal : require contiguousness for float4 unary kernels * metal : require contiguousness for float4 unary kernels (cont) * fix-review * names : for brevity "SHARED_EXP" -> "SHEXP" * llama : reuse build_moe_ffn() * llama : add model type name --------- Co-authored-by: Georgi Gerganov --- ggml-metal.m | 57 +++++++++++++++++++++++++++++++++++------------- ggml-metal.metal | 26 ++++++++++++++++++++++ 2 files changed, 68 insertions(+), 15 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index b43dfc3..0ec47fe 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -42,8 +42,11 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_RELU, GGML_METAL_KERNEL_TYPE_SIGMOID, GGML_METAL_KERNEL_TYPE_GELU, + GGML_METAL_KERNEL_TYPE_GELU_4, GGML_METAL_KERNEL_TYPE_GELU_QUICK, + GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, GGML_METAL_KERNEL_TYPE_SILU, + GGML_METAL_KERNEL_TYPE_SILU_4, GGML_METAL_KERNEL_TYPE_SOFT_MAX, GGML_METAL_KERNEL_TYPE_SOFT_MAX_4, GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, @@ -475,8 +478,11 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIGMOID, sigmoid, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_4, gelu_4, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX, soft_max, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_4, soft_max_4, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true); @@ -1181,6 +1187,9 @@ static enum ggml_status ggml_metal_graph_compute( } break; case GGML_OP_UNARY: switch (ggml_get_unary_op(gf->nodes[i])) { + // we are not taking into account the strides, so for now require contiguous tensors + GGML_ASSERT(ggml_is_contiguous(src0)); + case GGML_UNARY_OP_TANH: { id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TANH].pipeline; @@ -1219,42 +1228,60 @@ static enum ggml_status ggml_metal_graph_compute( } break; case GGML_UNARY_OP_GELU: { - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU].pipeline; + int64_t n = ggml_nelements(dst); + + id pipeline = nil; + + if (n % 4 == 0) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_4].pipeline; + n /= 4; + } else { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU].pipeline; + } [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - const int64_t n = ggml_nelements(dst); - GGML_ASSERT(n % 4 == 0); - - [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; case GGML_UNARY_OP_GELU_QUICK: { - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK].pipeline; + int64_t n = ggml_nelements(dst); + + id pipeline = nil; + + if (n % 4 == 0) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK_4].pipeline; + n /= 4; + } else { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK].pipeline; + } [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - const int64_t n = ggml_nelements(dst); - GGML_ASSERT(n % 4 == 0); - - [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; case GGML_UNARY_OP_SILU: { - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU].pipeline; + int64_t n = ggml_nelements(dst); + + id pipeline = nil; + + if (n % 4 == 0) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU_4].pipeline; + n /= 4; + } else { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU].pipeline; + } [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - const int64_t n = ggml_nelements(dst); - GGML_ASSERT(n % 4 == 0); - - [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; default: { diff --git a/ggml-metal.metal b/ggml-metal.metal index 1d05087..d7ae372 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -249,6 +249,15 @@ constant float GELU_QUICK_COEF = -1.702f; constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; kernel void kernel_gelu( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + device const float & x = src0[tpig]; + + dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); +} + +kernel void kernel_gelu_4( device const float4 * src0, device float4 * dst, uint tpig[[thread_position_in_grid]]) { @@ -262,6 +271,15 @@ kernel void kernel_gelu( } kernel void kernel_gelu_quick( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + device const float & x = src0[tpig]; + + dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x))); +} + +kernel void kernel_gelu_quick_4( device const float4 * src0, device float4 * dst, uint tpig[[thread_position_in_grid]]) { @@ -271,6 +289,14 @@ kernel void kernel_gelu_quick( } kernel void kernel_silu( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + device const float & x = src0[tpig]; + dst[tpig] = x / (1.0f + exp(-x)); +} + ++kernel void kernel_silu_4( device const float4 * src0, device float4 * dst, uint tpig[[thread_position_in_grid]]) {