diff --git a/kt-kernel/cpu_backend/worker_pool.cpp b/kt-kernel/cpu_backend/worker_pool.cpp index 05564fc9..05d70660 100644 --- a/kt-kernel/cpu_backend/worker_pool.cpp +++ b/kt-kernel/cpu_backend/worker_pool.cpp @@ -13,437 +13,20 @@ #include #include #include -#include -#include #include #include #include #include -#include -#include -#include -#include -#include -#include #include -#include -#include #include "hwloc.h" -// RDTSC-based timer for lightweight timing -// Uses CPU timestamp counter instead of system clock for lower overhead -namespace { - -// Read CPU timestamp counter (RDTSC) -inline uint64_t rdtsc_now() { return __rdtsc(); } - -// Estimate RDTSC cycles for given milliseconds -// This is calculated once at startup -static uint64_t g_rdtsc_cycles_per_ms = 0; - -// Initialize RDTSC frequency by measuring against chrono -static uint64_t init_rdtsc_frequency() { - auto start_chrono = std::chrono::high_resolution_clock::now(); - uint64_t start_rdtsc = rdtsc_now(); - - // Busy wait for ~10ms to calibrate - while (true) { - auto now = std::chrono::high_resolution_clock::now(); - auto elapsed = std::chrono::duration_cast(now - start_chrono).count(); - if (elapsed >= 10) break; - } - - uint64_t end_rdtsc = rdtsc_now(); - auto end_chrono = std::chrono::high_resolution_clock::now(); - auto elapsed_ms = std::chrono::duration_cast(end_chrono - start_chrono).count(); - - if (elapsed_ms > 0) { - return (end_rdtsc - start_rdtsc) / elapsed_ms; - } - // Fallback: assume 2.5 GHz CPU - return 2500000; -} - -// Get cycles per millisecond (lazy initialization) -inline uint64_t get_rdtsc_cycles_per_ms() { - if (g_rdtsc_cycles_per_ms == 0) { - g_rdtsc_cycles_per_ms = init_rdtsc_frequency(); - } - return g_rdtsc_cycles_per_ms; -} - -} // namespace - -// ===================================================== -// Global per-thread timing for SFT MOE forward/backward -// Collects timing from InNumaPool worker threads -// ===================================================== -#ifndef SFT_TIMER_DISABLED -namespace sft_timer { - -constexpr int MAX_THREADS = 256; -static uint64_t forward_rt[MAX_THREADS] = {0}; -static uint64_t backward_rt[MAX_THREADS] = {0}; -static int forward_tasks[MAX_THREADS] = {0}; -static int backward_tasks[MAX_THREADS] = {0}; -static int forward_threads = 0; -static int backward_threads = 0; - -inline double ticks_to_ms(uint64_t cycles) { return (double)cycles / get_rdtsc_cycles_per_ms(); } - -// ===================================================== -// Chrome Trace Event Format support -// ===================================================== -struct TraceEvent { - std::string name; // event name (op_name) - std::string cat; // category - char ph; // phase: 'X' for complete event, 'B' for begin, 'E' for end - double ts; // timestamp in microseconds (with ns precision via decimals) - double dur; // duration in microseconds (with ns precision via decimals) - int pid; // process id (numa_id) - int tid; // thread id - int task_count; // number of tasks processed - std::string args_json; // optional custom args JSON (for kernel traces) -}; - -static std::vector g_trace_events; -static std::mutex g_trace_mutex; -static uint64_t g_trace_start_time = 0; // baseline timestamp (RDTSC) -static double g_trace_start_epoch_us = 0.0; // wall-clock epoch time in microseconds -static std::string g_trace_output_path = "sft_trace.json"; - -// Thread-safe initialization using std::call_once -static std::once_flag g_trace_init_flag; - -// Forward declaration for atexit registration. -static void write_trace_to_file(); - -// Initialize trace start time (thread-safe) -static void init_trace() { - std::call_once(g_trace_init_flag, []() { - g_trace_start_time = rdtsc_now(); - // Record wall-clock epoch time for cross-process trace alignment - auto now_wall = std::chrono::system_clock::now(); - auto epoch_us = std::chrono::duration_cast(now_wall.time_since_epoch()).count(); - g_trace_start_epoch_us = static_cast(epoch_us); - // Check for custom output path from environment - const char* env_path = std::getenv("SFT_TRACE_PATH"); - if (env_path && env_path[0] != '\0') { - g_trace_output_path = env_path; - } - // Flush trace on normal exit before static destructors run. - std::atexit(write_trace_to_file); - }); -} - -// Convert RDTSC cycles to microseconds with nanosecond precision (as double) -// Chrome tracing uses microseconds but supports fractional values for sub-us precision -static double cycles_to_us(uint64_t cycles) { - // cycles_per_ms * 1000 = cycles_per_us - // cycles / cycles_per_us = microseconds - // Using 1e6 for cycles_per_ms -> cycles_per_s, then divide to get us with ns precision - double cycles_per_us = get_rdtsc_cycles_per_ms() / 1000.0; - return static_cast(cycles) / cycles_per_us; -} - -// Add trace events for an operation using absolute timestamps -static void add_trace_events(const char* op_name, int numa_id, int thread_count, const uint64_t* start_ts_arr, - const uint64_t* end_ts_arr, const int* tasks) { - init_trace(); - - std::lock_guard lock(g_trace_mutex); - - for (int i = 0; i < thread_count; i++) { - // Convert absolute RDTSC timestamps to relative microseconds from trace start - double start_us = (start_ts_arr[i] > g_trace_start_time) ? cycles_to_us(start_ts_arr[i] - g_trace_start_time) : 0.0; - double end_us = (end_ts_arr[i] > g_trace_start_time) ? cycles_to_us(end_ts_arr[i] - g_trace_start_time) : 0.0; - double dur_us = end_us - start_us; - if (dur_us < 0) dur_us = 0; - - TraceEvent ev; - ev.name = op_name; - ev.cat = "sft_op"; - ev.ph = 'X'; // Complete event - ev.ts = start_us; - ev.dur = dur_us; - ev.pid = numa_id; - ev.tid = i; - ev.task_count = tasks[i]; - - g_trace_events.push_back(ev); - } -} - -// Write trace events to JSON file (Chrome Trace Event Format) -static void write_trace_to_file() { - std::lock_guard lock(g_trace_mutex); - - if (g_trace_events.empty()) { - return; - } - - // Sort events by (pid, tid, ts) to fix overlap issues in Chrome trace viewer - // Events from same thread should be ordered by start time - std::sort(g_trace_events.begin(), g_trace_events.end(), [](const TraceEvent& a, const TraceEvent& b) { - if (a.pid != b.pid) return a.pid < b.pid; - if (a.tid != b.tid) return a.tid < b.tid; - return a.ts < b.ts; - }); - - std::ofstream ofs(g_trace_output_path); - if (!ofs.is_open()) { - fprintf(stderr, "sft_timer: Failed to open trace file: %s\n", g_trace_output_path.c_str()); - return; - } - - // Use fixed precision for nanosecond accuracy (3 decimal places in microseconds = nanoseconds) - ofs << std::fixed << std::setprecision(3); - - ofs << "{\n"; - ofs << " \"traceEvents\": [\n"; - - for (size_t i = 0; i < g_trace_events.size(); i++) { - const auto& ev = g_trace_events[i]; - ofs << " {"; - ofs << "\"name\":\"" << ev.name << "\","; - ofs << "\"cat\":\"" << ev.cat << "\","; - ofs << "\"ph\":\"" << ev.ph << "\","; - ofs << "\"ts\":" << ev.ts << ","; - ofs << "\"dur\":" << ev.dur << ","; - ofs << "\"pid\":" << ev.pid << ","; - ofs << "\"tid\":" << ev.tid << ","; - if (!ev.args_json.empty()) { - ofs << "\"args\":" << ev.args_json; - } else { - ofs << "\"args\":{\"task_count\":" << ev.task_count << "}"; - } - ofs << "}"; - if (i < g_trace_events.size() - 1) { - ofs << ","; - } - ofs << "\n"; - } - - ofs << " ],\n"; - ofs << " \"metadata\": {\"start_epoch_us\": " << std::setprecision(0) << g_trace_start_epoch_us << "},\n"; - ofs << std::setprecision(3); - ofs << " \"displayTimeUnit\": \"ns\"\n"; - ofs << "}\n"; - - ofs.close(); - fprintf(stderr, "sft_timer: Trace written to %s (%zu events)\n", g_trace_output_path.c_str(), g_trace_events.size()); -} - -// Signal handler for SIGTERM -static void sigterm_handler(int sig) { - fprintf(stderr, "sft_timer: Received signal %d, writing trace...\n", sig); - write_trace_to_file(); - // Re-raise the signal with default handler to allow normal termination - signal(sig, SIG_DFL); - raise(sig); -} - -// Register signal handlers -static void register_signal_handlers() { - static bool registered = false; - if (!registered) { - signal(SIGTERM, sigterm_handler); - signal(SIGINT, sigterm_handler); - registered = true; - } -} - -void print_rt(FILE* out, const char* name, uint64_t* rt, int* tasks, int rt_threads) { - if (rt_threads <= 0) return; - FILE* output = out ? out : stderr; - auto max_val = *std::max_element(rt, rt + rt_threads); - auto min_val = *std::min_element(rt, rt + rt_threads); - uint64_t sum = std::accumulate(rt, rt + rt_threads, (uint64_t)0); - int total_tasks = std::accumulate(tasks, tasks + rt_threads, 0); - - // Sort to find 20% and 80% percentile thresholds - std::vector sorted(rt, rt + rt_threads); - std::sort(sorted.begin(), sorted.end()); - int p20_idx = rt_threads * 20 / 100; - int p80_idx = rt_threads * 80 / 100; - uint64_t p20_threshold = sorted[p20_idx]; // Fast threshold (top 20%) - uint64_t p80_threshold = sorted[p80_idx]; // Slow threshold (bottom 20%) - - // ANSI color codes - const char* GREEN = "\033[32m"; - const char* RED = "\033[31m"; - const char* RESET = "\033[0m"; - - // Line 1: time - fprintf(output, "%30s max %.3f min %.3f avg %.3f : ", name, ticks_to_ms(max_val), ticks_to_ms(min_val), - ticks_to_ms(sum / rt_threads)); - for (int i = 0; i < rt_threads; i++) { - if (rt[i] <= p20_threshold) { - fprintf(output, "%s%.3f%s ", GREEN, ticks_to_ms(rt[i]), RESET); - } else if (rt[i] >= p80_threshold) { - fprintf(output, "%s%.3f%s ", RED, ticks_to_ms(rt[i]), RESET); - } else { - fprintf(output, "%.3f ", ticks_to_ms(rt[i])); - } - } - fprintf(output, "\n"); - - // Line 2: task count - fprintf(output, "%30s total %d : ", "tasks", total_tasks); - for (int i = 0; i < rt_threads; i++) { - if (rt[i] <= p20_threshold) { - fprintf(output, "%s%d%s ", GREEN, tasks[i], RESET); - } else if (rt[i] >= p80_threshold) { - fprintf(output, "%s%d%s ", RED, tasks[i], RESET); - } else { - fprintf(output, "%d ", tasks[i]); - } - } - fprintf(output, "\n"); -} - -void reset_forward() { - std::fill(forward_rt, forward_rt + MAX_THREADS, 0); - std::fill(forward_tasks, forward_tasks + MAX_THREADS, 0); - forward_threads = 0; -} - -void reset_backward() { - std::fill(backward_rt, backward_rt + MAX_THREADS, 0); - std::fill(backward_tasks, backward_tasks + MAX_THREADS, 0); - backward_threads = 0; -} - -void collect_forward(InNumaPool* pool) { - int n = pool->get_worker_count(); - for (int i = 0; i < n && forward_threads < MAX_THREADS; i++) { - forward_rt[forward_threads] = pool->get_thread_cycles(i); - forward_tasks[forward_threads] = pool->get_thread_task_count(i); - forward_threads++; - } -} - -void collect_backward(InNumaPool* pool) { - int n = pool->get_worker_count(); - for (int i = 0; i < n && backward_threads < MAX_THREADS; i++) { - backward_rt[backward_threads] = pool->get_thread_cycles(i); - backward_tasks[backward_threads] = pool->get_thread_task_count(i); - backward_threads++; - } -} - -void print_forward() { print_rt(stderr, "forward", forward_rt, forward_tasks, forward_threads); } -void print_backward(const char* name) { print_rt(stderr, name, backward_rt, backward_tasks, backward_threads); } - -void print_op_stats(InNumaPool* pool, const char* op_name) { - if (pool == nullptr || op_name == nullptr || op_name[0] == '\0') { - return; - } - int n = pool->get_worker_count(); - if (n <= 0) { - return; - } - - // Ensure signal handlers are registered on first call - static bool handlers_registered = false; - if (!handlers_registered) { - register_signal_handlers(); - handlers_registered = true; - } - - FILE* output = stderr; - int numa_id = pool->get_numa_id(); - // if (numa_id == 0) { - // output = stdout; - // } else if (numa_id == 1) { - // output = stderr; - // } - std::vector rt(n); - std::vector start_ts(n); - std::vector end_ts(n); - std::vector tasks(n); - for (int i = 0; i < n; i++) { - rt[i] = pool->get_thread_cycles(i); - tasks[i] = pool->get_thread_task_count(i); - start_ts[i] = pool->get_thread_start_ts(i); - end_ts[i] = pool->get_thread_end_ts(i); - } - // print_rt(output, op_name, rt.data(), tasks.data(), n); - - // Save trace data to memory for later export - add_trace_events(op_name, numa_id, n, start_ts.data(), end_ts.data(), tasks.data()); -} - -// ===================================================== -// Kernel-level tracing API implementation -// ===================================================== - -uint64_t get_trace_timestamp() { return rdtsc_now(); } - -void add_kernel_trace(const char* name, uint64_t start_ts, uint64_t end_ts, int numa_id, int thread_id, - const char* args) { - init_trace(); - - // Convert absolute RDTSC timestamps to relative microseconds from trace start - double start_us = (start_ts > g_trace_start_time) ? cycles_to_us(start_ts - g_trace_start_time) : 0.0; - double end_us = (end_ts > g_trace_start_time) ? cycles_to_us(end_ts - g_trace_start_time) : 0.0; - double dur_us = end_us - start_us; - if (dur_us < 0) dur_us = 0; - - std::lock_guard lock(g_trace_mutex); - - TraceEvent ev; - ev.name = name; - ev.cat = "kernel"; - ev.ph = 'X'; // Complete event - ev.ts = start_us; - ev.dur = dur_us; - ev.pid = numa_id; - ev.tid = thread_id; - ev.task_count = 0; // Not applicable for kernel traces - if (args != nullptr && args[0] != '\0') { - ev.args_json = args; - } - - g_trace_events.push_back(ev); -} - -} // namespace sft_timer -#endif // SFT_TIMER_DISABLED - -// Intel ITT API for profiler integration (VTune, etc.) -// Allows profilers to identify spin-wait regions -#ifdef USE_ITT_NOTIFY -#include -static __itt_domain* g_itt_domain = nullptr; -static __itt_string_handle* g_itt_spin_wait = nullptr; - -static void init_itt() { - if (g_itt_domain == nullptr) { - g_itt_domain = __itt_domain_create("WorkerPool"); - g_itt_spin_wait = __itt_string_handle_create("SpinWait"); - } -} - -#define ITT_SYNC_PREPARE(addr) __itt_sync_prepare(addr) -#define ITT_SYNC_CANCEL(addr) __itt_sync_cancel(addr) -#define ITT_SYNC_ACQUIRED(addr) __itt_sync_acquired(addr) -#else -#define ITT_SYNC_PREPARE(addr) ((void)0) -#define ITT_SYNC_CANCEL(addr) ((void)0) -#define ITT_SYNC_ACQUIRED(addr) ((void)0) -static void init_itt() {} -#endif - thread_local int WorkerPool::thread_local_id = -1; InNumaPool::InNumaPool(int max_thread_num) { printf("In Numa Worker Pool at NUMA %d, %d threads\n", numa_node_of_cpu(sched_getcpu()), max_thread_num); - numa_id_ = numa_node_of_cpu(sched_getcpu()); total_worker_count = max_thread_num; - block_size_ = 0; set_restricted_worker_count(total_worker_count); thread_state_ = std::unique_ptr(new ThreadState[max_thread_num]); for (int i = 0; i < total_worker_count; i++) { @@ -463,9 +46,7 @@ InNumaPool::InNumaPool(int max_thread_num, int numa_id, int threads_id_start) { hwloc_topology_init(&topology); hwloc_topology_load(topology); printf("In Numa Worker Pool at NUMA %d, %d threads\n", numa_node_of_cpu(sched_getcpu()), max_thread_num); - numa_id_ = numa_id; total_worker_count = max_thread_num; - block_size_ = 0; set_restricted_worker_count(total_worker_count); thread_state_ = std::unique_ptr(new ThreadState[max_thread_num]); for (int i = 0; i < total_worker_count; i++) { @@ -514,7 +95,11 @@ InNumaPool::InNumaPool(int max_thread_num, int numa_id, int threads_id_start) { InNumaPool::~InNumaPool() { for (int i = 0; i < total_worker_count; i++) { - thread_state_[i].status.store(ThreadStatus::EXIT, std::memory_order_release); + { + std::lock_guard lock(thread_state_[i].mutex); + thread_state_[i].status.store(ThreadStatus::EXIT, std::memory_order_release); + } + thread_state_[i].cv.notify_one(); } for (int i = 0; i < total_worker_count; i++) { if (workers_[i].joinable()) { @@ -551,51 +136,41 @@ void InNumaPool::wait() { #endif } -void InNumaPool::do_work_stealing_job(int task_num, std::function compute_func, const char* task_name, - int block_size, bool async) { - do_work_stealing_job(task_num, nullptr, compute_func, nullptr, task_name, block_size); +void InNumaPool::do_work_stealing_job(int task_num, std::function compute_func) { + do_work_stealing_job(task_num, nullptr, compute_func, nullptr); } void InNumaPool::do_work_stealing_job(int task_num, std::function init_func, - std::function compute_func, std::function finalize_func, - const char* task_name, int block_size, bool async) { - bool has_name = task_name != nullptr && task_name[0] != '\0'; - if (has_name) { - reset_counters(); - } - do_work_stealing_job_async(task_num, init_func, compute_func, finalize_func, block_size); - if (!async) wait(); - if (has_name) { - sft_timer::print_op_stats(this, task_name); - } + std::function compute_func, std::function finalize_func) { + do_work_stealing_job_async(task_num, init_func, compute_func, finalize_func); + wait(); } void InNumaPool::do_work_stealing_job_async(int task_num, std::function init_func, std::function compute_func, - std::function finalize_func, int block_size) { + std::function finalize_func) { init_func_ = init_func; compute_func_ = compute_func; finalize_func_ = finalize_func; - block_size_ = block_size; worker_count = std::min(restricted_worker_count, task_num); curr_.store(0, std::memory_order_release); end_ = task_num; - for (int i = 0; i < worker_count; i++) { - thread_state_[i].status.store(ThreadStatus::WORKING, std::memory_order_release); + { + std::lock_guard lock(thread_state_[i].mutex); + thread_state_[i].status.store(ThreadStatus::WORKING, std::memory_order_release); + } + thread_state_[i].cv.notify_one(); } WorkerPool::thread_local_id = 0; process_tasks(0); } void InNumaPool::process_tasks(int thread_id) { - uint64_t start_cycles = rdtsc_now(); +#ifdef PROFILE_BALANCE + auto start = std::chrono::high_resolution_clock::now(); +#endif auto& s = thread_state_[thread_id]; - int local_task_count = 0; - - // Record absolute start timestamp - s.start_ts = start_cycles; - if (init_func_ != nullptr) { init_func_(thread_id); } @@ -608,12 +183,7 @@ void InNumaPool::process_tasks(int thread_id) { break; } - int block = 0; - if (block_size_ > 0) { - block = std::min(block_size_, rem); - } else { - block = (rem + worker_count - 1) / worker_count; - } + int block = (rem + worker_count - 1) / worker_count; block = 1; int task_id = curr_.fetch_add(block, std::memory_order_acq_rel); if (task_id >= end_) { @@ -625,7 +195,6 @@ void InNumaPool::process_tasks(int thread_id) { break; } compute_func_(task_id + i); - local_task_count++; } } @@ -633,44 +202,34 @@ void InNumaPool::process_tasks(int thread_id) { finalize_func_(thread_id); } - // IMPORTANT: Update timing BEFORE setting status to WAITING - // The release semantics of status.store() ensures all prior writes are visible - uint64_t end_cycles = rdtsc_now(); - s.finish_cycles = end_cycles - start_cycles; - s.task_count = local_task_count; - s.end_ts = end_cycles; - - // Signal completion - release ensures timing writes are visible to wait() s.status.store(ThreadStatus::WAITING, std::memory_order_release); +#ifdef PROFILE_BALANCE + s.finish_ns = + std::chrono::duration_cast(std::chrono::high_resolution_clock::now() - start).count(); +#endif } void InNumaPool::worker_thread(int thread_id, int numa_id) { if (numa_id >= 0) { set_memory_to_numa(numa_id); } - init_itt(); // Initialize ITT if enabled - // Use RDTSC for lightweight timing instead of std::chrono - const uint64_t sleep_threshold_cycles = get_rdtsc_cycles_per_ms() * 50; // 50ms in cycles - uint64_t start = rdtsc_now(); + auto start = std::chrono::high_resolution_clock::now(); WorkerPool::thread_local_id = thread_id; // 设置线程本地变量 while (true) { - ITT_SYNC_PREPARE(&thread_state_[thread_id].status); // Signal profiler: about to spin-wait ThreadStatus status = thread_state_[thread_id].status.load(std::memory_order_acquire); if (status == ThreadStatus::WORKING) { - ITT_SYNC_ACQUIRED(&thread_state_[thread_id].status); // Signal profiler: acquired work process_tasks(thread_id); - start = rdtsc_now(); + start = std::chrono::high_resolution_clock::now(); } else if (status == ThreadStatus::WAITING) { - // PAUSE instruction hints to CPU this is a spin-wait loop - _mm_pause(); - uint64_t now = rdtsc_now(); - uint64_t elapsed_cycles = now - start; - if (elapsed_cycles > sleep_threshold_cycles) { - ITT_SYNC_CANCEL(&thread_state_[thread_id].status); // Signal profiler: going to sleep - std::this_thread::sleep_for(std::chrono::microseconds(100)); + auto now = std::chrono::high_resolution_clock::now(); + auto duration = std::chrono::duration_cast(now - start).count(); + if (duration > 50) { + std::unique_lock lock(thread_state_[thread_id].mutex); + thread_state_[thread_id].cv.wait(lock, [&] { + return thread_state_[thread_id].status.load(std::memory_order_acquire) != ThreadStatus::WAITING; + }); } } else if (status == ThreadStatus::EXIT) { - ITT_SYNC_CANCEL(&thread_state_[thread_id].status); // Signal profiler: exiting return; } } @@ -695,6 +254,8 @@ void NumaJobDistributor::init(std::vector numa_ids) { this->numa_ids = numa_ids; for (size_t i = 0; i < numa_count; i++) { status.push_back(nullptr); + mutexes.push_back(std::make_unique()); + cvs.push_back(std::make_unique()); } workers.resize(numa_count); @@ -716,6 +277,8 @@ void NumaJobDistributor::init(std::vector numa_ids, std::vector thread this->numa_ids = numa_ids; for (size_t i = 0; i < numa_count; i++) { status.push_back(nullptr); + mutexes.push_back(std::make_unique()); + cvs.push_back(std::make_unique()); } workers.resize(numa_count); @@ -767,7 +330,11 @@ void NumaJobDistributor::init(std::vector numa_ids, std::vector thread NumaJobDistributor::~NumaJobDistributor() { for (int i = 0; i < numa_count; i++) { - status[i]->store(ThreadStatus::EXIT, std::memory_order_release); + { + std::lock_guard lock(*mutexes[i]); + status[i]->store(ThreadStatus::EXIT, std::memory_order_release); + } + cvs[i]->notify_one(); } for (int i = 0; i < numa_count; i++) { if (workers[i].joinable()) { @@ -784,7 +351,11 @@ void NumaJobDistributor::do_numa_job(std::function compute_func) { for (int i = 0; i < numa_count; i++) { if (i == me_numa) continue; - status[i]->store(ThreadStatus::WORKING, std::memory_order_release); + { + std::lock_guard lock(*mutexes[i]); + status[i]->store(ThreadStatus::WORKING, std::memory_order_release); + } + cvs[i]->notify_one(); } compute_func(me_numa); for (int i = 0; i < numa_count; i++) { @@ -798,7 +369,11 @@ void NumaJobDistributor::do_numa_job(std::function compute_func) { void NumaJobDistributor::do_numa_job(std::function compute_func) { this->compute_func = compute_func; for (int i = 0; i < numa_count; i++) { - status[i]->store(ThreadStatus::WORKING, std::memory_order_release); + { + std::lock_guard lock(*mutexes[i]); + status[i]->store(ThreadStatus::WORKING, std::memory_order_release); + } + cvs[i]->notify_one(); } for (int i = 0; i < numa_count; i++) { while (status[i]->load(std::memory_order_acquire) == ThreadStatus::WORKING) { @@ -808,35 +383,29 @@ void NumaJobDistributor::do_numa_job(std::function compute_func) { #endif void NumaJobDistributor::worker_thread(int numa_id) { - init_itt(); // Initialize ITT if enabled - // Use RDTSC for lightweight timing instead of std::chrono - const uint64_t sleep_threshold_cycles = get_rdtsc_cycles_per_ms() * 50; // 50ms in cycles - uint64_t start = rdtsc_now(); + auto start = std::chrono::high_resolution_clock::now(); set_memory_to_numa(numa_id); status[numa_id] = std::move(std::unique_ptr>(new std::atomic(ThreadStatus::WAITING))); ready_bar->arrive_and_wait(); while (true) { - ITT_SYNC_PREPARE(status[numa_id].get()); // Signal profiler: about to spin-wait auto stat = status[numa_id]->load(std::memory_order_acquire); if (stat == ThreadStatus::WORKING) { - ITT_SYNC_ACQUIRED(status[numa_id].get()); // Signal profiler: acquired work auto me_numa = numa_node_of_cpu(sched_getcpu()); // printf("numa work on %d, me %d\n", numa_id, me_numa); compute_func(numa_id); status[numa_id]->store(ThreadStatus::WAITING, std::memory_order_release); - start = rdtsc_now(); + start = std::chrono::high_resolution_clock::now(); } else if (stat == ThreadStatus::WAITING) { - // PAUSE instruction hints to CPU this is a spin-wait loop - _mm_pause(); - uint64_t now = rdtsc_now(); - uint64_t elapsed_cycles = now - start; - if (elapsed_cycles > sleep_threshold_cycles) { - ITT_SYNC_CANCEL(status[numa_id].get()); // Signal profiler: going to sleep - std::this_thread::sleep_for(std::chrono::milliseconds(1)); + auto now = std::chrono::high_resolution_clock::now(); + auto duration = std::chrono::duration_cast(now - start).count(); + if (duration > 50) { + std::unique_lock lock(*mutexes[numa_id]); + cvs[numa_id]->wait(lock, [&] { + return status[numa_id]->load(std::memory_order_acquire) != ThreadStatus::WAITING; + }); } } else if (stat == ThreadStatus::EXIT) { - ITT_SYNC_CANCEL(status[numa_id].get()); // Signal profiler: exiting return; } } @@ -911,15 +480,10 @@ InNumaPool* WorkerPool::get_subpool(int numa_id) { return numa_worker_pools[numa NumaJobDistributor* WorkerPool::dispense_backend() { return distributor.get(); } void WorkerPool::do_work_stealing_job(int task_num, std::function init_func, - std::function compute_func, std::function finalize_func, - const char* task_name, int block_size, bool async) { - numa_worker_pools[0]->do_work_stealing_job(task_num, init_func, compute_func, finalize_func, task_name, block_size, - async); + std::function compute_func, std::function finalize_func) { + numa_worker_pools[0]->do_work_stealing_job(task_num, init_func, compute_func, finalize_func); } -void WorkerPool::do_work_stealing_job(int task_num, std::function compute_func, const char* task_name, - int block_size, bool async) { - do_work_stealing_job(task_num, nullptr, compute_func, nullptr, task_name, block_size, async); +void WorkerPool::do_work_stealing_job(int task_num, std::function compute_func) { + do_work_stealing_job(task_num, nullptr, compute_func, nullptr); } - -void WorkerPool::wait() { numa_worker_pools[0]->wait(); } diff --git a/kt-kernel/cpu_backend/worker_pool.h b/kt-kernel/cpu_backend/worker_pool.h index bab75bbe..72c9b635 100644 --- a/kt-kernel/cpu_backend/worker_pool.h +++ b/kt-kernel/cpu_backend/worker_pool.h @@ -62,10 +62,11 @@ enum ThreadStatus { struct alignas(64) ThreadState { std::atomic status; - uint64_t finish_cycles; // Per-thread timing (always enabled) - int task_count; // Per-thread task count - uint64_t start_ts; // Absolute start timestamp (RDTSC) - uint64_t end_ts; // Absolute end timestamp (RDTSC) + std::mutex mutex; + std::condition_variable cv; +#ifdef PROFILE_BALANCE + size_t finish_ns; +#endif }; class InNumaPool { @@ -76,45 +77,21 @@ class InNumaPool { int get_thread_num(); void set_restricted_worker_count(int count); - void do_work_stealing_job_async(int, std::function, std::function, std::function, - int block_size = 0); + void do_work_stealing_job_async(int, std::function, std::function, std::function); void wait(); - void do_work_stealing_job(int, std::function, std::function, std::function, - const char* task_name = nullptr, int block_size = 0, bool async = false); - void do_work_stealing_job(int, std::function, const char* task_name = nullptr, int block_size = 0, - bool async = false); - - // Get per-thread timing info - int get_worker_count() const { return worker_count; } - int get_numa_id() const { return numa_id_; } - uint64_t get_thread_cycles(int tid) const { return thread_state_[tid].finish_cycles; } - int get_thread_task_count(int tid) const { return thread_state_[tid].task_count; } - uint64_t get_thread_start_ts(int tid) const { return thread_state_[tid].start_ts; } - uint64_t get_thread_end_ts(int tid) const { return thread_state_[tid].end_ts; } - - // Reset per-thread timing/task counters (call before timing a sequence of operations) - // NOTE: Only call when all workers are in WAITING state (after wait() returns) - void reset_counters() { - for (int i = 0; i < total_worker_count; i++) { - thread_state_[i].finish_cycles = 0; - thread_state_[i].task_count = 0; - thread_state_[i].start_ts = 0; - thread_state_[i].end_ts = 0; - } - } + void do_work_stealing_job(int, std::function, std::function, std::function); + void do_work_stealing_job(int, std::function); private: int worker_count; int total_worker_count; - int numa_id_; std::unique_ptr thread_state_; // [thread_num] std::vector workers_; // changed ever time called do_work_stealing_job_async int restricted_worker_count; - int block_size_; std::function init_func_; std::function compute_func_; std::function finalize_func_; @@ -144,6 +121,8 @@ class NumaJobDistributor { int numa_count; std::vector numa_ids; std::vector>> status; + std::vector> mutexes; + std::vector> cvs; std::function compute_func; std::vector workers; @@ -171,12 +150,8 @@ class WorkerPool { InNumaPool* get_subpool(int numa_id); - void do_work_stealing_job(int, std::function, std::function, std::function, - const char* task_name = nullptr, int block_size = 0, bool async = false); - void do_work_stealing_job(int, std::function, const char* task_name = nullptr, int block_size = 0, - bool async = false); - - void wait(); + void do_work_stealing_job(int, std::function, std::function, std::function); + void do_work_stealing_job(int, std::function); WorkerPoolConfig config; @@ -191,58 +166,4 @@ class WorkerPool { std::vector> numa_worker_pools; }; -// ===================================================== -// Global per-thread timing for SFT MOE forward/backward -// ===================================================== -// Define SFT_TIMER_DISABLED to disable all timing (functions become no-ops) -// #define SFT_TIMER_DISABLED -namespace sft_timer { - -#ifdef SFT_TIMER_DISABLED -// Disabled: all functions are no-ops -inline void reset_forward() {} -inline void reset_backward() {} -inline void collect_forward(InNumaPool*) {} -inline void collect_backward(InNumaPool*) {} -inline void print_forward() {} -inline void print_backward(const char* = "backward") {} -inline void print_op_stats(InNumaPool*, const char*) {} -inline uint64_t get_trace_timestamp() { return 0; } -inline void add_kernel_trace(const char*, uint64_t, uint64_t, int, int, const char* = nullptr) {} -#else -// Enabled: declarations only, implementation in worker_pool.cpp -void reset_forward(); -void reset_backward(); -void collect_forward(InNumaPool* pool); -void collect_backward(InNumaPool* pool); -void print_forward(); -void print_backward(const char* name = "backward"); - -// Print per-thread timing for a single operation -// Call pool->reset_counters() BEFORE the operation, then call this AFTER -void print_op_stats(InNumaPool* pool, const char* op_name); - -// ===================================================== -// Kernel-level tracing API -// For tracing individual kernels (e.g., AVX matmul) within worker threads -// ===================================================== - -// Get current RDTSC timestamp (lightweight, ~20 cycles overhead) -uint64_t get_trace_timestamp(); - -// Add a kernel trace event -// @param name Kernel name (e.g., "lora_bf16_matmul_t4r4") -// @param start_ts Start timestamp from get_trace_timestamp() -// @param end_ts End timestamp from get_trace_timestamp() -// @param numa_id NUMA node ID (use -1 for auto-detect or 0 if unknown) -// @param thread_id Thread ID within the pool (use WorkerPool::thread_local_id) -// @param args Optional JSON args string (e.g., "{\"tokens\":128,\"rank\":8}") -void add_kernel_trace(const char* name, uint64_t start_ts, uint64_t end_ts, int numa_id, int thread_id, - const char* args = nullptr); - -static void write_trace_to_file(); // Write all collected traces to a file (e.g., "sft_kernel_traces.json") -#endif - -} // namespace sft_timer - #endif diff --git a/kt-kernel/operators/amx/la/avx_kernels.hpp b/kt-kernel/operators/amx/la/avx_kernels.hpp index 463a8f3f..1ae493d2 100644 --- a/kt-kernel/operators/amx/la/avx_kernels.hpp +++ b/kt-kernel/operators/amx/la/avx_kernels.hpp @@ -48,9 +48,6 @@ namespace avx { */ inline void lora_bf16_matmul_t4r4(const ggml_bf16_t* __restrict input, const ggml_bf16_t* __restrict weight, float* __restrict output, int num_tokens, int k_dim, int rank) { - // #if AVX_KERNEL_TRACE_ENABLED - // uint64_t trace_start = sft_timer::get_trace_timestamp(); - // #endif constexpr int T_BLOCK = 4; constexpr int R_BLOCK = 4; @@ -240,13 +237,6 @@ inline void lora_bf16_matmul_t4r4(const ggml_bf16_t* __restrict input, const ggm } } - // #if AVX_KERNEL_TRACE_ENABLED - // uint64_t trace_end = sft_timer::get_trace_timestamp(); - // char args_buf[128]; - // snprintf(args_buf, sizeof(args_buf), "{\"T\":%d,\"K\":%d,\"R\":%d}", num_tokens, k_dim, rank); - // sft_timer::add_kernel_trace("lora_bf16_matmul_t4r4", trace_start, trace_end, 0, WorkerPool::thread_local_id, - // args_buf); - // #endif } /** @@ -275,7 +265,6 @@ inline void lora_fp32_bf16_fused_add(const float* __restrict intermediate, const ggml_bf16_t* __restrict output, int num_tokens, int rank, int output_dim, float scale) { #if AVX_KERNEL_TRACE_ENABLED - uint64_t trace_start = sft_timer::get_trace_timestamp(); #endif constexpr int T_BLOCK = 4; @@ -579,11 +568,8 @@ inline void lora_fp32_bf16_fused_add(const float* __restrict intermediate, const } #if AVX_KERNEL_TRACE_ENABLED - uint64_t trace_end = sft_timer::get_trace_timestamp(); char args_buf[128]; snprintf(args_buf, sizeof(args_buf), "{\"T\":%d,\"R\":%d,\"O\":%d}", num_tokens, rank, output_dim); - sft_timer::add_kernel_trace("lora_fp32_bf16_fused_add", trace_start, trace_end, 0, WorkerPool::thread_local_id, - args_buf); #endif } @@ -615,7 +601,6 @@ inline void lora_fp32_bf16_fused_add_wt(const float* __restrict intermediate, co ggml_bf16_t* __restrict output, int num_tokens, int rank, int output_dim, float scale) { #if AVX_KERNEL_TRACE_ENABLED - uint64_t trace_start = sft_timer::get_trace_timestamp(); #endif constexpr int T_BLOCK = 4; @@ -886,11 +871,8 @@ inline void lora_fp32_bf16_fused_add_wt(const float* __restrict intermediate, co } #if AVX_KERNEL_TRACE_ENABLED - uint64_t trace_end = sft_timer::get_trace_timestamp(); char args_buf[128]; snprintf(args_buf, sizeof(args_buf), "{\"T\":%d,\"R\":%d,\"O\":%d}", num_tokens, rank, output_dim); - sft_timer::add_kernel_trace("lora_fp32_bf16_fused_add_wt", trace_start, trace_end, 0, WorkerPool::thread_local_id, - args_buf); #endif } @@ -1190,7 +1172,6 @@ inline void lora_fp32_bf16_fused_add_transposed(const float* __restrict intermed const ggml_bf16_t* __restrict weight_t, ggml_bf16_t* __restrict output, int num_tokens, int rank, int output_dim, float scale) { #if AVX_KERNEL_TRACE_ENABLED - uint64_t trace_start = sft_timer::get_trace_timestamp(); #endif constexpr int T_BLOCK = 4; @@ -1384,11 +1365,8 @@ inline void lora_fp32_bf16_fused_add_transposed(const float* __restrict intermed } #if AVX_KERNEL_TRACE_ENABLED - uint64_t trace_end = sft_timer::get_trace_timestamp(); char args_buf[128]; snprintf(args_buf, sizeof(args_buf), "{\"T\":%d,\"R\":%d,\"O\":%d}", num_tokens, rank, output_dim); - sft_timer::add_kernel_trace("lora_fp32_bf16_fused_add_transposed", trace_start, trace_end, 0, - WorkerPool::thread_local_id, args_buf); #endif } @@ -1408,7 +1386,6 @@ inline void lora_fp32_bf16_fused_add_transposed(const float* __restrict intermed inline void lora_backward_matmul_transposed(const ggml_bf16_t* __restrict grad, const ggml_bf16_t* __restrict lora_b_t, float* __restrict result, int num_tokens, int hidden, int rank) { #if AVX_KERNEL_TRACE_ENABLED - uint64_t trace_start = sft_timer::get_trace_timestamp(); #endif constexpr int H_BLOCK = 32; @@ -1448,11 +1425,8 @@ inline void lora_backward_matmul_transposed(const ggml_bf16_t* __restrict grad, } #if AVX_KERNEL_TRACE_ENABLED - uint64_t trace_end = sft_timer::get_trace_timestamp(); char args_buf[128]; snprintf(args_buf, sizeof(args_buf), "{\"T\":%d,\"H\":%d,\"R\":%d}", num_tokens, hidden, rank); - sft_timer::add_kernel_trace("lora_backward_matmul_transposed", trace_start, trace_end, 0, WorkerPool::thread_local_id, - args_buf); #endif } diff --git a/kt-kernel/operators/amx/sft_moe.hpp b/kt-kernel/operators/amx/sft_moe.hpp index bbd8eeea..3ef9c044 100644 --- a/kt-kernel/operators/amx/sft_moe.hpp +++ b/kt-kernel/operators/amx/sft_moe.hpp @@ -1075,13 +1075,13 @@ class AMX_SFT_MOE_TP : public BaseMOE { } // Step 3: Copy input to expert buffers - auto direct_or_pool = [&](int count, auto&& fn, const char* task_name, int block_size) { + auto direct_or_pool = [&](int count, auto&& fn) { if (qlen < 10) { for (int i = 0; i < count; i++) { fn(i); } } else { - pool->do_work_stealing_job(count, nullptr, fn, nullptr, task_name, block_size); + pool->do_work_stealing_job(count, nullptr, fn, nullptr); } }; @@ -1095,8 +1095,7 @@ class AMX_SFT_MOE_TP : public BaseMOE { memcpy(m_local_input_ptr_[expert_ids[i * k + j]] + m_local_pos_[i][j] * config_.hidden_size, (ggml_bf16_t*)input + i * config_.hidden_size, sizeof(ggml_bf16_t) * config_.hidden_size); } - }, - "fwd_pack_input", 1); + }); // NaN Check: Step 3 - Packed input if (is_nan_check_enabled()) { @@ -1118,8 +1117,7 @@ class AMX_SFT_MOE_TP : public BaseMOE { [this](int task_id) { int expert_idx = m_expert_id_map_[task_id]; gate_up_ba_[expert_idx]->from_mat(m_local_num_[expert_idx], m_local_input_ptr_[expert_idx], 0, 1); - }, - "fwd_quantize_in", 1); + }); // Step 5: Gate + Up GEMM (base projection) int nth = T::recommended_nth(config_.intermediate_size); @@ -1137,7 +1135,7 @@ class AMX_SFT_MOE_TP : public BaseMOE { gate_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_gate_output_ptr_[expert_idx], ith, nth); } }, - nullptr, "fwd_gate_up_gemm", 1); + nullptr); // NaN Check: Step 5 - Gate/Up GEMM output (before LoRA) if (is_nan_check_enabled()) { @@ -1230,10 +1228,7 @@ class AMX_SFT_MOE_TP : public BaseMOE { // Step 6: Activation (silu(gate) * up) { - uint64_t act_start = sft_timer::get_trace_timestamp(); Base::apply_activation(activated_expert, nth, qlen); - uint64_t act_end = sft_timer::get_trace_timestamp(); - sft_timer::add_kernel_trace("apply_activation", act_start, act_end, tp_part_idx, 0); } // NaN Check: Step 6 - Activation output (silu(gate) * up) @@ -1295,7 +1290,7 @@ class AMX_SFT_MOE_TP : public BaseMOE { int expert_idx = m_expert_id_map_[task_id]; down_ba_[expert_idx]->from_mat(m_local_num_[expert_idx], m_local_gate_output_ptr_[expert_idx], 0, 1); }, - nullptr, "fwd_down_quantize"); + nullptr); // Step 8: Down GEMM nth = T::recommended_nth(config_.hidden_size); @@ -1307,7 +1302,7 @@ class AMX_SFT_MOE_TP : public BaseMOE { this->do_down_gemm(expert_idx, ith, nth, qlen); down_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_down_output_ptr_[expert_idx], ith, nth); }, - nullptr, "fwd_down_gemm", 1); + nullptr); // NaN Check: Step 8 - Down GEMM output (before LoRA) if (is_nan_check_enabled()) { @@ -1373,7 +1368,7 @@ class AMX_SFT_MOE_TP : public BaseMOE { f32out[1] = x1; } }, - nullptr, "fwd_merge"); + nullptr); // NaN Check: Step 9 - Final output (after weighted merge) if (is_nan_check_enabled()) { @@ -1428,12 +1423,6 @@ class AMX_SFT_MOE_TP : public BaseMOE { int activated_expert = cache.activated_expert_cache; constexpr int kSmallBwdDirectQlen = 0; constexpr int kSmallBwdDirectMaxTasks = 16; - auto trace_phase = [this](const char* name, auto&& fn) { - uint64_t start = sft_timer::get_trace_timestamp(); - fn(); - uint64_t end = sft_timer::get_trace_timestamp(); - sft_timer::add_kernel_trace(name, start, end, tp_part_idx, 0); - }; // NaN Check: grad_output input if (is_nan_check_enabled()) { @@ -1492,9 +1481,8 @@ class AMX_SFT_MOE_TP : public BaseMOE { } // ★ Allocate backward-phase buffers ★ - trace_phase("bwd_setup_alloc", [&] { alloc_backward_buffers(); }); + alloc_backward_buffers(); - trace_phase("bwd_setup_state", [&] { // ★ share_backward_bb: check if async repack already prepared this layer ★ if (config_.share_backward_bb) { auto& shared = SFTSharedPools::instance(); @@ -1544,12 +1532,10 @@ class AMX_SFT_MOE_TP : public BaseMOE { m_local_down_output_ptr_[i] = m_local_down_output_ + offset * config_.hidden_size; offset += m_local_num_[i]; } - }); // Restore input data from cache into m_local_input_ (shared_mem_buffer may have been // overwritten by subsequent layers' forward passes). This is needed for gate/up LoRA // gradient computation which reads from m_local_input_ptr_. - trace_phase("bwd_phase_down", [&] { auto pool_local = config_.pool->get_subpool(tp_part_idx); auto restore_input = [&](int i) { for (int j = 0; j < k; j++) { @@ -1569,7 +1555,7 @@ class AMX_SFT_MOE_TP : public BaseMOE { restore_input(i); } } else { - pool_local->do_work_stealing_job(qlen, nullptr, restore_input, nullptr, "bwd_restore_input", 1); + pool_local->do_work_stealing_job(qlen, nullptr, restore_input, nullptr); } // Step 1: Down projection backward @@ -1579,7 +1565,6 @@ class AMX_SFT_MOE_TP : public BaseMOE { } else { // backward_down(cache, grad_output, grad_down_lora_a, grad_down_lora_b); } - }); // // Compute total tokens for debug // size_t total_tokens = 0; @@ -1655,7 +1640,7 @@ class AMX_SFT_MOE_TP : public BaseMOE { // } // } - trace_phase("bwd_phase_act", [&] { backward_activation(cache); }); + backward_activation(cache); // NaN Check: Step 2 - After backward_activation if (is_nan_check_enabled()) { @@ -1695,14 +1680,12 @@ class AMX_SFT_MOE_TP : public BaseMOE { // } // } - trace_phase("bwd_phase_gate_up", [&] { if constexpr (supports_standard_mat_mul_v) { backward_gate_up_amx(cache, grad_input, grad_gate_lora_a, grad_gate_lora_b, grad_up_lora_a, grad_up_lora_b, full_intermediate_size, fp32_grad_gate_lora_a, fp32_grad_up_lora_a); } else { // backward_gate_up(cache, grad_input, grad_gate_lora_a, grad_gate_lora_b, grad_up_lora_a, grad_up_lora_b); } - }); // NaN Check: Step 3 - After backward_gate_up if (is_nan_check_enabled()) { @@ -1740,7 +1723,6 @@ class AMX_SFT_MOE_TP : public BaseMOE { // Step 4: Compute grad_weights (gradient for routing weights) // grad_weights[token_idx, expert_pos] = dot(grad_output[token_idx], down_output[token, expert]) if (grad_weights != nullptr) { - trace_phase("bwd_phase_gradw", [&] { auto pool = config_.pool->get_subpool(tp_part_idx); float* grad_w = (float*)grad_weights; const ggml_bf16_t* grad_out = (const ggml_bf16_t*)grad_output; @@ -1788,9 +1770,8 @@ class AMX_SFT_MOE_TP : public BaseMOE { compute_grad_weight(token_idx); } } else { - pool->do_work_stealing_job(qlen, nullptr, compute_grad_weight, nullptr, "bwd_grad_weights"); + pool->do_work_stealing_job(qlen, nullptr, compute_grad_weight, nullptr); } - }); } // NaN Check: Step 4 - After grad_weights computation @@ -2106,7 +2087,7 @@ class AMX_SFT_MOE_TP : public BaseMOE { break; } }, - nullptr, "transpose_lora_b_weights"); + nullptr); } /** @@ -2167,7 +2148,7 @@ class AMX_SFT_MOE_TP : public BaseMOE { break; } }, - nullptr, "fwd_lora_prep"); + nullptr); lora_weights_prepared_ = true; } @@ -2212,7 +2193,7 @@ class AMX_SFT_MOE_TP : public BaseMOE { break; } }, - nullptr, "bwd_lora_prep"); + nullptr); lora_backward_weights_prepared_ = true; } @@ -2265,7 +2246,7 @@ class AMX_SFT_MOE_TP : public BaseMOE { dst_bb->from_mat_transposed((ggml_bf16_t*)(src + expert_offset), config_.intermediate_size, config_.hidden_size, ith, nth_gate_up); }, - nullptr, "bwd_prep_gate_up"); + nullptr); // Phase 2: down backward // down_proj: [hidden_size, intermediate_size] -> transposed BufferB [intermediate_size, hidden_size] @@ -2281,7 +2262,7 @@ class AMX_SFT_MOE_TP : public BaseMOE { down_backward_bb_[expert_idx]->from_mat_transposed((ggml_bf16_t*)(src + expert_offset), config_.hidden_size, config_.intermediate_size, ith, nth_down); }, - nullptr, "bwd_prep_down"); + nullptr); backward_weights_prepared_ = true; } @@ -2317,7 +2298,7 @@ class AMX_SFT_MOE_TP : public BaseMOE { dst_bb->from_mat_transposed(workspace.data(), src_bb->n, src_bb->k, p, dst_nth); } }, - nullptr, "bwd_repack_gate_up"); + nullptr); // Phase 2: down (uses [hidden_size, intermediate_size] -> [intermediate_size, hidden_size]) pool->do_work_stealing_job( @@ -2339,7 +2320,7 @@ class AMX_SFT_MOE_TP : public BaseMOE { dst_bb->from_mat_transposed(workspace.data(), src_bb->n, src_bb->k, p, dst_nth); } }, - nullptr, "bwd_repack_down"); + nullptr); backward_weights_prepared_ = true; } @@ -2543,7 +2524,7 @@ class AMX_SFT_MOE_TP : public BaseMOE { down_backward_bb_[expert_idx].get()); } }, - nullptr, "load_bwd_kt"); + nullptr); if (!ok.load()) return false; backward_weights_prepared_ = true; @@ -2596,7 +2577,7 @@ class AMX_SFT_MOE_TP : public BaseMOE { } } }, - nullptr, "load_bwd_projs"); + nullptr); backward_weights_prepared_ = true; } @@ -3421,7 +3402,7 @@ class AMX_SFT_MOE_TP : public BaseMOE { do_up ? lora_up_intermediate_ptr_[expert_idx] : lora_gate_intermediate_ptr_[expert_idx]; bc->to_mat(m, inter_ptr, ith, nth); }, - nullptr, "fwd_lora_gu_a"); + nullptr); // ===================================================== // Step 2: Quantize lora_intermediate to BufferA @@ -3439,7 +3420,7 @@ class AMX_SFT_MOE_TP : public BaseMOE { ggml_bf16_t* ptr = do_up ? lora_up_intermediate_ptr_[expert_idx] : lora_gate_intermediate_ptr_[expert_idx]; ba->from_mat(m, ptr, 0, 1); }, - nullptr, "fwd_lora_gu_quant"); + nullptr); // ===================================================== // Step 3a: lora_intermediate @ lora_B^T -> lora_output (GEMM only) @@ -3464,7 +3445,7 @@ class AMX_SFT_MOE_TP : public BaseMOE { // GEMM: [m, padded_lora_rank] @ [intermediate_size, padded_lora_rank]^T -> [m, intermediate_size] amx::mat_mul(m, config_.intermediate_size, padded_lora_rank_, ba, bb, bc, ith, nth); }, - nullptr, "fwd_lora_gu_gemm"); + nullptr); // ===================================================== // Step 3b: Add LoRA output to main output with scaling @@ -3489,7 +3470,7 @@ class AMX_SFT_MOE_TP : public BaseMOE { add_lora_output_to_main(bc.get(), main_output, m, config_.intermediate_size, lora_scaling_, ith, nth, lora_sum_ptr); }, - nullptr, "fwd_lora_gu_add"); + nullptr); } /** @@ -3581,7 +3562,7 @@ class AMX_SFT_MOE_TP : public BaseMOE { // Convert BufferC to BF16 for step 2 input bc->to_mat(m, lora_gate_intermediate_ptr_[expert_idx], ith, nth); }, - nullptr, "fwd_lora_down_a"); + nullptr); @@ -3597,7 +3578,7 @@ class AMX_SFT_MOE_TP : public BaseMOE { // Reuse gate intermediate buffer (no race condition for down projection) lora_gate_intermediate_ba_[expert_idx]->from_mat(m, lora_gate_intermediate_ptr_[expert_idx], 0, 1); }, - nullptr, "fwd_lora_down_quant"); + nullptr); // ===================================================== // Step 3a: lora_intermediate @ down_lora_B^T -> lora_output (GEMM only) @@ -3620,7 +3601,7 @@ class AMX_SFT_MOE_TP : public BaseMOE { // GEMM: [m, padded_lora_rank] @ [hidden_size, padded_lora_rank]^T -> [m, hidden_size] amx::mat_mul(m, config_.hidden_size, padded_lora_rank_, ba, bb, bc, ith, nth); }, - nullptr, "fwd_lora_down_gemm", 1); + nullptr); // ===================================================== // Step 3b: Add LoRA output to main output with scaling @@ -3642,7 +3623,7 @@ class AMX_SFT_MOE_TP : public BaseMOE { add_lora_output_to_main(bc.get(), m_local_down_output_ptr_[expert_idx], m, config_.hidden_size, lora_scaling_, ith, nth, &down_lora_sum); }, - nullptr, "fwd_lora_down_add"); + nullptr); // // Print LoRA contribution statistics // size_t total_elements = 0; @@ -3791,7 +3772,7 @@ class AMX_SFT_MOE_TP : public BaseMOE { output + t_start * inter_size, // output [local_num_tokens, inter_size] local_num_tokens, rank, inter_size, scale); }, - nullptr, "fwd_lora_gu"); + nullptr); } /** @@ -3856,7 +3837,7 @@ class AMX_SFT_MOE_TP : public BaseMOE { output + t_start * hidden, // output [local_num_tokens, hidden] local_num_tokens, rank, hidden, scale); }, - nullptr, "fwd_lora_down"); + nullptr); } ForwardCache& push_cache() { @@ -3941,7 +3922,7 @@ class AMX_SFT_MOE_TP : public BaseMOE { } } }, - nullptr, "save_cache"); + nullptr); cache.valid = true; } @@ -3968,7 +3949,7 @@ class AMX_SFT_MOE_TP : public BaseMOE { memcpy(cache.intermediate_cache + cache_offsets_[i] * config_.intermediate_size, m_local_gate_output_ptr_[expert_idx], num_tokens * config_.intermediate_size * sizeof(ggml_bf16_t)); }, - nullptr, "save_inter_cache"); + nullptr); } /** @@ -3993,7 +3974,7 @@ class AMX_SFT_MOE_TP : public BaseMOE { memcpy(cache.down_output_cache + cache_offsets_[i] * config_.hidden_size, src_ptr, num_tokens * config_.hidden_size * sizeof(ggml_bf16_t)); }, - nullptr, "save_down_cache"); + nullptr); } void backward_down(const ForwardCache& cache, const void* grad_output, void* grad_down_lora_a, @@ -4025,7 +4006,7 @@ class AMX_SFT_MOE_TP : public BaseMOE { memset(reinterpret_cast(grad_intermediate_) + offset, 0, size); } }, - nullptr, "bwd_down_memset"); + nullptr); } // Scatter grad_output to per-expert buffers and compute gradients @@ -4176,7 +4157,7 @@ class AMX_SFT_MOE_TP : public BaseMOE { } } }, - nullptr, "bwd_down"); + nullptr); } /** @@ -4198,13 +4179,13 @@ class AMX_SFT_MOE_TP : public BaseMOE { int k = cache.k_cache; constexpr int kSmallBwdDirectQlen = 0; constexpr int kSmallBwdDirectMaxTasks = 16; - auto direct_or_pool = [&](int count, auto&& fn, const char* task_name, int block_size = 1) { + auto direct_or_pool = [&](int count, auto&& fn) { if (qlen <= kSmallBwdDirectQlen && count <= kSmallBwdDirectMaxTasks) { for (int i = 0; i < count; i++) { fn(i); } } else { - pool->do_work_stealing_job(count, nullptr, fn, nullptr, task_name, block_size); + pool->do_work_stealing_job(count, nullptr, fn, nullptr); } }; @@ -4258,8 +4239,7 @@ class AMX_SFT_MOE_TP : public BaseMOE { int num_tokens = m_local_num_[expert_idx]; if (num_tokens == 0) return; memset(grad_output_bf16_ptr_[expert_idx], 0, num_tokens * config_.hidden_size * sizeof(ggml_bf16_t)); - }, - "bwd_down_zero"); + }); // ===================================================== // Step 2: Scatter grad_output to per-expert BF16 buffers @@ -4305,8 +4285,7 @@ class AMX_SFT_MOE_TP : public BaseMOE { dst_row[h] = GGML_FP32_TO_BF16(cur); } } - }, - "bwd_down_scatter"); + }); } // ===================================================== @@ -4319,8 +4298,7 @@ class AMX_SFT_MOE_TP : public BaseMOE { int num_tokens = m_local_num_[expert_idx]; if (num_tokens == 0) return; grad_output_ba_[expert_idx]->from_mat(num_tokens, grad_output_bf16_ptr_[expert_idx], 0, 1); - }, - "bwd_down_quantize"); + }); // ===================================================== // Step 3+4: AMX GEMM + to_mat (merged to use same ith/nth) @@ -4365,7 +4343,7 @@ class AMX_SFT_MOE_TP : public BaseMOE { // to_mat: Convert BufferC to BF16 - use same ith, nth as mat_mul! bc->to_mat(m, grad_intermediate_ + expert_offsets[task_idx], ith, nth); }, - nullptr, "bwd_down_gemm", 1); + nullptr); // ===================================================== // Step 3.5: Add LoRA contribution to grad_intermediate (AVX512) @@ -4417,8 +4395,7 @@ class AMX_SFT_MOE_TP : public BaseMOE { expert_lora_a, // [rank, inter_size] BF16 grad_inter + t_start * inter_size, // [local_num_tokens, inter_size] BF16 local_num_tokens, rank, inter_size, scale); - }, - "bwd_down_lora_to_inter"); + }); } @@ -4600,7 +4577,7 @@ class AMX_SFT_MOE_TP : public BaseMOE { } } }, - nullptr, "bwd_down_lora_grad_AB"); + nullptr); } else { struct LoraGradTask { int expert_task = -1; @@ -4609,7 +4586,6 @@ class AMX_SFT_MOE_TP : public BaseMOE { }; std::vector lora_grad_tasks; - uint64_t grad_setup_start = sft_timer::get_trace_timestamp(); float* grad_b_accum_all = down_lora_grad_b_accum_pool_; float* grad_a_accum_all = down_lora_grad_a_accum_pool_; if (activated_expert > 0) { @@ -4630,8 +4606,6 @@ class AMX_SFT_MOE_TP : public BaseMOE { lora_grad_tasks.push_back({task_id, t, std::min(t + token_tile, buf.num_tokens)}); } } - uint64_t grad_setup_end = sft_timer::get_trace_timestamp(); - sft_timer::add_kernel_trace("bwd_down_lora_grad_setup", grad_setup_start, grad_setup_end, tp_part_idx, 0); if (!lora_grad_tasks.empty()) { direct_or_pool( @@ -4748,8 +4722,7 @@ class AMX_SFT_MOE_TP : public BaseMOE { grad_a_global[off] += grad_a_local[off]; } } - }, - "bwd_down_lora_grad_AB"); + }); constexpr int kDownGradBTile = 512; constexpr int kDownGradATile = 512; @@ -4788,7 +4761,7 @@ class AMX_SFT_MOE_TP : public BaseMOE { } } }, - nullptr, "bwd_down_lora_write_B"); + nullptr); pool->do_work_stealing_job( activated_expert * grad_a_blocks, nullptr, @@ -4812,7 +4785,7 @@ class AMX_SFT_MOE_TP : public BaseMOE { } } }, - nullptr, "bwd_down_lora_write_A"); + nullptr); } } } @@ -4824,13 +4797,13 @@ class AMX_SFT_MOE_TP : public BaseMOE { int qlen = cache.qlen_cache; constexpr int kSmallBwdDirectQlen = 0; constexpr int kSmallBwdDirectMaxTasks = 16; - auto direct_or_pool = [&](int count, auto&& fn, const char* task_name, int block_size = 1) { + auto direct_or_pool = [&](int count, auto&& fn) { if (qlen <= kSmallBwdDirectQlen && count <= kSmallBwdDirectMaxTasks) { for (int i = 0; i < count; i++) { fn(i); } } else { - pool->do_work_stealing_job(count, nullptr, fn, nullptr, task_name, block_size); + pool->do_work_stealing_job(count, nullptr, fn, nullptr); } }; @@ -4942,8 +4915,7 @@ class AMX_SFT_MOE_TP : public BaseMOE { grad_gate[i] = GGML_FP32_TO_BF16(grad_i_val * u_val * sigmoid_val * (1.0f + g_val * (1.0f - sigmoid_val))); grad_up[i] = GGML_FP32_TO_BF16(grad_i_val * silu_val); } - }, - "bwd_act_silu"); + }); } @@ -4963,13 +4935,13 @@ class AMX_SFT_MOE_TP : public BaseMOE { int k = cache.k_cache; constexpr int kSmallBwdDirectQlen = 0; constexpr int kSmallBwdDirectMaxTasks = 16; - auto direct_or_pool = [&](int count, auto&& fn, const char* task_name, int block_size = 1) { + auto direct_or_pool = [&](int count, auto&& fn) { if (qlen <= kSmallBwdDirectQlen && count <= kSmallBwdDirectMaxTasks) { for (int i = 0; i < count; i++) { fn(i); } } else { - pool->do_work_stealing_job(count, nullptr, fn, nullptr, task_name, block_size); + pool->do_work_stealing_job(count, nullptr, fn, nullptr); } }; @@ -5065,7 +5037,7 @@ class AMX_SFT_MOE_TP : public BaseMOE { } } - auto scatter_to_grad_input = [&](float scale, const char* task_name) { + auto scatter_to_grad_input = [&](float scale) { ggml_bf16_t* grad_input_bf16 = (ggml_bf16_t*)grad_input; const int hidden = config_.hidden_size; const int hidden_vec_end = hidden & ~31; @@ -5101,13 +5073,10 @@ class AMX_SFT_MOE_TP : public BaseMOE { dst[h] = GGML_FP32_TO_BF16(cur); } } - }, - task_name); + }); }; auto base_pass = [&](bool do_up) { - const char* quant_name = do_up ? "bwd_gu_base_q_up" : "bwd_gu_base_q_gate"; - const char* gemm_name = do_up ? "bwd_gu_base_gemm_up" : "bwd_gu_base_gemm_gate"; // Quantize grad to BufferA direct_or_pool( @@ -5121,8 +5090,7 @@ class AMX_SFT_MOE_TP : public BaseMOE { ggml_bf16_t* grad = do_up ? (grad_up_output_ + offset * config_.intermediate_size) : (grad_gate_output_ + offset * config_.intermediate_size); down_ba_[expert_idx]->from_mat(m, grad, 0, 1); - }, - quant_name); + }); int nth = T::recommended_nth(config_.hidden_size); pool->do_work_stealing_job( @@ -5141,9 +5109,9 @@ class AMX_SFT_MOE_TP : public BaseMOE { amx::mat_mul(m, config_.hidden_size, config_.intermediate_size, ba, bb, bc, ith, nth); bc->to_mat(m, grad_output_bf16_ptr_[expert_idx], ith, nth); }, - nullptr, gemm_name, 1); + nullptr); - scatter_to_grad_input(1.0f, "bwd_gu_scatter_base"); + scatter_to_grad_input(1.0f); }; base_pass(false); // gate @@ -5341,8 +5309,7 @@ class AMX_SFT_MOE_TP : public BaseMOE { up_gradb_global[off] += up_gradb_local[off]; } } - }, - "bwd_gu_lora_u_gradb_fused"); + }); constexpr int kGuGradBBlock = 256; int gradb_blocks = (inter_size + kGuGradBBlock - 1) / kGuGradBBlock; @@ -5375,7 +5342,7 @@ class AMX_SFT_MOE_TP : public BaseMOE { } } }, - nullptr, "bwd_gu_lora_gradb_fused_write"); + nullptr); } // ===================================================== @@ -5386,8 +5353,6 @@ class AMX_SFT_MOE_TP : public BaseMOE { // Gate and up still run sequentially because they share grad_output_bf16_ptr_. // ===================================================== auto lora_pass_remainder = [&](bool do_up) { - const char* gb_gradin_name = do_up ? "bwd_gu_lora_gb_gradin_fused_up" : "bwd_gu_lora_gb_gradin_fused_gate"; - const char* grad_a_name = do_up ? "bwd_gu_lora_gradA_up" : "bwd_gu_lora_gradA_gate"; struct GuLoraGradInTask { int expert_task = -1; @@ -5443,11 +5408,10 @@ class AMX_SFT_MOE_TP : public BaseMOE { memset(grad_out, 0, static_cast(local_tokens) * hidden * sizeof(ggml_bf16_t)); avx::lora_fp32_bf16_fused_add_transposed(gb, lora_a, grad_out, local_tokens, lora_rank_, hidden, 1.0f); - }, - gb_gradin_name); + }); } - scatter_to_grad_input(lora_scaling_, "bwd_gu_scatter_lora"); + scatter_to_grad_input(lora_scaling_); // Step 6: grad_A = G_B^T @ X ggml_bf16_t* grad_lora_a = do_up ? grad_up_a : grad_gate_a; @@ -5547,7 +5511,7 @@ class AMX_SFT_MOE_TP : public BaseMOE { } } }, - nullptr, grad_a_name); + nullptr); }; lora_pass_remainder(false); // gate: gb_gradin_fused, scatter, gradA diff --git a/kt-kernel/operators/moe-sft-tp.hpp b/kt-kernel/operators/moe-sft-tp.hpp index 35c3e9ec..71361ae3 100644 --- a/kt-kernel/operators/moe-sft-tp.hpp +++ b/kt-kernel/operators/moe-sft-tp.hpp @@ -318,7 +318,7 @@ class TP_MOE_SFT : public TP_MOE { sizeof(ggml_bf16_t) * tpc.intermediate_size); } }, - nullptr, "memcpy_weights_tmp"); + nullptr); } // Set BF16 weight pointers on sub-MOEs for backward @@ -490,16 +490,13 @@ class TP_MOE_SFT : public TP_MOE { throw std::runtime_error("Weights not loaded"); } - auto start_sft = sft_timer::get_trace_timestamp(); int qlen = *qlen_ptr; auto pool = config.pool; // Reset forward timing before computation - // sft_timer::reset_forward(); // Reset per-thread counters in each subpool (to accumulate all do_work_stealing_job calls) for (int i = 0; i < tp_count; i++) { - pool->get_subpool(i)->reset_counters(); } // Run forward on each NUMA node @@ -508,24 +505,18 @@ class TP_MOE_SFT : public TP_MOE { save_for_backward); }); - auto end_fwd = sft_timer::get_trace_timestamp(); // // Collect per-thread timing from all NUMA subpools // for (int i = 0; i < tp_count; i++) { - // sft_timer::collect_forward(pool->get_subpool(i)); // } // // Print per-thread forward timing - // sft_timer::print_forward(); // Merge results from all NUMA nodes this->merge_results(qlen, output); - auto end_merge = sft_timer::get_trace_timestamp(); pool->dispense_backend()->do_numa_job([&](int numa_id) { - sft_timer::add_kernel_trace("fwd", start_sft, end_fwd, numa_id, 0); - sft_timer::add_kernel_trace("merge", end_fwd, end_merge, numa_id, 0); }); } @@ -561,7 +552,6 @@ class TP_MOE_SFT : public TP_MOE { void* grad_weights) { auto pool = config.pool; - auto start_sft = sft_timer::get_trace_timestamp(); // Get full intermediate_size (before TP partitioning) int full_intermediate_size = sft_config.intermediate_size; @@ -665,9 +655,8 @@ class TP_MOE_SFT : public TP_MOE { const auto& seg = clear_segs[(size_t)seg_idx]; std::memset(seg.ptr, 0, seg.len); }, - nullptr, "bwd_alloc_memset"); + nullptr); - auto end_alloc = sft_timer::get_trace_timestamp(); // Compute TP-slice pointers for copy-type direct writes // Each TP writes to its own I-slice of the final output tensor @@ -697,7 +686,6 @@ class TP_MOE_SFT : public TP_MOE { // Run backward on each NUMA node pool->dispense_backend()->do_numa_job([&](int numa_id) { - auto start_Bwd = sft_timer::get_trace_timestamp(); tps[numa_id]->backward(grad_output, part_grad_input_[numa_id], // reduce-type: BF16 pointer unused (FP32 sparse used instead) nullptr, /* grad_gate_lora_a — unused, FP32 path below */ @@ -708,18 +696,13 @@ class TP_MOE_SFT : public TP_MOE { nullptr, /* grad_down_lora_b — unused, FP32 path below */ part_grad_weights_[numa_id], full_intermediate_size, tp_fp32_down_b[numa_id], tp_fp32_gate_a[numa_id], tp_fp32_up_a[numa_id]); - auto end_bwd = sft_timer::get_trace_timestamp(); - sft_timer::add_kernel_trace("bwd_alloc", start_sft, end_alloc, numa_id, 0); - sft_timer::add_kernel_trace("bwd_tp", start_Bwd, end_bwd, numa_id, 0); }); // // Collect per-thread timing from all NUMA subpools // for (int i = 0; i < tp_count; i++) { - // sft_timer::collect_backward(pool->get_subpool(i)); // } // // Print per-thread backward timing - // sft_timer::print_backward(); // // Print expert token distribution for load balancing analysis // { @@ -739,7 +722,6 @@ class TP_MOE_SFT : public TP_MOE { // } // Bug #22 fix: Merge grad_input from all NUMA nodes (sum them together) - auto start_sum = sft_timer::get_trace_timestamp(); { auto* out = (ggml_bf16_t*)grad_input; pool->do_work_stealing_job( @@ -784,13 +766,11 @@ class TP_MOE_SFT : public TP_MOE { dst[h] = GGML_FP32_TO_BF16(sum); } }, - nullptr, "merge_grad_input"); + nullptr); } - auto end_sum = sft_timer::get_trace_timestamp(); // Merge reduce-type LoRA gradients: sparse FP32 sum across TPs → BF16 final output // Copy-type grads (gate/up_lora_b, down_lora_a) were written directly — no merge needed. - auto start_merge = sft_timer::get_trace_timestamp(); if constexpr (!kSkipLoRA) { // Sparse merge for gate_lora_a, up_lora_a: [active_count, r, H] FP32 → [E, r, H] BF16 { @@ -835,7 +815,7 @@ class TP_MOE_SFT : public TP_MOE { ud[h] = GGML_FP32_TO_BF16(us); } }, - nullptr, "merge_lora_a"); + nullptr); } // Sparse merge for down_lora_b: [active_count, H, r] FP32 → [E, H, r] BF16 @@ -861,7 +841,7 @@ class TP_MOE_SFT : public TP_MOE { } } }, - nullptr, "merge_down_lora_b"); + nullptr); } } // if constexpr (!kSkipLoRA) @@ -900,14 +880,10 @@ class TP_MOE_SFT : public TP_MOE { out_grad_weights[i] = sum; } }, - nullptr, "merge_grad_weights"); + nullptr); } - auto end_merge = sft_timer::get_trace_timestamp(); pool->dispense_backend()->do_numa_job([&](int numa_id) { - sft_timer::add_kernel_trace("merge_tp", start_sum, end_sum, numa_id, 0); - sft_timer::add_kernel_trace("merge_lora_a", end_sum, start_merge, numa_id, 0); - sft_timer::add_kernel_trace("merge_grad_weights", start_merge, end_merge, numa_id, 0); }); } @@ -989,8 +965,7 @@ class TP_MOE_SFT : public TP_MOE { numa_id * tp_inter, sizeof(ggml_bf16_t) * tp_inter); } - }, - "upd_lora_tp"); + }); // Update weights after all memcpy complete tps[numa_id]->update_lora_weights(gate_lora_a, partitioned_gate_lora_b_[numa_id], up_lora_a, @@ -1074,7 +1049,7 @@ class TP_MOE_SFT : public TP_MOE { sizeof(ggml_bf16_t) * tpc.intermediate_size); } }, - nullptr, "memcpy_bwd_tmp"); + nullptr); tps[i]->prepare_bwd(temp_gate, temp_up, temp_down); diff --git a/kt-kernel/python/experts_base.py b/kt-kernel/python/experts_base.py index 9a50427a..d23cf8d4 100644 --- a/kt-kernel/python/experts_base.py +++ b/kt-kernel/python/experts_base.py @@ -90,7 +90,7 @@ class KExpertsCPUBuffer: hidden_size = hidden_states.shape[-1] batch_size = hidden_states.shape[0] - pin_memory = False + pin_memory = True if batch_size in cls.capture_buffers: return cls.capture_buffers[batch_size] diff --git a/kt-kernel/scripts/convert_cpu_weights.py b/kt-kernel/scripts/convert_cpu_weights.py index 9f498889..9295d672 100644 --- a/kt-kernel/scripts/convert_cpu_weights.py +++ b/kt-kernel/scripts/convert_cpu_weights.py @@ -818,35 +818,30 @@ class OnlineQuantConverter(ConverterBase): print(f" [Fused] tensor {p} shape: {tuple(w.shape)}") fused_tensors.append(w) - # fused_tensors[0] : down_proj, [E, H, I] - # fused_tensors[1] : gate_up_proj, [E, 2I, H] + # fused_tensors[0] : down-like, [E, I, H] + # fused_tensors[1] : gate_up-like, [E, H, 2I] down_fused = fused_tensors[0] gate_up_fused = fused_tensors[1] - # gate_up_fused is [E, 2I, H] — split on dim 1, no transpose needed + # gate_up_fused: [E, H, 2I] -> [E, 2I, H] -> gate / up if gate_up_fused.dim() != 3: raise ValueError( f"[Fused] Expect gate_up fused tensor to be 3D, got shape {tuple(gate_up_fused.shape)}" ) - E = gate_up_fused.shape[0] - I = self.moe_intermediate_size - H = self.hidden_size + E, H, twoI = gate_up_fused.shape + if twoI % 2 != 0: + raise ValueError(f"[Fused] gate_up last dim (2I) not even: {twoI}") + I = twoI // 2 - if gate_up_fused.shape != (E, 2 * I, H): - raise ValueError( - f"[Fused] gate_up shape {tuple(gate_up_fused.shape)} != expected ({E}, {2*I}, {H}). " - f"If your model stores gate_up as [E, H, 2I], transpose is needed." - ) - - gate_proj = gate_up_fused[:, :I, :].contiguous() # [E, I, H] - up_proj = gate_up_fused[:, I:, :].contiguous() # [E, I, H] + gate_up_T = gate_up_fused.transpose(1, 2).contiguous() # [E, 2I, H] + gate_proj = gate_up_T[:, :I, :] # [E, I, H] + up_proj = gate_up_T[:, I:, :] # [E, I, H] if down_fused.dim() != 3: raise ValueError(f"[Fused] Expect down fused tensor to be 3D, got shape {tuple(down_fused.shape)}") if down_fused.shape[0] != E: raise ValueError(f"[Fused] down_fused expert dim mismatch: {down_fused.shape[0]} vs gate_up {E}") - # down_proj is [E, H, I] — matches load_weights_from_tensors expectation, no transpose - down_proj = down_fused.contiguous() # [E, H, I] + down_proj = down_fused.transpose(1, 2).contiguous() # [E, H, I] del fused_tensors del gate_up_fused del down_fused