koboldcpp/otherarch/rwkv_vocab.cpp
2023-06-13 20:06:19 +08:00

87 lines
No EOL
3.5 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#include <vector>
#include <string>
#include <fstream>
#include <iostream>
#include "expose.h"
std::vector<std::string> rwkv_vocab;
std::vector<std::string> special = {"Ā","ā","Ă","ă","Ą","ą","Ć","ć","Ĉ","ĉ","Ċ","ċ","Č","č","Ď","ď","Đ","đ","Ē","ē","Ĕ","ĕ","Ė","ė","Ę","ę","Ě","ě","Ĝ","ĝ","Ğ","ğ","Ġ","!","\"","#","$","%","&","\'","(",")","*","+",",","-",".","/","0","1","2","3","4","5","6","7","8","9",":",";","<","=",">","?","@","A","B","C","D","E","F","G","H","I","J","K","L","M","N","O","P","Q","R","S","T","U","V","W","X","Y","Z","[","\\","]","^","_","`","a","b","c","d","e","f","g","h","i","j","k","l","m","n","o","p","q","r","s","t","u","v","w","x","y","z","{","|","}","~","ġ","Ģ","ģ","Ĥ","ĥ","Ħ","ħ","Ĩ","ĩ","Ī","ī","Ĭ","ĭ","Į","į","İ","ı","IJ","ij","Ĵ","ĵ","Ķ","ķ","ĸ","Ĺ","ĺ","Ļ","ļ","Ľ","ľ","Ŀ","ŀ","Ł","ł","¡","¢","£","¤","¥","¦","§","¨","©","ª","«","¬","Ń","®","¯","°","±","²","³","´","µ","","·","¸","¹","º","»","¼","½","¾","¿","À","Á","Â","Ã","Ä","Å","Æ","Ç","È","É","Ê","Ë","Ì","Í","Î","Ï","Ð","Ñ","Ò","Ó","Ô","Õ","Ö","×","Ø","Ù","Ú","Û","Ü","Ý","Þ","ß","à","á","â","ã","ä","å","æ","ç","è","é","ê","ë","ì","í","î","ï","ð","ñ","ò","ó","ô","õ","ö","÷","ø","ù","ú","û","ü","ý","þ","ÿ"};
static void replaceAll(std::string& str, const std::string& from, const std::string& to) {
if(from.empty())
return;
size_t start_pos = 0;
while((start_pos = str.find(from, start_pos)) != std::string::npos) {
str.replace(start_pos, from.length(), to);
start_pos += to.length(); // In case 'to' contains 'from', like replacing 'x' with 'yx'
}
}
static std::string hexToUnicode(const std::string& hexString) {
std::string unicodeString;
for (size_t i = 0; i < hexString.length(); i += 2) {
std::string byteString = hexString.substr(i, 2);
unsigned int byteValue = std::stoi(byteString, nullptr, 16);
unicodeString += static_cast<char>(byteValue);
}
return unicodeString;
}
void read_rwkv_vocab()
{
std::string line;
auto filepath = executable_path+ "rwkv_vocab.embd";
printf("\nReading vocab from %s",filepath.c_str());
std::ifstream myfile(filepath);
if (myfile.is_open())
{
int slen = special.size();
while (myfile.good())
{
getline(myfile, line);
for(int i=0;i<slen;++i)
{
std::string swapped = "";
swapped.push_back((char)i);
replaceAll(line,special[i],swapped);
}
rwkv_vocab.push_back(line);
}
myfile.close();
}
else
{
std::cout << "Unable to open RWKV vocab file";
}
}
void read_rwkv_world_vocab() //its in hexadecimal
{
std::string line;
std::string unicodeString;
auto filepath = executable_path+ "rwkv_world_vocab.embd";
printf("\nReading world vocab from %s",filepath.c_str());
std::ifstream myfile(filepath);
if (myfile.is_open())
{
int slen = special.size();
int idx = 0;
rwkv_vocab.push_back("<<UNUSED_TOKEN>>");
while (myfile.good())
{
getline(myfile, line);
unicodeString = hexToUnicode(line);
// printf("\n%d: %s",idx,unicodeString.c_str());
rwkv_vocab.push_back(unicodeString);
++idx;
}
myfile.close();
}
else
{
std::cout << "Unable to open RWKV world vocab file";
}
}