From f62a546e03a0bae558eead57113056ee5614ea4c Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 5 Oct 2024 12:36:40 +0300 Subject: [PATCH] whisper : fix excessive memory usage (#2443) * whisper : fix KV cache allocation * whisper : reduce memory overhead from unused input tensors --- src/whisper.cpp | 71 +++++++++++++++++++++++++++---------------------- 1 file changed, 39 insertions(+), 32 deletions(-) diff --git a/src/whisper.cpp b/src/whisper.cpp index 9c7c66b..101c4f5 100644 --- a/src/whisper.cpp +++ b/src/whisper.cpp @@ -163,7 +163,6 @@ static void whisper_log_callback_default(ggml_log_level level, const char * text } \ } while (0) -//#define WHISPER_USE_FLASH_FF #define WHISPER_MAX_DECODERS 8 #define WHISPER_MAX_NODES 4096 @@ -817,6 +816,9 @@ struct whisper_state { int32_t n_fail_p = 0; // number of logprob threshold failures int32_t n_fail_h = 0; // number of entropy threshold failures + // number of decoders for which we have constructed the KV cache + int32_t kv_self_n_dec = 0; + // unified self-attention KV cache for all decoders whisper_kv_cache kv_self; @@ -2096,9 +2098,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder( struct ggml_tensor * Q = ggml_permute(ctx0, - ggml_cpy(ctx0, - Qcur, - ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_state_head, n_head, n_ctx)), + ggml_reshape_3d(ctx0, Qcur, n_state_head, n_head, n_ctx), 0, 2, 1, 3); if (wctx.params.flash_attn) { @@ -2125,9 +2125,9 @@ static struct ggml_cgraph * whisper_build_graph_encoder( } else { struct ggml_tensor * K = ggml_permute(ctx0, - ggml_cpy(ctx0, - Kcur, - ggml_new_tensor_3d(ctx0, wctx.itype, n_state_head, n_head, n_ctx)), + ggml_cast(ctx0, + ggml_reshape_3d(ctx0, Kcur, n_state_head, n_head, n_ctx), + wctx.itype), 0, 2, 1, 3); // K * Q @@ -2136,22 +2136,19 @@ static struct ggml_cgraph * whisper_build_graph_encoder( struct ggml_tensor * KQ_soft_max = ggml_soft_max_ext(ctx0, KQ, nullptr, KQscale, 0.0f); struct ggml_tensor * V = - ggml_cpy(ctx0, + ggml_cast(ctx0, ggml_permute(ctx0, ggml_reshape_3d(ctx0, Vcur, n_state_head, n_head, n_ctx), 1, 2, 0, 3), - ggml_new_tensor_3d(ctx0, wctx.itype, n_ctx, n_state_head, n_head) - ); + wctx.itype); struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max); struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); - cur = ggml_cpy(ctx0, - KQV_merged, - ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx)); + cur = ggml_cont_2d(ctx0, KQV_merged, n_state, n_ctx); } } @@ -2181,11 +2178,6 @@ static struct ggml_cgraph * whisper_build_graph_encoder( layer.mlp_ln_b); } -#ifdef WHISPER_USE_FLASH_FF - cur = ggml_flash_ff(ctx0, - ggml_cpy(ctx0, cur, ggml_new_tensor_2d(ctx0, wstate.itype, n_state, n_ctx)), - layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b); -#else // fully connected cur = ggml_mul_mat(ctx0, layer.mlp_0_w, @@ -2202,7 +2194,6 @@ static struct ggml_cgraph * whisper_build_graph_encoder( cur); cur = ggml_add(ctx0, cur, layer.mlp_1_b); -#endif } inpL = ggml_add(ctx0, cur, inpFF); @@ -2578,9 +2569,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder( struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); - cur = ggml_cpy(ctx0, - KQV_merged, - ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_tokens)); + cur = ggml_cont_2d(ctx0, KQV_merged, n_state, n_tokens); } } @@ -2687,9 +2676,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder( struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); - cur = ggml_cpy(ctx0, - KQV_merged, - ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_tokens)); + cur = ggml_cont_2d(ctx0, KQV_merged, n_state, n_tokens); } } @@ -3403,14 +3390,13 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) { whisper_mel_init(state->mel, state->backends[0], n_len, n_len, n_mel); } - // at this point, we don't know yet how many decoders will be used, so we overallocate 3x ctx - // in theory, there can be a case where this is not enough, but in practice it should always be enough - const int factor = 3; - + // at this point, we don't know yet how many decoders will be used + // later during decoding, if more decoders are used, we will recreate the KV cache respectively + state->kv_self_n_dec = 1; if (!whisper_kv_cache_init(state->kv_self, state->backends[0], ctx->itype, ctx->model.hparams.n_text_state, ctx->model.hparams.n_text_layer, - GGML_PAD(ctx->model.hparams.n_text_ctx, 256)*factor)) { + GGML_PAD(ctx->model.hparams.n_text_ctx, 256))) { WHISPER_LOG_ERROR("%s: whisper_kv_cache_init() failed for self-attention cache\n", __func__); whisper_free_state(state); return nullptr; @@ -5775,13 +5761,34 @@ int whisper_full_with_state( } WHISPER_LOG_DEBUG("\n\n"); + // recreate the KV cache if the number of decoders has changed + if (state->kv_self_n_dec < n_decoders_cur) { + WHISPER_LOG_DEBUG("%s: recreating KV cache: n_decoders_cur = %d\n", __func__, n_decoders_cur); + + whisper_kv_cache_free(state->kv_self); + + // overallocate to workaround KV cache fragmentation issues + const int factor = n_decoders_cur > 1 ? n_decoders_cur + 2 : 1; + + if (!whisper_kv_cache_init(state->kv_self, state->backends[0], ctx->itype, + ctx->model.hparams.n_text_state, + ctx->model.hparams.n_text_layer, + GGML_PAD(ctx->model.hparams.n_text_ctx, 256)*factor)) { + WHISPER_LOG_ERROR("%s: whisper_kv_cache_init() failed for self-attention cache\n", __func__); + whisper_free_state(state); + return -7; + } + + state->kv_self_n_dec = n_decoders_cur; + } + whisper_kv_cache_clear(state->kv_self); whisper_batch_prep_legacy(state->batch, prompt.data(), prompt.size(), 0, 0); if (!whisper_decode_internal(*ctx, *state, state->batch, params.n_threads, false, params.abort_callback, params.abort_callback_user_data)) { WHISPER_LOG_ERROR("%s: failed to decode\n", __func__); - return -7; + return -8; } { @@ -6081,7 +6088,7 @@ int whisper_full_with_state( if (!whisper_decode_internal(*ctx, *state, state->batch, params.n_threads, false, params.abort_callback, params.abort_callback_user_data)) { WHISPER_LOG_ERROR("%s: failed to decode\n", __func__); - return -8; + return -9; } const int64_t t_start_sample_us = ggml_time_us();