diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 5607386..ce28cb5 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -401,10 +401,8 @@ GGML_CALL static void * ggml_backend_cuda_buffer_get_base(ggml_backend_buffer_t GGML_CALL static void ggml_backend_cuda_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) { ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context; - if (tensor->view_src != NULL && tensor->view_offs == 0) { + if (tensor->view_src != NULL) { assert(tensor->view_src->buffer->buft == buffer->buft); - tensor->backend = tensor->view_src->backend; - tensor->extra = tensor->view_src->extra; return; } @@ -1962,227 +1960,49 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor } } -#if 0 -template -static __global__ void k_compute_batched_ptrs_id( - const void ** ptrs_src, void ** ptrs_dst, - int ne12, int ne13, - int ne23, - int nb02, int nb03, - int nb12, int nb13, - int nb2, int nb3, - int r2, int r3, - ggml_type src0_type, half * src0_as_f16, int64_t src0_ne, - const half * src1_f16, half * dst_f16, - const int32_t * ids, const int id, - Srcs... src0s) { - - int i = ids[id]; - - half * src0_f16; - const void * srcs_ar[] = { (const half *) src0s... }; - if (src0_type == GGML_TYPE_F16) { - src0_f16 = (half *) srcs_ar[i]; - } else { - src0_f16 = src0_as_f16; - if (threadIdx.x == 0 && threadIdx.y == 0) { - const to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(src0_type); - to_fp16(srcs_ar[i], src0_f16, src0_ne, cudaStreamFireAndForget); - } - } - - int i13 = blockIdx.x * blockDim.x + threadIdx.x; - int i12 = blockIdx.y * blockDim.y + threadIdx.y; - - if (i13 >= ne13 || i12 >= ne12) { - return; - } - - int i03 = i13 / r3; - int i02 = i12 / r2; - - ptrs_src[0*ne23 + i12 + i13*ne12] = (const char *) src0_f16 + i02*nb02 + i03*nb03; - ptrs_src[1*ne23 + i12 + i13*ne12] = (const char *) src1_f16 + i12*nb12/2 + i13*nb13/2; - ptrs_dst[0*ne23 + i12 + i13*ne12] = ( char *) dst_f16 + i12* nb2/2 + i13* nb3/2; -} - -static void ggml_cuda_mul_mat_id_cublas(ggml_tensor * dst) { - const struct ggml_tensor * ids = dst->src[0]; - const struct ggml_tensor * src1 = dst->src[1]; - const struct ggml_tensor * src00 = dst->src[2]; - - const int id = dst->op_params[0]; - - GGML_ASSERT(!ggml_is_transposed(src00)); - GGML_ASSERT(!ggml_is_transposed(src1)); - - GGML_ASSERT(src00->backend != GGML_BACKEND_TYPE_GPU_SPLIT); - GGML_ASSERT(src1->type == GGML_TYPE_F32); - - const int64_t ne00 = src00->ne[0]; GGML_UNUSED(ne00); - const int64_t ne01 = src00->ne[1]; - const int64_t ne02 = src00->ne[2]; - const int64_t ne03 = src00->ne[3]; - - //const int64_t nb01 = src00->nb[1]; - const int64_t nb02 = src00->nb[2]; GGML_UNUSED(nb02); - const int64_t nb03 = src00->nb[3]; GGML_UNUSED(nb03); - - const int64_t ne10 = src1->ne[0]; - const int64_t ne11 = src1->ne[1]; - const int64_t ne12 = src1->ne[2]; - const int64_t ne13 = src1->ne[3]; - - //const int64_t nb11 = src1->nb[1]; - const int64_t nb12 = src1->nb[2]; GGML_UNUSED(nb12); - const int64_t nb13 = src1->nb[3]; GGML_UNUSED(nb13); - - const int64_t ne1 = ggml_nelements(src1); - const int64_t ne = ggml_nelements(dst); - - ggml_cuda_set_device(g_main_device); - cudaStream_t main_stream = g_cudaStreams[g_main_device][0]; - - CUBLAS_CHECK(cublasSetStream(g_cublas_handles[g_main_device], main_stream)); - - //ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra; - //void * src0_ddq = src0_extra->data_device[g_main_device]; - //half * src0_as_f16 = (half *) src0_ddq; - - ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra; - float * src1_ddf = (float *) src1_extra->data_device[g_main_device]; - - ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra; - float * dst_ddf = (float *) dst_extra->data_device[g_main_device]; - - // convert src1 to fp16 - const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type); - GGML_ASSERT(to_fp16_cuda != nullptr); - - size_t src1_as = 0; - half * src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne1 * sizeof(half), &src1_as); - to_fp16_cuda(src1_ddf, src1_as_f16, ne1, main_stream); - - size_t dst_as = 0; - half * dst_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &dst_as); - - GGML_ASSERT(ne12 % ne02 == 0); - GGML_ASSERT(ne13 % ne03 == 0); - - // broadcast factors - const int64_t r2 = ne12/ne02; - const int64_t r3 = ne13/ne03; - - const half alpha_f16 = 1.0f; - const half beta_f16 = 0.0f; - - // use cublasGemmBatchedEx - const int ne23 = ne12*ne13; - - const void ** ptrs_src = nullptr; - void ** ptrs_dst = nullptr; - - size_t ptrs_src_s = 0; - size_t ptrs_dst_s = 0; - - ptrs_src = (const void **) ggml_cuda_pool_malloc(2*ne23*sizeof(void *), &ptrs_src_s); - ptrs_dst = ( void **) ggml_cuda_pool_malloc(1*ne23*sizeof(void *), &ptrs_dst_s); - - int64_t src0_ne = ggml_nelements(src00); - half * src0_as_f16 = nullptr; - size_t src0_as = 0; - if (src00->type != GGML_TYPE_F16) { - src0_as_f16 = (half *) ggml_cuda_pool_malloc(src0_ne * sizeof(half), &src0_as); - } - - static_assert(GGML_MAX_SRC == 6, "GGML_MAX_SRC == 6"); - dim3 block_dims(ne13, ne12); - k_compute_batched_ptrs_id<<<1, block_dims, 0, main_stream>>>( - ptrs_src, ptrs_dst, - ne12, ne13, - ne23, - ne00*ne01*sizeof(half), ne00*ne01*ne02*sizeof(half), - nb12, nb13, - dst->nb[2], dst->nb[3], - r2, r3, - src00->type, src0_as_f16, src0_ne, - src1_as_f16, dst_f16, - (const int *)((ggml_tensor_extra_gpu *)ids->extra)->data_device[g_main_device], id, - dst->src[2] ? (const half *)((ggml_tensor_extra_gpu *)dst->src[2]->extra)->data_device[g_main_device] : nullptr, - dst->src[3] ? (const half *)((ggml_tensor_extra_gpu *)dst->src[3]->extra)->data_device[g_main_device] : nullptr, - dst->src[4] ? (const half *)((ggml_tensor_extra_gpu *)dst->src[4]->extra)->data_device[g_main_device] : nullptr, - dst->src[5] ? (const half *)((ggml_tensor_extra_gpu *)dst->src[5]->extra)->data_device[g_main_device] : nullptr - ); - CUDA_CHECK(cudaGetLastError()); - - CUBLAS_CHECK( - cublasGemmBatchedEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N, - ne01, ne11, ne10, - &alpha_f16, (const void **) (ptrs_src + 0*ne23), CUDA_R_16F, ne00, - (const void **) (ptrs_src + 1*ne23), CUDA_R_16F, ne10, - &beta_f16, ( void **) (ptrs_dst + 0*ne23), CUDA_R_16F, ne01, - ne23, - CUBLAS_COMPUTE_16F, - CUBLAS_GEMM_DEFAULT_TENSOR_OP)); - - if (src0_as != 0) { - ggml_cuda_pool_free(src0_as_f16, src0_as); - } - if (ptrs_src_s != 0) { - ggml_cuda_pool_free(ptrs_src, ptrs_src_s); - } - if (ptrs_dst_s != 0) { - ggml_cuda_pool_free(ptrs_dst, ptrs_dst_s); - } - - const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16); - to_fp32_cuda(dst_f16, dst_ddf, ne, main_stream); - - ggml_cuda_pool_free(src1_as_f16, src1_as); - ggml_cuda_pool_free(dst_f16, dst_as); -} -#endif - static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { -#if 0 - ggml_cuda_mul_mat_id_cublas(dst); - // TODO: mmq/mmv support -#endif const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src1 = dst->src[1]; + const ggml_tensor * ids = dst->src[2]; + + GGML_ASSERT(!ggml_backend_buffer_is_cuda_split(src0->buffer) && "mul_mat_id does not support split buffers"); cudaStream_t stream = ctx.stream(); const size_t nb11 = src1->nb[1]; const size_t nb1 = dst->nb[1]; - const struct ggml_tensor * ids = src0; const int32_t id = ((int32_t *) dst->op_params)[0]; - const int32_t n_as = ((int32_t *) dst->op_params)[1]; + const int32_t n_as = src0->ne[2]; std::vector ids_host(ggml_nbytes(ids)); const char * ids_dev = (const char *) ids->data; CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream)); CUDA_CHECK(cudaStreamSynchronize(stream)); + ggml_tensor src0_row = *src0; ggml_tensor src1_row = *src1; ggml_tensor dst_row = *dst; + char * src0_original = (char *) src0->data; char * src1_original = (char *) src1->data; char * dst_original = (char *) dst->data; + src0_row.ne[2] = 1; + src0_row.ne[3] = 1; + src0_row.nb[3] = src0->nb[2]; + if (src1->ne[1] == 1) { for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) { const int32_t row_id = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]); GGML_ASSERT(row_id >= 0 && row_id < n_as); - const struct ggml_tensor * src0_row = dst->src[row_id + 2]; - + src0_row.data = src0_original + row_id*src0->nb[2]; src1_row.data = src1_original + i01*src1->nb[1]; dst_row.data = dst_original + i01*dst->nb[1]; - ggml_cuda_mul_mat(ctx, src0_row, &src1_row, &dst_row); + ggml_cuda_mul_mat(ctx, &src0_row, &src1_row, &dst_row); } } else { ggml_cuda_pool_alloc src1_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(src1)); @@ -2192,8 +2012,6 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst_row.data = dst_contiguous.get(); for (int32_t row_id = 0; row_id < n_as; ++row_id) { - const struct ggml_tensor * src0_row = dst->src[row_id + 2]; - int64_t num_src1_rows = 0; for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) { const int32_t row_id_i = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]); @@ -2213,6 +2031,8 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * continue; } + src0_row.data = src0_original + row_id*src0->nb[2]; + src1_row.ne[1] = num_src1_rows; dst_row.ne[1] = num_src1_rows; @@ -2224,7 +2044,7 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst_row.nb[2] = num_src1_rows*nb1; dst_row.nb[3] = num_src1_rows*nb1; - ggml_cuda_mul_mat(ctx, src0_row, &src1_row, &dst_row); + ggml_cuda_mul_mat(ctx, &src0_row, &src1_row, &dst_row); num_src1_rows = 0; for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) { @@ -2389,7 +2209,7 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg cudaError_t err = cudaGetLastError(); if (err != cudaSuccess) { fprintf(stderr, "%s: %s failed\n", __func__, ggml_op_desc(dst)); - GGML_ASSERT(false); + CUDA_CHECK(err); } return true; diff --git a/ggml-cuda/argsort.cu b/ggml-cuda/argsort.cu index 1333287..1641440 100644 --- a/ggml-cuda/argsort.cu +++ b/ggml-cuda/argsort.cu @@ -8,32 +8,41 @@ static inline __device__ void ggml_cuda_swap(T & a, T & b) { } template -static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int ncols) { +static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int ncols, int ncols_pad) { // bitonic sort int col = threadIdx.x; int row = blockIdx.y; - if (col >= ncols) return; + if (col >= ncols_pad) { + return; + } const float * x_row = x + row * ncols; - int * dst_row = dst + row * ncols; + extern __shared__ int dst_row[]; // initialize indices - if (col < ncols) { - dst_row[col] = col; - } + dst_row[col] = col; + __syncthreads(); - for (int k = 2; k <= ncols; k *= 2) { + for (int k = 2; k <= ncols_pad; k *= 2) { for (int j = k / 2; j > 0; j /= 2) { int ixj = col ^ j; if (ixj > col) { if ((col & k) == 0) { - if (order == GGML_SORT_ORDER_ASC ? x_row[dst_row[col]] > x_row[dst_row[ixj]] : x_row[dst_row[col]] < x_row[dst_row[ixj]]) { + if (dst_row[col] >= ncols || + (dst_row[ixj] < ncols && (order == GGML_SORT_ORDER_ASC ? + x_row[dst_row[col]] > x_row[dst_row[ixj]] : + x_row[dst_row[col]] < x_row[dst_row[ixj]])) + ) { ggml_cuda_swap(dst_row[col], dst_row[ixj]); } } else { - if (order == GGML_SORT_ORDER_ASC ? x_row[dst_row[col]] < x_row[dst_row[ixj]] : x_row[dst_row[col]] > x_row[dst_row[ixj]]) { + if (dst_row[ixj] >= ncols || + (dst_row[col] < ncols && (order == GGML_SORT_ORDER_ASC ? + x_row[dst_row[col]] < x_row[dst_row[ixj]] : + x_row[dst_row[col]] > x_row[dst_row[ixj]])) + ) { ggml_cuda_swap(dst_row[col], dst_row[ixj]); } } @@ -41,18 +50,35 @@ static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int n __syncthreads(); } } + + // copy the result to dst without the padding + if (col < ncols) { + dst[row * ncols + col] = dst_row[col]; + } +} + +static int next_power_of_2(int x) { + int n = 1; + while (n < x) { + n *= 2; + } + return n; } static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, const int nrows, ggml_sort_order order, cudaStream_t stream) { // bitonic sort requires ncols to be power of 2 - GGML_ASSERT((ncols & (ncols - 1)) == 0); + const int ncols_pad = next_power_of_2(ncols); - const dim3 block_dims(ncols, 1, 1); + const dim3 block_dims(ncols_pad, 1, 1); const dim3 block_nums(1, nrows, 1); + const size_t shared_mem = ncols_pad * sizeof(int); + + GGML_ASSERT(shared_mem <= ggml_cuda_info().devices[ggml_cuda_get_device()].smpb); + if (order == GGML_SORT_ORDER_ASC) { - k_argsort_f32_i32<<>>(x, dst, ncols); + k_argsort_f32_i32<<>>(x, dst, ncols, ncols_pad); } else if (order == GGML_SORT_ORDER_DESC) { - k_argsort_f32_i32<<>>(x, dst, ncols); + k_argsort_f32_i32<<>>(x, dst, ncols, ncols_pad); } else { GGML_ASSERT(false); } diff --git a/ggml-metal.m b/ggml-metal.m index a08abbc..419d8b9 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -1685,37 +1685,31 @@ static enum ggml_status ggml_metal_graph_compute( { //GGML_ASSERT(ne00 == ne10); //GGML_ASSERT(ne03 == ne13); - - GGML_ASSERT(src0t == GGML_TYPE_I32); - - const int n_as = ((int32_t *) dst->op_params)[1]; - - // TODO: make this more general - GGML_ASSERT(n_as <= 8); + const int n_as = src0->ne[2]; // max size of the src1ids array in the kernel shared buffer GGML_ASSERT(ne11 <= 4096); - const int64_t ne20 = src2 ? src2->ne[0] : 0; - const int64_t ne21 = src2 ? src2->ne[1] : 0; - const int64_t ne22 = src2 ? src2->ne[2] : 0; - const int64_t ne23 = src2 ? src2->ne[3] : 0; GGML_UNUSED(ne23); + // src2 = ids + const int64_t ne20 = src2->ne[0]; GGML_UNUSED(ne20); + const int64_t ne21 = src2->ne[1]; + const int64_t ne22 = src2->ne[2]; GGML_UNUSED(ne22); + const int64_t ne23 = src2->ne[3]; GGML_UNUSED(ne23); - const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20); - const uint64_t nb21 = src2 ? src2->nb[1] : 0; - const uint64_t nb22 = src2 ? src2->nb[2] : 0; - const uint64_t nb23 = src2 ? src2->nb[3] : 0; GGML_UNUSED(nb23); + const uint64_t nb20 = src2->nb[0]; GGML_UNUSED(nb20); + const uint64_t nb21 = src2->nb[1]; + const uint64_t nb22 = src2->nb[2]; GGML_UNUSED(nb22); + const uint64_t nb23 = src2->nb[3]; GGML_UNUSED(nb23); - const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t); + const enum ggml_type src2t = src2->type; GGML_UNUSED(src2t); - GGML_ASSERT(!ggml_is_transposed(src2)); + GGML_ASSERT(src2t == GGML_TYPE_I32); + + GGML_ASSERT(!ggml_is_transposed(src0)); GGML_ASSERT(!ggml_is_transposed(src1)); GGML_ASSERT(src1t == GGML_TYPE_F32); - const uint r2 = ne12/ne22; - const uint r3 = ne13/ne23; - // find the break-even point where the matrix-matrix kernel becomes more efficient compared // to the matrix-vector kernel int ne11_mm_min = n_as; @@ -1723,7 +1717,10 @@ static enum ggml_status ggml_metal_graph_compute( const int idx = ((int32_t *) dst->op_params)[0]; // batch size - GGML_ASSERT(ne01 == ne11); + GGML_ASSERT(ne21 == ne11); // ? + GGML_ASSERT(ne12 == 1 && ne13 == 1); // no broadcasting + const uint r2 = 1; + const uint r3 = 1; // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel @@ -1732,7 +1729,7 @@ static enum ggml_status ggml_metal_graph_compute( // indirect matrix multiplication // !!! if ([ctx->device supportsFamily:MTLGPUFamilyApple7] && - ne20 % 32 == 0 && ne20 >= 64 && + ne00 % 32 == 0 && ne00 >= 64 && ne11 > ne11_mm_min) { // some Metal matrix data types require aligned pointers @@ -1745,7 +1742,7 @@ static enum ggml_status ggml_metal_graph_compute( id pipeline = nil; - switch (src2->type) { + switch (src0->type) { case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32 ].pipeline; break; case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32 ].pipeline; break; case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32 ].pipeline; break; @@ -1774,36 +1771,27 @@ static enum ggml_status ggml_metal_graph_compute( [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3]; - [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4]; - [encoder setBytes:&ne22 length:sizeof(ne22) atIndex:5]; - [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6]; - [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:7]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:8]; - [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:9]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:10]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:11]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:12]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15]; - [encoder setBytes:&r2 length:sizeof(r2) atIndex:16]; - [encoder setBytes:&r3 length:sizeof(r3) atIndex:17]; - [encoder setBytes:&idx length:sizeof(idx) atIndex:18]; - // TODO: how to make this an array? read Metal docs - for (int j = 0; j < 8; ++j) { - // NOTE: this is done like this to avoid uninitialized kernel arguments when n_as < 8 - struct ggml_tensor * src_cur = dst->src[2 + (j % n_as)]; - - size_t offs_src_cur = 0; - id id_src_cur = ggml_metal_get_buffer(src_cur, &offs_src_cur); - - [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:19 + j]; - } + [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3]; + [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:4]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:5]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:6]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; + [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:9]; + [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:10]; + [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:11]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:12]; + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:13]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:14]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:15]; + [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:16]; + [encoder setBytes:&r2 length:sizeof(r2) atIndex:17]; + [encoder setBytes:&r3 length:sizeof(r3) atIndex:18]; + [encoder setBytes:&idx length:sizeof(idx) atIndex:19]; [encoder setThreadgroupMemoryLength:GGML_PAD(8192 + 2*ne11, 16) atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake((ne11 + 31)/32, (ne21 + 63)/64, n_as*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake((ne11 + 31)/32, (ne01 + 63)/64, n_as*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; } else { int nth0 = 32; int nth1 = 1; @@ -1813,7 +1801,7 @@ static enum ggml_status ggml_metal_graph_compute( id pipeline = nil; // use custom matrix x vector kernel - switch (src2t) { + switch (src0t) { case GGML_TYPE_F32: { GGML_ASSERT(src1t == GGML_TYPE_F32); @@ -1947,8 +1935,8 @@ static enum ggml_status ggml_metal_graph_compute( } }; - if (ggml_is_quantized(src2t)) { - GGML_ASSERT(ne20 >= nth0*nth1); + if (ggml_is_quantized(src0t)) { + GGML_ASSERT(ne00 >= nth0*nth1); } const int64_t _ne1 = 1; // kernels needs a reference in constant memory @@ -1957,75 +1945,66 @@ static enum ggml_status ggml_metal_graph_compute( [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3]; - [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4]; - [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5]; - [encoder setBytes:&ne22 length:sizeof(ne22) atIndex:6]; - [encoder setBytes:&nb20 length:sizeof(nb20) atIndex:7]; - [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:8]; - [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:9]; - [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10]; - [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:11]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12]; - [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17]; - [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:18]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:19]; - [encoder setBytes:&r2 length:sizeof(r2) atIndex:20]; - [encoder setBytes:&r3 length:sizeof(r3) atIndex:21]; - [encoder setBytes:&idx length:sizeof(idx) atIndex:22]; - // TODO: how to make this an array? read Metal docs - for (int j = 0; j < 8; ++j) { - // NOTE: this is done like this to avoid uninitialized kernel arguments when n_as < 8 - struct ggml_tensor * src_cur = dst->src[2 + (j % n_as)]; + [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3]; + [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:4]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:5]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:6]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:7]; + [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:8]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:9]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:10]; + [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11]; + [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:12]; + [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13]; + [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14]; + [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16]; + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:18]; + [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:19]; + [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:20]; + [encoder setBytes:&r2 length:sizeof(r2) atIndex:21]; + [encoder setBytes:&r3 length:sizeof(r3) atIndex:22]; + [encoder setBytes:&idx length:sizeof(idx) atIndex:23]; - size_t offs_src_cur = 0; - id id_src_cur = ggml_metal_get_buffer(src_cur, &offs_src_cur); - - [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:23 + j]; + if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 || + src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K || + src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) { + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } - - if (src2t == GGML_TYPE_Q4_0 || src2t == GGML_TYPE_Q4_1 || src2t == GGML_TYPE_Q5_0 || - src2t == GGML_TYPE_Q5_1 || src2t == GGML_TYPE_Q8_0 || src2t == GGML_TYPE_Q2_K || - src2t == GGML_TYPE_IQ1_S || src2t == GGML_TYPE_IQ1_M || src2t == GGML_TYPE_IQ2_S) { - [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src2t == GGML_TYPE_IQ2_XXS || src2t == GGML_TYPE_IQ2_XS) { - const int mem_size = src2t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128; + else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) { + const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128; [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } - else if (src2t == GGML_TYPE_IQ3_XXS || src2t == GGML_TYPE_IQ3_S) { - const int mem_size = src2t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4; + else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_S) { + const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4; [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } - else if (src2t == GGML_TYPE_IQ4_NL || src2t == GGML_TYPE_IQ4_XS) { + else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS) { const int mem_size = 32*sizeof(float); [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } - else if (src2t == GGML_TYPE_Q4_K) { - [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + else if (src0t == GGML_TYPE_Q4_K) { + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } - else if (src2t == GGML_TYPE_Q3_K) { + else if (src0t == GGML_TYPE_Q3_K) { #ifdef GGML_QKK_64 - [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; #else - [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; #endif } - else if (src2t == GGML_TYPE_Q5_K) { - [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + else if (src0t == GGML_TYPE_Q5_K) { + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } - else if (src2t == GGML_TYPE_Q6_K) { - [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + else if (src0t == GGML_TYPE_Q6_K) { + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } else { const int64_t ny = (_ne1 + nrows - 1)/nrows; - [encoder dispatchThreadgroups:MTLSizeMake(ne21, ny, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } } } break; @@ -2432,6 +2411,16 @@ static enum ggml_status ggml_metal_graph_compute( enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0]; + // bitonic sort requires the number of elements to be power of 2 + int64_t ne00_padded = 1; + while (ne00_padded < ne00) { + ne00_padded *= 2; + } + + // Metal kernels require the buffer size to be multiple of 16 bytes + // https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/1443142-setthreadgroupmemorylength + const int mem_size = GGML_PAD(ne00_padded*sizeof(int32_t), 16); + id pipeline = nil; switch (order) { @@ -2441,11 +2430,13 @@ static enum ggml_status ggml_metal_graph_compute( }; [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; + [encoder setBytes:&ne00_padded length:sizeof( int64_t) atIndex:3]; + [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00, 1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00_padded, 1, 1)]; } break; case GGML_OP_LEAKY_RELU: { diff --git a/ggml-metal.metal b/ggml-metal.metal index 744b2a8..9a29f57 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -13,8 +13,8 @@ using namespace metal; #define N_SIMDWIDTH 32 // assuming SIMD group size is 32 enum ggml_sort_order { - GGML_SORT_ASC, - GGML_SORT_DESC, + GGML_SORT_ORDER_ASC, + GGML_SORT_ORDER_DESC, }; // general-purpose kernel for addition, multiplication and division of two tensors @@ -1973,9 +1973,11 @@ kernel void kernel_timestep_embedding_f32( // bitonic sort implementation following the CUDA kernels as reference typedef void (argsort_t)( - device const float * x, - device int32_t * dst, - constant int64_t & ncols, + device const float * x, + device int32_t * dst, + constant int64_t & ncols, + constant int64_t & ncols_pad, + threadgroup int32_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]]); @@ -1984,33 +1986,42 @@ kernel void kernel_argsort_f32_i32( device const float * x, device int32_t * dst, constant int64_t & ncols, + constant int64_t & ncols_pad, + threadgroup int32_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]]) { // bitonic sort int col = tpitg[0]; int row = tgpig[1]; - if (col >= ncols) return; + if (col >= ncols_pad) return; - device const float * x_row = x + row * ncols; - device int32_t * dst_row = dst + row * ncols; + device const float * x_row = x + row * ncols; + threadgroup int32_t * dst_row = shared_values; // initialize indices - if (col < ncols) { - dst_row[col] = col; - } + dst_row[col] = col; + threadgroup_barrier(mem_flags::mem_threadgroup); - for (int k = 2; k <= ncols; k *= 2) { + for (int k = 2; k <= ncols_pad; k *= 2) { for (int j = k / 2; j > 0; j /= 2) { int ixj = col ^ j; if (ixj > col) { if ((col & k) == 0) { - if (order == GGML_SORT_ASC ? x_row[dst_row[col]] > x_row[dst_row[ixj]] : x_row[dst_row[col]] < x_row[dst_row[ixj]]) { + if (dst_row[col] >= ncols || + (dst_row[ixj] < ncols && (order == GGML_SORT_ORDER_ASC ? + x_row[dst_row[col]] > x_row[dst_row[ixj]] : + x_row[dst_row[col]] < x_row[dst_row[ixj]])) + ) { SWAP(dst_row[col], dst_row[ixj]); } } else { - if (order == GGML_SORT_ASC ? x_row[dst_row[col]] < x_row[dst_row[ixj]] : x_row[dst_row[col]] > x_row[dst_row[ixj]]) { + if (dst_row[ixj] >= ncols || + (dst_row[col] < ncols && (order == GGML_SORT_ORDER_ASC ? + x_row[dst_row[col]] < x_row[dst_row[ixj]] : + x_row[dst_row[col]] > x_row[dst_row[ixj]])) + ) { SWAP(dst_row[col], dst_row[ixj]); } } @@ -2018,10 +2029,15 @@ kernel void kernel_argsort_f32_i32( threadgroup_barrier(mem_flags::mem_threadgroup); } } + + // copy the result to dst without the padding + if (col < ncols) { + dst[row * ncols + col] = dst_row[col]; + } } -template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32; -template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32; +template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32; +template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32; kernel void kernel_leaky_relu_f32( device const float * src0, @@ -5785,9 +5801,10 @@ kernel void kernel_mul_mm(device const uchar * src0, template kernel void kernel_mul_mm_id( - device const uchar * ids, + device const uchar * src0s, device const uchar * src1, device float * dst, + device const uchar * ids, constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne02, @@ -5804,22 +5821,14 @@ kernel void kernel_mul_mm_id( constant uint & r2, constant uint & r3, constant int & idx, - device const uchar * src00, - device const uchar * src01, - device const uchar * src02, - device const uchar * src03, - device const uchar * src04, - device const uchar * src05, - device const uchar * src06, - device const uchar * src07, threadgroup uchar * shared_memory [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const uchar * src0s[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; // expert id const int32_t id = tgpig.z/(ne12*ne13); + device const uchar * src0 = src0s + id*nb02; tgpig.z = tgpig.z%(ne12*ne13); @@ -5834,7 +5843,7 @@ kernel void kernel_mul_mm_id( } kernel_mul_mm_id_impl( - src0s[id], + src0, src1, src1ids, dst, @@ -5960,9 +5969,10 @@ template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_m // typedef void (mat_mm_id_t)( - device const uchar * ids, + device const uchar * src0s, device const uchar * src1, device float * dst, + device const uchar * ids, constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne02, @@ -5979,14 +5989,6 @@ typedef void (mat_mm_id_t)( constant uint & r2, constant uint & r3, constant int & idx, - device const uchar * src00, - device const uchar * src01, - device const uchar * src02, - device const uchar * src03, - device const uchar * src04, - device const uchar * src05, - device const uchar * src06, - device const uchar * src07, threadgroup uchar *, uint3, uint, uint); @@ -6022,9 +6024,10 @@ template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel [[host_name("kernel_mul_mv_id_f32_f32")]] kernel void kernel_mul_mv_id_f32_f32( - device const char * ids, + device const char * src0s, device const char * src1, device float * dst, + device const char * ids, constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne01, @@ -6045,28 +6048,19 @@ kernel void kernel_mul_mv_id_f32_f32( constant uint & r2, constant uint & r3, constant int & idx, - device const char * src00, - device const char * src01, - device const char * src02, - device const char * src03, - device const char * src04, - device const char * src05, - device const char * src06, - device const char * src07, uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; - const int64_t bid = tgpig.z/(ne12*ne13); tgpig.z = tgpig.z%(ne12*ne13); const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + device const char * src0 = src0s + id*nb02; kernel_mul_mv_f32_f32_impl( - src0[id], + src0, src1 + bid*nb11, dst + bid*ne0, ne00, @@ -6091,9 +6085,10 @@ kernel void kernel_mul_mv_id_f32_f32( [[host_name("kernel_mul_mv_id_f16_f32")]] kernel void kernel_mul_mv_id_f16_f32( - device const char * ids, + device const char * src0s, device const char * src1, device float * dst, + device const char * ids, constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne01, @@ -6114,28 +6109,19 @@ kernel void kernel_mul_mv_id_f16_f32( constant uint & r2, constant uint & r3, constant int & idx, - device const char * src00, - device const char * src01, - device const char * src02, - device const char * src03, - device const char * src04, - device const char * src05, - device const char * src06, - device const char * src07, uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; - const int64_t bid = tgpig.z/(ne12*ne13); tgpig.z = tgpig.z%(ne12*ne13); const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + device const char * src0 = src0s + id*nb02; kernel_mul_mv_f16_f32_impl( - src0[id], + src0, src1 + bid*nb11, dst + bid*ne0, ne00, @@ -6160,9 +6146,10 @@ kernel void kernel_mul_mv_id_f16_f32( [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel void kernel_mul_mv_id_q8_0_f32( - device const char * ids, + device const char * src0s, device const char * src1, device float * dst, + device const char * ids, constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne01, @@ -6183,28 +6170,19 @@ kernel void kernel_mul_mv_id_q8_0_f32( constant uint & r2, constant uint & r3, constant int & idx, - device const char * src00, - device const char * src01, - device const char * src02, - device const char * src03, - device const char * src04, - device const char * src05, - device const char * src06, - device const char * src07, uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; - const int64_t bid = tgpig.z/(ne12*ne13); tgpig.z = tgpig.z%(ne12*ne13); const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + device const char * src0 = src0s + id*nb02; kernel_mul_mv_q8_0_f32_impl( - src0[id], + src0, (device const float *) (src1 + bid*nb11), dst + bid*ne0, ne00, @@ -6223,9 +6201,10 @@ kernel void kernel_mul_mv_id_q8_0_f32( [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel void kernel_mul_mv_id_q4_0_f32( - device const char * ids, + device const char * src0s, device const char * src1, device float * dst, + device const char * ids, constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne01, @@ -6246,28 +6225,19 @@ kernel void kernel_mul_mv_id_q4_0_f32( constant uint & r2, constant uint & r3, constant int & idx, - device const char * src00, - device const char * src01, - device const char * src02, - device const char * src03, - device const char * src04, - device const char * src05, - device const char * src06, - device const char * src07, uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; - const int64_t bid = tgpig.z/(ne12*ne13); tgpig.z = tgpig.z%(ne12*ne13); const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + device const char * src0 = src0s + id*nb02; mul_vec_q_n_f32_impl( - src0[id], + src0, (device const float *) (src1 + bid*nb11), dst + bid*ne0, ne00, @@ -6286,9 +6256,10 @@ kernel void kernel_mul_mv_id_q4_0_f32( [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel void kernel_mul_mv_id_q4_1_f32( - device const char * ids, + device const char * src0s, device const char * src1, device float * dst, + device const char * ids, constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne01, @@ -6309,28 +6280,19 @@ kernel void kernel_mul_mv_id_q4_1_f32( constant uint & r2, constant uint & r3, constant int & idx, - device const char * src00, - device const char * src01, - device const char * src02, - device const char * src03, - device const char * src04, - device const char * src05, - device const char * src06, - device const char * src07, uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; - const int64_t bid = tgpig.z/(ne12*ne13); tgpig.z = tgpig.z%(ne12*ne13); const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + device const char * src0 = src0s + id*nb02; mul_vec_q_n_f32_impl( - src0[id], + src0, (device const float *) (src1 + bid*nb11), dst + bid*ne0, ne00, @@ -6349,9 +6311,10 @@ kernel void kernel_mul_mv_id_q4_1_f32( [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel void kernel_mul_mv_id_q5_0_f32( - device const char * ids, + device const char * src0s, device const char * src1, device float * dst, + device const char * ids, constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne01, @@ -6372,28 +6335,19 @@ kernel void kernel_mul_mv_id_q5_0_f32( constant uint & r2, constant uint & r3, constant int & idx, - device const char * src00, - device const char * src01, - device const char * src02, - device const char * src03, - device const char * src04, - device const char * src05, - device const char * src06, - device const char * src07, uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; - const int64_t bid = tgpig.z/(ne12*ne13); tgpig.z = tgpig.z%(ne12*ne13); const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + device const char * src0 = src0s + id*nb02; mul_vec_q_n_f32_impl( - src0[id], + src0, (device const float *) (src1 + bid*nb11), dst + bid*ne0, ne00, @@ -6412,9 +6366,10 @@ kernel void kernel_mul_mv_id_q5_0_f32( [[host_name("kernel_mul_mv_id_q5_1_f32")]] kernel void kernel_mul_mv_id_q5_1_f32( - device const char * ids, + device const char * src0s, device const char * src1, device float * dst, + device const char * ids, constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne01, @@ -6435,28 +6390,19 @@ kernel void kernel_mul_mv_id_q5_1_f32( constant uint & r2, constant uint & r3, constant int & idx, - device const char * src00, - device const char * src01, - device const char * src02, - device const char * src03, - device const char * src04, - device const char * src05, - device const char * src06, - device const char * src07, uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; - const int64_t bid = tgpig.z/(ne12*ne13); tgpig.z = tgpig.z%(ne12*ne13); const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + device const char * src0 = src0s + id*nb02; mul_vec_q_n_f32_impl( - src0[id], + src0, (device const float *) (src1 + bid*nb11), dst + bid*ne0, ne00, @@ -6475,9 +6421,10 @@ kernel void kernel_mul_mv_id_q5_1_f32( [[host_name("kernel_mul_mv_id_q2_K_f32")]] kernel void kernel_mul_mv_id_q2_K_f32( - device const char * ids, + device const char * src0s, device const char * src1, device float * dst, + device const char * ids, constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne01, @@ -6498,28 +6445,19 @@ kernel void kernel_mul_mv_id_q2_K_f32( constant uint & r2, constant uint & r3, constant int & idx, - device const char * src00, - device const char * src01, - device const char * src02, - device const char * src03, - device const char * src04, - device const char * src05, - device const char * src06, - device const char * src07, uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; - const int64_t bid = tgpig.z/(ne12*ne13); tgpig.z = tgpig.z%(ne12*ne13); const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + device const char * src0 = src0s + id*nb02; kernel_mul_mv_q2_K_f32_impl( - src0[id], + src0, (device const float *) (src1 + bid*nb11), dst + bid*ne0, ne00, @@ -6538,9 +6476,10 @@ kernel void kernel_mul_mv_id_q2_K_f32( [[host_name("kernel_mul_mv_id_q3_K_f32")]] kernel void kernel_mul_mv_id_q3_K_f32( - device const char * ids, + device const char * src0s, device const char * src1, device float * dst, + device const char * ids, constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne01, @@ -6561,28 +6500,19 @@ kernel void kernel_mul_mv_id_q3_K_f32( constant uint & r2, constant uint & r3, constant int & idx, - device const char * src00, - device const char * src01, - device const char * src02, - device const char * src03, - device const char * src04, - device const char * src05, - device const char * src06, - device const char * src07, uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; - const int64_t bid = tgpig.z/(ne12*ne13); tgpig.z = tgpig.z%(ne12*ne13); const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + device const char * src0 = src0s + id*nb02; kernel_mul_mv_q3_K_f32_impl( - src0[id], + src0, (device const float *) (src1 + bid*nb11), dst + bid*ne0, ne00, @@ -6601,9 +6531,10 @@ kernel void kernel_mul_mv_id_q3_K_f32( [[host_name("kernel_mul_mv_id_q4_K_f32")]] kernel void kernel_mul_mv_id_q4_K_f32( - device const char * ids, + device const char * src0s, device const char * src1, device float * dst, + device const char * ids, constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne01, @@ -6624,28 +6555,19 @@ kernel void kernel_mul_mv_id_q4_K_f32( constant uint & r2, constant uint & r3, constant int & idx, - device const char * src00, - device const char * src01, - device const char * src02, - device const char * src03, - device const char * src04, - device const char * src05, - device const char * src06, - device const char * src07, uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; - const int64_t bid = tgpig.z/(ne12*ne13); tgpig.z = tgpig.z%(ne12*ne13); const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + device const char * src0 = src0s + id*nb02; kernel_mul_mv_q4_K_f32_impl( - src0[id], + src0, (device const float *) (src1 + bid*nb11), dst + bid*ne0, ne00, @@ -6664,9 +6586,10 @@ kernel void kernel_mul_mv_id_q4_K_f32( [[host_name("kernel_mul_mv_id_q5_K_f32")]] kernel void kernel_mul_mv_id_q5_K_f32( - device const char * ids, + device const char * src0s, device const char * src1, device float * dst, + device const char * ids, constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne01, @@ -6687,28 +6610,19 @@ kernel void kernel_mul_mv_id_q5_K_f32( constant uint & r2, constant uint & r3, constant int & idx, - device const char * src00, - device const char * src01, - device const char * src02, - device const char * src03, - device const char * src04, - device const char * src05, - device const char * src06, - device const char * src07, uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; - const int64_t bid = tgpig.z/(ne12*ne13); tgpig.z = tgpig.z%(ne12*ne13); const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + device const char * src0 = src0s + id*nb02; kernel_mul_mv_q5_K_f32_impl( - src0[id], + src0, (device const float *) (src1 + bid*nb11), dst + bid*ne0, ne00, @@ -6727,9 +6641,10 @@ kernel void kernel_mul_mv_id_q5_K_f32( [[host_name("kernel_mul_mv_id_q6_K_f32")]] kernel void kernel_mul_mv_id_q6_K_f32( - device const char * ids, + device const char * src0s, device const char * src1, device float * dst, + device const char * ids, constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne01, @@ -6750,28 +6665,19 @@ kernel void kernel_mul_mv_id_q6_K_f32( constant uint & r2, constant uint & r3, constant int & idx, - device const char * src00, - device const char * src01, - device const char * src02, - device const char * src03, - device const char * src04, - device const char * src05, - device const char * src06, - device const char * src07, uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; - const int64_t bid = tgpig.z/(ne12*ne13); tgpig.z = tgpig.z%(ne12*ne13); const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + device const char * src0 = src0s + id*nb02; kernel_mul_mv_q6_K_f32_impl( - src0[id], + src0, (device const float *) (src1 + bid*nb11), dst + bid*ne0, ne00, @@ -6790,9 +6696,10 @@ kernel void kernel_mul_mv_id_q6_K_f32( [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]] kernel void kernel_mul_mv_id_iq2_xxs_f32( - device const char * ids, + device const char * src0s, device const char * src1, device float * dst, + device const char * ids, constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne01, @@ -6813,29 +6720,20 @@ kernel void kernel_mul_mv_id_iq2_xxs_f32( constant uint & r2, constant uint & r3, constant int & idx, - device const char * src00, - device const char * src01, - device const char * src02, - device const char * src03, - device const char * src04, - device const char * src05, - device const char * src06, - device const char * src07, threadgroup int8_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; - const int64_t bid = tgpig.z/(ne12*ne13); tgpig.z = tgpig.z%(ne12*ne13); const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + device const char * src0 = src0s + id*nb02; kernel_mul_mv_iq2_xxs_f32_impl( - src0[id], + src0, (device const float *) (src1 + bid*nb11), dst + bid*ne0, ne00, @@ -6855,9 +6753,10 @@ kernel void kernel_mul_mv_id_iq2_xxs_f32( [[host_name("kernel_mul_mv_id_iq2_xs_f32")]] kernel void kernel_mul_mv_id_iq2_xs_f32( - device const char * ids, + device const char * src0s, device const char * src1, device float * dst, + device const char * ids, constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne01, @@ -6878,29 +6777,20 @@ kernel void kernel_mul_mv_id_iq2_xs_f32( constant uint & r2, constant uint & r3, constant int & idx, - device const char * src00, - device const char * src01, - device const char * src02, - device const char * src03, - device const char * src04, - device const char * src05, - device const char * src06, - device const char * src07, threadgroup int8_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; - const int64_t bid = tgpig.z/(ne12*ne13); tgpig.z = tgpig.z%(ne12*ne13); const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + device const char * src0 = src0s + id*nb02; kernel_mul_mv_iq2_xs_f32_impl( - src0[id], + src0, (device const float *) (src1 + bid*nb11), dst + bid*ne0, ne00, @@ -6920,9 +6810,10 @@ kernel void kernel_mul_mv_id_iq2_xs_f32( [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]] kernel void kernel_mul_mv_id_iq3_xxs_f32( - device const char * ids, + device const char * src0s, device const char * src1, device float * dst, + device const char * ids, constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne01, @@ -6943,29 +6834,20 @@ kernel void kernel_mul_mv_id_iq3_xxs_f32( constant uint & r2, constant uint & r3, constant int & idx, - device const char * src00, - device const char * src01, - device const char * src02, - device const char * src03, - device const char * src04, - device const char * src05, - device const char * src06, - device const char * src07, threadgroup int8_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; - const int64_t bid = tgpig.z/(ne12*ne13); tgpig.z = tgpig.z%(ne12*ne13); const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + device const char * src0 = src0s + id*nb02; kernel_mul_mv_iq3_xxs_f32_impl( - src0[id], + src0, (device const float *) (src1 + bid*nb11), dst + bid*ne0, ne00, @@ -6985,9 +6867,10 @@ kernel void kernel_mul_mv_id_iq3_xxs_f32( [[host_name("kernel_mul_mv_id_iq3_s_f32")]] kernel void kernel_mul_mv_id_iq3_s_f32( - device const char * ids, + device const char * src0s, device const char * src1, device float * dst, + device const char * ids, constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne01, @@ -7008,29 +6891,20 @@ kernel void kernel_mul_mv_id_iq3_s_f32( constant uint & r2, constant uint & r3, constant int & idx, - device const char * src00, - device const char * src01, - device const char * src02, - device const char * src03, - device const char * src04, - device const char * src05, - device const char * src06, - device const char * src07, threadgroup int8_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; - const int64_t bid = tgpig.z/(ne12*ne13); tgpig.z = tgpig.z%(ne12*ne13); const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + device const char * src0 = src0s + id*nb02; kernel_mul_mv_iq3_s_f32_impl( - src0[id], + src0, (device const float *) (src1 + bid*nb11), dst + bid*ne0, ne00, @@ -7050,9 +6924,10 @@ kernel void kernel_mul_mv_id_iq3_s_f32( [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel void kernel_mul_mv_id_iq2_s_f32( - device const char * ids, + device const char * src0s, device const char * src1, device float * dst, + device const char * ids, constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne01, @@ -7073,29 +6948,20 @@ kernel void kernel_mul_mv_id_iq2_s_f32( constant uint & r2, constant uint & r3, constant int & idx, - device const char * src00, - device const char * src01, - device const char * src02, - device const char * src03, - device const char * src04, - device const char * src05, - device const char * src06, - device const char * src07, threadgroup int8_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; - const int64_t bid = tgpig.z/(ne12*ne13); tgpig.z = tgpig.z%(ne12*ne13); const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + device const char * src0 = src0s + id*nb02; kernel_mul_mv_iq2_s_f32_impl( - src0[id], + src0, (device const float *) (src1 + bid*nb11), dst + bid*ne0, ne00, @@ -7115,9 +6981,10 @@ kernel void kernel_mul_mv_id_iq2_s_f32( [[host_name("kernel_mul_mv_id_iq1_s_f32")]] kernel void kernel_mul_mv_id_iq1_s_f32( - device const char * ids, + device const char * src0s, device const char * src1, device float * dst, + device const char * ids, constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne01, @@ -7138,28 +7005,19 @@ kernel void kernel_mul_mv_id_iq1_s_f32( constant uint & r2, constant uint & r3, constant int & idx, - device const char * src00, - device const char * src01, - device const char * src02, - device const char * src03, - device const char * src04, - device const char * src05, - device const char * src06, - device const char * src07, uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; - const int64_t bid = tgpig.z/(ne12*ne13); tgpig.z = tgpig.z%(ne12*ne13); const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + device const char * src0 = src0s + id*nb02; kernel_mul_mv_iq1_s_f32_impl( - src0[id], + src0, (device const float *) (src1 + bid*nb11), dst + bid*ne0, ne00, @@ -7178,9 +7036,10 @@ kernel void kernel_mul_mv_id_iq1_s_f32( [[host_name("kernel_mul_mv_id_iq1_m_f32")]] kernel void kernel_mul_mv_id_iq1_m_f32( - device const char * ids, + device const char * src0s, device const char * src1, device float * dst, + device const char * ids, constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne01, @@ -7201,28 +7060,19 @@ kernel void kernel_mul_mv_id_iq1_m_f32( constant uint & r2, constant uint & r3, constant int & idx, - device const char * src00, - device const char * src01, - device const char * src02, - device const char * src03, - device const char * src04, - device const char * src05, - device const char * src06, - device const char * src07, uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; - const int64_t bid = tgpig.z/(ne12*ne13); tgpig.z = tgpig.z%(ne12*ne13); const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + device const char * src0 = src0s + id*nb02; kernel_mul_mv_iq1_m_f32_impl( - src0[id], + src0, (device const float *) (src1 + bid*nb11), dst + bid*ne0, ne00, @@ -7241,9 +7091,10 @@ kernel void kernel_mul_mv_id_iq1_m_f32( [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel void kernel_mul_mv_id_iq4_nl_f32( - device const char * ids, + device const char * src0s, device const char * src1, device float * dst, + device const char * ids, constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne01, @@ -7264,29 +7115,20 @@ kernel void kernel_mul_mv_id_iq4_nl_f32( constant uint & r2, constant uint & r3, constant int & idx, - device const char * src00, - device const char * src01, - device const char * src02, - device const char * src03, - device const char * src04, - device const char * src05, - device const char * src06, - device const char * src07, threadgroup float * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; - const int64_t bid = tgpig.z/(ne12*ne13); tgpig.z = tgpig.z%(ne12*ne13); const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + device const char * src0 = src0s + id*nb02; kernel_mul_mv_iq4_nl_f32_impl( - src0[id], + src0, (device const float *) (src1 + bid*nb11), dst + bid*ne0, ne00, @@ -7306,9 +7148,10 @@ kernel void kernel_mul_mv_id_iq4_nl_f32( [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel void kernel_mul_mv_id_iq4_xs_f32( - device const char * ids, + device const char * src0s, device const char * src1, device float * dst, + device const char * ids, constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne01, @@ -7329,33 +7172,24 @@ kernel void kernel_mul_mv_id_iq4_xs_f32( constant uint & r2, constant uint & r3, constant int & idx, - device const char * src00, - device const char * src01, - device const char * src02, - device const char * src03, - device const char * src04, - device const char * src05, - device const char * src06, - device const char * src07, threadgroup float * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; - const int64_t bid = tgpig.z/(ne12*ne13); tgpig.z = tgpig.z%(ne12*ne13); const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + device const char * src0 = src0s + id*nb02; #if QK_K == 64 kernel_mul_mv_iq4_nl_f32_impl( #else kernel_mul_mv_iq4_xs_f32_impl( #endif - src0[id], + src0, (device const float *) (src1 + bid*nb11), dst + bid*ne0, ne00, diff --git a/ggml.c b/ggml.c index 7471e79..c9b0a6a 100644 --- a/ggml.c +++ b/ggml.c @@ -4573,45 +4573,38 @@ void ggml_mul_mat_set_prec( // ggml_mul_mat_id +// NOTE: id will be removed in the future and instead all the experts listed in ids will be computed +// this will allow computing all the used experts in a single matrix multiplication struct ggml_tensor * ggml_mul_mat_id( struct ggml_context * ctx, - struct ggml_tensor * const as[], - int n_as, + struct ggml_tensor * as, struct ggml_tensor * ids, int id, struct ggml_tensor * b) { GGML_ASSERT(ids->type == GGML_TYPE_I32); - GGML_ASSERT(ids->ne[2] == 1 && ids->ne[3] == 1); - GGML_ASSERT(ids->ne[1] == b->ne[1]); + GGML_ASSERT(ids->ne[2] == 1 && ids->ne[3] == 1); // ids is 2d + GGML_ASSERT(ids->ne[1] == b->ne[1]); // must have an expert per b row GGML_ASSERT(ids->ne[2] == b->ne[2] && ids->ne[3] == b->ne[3]); - GGML_ASSERT(n_as > 0 && n_as <= GGML_MAX_SRC - 2); - GGML_ASSERT(id >= 0 && id < ids->ne[0]); + GGML_ASSERT(id >= 0 && id < ids->ne[0]); // valid id + GGML_ASSERT(as->ne[0] == b->ne[0]); // can_mul_mat bool is_node = false; - if (as[0]->grad || b->grad) { + if (as->grad || b->grad) { is_node = true; } - const int64_t ne[4] = { as[0]->ne[1], b->ne[1], b->ne[2], b->ne[3] }; + const int64_t ne[4] = { as->ne[1], b->ne[1], b->ne[2], b->ne[3] }; struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); ggml_set_op_params_i32(result, 0, id); - ggml_set_op_params_i32(result, 1, n_as); result->op = GGML_OP_MUL_MAT_ID; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = ids; + result->src[0] = as; result->src[1] = b; - - for (int i = 0; i < n_as; i++) { - struct ggml_tensor * a = as[i]; - GGML_ASSERT(ggml_are_same_shape(as[0], a)); - GGML_ASSERT(ggml_can_mul_mat(a, b)); - GGML_ASSERT(!ggml_is_transposed(a)); - result->src[i + 2] = a; - } + result->src[2] = ids; return result; } @@ -10948,10 +10941,9 @@ static void ggml_compute_forward_mul_mat_id( const struct ggml_compute_params * params, struct ggml_tensor * dst) { - const struct ggml_tensor * ids = dst->src[0]; + const struct ggml_tensor * src0 = dst->src[0]; const struct ggml_tensor * src1 = dst->src[1]; - - const struct ggml_tensor * src0 = dst->src[2]; // only for GGML_TENSOR_BINARY_OP_LOCALS + const struct ggml_tensor * ids = dst->src[2]; GGML_TENSOR_BINARY_OP_LOCALS @@ -10981,13 +10973,13 @@ static void ggml_compute_forward_mul_mat_id( GGML_ASSERT(nb1 <= nb2); GGML_ASSERT(nb2 <= nb3); - // broadcast factors - const int64_t r2 = ne12/ne02; - const int64_t r3 = ne13/ne03; + // broadcast is not supported with mmid + assert(ne12 == 1); + assert(ne13 == 1); // row groups const int id = ggml_get_op_params_i32(dst, 0); - const int n_as = ggml_get_op_params_i32(dst, 1); + const int n_as = src0->ne[2]; char * wdata_src1_end = (src1->type == vec_dot_type) ? (char *) params->wdata : @@ -11047,7 +11039,7 @@ static void ggml_compute_forward_mul_mat_id( continue; } - const struct ggml_tensor * src0_cur = dst->src[cur_a + 2]; + size_t src0_offset = cur_a*src0->nb[2]; const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata; const size_t row_size = ggml_row_size(vec_dot_type, ne10); @@ -11082,9 +11074,6 @@ static void ggml_compute_forward_mul_mat_id( continue; } - assert(ne12 % ne02 == 0); - assert(ne13 % ne03 == 0); - // block-tiling attempt const int64_t blck_0 = 16; const int64_t blck_1 = 16; @@ -11101,14 +11090,14 @@ static void ggml_compute_forward_mul_mat_id( const int64_t i11 = MMID_MATRIX_ROW(cur_a, _i11); // broadcast src0 into src1 - const int64_t i03 = i13/r3; - const int64_t i02 = i12/r2; + //const int64_t i03 = i13/r3; + //const int64_t i02 = i12/r2; const int64_t i1 = i11; const int64_t i2 = i12; const int64_t i3 = i13; - const char * src0_row = (const char *) src0_cur->data + (0 + i02*nb02 + i03*nb03); + const char * src0_row = (const char *) src0->data + src0_offset; // desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides // if it is, then we have either copied the data to params->wdata and made it contiguous or we are using @@ -18464,13 +18453,13 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa case GGML_OP_MUL_MAT_ID: { cur = 0; - const struct ggml_tensor * src0 = node->src[2]; + const struct ggml_tensor * src0 = node->src[0]; const struct ggml_tensor * src1 = node->src[1]; const enum ggml_type vec_dot_type = type_traits[src0->type].vec_dot_type; if (src1->type != vec_dot_type) { cur += ggml_row_size(vec_dot_type, ggml_nelements(src1)); } - const int n_as = ggml_get_op_params_i32(node, 1); + const int n_as = src0->ne[2]; cur += GGML_PAD(cur, sizeof(int64_t)); // align cur += n_as * sizeof(int64_t); // matrix_row_counts cur += n_as * src1->ne[1] * sizeof(int64_t); // matrix_rows diff --git a/ggml.h b/ggml.h index 5d4a4ce..5cef45c 100644 --- a/ggml.h +++ b/ggml.h @@ -1164,8 +1164,7 @@ extern "C" { // ggml_mul_mat_id(ctx, as, ids, id, b) ~= ggml_mul_mat(as[ids[id]], b) GGML_API struct ggml_tensor * ggml_mul_mat_id( struct ggml_context * ctx, - struct ggml_tensor * const as[], - int n_as, + struct ggml_tensor * as, struct ggml_tensor * ids, int id, struct ggml_tensor * b);