From 0cb820e0f9e52e4aea07abd59d4220323b7bd849 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 14 May 2023 18:46:19 +0300 Subject: [PATCH] talk-llama : fix build + sync latest llama.cpp --- examples/talk-llama/llama.cpp | 527 +++++++++++++++++------------ examples/talk-llama/llama.h | 38 ++- examples/talk-llama/talk-llama.cpp | 8 +- 3 files changed, 335 insertions(+), 238 deletions(-) diff --git a/examples/talk-llama/llama.cpp b/examples/talk-llama/llama.cpp index dfc68ed..98f49ab 100644 --- a/examples/talk-llama/llama.cpp +++ b/examples/talk-llama/llama.cpp @@ -9,6 +9,9 @@ #include "llama.h" #include "ggml.h" +#ifdef GGML_USE_CUBLAS +#include "ggml-cuda.h" +#endif #include #include @@ -50,49 +53,49 @@ static const size_t MB = 1024*1024; static const std::map & MEM_REQ_SCRATCH0() { - static std::map _MEM_REQ_SCRATCH0 = { + static std::map k_sizes = { { MODEL_7B, 512ull * MB }, { MODEL_13B, 512ull * MB }, { MODEL_30B, 512ull * MB }, { MODEL_65B, 1024ull * MB }, }; - return _MEM_REQ_SCRATCH0; + return k_sizes; } static const std::map & MEM_REQ_SCRATCH1() { - static std::map _MEM_REQ_SCRATCH1 = { + static std::map k_sizes = { { MODEL_7B, 512ull * MB }, { MODEL_13B, 512ull * MB }, { MODEL_30B, 512ull * MB }, { MODEL_65B, 1024ull * MB }, }; - return _MEM_REQ_SCRATCH1; + return k_sizes; } // 2*n_embd*n_ctx*n_layer*sizeof(float16) static const std::map & MEM_REQ_KV_SELF() { - static std::map _MEM_REQ_KV_SELF = { + static std::map k_sizes = { { MODEL_7B, 1026ull * MB }, { MODEL_13B, 1608ull * MB }, { MODEL_30B, 3124ull * MB }, { MODEL_65B, 5120ull * MB }, }; - return _MEM_REQ_KV_SELF; + return k_sizes; } // this is mostly needed for temporary mul_mat buffers to dequantize the data // not actually needed if BLAS is disabled static const std::map & MEM_REQ_EVAL() { - static std::map _MEM_REQ_EVAL = { + static std::map k_sizes = { { MODEL_7B, 768ull * MB }, { MODEL_13B, 1024ull * MB }, { MODEL_30B, 1280ull * MB }, { MODEL_65B, 1536ull * MB }, }; - return _MEM_REQ_EVAL; + return k_sizes; } // default hparams (LLaMA 7B) @@ -402,6 +405,7 @@ enum llama_file_version { LLAMA_FILE_VERSION_GGML, LLAMA_FILE_VERSION_GGMF_V1, // added version field and scores in vocab LLAMA_FILE_VERSION_GGJT_V1, // added padding + LLAMA_FILE_VERSION_GGJT_V2, // changed quantization format }; struct llama_file_loader { @@ -432,6 +436,8 @@ struct llama_file_loader { file_version = LLAMA_FILE_VERSION_GGMF_V1; } else if (magic == 'ggjt' && version == 1) { file_version = LLAMA_FILE_VERSION_GGJT_V1; + } else if (magic == 'ggjt' && version == 2) { + file_version = LLAMA_FILE_VERSION_GGJT_V2; } else { throw format("unknown (magic, version) combination: %08x, %08x; is this really a GGML file?", magic, version); @@ -482,7 +488,6 @@ struct llama_file_loader { case GGML_TYPE_F16: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: - case GGML_TYPE_Q4_2: case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: @@ -527,8 +532,8 @@ struct llama_file_saver { write_vocab(); } void write_magic() { - file.write_u32('ggjt'); // magic - file.write_u32(1); // version + file.write_u32(LLAMA_FILE_MAGIC); // magic + file.write_u32(LLAMA_FILE_VERSION); // version } void write_hparams(enum llama_ftype new_ftype) { const llama_hparams & hparams = any_file_loader->hparams; @@ -558,7 +563,6 @@ struct llama_file_saver { case GGML_TYPE_F16: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: - case GGML_TYPE_Q4_2: case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: @@ -585,12 +589,12 @@ struct llama_model_loader { std::unique_ptr mapping; llama_model_loader(const std::string & fname_base, bool use_mmap, bool vocab_only) { - auto first_file = new llama_file_loader(fname_base.c_str(), 0, tensors_map); + auto * first_file = new llama_file_loader(fname_base.c_str(), 0, tensors_map); file_loaders.emplace_back(first_file); uint32_t n_parts = vocab_only ? 1 : guess_n_parts(); for (uint32_t i = 1; i < n_parts; i++) { std::string fname = fname_base + "." + std::to_string(i); - auto ith_file = new llama_file_loader(fname.c_str(), i, tensors_map); + auto * ith_file = new llama_file_loader(fname.c_str(), i, tensors_map); file_loaders.emplace_back(ith_file); if (ith_file->hparams != first_file->hparams) { throw format("llama.cpp: hparams inconsistent between files"); @@ -637,7 +641,7 @@ struct llama_model_loader { } } - struct ggml_tensor * get_tensor(const std::string & name, std::vector ne) { + struct ggml_tensor * get_tensor(const std::string & name, const std::vector & ne) { auto it = tensors_map.name_to_idx.find(name); if (it == tensors_map.name_to_idx.end()) { throw format("llama.cpp: tensor '%s' is missing from model", name.c_str()); @@ -659,13 +663,14 @@ struct llama_model_loader { LLAMA_ASSERT(lt.ne.size() == 1); tensor = ggml_new_tensor_1d(ggml_ctx, lt.type, lt.ne.at(0)); } + ggml_set_name(tensor, lt.name.c_str()); LLAMA_ASSERT(lt.ggml_tensor == NULL); // if this fails, we called get_tensor twice on the same tensor lt.ggml_tensor = tensor; num_ggml_tensors_created++; return tensor; } - void done_getting_tensors() { + void done_getting_tensors() const { if (num_ggml_tensors_created != tensors_map.tensors.size()) { throw std::string("llama.cpp: file contained more tensors than expected"); } @@ -727,8 +732,7 @@ struct llama_model_loader { LLAMA_ASSERT(offset == lt.size); } else if (lt.split_type == SPLIT_BY_COLUMNS) { // Let's load the data into temporary buffers to ensure the OS performs large loads. - std::vector tmp_bufs; - tmp_bufs.resize(lt.shards.size()); + std::vector tmp_bufs(lt.shards.size()); for (size_t i = 0; i < lt.shards.size(); i++) { llama_load_tensor_shard & shard = lt.shards.at(i); llama_file & file = file_loaders.at(shard.file_idx)->file; @@ -799,6 +803,8 @@ static bool kv_cache_init( cache.k = ggml_new_tensor_1d(cache.ctx, wtype, n_elements); cache.v = ggml_new_tensor_1d(cache.ctx, wtype, n_elements); + ggml_set_name(cache.k, "cache_k"); + ggml_set_name(cache.v, "cache_v"); return true; } @@ -807,7 +813,8 @@ struct llama_context_params llama_context_default_params() { struct llama_context_params result = { /*.n_ctx =*/ 512, /*.n_parts =*/ -1, - /*.seed =*/ 0, + /*.gpu_layers =*/ 0, + /*.seed =*/ -1, /*.f16_kv =*/ false, /*.logits_all =*/ false, /*.vocab_only =*/ false, @@ -837,9 +844,11 @@ static const char *llama_file_version_name(llama_file_version version) { switch (version) { case LLAMA_FILE_VERSION_GGML: return "'ggml' (old version with low tokenizer quality and no mmap support)"; case LLAMA_FILE_VERSION_GGMF_V1: return "ggmf v1 (old version with no mmap support)"; - case LLAMA_FILE_VERSION_GGJT_V1: return "ggjt v1 (latest)"; - default: LLAMA_ASSERT(false); + case LLAMA_FILE_VERSION_GGJT_V1: return "ggjt v1 (pre #1405)"; + case LLAMA_FILE_VERSION_GGJT_V2: return "ggjt v2 (latest)"; } + + return "unknown"; } static const char *llama_ftype_name(enum llama_ftype ftype) { @@ -850,7 +859,6 @@ static const char *llama_ftype_name(enum llama_ftype ftype) { case LLAMA_FTYPE_MOSTLY_Q4_1: return "mostly Q4_1"; case LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16: return "mostly Q4_1, some F16"; - case LLAMA_FTYPE_MOSTLY_Q4_2: return "mostly Q4_2"; case LLAMA_FTYPE_MOSTLY_Q5_0: return "mostly Q5_0"; case LLAMA_FTYPE_MOSTLY_Q5_1: return "mostly Q5_1"; case LLAMA_FTYPE_MOSTLY_Q8_0: return "mostly Q8_0"; @@ -872,6 +880,7 @@ static void llama_model_load_internal( const std::string & fname, llama_context & lctx, int n_ctx, + int n_gpu_layers, ggml_type memory_type, bool use_mmap, bool use_mlock, @@ -916,13 +925,22 @@ static void llama_model_load_internal( fprintf(stderr, "%s: model size = %s\n", __func__, llama_model_type_name(model.type)); } + if (file_version != LLAMA_FILE_VERSION_GGJT_V2) { + if (hparams.ftype != LLAMA_FTYPE_ALL_F32 && + hparams.ftype != LLAMA_FTYPE_MOSTLY_F16 && + hparams.ftype != LLAMA_FTYPE_MOSTLY_Q8_0) { + throw format("this format is no longer supported (see https://github.com/ggerganov/llama.cpp/pull/1305)"); + } + } + if (vocab_only) { return; } auto & ctx = model.ctx; - size_t ctx_size, mmapped_size; + size_t ctx_size; + size_t mmapped_size; ml->calc_sizes(&ctx_size, &mmapped_size); fprintf(stderr, "%s: ggml ctx size = %6.2f KB\n", __func__, ctx_size/1024.0); @@ -968,8 +986,6 @@ static void llama_model_load_internal( // prepare memory for the weights { - const auto & hparams = model.hparams; - const uint32_t n_embd = hparams.n_embd; const uint32_t n_layer = hparams.n_layer; const uint32_t n_vocab = hparams.n_vocab; @@ -1011,6 +1027,35 @@ static void llama_model_load_internal( ml->load_all_data(progress_callback, progress_callback_user_data, use_mlock ? &lctx.model.mlock_mmap : NULL); model.mapping = std::move(ml->mapping); +#ifdef GGML_USE_CUBLAS + { + const int n_gpu = std::min(n_gpu_layers, int(hparams.n_layer)); + + fprintf(stderr, "%s: [cublas] offloading %d layers to GPU\n", __func__, n_gpu); + + size_t vram_total = 0; + + for (int i = 0; i < n_gpu; ++i) { + const auto & layer = model.layers[i]; + + ggml_cuda_transform_tensor(layer.wq); vram_total += ggml_nbytes(layer.wq); + ggml_cuda_transform_tensor(layer.wk); vram_total += ggml_nbytes(layer.wk); + ggml_cuda_transform_tensor(layer.wv); vram_total += ggml_nbytes(layer.wv); + ggml_cuda_transform_tensor(layer.wo); vram_total += ggml_nbytes(layer.wo); + ggml_cuda_transform_tensor(layer.w1); vram_total += ggml_nbytes(layer.w1); + ggml_cuda_transform_tensor(layer.w2); vram_total += ggml_nbytes(layer.w2); + ggml_cuda_transform_tensor(layer.w3); vram_total += ggml_nbytes(layer.w3); + } + if (n_gpu_layers > (int) hparams.n_layer) { + fprintf(stderr, "%s: [cublas] offloading output layer to GPU\n", __func__); + ggml_cuda_transform_tensor(model.output); vram_total += ggml_nbytes(model.output); + } + + fprintf(stderr, "%s: [cublas] total VRAM used: %zu MB\n", __func__, vram_total / 1024 / 1024); + } +#else + (void) n_gpu_layers; +#endif // loading time will be recalculate after the first eval, so // we take page faults deferred by mmap() into consideration @@ -1021,6 +1066,7 @@ static bool llama_model_load( const std::string & fname, llama_context & lctx, int n_ctx, + int n_gpu_layers, ggml_type memory_type, bool use_mmap, bool use_mlock, @@ -1028,7 +1074,7 @@ static bool llama_model_load( llama_progress_callback progress_callback, void *progress_callback_user_data) { try { - llama_model_load_internal(fname, lctx, n_ctx, memory_type, use_mmap, use_mlock, + llama_model_load_internal(fname, lctx, n_ctx, n_gpu_layers, memory_type, use_mmap, use_mlock, vocab_only, progress_callback, progress_callback_user_data); return true; } catch (const std::string & err) { @@ -1050,6 +1096,13 @@ static bool llama_eval_internal( const int n_tokens, const int n_past, const int n_threads) { + + // enforce that the first token is BOS + if (n_past == 0 && tokens[0] != llama_token_bos()) { + fprintf(stderr, "%s: first token must be BOS\n", __func__); + return false; + } + const int64_t t_start_us = ggml_time_us(); const int N = n_tokens; @@ -1057,7 +1110,7 @@ static bool llama_eval_internal( const auto & model = lctx.model; const auto & hparams = model.hparams; - auto & kv_self = model.kv_self; + const auto & kv_self = model.kv_self; LLAMA_ASSERT(!!kv_self.ctx); @@ -1085,6 +1138,7 @@ static bool llama_eval_internal( gf.n_threads = N >= 32 && ggml_cpu_has_blas() && !ggml_cpu_has_gpublas() ? 1 : n_threads; struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); + ggml_set_name(embd, "embd"); memcpy(embd->data, tokens, N*ggml_element_size(embd)); struct ggml_tensor * inpL = ggml_get_rows(ctx0, model.tok_embeddings, embd); @@ -1109,8 +1163,10 @@ static bool llama_eval_internal( // self-attention { // compute Q and K and RoPE them - struct ggml_tensor * Qcur = ggml_rope(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model.layers[il].wq, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0); - struct ggml_tensor * Kcur = ggml_rope(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model.layers[il].wk, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0); + struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model.layers[il].wq, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0); + struct ggml_tensor * Kcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model.layers[il].wk, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0); + ggml_set_name(Qcur, "Qcur"); + ggml_set_name(Kcur, "Kcur"); // store key and value to memory { @@ -1131,6 +1187,7 @@ static bool llama_eval_internal( ggml_permute(ctx0, Qcur, 0, 2, 1, 3); + ggml_set_name(Q, "Q"); struct ggml_tensor * K = ggml_permute(ctx0, @@ -1138,21 +1195,28 @@ static bool llama_eval_internal( ggml_view_1d(ctx0, kv_self.k, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(kv_self.k)*n_embd), n_embd/n_head, n_head, n_past + N), 0, 2, 1, 3); + ggml_set_name(K, "K"); // K * Q struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); + ggml_set_name(KQ, "KQ"); // KQ_scaled = KQ / sqrt(n_embd/n_head) - struct ggml_tensor * KQ_scaled = - ggml_scale(ctx0, - KQ, - ggml_new_f32(ctx0, 1.0f/sqrtf(float(n_embd)/n_head))); + struct ggml_tensor * KQ_scale = ggml_new_f32(ctx0, 1.0f/sqrtf(float(n_embd)/n_head)); + ggml_set_name(KQ_scale, "1/sqrt(n_embd/n_head)"); + + // KQ_scaled shape [n_past + N, N, n_head, 1] + struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, KQ_scale); + ggml_set_name(KQ_scaled, "KQ_scaled"); // KQ_masked = mask_past(KQ_scaled) - struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past); + struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past); + ggml_set_name(KQ_masked, "KQ_masked"); // KQ = soft_max(KQ_masked) - struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked); + struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked); + ggml_set_name(KQ_soft_max, "KQ_soft_max"); + // split cached V into n_head heads struct ggml_tensor * V = @@ -1161,9 +1225,11 @@ static bool llama_eval_internal( n_ctx*ggml_element_size(kv_self.v), n_ctx*ggml_element_size(kv_self.v)*n_embd/n_head, il*n_ctx*ggml_element_size(kv_self.v)*n_embd); + ggml_set_name(V, "V"); #if 1 struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max); + ggml_set_name(KQV, "KQV"); #else // make V contiguous in memory to speed up the matmul, however we waste time on the copy // on M1 this is faster for the perplexity computation, but ~5% slower for the single-token generation @@ -1174,11 +1240,13 @@ static bool llama_eval_internal( // KQV_merged = KQV.permute(0, 2, 1, 3) struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); + ggml_set_name(KQV_merged, "KQV_merged"); // cur = KQV_merged.contiguous().view(n_embd, N) cur = ggml_cpy(ctx0, KQV_merged, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N)); + ggml_set_name(cur, "KQV_merged_contiguous"); // projection (no bias) cur = ggml_mul_mat(ctx0, @@ -1250,7 +1318,7 @@ static bool llama_eval_internal( lctx.use_buf(ctx0, -1); // logits -> probs - //inpL = ggml_soft_max(ctx0, inpL); + //inpL = ggml_soft_max_inplace(ctx0, inpL); // run the computation ggml_build_forward_expand(&gf, inpL); @@ -1288,7 +1356,7 @@ static bool llama_eval_internal( } // extract embeddings - if (lctx.embedding.size()) { + if (!lctx.embedding.empty()) { auto & embedding_out = lctx.embedding; embedding_out.resize(n_embd); @@ -1339,6 +1407,8 @@ struct llama_sp_symbol { size_t n; }; +static_assert(std::is_trivially_copyable::value, "llama_sp_symbol is not trivially copyable"); + struct llama_sp_bigram { struct comparator { bool operator()(llama_sp_bigram & l, llama_sp_bigram & r) { @@ -1371,7 +1441,7 @@ struct llama_tokenizer { sym.prev = index - 1; sym.next = offs == text.size() ? -1 : index + 1; index++; - symbols_.emplace_back(std::move(sym)); + symbols_.emplace_back(sym); } // seed the work queue with all possible 2-character tokens. @@ -1462,12 +1532,12 @@ static std::vector llama_tokenize(const llama_vocab & vocab, co llama_tokenizer tokenizer(vocab); std::vector output; - if (text.size() == 0) { + if (text.empty()) { return output; } if (bos) { - output.push_back(1); + output.push_back(llama_token_bos()); } tokenizer.tokenize(text, output); @@ -1690,7 +1760,7 @@ void llama_sample_temperature(struct llama_context * ctx, llama_token_data_array } } -void llama_sample_repetition_penalty(struct llama_context * ctx, llama_token_data_array * candidates, llama_token * last_tokens, size_t last_tokens_size, float penalty) { +void llama_sample_repetition_penalty(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float penalty) { if (last_tokens_size == 0 || penalty == 1.0f) { return; } @@ -1698,7 +1768,7 @@ void llama_sample_repetition_penalty(struct llama_context * ctx, llama_token_dat const int64_t t_start_sample_us = ggml_time_us(); for (size_t i = 0; i < candidates->size; ++i) { - auto token_iter = std::find(last_tokens, last_tokens + last_tokens_size, candidates->data[i].id); + const auto * token_iter = std::find(last_tokens, last_tokens + last_tokens_size, candidates->data[i].id); if (token_iter == last_tokens + last_tokens_size) { continue; } @@ -1719,7 +1789,7 @@ void llama_sample_repetition_penalty(struct llama_context * ctx, llama_token_dat } } -void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, llama_token_data_array * candidates, llama_token * last_tokens_p, size_t last_tokens_size, float alpha_frequency, float alpha_presence) { +void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens_p, size_t last_tokens_size, float alpha_frequency, float alpha_presence) { if (last_tokens_size == 0 || (alpha_frequency == 0.0f && alpha_presence == 0.0f)) { return; } @@ -1776,7 +1846,7 @@ llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_ float k = powf((epsilon_hat * powf(2, *mu)) / (1 - powf(N, -epsilon_hat)), 1 / s_hat); // Sample the next word X using top-k sampling - llama_sample_top_k(nullptr, candidates, int(k)); + llama_sample_top_k(nullptr, candidates, int(k), 1); if (ctx) { ctx->t_sample_us += ggml_time_us() - t_start_sample_us; } @@ -1842,7 +1912,7 @@ llama_token llama_sample_token_greedy(struct llama_context * ctx, llama_token_da const int64_t t_start_sample_us = ggml_time_us(); // Find max element - auto max_iter = std::max_element(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) { + auto * max_iter = std::max_element(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) { return a.logit < b.logit; }); @@ -1885,7 +1955,6 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s switch (ftype) { case LLAMA_FTYPE_MOSTLY_Q4_0: quantized_type = GGML_TYPE_Q4_0; break; case LLAMA_FTYPE_MOSTLY_Q4_1: quantized_type = GGML_TYPE_Q4_1; break; - case LLAMA_FTYPE_MOSTLY_Q4_2: quantized_type = GGML_TYPE_Q4_2; break; case LLAMA_FTYPE_MOSTLY_Q5_0: quantized_type = GGML_TYPE_Q5_0; break; case LLAMA_FTYPE_MOSTLY_Q5_1: quantized_type = GGML_TYPE_Q5_1; break; case LLAMA_FTYPE_MOSTLY_Q8_0: quantized_type = GGML_TYPE_Q8_0; break; @@ -1896,7 +1965,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s nthread = std::thread::hardware_concurrency(); } - std::unique_ptr model_loader(new llama_model_loader(fname_inp.c_str(), /*use_mmap*/ false, + std::unique_ptr model_loader(new llama_model_loader(fname_inp, /*use_mmap*/ false, /*vocab_only*/ false)); llama_file_saver file_saver(fname_out.c_str(), model_loader->file_loaders.at(0).get(), ftype); @@ -1950,7 +2019,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s } else if (tensor.type == GGML_TYPE_F16) { f32_conv_buf.resize(nelements * sizeof(float)); f32_data = (float *) f32_conv_buf.addr; - auto f16_data = (const ggml_fp16_t *) tensor.data; + const auto * f16_data = (const ggml_fp16_t *) tensor.data; for (size_t i = 0; i < nelements; i++) { f32_data[i] = ggml_fp16_to_fp32(f16_data[i]); } @@ -1981,21 +2050,31 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s size_t first = counter; counter += chunk_size; if (first >= nelements) { if (!local_hist.empty()) { - for (int j=0; j %8.2f MB | hist: ", tensor.size/1024.0/1024.0, new_size/1024.0/1024.0); @@ -2041,7 +2120,7 @@ struct llama_context * llama_init_from_file( llama_context * ctx = new llama_context; - if (params.seed <= 0) { + if (params.seed < 0) { params.seed = time(NULL); } @@ -2067,7 +2146,7 @@ struct llama_context * llama_init_from_file( ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32; - if (!llama_model_load(path_model, *ctx, params.n_ctx, memory_type, + if (!llama_model_load(path_model, *ctx, params.n_ctx, params.n_gpu_layers, memory_type, params.use_mmap, params.use_mlock, params.vocab_only, params.progress_callback, params.progress_callback_user_data)) { fprintf(stderr, "%s: failed to load model\n", __func__); @@ -2193,7 +2272,8 @@ int llama_apply_lora_from_file_internal(struct llama_context * ctx, const char * fprintf(stderr, "%s: loading base model from '%s'\n", __func__, path_base_model); model_loader.reset(new llama_model_loader(path_base_model, /*use_mmap*/ true, /*vocab_only*/ false)); - size_t ctx_size, mmapped_size; + size_t ctx_size; + size_t mmapped_size; model_loader->calc_sizes(&ctx_size, &mmapped_size); base_buf.resize(ctx_size); @@ -2232,8 +2312,12 @@ int llama_apply_lora_from_file_internal(struct llama_context * ctx, const char * fin.read(reinterpret_cast(&ne[i]), sizeof(ne[i])); } - std::string name(length, 0); - fin.read(&name[0], length); + std::string name; + { + char buf[1024]; + fin.read(buf, length); + name = std::string(buf, length); + } // check for lora suffix and get the type of tensor const std::string lora_suffix = ".lora"; @@ -2248,7 +2332,7 @@ int llama_apply_lora_from_file_internal(struct llama_context * ctx, const char * base_name.erase(pos); // fprintf(stderr, "%s: %s => %s (lora type %s) ", __func__, name.c_str(),base_name.c_str(), lora_type.c_str()); - if (model_tensors.find(base_name.data()) == model_tensors.end()) { + if (model_tensors.find(base_name) == model_tensors.end()) { fprintf(stderr, "%s: unknown tensor '%s' in lora adapter\n", __func__, name.data()); return 1; } @@ -2328,7 +2412,7 @@ int llama_apply_lora_from_file_internal(struct llama_context * ctx, const char * if (scaling != 1.0f) { ggml_tensor * scale_tensor = ggml_new_f32(lora_ctx, scaling); - BA = ggml_scale(lora_ctx, BA, scale_tensor); + BA = ggml_scale_inplace(lora_ctx, BA, scale_tensor); } ggml_tensor * r; @@ -2350,8 +2434,9 @@ int llama_apply_lora_from_file_internal(struct llama_context * ctx, const char * lora_tensors.clear(); n_tensors++; - if (n_tensors % 4 == 0) + if (n_tensors % 4 == 0) { fprintf(stderr, "."); + } } } @@ -2376,21 +2461,21 @@ int llama_apply_lora_from_file(struct llama_context * ctx, const char * path_lor } } -int llama_get_kv_cache_token_count(struct llama_context * ctx) { +int llama_get_kv_cache_token_count(const struct llama_context * ctx) { return ctx->model.kv_self.n; } -#define LLAMA_MAX_RNG_STATE 64*1024 +#define LLAMA_MAX_RNG_STATE (64*1024) void llama_set_rng_seed(struct llama_context * ctx, int seed) { - if (seed <= 0) { + if (seed < 0) { seed = time(NULL); } ctx->rng.seed(seed); } // Returns the *maximum* size of the state -size_t llama_get_state_size(struct llama_context * ctx) { +size_t llama_get_state_size(const struct llama_context * ctx) { // we don't know size of rng until we actually serialize it. so reserve more than enough memory for its serialized state. // for reference, std::mt19937(1337) serializes to 6701 bytes. const size_t s_rng_size = sizeof(size_t); @@ -2421,8 +2506,8 @@ size_t llama_get_state_size(struct llama_context * ctx) { } // Copies the state to the specified destination address -size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dest) { - uint8_t * out = dest; +size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst) { + uint8_t * out = dst; // copy rng { @@ -2482,7 +2567,9 @@ size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dest) { if (kv_size) { const size_t elt_size = ggml_element_size(kv_self.k); + char buffer[4096]; + ggml_context * cpy_ctx = ggml_init({ sizeof(buffer), buffer, /* no_alloc */ true }); ggml_cgraph gf{}; gf.n_threads = 1; @@ -2506,10 +2593,12 @@ size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dest) { ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, k3d, kout3d)); ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, v3d, vout3d)); ggml_graph_compute(cpy_ctx, &gf); + + ggml_free(cpy_ctx); } } - const size_t written = out - dest; + const size_t written = out - dst; const size_t max_size = llama_get_state_size(ctx); LLAMA_ASSERT(written <= max_size); @@ -2519,15 +2608,15 @@ size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dest) { // Sets the state reading from the specified source address size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) { - const uint8_t * in = src; + const uint8_t * inp = src; // set rng { size_t rng_size; char rng_buf[LLAMA_MAX_RNG_STATE]; - memcpy(&rng_size, in, sizeof(rng_size)); in += sizeof(rng_size); - memcpy(&rng_buf[0], in, LLAMA_MAX_RNG_STATE); in += LLAMA_MAX_RNG_STATE; + memcpy(&rng_size, inp, sizeof(rng_size)); inp += sizeof(rng_size); + memcpy(&rng_buf[0], inp, LLAMA_MAX_RNG_STATE); inp += LLAMA_MAX_RNG_STATE; std::stringstream rng_ss; rng_ss.str(std::string(&rng_buf[0], rng_size)); @@ -2541,30 +2630,30 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) { size_t logits_cap; size_t logits_size; - memcpy(&logits_cap, in, sizeof(logits_cap)); in += sizeof(logits_cap); - memcpy(&logits_size, in, sizeof(logits_size)); in += sizeof(logits_size); + memcpy(&logits_cap, inp, sizeof(logits_cap)); inp += sizeof(logits_cap); + memcpy(&logits_size, inp, sizeof(logits_size)); inp += sizeof(logits_size); LLAMA_ASSERT(ctx->logits.capacity() == logits_cap); if (logits_size) { ctx->logits.resize(logits_size); - memcpy(ctx->logits.data(), in, logits_size * sizeof(float)); + memcpy(ctx->logits.data(), inp, logits_size * sizeof(float)); } - in += logits_cap * sizeof(float); + inp += logits_cap * sizeof(float); } // set embeddings { size_t embedding_size; - memcpy(&embedding_size, in, sizeof(embedding_size)); in += sizeof(embedding_size); + memcpy(&embedding_size, inp, sizeof(embedding_size)); inp += sizeof(embedding_size); LLAMA_ASSERT(ctx->embedding.capacity() == embedding_size); if (embedding_size) { - memcpy(ctx->embedding.data(), in, embedding_size * sizeof(float)); - in += embedding_size * sizeof(float); + memcpy(ctx->embedding.data(), inp, embedding_size * sizeof(float)); + inp += embedding_size * sizeof(float); } } @@ -2579,25 +2668,27 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) { size_t kv_size; int kv_ntok; - memcpy(&kv_size, in, sizeof(kv_size)); in += sizeof(kv_size); - memcpy(&kv_ntok, in, sizeof(kv_ntok)); in += sizeof(kv_ntok); + memcpy(&kv_size, inp, sizeof(kv_size)); inp += sizeof(kv_size); + memcpy(&kv_ntok, inp, sizeof(kv_ntok)); inp += sizeof(kv_ntok); if (kv_size) { LLAMA_ASSERT(kv_self.buf.size == kv_size); const size_t elt_size = ggml_element_size(kv_self.k); + char buffer[4096]; + ggml_context * cpy_ctx = ggml_init({ sizeof(buffer), buffer, /* no_alloc */ true }); ggml_cgraph gf{}; gf.n_threads = 1; ggml_tensor * kin3d = ggml_new_tensor_3d(cpy_ctx, kv_self.k->type, n_embd, kv_ntok, n_layer); - kin3d->data = (void *) in; - in += ggml_nbytes(kin3d); + kin3d->data = (void *) inp; + inp += ggml_nbytes(kin3d); ggml_tensor * vin3d = ggml_new_tensor_3d(cpy_ctx, kv_self.v->type, kv_ntok, n_embd, n_layer); - vin3d->data = (void *) in; - in += ggml_nbytes(vin3d); + vin3d->data = (void *) inp; + inp += ggml_nbytes(vin3d); ggml_tensor * k3d = ggml_view_3d(cpy_ctx, kv_self.k, n_embd, kv_ntok, n_layer, @@ -2611,12 +2702,13 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) { ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, vin3d, v3d)); ggml_graph_compute(cpy_ctx, &gf); + ggml_free(cpy_ctx); } ctx->model.kv_self.n = kv_ntok; } - const size_t nread = in - src; + const size_t nread = inp - src; const size_t max_size = llama_get_state_size(ctx); LLAMA_ASSERT(nread <= max_size); @@ -2624,134 +2716,6 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) { return nread; } -int llama_eval( - struct llama_context * ctx, - const llama_token * tokens, - int n_tokens, - int n_past, - int n_threads) { - if (!llama_eval_internal(*ctx, tokens, n_tokens, n_past, n_threads)) { - fprintf(stderr, "%s: failed to eval\n", __func__); - return 1; - } - // get a more accurate load time, upon first eval - if (!ctx->has_evaluated_once) { - ctx->t_load_us = ggml_time_us() - ctx->t_start_us; - ctx->has_evaluated_once = true; - } - return 0; -} - -int llama_tokenize( - struct llama_context * ctx, - const char * text, - llama_token * tokens, - int n_max_tokens, - bool add_bos) { - auto res = llama_tokenize(ctx->vocab, text, add_bos); - - if (n_max_tokens < (int) res.size()) { - fprintf(stderr, "%s: too many tokens\n", __func__); - return -((int) res.size()); - } - - for (size_t i = 0; i < res.size(); i++) { - tokens[i] = res[i]; - } - - return res.size(); -} - -int llama_n_vocab(struct llama_context * ctx) { - return ctx->vocab.id_to_token.size(); -} - -int llama_n_ctx(struct llama_context * ctx) { - return ctx->model.hparams.n_ctx; -} - -int llama_n_embd(struct llama_context * ctx) { - return ctx->model.hparams.n_embd; -} - -float * llama_get_logits(struct llama_context * ctx) { - return ctx->logits.data(); -} - -float * llama_get_embeddings(struct llama_context * ctx) { - return ctx->embedding.data(); -} - -const char * llama_token_to_str(struct llama_context * ctx, llama_token token) { - if (token >= llama_n_vocab(ctx)) { - return nullptr; - } - - return ctx->vocab.id_to_token[token].tok.c_str(); -} - -llama_token llama_token_bos() { - return 1; -} - -llama_token llama_token_eos() { - return 2; -} - -llama_token llama_token_nl() { - return 13; -} - - -void llama_print_timings(struct llama_context * ctx) { - const int64_t t_end_us = ggml_time_us(); - - const int32_t n_sample = std::max(1, ctx->n_sample); - const int32_t n_eval = std::max(1, ctx->n_eval); - const int32_t n_p_eval = std::max(1, ctx->n_p_eval); - - fprintf(stderr, "\n"); - fprintf(stderr, "%s: load time = %8.2f ms\n", __func__, ctx->t_load_us / 1000.0); - fprintf(stderr, "%s: sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3 * ctx->t_sample_us, n_sample, 1e-3 * ctx->t_sample_us / n_sample); - fprintf(stderr, "%s: prompt eval time = %8.2f ms / %5d tokens (%8.2f ms per token)\n", __func__, 1e-3 * ctx->t_p_eval_us, n_p_eval, 1e-3 * ctx->t_p_eval_us / n_p_eval); - fprintf(stderr, "%s: eval time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3 * ctx->t_eval_us, n_eval, 1e-3 * ctx->t_eval_us / n_eval); - fprintf(stderr, "%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0); -} - -void llama_reset_timings(struct llama_context * ctx) { - ctx->t_start_us = ggml_time_us(); - ctx->t_sample_us = ctx->n_sample = 0; - ctx->t_eval_us = ctx->n_eval = 0; - ctx->t_p_eval_us = ctx->n_p_eval = 0; -} - -const char * llama_print_system_info(void) { - static std::string s; - - s = ""; - s += "AVX = " + std::to_string(ggml_cpu_has_avx()) + " | "; - s += "AVX2 = " + std::to_string(ggml_cpu_has_avx2()) + " | "; - s += "AVX512 = " + std::to_string(ggml_cpu_has_avx512()) + " | "; - s += "AVX512_VBMI = " + std::to_string(ggml_cpu_has_avx512_vbmi()) + " | "; - s += "AVX512_VNNI = " + std::to_string(ggml_cpu_has_avx512_vnni()) + " | "; - s += "FMA = " + std::to_string(ggml_cpu_has_fma()) + " | "; - s += "NEON = " + std::to_string(ggml_cpu_has_neon()) + " | "; - s += "ARM_FMA = " + std::to_string(ggml_cpu_has_arm_fma()) + " | "; - s += "F16C = " + std::to_string(ggml_cpu_has_f16c()) + " | "; - s += "FP16_VA = " + std::to_string(ggml_cpu_has_fp16_va()) + " | "; - s += "WASM_SIMD = " + std::to_string(ggml_cpu_has_wasm_simd()) + " | "; - s += "BLAS = " + std::to_string(ggml_cpu_has_blas()) + " | "; - s += "SSE3 = " + std::to_string(ggml_cpu_has_sse3()) + " | "; - s += "VSX = " + std::to_string(ggml_cpu_has_vsx()) + " | "; - - return s.c_str(); -} - -// For internal test use -std::vector>& llama_internal_get_tensor_map(struct llama_context * ctx) { - return ctx->model.tensors_by_name; -} - bool llama_load_session_file(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { llama_file file(path_session, "rb"); @@ -2760,7 +2724,7 @@ bool llama_load_session_file(struct llama_context * ctx, const char * path_sessi const uint32_t magic = file.read_u32(); const uint32_t version = file.read_u32(); - if (!(magic == LLAMA_SESSION_MAGIC && version == LLAMA_SESSION_VERSION)) { + if (magic != LLAMA_SESSION_MAGIC || version != LLAMA_SESSION_VERSION) { fprintf(stderr, "%s : unknown (magic, version) for session file: %08x, %08x\n", __func__, magic, version); return false; } @@ -2792,7 +2756,7 @@ bool llama_load_session_file(struct llama_context * ctx, const char * path_sessi const size_t n_state_size_cur = file.size - file.tell(); const size_t n_state_size_max = llama_get_state_size(ctx); - if (n_state_size_cur > n_state_size_max) { + if (n_state_size_cur > n_state_size_max) { fprintf(stderr, "%s : the state size in session file is too big! max %zu, got %zu\n", __func__, n_state_size_max, n_state_size_cur); return false; } @@ -2829,4 +2793,135 @@ bool llama_save_session_file(struct llama_context * ctx, const char * path_sessi } return true; -} \ No newline at end of file +} + +int llama_eval( + struct llama_context * ctx, + const llama_token * tokens, + int n_tokens, + int n_past, + int n_threads) { + if (!llama_eval_internal(*ctx, tokens, n_tokens, n_past, n_threads)) { + fprintf(stderr, "%s: failed to eval\n", __func__); + return 1; + } + + // get a more accurate load time, upon first eval + // TODO: fix this + if (!ctx->has_evaluated_once) { + ctx->t_load_us = ggml_time_us() - ctx->t_start_us; + ctx->has_evaluated_once = true; + } + + return 0; +} + +int llama_tokenize( + struct llama_context * ctx, + const char * text, + llama_token * tokens, + int n_max_tokens, + bool add_bos) { + auto res = llama_tokenize(ctx->vocab, text, add_bos); + + if (n_max_tokens < (int) res.size()) { + fprintf(stderr, "%s: too many tokens\n", __func__); + return -((int) res.size()); + } + + for (size_t i = 0; i < res.size(); i++) { + tokens[i] = res[i]; + } + + return res.size(); +} + +int llama_n_vocab(const struct llama_context * ctx) { + return ctx->vocab.id_to_token.size(); +} + +int llama_n_ctx(const struct llama_context * ctx) { + return ctx->model.hparams.n_ctx; +} + +int llama_n_embd(const struct llama_context * ctx) { + return ctx->model.hparams.n_embd; +} + +float * llama_get_logits(struct llama_context * ctx) { + return ctx->logits.data(); +} + +float * llama_get_embeddings(struct llama_context * ctx) { + return ctx->embedding.data(); +} + +const char * llama_token_to_str(const struct llama_context * ctx, llama_token token) { + if (token >= llama_n_vocab(ctx)) { + return nullptr; + } + + return ctx->vocab.id_to_token[token].tok.c_str(); +} + +llama_token llama_token_bos() { + return 1; +} + +llama_token llama_token_eos() { + return 2; +} + +llama_token llama_token_nl() { + return 13; +} + + +void llama_print_timings(struct llama_context * ctx) { + const int64_t t_end_us = ggml_time_us(); + + const int32_t n_sample = std::max(1, ctx->n_sample); + const int32_t n_eval = std::max(1, ctx->n_eval); + const int32_t n_p_eval = std::max(1, ctx->n_p_eval); + + fprintf(stderr, "\n"); + fprintf(stderr, "%s: load time = %8.2f ms\n", __func__, ctx->t_load_us / 1000.0); + fprintf(stderr, "%s: sample time = %8.2f ms / %5d runs (%8.2f ms per token)\n", __func__, 1e-3 * ctx->t_sample_us, n_sample, 1e-3 * ctx->t_sample_us / n_sample); + fprintf(stderr, "%s: prompt eval time = %8.2f ms / %5d tokens (%8.2f ms per token)\n", __func__, 1e-3 * ctx->t_p_eval_us, n_p_eval, 1e-3 * ctx->t_p_eval_us / n_p_eval); + fprintf(stderr, "%s: eval time = %8.2f ms / %5d runs (%8.2f ms per token)\n", __func__, 1e-3 * ctx->t_eval_us, n_eval, 1e-3 * ctx->t_eval_us / n_eval); + fprintf(stderr, "%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0); +} + +void llama_reset_timings(struct llama_context * ctx) { + ctx->t_start_us = ggml_time_us(); + ctx->t_sample_us = ctx->n_sample = 0; + ctx->t_eval_us = ctx->n_eval = 0; + ctx->t_p_eval_us = ctx->n_p_eval = 0; +} + +const char * llama_print_system_info(void) { + static std::string s; + + s = ""; + s += "AVX = " + std::to_string(ggml_cpu_has_avx()) + " | "; + s += "AVX2 = " + std::to_string(ggml_cpu_has_avx2()) + " | "; + s += "AVX512 = " + std::to_string(ggml_cpu_has_avx512()) + " | "; + s += "AVX512_VBMI = " + std::to_string(ggml_cpu_has_avx512_vbmi()) + " | "; + s += "AVX512_VNNI = " + std::to_string(ggml_cpu_has_avx512_vnni()) + " | "; + s += "FMA = " + std::to_string(ggml_cpu_has_fma()) + " | "; + s += "NEON = " + std::to_string(ggml_cpu_has_neon()) + " | "; + s += "ARM_FMA = " + std::to_string(ggml_cpu_has_arm_fma()) + " | "; + s += "F16C = " + std::to_string(ggml_cpu_has_f16c()) + " | "; + s += "FP16_VA = " + std::to_string(ggml_cpu_has_fp16_va()) + " | "; + s += "WASM_SIMD = " + std::to_string(ggml_cpu_has_wasm_simd()) + " | "; + s += "BLAS = " + std::to_string(ggml_cpu_has_blas()) + " | "; + s += "SSE3 = " + std::to_string(ggml_cpu_has_sse3()) + " | "; + s += "VSX = " + std::to_string(ggml_cpu_has_vsx()) + " | "; + + return s.c_str(); +} + +// For internal test use +std::vector>& llama_internal_get_tensor_map(struct llama_context * ctx) { + return ctx->model.tensors_by_name; +} diff --git a/examples/talk-llama/llama.h b/examples/talk-llama/llama.h index 2f12090..21cba8c 100644 --- a/examples/talk-llama/llama.h +++ b/examples/talk-llama/llama.h @@ -19,7 +19,7 @@ # define LLAMA_API #endif -#define LLAMA_FILE_VERSION 1 +#define LLAMA_FILE_VERSION 2 #define LLAMA_FILE_MAGIC 'ggjt' #define LLAMA_FILE_MAGIC_UNVERSIONED 'ggml' #define LLAMA_SESSION_MAGIC 'ggsn' @@ -54,9 +54,10 @@ extern "C" { typedef void (*llama_progress_callback)(float progress, void *ctx); struct llama_context_params { - int n_ctx; // text context - int n_parts; // -1 for default - int seed; // RNG seed, 0 for random + int n_ctx; // text context + int n_parts; // -1 for default + int n_gpu_layers; // number of layers to store in VRAM + int seed; // RNG seed, -1 for random bool f16_kv; // use fp16 for KV cache bool logits_all; // the llama_eval() call computes all logits, not just the last one @@ -78,7 +79,7 @@ extern "C" { LLAMA_FTYPE_MOSTLY_Q4_0 = 2, // except 1d tensors LLAMA_FTYPE_MOSTLY_Q4_1 = 3, // except 1d tensors LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16 = 4, // tok_embeddings.weight and output.weight are F16 - LLAMA_FTYPE_MOSTLY_Q4_2 = 5, // except 1d tensors + // LLAMA_FTYPE_MOSTLY_Q4_2 = 5, // support has been removed // LLAMA_FTYPE_MOSTLY_Q4_3 (6) support has been removed LLAMA_FTYPE_MOSTLY_Q8_0 = 7, // except 1d tensors LLAMA_FTYPE_MOSTLY_Q5_0 = 8, // except 1d tensors @@ -122,19 +123,19 @@ extern "C" { int n_threads); // Returns the number of tokens in the KV cache - LLAMA_API int llama_get_kv_cache_token_count(struct llama_context * ctx); + LLAMA_API int llama_get_kv_cache_token_count(const struct llama_context * ctx); // Sets the current rng seed. LLAMA_API void llama_set_rng_seed(struct llama_context * ctx, int seed); // Returns the maximum size in bytes of the state (rng, logits, embedding // and kv_cache) - will often be smaller after compacting tokens - LLAMA_API size_t llama_get_state_size(struct llama_context * ctx); + LLAMA_API size_t llama_get_state_size(const struct llama_context * ctx); // Copies the state to the specified destination address. // Destination needs to have allocated enough memory. // Returns the number of bytes copied - LLAMA_API size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dest); + LLAMA_API size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst); // Set the state reading from the specified address // Returns the number of bytes read @@ -143,6 +144,7 @@ extern "C" { // Save/load session file LLAMA_API bool llama_load_session_file(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out); LLAMA_API bool llama_save_session_file(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count); + // Run the llama inference to obtain the logits and probabilities for the next token. // tokens + n_tokens is the provided batch of new tokens to process // n_past is the number of tokens to use from previous eval calls @@ -166,9 +168,9 @@ extern "C" { int n_max_tokens, bool add_bos); - LLAMA_API int llama_n_vocab(struct llama_context * ctx); - LLAMA_API int llama_n_ctx (struct llama_context * ctx); - LLAMA_API int llama_n_embd (struct llama_context * ctx); + LLAMA_API int llama_n_vocab(const struct llama_context * ctx); + LLAMA_API int llama_n_ctx (const struct llama_context * ctx); + LLAMA_API int llama_n_embd (const struct llama_context * ctx); // Token logits obtained from the last call to llama_eval() // The logits for the last token are stored in the last row @@ -182,7 +184,7 @@ extern "C" { LLAMA_API float * llama_get_embeddings(struct llama_context * ctx); // Token Id -> String. Uses the vocabulary in the provided context - LLAMA_API const char * llama_token_to_str(struct llama_context * ctx, llama_token token); + LLAMA_API const char * llama_token_to_str(const struct llama_context * ctx, llama_token token); // Special tokens LLAMA_API llama_token llama_token_bos(); @@ -192,25 +194,25 @@ extern "C" { // Sampling functions /// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix. - LLAMA_API void llama_sample_repetition_penalty(struct llama_context * ctx, llama_token_data_array * candidates, llama_token * last_tokens, size_t last_tokens_size, float penalty); + LLAMA_API void llama_sample_repetition_penalty(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float penalty); /// @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details. - LLAMA_API void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, llama_token_data_array * candidates, llama_token * last_tokens, size_t last_tokens_size, float alpha_frequency, float alpha_presence); + LLAMA_API void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float alpha_frequency, float alpha_presence); /// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits. LLAMA_API void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates); /// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 - LLAMA_API void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * candidates, int k, size_t min_keep = 1); + LLAMA_API void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * candidates, int k, size_t min_keep); /// @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 - LLAMA_API void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep = 1); + LLAMA_API void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep); /// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/. - LLAMA_API void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep = 1); + LLAMA_API void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep); /// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666. - LLAMA_API void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep = 1); + LLAMA_API void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep); LLAMA_API void llama_sample_temperature(struct llama_context * ctx, llama_token_data_array * candidates, float temp); /// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. diff --git a/examples/talk-llama/talk-llama.cpp b/examples/talk-llama/talk-llama.cpp index 7960ab7..45b8cb7 100644 --- a/examples/talk-llama/talk-llama.cpp +++ b/examples/talk-llama/talk-llama.cpp @@ -560,7 +560,7 @@ int main(int argc, char ** argv) { embd_inp.insert(embd_inp.end(), embd.begin(), embd.end()); n_past += embd.size(); - + embd.clear(); if (done) break; @@ -577,7 +577,7 @@ int main(int argc, char ** argv) { if (!path_session.empty() && need_to_save_session) { need_to_save_session = false; llama_save_session_file(ctx_llama, path_session.c_str(), session_tokens.data(), session_tokens.size()); - } + } llama_token id = 0; @@ -609,8 +609,8 @@ int main(int argc, char ** argv) { id = llama_sample_token_greedy(ctx_llama, &candidates_p); } else { // Temperature sampling - llama_sample_top_k(ctx_llama, &candidates_p, top_k); - llama_sample_top_p(ctx_llama, &candidates_p, top_p); + llama_sample_top_k(ctx_llama, &candidates_p, top_k, 1); + llama_sample_top_p(ctx_llama, &candidates_p, top_p, 1); llama_sample_temperature(ctx_llama, &candidates_p, temp); id = llama_sample_token(ctx_llama, &candidates_p); }