sync write_output_files function

This commit is contained in:
Concedo 2025-11-30 17:21:38 +08:00
parent ef992b4ab7
commit 95be49ac19

View file

@ -1012,8 +1012,8 @@ void process_shaders() {
}
void write_output_files() {
std::stringstream hdr = make_generic_stringstream();
std::stringstream src = make_generic_stringstream();
std::ofstream hdr(target_hpp, std::ios::binary);
std::ofstream src(target_cpp, std::ios::binary);
hdr << "#include <cstdint>\n\n";
src << "#include \"" << basename(target_hpp) << "\"\n\n";
@ -1031,7 +1031,7 @@ void write_output_files() {
hdr << "extern const uint64_t " << name << "_len;\n";
hdr << "extern const unsigned char " << name << "_data[];\n\n";
if (input_filepath != "") {
//if (input_filepath != "") {
std::string data = read_binary_file(path);
if (data.empty()) {
continue;
@ -1045,6 +1045,9 @@ void write_output_files() {
if ((i + 1) % 12 == 0) src << "\n";
}
src << std::dec << "\n};\n\n";
//}
if (!no_clean) {
std::remove(path.c_str());
}
}
@ -1054,9 +1057,9 @@ void write_output_files() {
hdr << "extern const uint64_t " << op << "_len[2][2][2][2];\n";
std::string op_file = op == "add_rms" ? "add.comp" : std::string(op) + ".comp";
if (basename(input_filepath) != op_file) {
continue;
}
// if (basename(input_filepath) != op_file) {
// continue;
// }
std::stringstream data = make_generic_stringstream();
std::stringstream len = make_generic_stringstream();
data << "const void * " << op << "_data[2][2][2][2] = ";
@ -1122,147 +1125,19 @@ void write_output_files() {
}
hdr << "extern const void * arr_dmmv_" << tname << "_" << btype << "_f32_data[3];\n";
hdr << "extern const uint64_t arr_dmmv_" << tname << "_" << btype << "_f32_len[3];\n";
if (basename(input_filepath) == "mul_mat_vec.comp") {
// if (basename(input_filepath) == "mul_mat_vec.comp") {
src << "const void * arr_dmmv_" << tname << "_" << btype << "_f32_data[3] = {mul_mat_vec_" << tname << "_" << btype << "_f32_data, mul_mat_vec_" << tname << "_" << btype << "_f32_subgroup_data, mul_mat_vec_" << tname << "_" << btype << "_f32_subgroup_no_shmem_data};\n";
src << "const uint64_t arr_dmmv_" << tname << "_" << btype << "_f32_len[3] = {mul_mat_vec_" << tname << "_" << btype << "_f32_len, mul_mat_vec_" << tname << "_" << btype << "_f32_subgroup_len, mul_mat_vec_" << tname << "_" << btype << "_f32_subgroup_no_shmem_len};\n";
}
// }
}
}
if (input_filepath == "") {
write_file_if_changed(target_hpp, hdr.str());
}
if (target_cpp != "") {
write_binary_file(target_cpp, src.str());
}
}
void write_output_files_combined() {
FILE* hdr = fopen(target_hpp.c_str(), "w");
FILE* src = fopen(target_cpp.c_str(), "w");
fprintf(hdr, "#include <cstdint>\n\n");
fprintf(src, "#include \"%s\"\n\n", basename(target_hpp).c_str());
std::sort(shader_fnames.begin(), shader_fnames.end());
for (const auto& pair : shader_fnames) {
const std::string& name = pair.first;
#ifdef _WIN32
std::string path = pair.second;
std::replace(path.begin(), path.end(), '/', '\\' );
#else
const std::string& path = pair.second;
#endif
FILE* spv = fopen(path.c_str(), "rb");
if (!spv) {
std::cerr << "Error opening SPIR-V file: " << path << " (" << strerror(errno) << ")\n";
continue;
}
fseek(spv, 0, SEEK_END);
size_t size = ftell(spv);
fseek(spv, 0, SEEK_SET);
std::vector<unsigned char> data(size);
size_t read_size = fread(data.data(), 1, size, spv);
fclose(spv);
if (read_size != size) {
std::cerr << "Error reading SPIR-V file: " << path << " (" << strerror(errno) << ")\n";
continue;
}
fprintf(hdr, "extern unsigned char %s_data[%zu];\n", name.c_str(), size);
fprintf(hdr, "const uint64_t %s_len = %zu;\n\n", name.c_str(), size);
fprintf(src, "unsigned char %s_data[%zu] = {\n", name.c_str(), size);
for (size_t i = 0; i < size; ++i) {
fprintf(src, "0x%02x,", data[i]);
if ((i + 1) % 12 == 0) fprintf(src, "\n");
}
fprintf(src, "\n};\n\n");
if (!no_clean) {
std::remove(path.c_str());
}
}
std::string suffixes[2] = {"_f32", "_f16"};
for (const char *op : {"add", "sub", "mul", "div", "add_rms"}) {
fprintf(hdr, "extern unsigned char *%s_data[2][2][2][2];\n", op);
fprintf(hdr, "extern uint64_t %s_len[2][2][2][2];\n", op);
std::string data = "unsigned char *" + std::string(op) + "_data[2][2][2][2] = ";
std::string len = "uint64_t " + std::string(op) + "_len[2][2][2][2] = ";
for (uint32_t t0 = 0; t0 < 2; ++t0) {
if (t0 == 0) {
data += "{";
len += "{";
}
for (uint32_t t1 = 0; t1 < 2; ++t1) {
if (t1 == 0) {
data += "{";
len += "{";
}
for (uint32_t t2 = 0; t2 < 2; ++t2) {
if (t2 == 0) {
data += "{";
len += "{";
}
for (uint32_t rte = 0; rte < 2; ++rte) {
if (rte == 0) {
data += "{";
len += "{";
}
data += op + suffixes[t0] + suffixes[t1] + suffixes[t2] + ((rte != 0) ? "_rte" : "");
len += op + suffixes[t0] + suffixes[t1] + suffixes[t2] + ((rte != 0) ? "_rte" : "");
data += "_data,";
len += "_len,";
if (rte == 1) {
data += "}, ";
len += "}, ";
}
}
if (t2 == 1) {
data += "}, ";
len += "}, ";
}
}
if (t1 == 1) {
data += "}, ";
len += "}, ";
}
}
if (t0 == 1) {
data += "};\n";
len += "};\n";
}
}
fputs(data.c_str(), src);
fputs(len.c_str(), src);
}
std::vector<std::string> btypes = {"f16", "f32"};
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
btypes.push_back("q8_1");
#endif
for (const std::string& btype : btypes) {
for (const auto& tname : type_names) {
if (btype == "q8_1" && !is_legacy_quant(tname)) {
continue;
}
fprintf(hdr, "extern unsigned char *arr_dmmv_%s_%s_f32_data[3];\n", tname.c_str(), btype.c_str());
fprintf(hdr, "extern uint64_t arr_dmmv_%s_%s_f32_len[3];\n", tname.c_str(), btype.c_str());
std::string data = "unsigned char *arr_dmmv_" + tname + "_" + btype + "_f32_data[3] = {mul_mat_vec_" + tname + "_" + btype + "_f32_data, mul_mat_vec_" + tname + "_" + btype + "_f32_subgroup_data, mul_mat_vec_" + tname + "_" + btype + "_f32_subgroup_no_shmem_data};\n";
std::string len = "uint64_t arr_dmmv_" + tname + "_" + btype + "_f32_len[3] = {mul_mat_vec_" + tname + "_" + btype + "_f32_len, mul_mat_vec_" + tname + "_" + btype + "_f32_subgroup_len, mul_mat_vec_" + tname + "_" + btype + "_f32_subgroup_no_shmem_len};\n";
fputs(data.c_str(), src);
fputs(len.c_str(), src);
}
}
fclose(hdr);
fclose(src);
// if (input_filepath == "") {
// write_file_if_changed(target_hpp, hdr.str());
// }
// if (target_cpp != "") {
// write_binary_file(target_cpp, src.str());
// }
}
} // namespace
@ -1314,8 +1189,7 @@ int main(int argc, char** argv) {
process_shaders();
//write_output_files();
write_output_files_combined();
write_output_files();
return EXIT_SUCCESS;
}