#include "utils.h" #include #include void utreplace(std::string & str, const std::string & needle, const std::string & replacement) { size_t pos = 0; while ((pos = str.find(needle, pos)) != std::string::npos) { str.replace(pos, needle.length(), replacement); pos += replacement.length(); } } std::map json_parse(const std::string & fname) { std::map result; // read file into string std::string json; { std::ifstream ifs(fname); if (!ifs) { fprintf(stderr, "Failed to open %s\n", fname.c_str()); exit(1); } json = std::string((std::istreambuf_iterator(ifs)), (std::istreambuf_iterator())); } if (json[0] != '{') { return result; } // parse json { bool has_key = false; bool in_token = false; std::string str_key = ""; std::string str_val = ""; int n = json.size(); for (int i = 1; i < n; ++i) { if (!in_token) { if (json[i] == ' ') continue; if (json[i] == '"') { in_token = true; continue; } } else { if (json[i] == '\\' && i+1 < n) { if (has_key == false) { str_key += json[i]; } else { str_val += json[i]; } ++i; } else if (json[i] == '"') { if (has_key == false) { has_key = true; ++i; while (json[i] == ' ') ++i; ++i; // : while (json[i] == ' ') ++i; if (json[i] != '\"') { while (json[i] != ',' && json[i] != '}') { str_val += json[i++]; } has_key = false; } else { in_token = true; continue; } } else { has_key = false; } ::utreplace(str_key, "\\u0120", " " ); // \u0120 -> space ::utreplace(str_key, "\\u010a", "\n"); // \u010a -> new line ::utreplace(str_key, "\\\"", "\""); // \\\" -> " try { result[str_key] = std::stoi(str_val); } catch (...) { //fprintf(stderr, "%s: ignoring key '%s' with value '%s'\n", fname.c_str(), str_key.c_str(), str_val.c_str()); } str_key = ""; str_val = ""; in_token = false; continue; } if (has_key == false) { str_key += json[i]; } else { str_val += json[i]; } } } } return result; } void gpt_vocab::add_special_token(const std::string & token) { special_tokens.push_back(token); } static void append_utf8(char32_t ch, std::string & out) { if (ch <= 0x7F) { out.push_back(static_cast(ch)); } else if (ch <= 0x7FF) { out.push_back(static_cast(0xC0 | ((ch >> 6) & 0x1F))); out.push_back(static_cast(0x80 | (ch & 0x3F))); } else if (ch <= 0xFFFF) { out.push_back(static_cast(0xE0 | ((ch >> 12) & 0x0F))); out.push_back(static_cast(0x80 | ((ch >> 6) & 0x3F))); out.push_back(static_cast(0x80 | (ch & 0x3F))); } else if (ch <= 0x10FFFF) { out.push_back(static_cast(0xF0 | ((ch >> 18) & 0x07))); out.push_back(static_cast(0x80 | ((ch >> 12) & 0x3F))); out.push_back(static_cast(0x80 | ((ch >> 6) & 0x3F))); out.push_back(static_cast(0x80 | (ch & 0x3F))); } else { printf("Invalid Unicode code point\n"); } } std::vector gpt_tokenize(const gpt_vocab & vocab, const std::string & text) { std::vector words; // first split the text into words { std::string str = text; std::string pat = R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)"; std::regex re(pat); std::smatch m; while (std::regex_search(str, m, re)) { for (auto x : m) { words.push_back(x); } str = m.suffix(); } } // find the longest tokens that form the words: std::vector tokens; for (const auto & word : words) { if (word.size() == 0) continue; int i = 0; int n = word.size(); while (i < n) { int j = n; while (j > i) { auto it = vocab.token_to_id.find(word.substr(i, j-i)); if (it != vocab.token_to_id.end()) { tokens.push_back(it->second); i = j; j = n; continue; } --j; } if (i == n) { break; } if (j == i) { auto sub = word.substr(i, 1); if (vocab.token_to_id.find(sub) != vocab.token_to_id.end()) { tokens.push_back(vocab.token_to_id.at(sub)); } else { fprintf(stderr, "%s: unknown token '%s'\n", __func__, sub.data()); } ++i; } } } return tokens; } bool should_transpose_layer(std::string name) { if(name.find(".mlp.fc_in.weight")!=std::string::npos || name.find(".attn.out_proj.weight")!=std::string::npos || name.find(".attn.q_proj.weight")!=std::string::npos || name.find(".attn.k_proj.weight")!=std::string::npos || name.find(".attn.v_proj.weight")!=std::string::npos || name.find("/attn/c_attn/w")!=std::string::npos || name.find("/attn/c_proj/w")!=std::string::npos || name.find("/mlp/c_fc/w")!=std::string::npos || name.find("/mlp/c_proj/w")!=std::string::npos) { return true; } return false; }