add device_inp_embd_delay test, device_memory_bw test, device_cuda_memory_bw test,

This commit is contained in:
Zonghang Li 2024-11-26 22:28:02 +04:00
parent a7a95b53fe
commit f78c437172
4 changed files with 346 additions and 63 deletions

View file

@ -32,6 +32,8 @@
#include <sys/types.h>
#include <vector>
#include <inttypes.h>
#include <thread>
#include <cuda_runtime.h>
const char * device_name() {
static char device_name[256];
@ -84,6 +86,7 @@ uint32_t device_cpu_cores() {
}
static float device_flops(struct llama_model * model, enum ggml_type src0t, enum ggml_type src1t, profiler_backend_type btype, int n_threads) {
const int n_repeat = 1;
const int n_embd = llama_n_embd(model);
std::vector<float> matrix_A(n_embd * n_embd, 1.0f);
std::vector<float> matrix_B(n_embd * n_embd, 1.0f / n_embd);
@ -110,11 +113,8 @@ static float device_flops(struct llama_model * model, enum ggml_type src0t, enum
return 0.0f;
}
size_t ctx_size = 0;
ctx_size += 2 * ggml_tensor_overhead(); // tensors
struct ggml_init_params params = {
/*.mem_size =*/ ctx_size,
/*.mem_size =*/ 2 * ggml_tensor_overhead(),
/*.mem_buffer =*/ NULL,
/*.no_alloc =*/ true, // the tensors will be allocated later by ggml_backend_alloc_ctx_tensors()
};
@ -130,16 +130,20 @@ static float device_flops(struct llama_model * model, enum ggml_type src0t, enum
struct ggml_cgraph * gf = NULL;
struct ggml_context * ctx_cgraph = NULL;
struct ggml_tensor * cur = NULL;
{
struct ggml_init_params params0 = {
/*.mem_size =*/ ggml_tensor_overhead() * GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead(),
/*.mem_size =*/ ggml_tensor_overhead() * (n_repeat + 2) + ggml_graph_overhead(),
/*.mem_buffer =*/ NULL,
/*.no_alloc =*/ true, // the tensors will be allocated later by ggml_gallocr_alloc_graph()
};
ctx_cgraph = ggml_init(params0);
gf = ggml_new_graph(ctx_cgraph);
struct ggml_tensor * cur = ggml_mul_mat(ctx_cgraph, tensor_a, tensor_b);
cur = ggml_mul_mat(ctx_cgraph, tensor_a, tensor_b);
for (int i = 0; i < n_repeat - 1; i++) {
cur = ggml_mul_mat(ctx_cgraph, tensor_a, cur);
}
ggml_build_forward_expand(gf, cur);
}
@ -151,14 +155,14 @@ static float device_flops(struct llama_model * model, enum ggml_type src0t, enum
}
// warm-up
ggml_backend_graph_compute(backend, gf);
// ggml_backend_graph_compute(backend, gf);
const int64_t t_start = ggml_time_us();
ggml_backend_graph_compute(backend, gf);
const int64_t t_end = ggml_time_us();
double elapsed_seconds = ((double)t_end - (double)t_start) / 1e6; // convert to seconds
double flops = (2.0 * (double)n_embd * (double)n_embd * (double)n_embd) / elapsed_seconds / 1e9; // convert to GFLOPS
double flops = (2.0 * n_repeat * (double)n_embd * (double)n_embd * (double)n_embd) / elapsed_seconds / 1e9; // convert to GFLOPS
ggml_free(ctx_cgraph);
ggml_gallocr_free(allocr);
@ -195,6 +199,125 @@ float device_cuda_flops(struct llama_model * model, enum ggml_type src0t, enum g
return 0.0f;
}
float device_inp_embd_delay(struct llama_model * model, enum ggml_type src0t, int n_tokens, int n_threads) {
const int n_vocab = llama_n_vocab(model);
const int n_embd = llama_n_embd(model);
ggml_backend_t backend = ggml_backend_cpu_init();
if (!backend) {
LOG_INF("%s: ggml backend init failed\n", __func__);
return 0.0f;
}
size_t ctx_size = 0;
ctx_size += 2 * ggml_tensor_overhead(); // tensors
struct ggml_init_params params = {
/*.mem_size =*/ ctx_size,
/*.mem_buffer =*/ NULL,
/*.no_alloc =*/ true, // the tensors will be allocated later by ggml_backend_alloc_ctx_tensors()
};
struct ggml_context * ctx = ggml_init(params);
struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_tokens);
struct ggml_tensor * tok_embd = ggml_new_tensor_2d(ctx, src0t, n_embd, n_vocab);
ggml_backend_buffer_t buffer = ggml_backend_alloc_ctx_tensors(ctx, backend);
std::vector<int32_t> matrix_A(n_tokens);
for (int i = 0; i < n_tokens; ++i) {
matrix_A[i] = i % n_vocab;
}
const size_t embd_size = n_vocab * n_embd;
void * matrix_B = nullptr;
// quantization and dequantization functions
ggml_type_traits_t qfns = ggml_internal_get_type_traits(src0t);
if (!qfns.from_float || !qfns.to_float) {
LOG_INF("Unsupported or uninitialized quantization type: %d\n", src0t);
ggml_free(ctx);
ggml_backend_buffer_free(buffer);
ggml_backend_free(backend);
return 0.0f;
}
size_t QK_K = 0;
switch (src0t) {
case GGML_TYPE_F32: {
matrix_B = malloc(embd_size * sizeof(float));
float * matrix_B_f32 = static_cast<float *>(matrix_B);
for (size_t i = 0; i < embd_size; ++i) {
matrix_B_f32[i] = static_cast<float>(rand()) / RAND_MAX;
}
break;
}
case GGML_TYPE_F16: {
matrix_B = malloc(embd_size * sizeof(ggml_fp16_t));
std::vector<float> temp_f32(embd_size);
for (size_t i = 0; i < embd_size; ++i) {
temp_f32[i] = static_cast<float>(rand()) / RAND_MAX;
}
ggml_fp32_to_fp16_row(temp_f32.data(), static_cast<ggml_fp16_t *>(matrix_B), embd_size);
break;
}
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q6_K:
case GGML_TYPE_Q8_K:
QK_K = 256;
matrix_B = malloc((embd_size / QK_K) * ggml_type_size(src0t));
break;
default:
LOG_INF("Unsupported type: %d\n", src0t);
ggml_free(ctx);
ggml_backend_buffer_free(buffer);
ggml_backend_free(backend);
return 0.0f;
}
ggml_backend_tensor_set(inp_tokens, matrix_A.data(), 0, ggml_nbytes(inp_tokens));
ggml_backend_tensor_set(tok_embd, matrix_B, 0, ggml_nbytes(tok_embd));
struct ggml_cgraph * gf = NULL;
struct ggml_context * ctx_cgraph = NULL;
{
struct ggml_init_params params0 = {
/*.mem_size =*/ ggml_tensor_overhead() * 3 + ggml_graph_overhead(),
/*.mem_buffer =*/ NULL,
/*.no_alloc =*/ true, // the tensors will be allocated later by ggml_gallocr_alloc_graph()
};
ctx_cgraph = ggml_init(params0);
gf = ggml_new_graph(ctx_cgraph);
struct ggml_tensor * cur = ggml_get_rows(ctx_cgraph, tok_embd, inp_tokens);
ggml_build_forward_expand(gf, cur);
}
ggml_gallocr_t allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(backend));
ggml_gallocr_alloc_graph(allocr, gf);
if (ggml_backend_is_cpu(backend)) {
ggml_backend_cpu_set_n_threads(backend, n_threads);
}
// warm-up
// ggml_backend_graph_compute(backend, gf);
const int64_t t_start = ggml_time_us();
ggml_backend_graph_compute(backend, gf);
const int64_t t_end = ggml_time_us();
double elapsed_ms = ((double)t_end - (double)t_start) / 1e3; // convert to ms
ggml_free(ctx_cgraph);
ggml_gallocr_free(allocr);
ggml_free(ctx);
ggml_backend_buffer_free(buffer);
ggml_backend_free(backend);
return (float)elapsed_ms;
}
uint64_t device_physical_memory(bool available) {
uint64_t memory = 0;
@ -357,41 +480,107 @@ uint64_t device_disk_read_bw(const char * test_file, size_t buffer_size_mb) {
return speed;
}
uint64_t device_memory_bw(size_t buffer_size_mb) {
uint64_t speed = 0;
size_t test_size = buffer_size_mb * 1024 * 1024; // convert MB to bytes
float device_memory_bw(int n_thread) {
size_t buffer_size = 5L * 1024 * 1024; // 5m
std::vector<std::thread> thread_pool;
std::vector<double> results(n_thread);
std::vector<char *> buffers(n_thread);
try {
// allocate memory for speed test
std::vector<char> buffer(test_size, 1);
for (int i = 0; i < n_thread; ++i) {
buffers[i] = new char[buffer_size];
}
// measure write speed
auto start_time = std::chrono::high_resolution_clock::now();
memset(buffer.data(), 0xAB, buffer.size());
auto end_time = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> elapsed_time = end_time - start_time;
double write_speed = static_cast<double>(test_size) / elapsed_time.count();
// measure read speed
start_time = std::chrono::high_resolution_clock::now();
auto memory_bw_test = [](char * buffer, size_t size, double & result) {
// read test
volatile char temp = 0;
for (size_t i = 0; i < buffer.size(); i += 64) {
temp += buffer[i]; // read in steps of cache line size to minimize cache thrashing
auto start = std::chrono::high_resolution_clock::now();
for (size_t i = 0; i < size; i += 64) {
temp += buffer[i];
}
end_time = std::chrono::high_resolution_clock::now();
elapsed_time = end_time - start_time;
double read_speed = static_cast<double>(test_size) / elapsed_time.count();
auto end = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> elapsed = end - start;
result = size / elapsed.count() / 1e9; // GB/s
};
// average speed
speed = static_cast<uint64_t>((write_speed + read_speed) / 2.0);
buffer.clear();
buffer.shrink_to_fit();
} catch (const std::exception &e) {
LOG_ERR("Exception while calculating memory speed: %s\n", e.what());
for (int i = 0; i < n_thread; ++i) {
thread_pool.emplace_back(memory_bw_test, buffers[i], buffer_size, std::ref(results[i]));
}
for (auto & t : thread_pool) {
t.join();
}
return speed;
double bandwidth = 0.0f;
for (double result : results) {
bandwidth += result;
}
for (int i = 0; i < n_thread; ++i) {
delete[] buffers[i];
}
return static_cast<float>(bandwidth);
}
float device_cuda_memory_bw(struct llama_model * model) {
#ifndef GGML_USE_CUDA
return 0.0f;
#endif
const int n_embd = llama_n_embd(model) * 2;
std::vector<float> matrix_A(n_embd * n_embd, 1.0f);
ggml_backend_t backend = ggml_backend_cuda_init(0);
if (!backend) {
LOG_INF("%s: ggml backend init failed\n", __func__);
return 0.0f;
}
struct ggml_init_params params = {
/*.mem_size =*/ ggml_tensor_overhead(),
/*.mem_buffer =*/ NULL,
/*.no_alloc =*/ true, // the tensors will be allocated later by ggml_backend_alloc_ctx_tensors()
};
struct ggml_context * ctx = ggml_init(params);
struct ggml_tensor * tensor_a = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_embd);
tensor_a->op = GGML_OP_READ;
ggml_backend_buffer_t buffer = ggml_backend_alloc_ctx_tensors(ctx, backend);
ggml_backend_tensor_set(tensor_a, matrix_A.data(), 0, ggml_nbytes(tensor_a));
struct ggml_cgraph * gf = NULL;
struct ggml_context * ctx_cgraph = NULL;
{
struct ggml_init_params params0 = {
/*.mem_size =*/ ggml_tensor_overhead() + ggml_graph_overhead(),
/*.mem_buffer =*/ NULL,
/*.no_alloc =*/ true, // the tensors will be allocated later by ggml_gallocr_alloc_graph()
};
ctx_cgraph = ggml_init(params0);
gf = ggml_new_graph(ctx_cgraph);
ggml_build_forward_expand(gf, tensor_a);
}
ggml_gallocr_t allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(backend));
ggml_gallocr_alloc_graph(allocr, gf);
const int64_t t_start = ggml_time_us();
ggml_backend_graph_compute(backend, gf);
const int64_t t_end = ggml_time_us();
double elapsed_s = ((double)t_end - (double)t_start) / 1e6;
size_t total_bytes = n_embd * n_embd * sizeof(float);
float bandwidth = (total_bytes / elapsed_s) / 1e9; // GB/s
ggml_free(ctx_cgraph);
ggml_gallocr_free(allocr);
ggml_free(ctx);
ggml_backend_buffer_free(buffer);
ggml_backend_free(backend);
return bandwidth;
}
int device_has_metal(void) {
@ -433,6 +622,68 @@ void device_get_props(struct llama_model * model, int device, struct ggml_backen
ggml_backend_dev_get_props(dev, props);
}
static float device_compute_delay(struct device_info & dev_info, int n_layers) {
struct model_flops n_flops = dev_info.model_flops;
struct cpu_props cpu = dev_info.cpu_props;
double total_latency = 0.0f;
#ifdef GGML_USE_CUDA
struct gpu_props gpu = dev_info.gpu_props;
total_latency += (double)n_flops.layer_f32_f32 / (double)gpu.cuda_flops_f32_f32 / 1e9;
total_latency += (double)n_flops.layer_f16_f32 / (double)gpu.cuda_flops_f16_f32 / 1e9;
total_latency += (double)n_flops.layer_q4k_f32 / (double)gpu.cuda_flops_q4k_f32 / 1e9;
total_latency += (double)n_flops.layer_q6k_f32 / (double)gpu.cuda_flops_q6k_f32 / 1e9;
total_latency += (double)n_flops.layer_q80_f32 / (double)gpu.cuda_flops_q80_f32 / 1e9;
#else
total_latency += (double)n_flops.layer_f32_f32 / (double)cpu.flops_f32_f32 / 1e9;
total_latency += (double)n_flops.layer_f16_f32 / (double)cpu.flops_f16_f32 / 1e9;
total_latency += (double)n_flops.layer_q4k_f32 / (double)cpu.flops_q4k_f32 / 1e9;
total_latency += (double)n_flops.layer_q6k_f32 / (double)cpu.flops_q6k_f32 / 1e9;
total_latency += (double)n_flops.layer_q80_f32 / (double)cpu.flops_q80_f32 / 1e9;
#endif
total_latency *= n_layers;
total_latency += (double)n_flops.output_f32_f32 / (double)cpu.flops_f32_f32 / 1e9;
total_latency += (double)n_flops.output_f16_f32 / (double)cpu.flops_f16_f32 / 1e9;
total_latency += (double)n_flops.output_q4k_f32 / (double)cpu.flops_q4k_f32 / 1e9;
total_latency += (double)n_flops.output_q6k_f32 / (double)cpu.flops_q6k_f32 / 1e9;
total_latency += (double)n_flops.output_q80_f32 / (double)cpu.flops_q80_f32 / 1e9;
total_latency *= 1000; // convert to ms
total_latency += n_flops.inp_embd_ms;
return static_cast<float>(total_latency);
}
// estimate the memory access delay, except for the input embedding because it has been considered in n_flops.inp_embd_ms
static float device_memory_access_delay(struct device_info & dev_info, int n_layers) {
struct model_params n_params = dev_info.model_params;
int64_t total_params = 0;
total_params += n_params.layer_f32 * 4 +
n_params.layer_f16 * 2 +
n_params.layer_q4k / 2 +
n_params.layer_q6k * 3 / 8 +
n_params.layer_q80;
total_params *= n_layers;
total_params += n_params.output_f32 * 4 +
n_params.output_f16 * 2 +
n_params.output_q4k / 2 +
n_params.output_q6k * 3 / 8 +
n_params.output_q80;
#ifdef GGML_USE_CUDA
return (double)total_params / 1e6 / dev_info.gpu_props.read_bandwidth;
#else
return (double)total_params / 1e6 / dev_info.memory.read_bandwidth;
#endif
}
void device_print_props(struct device_info * dev_info_set, int n, struct llama_model * model) {
LOG_INF("\n-------------------------------------------------------------------------------------------\n");
LOG_INF("| Property ");
@ -520,9 +771,9 @@ void device_print_props(struct device_info * dev_info_set, int n, struct llama_m
}
LOG_INF("\n");
LOG_INF("| Mem Bandwidth (GB/s) ");
LOG_INF("| Mem Read Bandwidth (GB/s) ");
for (int i = 0; i < n; ++i) {
LOG_INF("| %-10.2f ", dev_info_set[i].memory.bandwidth);
LOG_INF("| %-10.2f ", dev_info_set[i].memory.read_bandwidth);
}
LOG_INF("\n");
@ -598,6 +849,12 @@ void device_print_props(struct device_info * dev_info_set, int n, struct llama_m
}
LOG_INF("\n");
LOG_INF("| VRAM Read Bandwidth (GB/s) ");
for (int i = 0; i < n; ++i) {
LOG_INF("| %-10.2f ", dev_info_set[i].gpu_props.read_bandwidth);
}
LOG_INF("\n");
LOG_INF("| Metal flops (F32xF32, GFLOPS)");
for (int i = 0; i < n; ++i) {
LOG_INF("| %-10.1f ", dev_info_set[i].gpu_props.metal_flops_f32_f32);
@ -758,13 +1015,14 @@ void device_print_props(struct device_info * dev_info_set, int n, struct llama_m
LOG_INF("| %-10" PRId64 " ", dev_info_set[0].model_params.output_q80);
LOG_INF("\n");
model_flops ffo = dev_info_set[0].model_flops;
int64_t total_flops = ffo.output_f32_f32 + (ffo.layer_f32_f32 * llama_model_n_layers(model)); // todo
double cpu_flops_f16 = dev_info_set[0].cpu_props.flops_f16_f32 * 1e9;
float latency = 0.0f;
int n_layers = llama_model_n_layers(model);
latency += device_compute_delay(dev_info_set[0], n_layers);
latency += device_memory_access_delay(dev_info_set[0], n_layers);
// LOG_INF("| Token latency (ms) ");
// LOG_INF("| %-10.2f ", total_flops / cpu_flops_f16 * 1000);
// LOG_INF("\n");
LOG_INF("| Token latency (ms) ");
LOG_INF("| %-10.2f ", latency);
LOG_INF("\n");
LOG_INF("-------------------------------------------------------------------------------------------\n\n");
}
@ -790,7 +1048,7 @@ size_t serialize(const struct device_info * dev_info, char ** buffer) {
+ sizeof(float) * 5 // cpu_props.flops_f32_f32, cpu_props.flops_f16_f32, cpu_props.flops_q4k_f32, cpu_props.flops_q6k_f32, cpu_props.flops_q80_f32
+ sizeof(struct memory_info)
+ sizeof(struct gpu_support)
+ sizeof(float) * 12; // gpu_props.memory_free, gpu_props.memory_total,
+ sizeof(float) * 13; // gpu_props.memory_free, gpu_props.memory_total, gpu_props.read_bandwidth,
// gpu_props.metal_flops_f32_f32, gpu_props.metal_flops_f16_f32, gpu_props.metal_flops_q4k_f32, gpu_props.metal_flops_q6k_f32, gpu_props.metal_flops_q80_f32,
// gpu_props.cuda_flops_f32_f32, gpu_props.cuda_flops_f16_f32, gpu_props.cuda_flops_q4k_f32, gpu_props.cuda_flops_q6k_f32, gpu_props.cuda_flops_q80_f32
@ -861,6 +1119,9 @@ size_t serialize(const struct device_info * dev_info, char ** buffer) {
memcpy(ptr, &dev_info->gpu_props.memory_total, sizeof(float));
ptr += sizeof(float);
memcpy(ptr, &dev_info->gpu_props.read_bandwidth, sizeof(float));
ptr += sizeof(float);
memcpy(ptr, &dev_info->gpu_props.metal_flops_f32_f32, sizeof(float));
ptr += sizeof(float);
@ -975,6 +1236,9 @@ void deserialize(const char * buffer, struct device_info * dev_info) {
memcpy(&dev_info->gpu_props.memory_total, ptr, sizeof(float));
ptr += sizeof(float);
memcpy(&dev_info->gpu_props.read_bandwidth, ptr, sizeof(float));
ptr += sizeof(float);
memcpy(&dev_info->gpu_props.metal_flops_f32_f32, ptr, sizeof(float));
ptr += sizeof(float);

View file

@ -30,14 +30,14 @@ struct memory_info {
float available_physical; // in GB
float total_swap; // in GB
float available_swap; // in GB
float bandwidth; // in GB/s
float read_bandwidth; // in GB/s
memory_info() :
total_physical (0.0f),
available_physical(0.0f),
total_swap (0.0f),
available_swap (0.0f),
bandwidth (0.0f) {}
read_bandwidth (0.0f) {}
};
struct gpu_support {
@ -64,6 +64,7 @@ struct gpu_props {
const char * description;
float memory_free; // in GB
float memory_total; // in GB
float read_bandwidth; // in GB/s
float metal_flops_f32_f32; // in GFLOPS
float metal_flops_f16_f32; // in GFLOPS
float metal_flops_q4k_f32; // in GFLOPS
@ -80,6 +81,7 @@ struct gpu_props {
description(""),
memory_free (0.0f),
memory_total (0.0f),
read_bandwidth (1.0f),
metal_flops_f32_f32(0.0f),
metal_flops_f16_f32(0.0f),
metal_flops_q4k_f32(0.0f),
@ -93,6 +95,7 @@ struct gpu_props {
};
struct model_flops {
float inp_embd_ms;
int64_t output_f32_f32;
int64_t output_f16_f32;
int64_t output_q4k_f32;
@ -105,6 +108,7 @@ struct model_flops {
int64_t layer_q80_f32;
model_flops() :
inp_embd_ms(0.0f),
output_f32_f32(0),
output_f16_f32(0),
output_q4k_f32(0),
@ -193,10 +197,12 @@ uint32_t device_cpu_cores (void);
float device_cpu_flops (struct llama_model * model, enum ggml_type src0t, enum ggml_type src1t, int n_threads);
float device_metal_flops (struct llama_model * model, enum ggml_type src0t, enum ggml_type src1t);
float device_cuda_flops (struct llama_model * model, enum ggml_type src0t, enum ggml_type src1t);
float device_inp_embd_delay (struct llama_model * model, enum ggml_type src0t, int n_tokens, int n_threads);
uint64_t device_physical_memory(bool available);
uint64_t device_swap_memory (bool available);
uint64_t device_disk_read_bw (const char * test_file, size_t buffer_size_mb);
uint64_t device_memory_bw (size_t buffer_size_mb);
float device_memory_bw (int n_thread);
float device_cuda_memory_bw (struct llama_model * model);
void device_get_props (struct llama_model * model, int device, struct ggml_backend_dev_props * props);
void device_print_props (struct device_info * dev_info_set, int n, struct llama_model * model);

View file

@ -531,7 +531,8 @@ extern "C" {
struct model_flops * n_flops,
struct model_params * n_params,
const int64_t n_input,
const int64_t n_history);
const int64_t n_history,
enum ggml_type * inp_embd_dtype);
// Get a llama model tensor
LLAMA_API struct ggml_tensor * llama_get_model_tensor(struct llama_model * model, const char * name);

View file

@ -3560,7 +3560,7 @@ void llama_profile_device(device_info * dev_info, struct llama_model * model, ll
dev_info->memory.available_physical = round(device_physical_memory(true) / (double)(1 << 30) * 100) / 100;
dev_info->memory.total_swap = round(device_swap_memory(false) / (double)(1 << 30) * 100) / 100;
dev_info->memory.available_swap = round(device_swap_memory(true) / (double)(1 << 30) * 100) / 100;
dev_info->memory.bandwidth = round(device_memory_bw (500) / (double)(1 << 30) * 100) / 100;
dev_info->memory.read_bandwidth = device_memory_bw(n_threads);
dev_info->disk_read_bandwidth = round(device_disk_read_bw(test_file, 500) / (double)(1 << 30) * 100) / 100;
@ -3584,6 +3584,7 @@ void llama_profile_device(device_info * dev_info, struct llama_model * model, ll
dev_info->gpu_props.description = gpu_props.description;
dev_info->gpu_props.memory_free = round(gpu_props.memory_free / (double)(1 << 30) * 100) / 100;
dev_info->gpu_props.memory_total = round(gpu_props.memory_total / (double)(1 << 30) * 100) / 100;
dev_info->gpu_props.read_bandwidth = device_cuda_memory_bw(model);
dev_info->gpu_props.metal_flops_f32_f32 = device_metal_flops(model, GGML_TYPE_F32, GGML_TYPE_F32);
dev_info->gpu_props.metal_flops_f16_f32 = device_metal_flops(model, GGML_TYPE_F16, GGML_TYPE_F32);
dev_info->gpu_props.metal_flops_q4k_f32 = device_metal_flops(model, GGML_TYPE_Q4_K, GGML_TYPE_F32);
@ -3598,7 +3599,9 @@ void llama_profile_device(device_info * dev_info, struct llama_model * model, ll
if (dev_info->rank == 0) {
struct model_flops * n_flops = &dev_info->model_flops;
struct model_params * n_params = &dev_info->model_params;
llama_model_n_flops(model, ml, n_flops, n_params, 1, 10);
enum ggml_type inp_embd_dtype = GGML_TYPE_F32;
llama_model_n_flops(model, ml, n_flops, n_params, 1, 32, &inp_embd_dtype);
n_flops->inp_embd_ms = device_inp_embd_delay(model, inp_embd_dtype, 1, n_threads);
}
}
@ -20766,6 +20769,7 @@ static void count_n_params(struct model_params * n_params, enum ggml_type dtype,
break;
case GGML_TYPE_Q8_0:
n_params->output_q80 += n_i64t;
break;
default:
throw std::runtime_error("Unrecognized weight type in PROFILER_LAYER_OUTPUT\n");
}
@ -20798,7 +20802,14 @@ static void count_n_params(struct model_params * n_params, enum ggml_type dtype,
}
}
void llama_model_n_flops(struct llama_model * model, struct llama_model_loader * ml, struct model_flops * n_flops, struct model_params * n_params, const int64_t n_input, const int64_t n_history) {
void llama_model_n_flops(
struct llama_model * model,
struct llama_model_loader * ml,
struct model_flops * n_flops,
struct model_params * n_params,
const int64_t n_input,
const int64_t n_history,
enum ggml_type * inp_embd_dtype) {
const llama_hparams hparams = model->hparams;
const int64_t n_layer = hparams.n_layer;
const int64_t n_vocab = hparams.n_vocab;
@ -20904,6 +20915,7 @@ void llama_model_n_flops(struct llama_model * model, struct llama_model_loader *
switch (it->second) {
case 1: { // "token_embd.weight"
count_n_params(n_params, cur->type, PROFILER_LAYER_INPUT, ggml_nelements(cur));
*inp_embd_dtype = cur->type;
break;
}
case 2: { // "output_norm.weight"