From c1320c1f0c9de3dc7d04b1ccf9fb2a38da8c840f Mon Sep 17 00:00:00 2001 From: Neo Zhang Jianyu Date: Sun, 14 Apr 2024 10:42:29 +0800 Subject: [PATCH] fix memcpy() crash, add missed cmd in guide, fix softmax (llama/6622) * disable mmap to fix memcpy crash, add missed cmd in guide, fix softmax * refactor to disable mmap for SYCL backend * fix compile error in other os * refactor the solution, use host buf to fix it, instead of disable mmap * keep to support mmap() * use host buff to reduce malloc times * revert to malloc/free solution, for threaad safe --- ggml-sycl.cpp | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/ggml-sycl.cpp b/ggml-sycl.cpp index 55a1eed..86091cf 100644 --- a/ggml-sycl.cpp +++ b/ggml-sycl.cpp @@ -3154,7 +3154,6 @@ typedef float (*vec_dot_q_mul_mat_sycl_t)( #define SYCL_SCALE_BLOCK_SIZE 256 #define SYCL_CLAMP_BLOCK_SIZE 256 #define SYCL_ROPE_BLOCK_SIZE 256 -#define SYCL_SOFT_MAX_BLOCK_SIZE 1024 #define SYCL_ALIBI_BLOCK_SIZE 32 #define SYCL_DIAG_MASK_INF_BLOCK_SIZE 32 #define SYCL_QUANTIZE_BLOCK_SIZE 256 @@ -13080,11 +13079,13 @@ static void soft_max_f32_sycl(const float * x, const float * mask, const float * const int nrows_y, const float scale, const float max_bias, dpct::queue_ptr stream) { int nth = WARP_SIZE; - while (nth < ncols_x && nth < SYCL_SOFT_MAX_BLOCK_SIZE) nth *= 2; + int max_block_size = g_work_group_size; + while (nth < ncols_x && nth < max_block_size) nth *= 2; + if (nth>max_block_size) nth = max_block_size; + const sycl::range<3> block_dims(1, 1, nth); const sycl::range<3> block_nums(1, 1, nrows_x); const size_t n_local_scratch = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE); - static_assert(SYCL_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted."); const uint32_t n_head_kv = nrows_x/nrows_y; const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv)); @@ -13094,6 +13095,12 @@ static void soft_max_f32_sycl(const float * x, const float * mask, const float * const size_t local_mem_size = stream->get_device().get_info(); if (n_local_scratch*sizeof(float) < local_mem_size) { + if (ncols_x > max_block_size) { + soft_max_f32_submitter(x, mask, pos, dst, ncols_x, nrows_y, scale, + max_bias, m0, m1, n_head_log2, block_nums, + block_dims, n_local_scratch, stream); + return; + } switch (ncols_x) { case 32: soft_max_f32_submitter(x, mask, pos, dst, ncols_x, nrows_y, scale, @@ -16814,11 +16821,13 @@ static void ggml_backend_sycl_buffer_set_tensor(ggml_backend_buffer_t buffer, const dpct::queue_ptr stream = g_syclStreams[ctx->device][0]; SYCL_CHECK( CHECK_TRY_ERROR(dpct::dev_mgr::instance().get_device(ctx->device).queues_wait_and_throw())); - + char* host_buf = (char*)malloc(size); + memcpy(host_buf, data, size); SYCL_CHECK( CHECK_TRY_ERROR((*stream) - .memcpy((char *)tensor->data + offset, data, size) + .memcpy((char *)tensor->data + offset, host_buf, size) .wait())); + free(host_buf); } catch (sycl::exception const &exc) { std::cerr << exc.what() << "Exception caught at file:" << __FILE__