talk-llama : sync llama.cpp

This commit is contained in:
Georgi Gerganov 2024-06-16 13:10:54 +03:00
parent 4942b1b428
commit 061eeb9f61
6 changed files with 9112 additions and 3693 deletions

File diff suppressed because it is too large Load Diff

View File

@ -81,9 +81,12 @@ extern "C" {
LLAMA_VOCAB_PRE_TYPE_GPT2 = 7, LLAMA_VOCAB_PRE_TYPE_GPT2 = 7,
LLAMA_VOCAB_PRE_TYPE_REFACT = 8, LLAMA_VOCAB_PRE_TYPE_REFACT = 8,
LLAMA_VOCAB_PRE_TYPE_COMMAND_R = 9, LLAMA_VOCAB_PRE_TYPE_COMMAND_R = 9,
LLAMA_VOCAB_PRE_TYPE_QWEN2 = 10, LLAMA_VOCAB_PRE_TYPE_STABLELM2 = 10,
LLAMA_VOCAB_PRE_TYPE_OLMO = 11, LLAMA_VOCAB_PRE_TYPE_QWEN2 = 11,
LLAMA_VOCAB_PRE_TYPE_DBRX = 12, LLAMA_VOCAB_PRE_TYPE_OLMO = 12,
LLAMA_VOCAB_PRE_TYPE_DBRX = 13,
LLAMA_VOCAB_PRE_TYPE_SMAUG = 14,
LLAMA_VOCAB_PRE_TYPE_PORO = 15,
}; };
// note: these values should be synchronized with ggml_rope // note: these values should be synchronized with ggml_rope
@ -95,7 +98,7 @@ extern "C" {
LLAMA_ROPE_TYPE_GLM = 4, LLAMA_ROPE_TYPE_GLM = 4,
}; };
enum llama_token_type { enum llama_token_type { //TODO: remove, required until per token attributes are available from GGUF file
LLAMA_TOKEN_TYPE_UNDEFINED = 0, LLAMA_TOKEN_TYPE_UNDEFINED = 0,
LLAMA_TOKEN_TYPE_NORMAL = 1, LLAMA_TOKEN_TYPE_NORMAL = 1,
LLAMA_TOKEN_TYPE_UNKNOWN = 2, LLAMA_TOKEN_TYPE_UNKNOWN = 2,
@ -105,6 +108,20 @@ extern "C" {
LLAMA_TOKEN_TYPE_BYTE = 6, LLAMA_TOKEN_TYPE_BYTE = 6,
}; };
enum llama_token_attr {
LLAMA_TOKEN_ATTR_UNDEFINED = 0,
LLAMA_TOKEN_ATTR_UNKNOWN = 1 << 0,
LLAMA_TOKEN_ATTR_UNUSED = 1 << 1,
LLAMA_TOKEN_ATTR_NORMAL = 1 << 2,
LLAMA_TOKEN_ATTR_CONTROL = 1 << 3, // SPECIAL?
LLAMA_TOKEN_ATTR_USER_DEFINED = 1 << 4,
LLAMA_TOKEN_ATTR_BYTE = 1 << 5,
LLAMA_TOKEN_ATTR_NORMALIZED = 1 << 6,
LLAMA_TOKEN_ATTR_LSTRIP = 1 << 7,
LLAMA_TOKEN_ATTR_RSTRIP = 1 << 8,
LLAMA_TOKEN_ATTR_SINGLE_WORD = 1 << 9,
};
// model file types // model file types
enum llama_ftype { enum llama_ftype {
LLAMA_FTYPE_ALL_F32 = 0, LLAMA_FTYPE_ALL_F32 = 0,
@ -242,6 +259,9 @@ extern "C" {
// proportion of the model (layers or rows) to offload to each GPU, size: llama_max_devices() // proportion of the model (layers or rows) to offload to each GPU, size: llama_max_devices()
const float * tensor_split; const float * tensor_split;
// comma separated list of RPC servers to use for offloading
const char * rpc_servers;
// Called with a progress value between 0.0 and 1.0. Pass NULL to disable. // Called with a progress value between 0.0 and 1.0. Pass NULL to disable.
// If the provided progress_callback returns true, model loading continues. // If the provided progress_callback returns true, model loading continues.
// If it returns false, model loading is immediately aborted. // If it returns false, model loading is immediately aborted.
@ -260,6 +280,8 @@ extern "C" {
bool check_tensors; // validate model tensor data bool check_tensors; // validate model tensor data
}; };
// NOTE: changing the default values of parameters marked as [EXPERIMENTAL] may cause crashes or incorrect results in certain configurations
// https://github.com/ggerganov/llama.cpp/pull/7544
struct llama_context_params { struct llama_context_params {
uint32_t seed; // RNG seed, -1 for random uint32_t seed; // RNG seed, -1 for random
uint32_t n_ctx; // text context, 0 = from model uint32_t n_ctx; // text context, 0 = from model
@ -286,14 +308,14 @@ extern "C" {
ggml_backend_sched_eval_callback cb_eval; ggml_backend_sched_eval_callback cb_eval;
void * cb_eval_user_data; void * cb_eval_user_data;
enum ggml_type type_k; // data type for K cache enum ggml_type type_k; // data type for K cache [EXPERIMENTAL]
enum ggml_type type_v; // data type for V cache enum ggml_type type_v; // data type for V cache [EXPERIMENTAL]
// Keep the booleans together to avoid misalignment during copy-by-value. // Keep the booleans together to avoid misalignment during copy-by-value.
bool logits_all; // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead) bool logits_all; // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
bool embeddings; // if true, extract embeddings (together with logits) bool embeddings; // if true, extract embeddings (together with logits)
bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
bool flash_attn; // whether to use flash attention bool flash_attn; // whether to use flash attention [EXPERIMENTAL]
// Abort callback // Abort callback
// if it returns true, execution of llama_decode() will be aborted // if it returns true, execution of llama_decode() will be aborted
@ -344,6 +366,9 @@ extern "C" {
// modifies a preceding LLAMA_GRETYPE_CHAR or // modifies a preceding LLAMA_GRETYPE_CHAR or
// LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA]) // LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA])
LLAMA_GRETYPE_CHAR_ALT = 6, LLAMA_GRETYPE_CHAR_ALT = 6,
// any character (.)
LLAMA_GRETYPE_CHAR_ANY = 7,
}; };
typedef struct llama_grammar_element { typedef struct llama_grammar_element {
@ -417,8 +442,8 @@ extern "C" {
LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx); LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx);
LLAMA_API enum llama_vocab_type llama_vocab_type (const struct llama_model * model); LLAMA_API enum llama_vocab_type llama_vocab_type (const struct llama_model * model);
LLAMA_API enum llama_rope_type llama_rope_type (const struct llama_model * model); LLAMA_API enum llama_rope_type llama_rope_type (const struct llama_model * model);
LLAMA_API int32_t llama_n_vocab (const struct llama_model * model); LLAMA_API int32_t llama_n_vocab (const struct llama_model * model);
LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model); LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model);
@ -755,6 +780,12 @@ extern "C" {
// n_threads_batch is the number of threads used for prompt and batch processing (multiple tokens) // n_threads_batch is the number of threads used for prompt and batch processing (multiple tokens)
LLAMA_API void llama_set_n_threads(struct llama_context * ctx, uint32_t n_threads, uint32_t n_threads_batch); LLAMA_API void llama_set_n_threads(struct llama_context * ctx, uint32_t n_threads, uint32_t n_threads_batch);
// Get the number of threads used for generation of a single token.
LLAMA_API uint32_t llama_n_threads(struct llama_context * ctx);
// Get the number of threads used for prompt and batch processing (multiple token).
LLAMA_API uint32_t llama_n_threads_batch(struct llama_context * ctx);
// Set whether to use causal attention or not // Set whether to use causal attention or not
// If set to true, the model will only attend to the past tokens // If set to true, the model will only attend to the past tokens
LLAMA_API void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn); LLAMA_API void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn);
@ -808,11 +839,14 @@ extern "C" {
LLAMA_API float llama_token_get_score(const struct llama_model * model, llama_token token); LLAMA_API float llama_token_get_score(const struct llama_model * model, llama_token token);
LLAMA_API enum llama_token_type llama_token_get_type(const struct llama_model * model, llama_token token); LLAMA_API enum llama_token_attr llama_token_get_attr(const struct llama_model * model, llama_token token);
// Check if the token is supposed to end generation (end-of-generation, eg. EOS, EOT, etc.) // Check if the token is supposed to end generation (end-of-generation, eg. EOS, EOT, etc.)
LLAMA_API bool llama_token_is_eog(const struct llama_model * model, llama_token token); LLAMA_API bool llama_token_is_eog(const struct llama_model * model, llama_token token);
// Identify if Token Id is a control token or a render-able token
LLAMA_API bool llama_token_is_control(const struct llama_model * model, llama_token token);
// Special tokens // Special tokens
LLAMA_API llama_token llama_token_bos(const struct llama_model * model); // beginning-of-sentence LLAMA_API llama_token llama_token_bos(const struct llama_model * model); // beginning-of-sentence
LLAMA_API llama_token llama_token_eos(const struct llama_model * model); // end-of-sentence LLAMA_API llama_token llama_token_eos(const struct llama_model * model); // end-of-sentence
@ -1026,49 +1060,9 @@ extern "C" {
llama_token token); llama_token token);
// //
// Beam search // Model split
// //
struct llama_beam_view {
const llama_token * tokens;
size_t n_tokens;
float p; // Cumulative beam probability (renormalized relative to all beams)
bool eob; // Callback should set this to true when a beam is at end-of-beam.
};
// Passed to beam_search_callback function.
// Whenever 0 < common_prefix_length, this number of tokens should be copied from any of the beams
// (e.g. beams[0]) as they will be removed (shifted) from all beams in all subsequent callbacks.
// These pointers are valid only during the synchronous callback, so should not be saved.
struct llama_beams_state {
struct llama_beam_view * beam_views;
size_t n_beams; // Number of elements in beam_views[].
size_t common_prefix_length; // Current max length of prefix tokens shared by all beams.
bool last_call; // True iff this is the last callback invocation.
};
// Type of pointer to the beam_search_callback function.
// void* callback_data is any custom data passed to llama_beam_search, that is subsequently
// passed back to beam_search_callback. This avoids having to use global variables in the callback.
typedef void (*llama_beam_search_callback_fn_t)(void * callback_data, struct llama_beams_state);
/// @details Deterministically returns entire sentence constructed by a beam search.
/// @param ctx Pointer to the llama_context.
/// @param callback Invoked for each iteration of the beam_search loop, passing in beams_state.
/// @param callback_data A pointer that is simply passed back to callback.
/// @param n_beams Number of beams to use.
/// @param n_past Number of tokens already evaluated.
/// @param n_predict Maximum number of tokens to predict. EOS may occur earlier.
LLAMA_API void llama_beam_search(
struct llama_context * ctx,
llama_beam_search_callback_fn_t callback,
void * callback_data,
size_t n_beams,
int32_t n_past,
int32_t n_predict);
/// @details Build a split GGUF final path for this chunk. /// @details Build a split GGUF final path for this chunk.
/// llama_split_path(split_path, sizeof(split_path), "/models/ggml-model-q4_0", 2, 4) => split_path = "/models/ggml-model-q4_0-00002-of-00004.gguf" /// llama_split_path(split_path, sizeof(split_path), "/models/ggml-model-q4_0", 2, 4) => split_path = "/models/ggml-model-q4_0-00002-of-00004.gguf"
// Returns the split_path length. // Returns the split_path length.

File diff suppressed because it is too large Load Diff

View File

@ -1,17 +1,20 @@
#pragma once #pragma once
#include <cstdint> #include <cstdint>
#include <map>
#include <utility>
#include <vector> #include <vector>
#include <unordered_map>
#include <unordered_set>
extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_number; struct range_nfd {
extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_letter; uint32_t first;
extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_separator; uint32_t last;
extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_whitespace; uint32_t nfd;
extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_accent_mark; };
extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_punctuation;
extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_symbol; static const uint32_t MAX_CODEPOINTS = 0x110000;
extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_control;
extern const std::multimap<uint32_t, uint32_t> unicode_map_nfd; extern const std::vector<std::pair<uint32_t, uint16_t>> unicode_ranges_flags;
extern const std::map<char32_t, char32_t> unicode_map_lowercase; extern const std::unordered_set<uint32_t> unicode_set_whitespace;
extern const std::unordered_map<uint32_t, uint32_t> unicode_map_lowercase;
extern const std::unordered_map<uint32_t, uint32_t> unicode_map_uppercase;
extern const std::vector<range_nfd> unicode_ranges_nfd;

View File

@ -1,4 +1,4 @@
#include "unicode.h" #include "unicode.h"
#include "unicode-data.h" #include "unicode-data.h"
#include <cassert> #include <cassert>
@ -109,57 +109,49 @@ static uint32_t unicode_cpt_from_utf8(const std::string & utf8, size_t & offset)
// return result; // return result;
//} //}
static std::unordered_map<uint32_t, int> unicode_cpt_type_map() { static std::vector<codepoint_flags> unicode_cpt_flags_array() {
std::unordered_map<uint32_t, int> cpt_types; std::vector<codepoint_flags> cpt_flags(MAX_CODEPOINTS, codepoint_flags::UNDEFINED);
for (auto p : unicode_ranges_number) {
for (auto i = p.first; i <= p.second; ++i) { assert (unicode_ranges_flags.front().first == 0);
cpt_types[i] = CODEPOINT_TYPE_NUMBER; assert (unicode_ranges_flags.back().first == MAX_CODEPOINTS);
for (size_t i = 1; i < unicode_ranges_flags.size(); ++i) {
const auto range_ini = unicode_ranges_flags[i-1]; // codepoint_ini, flags
const auto range_end = unicode_ranges_flags[i]; // codepoint_end, flags
for (uint32_t cpt = range_ini.first; cpt < range_end.first; ++cpt) {
cpt_flags[cpt] = range_ini.second;
} }
} }
for (auto p : unicode_ranges_letter) {
for (auto i = p.first; i <= p.second; ++i) { for (auto cpt : unicode_set_whitespace) {
cpt_types[i] = CODEPOINT_TYPE_LETTER; cpt_flags[cpt].is_whitespace = true;
}
} }
for (auto p : unicode_ranges_separator) {
for (auto i = p.first; i <= p.second; ++i) { for (auto p : unicode_map_lowercase) {
cpt_types[i] = CODEPOINT_TYPE_SEPARATOR; cpt_flags[p.second].is_lowercase = true;
}
} }
for (auto p : unicode_ranges_accent_mark) {
for (auto i = p.first; i <= p.second; ++i) { for (auto p : unicode_map_uppercase) {
cpt_types[i] = CODEPOINT_TYPE_ACCENT_MARK; cpt_flags[p.second].is_uppercase = true;
}
} }
for (auto p : unicode_ranges_punctuation) {
for (auto i = p.first; i <= p.second; ++i) { for (auto &range : unicode_ranges_nfd) { // start, last, nfd
cpt_types[i] = CODEPOINT_TYPE_PUNCTUATION; cpt_flags[range.nfd].is_nfd = true;
}
} }
for (auto p : unicode_ranges_symbol) {
for (auto i = p.first; i <= p.second; ++i) { return cpt_flags;
cpt_types[i] = CODEPOINT_TYPE_SYMBOL;
}
}
for (auto p : unicode_ranges_control) {
for (auto i = p.first; i <= p.second; ++i) {
cpt_types[i] = CODEPOINT_TYPE_CONTROL;
}
}
return cpt_types;
} }
static std::unordered_map<uint8_t, std::string> unicode_byte_to_utf8_map() { static std::unordered_map<uint8_t, std::string> unicode_byte_to_utf8_map() {
std::unordered_map<uint8_t, std::string> map; std::unordered_map<uint8_t, std::string> map;
for (int ch = u'!'; ch <= u'~'; ++ch) { for (int ch = 0x21; ch <= 0x7E; ++ch) { // u'!' to u'~'
assert(0 <= ch && ch < 256); assert(0 <= ch && ch < 256);
map[ch] = unicode_cpt_to_utf8(ch); map[ch] = unicode_cpt_to_utf8(ch);
} }
for (int ch = u'¡'; ch <= u'¬'; ++ch) { for (int ch = 0xA1; ch <= 0xAC; ++ch) { // u'¡' to u'¬'
assert(0 <= ch && ch < 256); assert(0 <= ch && ch < 256);
map[ch] = unicode_cpt_to_utf8(ch); map[ch] = unicode_cpt_to_utf8(ch);
} }
for (int ch = u'®'; ch <= u'ÿ'; ++ch) { for (int ch = 0xAE; ch <= 0xFF; ++ch) { // u'®' to u'ÿ'
assert(0 <= ch && ch < 256); assert(0 <= ch && ch < 256);
map[ch] = unicode_cpt_to_utf8(ch); map[ch] = unicode_cpt_to_utf8(ch);
} }
@ -175,15 +167,15 @@ static std::unordered_map<uint8_t, std::string> unicode_byte_to_utf8_map() {
static std::unordered_map<std::string, uint8_t> unicode_utf8_to_byte_map() { static std::unordered_map<std::string, uint8_t> unicode_utf8_to_byte_map() {
std::unordered_map<std::string, uint8_t> map; std::unordered_map<std::string, uint8_t> map;
for (int ch = u'!'; ch <= u'~'; ++ch) { for (int ch = 0x21; ch <= 0x7E; ++ch) { // u'!' to u'~'
assert(0 <= ch && ch < 256); assert(0 <= ch && ch < 256);
map[unicode_cpt_to_utf8(ch)] = ch; map[unicode_cpt_to_utf8(ch)] = ch;
} }
for (int ch = u'¡'; ch <= u'¬'; ++ch) { for (int ch = 0xA1; ch <= 0xAC; ++ch) { // u'¡' to u'¬'
assert(0 <= ch && ch < 256); assert(0 <= ch && ch < 256);
map[unicode_cpt_to_utf8(ch)] = ch; map[unicode_cpt_to_utf8(ch)] = ch;
} }
for (int ch = u'®'; ch <= u'ÿ'; ++ch) { for (int ch = 0xAE; ch <= 0xFF; ++ch) { // u'®' to u'ÿ'
assert(0 <= ch && ch < 256); assert(0 <= ch && ch < 256);
map[unicode_cpt_to_utf8(ch)] = ch; map[unicode_cpt_to_utf8(ch)] = ch;
} }
@ -238,8 +230,9 @@ static std::vector<size_t> unicode_regex_split_custom_gpt2(const std::string & t
return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : 0; return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : 0;
}; };
auto _get_cpt_type = [&] (const size_t pos) -> int { auto _get_flags = [&] (const size_t pos) -> codepoint_flags {
return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_type(cpts[pos]) : CODEPOINT_TYPE_UNIDENTIFIED; static const codepoint_flags undef(codepoint_flags::UNDEFINED);
return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags(cpts[pos]) : undef;
}; };
size_t _prev_end = offset_ini; size_t _prev_end = offset_ini;
@ -261,7 +254,7 @@ static std::vector<size_t> unicode_regex_split_custom_gpt2(const std::string & t
for (size_t pos = offset_ini; pos < offset_end; /*pos++*/ ) { for (size_t pos = offset_ini; pos < offset_end; /*pos++*/ ) {
const char32_t cpt = _get_cpt(pos); const char32_t cpt = _get_cpt(pos);
const int cpt_type = _get_cpt_type(pos); const auto flags = _get_flags(pos);
// regex: 's|'t|'re|'ve|'m|'ll|'d // regex: 's|'t|'re|'ve|'m|'ll|'d
if (cpt == '\'' && pos+1 < offset_end) { if (cpt == '\'' && pos+1 < offset_end) {
@ -281,39 +274,37 @@ static std::vector<size_t> unicode_regex_split_custom_gpt2(const std::string & t
} }
} }
char32_t cpt2 = (cpt == ' ' ? _get_cpt(pos+1) : cpt); auto flags2 = (cpt == ' ' ? _get_flags(pos+1) : flags);
int cpt2_type = (cpt == ' ' ? _get_cpt_type(pos+1) : cpt_type);
// regex: <space>?\p{L}+ // regex: <space>?\p{L}+
if (cpt2_type == CODEPOINT_TYPE_LETTER) { if (flags2.is_letter) {
pos += (cpt == ' '); pos += (cpt == ' ');
while (cpt2_type == CODEPOINT_TYPE_LETTER) { while (flags2.is_letter) {
cpt2_type = _get_cpt_type(++pos); flags2 = _get_flags(++pos);
} }
_add_token(pos); _add_token(pos);
continue; continue;
} }
// regex: <space>?\p{N}+ // regex: <space>?\p{N}+
if (cpt2_type == CODEPOINT_TYPE_NUMBER) { if (flags2.is_number) {
pos += (cpt == ' '); pos += (cpt == ' ');
while (cpt2_type == CODEPOINT_TYPE_NUMBER) { while (flags2.is_number) {
cpt2_type = _get_cpt_type(++pos); flags2 = _get_flags(++pos);
} }
_add_token(pos); _add_token(pos);
continue; continue;
} }
// regex: <space>?[^\s\p{L}\p{N}]+ // regex: <space>?[^\s\p{L}\p{N}]+
if (!unicode_cpt_is_whitespace(cpt2) && cpt2_type != CODEPOINT_TYPE_LETTER && cpt2_type != CODEPOINT_TYPE_NUMBER && cpt2_type != CODEPOINT_TYPE_UNIDENTIFIED) { if (!(flags2.is_whitespace || flags2.is_letter || flags2.is_number || flags2.is_undefined)) {
pos += (cpt == ' '); pos += (cpt == ' ');
while (!unicode_cpt_is_whitespace(cpt2) && cpt2_type != CODEPOINT_TYPE_LETTER && cpt2_type != CODEPOINT_TYPE_NUMBER && cpt2_type != CODEPOINT_TYPE_UNIDENTIFIED) { while (!(flags2.is_whitespace || flags2.is_letter || flags2.is_number || flags2.is_undefined)) {
cpt2_type = _get_cpt_type(++pos); flags2 = _get_flags(++pos);
cpt2 = _get_cpt(pos);
} }
_add_token(pos); _add_token(pos);
continue; continue;
} }
size_t num_whitespaces = 0; size_t num_whitespaces = 0;
while (unicode_cpt_is_whitespace(_get_cpt(pos+num_whitespaces))) { while (_get_flags(pos+num_whitespaces).is_whitespace) {
num_whitespaces++; num_whitespaces++;
} }
@ -357,8 +348,9 @@ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string &
return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : 0; return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : 0;
}; };
auto _get_cpt_type = [&] (const size_t pos) -> int { auto _get_flags = [&] (const size_t pos) -> codepoint_flags {
return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_type(cpts[pos]) : CODEPOINT_TYPE_UNIDENTIFIED; static const codepoint_flags undef(codepoint_flags::UNDEFINED);
return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags(cpts[pos]) : undef;
}; };
size_t _prev_end = offset_ini; size_t _prev_end = offset_ini;
@ -380,7 +372,7 @@ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string &
for (size_t pos = offset_ini; pos < offset_end; /*pos++*/ ) { for (size_t pos = offset_ini; pos < offset_end; /*pos++*/ ) {
const char32_t cpt = _get_cpt(pos); const char32_t cpt = _get_cpt(pos);
const int cpt_type = _get_cpt_type(pos); const auto flags = _get_flags(pos);
// regex: (?i:'s|'t|'re|'ve|'m|'ll|'d) // case insensitive // regex: (?i:'s|'t|'re|'ve|'m|'ll|'d) // case insensitive
if (cpt == '\'' && pos+1 < offset_end) { if (cpt == '\'' && pos+1 < offset_end) {
@ -401,10 +393,10 @@ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string &
} }
// regex: [^\r\n\p{L}\p{N}]?\p{L}+ //####FIXME: the first \p{L} is correct? // regex: [^\r\n\p{L}\p{N}]?\p{L}+ //####FIXME: the first \p{L} is correct?
if (cpt != '\r' && cpt != '\n' && /*cpt_type != CODEPOINT_TYPE_LETTER &&*/ cpt_type != CODEPOINT_TYPE_NUMBER) { if (!(cpt == '\r' || cpt == '\n' || /*flags.is_letter |*/ flags.is_number)) {
if (cpt_type == CODEPOINT_TYPE_LETTER || _get_cpt_type(pos+1) == CODEPOINT_TYPE_LETTER) { // one or more letters if (flags.is_letter || _get_flags(pos+1).is_letter) { // one or more letters
pos++; pos++;
while (_get_cpt_type(pos) == CODEPOINT_TYPE_LETTER) { while (_get_flags(pos).is_letter) {
pos++; pos++;
} }
_add_token(pos); _add_token(pos);
@ -413,9 +405,9 @@ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string &
} }
// regex: \p{N}{1,3} // regex: \p{N}{1,3}
if (cpt_type == CODEPOINT_TYPE_NUMBER) { if (flags.is_number) {
size_t ini = pos; size_t ini = pos;
while (_get_cpt_type(pos) == CODEPOINT_TYPE_NUMBER) { while (_get_flags(pos).is_number) {
if (++pos - ini >= 3 ) { if (++pos - ini >= 3 ) {
_add_token(pos); _add_token(pos);
ini = pos; ini = pos;
@ -426,14 +418,13 @@ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string &
} }
// regex: <space>?[^\s\p{L}\p{N}]+[\r\n]* // regex: <space>?[^\s\p{L}\p{N}]+[\r\n]*
char32_t cpt2 = (cpt == ' ' ? _get_cpt(pos+1) : cpt); auto flags2 = (cpt == ' ' ? _get_flags(pos+1) : flags);
int cpt2_type = (cpt == ' ' ? _get_cpt_type(pos+1) : cpt_type); if (!(flags2.is_whitespace || flags2.is_letter || flags2.is_number || flags2.is_undefined)) {
if (!unicode_cpt_is_whitespace(cpt2) && cpt2_type != CODEPOINT_TYPE_LETTER && cpt2_type != CODEPOINT_TYPE_NUMBER && cpt2_type != CODEPOINT_TYPE_UNIDENTIFIED) {
pos += (cpt == ' '); pos += (cpt == ' ');
while (!unicode_cpt_is_whitespace(cpt2) && cpt2_type != CODEPOINT_TYPE_LETTER && cpt2_type != CODEPOINT_TYPE_NUMBER && cpt2_type != CODEPOINT_TYPE_UNIDENTIFIED) { while (!(flags2.is_whitespace || flags2.is_letter || flags2.is_number || flags2.is_undefined)) {
cpt2_type = _get_cpt_type(++pos); flags2 = _get_flags(++pos);
cpt2 = _get_cpt(pos);
} }
char32_t cpt2 = _get_cpt(pos);
while (cpt2 == '\r' || cpt2 == '\n') { while (cpt2 == '\r' || cpt2 == '\n') {
cpt2 = _get_cpt(++pos); cpt2 = _get_cpt(++pos);
} }
@ -443,7 +434,7 @@ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string &
size_t num_whitespaces = 0; size_t num_whitespaces = 0;
size_t last_end_r_or_n = 0; size_t last_end_r_or_n = 0;
while (unicode_cpt_is_whitespace(_get_cpt(pos+num_whitespaces))) { while (_get_flags(pos+num_whitespaces).is_whitespace) {
char32_t cpt2 = _get_cpt(pos+num_whitespaces); char32_t cpt2 = _get_cpt(pos+num_whitespaces);
if (cpt2 == '\r' || cpt2 == '\n') { if (cpt2 == '\r' || cpt2 == '\n') {
last_end_r_or_n = pos + num_whitespaces + 1; last_end_r_or_n = pos + num_whitespaces + 1;
@ -589,15 +580,14 @@ std::string unicode_cpt_to_utf8(uint32_t cp) {
} }
std::vector<uint32_t> unicode_cpts_normalize_nfd(const std::vector<uint32_t> & cpts) { std::vector<uint32_t> unicode_cpts_normalize_nfd(const std::vector<uint32_t> & cpts) {
std::vector<uint32_t> result; auto comp = [] (const uint32_t cpt, const range_nfd & range) {
result.reserve(cpts.size()); return cpt < range.first;
};
std::vector<uint32_t> result(cpts.size());
for (size_t i = 0; i < cpts.size(); ++i) { for (size_t i = 0; i < cpts.size(); ++i) {
auto it = unicode_map_nfd.find(cpts[i]); const uint32_t cpt = cpts[i];
if (it == unicode_map_nfd.end()) { auto it = std::upper_bound(unicode_ranges_nfd.cbegin(), unicode_ranges_nfd.cend(), cpt, comp) - 1;
result.push_back(cpts[i]); result[i] = (it->first <= cpt && cpt <= it->last) ? it->nfd : cpt;
} else {
result.push_back(it->second);
}
} }
return result; return result;
} }
@ -611,31 +601,19 @@ std::vector<uint32_t> unicode_cpts_from_utf8(const std::string & utf8) {
return result; return result;
} }
int unicode_cpt_type(uint32_t cp) { codepoint_flags unicode_cpt_flags(const uint32_t cp) {
static std::unordered_map<uint32_t, int> cpt_types = unicode_cpt_type_map(); static const codepoint_flags undef(codepoint_flags::UNDEFINED);
const auto it = cpt_types.find(cp); static const auto cpt_flags = unicode_cpt_flags_array();
return it == cpt_types.end() ? CODEPOINT_TYPE_UNIDENTIFIED : it->second; return cp < cpt_flags.size() ? cpt_flags[cp] : undef;
} }
int unicode_cpt_type(const std::string & utf8) { codepoint_flags unicode_cpt_flags(const std::string & utf8) {
if (utf8.length() == 0) { static const codepoint_flags undef(codepoint_flags::UNDEFINED);
return CODEPOINT_TYPE_UNIDENTIFIED; if (utf8.empty()) {
return undef; // undefined
} }
size_t offset = 0; size_t offset = 0;
return unicode_cpt_type(unicode_cpt_from_utf8(utf8, offset)); return unicode_cpt_flags(unicode_cpt_from_utf8(utf8, offset));
}
bool unicode_cpt_is_whitespace(uint32_t cp) {
static const std::unordered_set<uint32_t> is_whitespace = [] {
std::unordered_set<uint32_t> is_whitespace;
for (auto p : unicode_ranges_whitespace) {
for (auto i = p.first; i <= p.second; ++i) {
is_whitespace.insert(i);
}
}
return is_whitespace;
}();
return (bool)is_whitespace.count(cp);
} }
std::string unicode_byte_to_utf8(uint8_t byte) { std::string unicode_byte_to_utf8(uint8_t byte) {
@ -656,21 +634,21 @@ char32_t unicode_tolower(char32_t cp) {
std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::string> & regex_exprs) { std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::string> & regex_exprs) {
// unicode categories // unicode categories
static const std::map<std::string, int> k_ucat_enum = { static const std::map<std::string, int> k_ucat_enum = {
{ "\\p{N}", CODEPOINT_TYPE_NUMBER }, { "\\p{N}", codepoint_flags::NUMBER },
{ "\\p{L}", CODEPOINT_TYPE_LETTER }, { "\\p{L}", codepoint_flags::LETTER },
{ "\\p{P}", CODEPOINT_TYPE_PUNCTUATION }, { "\\p{P}", codepoint_flags::PUNCTUATION },
}; };
static const std::map<int, int> k_ucat_cpt = { static const std::map<int, int> k_ucat_cpt = {
{ CODEPOINT_TYPE_NUMBER, 0xD1 }, { codepoint_flags::NUMBER, 0xD1 },
{ CODEPOINT_TYPE_LETTER, 0xD2 }, { codepoint_flags::LETTER, 0xD2 },
{ CODEPOINT_TYPE_PUNCTUATION, 0xD3 }, { codepoint_flags::PUNCTUATION, 0xD3 },
}; };
static const std::map<int, std::string> k_ucat_map = { static const std::map<int, std::string> k_ucat_map = {
{ CODEPOINT_TYPE_NUMBER, "\x30-\x39" }, // 0-9 { codepoint_flags::NUMBER, "\x30-\x39" }, // 0-9
{ CODEPOINT_TYPE_LETTER, "\x41-\x5A\x61-\x7A" }, // A-Za-z { codepoint_flags::LETTER, "\x41-\x5A\x61-\x7A" }, // A-Za-z
{ CODEPOINT_TYPE_PUNCTUATION, "\x21-\x23\x25-\x2A\x2C-\x2F\x3A-\x3B\x3F-\x40\\\x5B-\\\x5D\x5F\\\x7B\\\x7D" }, // !-#%-*,-/:-;?-@\[-\]_\{\} { codepoint_flags::PUNCTUATION, "\x21-\x23\x25-\x2A\x2C-\x2F\x3A-\x3B\x3F-\x40\\\x5B-\\\x5D\x5F\\\x7B\\\x7D" }, // !-#%-*,-/:-;?-@\[-\]_\{\}
}; };
// compute collapsed codepoints only if needed by at least one regex // compute collapsed codepoints only if needed by at least one regex
@ -701,10 +679,10 @@ std::vector<std::string> unicode_regex_split(const std::string & text, const std
continue; continue;
} }
const int cpt_type = unicode_cpt_type(cpts[i]); const int cpt_flag = unicode_cpt_flags(cpts[i]).category_flag();
if (k_ucat_cpt.find(cpt_type) != k_ucat_cpt.end()) { if (k_ucat_cpt.find(cpt_flag) != k_ucat_cpt.end()) {
text_collapsed[i] = k_ucat_cpt.at(cpt_type); text_collapsed[i] = k_ucat_cpt.at(cpt_flag);
} else { } else {
text_collapsed[i] = (char) 0xD0; // fallback text_collapsed[i] = (char) 0xD0; // fallback
} }

View File

@ -4,24 +4,56 @@
#include <string> #include <string>
#include <vector> #include <vector>
#define CODEPOINT_TYPE_UNIDENTIFIED 0 struct codepoint_flags {
#define CODEPOINT_TYPE_NUMBER 1 enum {
#define CODEPOINT_TYPE_LETTER 2 UNDEFINED = 0x0001,
#define CODEPOINT_TYPE_SEPARATOR 3 NUMBER = 0x0002, // regex: \p{N}
#define CODEPOINT_TYPE_ACCENT_MARK 4 LETTER = 0x0004, // regex: \p{L}
#define CODEPOINT_TYPE_PUNCTUATION 5 SEPARATOR = 0x0008, // regex: \p{Z}
#define CODEPOINT_TYPE_SYMBOL 6 ACCENT_MARK = 0x0010, // regex: \p{M}
#define CODEPOINT_TYPE_CONTROL 7 PUNCTUATION = 0x0020, // regex: \p{P}
SYMBOL = 0x0040, // regex: \p{S}
CONTROL = 0x0080, // regex: \p{C}
MASK_CATEGORIES = 0x00FF,
};
// codepoint type
uint16_t is_undefined : 1;
uint16_t is_number : 1; // regex: \p{N}
uint16_t is_letter : 1; // regex: \p{L}
uint16_t is_separator : 1; // regex: \p{Z}
uint16_t is_accent_mark : 1; // regex: \p{M}
uint16_t is_punctuation : 1; // regex: \p{P}
uint16_t is_symbol : 1; // regex: \p{S}
uint16_t is_control : 1; // regex: \p{C}
// helper flags
uint16_t is_whitespace : 1; // regex: \s
uint16_t is_lowercase : 1;
uint16_t is_uppercase : 1;
uint16_t is_nfd : 1;
// decode from uint16
inline codepoint_flags(const uint16_t flags=0) {
*reinterpret_cast<uint16_t*>(this) = flags;
}
inline uint16_t as_uint() const {
return *reinterpret_cast<const uint16_t*>(this);
}
inline uint16_t category_flag() const {
return this->as_uint() & MASK_CATEGORIES;
}
};
std::string unicode_cpt_to_utf8(uint32_t cp); std::string unicode_cpt_to_utf8(uint32_t cp);
std::vector<uint32_t> unicode_cpts_from_utf8(const std::string & utf8); std::vector<uint32_t> unicode_cpts_from_utf8(const std::string & utf8);
std::vector<uint32_t> unicode_cpts_normalize_nfd(const std::vector<uint32_t> & cpts); std::vector<uint32_t> unicode_cpts_normalize_nfd(const std::vector<uint32_t> & cpts);
int unicode_cpt_type(uint32_t cp); codepoint_flags unicode_cpt_flags(const uint32_t cp);
int unicode_cpt_type(const std::string & utf8); codepoint_flags unicode_cpt_flags(const std::string & utf8);
bool unicode_cpt_is_whitespace(uint32_t cp);
std::string unicode_byte_to_utf8(uint8_t byte); std::string unicode_byte_to_utf8(uint8_t byte);
uint8_t unicode_utf8_to_byte(const std::string & utf8); uint8_t unicode_utf8_to_byte(const std::string & utf8);