/ unicode.cpp
unicode.cpp
1 #include "unicode.h" 2 #include "unicode-data.h" 3 4 #include <cassert> 5 #include <cstddef> 6 #include <cstdint> 7 #include <map> 8 #include <regex> 9 #include <stdexcept> 10 #include <string> 11 #include <unordered_map> 12 #include <unordered_set> 13 #include <utility> 14 #include <vector> 15 #include <locale> 16 #include <codecvt> 17 18 static std::string unicode_cpts_to_utf8(const std::vector<uint32_t> & cps) { 19 std::string result; 20 for (size_t i = 0; i < cps.size(); ++i) { 21 result.append(unicode_cpt_to_utf8(cps[i])); 22 } 23 return result; 24 } 25 26 static uint32_t unicode_cpt_from_utf8(const std::string & utf8, size_t & offset) { 27 assert(offset < utf8.size()); 28 if (!(utf8[offset + 0] & 0x80)) { 29 auto result = utf8[offset + 0]; 30 offset += 1; 31 return result; 32 } 33 if (!(utf8[offset + 0] & 0x40)) { 34 throw std::invalid_argument("invalid character"); 35 } 36 if (!(utf8[offset + 0] & 0x20)) { 37 if (offset + 1 >= utf8.size() || ! ((utf8[offset + 1] & 0xc0) == 0x80)) { 38 throw std::invalid_argument("invalid character"); 39 } 40 auto result = ((utf8[offset + 0] & 0x1f) << 6) | (utf8[offset + 1] & 0x3f); 41 offset += 2; 42 return result; 43 } 44 if (!(utf8[offset + 0] & 0x10)) { 45 if (offset + 2 >= utf8.size() || ! ((utf8[offset + 1] & 0xc0) == 0x80) || ! ((utf8[offset + 2] & 0xc0) == 0x80)) { 46 throw std::invalid_argument("invalid character"); 47 } 48 auto result = ((utf8[offset + 0] & 0x0f) << 12) | ((utf8[offset + 1] & 0x3f) << 6) | (utf8[offset + 2] & 0x3f); 49 offset += 3; 50 return result; 51 } 52 if (!(utf8[offset + 0] & 0x08)) { 53 if (offset + 3 >= utf8.size() || ! ((utf8[offset + 1] & 0xc0) == 0x80) || ! ((utf8[offset + 2] & 0xc0) == 0x80) || !((utf8[offset + 3] & 0xc0) == 0x80)) { 54 throw std::invalid_argument("invalid character"); 55 } 56 auto result = ((utf8[offset + 0] & 0x07) << 18) | ((utf8[offset + 1] & 0x3f) << 12) | ((utf8[offset + 2] & 0x3f) << 6) | (utf8[offset + 3] & 0x3f); 57 offset += 4; 58 return result; 59 } 60 throw std::invalid_argument("failed to convert utf8 to codepoint"); 61 } 62 63 //static std::vector<uint16_t> unicode_cpt_to_utf16(uint32_t cp) { 64 // std::vector<uint16_t> result; 65 // if (/* 0x0000 <= cp && */ cp <= 0xffff) { 66 // result.emplace_back(cp); 67 // return result; 68 // } 69 // if (0x10000 <= cp && cp <= 0x10ffff) { 70 // result.emplace_back(0xd800 | ((cp - 0x10000) >> 10)); 71 // result.emplace_back(0xdc00 | ((cp - 0x10000) & 0x03ff)); 72 // return result; 73 // } 74 // throw std::invalid_argument("failed to convert codepoint to utf16"); 75 //} 76 77 //static std::vector<uint16_t> unicode_cpts_to_utf16(const std::vector<uint32_t> & cps) { 78 // std::vector<uint16_t> result; 79 // for (size_t i = 0; i < cps.size(); ++i) { 80 // auto temp = unicode_cpt_to_utf16(cps[i]); 81 // result.insert(result.end(), temp.begin(), temp.end()); 82 // } 83 // return result; 84 //} 85 86 //static uint32_t unicode_cpt_from_utf16(const std::vector<uint16_t> & utf16, size_t & offset) { 87 // assert(offset < utf16.size()); 88 // if (((utf16[0] >> 10) << 10) != 0xd800) { 89 // auto result = utf16[offset + 0]; 90 // offset += 1; 91 // return result; 92 // } 93 // 94 // if (offset + 1 >= utf16.size() || !((utf16[1] & 0xdc00) == 0xdc00)) { 95 // throw std::invalid_argument("invalid character"); 96 // } 97 // 98 // auto result = 0x10000 + (((utf16[0] & 0x03ff) << 10) | (utf16[1] & 0x03ff)); 99 // offset += 2; 100 // return result; 101 //} 102 103 //static std::vector<uint32_t> unicode_cpts_from_utf16(const std::vector<uint16_t> & utf16) { 104 // std::vector<uint32_t> result; 105 // size_t offset = 0; 106 // while (offset < utf16.size()) { 107 // result.push_back(unicode_cpt_from_utf16(utf16, offset)); 108 // } 109 // return result; 110 //} 111 112 static std::vector<codepoint_flags> unicode_cpt_flags_array() { 113 std::vector<codepoint_flags> cpt_flags(MAX_CODEPOINTS, codepoint_flags::UNDEFINED); 114 115 assert (unicode_ranges_flags.front().first == 0); 116 assert (unicode_ranges_flags.back().first == MAX_CODEPOINTS); 117 for (size_t i = 1; i < unicode_ranges_flags.size(); ++i) { 118 const auto range_ini = unicode_ranges_flags[i-1]; // codepoint_ini, flags 119 const auto range_end = unicode_ranges_flags[i]; // codepoint_end, flags 120 for (uint32_t cpt = range_ini.first; cpt < range_end.first; ++cpt) { 121 cpt_flags[cpt] = range_ini.second; 122 } 123 } 124 125 for (auto cpt : unicode_set_whitespace) { 126 cpt_flags[cpt].is_whitespace = true; 127 } 128 129 for (auto p : unicode_map_lowercase) { 130 cpt_flags[p.second].is_lowercase = true; 131 } 132 133 for (auto p : unicode_map_uppercase) { 134 cpt_flags[p.second].is_uppercase = true; 135 } 136 137 for (auto &range : unicode_ranges_nfd) { // start, last, nfd 138 cpt_flags[range.nfd].is_nfd = true; 139 } 140 141 return cpt_flags; 142 } 143 144 static std::unordered_map<uint8_t, std::string> unicode_byte_to_utf8_map() { 145 std::unordered_map<uint8_t, std::string> map; 146 for (int ch = 0x21; ch <= 0x7E; ++ch) { // u'!' to u'~' 147 assert(0 <= ch && ch < 256); 148 map[ch] = unicode_cpt_to_utf8(ch); 149 } 150 for (int ch = 0xA1; ch <= 0xAC; ++ch) { // u'¡' to u'¬' 151 assert(0 <= ch && ch < 256); 152 map[ch] = unicode_cpt_to_utf8(ch); 153 } 154 for (int ch = 0xAE; ch <= 0xFF; ++ch) { // u'®' to u'ÿ' 155 assert(0 <= ch && ch < 256); 156 map[ch] = unicode_cpt_to_utf8(ch); 157 } 158 auto n = 0; 159 for (int ch = 0; ch < 256; ++ch) { 160 if (map.find(ch) == map.end()) { 161 map[ch] = unicode_cpt_to_utf8(256 + n); 162 ++n; 163 } 164 } 165 return map; 166 } 167 168 static std::unordered_map<std::string, uint8_t> unicode_utf8_to_byte_map() { 169 std::unordered_map<std::string, uint8_t> map; 170 for (int ch = 0x21; ch <= 0x7E; ++ch) { // u'!' to u'~' 171 assert(0 <= ch && ch < 256); 172 map[unicode_cpt_to_utf8(ch)] = ch; 173 } 174 for (int ch = 0xA1; ch <= 0xAC; ++ch) { // u'¡' to u'¬' 175 assert(0 <= ch && ch < 256); 176 map[unicode_cpt_to_utf8(ch)] = ch; 177 } 178 for (int ch = 0xAE; ch <= 0xFF; ++ch) { // u'®' to u'ÿ' 179 assert(0 <= ch && ch < 256); 180 map[unicode_cpt_to_utf8(ch)] = ch; 181 } 182 auto n = 0; 183 for (int ch = 0; ch < 256; ++ch) { 184 if (map.find(unicode_cpt_to_utf8(ch)) == map.end()) { 185 map[unicode_cpt_to_utf8(256 + n)] = ch; 186 ++n; 187 } 188 } 189 return map; 190 } 191 192 static inline std::wstring unicode_wstring_from_utf8(const std::string & s) { 193 std::wstring_convert<std::codecvt_utf8<wchar_t>> conv; 194 return conv.from_bytes(s); 195 } 196 197 static std::vector<std::string> unicode_byte_encoding_process(const std::vector<std::string> & bpe_words) { 198 std::vector<std::string> bpe_encoded_words; 199 for (const auto & word : bpe_words) { 200 std::string text_utf; 201 auto utf_word = unicode_cpts_from_utf8(word); 202 for (size_t i = 0; i < utf_word.size(); ++i) { 203 text_utf += unicode_cpt_to_utf8(utf_word[i]); 204 } 205 206 std::string encoded_token; 207 for (char & c : text_utf) { 208 encoded_token += unicode_byte_to_utf8(c); 209 } 210 bpe_encoded_words.emplace_back(encoded_token); 211 } 212 return bpe_encoded_words; 213 } 214 215 // GPT2 system regex: 's|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+ 216 static std::vector<size_t> unicode_regex_split_custom_gpt2(const std::string & text, const std::vector<size_t> & offsets) { 217 std::vector<size_t> bpe_offsets; // store the offset of each word 218 bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size 219 220 const auto cpts = unicode_cpts_from_utf8(text); 221 222 size_t start = 0; 223 for (auto offset : offsets) { 224 const size_t offset_ini = start; 225 const size_t offset_end = start + offset; 226 assert(offset_end <= cpts.size()); 227 start = offset_end; 228 229 auto _get_cpt = [&] (const size_t pos) -> uint32_t { 230 return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : 0; 231 }; 232 233 auto _get_flags = [&] (const size_t pos) -> codepoint_flags { 234 static const codepoint_flags undef(codepoint_flags::UNDEFINED); 235 return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags(cpts[pos]) : undef; 236 }; 237 238 size_t _prev_end = offset_ini; 239 auto _add_token = [&] (const size_t end) -> size_t { 240 assert(_prev_end <= end && end <= offset_end); 241 size_t len = end - _prev_end; 242 if (len > 0) { 243 bpe_offsets.push_back(len); 244 } 245 _prev_end = end; 246 //if (len > 0) { 247 // std::string s = ""; 248 // for(size_t p = end-len; p < end; p++) 249 // s += unicode_cpt_to_utf8(cpts[p]); 250 // printf(">>> '%s'\n", s.c_str()); 251 //} 252 return len; 253 }; 254 255 for (size_t pos = offset_ini; pos < offset_end; /*pos++*/ ) { 256 const uint32_t cpt = _get_cpt(pos); 257 const auto flags = _get_flags(pos); 258 259 // regex: 's|'t|'re|'ve|'m|'ll|'d 260 if (cpt == '\'' && pos+1 < offset_end) { 261 uint32_t cpt_next = _get_cpt(pos+1); 262 if (cpt_next == 's' || cpt_next == 't' || cpt_next == 'm' || cpt_next == 'd') { 263 pos += _add_token(pos+2); 264 continue; 265 } 266 if (pos+2 < offset_end) { 267 uint32_t cpt_next_next = _get_cpt(pos+2); 268 if ((cpt_next == 'r' && cpt_next_next == 'e') || 269 (cpt_next == 'v' && cpt_next_next == 'e') || 270 (cpt_next == 'l' && cpt_next_next == 'l')) { 271 pos += _add_token(pos+3); 272 continue; 273 } 274 } 275 } 276 277 auto flags2 = (cpt == ' ' ? _get_flags(pos+1) : flags); 278 // regex: <space>?\p{L}+ 279 if (flags2.is_letter) { 280 pos += (cpt == ' '); 281 while (flags2.is_letter) { 282 flags2 = _get_flags(++pos); 283 } 284 _add_token(pos); 285 continue; 286 } 287 // regex: <space>?\p{N}+ 288 if (flags2.is_number) { 289 pos += (cpt == ' '); 290 while (flags2.is_number) { 291 flags2 = _get_flags(++pos); 292 } 293 _add_token(pos); 294 continue; 295 } 296 // regex: <space>?[^\s\p{L}\p{N}]+ 297 if (!(flags2.is_whitespace || flags2.is_letter || flags2.is_number || flags2.is_undefined)) { 298 pos += (cpt == ' '); 299 while (!(flags2.is_whitespace || flags2.is_letter || flags2.is_number || flags2.is_undefined)) { 300 flags2 = _get_flags(++pos); 301 } 302 _add_token(pos); 303 continue; 304 } 305 306 size_t num_whitespaces = 0; 307 while (_get_flags(pos+num_whitespaces).is_whitespace) { 308 num_whitespaces++; 309 } 310 311 // regex: \s+(?!\S) 312 if (num_whitespaces > 1 && _get_cpt(pos+num_whitespaces) != 0) { 313 pos += num_whitespaces - 1; 314 _add_token(pos); 315 continue; 316 } 317 318 // regex: \s+ 319 if (num_whitespaces > 0) { 320 pos += num_whitespaces; 321 _add_token(pos); 322 continue; 323 } 324 325 // no matches 326 _add_token(++pos); 327 } 328 } 329 330 return bpe_offsets; 331 } 332 333 // LLAMA3 system regex: "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" 334 static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string & text, const std::vector<size_t> & offsets) { 335 std::vector<size_t> bpe_offsets; // store the offset of each word 336 bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size 337 338 const auto cpts = unicode_cpts_from_utf8(text); 339 340 size_t start = 0; 341 for (auto offset : offsets) { 342 const size_t offset_ini = start; 343 const size_t offset_end = start + offset; 344 assert(offset_end <= cpts.size()); 345 start = offset_end; 346 347 auto _get_cpt = [&] (const size_t pos) -> uint32_t { 348 return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : 0; 349 }; 350 351 auto _get_flags = [&] (const size_t pos) -> codepoint_flags { 352 static const codepoint_flags undef(codepoint_flags::UNDEFINED); 353 return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags(cpts[pos]) : undef; 354 }; 355 356 size_t _prev_end = offset_ini; 357 auto _add_token = [&] (const size_t end) -> size_t { 358 assert(_prev_end <= end && end <= offset_end); 359 size_t len = end - _prev_end; 360 if (len > 0) { 361 bpe_offsets.push_back(len); 362 } 363 _prev_end = end; 364 //if (len > 0) { 365 // std::string s = ""; 366 // for(size_t p = end-len; p < end; p++) 367 // s += unicode_cpt_to_utf8(cpts[p]); 368 // printf(">>> '%s'\n", s.c_str()); 369 //} 370 return len; 371 }; 372 373 for (size_t pos = offset_ini; pos < offset_end; /*pos++*/ ) { 374 const uint32_t cpt = _get_cpt(pos); 375 const auto flags = _get_flags(pos); 376 377 // regex: (?i:'s|'t|'re|'ve|'m|'ll|'d) // case insensitive 378 if (cpt == '\'' && pos+1 < offset_end) { 379 uint32_t cpt_next = unicode_tolower(_get_cpt(pos+1)); 380 if (cpt_next == 's' || cpt_next == 't' || cpt_next == 'm' || cpt_next == 'd') { 381 pos += _add_token(pos+2); 382 continue; 383 } 384 if (pos+2 < offset_end) { 385 uint32_t cpt_next_next = unicode_tolower(_get_cpt(pos+2)); 386 if ((cpt_next == 'r' && cpt_next_next == 'e') || 387 (cpt_next == 'v' && cpt_next_next == 'e') || 388 (cpt_next == 'l' && cpt_next_next == 'l')) { 389 pos += _add_token(pos+3); 390 continue; 391 } 392 } 393 } 394 395 // regex: [^\r\n\p{L}\p{N}]?\p{L}+ //####FIXME: the first \p{L} is correct? 396 if (!(cpt == '\r' || cpt == '\n' || /*flags.is_letter |*/ flags.is_number)) { 397 if (flags.is_letter || _get_flags(pos+1).is_letter) { // one or more letters 398 pos++; 399 while (_get_flags(pos).is_letter) { 400 pos++; 401 } 402 _add_token(pos); 403 continue; 404 } 405 } 406 407 // regex: \p{N}{1,3} 408 if (flags.is_number) { 409 size_t ini = pos; 410 while (_get_flags(pos).is_number) { 411 if (++pos - ini >= 3 ) { 412 _add_token(pos); 413 ini = pos; 414 } 415 } 416 _add_token(pos); 417 continue; 418 } 419 420 // regex: <space>?[^\s\p{L}\p{N}]+[\r\n]* 421 auto flags2 = (cpt == ' ' ? _get_flags(pos+1) : flags); 422 if (!(flags2.is_whitespace || flags2.is_letter || flags2.is_number || flags2.is_undefined)) { 423 pos += (cpt == ' '); 424 while (!(flags2.is_whitespace || flags2.is_letter || flags2.is_number || flags2.is_undefined)) { 425 flags2 = _get_flags(++pos); 426 } 427 uint32_t cpt2 = _get_cpt(pos); 428 while (cpt2 == '\r' || cpt2 == '\n') { 429 cpt2 = _get_cpt(++pos); 430 } 431 _add_token(pos); 432 continue; 433 } 434 435 size_t num_whitespaces = 0; 436 size_t last_end_r_or_n = 0; 437 while (_get_flags(pos+num_whitespaces).is_whitespace) { 438 uint32_t cpt2 = _get_cpt(pos+num_whitespaces); 439 if (cpt2 == '\r' || cpt2 == '\n') { 440 last_end_r_or_n = pos + num_whitespaces + 1; 441 } 442 num_whitespaces++; 443 } 444 445 // regex: \s*[\r\n]+ 446 if (last_end_r_or_n > 0) { 447 pos = last_end_r_or_n; 448 _add_token(pos); 449 continue; 450 } 451 452 // regex: \s+(?!\S) 453 if (num_whitespaces > 1 && _get_cpt(pos+num_whitespaces) != 0) { 454 pos += num_whitespaces - 1; 455 _add_token(pos); 456 continue; 457 } 458 459 // regex: \s+ 460 if (num_whitespaces > 0) { 461 pos += num_whitespaces; 462 _add_token(pos); 463 continue; 464 } 465 466 // no matches 467 _add_token(++pos); 468 } 469 } 470 471 return bpe_offsets; 472 } 473 474 // use std::wregex to split the text 475 static std::vector<size_t> unicode_regex_split_stl(const std::wstring & wtext, const std::wstring & regex_expr, const std::vector<size_t> & offsets) { 476 std::wregex expr(regex_expr); 477 std::vector<size_t> bpe_offsets; // store the offset of each word 478 bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size 479 size_t start = 0; 480 for (auto offset : offsets) { 481 std::wcregex_iterator it(wtext.data() + start, wtext.data() + start + offset, expr); 482 std::wcregex_iterator end; 483 484 int64_t start_idx = 0; 485 while (it != end) { 486 std::wcmatch match = *it; 487 if (match.position() > start_idx) { 488 bpe_offsets.emplace_back(match.position() - start_idx); 489 } 490 bpe_offsets.emplace_back(match.length()); 491 start_idx = match.position() + match.length(); 492 ++it; 493 } 494 495 if (start_idx < (int64_t) offset) { 496 bpe_offsets.emplace_back(offset - start_idx); 497 } 498 start += offset; 499 } 500 501 return bpe_offsets; 502 } 503 504 // use std::regex to split the text 505 static std::vector<size_t> unicode_regex_split_stl(const std::string & text, const std::string & regex_expr, const std::vector<size_t> & offsets) { 506 std::regex expr(regex_expr); 507 std::vector<size_t> bpe_offsets; // store the offset of each word 508 bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size 509 size_t start = 0; 510 for (auto offset : offsets) { 511 std::cregex_iterator it(text.data() + start, text.data() + start + offset, expr); 512 std::cregex_iterator end; 513 514 int64_t start_idx = 0; 515 while (it != end) { 516 std::cmatch match = *it; 517 if (match.position() > start_idx) { 518 bpe_offsets.emplace_back(match.position() - start_idx); 519 } 520 bpe_offsets.emplace_back(match.length()); 521 start_idx = match.position() + match.length(); 522 ++it; 523 } 524 525 if (start_idx < (int64_t) offset) { 526 bpe_offsets.emplace_back(offset - start_idx); 527 } 528 start += offset; 529 } 530 531 return bpe_offsets; 532 } 533 534 static std::vector<size_t> unicode_regex_split_custom(const std::string & text, const std::string & regex_expr, const std::vector<size_t> & offsets) { 535 std::vector<size_t> bpe_offsets; 536 537 if (regex_expr == "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)") { 538 bpe_offsets = unicode_regex_split_custom_gpt2(text, offsets); 539 } else if ( 540 regex_expr == "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" || 541 regex_expr == "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+") { 542 543 bpe_offsets = unicode_regex_split_custom_llama3(text, offsets); 544 } 545 546 return bpe_offsets; 547 } 548 549 // 550 // interface 551 // 552 553 std::string unicode_cpt_to_utf8(uint32_t cp) { 554 std::string result; 555 556 if (/* 0x00 <= cp && */ cp <= 0x7f) { 557 result.push_back(cp); 558 return result; 559 } 560 if (0x80 <= cp && cp <= 0x7ff) { 561 result.push_back(0xc0 | ((cp >> 6) & 0x1f)); 562 result.push_back(0x80 | (cp & 0x3f)); 563 return result; 564 } 565 if (0x800 <= cp && cp <= 0xffff) { 566 result.push_back(0xe0 | ((cp >> 12) & 0x0f)); 567 result.push_back(0x80 | ((cp >> 6) & 0x3f)); 568 result.push_back(0x80 | (cp & 0x3f)); 569 return result; 570 } 571 if (0x10000 <= cp && cp <= 0x10ffff) { 572 result.push_back(0xf0 | ((cp >> 18) & 0x07)); 573 result.push_back(0x80 | ((cp >> 12) & 0x3f)); 574 result.push_back(0x80 | ((cp >> 6) & 0x3f)); 575 result.push_back(0x80 | (cp & 0x3f)); 576 return result; 577 } 578 579 throw std::invalid_argument("invalid codepoint"); 580 } 581 582 std::vector<uint32_t> unicode_cpts_normalize_nfd(const std::vector<uint32_t> & cpts) { 583 auto comp = [] (const uint32_t cpt, const range_nfd & range) { 584 return cpt < range.first; 585 }; 586 std::vector<uint32_t> result(cpts.size()); 587 for (size_t i = 0; i < cpts.size(); ++i) { 588 const uint32_t cpt = cpts[i]; 589 auto it = std::upper_bound(unicode_ranges_nfd.cbegin(), unicode_ranges_nfd.cend(), cpt, comp) - 1; 590 result[i] = (it->first <= cpt && cpt <= it->last) ? it->nfd : cpt; 591 } 592 return result; 593 } 594 595 std::vector<uint32_t> unicode_cpts_from_utf8(const std::string & utf8) { 596 std::vector<uint32_t> result; 597 size_t offset = 0; 598 while (offset < utf8.size()) { 599 result.push_back(unicode_cpt_from_utf8(utf8, offset)); 600 } 601 return result; 602 } 603 604 codepoint_flags unicode_cpt_flags(const uint32_t cp) { 605 static const codepoint_flags undef(codepoint_flags::UNDEFINED); 606 static const auto cpt_flags = unicode_cpt_flags_array(); 607 return cp < cpt_flags.size() ? cpt_flags[cp] : undef; 608 } 609 610 codepoint_flags unicode_cpt_flags(const std::string & utf8) { 611 static const codepoint_flags undef(codepoint_flags::UNDEFINED); 612 if (utf8.empty()) { 613 return undef; // undefined 614 } 615 size_t offset = 0; 616 return unicode_cpt_flags(unicode_cpt_from_utf8(utf8, offset)); 617 } 618 619 std::string unicode_byte_to_utf8(uint8_t byte) { 620 static std::unordered_map<uint8_t, std::string> map = unicode_byte_to_utf8_map(); 621 return map.at(byte); 622 } 623 624 uint8_t unicode_utf8_to_byte(const std::string & utf8) { 625 static std::unordered_map<std::string, uint8_t> map = unicode_utf8_to_byte_map(); 626 return map.at(utf8); 627 } 628 629 uint32_t unicode_tolower(uint32_t cp) { 630 auto it = unicode_map_lowercase.find(cp); 631 return it == unicode_map_lowercase.end() ? cp : it->second; 632 } 633 634 std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::string> & regex_exprs) { 635 // unicode categories 636 static const std::map<std::string, int> k_ucat_enum = { 637 { "\\p{N}", codepoint_flags::NUMBER }, 638 { "\\p{L}", codepoint_flags::LETTER }, 639 { "\\p{P}", codepoint_flags::PUNCTUATION }, 640 }; 641 642 static const std::map<int, int> k_ucat_cpt = { 643 { codepoint_flags::NUMBER, 0xD1 }, 644 { codepoint_flags::LETTER, 0xD2 }, 645 { codepoint_flags::PUNCTUATION, 0xD3 }, 646 }; 647 648 static const std::map<int, std::string> k_ucat_map = { 649 { codepoint_flags::NUMBER, "\x30-\x39" }, // 0-9 650 { codepoint_flags::LETTER, "\x41-\x5A\x61-\x7A" }, // A-Za-z 651 { codepoint_flags::PUNCTUATION, "\x21-\x23\x25-\x2A\x2C-\x2F\x3A-\x3B\x3F-\x40\\\x5B-\\\x5D\x5F\\\x7B\\\x7D" }, // !-#%-*,-/:-;?-@\[-\]_\{\} 652 }; 653 654 // compute collapsed codepoints only if needed by at least one regex 655 bool need_collapse = false; 656 for (auto & regex_expr : regex_exprs) { 657 // search for unicode categories 658 for (const auto & ucat : k_ucat_enum) { 659 if (std::string::npos != regex_expr.find(ucat.first)) { 660 need_collapse = true; 661 break; 662 } 663 } 664 } 665 666 const auto cpts = unicode_cpts_from_utf8(text); 667 668 // generate a "collapsed" representation of the text, where all codepoints are replaced by a single byte 669 // ref: https://github.com/ggerganov/llama.cpp/pull/6920#issuecomment-2081479935 670 std::string text_collapsed; 671 if (need_collapse) { 672 // collapse all unicode categories 673 text_collapsed.resize(cpts.size()); 674 675 for (size_t i = 0; i < cpts.size(); ++i) { 676 // keep single-byte codepoints as is 677 if (cpts[i] < 128) { 678 text_collapsed[i] = cpts[i]; 679 continue; 680 } 681 682 const int cpt_flag = unicode_cpt_flags(cpts[i]).category_flag(); 683 684 if (k_ucat_cpt.find(cpt_flag) != k_ucat_cpt.end()) { 685 text_collapsed[i] = k_ucat_cpt.at(cpt_flag); 686 } else { 687 text_collapsed[i] = (char) 0xD0; // fallback 688 } 689 } 690 } 691 692 std::vector<size_t> bpe_offsets = { cpts.size() }; 693 694 for (auto & regex_expr : regex_exprs) { 695 // first, see if we have an efficient custom regex implementation 696 auto tmp = unicode_regex_split_custom(text, regex_expr, bpe_offsets); 697 698 if (!tmp.empty()) { 699 bpe_offsets = std::move(tmp); 700 continue; 701 } 702 703 // fallback to general-purpose std::regex / std::wregex 704 try { 705 // if a unicode category is used in the regex, we use the collapsed text and replace the unicode category 706 // with the corresponding collapsed representation 707 bool use_collapsed = false; 708 for (auto & ucat : k_ucat_enum) { 709 if (std::string::npos != regex_expr.find(ucat.first)) { 710 use_collapsed = true; 711 break; 712 } 713 } 714 715 if (use_collapsed) { 716 // sanity-check that the original regex does not contain any non-ASCII characters 717 const auto cpts_regex = unicode_cpts_from_utf8(regex_expr); 718 for (size_t i = 0; i < cpts_regex.size(); ++i) { 719 if (cpts_regex[i] >= 128) { 720 throw std::runtime_error("Regex includes both unicode categories and non-ASCII characters - not supported"); 721 } 722 } 723 724 // generate a collapsed representation of the regex 725 std::string regex_expr_collapsed; 726 727 // track if we are inside [], because nested [] are not allowed 728 bool inside = false; 729 for (size_t i = 0; i < regex_expr.size(); ++i) { 730 if (regex_expr[i] == '[' && (i == 0 || regex_expr[i - 1] != '\\')) { 731 regex_expr_collapsed += '['; 732 inside = true; 733 continue; 734 } 735 736 if (inside && regex_expr[i] == ']' && regex_expr[i - 1] != '\\') { 737 regex_expr_collapsed += ']'; 738 inside = false; 739 continue; 740 } 741 742 if (regex_expr[i + 0] == '\\' && i + 4 < regex_expr.size() && 743 regex_expr[i + 1] == 'p' && 744 regex_expr[i + 2] == '{' && 745 regex_expr[i + 4] == '}') { 746 const std::string pat = regex_expr.substr(i, 5); 747 if (k_ucat_enum.find(pat) != k_ucat_enum.end()) { 748 if (!inside) { 749 regex_expr_collapsed += '['; 750 } 751 regex_expr_collapsed += k_ucat_cpt.at(k_ucat_enum.at(pat)); 752 regex_expr_collapsed += k_ucat_map.at(k_ucat_enum.at(pat)); 753 if (!inside) { 754 regex_expr_collapsed += ']'; 755 } 756 i += 4; 757 continue; 758 } 759 } 760 761 regex_expr_collapsed += regex_expr[i]; 762 } 763 764 //printf("text_collapsed: %s\n", text_collapsed.c_str()); 765 //printf("regex_expr_collapsed: %s\n", regex_expr_collapsed.c_str()); 766 bpe_offsets = unicode_regex_split_stl(text_collapsed, regex_expr_collapsed, bpe_offsets); 767 } else { 768 // no unicode category used, we can use std::wregex directly 769 const std::wstring wtext = unicode_wstring_from_utf8(text); 770 const std::wstring wregex_expr = unicode_wstring_from_utf8(regex_expr); 771 772 //printf("text: %s\n", text.c_str()); 773 //printf("regex_expr: %s\n", regex_expr.c_str()); 774 bpe_offsets = unicode_regex_split_stl(wtext, wregex_expr, bpe_offsets); 775 } 776 } catch (std::regex_error & e) { 777 fprintf(stderr, "Failed to process regex: '%s'\n", regex_expr.c_str()); 778 fprintf(stderr, "Regex error: %s\n", e.what()); 779 throw std::runtime_error("Failed to process regex"); 780 } 781 } 782 783 std::vector<std::string> bpe_words; 784 bpe_words.reserve(bpe_offsets.size()); // reserve memory for the approximate size 785 786 size_t start = 0; 787 for (size_t & offset : bpe_offsets) { 788 bpe_words.emplace_back(); 789 for (size_t i = start; i < start + offset; ++i) { 790 bpe_words.back() += unicode_cpt_to_utf8(cpts[i]); 791 } 792 start += offset; 793 } 794 795 return unicode_byte_encoding_process(bpe_words); 796 }