diff --git a/common/profiler.cpp b/common/profiler.cpp index 1bdd88b6..99be85cb 100644 --- a/common/profiler.cpp +++ b/common/profiler.cpp @@ -32,6 +32,8 @@ #include #include #include +#include +#include const char * device_name() { static char device_name[256]; @@ -84,7 +86,8 @@ uint32_t device_cpu_cores() { } static float device_flops(struct llama_model * model, enum ggml_type src0t, enum ggml_type src1t, profiler_backend_type btype, int n_threads) { - const int n_embd = llama_n_embd(model); + const int n_repeat = 1; + const int n_embd = llama_n_embd(model); std::vector matrix_A(n_embd * n_embd, 1.0f); std::vector matrix_B(n_embd * n_embd, 1.0f / n_embd); @@ -110,11 +113,8 @@ static float device_flops(struct llama_model * model, enum ggml_type src0t, enum return 0.0f; } - size_t ctx_size = 0; - ctx_size += 2 * ggml_tensor_overhead(); // tensors - struct ggml_init_params params = { - /*.mem_size =*/ ctx_size, + /*.mem_size =*/ 2 * ggml_tensor_overhead(), /*.mem_buffer =*/ NULL, /*.no_alloc =*/ true, // the tensors will be allocated later by ggml_backend_alloc_ctx_tensors() }; @@ -130,16 +130,20 @@ static float device_flops(struct llama_model * model, enum ggml_type src0t, enum struct ggml_cgraph * gf = NULL; struct ggml_context * ctx_cgraph = NULL; + struct ggml_tensor * cur = NULL; { struct ggml_init_params params0 = { - /*.mem_size =*/ ggml_tensor_overhead() * GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead(), + /*.mem_size =*/ ggml_tensor_overhead() * (n_repeat + 2) + ggml_graph_overhead(), /*.mem_buffer =*/ NULL, /*.no_alloc =*/ true, // the tensors will be allocated later by ggml_gallocr_alloc_graph() }; ctx_cgraph = ggml_init(params0); gf = ggml_new_graph(ctx_cgraph); - struct ggml_tensor * cur = ggml_mul_mat(ctx_cgraph, tensor_a, tensor_b); + cur = ggml_mul_mat(ctx_cgraph, tensor_a, tensor_b); + for (int i = 0; i < n_repeat - 1; i++) { + cur = ggml_mul_mat(ctx_cgraph, tensor_a, cur); + } ggml_build_forward_expand(gf, cur); } @@ -151,14 +155,14 @@ static float device_flops(struct llama_model * model, enum ggml_type src0t, enum } // warm-up - ggml_backend_graph_compute(backend, gf); + // ggml_backend_graph_compute(backend, gf); const int64_t t_start = ggml_time_us(); ggml_backend_graph_compute(backend, gf); const int64_t t_end = ggml_time_us(); double elapsed_seconds = ((double)t_end - (double)t_start) / 1e6; // convert to seconds - double flops = (2.0 * (double)n_embd * (double)n_embd * (double)n_embd) / elapsed_seconds / 1e9; // convert to GFLOPS + double flops = (2.0 * n_repeat * (double)n_embd * (double)n_embd * (double)n_embd) / elapsed_seconds / 1e9; // convert to GFLOPS ggml_free(ctx_cgraph); ggml_gallocr_free(allocr); @@ -195,6 +199,125 @@ float device_cuda_flops(struct llama_model * model, enum ggml_type src0t, enum g return 0.0f; } +float device_inp_embd_delay(struct llama_model * model, enum ggml_type src0t, int n_tokens, int n_threads) { + const int n_vocab = llama_n_vocab(model); + const int n_embd = llama_n_embd(model); + + ggml_backend_t backend = ggml_backend_cpu_init(); + if (!backend) { + LOG_INF("%s: ggml backend init failed\n", __func__); + return 0.0f; + } + + size_t ctx_size = 0; + ctx_size += 2 * ggml_tensor_overhead(); // tensors + + struct ggml_init_params params = { + /*.mem_size =*/ ctx_size, + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, // the tensors will be allocated later by ggml_backend_alloc_ctx_tensors() + }; + struct ggml_context * ctx = ggml_init(params); + + struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_tokens); + struct ggml_tensor * tok_embd = ggml_new_tensor_2d(ctx, src0t, n_embd, n_vocab); + + ggml_backend_buffer_t buffer = ggml_backend_alloc_ctx_tensors(ctx, backend); + + std::vector matrix_A(n_tokens); + for (int i = 0; i < n_tokens; ++i) { + matrix_A[i] = i % n_vocab; + } + + const size_t embd_size = n_vocab * n_embd; + void * matrix_B = nullptr; + + // quantization and dequantization functions + ggml_type_traits_t qfns = ggml_internal_get_type_traits(src0t); + if (!qfns.from_float || !qfns.to_float) { + LOG_INF("Unsupported or uninitialized quantization type: %d\n", src0t); + ggml_free(ctx); + ggml_backend_buffer_free(buffer); + ggml_backend_free(backend); + return 0.0f; + } + + size_t QK_K = 0; + switch (src0t) { + case GGML_TYPE_F32: { + matrix_B = malloc(embd_size * sizeof(float)); + float * matrix_B_f32 = static_cast(matrix_B); + for (size_t i = 0; i < embd_size; ++i) { + matrix_B_f32[i] = static_cast(rand()) / RAND_MAX; + } + break; + } + case GGML_TYPE_F16: { + matrix_B = malloc(embd_size * sizeof(ggml_fp16_t)); + std::vector temp_f32(embd_size); + for (size_t i = 0; i < embd_size; ++i) { + temp_f32[i] = static_cast(rand()) / RAND_MAX; + } + ggml_fp32_to_fp16_row(temp_f32.data(), static_cast(matrix_B), embd_size); + break; + } + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q6_K: + case GGML_TYPE_Q8_K: + QK_K = 256; + matrix_B = malloc((embd_size / QK_K) * ggml_type_size(src0t)); + break; + default: + LOG_INF("Unsupported type: %d\n", src0t); + ggml_free(ctx); + ggml_backend_buffer_free(buffer); + ggml_backend_free(backend); + return 0.0f; + } + + ggml_backend_tensor_set(inp_tokens, matrix_A.data(), 0, ggml_nbytes(inp_tokens)); + ggml_backend_tensor_set(tok_embd, matrix_B, 0, ggml_nbytes(tok_embd)); + + struct ggml_cgraph * gf = NULL; + struct ggml_context * ctx_cgraph = NULL; + { + struct ggml_init_params params0 = { + /*.mem_size =*/ ggml_tensor_overhead() * 3 + ggml_graph_overhead(), + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, // the tensors will be allocated later by ggml_gallocr_alloc_graph() + }; + ctx_cgraph = ggml_init(params0); + + gf = ggml_new_graph(ctx_cgraph); + struct ggml_tensor * cur = ggml_get_rows(ctx_cgraph, tok_embd, inp_tokens); + ggml_build_forward_expand(gf, cur); + } + + ggml_gallocr_t allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(backend)); + ggml_gallocr_alloc_graph(allocr, gf); + + if (ggml_backend_is_cpu(backend)) { + ggml_backend_cpu_set_n_threads(backend, n_threads); + } + + // warm-up + // ggml_backend_graph_compute(backend, gf); + + const int64_t t_start = ggml_time_us(); + ggml_backend_graph_compute(backend, gf); + const int64_t t_end = ggml_time_us(); + + double elapsed_ms = ((double)t_end - (double)t_start) / 1e3; // convert to ms + + ggml_free(ctx_cgraph); + ggml_gallocr_free(allocr); + ggml_free(ctx); + ggml_backend_buffer_free(buffer); + ggml_backend_free(backend); + + return (float)elapsed_ms; +} + uint64_t device_physical_memory(bool available) { uint64_t memory = 0; @@ -357,41 +480,107 @@ uint64_t device_disk_read_bw(const char * test_file, size_t buffer_size_mb) { return speed; } -uint64_t device_memory_bw(size_t buffer_size_mb) { - uint64_t speed = 0; - size_t test_size = buffer_size_mb * 1024 * 1024; // convert MB to bytes +float device_memory_bw(int n_thread) { + size_t buffer_size = 5L * 1024 * 1024; // 5m + std::vector thread_pool; + std::vector results(n_thread); + std::vector buffers(n_thread); - try { - // allocate memory for speed test - std::vector buffer(test_size, 1); - - // measure write speed - auto start_time = std::chrono::high_resolution_clock::now(); - memset(buffer.data(), 0xAB, buffer.size()); - auto end_time = std::chrono::high_resolution_clock::now(); - std::chrono::duration elapsed_time = end_time - start_time; - double write_speed = static_cast(test_size) / elapsed_time.count(); - - // measure read speed - start_time = std::chrono::high_resolution_clock::now(); - volatile char temp = 0; - for (size_t i = 0; i < buffer.size(); i += 64) { - temp += buffer[i]; // read in steps of cache line size to minimize cache thrashing - } - end_time = std::chrono::high_resolution_clock::now(); - elapsed_time = end_time - start_time; - double read_speed = static_cast(test_size) / elapsed_time.count(); - - // average speed - speed = static_cast((write_speed + read_speed) / 2.0); - - buffer.clear(); - buffer.shrink_to_fit(); - } catch (const std::exception &e) { - LOG_ERR("Exception while calculating memory speed: %s\n", e.what()); + for (int i = 0; i < n_thread; ++i) { + buffers[i] = new char[buffer_size]; } - return speed; + auto memory_bw_test = [](char * buffer, size_t size, double & result) { + // read test + volatile char temp = 0; + auto start = std::chrono::high_resolution_clock::now(); + for (size_t i = 0; i < size; i += 64) { + temp += buffer[i]; + } + auto end = std::chrono::high_resolution_clock::now(); + std::chrono::duration elapsed = end - start; + result = size / elapsed.count() / 1e9; // GB/s + }; + + for (int i = 0; i < n_thread; ++i) { + thread_pool.emplace_back(memory_bw_test, buffers[i], buffer_size, std::ref(results[i])); + } + for (auto & t : thread_pool) { + t.join(); + } + + double bandwidth = 0.0f; + for (double result : results) { + bandwidth += result; + } + + for (int i = 0; i < n_thread; ++i) { + delete[] buffers[i]; + } + + return static_cast(bandwidth); +} + +float device_cuda_memory_bw(struct llama_model * model) { +#ifndef GGML_USE_CUDA + return 0.0f; +#endif + + const int n_embd = llama_n_embd(model) * 2; + std::vector matrix_A(n_embd * n_embd, 1.0f); + + ggml_backend_t backend = ggml_backend_cuda_init(0); + if (!backend) { + LOG_INF("%s: ggml backend init failed\n", __func__); + return 0.0f; + } + + struct ggml_init_params params = { + /*.mem_size =*/ ggml_tensor_overhead(), + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, // the tensors will be allocated later by ggml_backend_alloc_ctx_tensors() + }; + struct ggml_context * ctx = ggml_init(params); + + struct ggml_tensor * tensor_a = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_embd); + tensor_a->op = GGML_OP_READ; + + ggml_backend_buffer_t buffer = ggml_backend_alloc_ctx_tensors(ctx, backend); + + ggml_backend_tensor_set(tensor_a, matrix_A.data(), 0, ggml_nbytes(tensor_a)); + + struct ggml_cgraph * gf = NULL; + struct ggml_context * ctx_cgraph = NULL; + { + struct ggml_init_params params0 = { + /*.mem_size =*/ ggml_tensor_overhead() + ggml_graph_overhead(), + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, // the tensors will be allocated later by ggml_gallocr_alloc_graph() + }; + ctx_cgraph = ggml_init(params0); + + gf = ggml_new_graph(ctx_cgraph); + ggml_build_forward_expand(gf, tensor_a); + } + + ggml_gallocr_t allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(backend)); + ggml_gallocr_alloc_graph(allocr, gf); + + const int64_t t_start = ggml_time_us(); + ggml_backend_graph_compute(backend, gf); + const int64_t t_end = ggml_time_us(); + + double elapsed_s = ((double)t_end - (double)t_start) / 1e6; + size_t total_bytes = n_embd * n_embd * sizeof(float); + float bandwidth = (total_bytes / elapsed_s) / 1e9; // GB/s + + ggml_free(ctx_cgraph); + ggml_gallocr_free(allocr); + ggml_free(ctx); + ggml_backend_buffer_free(buffer); + ggml_backend_free(backend); + + return bandwidth; } int device_has_metal(void) { @@ -433,6 +622,68 @@ void device_get_props(struct llama_model * model, int device, struct ggml_backen ggml_backend_dev_get_props(dev, props); } +static float device_compute_delay(struct device_info & dev_info, int n_layers) { + struct model_flops n_flops = dev_info.model_flops; + struct cpu_props cpu = dev_info.cpu_props; + + double total_latency = 0.0f; +#ifdef GGML_USE_CUDA + struct gpu_props gpu = dev_info.gpu_props; + + total_latency += (double)n_flops.layer_f32_f32 / (double)gpu.cuda_flops_f32_f32 / 1e9; + total_latency += (double)n_flops.layer_f16_f32 / (double)gpu.cuda_flops_f16_f32 / 1e9; + total_latency += (double)n_flops.layer_q4k_f32 / (double)gpu.cuda_flops_q4k_f32 / 1e9; + total_latency += (double)n_flops.layer_q6k_f32 / (double)gpu.cuda_flops_q6k_f32 / 1e9; + total_latency += (double)n_flops.layer_q80_f32 / (double)gpu.cuda_flops_q80_f32 / 1e9; +#else + total_latency += (double)n_flops.layer_f32_f32 / (double)cpu.flops_f32_f32 / 1e9; + total_latency += (double)n_flops.layer_f16_f32 / (double)cpu.flops_f16_f32 / 1e9; + total_latency += (double)n_flops.layer_q4k_f32 / (double)cpu.flops_q4k_f32 / 1e9; + total_latency += (double)n_flops.layer_q6k_f32 / (double)cpu.flops_q6k_f32 / 1e9; + total_latency += (double)n_flops.layer_q80_f32 / (double)cpu.flops_q80_f32 / 1e9; +#endif + + total_latency *= n_layers; + + total_latency += (double)n_flops.output_f32_f32 / (double)cpu.flops_f32_f32 / 1e9; + total_latency += (double)n_flops.output_f16_f32 / (double)cpu.flops_f16_f32 / 1e9; + total_latency += (double)n_flops.output_q4k_f32 / (double)cpu.flops_q4k_f32 / 1e9; + total_latency += (double)n_flops.output_q6k_f32 / (double)cpu.flops_q6k_f32 / 1e9; + total_latency += (double)n_flops.output_q80_f32 / (double)cpu.flops_q80_f32 / 1e9; + + total_latency *= 1000; // convert to ms + + total_latency += n_flops.inp_embd_ms; + + return static_cast(total_latency); +} + +// estimate the memory access delay, except for the input embedding because it has been considered in n_flops.inp_embd_ms +static float device_memory_access_delay(struct device_info & dev_info, int n_layers) { + struct model_params n_params = dev_info.model_params; + + int64_t total_params = 0; + total_params += n_params.layer_f32 * 4 + + n_params.layer_f16 * 2 + + n_params.layer_q4k / 2 + + n_params.layer_q6k * 3 / 8 + + n_params.layer_q80; + + total_params *= n_layers; + + total_params += n_params.output_f32 * 4 + + n_params.output_f16 * 2 + + n_params.output_q4k / 2 + + n_params.output_q6k * 3 / 8 + + n_params.output_q80; + +#ifdef GGML_USE_CUDA + return (double)total_params / 1e6 / dev_info.gpu_props.read_bandwidth; +#else + return (double)total_params / 1e6 / dev_info.memory.read_bandwidth; +#endif +} + void device_print_props(struct device_info * dev_info_set, int n, struct llama_model * model) { LOG_INF("\n-------------------------------------------------------------------------------------------\n"); LOG_INF("| Property "); @@ -520,9 +771,9 @@ void device_print_props(struct device_info * dev_info_set, int n, struct llama_m } LOG_INF("\n"); - LOG_INF("| Mem Bandwidth (GB/s) "); + LOG_INF("| Mem Read Bandwidth (GB/s) "); for (int i = 0; i < n; ++i) { - LOG_INF("| %-10.2f ", dev_info_set[i].memory.bandwidth); + LOG_INF("| %-10.2f ", dev_info_set[i].memory.read_bandwidth); } LOG_INF("\n"); @@ -598,6 +849,12 @@ void device_print_props(struct device_info * dev_info_set, int n, struct llama_m } LOG_INF("\n"); + LOG_INF("| VRAM Read Bandwidth (GB/s) "); + for (int i = 0; i < n; ++i) { + LOG_INF("| %-10.2f ", dev_info_set[i].gpu_props.read_bandwidth); + } + LOG_INF("\n"); + LOG_INF("| Metal flops (F32xF32, GFLOPS)"); for (int i = 0; i < n; ++i) { LOG_INF("| %-10.1f ", dev_info_set[i].gpu_props.metal_flops_f32_f32); @@ -758,13 +1015,14 @@ void device_print_props(struct device_info * dev_info_set, int n, struct llama_m LOG_INF("| %-10" PRId64 " ", dev_info_set[0].model_params.output_q80); LOG_INF("\n"); - model_flops ffo = dev_info_set[0].model_flops; - int64_t total_flops = ffo.output_f32_f32 + (ffo.layer_f32_f32 * llama_model_n_layers(model)); // todo - double cpu_flops_f16 = dev_info_set[0].cpu_props.flops_f16_f32 * 1e9; + float latency = 0.0f; + int n_layers = llama_model_n_layers(model); + latency += device_compute_delay(dev_info_set[0], n_layers); + latency += device_memory_access_delay(dev_info_set[0], n_layers); - // LOG_INF("| Token latency (ms) "); - // LOG_INF("| %-10.2f ", total_flops / cpu_flops_f16 * 1000); - // LOG_INF("\n"); + LOG_INF("| Token latency (ms) "); + LOG_INF("| %-10.2f ", latency); + LOG_INF("\n"); LOG_INF("-------------------------------------------------------------------------------------------\n\n"); } @@ -790,7 +1048,7 @@ size_t serialize(const struct device_info * dev_info, char ** buffer) { + sizeof(float) * 5 // cpu_props.flops_f32_f32, cpu_props.flops_f16_f32, cpu_props.flops_q4k_f32, cpu_props.flops_q6k_f32, cpu_props.flops_q80_f32 + sizeof(struct memory_info) + sizeof(struct gpu_support) - + sizeof(float) * 12; // gpu_props.memory_free, gpu_props.memory_total, + + sizeof(float) * 13; // gpu_props.memory_free, gpu_props.memory_total, gpu_props.read_bandwidth, // gpu_props.metal_flops_f32_f32, gpu_props.metal_flops_f16_f32, gpu_props.metal_flops_q4k_f32, gpu_props.metal_flops_q6k_f32, gpu_props.metal_flops_q80_f32, // gpu_props.cuda_flops_f32_f32, gpu_props.cuda_flops_f16_f32, gpu_props.cuda_flops_q4k_f32, gpu_props.cuda_flops_q6k_f32, gpu_props.cuda_flops_q80_f32 @@ -861,6 +1119,9 @@ size_t serialize(const struct device_info * dev_info, char ** buffer) { memcpy(ptr, &dev_info->gpu_props.memory_total, sizeof(float)); ptr += sizeof(float); + memcpy(ptr, &dev_info->gpu_props.read_bandwidth, sizeof(float)); + ptr += sizeof(float); + memcpy(ptr, &dev_info->gpu_props.metal_flops_f32_f32, sizeof(float)); ptr += sizeof(float); @@ -975,6 +1236,9 @@ void deserialize(const char * buffer, struct device_info * dev_info) { memcpy(&dev_info->gpu_props.memory_total, ptr, sizeof(float)); ptr += sizeof(float); + memcpy(&dev_info->gpu_props.read_bandwidth, ptr, sizeof(float)); + ptr += sizeof(float); + memcpy(&dev_info->gpu_props.metal_flops_f32_f32, ptr, sizeof(float)); ptr += sizeof(float); diff --git a/common/profiler.h b/common/profiler.h index a2395a14..744af877 100644 --- a/common/profiler.h +++ b/common/profiler.h @@ -26,18 +26,18 @@ struct cpu_props { }; struct memory_info { - float total_physical; // in GB - float available_physical; // in GB - float total_swap; // in GB - float available_swap; // in GB - float bandwidth; // in GB/s + float total_physical; // in GB + float available_physical; // in GB + float total_swap; // in GB + float available_swap; // in GB + float read_bandwidth; // in GB/s memory_info() : total_physical (0.0f), available_physical(0.0f), total_swap (0.0f), available_swap (0.0f), - bandwidth (0.0f) {} + read_bandwidth (0.0f) {} }; struct gpu_support { @@ -64,6 +64,7 @@ struct gpu_props { const char * description; float memory_free; // in GB float memory_total; // in GB + float read_bandwidth; // in GB/s float metal_flops_f32_f32; // in GFLOPS float metal_flops_f16_f32; // in GFLOPS float metal_flops_q4k_f32; // in GFLOPS @@ -80,6 +81,7 @@ struct gpu_props { description(""), memory_free (0.0f), memory_total (0.0f), + read_bandwidth (1.0f), metal_flops_f32_f32(0.0f), metal_flops_f16_f32(0.0f), metal_flops_q4k_f32(0.0f), @@ -93,6 +95,7 @@ struct gpu_props { }; struct model_flops { + float inp_embd_ms; int64_t output_f32_f32; int64_t output_f16_f32; int64_t output_q4k_f32; @@ -105,6 +108,7 @@ struct model_flops { int64_t layer_q80_f32; model_flops() : + inp_embd_ms(0.0f), output_f32_f32(0), output_f16_f32(0), output_q4k_f32(0), @@ -193,10 +197,12 @@ uint32_t device_cpu_cores (void); float device_cpu_flops (struct llama_model * model, enum ggml_type src0t, enum ggml_type src1t, int n_threads); float device_metal_flops (struct llama_model * model, enum ggml_type src0t, enum ggml_type src1t); float device_cuda_flops (struct llama_model * model, enum ggml_type src0t, enum ggml_type src1t); +float device_inp_embd_delay (struct llama_model * model, enum ggml_type src0t, int n_tokens, int n_threads); uint64_t device_physical_memory(bool available); uint64_t device_swap_memory (bool available); uint64_t device_disk_read_bw (const char * test_file, size_t buffer_size_mb); -uint64_t device_memory_bw (size_t buffer_size_mb); +float device_memory_bw (int n_thread); +float device_cuda_memory_bw (struct llama_model * model); void device_get_props (struct llama_model * model, int device, struct ggml_backend_dev_props * props); void device_print_props (struct device_info * dev_info_set, int n, struct llama_model * model); diff --git a/include/llama.h b/include/llama.h index 24663712..b75640b4 100644 --- a/include/llama.h +++ b/include/llama.h @@ -531,7 +531,8 @@ extern "C" { struct model_flops * n_flops, struct model_params * n_params, const int64_t n_input, - const int64_t n_history); + const int64_t n_history, + enum ggml_type * inp_embd_dtype); // Get a llama model tensor LLAMA_API struct ggml_tensor * llama_get_model_tensor(struct llama_model * model, const char * name); diff --git a/src/llama.cpp b/src/llama.cpp index fd7cb279..67e8b21a 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -3558,9 +3558,9 @@ void llama_profile_device(device_info * dev_info, struct llama_model * model, ll dev_info->memory.total_physical = round(device_physical_memory(false) / (double)(1 << 30) * 100) / 100; dev_info->memory.available_physical = round(device_physical_memory(true) / (double)(1 << 30) * 100) / 100; - dev_info->memory.total_swap = round(device_swap_memory (false) / (double)(1 << 30) * 100) / 100; - dev_info->memory.available_swap = round(device_swap_memory (true) / (double)(1 << 30) * 100) / 100; - dev_info->memory.bandwidth = round(device_memory_bw (500) / (double)(1 << 30) * 100) / 100; + dev_info->memory.total_swap = round(device_swap_memory(false) / (double)(1 << 30) * 100) / 100; + dev_info->memory.available_swap = round(device_swap_memory(true) / (double)(1 << 30) * 100) / 100; + dev_info->memory.read_bandwidth = device_memory_bw(n_threads); dev_info->disk_read_bandwidth = round(device_disk_read_bw(test_file, 500) / (double)(1 << 30) * 100) / 100; @@ -3584,6 +3584,7 @@ void llama_profile_device(device_info * dev_info, struct llama_model * model, ll dev_info->gpu_props.description = gpu_props.description; dev_info->gpu_props.memory_free = round(gpu_props.memory_free / (double)(1 << 30) * 100) / 100; dev_info->gpu_props.memory_total = round(gpu_props.memory_total / (double)(1 << 30) * 100) / 100; + dev_info->gpu_props.read_bandwidth = device_cuda_memory_bw(model); dev_info->gpu_props.metal_flops_f32_f32 = device_metal_flops(model, GGML_TYPE_F32, GGML_TYPE_F32); dev_info->gpu_props.metal_flops_f16_f32 = device_metal_flops(model, GGML_TYPE_F16, GGML_TYPE_F32); dev_info->gpu_props.metal_flops_q4k_f32 = device_metal_flops(model, GGML_TYPE_Q4_K, GGML_TYPE_F32); @@ -3598,7 +3599,9 @@ void llama_profile_device(device_info * dev_info, struct llama_model * model, ll if (dev_info->rank == 0) { struct model_flops * n_flops = &dev_info->model_flops; struct model_params * n_params = &dev_info->model_params; - llama_model_n_flops(model, ml, n_flops, n_params, 1, 10); + enum ggml_type inp_embd_dtype = GGML_TYPE_F32; + llama_model_n_flops(model, ml, n_flops, n_params, 1, 32, &inp_embd_dtype); + n_flops->inp_embd_ms = device_inp_embd_delay(model, inp_embd_dtype, 1, n_threads); } } @@ -20766,6 +20769,7 @@ static void count_n_params(struct model_params * n_params, enum ggml_type dtype, break; case GGML_TYPE_Q8_0: n_params->output_q80 += n_i64t; + break; default: throw std::runtime_error("Unrecognized weight type in PROFILER_LAYER_OUTPUT\n"); } @@ -20798,7 +20802,14 @@ static void count_n_params(struct model_params * n_params, enum ggml_type dtype, } } -void llama_model_n_flops(struct llama_model * model, struct llama_model_loader * ml, struct model_flops * n_flops, struct model_params * n_params, const int64_t n_input, const int64_t n_history) { +void llama_model_n_flops( + struct llama_model * model, + struct llama_model_loader * ml, + struct model_flops * n_flops, + struct model_params * n_params, + const int64_t n_input, + const int64_t n_history, + enum ggml_type * inp_embd_dtype) { const llama_hparams hparams = model->hparams; const int64_t n_layer = hparams.n_layer; const int64_t n_vocab = hparams.n_vocab; @@ -20904,6 +20915,7 @@ void llama_model_n_flops(struct llama_model * model, struct llama_model_loader * switch (it->second) { case 1: { // "token_embd.weight" count_n_params(n_params, cur->type, PROFILER_LAYER_INPUT, ggml_nelements(cur)); + *inp_embd_dtype = cur->type; break; } case 2: { // "output_norm.weight"