align sft branch with main: revert worker_pool, strip sft_timer, fix inference defaults

- Revert worker_pool.cpp/.h to main (remove RDTSC timer, Chrome Trace,
  sft_timer namespace, ITT API, extended do_work_stealing_job API)
- Strip all sft_timer instrumentation from sft-only files (sft_moe.hpp,
  moe-sft-tp.hpp, avx_kernels.hpp)
- Restore pin_memory=True in KExpertsCPUBuffer (inference path)
- Restore fused tensor transpose logic in convert_cpu_weights.py (main layout)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
mrhaoxx 2026-04-21 17:39:56 +08:00
parent 168e10f254
commit a789729923
No known key found for this signature in database
7 changed files with 159 additions and 766 deletions

View file

@ -13,437 +13,20 @@
#include <hwloc/bitmap.h>
#include <numa.h>
#include <numaif.h>
#include <signal.h>
#include <x86intrin.h>
#include <algorithm>
#include <cassert>
#include <chrono>
#include <cstdio>
#include <cstdlib>
#include <fstream>
#include <iomanip>
#include <mutex>
#include <numeric>
#include <sstream>
#include <stdexcept>
#include <string>
#include <vector>
#include "hwloc.h"
// RDTSC-based timer for lightweight timing
// Uses CPU timestamp counter instead of system clock for lower overhead
namespace {
// Read CPU timestamp counter (RDTSC)
inline uint64_t rdtsc_now() { return __rdtsc(); }
// Estimate RDTSC cycles for given milliseconds
// This is calculated once at startup
static uint64_t g_rdtsc_cycles_per_ms = 0;
// Initialize RDTSC frequency by measuring against chrono
static uint64_t init_rdtsc_frequency() {
auto start_chrono = std::chrono::high_resolution_clock::now();
uint64_t start_rdtsc = rdtsc_now();
// Busy wait for ~10ms to calibrate
while (true) {
auto now = std::chrono::high_resolution_clock::now();
auto elapsed = std::chrono::duration_cast<std::chrono::milliseconds>(now - start_chrono).count();
if (elapsed >= 10) break;
}
uint64_t end_rdtsc = rdtsc_now();
auto end_chrono = std::chrono::high_resolution_clock::now();
auto elapsed_ms = std::chrono::duration_cast<std::chrono::milliseconds>(end_chrono - start_chrono).count();
if (elapsed_ms > 0) {
return (end_rdtsc - start_rdtsc) / elapsed_ms;
}
// Fallback: assume 2.5 GHz CPU
return 2500000;
}
// Get cycles per millisecond (lazy initialization)
inline uint64_t get_rdtsc_cycles_per_ms() {
if (g_rdtsc_cycles_per_ms == 0) {
g_rdtsc_cycles_per_ms = init_rdtsc_frequency();
}
return g_rdtsc_cycles_per_ms;
}
} // namespace
// =====================================================
// Global per-thread timing for SFT MOE forward/backward
// Collects timing from InNumaPool worker threads
// =====================================================
#ifndef SFT_TIMER_DISABLED
namespace sft_timer {
constexpr int MAX_THREADS = 256;
static uint64_t forward_rt[MAX_THREADS] = {0};
static uint64_t backward_rt[MAX_THREADS] = {0};
static int forward_tasks[MAX_THREADS] = {0};
static int backward_tasks[MAX_THREADS] = {0};
static int forward_threads = 0;
static int backward_threads = 0;
inline double ticks_to_ms(uint64_t cycles) { return (double)cycles / get_rdtsc_cycles_per_ms(); }
// =====================================================
// Chrome Trace Event Format support
// =====================================================
struct TraceEvent {
std::string name; // event name (op_name)
std::string cat; // category
char ph; // phase: 'X' for complete event, 'B' for begin, 'E' for end
double ts; // timestamp in microseconds (with ns precision via decimals)
double dur; // duration in microseconds (with ns precision via decimals)
int pid; // process id (numa_id)
int tid; // thread id
int task_count; // number of tasks processed
std::string args_json; // optional custom args JSON (for kernel traces)
};
static std::vector<TraceEvent> g_trace_events;
static std::mutex g_trace_mutex;
static uint64_t g_trace_start_time = 0; // baseline timestamp (RDTSC)
static double g_trace_start_epoch_us = 0.0; // wall-clock epoch time in microseconds
static std::string g_trace_output_path = "sft_trace.json";
// Thread-safe initialization using std::call_once
static std::once_flag g_trace_init_flag;
// Forward declaration for atexit registration.
static void write_trace_to_file();
// Initialize trace start time (thread-safe)
static void init_trace() {
std::call_once(g_trace_init_flag, []() {
g_trace_start_time = rdtsc_now();
// Record wall-clock epoch time for cross-process trace alignment
auto now_wall = std::chrono::system_clock::now();
auto epoch_us = std::chrono::duration_cast<std::chrono::microseconds>(now_wall.time_since_epoch()).count();
g_trace_start_epoch_us = static_cast<double>(epoch_us);
// Check for custom output path from environment
const char* env_path = std::getenv("SFT_TRACE_PATH");
if (env_path && env_path[0] != '\0') {
g_trace_output_path = env_path;
}
// Flush trace on normal exit before static destructors run.
std::atexit(write_trace_to_file);
});
}
// Convert RDTSC cycles to microseconds with nanosecond precision (as double)
// Chrome tracing uses microseconds but supports fractional values for sub-us precision
static double cycles_to_us(uint64_t cycles) {
// cycles_per_ms * 1000 = cycles_per_us
// cycles / cycles_per_us = microseconds
// Using 1e6 for cycles_per_ms -> cycles_per_s, then divide to get us with ns precision
double cycles_per_us = get_rdtsc_cycles_per_ms() / 1000.0;
return static_cast<double>(cycles) / cycles_per_us;
}
// Add trace events for an operation using absolute timestamps
static void add_trace_events(const char* op_name, int numa_id, int thread_count, const uint64_t* start_ts_arr,
const uint64_t* end_ts_arr, const int* tasks) {
init_trace();
std::lock_guard<std::mutex> lock(g_trace_mutex);
for (int i = 0; i < thread_count; i++) {
// Convert absolute RDTSC timestamps to relative microseconds from trace start
double start_us = (start_ts_arr[i] > g_trace_start_time) ? cycles_to_us(start_ts_arr[i] - g_trace_start_time) : 0.0;
double end_us = (end_ts_arr[i] > g_trace_start_time) ? cycles_to_us(end_ts_arr[i] - g_trace_start_time) : 0.0;
double dur_us = end_us - start_us;
if (dur_us < 0) dur_us = 0;
TraceEvent ev;
ev.name = op_name;
ev.cat = "sft_op";
ev.ph = 'X'; // Complete event
ev.ts = start_us;
ev.dur = dur_us;
ev.pid = numa_id;
ev.tid = i;
ev.task_count = tasks[i];
g_trace_events.push_back(ev);
}
}
// Write trace events to JSON file (Chrome Trace Event Format)
static void write_trace_to_file() {
std::lock_guard<std::mutex> lock(g_trace_mutex);
if (g_trace_events.empty()) {
return;
}
// Sort events by (pid, tid, ts) to fix overlap issues in Chrome trace viewer
// Events from same thread should be ordered by start time
std::sort(g_trace_events.begin(), g_trace_events.end(), [](const TraceEvent& a, const TraceEvent& b) {
if (a.pid != b.pid) return a.pid < b.pid;
if (a.tid != b.tid) return a.tid < b.tid;
return a.ts < b.ts;
});
std::ofstream ofs(g_trace_output_path);
if (!ofs.is_open()) {
fprintf(stderr, "sft_timer: Failed to open trace file: %s\n", g_trace_output_path.c_str());
return;
}
// Use fixed precision for nanosecond accuracy (3 decimal places in microseconds = nanoseconds)
ofs << std::fixed << std::setprecision(3);
ofs << "{\n";
ofs << " \"traceEvents\": [\n";
for (size_t i = 0; i < g_trace_events.size(); i++) {
const auto& ev = g_trace_events[i];
ofs << " {";
ofs << "\"name\":\"" << ev.name << "\",";
ofs << "\"cat\":\"" << ev.cat << "\",";
ofs << "\"ph\":\"" << ev.ph << "\",";
ofs << "\"ts\":" << ev.ts << ",";
ofs << "\"dur\":" << ev.dur << ",";
ofs << "\"pid\":" << ev.pid << ",";
ofs << "\"tid\":" << ev.tid << ",";
if (!ev.args_json.empty()) {
ofs << "\"args\":" << ev.args_json;
} else {
ofs << "\"args\":{\"task_count\":" << ev.task_count << "}";
}
ofs << "}";
if (i < g_trace_events.size() - 1) {
ofs << ",";
}
ofs << "\n";
}
ofs << " ],\n";
ofs << " \"metadata\": {\"start_epoch_us\": " << std::setprecision(0) << g_trace_start_epoch_us << "},\n";
ofs << std::setprecision(3);
ofs << " \"displayTimeUnit\": \"ns\"\n";
ofs << "}\n";
ofs.close();
fprintf(stderr, "sft_timer: Trace written to %s (%zu events)\n", g_trace_output_path.c_str(), g_trace_events.size());
}
// Signal handler for SIGTERM
static void sigterm_handler(int sig) {
fprintf(stderr, "sft_timer: Received signal %d, writing trace...\n", sig);
write_trace_to_file();
// Re-raise the signal with default handler to allow normal termination
signal(sig, SIG_DFL);
raise(sig);
}
// Register signal handlers
static void register_signal_handlers() {
static bool registered = false;
if (!registered) {
signal(SIGTERM, sigterm_handler);
signal(SIGINT, sigterm_handler);
registered = true;
}
}
void print_rt(FILE* out, const char* name, uint64_t* rt, int* tasks, int rt_threads) {
if (rt_threads <= 0) return;
FILE* output = out ? out : stderr;
auto max_val = *std::max_element(rt, rt + rt_threads);
auto min_val = *std::min_element(rt, rt + rt_threads);
uint64_t sum = std::accumulate(rt, rt + rt_threads, (uint64_t)0);
int total_tasks = std::accumulate(tasks, tasks + rt_threads, 0);
// Sort to find 20% and 80% percentile thresholds
std::vector<uint64_t> sorted(rt, rt + rt_threads);
std::sort(sorted.begin(), sorted.end());
int p20_idx = rt_threads * 20 / 100;
int p80_idx = rt_threads * 80 / 100;
uint64_t p20_threshold = sorted[p20_idx]; // Fast threshold (top 20%)
uint64_t p80_threshold = sorted[p80_idx]; // Slow threshold (bottom 20%)
// ANSI color codes
const char* GREEN = "\033[32m";
const char* RED = "\033[31m";
const char* RESET = "\033[0m";
// Line 1: time
fprintf(output, "%30s max %.3f min %.3f avg %.3f : ", name, ticks_to_ms(max_val), ticks_to_ms(min_val),
ticks_to_ms(sum / rt_threads));
for (int i = 0; i < rt_threads; i++) {
if (rt[i] <= p20_threshold) {
fprintf(output, "%s%.3f%s ", GREEN, ticks_to_ms(rt[i]), RESET);
} else if (rt[i] >= p80_threshold) {
fprintf(output, "%s%.3f%s ", RED, ticks_to_ms(rt[i]), RESET);
} else {
fprintf(output, "%.3f ", ticks_to_ms(rt[i]));
}
}
fprintf(output, "\n");
// Line 2: task count
fprintf(output, "%30s total %d : ", "tasks", total_tasks);
for (int i = 0; i < rt_threads; i++) {
if (rt[i] <= p20_threshold) {
fprintf(output, "%s%d%s ", GREEN, tasks[i], RESET);
} else if (rt[i] >= p80_threshold) {
fprintf(output, "%s%d%s ", RED, tasks[i], RESET);
} else {
fprintf(output, "%d ", tasks[i]);
}
}
fprintf(output, "\n");
}
void reset_forward() {
std::fill(forward_rt, forward_rt + MAX_THREADS, 0);
std::fill(forward_tasks, forward_tasks + MAX_THREADS, 0);
forward_threads = 0;
}
void reset_backward() {
std::fill(backward_rt, backward_rt + MAX_THREADS, 0);
std::fill(backward_tasks, backward_tasks + MAX_THREADS, 0);
backward_threads = 0;
}
void collect_forward(InNumaPool* pool) {
int n = pool->get_worker_count();
for (int i = 0; i < n && forward_threads < MAX_THREADS; i++) {
forward_rt[forward_threads] = pool->get_thread_cycles(i);
forward_tasks[forward_threads] = pool->get_thread_task_count(i);
forward_threads++;
}
}
void collect_backward(InNumaPool* pool) {
int n = pool->get_worker_count();
for (int i = 0; i < n && backward_threads < MAX_THREADS; i++) {
backward_rt[backward_threads] = pool->get_thread_cycles(i);
backward_tasks[backward_threads] = pool->get_thread_task_count(i);
backward_threads++;
}
}
void print_forward() { print_rt(stderr, "forward", forward_rt, forward_tasks, forward_threads); }
void print_backward(const char* name) { print_rt(stderr, name, backward_rt, backward_tasks, backward_threads); }
void print_op_stats(InNumaPool* pool, const char* op_name) {
if (pool == nullptr || op_name == nullptr || op_name[0] == '\0') {
return;
}
int n = pool->get_worker_count();
if (n <= 0) {
return;
}
// Ensure signal handlers are registered on first call
static bool handlers_registered = false;
if (!handlers_registered) {
register_signal_handlers();
handlers_registered = true;
}
FILE* output = stderr;
int numa_id = pool->get_numa_id();
// if (numa_id == 0) {
// output = stdout;
// } else if (numa_id == 1) {
// output = stderr;
// }
std::vector<uint64_t> rt(n);
std::vector<uint64_t> start_ts(n);
std::vector<uint64_t> end_ts(n);
std::vector<int> tasks(n);
for (int i = 0; i < n; i++) {
rt[i] = pool->get_thread_cycles(i);
tasks[i] = pool->get_thread_task_count(i);
start_ts[i] = pool->get_thread_start_ts(i);
end_ts[i] = pool->get_thread_end_ts(i);
}
// print_rt(output, op_name, rt.data(), tasks.data(), n);
// Save trace data to memory for later export
add_trace_events(op_name, numa_id, n, start_ts.data(), end_ts.data(), tasks.data());
}
// =====================================================
// Kernel-level tracing API implementation
// =====================================================
uint64_t get_trace_timestamp() { return rdtsc_now(); }
void add_kernel_trace(const char* name, uint64_t start_ts, uint64_t end_ts, int numa_id, int thread_id,
const char* args) {
init_trace();
// Convert absolute RDTSC timestamps to relative microseconds from trace start
double start_us = (start_ts > g_trace_start_time) ? cycles_to_us(start_ts - g_trace_start_time) : 0.0;
double end_us = (end_ts > g_trace_start_time) ? cycles_to_us(end_ts - g_trace_start_time) : 0.0;
double dur_us = end_us - start_us;
if (dur_us < 0) dur_us = 0;
std::lock_guard<std::mutex> lock(g_trace_mutex);
TraceEvent ev;
ev.name = name;
ev.cat = "kernel";
ev.ph = 'X'; // Complete event
ev.ts = start_us;
ev.dur = dur_us;
ev.pid = numa_id;
ev.tid = thread_id;
ev.task_count = 0; // Not applicable for kernel traces
if (args != nullptr && args[0] != '\0') {
ev.args_json = args;
}
g_trace_events.push_back(ev);
}
} // namespace sft_timer
#endif // SFT_TIMER_DISABLED
// Intel ITT API for profiler integration (VTune, etc.)
// Allows profilers to identify spin-wait regions
#ifdef USE_ITT_NOTIFY
#include <ittnotify.h>
static __itt_domain* g_itt_domain = nullptr;
static __itt_string_handle* g_itt_spin_wait = nullptr;
static void init_itt() {
if (g_itt_domain == nullptr) {
g_itt_domain = __itt_domain_create("WorkerPool");
g_itt_spin_wait = __itt_string_handle_create("SpinWait");
}
}
#define ITT_SYNC_PREPARE(addr) __itt_sync_prepare(addr)
#define ITT_SYNC_CANCEL(addr) __itt_sync_cancel(addr)
#define ITT_SYNC_ACQUIRED(addr) __itt_sync_acquired(addr)
#else
#define ITT_SYNC_PREPARE(addr) ((void)0)
#define ITT_SYNC_CANCEL(addr) ((void)0)
#define ITT_SYNC_ACQUIRED(addr) ((void)0)
static void init_itt() {}
#endif
thread_local int WorkerPool::thread_local_id = -1;
InNumaPool::InNumaPool(int max_thread_num) {
printf("In Numa Worker Pool at NUMA %d, %d threads\n", numa_node_of_cpu(sched_getcpu()), max_thread_num);
numa_id_ = numa_node_of_cpu(sched_getcpu());
total_worker_count = max_thread_num;
block_size_ = 0;
set_restricted_worker_count(total_worker_count);
thread_state_ = std::unique_ptr<ThreadState[]>(new ThreadState[max_thread_num]);
for (int i = 0; i < total_worker_count; i++) {
@ -463,9 +46,7 @@ InNumaPool::InNumaPool(int max_thread_num, int numa_id, int threads_id_start) {
hwloc_topology_init(&topology);
hwloc_topology_load(topology);
printf("In Numa Worker Pool at NUMA %d, %d threads\n", numa_node_of_cpu(sched_getcpu()), max_thread_num);
numa_id_ = numa_id;
total_worker_count = max_thread_num;
block_size_ = 0;
set_restricted_worker_count(total_worker_count);
thread_state_ = std::unique_ptr<ThreadState[]>(new ThreadState[max_thread_num]);
for (int i = 0; i < total_worker_count; i++) {
@ -514,7 +95,11 @@ InNumaPool::InNumaPool(int max_thread_num, int numa_id, int threads_id_start) {
InNumaPool::~InNumaPool() {
for (int i = 0; i < total_worker_count; i++) {
thread_state_[i].status.store(ThreadStatus::EXIT, std::memory_order_release);
{
std::lock_guard<std::mutex> lock(thread_state_[i].mutex);
thread_state_[i].status.store(ThreadStatus::EXIT, std::memory_order_release);
}
thread_state_[i].cv.notify_one();
}
for (int i = 0; i < total_worker_count; i++) {
if (workers_[i].joinable()) {
@ -551,51 +136,41 @@ void InNumaPool::wait() {
#endif
}
void InNumaPool::do_work_stealing_job(int task_num, std::function<void(int)> compute_func, const char* task_name,
int block_size, bool async) {
do_work_stealing_job(task_num, nullptr, compute_func, nullptr, task_name, block_size);
void InNumaPool::do_work_stealing_job(int task_num, std::function<void(int)> compute_func) {
do_work_stealing_job(task_num, nullptr, compute_func, nullptr);
}
void InNumaPool::do_work_stealing_job(int task_num, std::function<void(int)> init_func,
std::function<void(int)> compute_func, std::function<void(int)> finalize_func,
const char* task_name, int block_size, bool async) {
bool has_name = task_name != nullptr && task_name[0] != '\0';
if (has_name) {
reset_counters();
}
do_work_stealing_job_async(task_num, init_func, compute_func, finalize_func, block_size);
if (!async) wait();
if (has_name) {
sft_timer::print_op_stats(this, task_name);
}
std::function<void(int)> compute_func, std::function<void(int)> finalize_func) {
do_work_stealing_job_async(task_num, init_func, compute_func, finalize_func);
wait();
}
void InNumaPool::do_work_stealing_job_async(int task_num, std::function<void(int)> init_func,
std::function<void(int)> compute_func,
std::function<void(int)> finalize_func, int block_size) {
std::function<void(int)> finalize_func) {
init_func_ = init_func;
compute_func_ = compute_func;
finalize_func_ = finalize_func;
block_size_ = block_size;
worker_count = std::min(restricted_worker_count, task_num);
curr_.store(0, std::memory_order_release);
end_ = task_num;
for (int i = 0; i < worker_count; i++) {
thread_state_[i].status.store(ThreadStatus::WORKING, std::memory_order_release);
{
std::lock_guard<std::mutex> lock(thread_state_[i].mutex);
thread_state_[i].status.store(ThreadStatus::WORKING, std::memory_order_release);
}
thread_state_[i].cv.notify_one();
}
WorkerPool::thread_local_id = 0;
process_tasks(0);
}
void InNumaPool::process_tasks(int thread_id) {
uint64_t start_cycles = rdtsc_now();
#ifdef PROFILE_BALANCE
auto start = std::chrono::high_resolution_clock::now();
#endif
auto& s = thread_state_[thread_id];
int local_task_count = 0;
// Record absolute start timestamp
s.start_ts = start_cycles;
if (init_func_ != nullptr) {
init_func_(thread_id);
}
@ -608,12 +183,7 @@ void InNumaPool::process_tasks(int thread_id) {
break;
}
int block = 0;
if (block_size_ > 0) {
block = std::min(block_size_, rem);
} else {
block = (rem + worker_count - 1) / worker_count;
}
int block = (rem + worker_count - 1) / worker_count;
block = 1;
int task_id = curr_.fetch_add(block, std::memory_order_acq_rel);
if (task_id >= end_) {
@ -625,7 +195,6 @@ void InNumaPool::process_tasks(int thread_id) {
break;
}
compute_func_(task_id + i);
local_task_count++;
}
}
@ -633,44 +202,34 @@ void InNumaPool::process_tasks(int thread_id) {
finalize_func_(thread_id);
}
// IMPORTANT: Update timing BEFORE setting status to WAITING
// The release semantics of status.store() ensures all prior writes are visible
uint64_t end_cycles = rdtsc_now();
s.finish_cycles = end_cycles - start_cycles;
s.task_count = local_task_count;
s.end_ts = end_cycles;
// Signal completion - release ensures timing writes are visible to wait()
s.status.store(ThreadStatus::WAITING, std::memory_order_release);
#ifdef PROFILE_BALANCE
s.finish_ns =
std::chrono::duration_cast<std::chrono::nanoseconds>(std::chrono::high_resolution_clock::now() - start).count();
#endif
}
void InNumaPool::worker_thread(int thread_id, int numa_id) {
if (numa_id >= 0) {
set_memory_to_numa(numa_id);
}
init_itt(); // Initialize ITT if enabled
// Use RDTSC for lightweight timing instead of std::chrono
const uint64_t sleep_threshold_cycles = get_rdtsc_cycles_per_ms() * 50; // 50ms in cycles
uint64_t start = rdtsc_now();
auto start = std::chrono::high_resolution_clock::now();
WorkerPool::thread_local_id = thread_id; // 设置线程本地变量
while (true) {
ITT_SYNC_PREPARE(&thread_state_[thread_id].status); // Signal profiler: about to spin-wait
ThreadStatus status = thread_state_[thread_id].status.load(std::memory_order_acquire);
if (status == ThreadStatus::WORKING) {
ITT_SYNC_ACQUIRED(&thread_state_[thread_id].status); // Signal profiler: acquired work
process_tasks(thread_id);
start = rdtsc_now();
start = std::chrono::high_resolution_clock::now();
} else if (status == ThreadStatus::WAITING) {
// PAUSE instruction hints to CPU this is a spin-wait loop
_mm_pause();
uint64_t now = rdtsc_now();
uint64_t elapsed_cycles = now - start;
if (elapsed_cycles > sleep_threshold_cycles) {
ITT_SYNC_CANCEL(&thread_state_[thread_id].status); // Signal profiler: going to sleep
std::this_thread::sleep_for(std::chrono::microseconds(100));
auto now = std::chrono::high_resolution_clock::now();
auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(now - start).count();
if (duration > 50) {
std::unique_lock<std::mutex> lock(thread_state_[thread_id].mutex);
thread_state_[thread_id].cv.wait(lock, [&] {
return thread_state_[thread_id].status.load(std::memory_order_acquire) != ThreadStatus::WAITING;
});
}
} else if (status == ThreadStatus::EXIT) {
ITT_SYNC_CANCEL(&thread_state_[thread_id].status); // Signal profiler: exiting
return;
}
}
@ -695,6 +254,8 @@ void NumaJobDistributor::init(std::vector<int> numa_ids) {
this->numa_ids = numa_ids;
for (size_t i = 0; i < numa_count; i++) {
status.push_back(nullptr);
mutexes.push_back(std::make_unique<std::mutex>());
cvs.push_back(std::make_unique<std::condition_variable>());
}
workers.resize(numa_count);
@ -716,6 +277,8 @@ void NumaJobDistributor::init(std::vector<int> numa_ids, std::vector<int> thread
this->numa_ids = numa_ids;
for (size_t i = 0; i < numa_count; i++) {
status.push_back(nullptr);
mutexes.push_back(std::make_unique<std::mutex>());
cvs.push_back(std::make_unique<std::condition_variable>());
}
workers.resize(numa_count);
@ -767,7 +330,11 @@ void NumaJobDistributor::init(std::vector<int> numa_ids, std::vector<int> thread
NumaJobDistributor::~NumaJobDistributor() {
for (int i = 0; i < numa_count; i++) {
status[i]->store(ThreadStatus::EXIT, std::memory_order_release);
{
std::lock_guard<std::mutex> lock(*mutexes[i]);
status[i]->store(ThreadStatus::EXIT, std::memory_order_release);
}
cvs[i]->notify_one();
}
for (int i = 0; i < numa_count; i++) {
if (workers[i].joinable()) {
@ -784,7 +351,11 @@ void NumaJobDistributor::do_numa_job(std::function<void(int)> compute_func) {
for (int i = 0; i < numa_count; i++) {
if (i == me_numa) continue;
status[i]->store(ThreadStatus::WORKING, std::memory_order_release);
{
std::lock_guard<std::mutex> lock(*mutexes[i]);
status[i]->store(ThreadStatus::WORKING, std::memory_order_release);
}
cvs[i]->notify_one();
}
compute_func(me_numa);
for (int i = 0; i < numa_count; i++) {
@ -798,7 +369,11 @@ void NumaJobDistributor::do_numa_job(std::function<void(int)> compute_func) {
void NumaJobDistributor::do_numa_job(std::function<void(int)> compute_func) {
this->compute_func = compute_func;
for (int i = 0; i < numa_count; i++) {
status[i]->store(ThreadStatus::WORKING, std::memory_order_release);
{
std::lock_guard<std::mutex> lock(*mutexes[i]);
status[i]->store(ThreadStatus::WORKING, std::memory_order_release);
}
cvs[i]->notify_one();
}
for (int i = 0; i < numa_count; i++) {
while (status[i]->load(std::memory_order_acquire) == ThreadStatus::WORKING) {
@ -808,35 +383,29 @@ void NumaJobDistributor::do_numa_job(std::function<void(int)> compute_func) {
#endif
void NumaJobDistributor::worker_thread(int numa_id) {
init_itt(); // Initialize ITT if enabled
// Use RDTSC for lightweight timing instead of std::chrono
const uint64_t sleep_threshold_cycles = get_rdtsc_cycles_per_ms() * 50; // 50ms in cycles
uint64_t start = rdtsc_now();
auto start = std::chrono::high_resolution_clock::now();
set_memory_to_numa(numa_id);
status[numa_id] =
std::move(std::unique_ptr<std::atomic<ThreadStatus>>(new std::atomic<ThreadStatus>(ThreadStatus::WAITING)));
ready_bar->arrive_and_wait();
while (true) {
ITT_SYNC_PREPARE(status[numa_id].get()); // Signal profiler: about to spin-wait
auto stat = status[numa_id]->load(std::memory_order_acquire);
if (stat == ThreadStatus::WORKING) {
ITT_SYNC_ACQUIRED(status[numa_id].get()); // Signal profiler: acquired work
auto me_numa = numa_node_of_cpu(sched_getcpu());
// printf("numa work on %d, me %d\n", numa_id, me_numa);
compute_func(numa_id);
status[numa_id]->store(ThreadStatus::WAITING, std::memory_order_release);
start = rdtsc_now();
start = std::chrono::high_resolution_clock::now();
} else if (stat == ThreadStatus::WAITING) {
// PAUSE instruction hints to CPU this is a spin-wait loop
_mm_pause();
uint64_t now = rdtsc_now();
uint64_t elapsed_cycles = now - start;
if (elapsed_cycles > sleep_threshold_cycles) {
ITT_SYNC_CANCEL(status[numa_id].get()); // Signal profiler: going to sleep
std::this_thread::sleep_for(std::chrono::milliseconds(1));
auto now = std::chrono::high_resolution_clock::now();
auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(now - start).count();
if (duration > 50) {
std::unique_lock<std::mutex> lock(*mutexes[numa_id]);
cvs[numa_id]->wait(lock, [&] {
return status[numa_id]->load(std::memory_order_acquire) != ThreadStatus::WAITING;
});
}
} else if (stat == ThreadStatus::EXIT) {
ITT_SYNC_CANCEL(status[numa_id].get()); // Signal profiler: exiting
return;
}
}
@ -911,15 +480,10 @@ InNumaPool* WorkerPool::get_subpool(int numa_id) { return numa_worker_pools[numa
NumaJobDistributor* WorkerPool::dispense_backend() { return distributor.get(); }
void WorkerPool::do_work_stealing_job(int task_num, std::function<void(int)> init_func,
std::function<void(int)> compute_func, std::function<void(int)> finalize_func,
const char* task_name, int block_size, bool async) {
numa_worker_pools[0]->do_work_stealing_job(task_num, init_func, compute_func, finalize_func, task_name, block_size,
async);
std::function<void(int)> compute_func, std::function<void(int)> finalize_func) {
numa_worker_pools[0]->do_work_stealing_job(task_num, init_func, compute_func, finalize_func);
}
void WorkerPool::do_work_stealing_job(int task_num, std::function<void(int)> compute_func, const char* task_name,
int block_size, bool async) {
do_work_stealing_job(task_num, nullptr, compute_func, nullptr, task_name, block_size, async);
void WorkerPool::do_work_stealing_job(int task_num, std::function<void(int)> compute_func) {
do_work_stealing_job(task_num, nullptr, compute_func, nullptr);
}
void WorkerPool::wait() { numa_worker_pools[0]->wait(); }

View file

@ -62,10 +62,11 @@ enum ThreadStatus {
struct alignas(64) ThreadState {
std::atomic<ThreadStatus> status;
uint64_t finish_cycles; // Per-thread timing (always enabled)
int task_count; // Per-thread task count
uint64_t start_ts; // Absolute start timestamp (RDTSC)
uint64_t end_ts; // Absolute end timestamp (RDTSC)
std::mutex mutex;
std::condition_variable cv;
#ifdef PROFILE_BALANCE
size_t finish_ns;
#endif
};
class InNumaPool {
@ -76,45 +77,21 @@ class InNumaPool {
int get_thread_num();
void set_restricted_worker_count(int count);
void do_work_stealing_job_async(int, std::function<void(int)>, std::function<void(int)>, std::function<void(int)>,
int block_size = 0);
void do_work_stealing_job_async(int, std::function<void(int)>, std::function<void(int)>, std::function<void(int)>);
void wait();
void do_work_stealing_job(int, std::function<void(int)>, std::function<void(int)>, std::function<void(int)>,
const char* task_name = nullptr, int block_size = 0, bool async = false);
void do_work_stealing_job(int, std::function<void(int)>, const char* task_name = nullptr, int block_size = 0,
bool async = false);
// Get per-thread timing info
int get_worker_count() const { return worker_count; }
int get_numa_id() const { return numa_id_; }
uint64_t get_thread_cycles(int tid) const { return thread_state_[tid].finish_cycles; }
int get_thread_task_count(int tid) const { return thread_state_[tid].task_count; }
uint64_t get_thread_start_ts(int tid) const { return thread_state_[tid].start_ts; }
uint64_t get_thread_end_ts(int tid) const { return thread_state_[tid].end_ts; }
// Reset per-thread timing/task counters (call before timing a sequence of operations)
// NOTE: Only call when all workers are in WAITING state (after wait() returns)
void reset_counters() {
for (int i = 0; i < total_worker_count; i++) {
thread_state_[i].finish_cycles = 0;
thread_state_[i].task_count = 0;
thread_state_[i].start_ts = 0;
thread_state_[i].end_ts = 0;
}
}
void do_work_stealing_job(int, std::function<void(int)>, std::function<void(int)>, std::function<void(int)>);
void do_work_stealing_job(int, std::function<void(int)>);
private:
int worker_count;
int total_worker_count;
int numa_id_;
std::unique_ptr<ThreadState[]> thread_state_; // [thread_num]
std::vector<std::thread> workers_;
// changed ever time called do_work_stealing_job_async
int restricted_worker_count;
int block_size_;
std::function<void(int)> init_func_;
std::function<void(int)> compute_func_;
std::function<void(int)> finalize_func_;
@ -144,6 +121,8 @@ class NumaJobDistributor {
int numa_count;
std::vector<int> numa_ids;
std::vector<std::unique_ptr<std::atomic<ThreadStatus>>> status;
std::vector<std::unique_ptr<std::mutex>> mutexes;
std::vector<std::unique_ptr<std::condition_variable>> cvs;
std::function<void(int)> compute_func;
std::vector<std::thread> workers;
@ -171,12 +150,8 @@ class WorkerPool {
InNumaPool* get_subpool(int numa_id);
void do_work_stealing_job(int, std::function<void(int)>, std::function<void(int)>, std::function<void(int)>,
const char* task_name = nullptr, int block_size = 0, bool async = false);
void do_work_stealing_job(int, std::function<void(int)>, const char* task_name = nullptr, int block_size = 0,
bool async = false);
void wait();
void do_work_stealing_job(int, std::function<void(int)>, std::function<void(int)>, std::function<void(int)>);
void do_work_stealing_job(int, std::function<void(int)>);
WorkerPoolConfig config;
@ -191,58 +166,4 @@ class WorkerPool {
std::vector<std::unique_ptr<InNumaPool>> numa_worker_pools;
};
// =====================================================
// Global per-thread timing for SFT MOE forward/backward
// =====================================================
// Define SFT_TIMER_DISABLED to disable all timing (functions become no-ops)
// #define SFT_TIMER_DISABLED
namespace sft_timer {
#ifdef SFT_TIMER_DISABLED
// Disabled: all functions are no-ops
inline void reset_forward() {}
inline void reset_backward() {}
inline void collect_forward(InNumaPool*) {}
inline void collect_backward(InNumaPool*) {}
inline void print_forward() {}
inline void print_backward(const char* = "backward") {}
inline void print_op_stats(InNumaPool*, const char*) {}
inline uint64_t get_trace_timestamp() { return 0; }
inline void add_kernel_trace(const char*, uint64_t, uint64_t, int, int, const char* = nullptr) {}
#else
// Enabled: declarations only, implementation in worker_pool.cpp
void reset_forward();
void reset_backward();
void collect_forward(InNumaPool* pool);
void collect_backward(InNumaPool* pool);
void print_forward();
void print_backward(const char* name = "backward");
// Print per-thread timing for a single operation
// Call pool->reset_counters() BEFORE the operation, then call this AFTER
void print_op_stats(InNumaPool* pool, const char* op_name);
// =====================================================
// Kernel-level tracing API
// For tracing individual kernels (e.g., AVX matmul) within worker threads
// =====================================================
// Get current RDTSC timestamp (lightweight, ~20 cycles overhead)
uint64_t get_trace_timestamp();
// Add a kernel trace event
// @param name Kernel name (e.g., "lora_bf16_matmul_t4r4")
// @param start_ts Start timestamp from get_trace_timestamp()
// @param end_ts End timestamp from get_trace_timestamp()
// @param numa_id NUMA node ID (use -1 for auto-detect or 0 if unknown)
// @param thread_id Thread ID within the pool (use WorkerPool::thread_local_id)
// @param args Optional JSON args string (e.g., "{\"tokens\":128,\"rank\":8}")
void add_kernel_trace(const char* name, uint64_t start_ts, uint64_t end_ts, int numa_id, int thread_id,
const char* args = nullptr);
static void write_trace_to_file(); // Write all collected traces to a file (e.g., "sft_kernel_traces.json")
#endif
} // namespace sft_timer
#endif

View file

@ -48,9 +48,6 @@ namespace avx {
*/
inline void lora_bf16_matmul_t4r4(const ggml_bf16_t* __restrict input, const ggml_bf16_t* __restrict weight,
float* __restrict output, int num_tokens, int k_dim, int rank) {
// #if AVX_KERNEL_TRACE_ENABLED
// uint64_t trace_start = sft_timer::get_trace_timestamp();
// #endif
constexpr int T_BLOCK = 4;
constexpr int R_BLOCK = 4;
@ -240,13 +237,6 @@ inline void lora_bf16_matmul_t4r4(const ggml_bf16_t* __restrict input, const ggm
}
}
// #if AVX_KERNEL_TRACE_ENABLED
// uint64_t trace_end = sft_timer::get_trace_timestamp();
// char args_buf[128];
// snprintf(args_buf, sizeof(args_buf), "{\"T\":%d,\"K\":%d,\"R\":%d}", num_tokens, k_dim, rank);
// sft_timer::add_kernel_trace("lora_bf16_matmul_t4r4", trace_start, trace_end, 0, WorkerPool::thread_local_id,
// args_buf);
// #endif
}
/**
@ -275,7 +265,6 @@ inline void lora_fp32_bf16_fused_add(const float* __restrict intermediate, const
ggml_bf16_t* __restrict output, int num_tokens, int rank, int output_dim,
float scale) {
#if AVX_KERNEL_TRACE_ENABLED
uint64_t trace_start = sft_timer::get_trace_timestamp();
#endif
constexpr int T_BLOCK = 4;
@ -579,11 +568,8 @@ inline void lora_fp32_bf16_fused_add(const float* __restrict intermediate, const
}
#if AVX_KERNEL_TRACE_ENABLED
uint64_t trace_end = sft_timer::get_trace_timestamp();
char args_buf[128];
snprintf(args_buf, sizeof(args_buf), "{\"T\":%d,\"R\":%d,\"O\":%d}", num_tokens, rank, output_dim);
sft_timer::add_kernel_trace("lora_fp32_bf16_fused_add", trace_start, trace_end, 0, WorkerPool::thread_local_id,
args_buf);
#endif
}
@ -615,7 +601,6 @@ inline void lora_fp32_bf16_fused_add_wt(const float* __restrict intermediate, co
ggml_bf16_t* __restrict output, int num_tokens, int rank, int output_dim,
float scale) {
#if AVX_KERNEL_TRACE_ENABLED
uint64_t trace_start = sft_timer::get_trace_timestamp();
#endif
constexpr int T_BLOCK = 4;
@ -886,11 +871,8 @@ inline void lora_fp32_bf16_fused_add_wt(const float* __restrict intermediate, co
}
#if AVX_KERNEL_TRACE_ENABLED
uint64_t trace_end = sft_timer::get_trace_timestamp();
char args_buf[128];
snprintf(args_buf, sizeof(args_buf), "{\"T\":%d,\"R\":%d,\"O\":%d}", num_tokens, rank, output_dim);
sft_timer::add_kernel_trace("lora_fp32_bf16_fused_add_wt", trace_start, trace_end, 0, WorkerPool::thread_local_id,
args_buf);
#endif
}
@ -1190,7 +1172,6 @@ inline void lora_fp32_bf16_fused_add_transposed(const float* __restrict intermed
const ggml_bf16_t* __restrict weight_t, ggml_bf16_t* __restrict output,
int num_tokens, int rank, int output_dim, float scale) {
#if AVX_KERNEL_TRACE_ENABLED
uint64_t trace_start = sft_timer::get_trace_timestamp();
#endif
constexpr int T_BLOCK = 4;
@ -1384,11 +1365,8 @@ inline void lora_fp32_bf16_fused_add_transposed(const float* __restrict intermed
}
#if AVX_KERNEL_TRACE_ENABLED
uint64_t trace_end = sft_timer::get_trace_timestamp();
char args_buf[128];
snprintf(args_buf, sizeof(args_buf), "{\"T\":%d,\"R\":%d,\"O\":%d}", num_tokens, rank, output_dim);
sft_timer::add_kernel_trace("lora_fp32_bf16_fused_add_transposed", trace_start, trace_end, 0,
WorkerPool::thread_local_id, args_buf);
#endif
}
@ -1408,7 +1386,6 @@ inline void lora_fp32_bf16_fused_add_transposed(const float* __restrict intermed
inline void lora_backward_matmul_transposed(const ggml_bf16_t* __restrict grad, const ggml_bf16_t* __restrict lora_b_t,
float* __restrict result, int num_tokens, int hidden, int rank) {
#if AVX_KERNEL_TRACE_ENABLED
uint64_t trace_start = sft_timer::get_trace_timestamp();
#endif
constexpr int H_BLOCK = 32;
@ -1448,11 +1425,8 @@ inline void lora_backward_matmul_transposed(const ggml_bf16_t* __restrict grad,
}
#if AVX_KERNEL_TRACE_ENABLED
uint64_t trace_end = sft_timer::get_trace_timestamp();
char args_buf[128];
snprintf(args_buf, sizeof(args_buf), "{\"T\":%d,\"H\":%d,\"R\":%d}", num_tokens, hidden, rank);
sft_timer::add_kernel_trace("lora_backward_matmul_transposed", trace_start, trace_end, 0, WorkerPool::thread_local_id,
args_buf);
#endif
}

View file

@ -1075,13 +1075,13 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
}
// Step 3: Copy input to expert buffers
auto direct_or_pool = [&](int count, auto&& fn, const char* task_name, int block_size) {
auto direct_or_pool = [&](int count, auto&& fn) {
if (qlen < 10) {
for (int i = 0; i < count; i++) {
fn(i);
}
} else {
pool->do_work_stealing_job(count, nullptr, fn, nullptr, task_name, block_size);
pool->do_work_stealing_job(count, nullptr, fn, nullptr);
}
};
@ -1095,8 +1095,7 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
memcpy(m_local_input_ptr_[expert_ids[i * k + j]] + m_local_pos_[i][j] * config_.hidden_size,
(ggml_bf16_t*)input + i * config_.hidden_size, sizeof(ggml_bf16_t) * config_.hidden_size);
}
},
"fwd_pack_input", 1);
});
// NaN Check: Step 3 - Packed input
if (is_nan_check_enabled()) {
@ -1118,8 +1117,7 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
[this](int task_id) {
int expert_idx = m_expert_id_map_[task_id];
gate_up_ba_[expert_idx]->from_mat(m_local_num_[expert_idx], m_local_input_ptr_[expert_idx], 0, 1);
},
"fwd_quantize_in", 1);
});
// Step 5: Gate + Up GEMM (base projection)
int nth = T::recommended_nth(config_.intermediate_size);
@ -1137,7 +1135,7 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
gate_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_gate_output_ptr_[expert_idx], ith, nth);
}
},
nullptr, "fwd_gate_up_gemm", 1);
nullptr);
// NaN Check: Step 5 - Gate/Up GEMM output (before LoRA)
if (is_nan_check_enabled()) {
@ -1230,10 +1228,7 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
// Step 6: Activation (silu(gate) * up)
{
uint64_t act_start = sft_timer::get_trace_timestamp();
Base::apply_activation(activated_expert, nth, qlen);
uint64_t act_end = sft_timer::get_trace_timestamp();
sft_timer::add_kernel_trace("apply_activation", act_start, act_end, tp_part_idx, 0);
}
// NaN Check: Step 6 - Activation output (silu(gate) * up)
@ -1295,7 +1290,7 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
int expert_idx = m_expert_id_map_[task_id];
down_ba_[expert_idx]->from_mat(m_local_num_[expert_idx], m_local_gate_output_ptr_[expert_idx], 0, 1);
},
nullptr, "fwd_down_quantize");
nullptr);
// Step 8: Down GEMM
nth = T::recommended_nth(config_.hidden_size);
@ -1307,7 +1302,7 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
this->do_down_gemm(expert_idx, ith, nth, qlen);
down_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_down_output_ptr_[expert_idx], ith, nth);
},
nullptr, "fwd_down_gemm", 1);
nullptr);
// NaN Check: Step 8 - Down GEMM output (before LoRA)
if (is_nan_check_enabled()) {
@ -1373,7 +1368,7 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
f32out[1] = x1;
}
},
nullptr, "fwd_merge");
nullptr);
// NaN Check: Step 9 - Final output (after weighted merge)
if (is_nan_check_enabled()) {
@ -1428,12 +1423,6 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
int activated_expert = cache.activated_expert_cache;
constexpr int kSmallBwdDirectQlen = 0;
constexpr int kSmallBwdDirectMaxTasks = 16;
auto trace_phase = [this](const char* name, auto&& fn) {
uint64_t start = sft_timer::get_trace_timestamp();
fn();
uint64_t end = sft_timer::get_trace_timestamp();
sft_timer::add_kernel_trace(name, start, end, tp_part_idx, 0);
};
// NaN Check: grad_output input
if (is_nan_check_enabled()) {
@ -1492,9 +1481,8 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
}
// ★ Allocate backward-phase buffers ★
trace_phase("bwd_setup_alloc", [&] { alloc_backward_buffers(); });
alloc_backward_buffers();
trace_phase("bwd_setup_state", [&] {
// ★ share_backward_bb: check if async repack already prepared this layer ★
if (config_.share_backward_bb) {
auto& shared = SFTSharedPools::instance();
@ -1544,12 +1532,10 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
m_local_down_output_ptr_[i] = m_local_down_output_ + offset * config_.hidden_size;
offset += m_local_num_[i];
}
});
// Restore input data from cache into m_local_input_ (shared_mem_buffer may have been
// overwritten by subsequent layers' forward passes). This is needed for gate/up LoRA
// gradient computation which reads from m_local_input_ptr_.
trace_phase("bwd_phase_down", [&] {
auto pool_local = config_.pool->get_subpool(tp_part_idx);
auto restore_input = [&](int i) {
for (int j = 0; j < k; j++) {
@ -1569,7 +1555,7 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
restore_input(i);
}
} else {
pool_local->do_work_stealing_job(qlen, nullptr, restore_input, nullptr, "bwd_restore_input", 1);
pool_local->do_work_stealing_job(qlen, nullptr, restore_input, nullptr);
}
// Step 1: Down projection backward
@ -1579,7 +1565,6 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
} else {
// backward_down(cache, grad_output, grad_down_lora_a, grad_down_lora_b);
}
});
// // Compute total tokens for debug
// size_t total_tokens = 0;
@ -1655,7 +1640,7 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
// }
// }
trace_phase("bwd_phase_act", [&] { backward_activation(cache); });
backward_activation(cache);
// NaN Check: Step 2 - After backward_activation
if (is_nan_check_enabled()) {
@ -1695,14 +1680,12 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
// }
// }
trace_phase("bwd_phase_gate_up", [&] {
if constexpr (supports_standard_mat_mul_v<T>) {
backward_gate_up_amx(cache, grad_input, grad_gate_lora_a, grad_gate_lora_b, grad_up_lora_a, grad_up_lora_b,
full_intermediate_size, fp32_grad_gate_lora_a, fp32_grad_up_lora_a);
} else {
// backward_gate_up(cache, grad_input, grad_gate_lora_a, grad_gate_lora_b, grad_up_lora_a, grad_up_lora_b);
}
});
// NaN Check: Step 3 - After backward_gate_up
if (is_nan_check_enabled()) {
@ -1740,7 +1723,6 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
// Step 4: Compute grad_weights (gradient for routing weights)
// grad_weights[token_idx, expert_pos] = dot(grad_output[token_idx], down_output[token, expert])
if (grad_weights != nullptr) {
trace_phase("bwd_phase_gradw", [&] {
auto pool = config_.pool->get_subpool(tp_part_idx);
float* grad_w = (float*)grad_weights;
const ggml_bf16_t* grad_out = (const ggml_bf16_t*)grad_output;
@ -1788,9 +1770,8 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
compute_grad_weight(token_idx);
}
} else {
pool->do_work_stealing_job(qlen, nullptr, compute_grad_weight, nullptr, "bwd_grad_weights");
pool->do_work_stealing_job(qlen, nullptr, compute_grad_weight, nullptr);
}
});
}
// NaN Check: Step 4 - After grad_weights computation
@ -2106,7 +2087,7 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
break;
}
},
nullptr, "transpose_lora_b_weights");
nullptr);
}
/**
@ -2167,7 +2148,7 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
break;
}
},
nullptr, "fwd_lora_prep");
nullptr);
lora_weights_prepared_ = true;
}
@ -2212,7 +2193,7 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
break;
}
},
nullptr, "bwd_lora_prep");
nullptr);
lora_backward_weights_prepared_ = true;
}
@ -2265,7 +2246,7 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
dst_bb->from_mat_transposed((ggml_bf16_t*)(src + expert_offset), config_.intermediate_size,
config_.hidden_size, ith, nth_gate_up);
},
nullptr, "bwd_prep_gate_up");
nullptr);
// Phase 2: down backward
// down_proj: [hidden_size, intermediate_size] -> transposed BufferB [intermediate_size, hidden_size]
@ -2281,7 +2262,7 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
down_backward_bb_[expert_idx]->from_mat_transposed((ggml_bf16_t*)(src + expert_offset), config_.hidden_size,
config_.intermediate_size, ith, nth_down);
},
nullptr, "bwd_prep_down");
nullptr);
backward_weights_prepared_ = true;
}
@ -2317,7 +2298,7 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
dst_bb->from_mat_transposed(workspace.data(), src_bb->n, src_bb->k, p, dst_nth);
}
},
nullptr, "bwd_repack_gate_up");
nullptr);
// Phase 2: down (uses [hidden_size, intermediate_size] -> [intermediate_size, hidden_size])
pool->do_work_stealing_job(
@ -2339,7 +2320,7 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
dst_bb->from_mat_transposed(workspace.data(), src_bb->n, src_bb->k, p, dst_nth);
}
},
nullptr, "bwd_repack_down");
nullptr);
backward_weights_prepared_ = true;
}
@ -2543,7 +2524,7 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
down_backward_bb_[expert_idx].get());
}
},
nullptr, "load_bwd_kt");
nullptr);
if (!ok.load()) return false;
backward_weights_prepared_ = true;
@ -2596,7 +2577,7 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
}
}
},
nullptr, "load_bwd_projs");
nullptr);
backward_weights_prepared_ = true;
}
@ -3421,7 +3402,7 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
do_up ? lora_up_intermediate_ptr_[expert_idx] : lora_gate_intermediate_ptr_[expert_idx];
bc->to_mat(m, inter_ptr, ith, nth);
},
nullptr, "fwd_lora_gu_a");
nullptr);
// =====================================================
// Step 2: Quantize lora_intermediate to BufferA
@ -3439,7 +3420,7 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
ggml_bf16_t* ptr = do_up ? lora_up_intermediate_ptr_[expert_idx] : lora_gate_intermediate_ptr_[expert_idx];
ba->from_mat(m, ptr, 0, 1);
},
nullptr, "fwd_lora_gu_quant");
nullptr);
// =====================================================
// Step 3a: lora_intermediate @ lora_B^T -> lora_output (GEMM only)
@ -3464,7 +3445,7 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
// GEMM: [m, padded_lora_rank] @ [intermediate_size, padded_lora_rank]^T -> [m, intermediate_size]
amx::mat_mul(m, config_.intermediate_size, padded_lora_rank_, ba, bb, bc, ith, nth);
},
nullptr, "fwd_lora_gu_gemm");
nullptr);
// =====================================================
// Step 3b: Add LoRA output to main output with scaling
@ -3489,7 +3470,7 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
add_lora_output_to_main(bc.get(), main_output, m, config_.intermediate_size, lora_scaling_, ith, nth,
lora_sum_ptr);
},
nullptr, "fwd_lora_gu_add");
nullptr);
}
/**
@ -3581,7 +3562,7 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
// Convert BufferC to BF16 for step 2 input
bc->to_mat(m, lora_gate_intermediate_ptr_[expert_idx], ith, nth);
},
nullptr, "fwd_lora_down_a");
nullptr);
@ -3597,7 +3578,7 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
// Reuse gate intermediate buffer (no race condition for down projection)
lora_gate_intermediate_ba_[expert_idx]->from_mat(m, lora_gate_intermediate_ptr_[expert_idx], 0, 1);
},
nullptr, "fwd_lora_down_quant");
nullptr);
// =====================================================
// Step 3a: lora_intermediate @ down_lora_B^T -> lora_output (GEMM only)
@ -3620,7 +3601,7 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
// GEMM: [m, padded_lora_rank] @ [hidden_size, padded_lora_rank]^T -> [m, hidden_size]
amx::mat_mul(m, config_.hidden_size, padded_lora_rank_, ba, bb, bc, ith, nth);
},
nullptr, "fwd_lora_down_gemm", 1);
nullptr);
// =====================================================
// Step 3b: Add LoRA output to main output with scaling
@ -3642,7 +3623,7 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
add_lora_output_to_main(bc.get(), m_local_down_output_ptr_[expert_idx], m, config_.hidden_size, lora_scaling_,
ith, nth, &down_lora_sum);
},
nullptr, "fwd_lora_down_add");
nullptr);
// // Print LoRA contribution statistics
// size_t total_elements = 0;
@ -3791,7 +3772,7 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
output + t_start * inter_size, // output [local_num_tokens, inter_size]
local_num_tokens, rank, inter_size, scale);
},
nullptr, "fwd_lora_gu");
nullptr);
}
/**
@ -3856,7 +3837,7 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
output + t_start * hidden, // output [local_num_tokens, hidden]
local_num_tokens, rank, hidden, scale);
},
nullptr, "fwd_lora_down");
nullptr);
}
ForwardCache& push_cache() {
@ -3941,7 +3922,7 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
}
}
},
nullptr, "save_cache");
nullptr);
cache.valid = true;
}
@ -3968,7 +3949,7 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
memcpy(cache.intermediate_cache + cache_offsets_[i] * config_.intermediate_size,
m_local_gate_output_ptr_[expert_idx], num_tokens * config_.intermediate_size * sizeof(ggml_bf16_t));
},
nullptr, "save_inter_cache");
nullptr);
}
/**
@ -3993,7 +3974,7 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
memcpy(cache.down_output_cache + cache_offsets_[i] * config_.hidden_size, src_ptr,
num_tokens * config_.hidden_size * sizeof(ggml_bf16_t));
},
nullptr, "save_down_cache");
nullptr);
}
void backward_down(const ForwardCache& cache, const void* grad_output, void* grad_down_lora_a,
@ -4025,7 +4006,7 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
memset(reinterpret_cast<char*>(grad_intermediate_) + offset, 0, size);
}
},
nullptr, "bwd_down_memset");
nullptr);
}
// Scatter grad_output to per-expert buffers and compute gradients
@ -4176,7 +4157,7 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
}
}
},
nullptr, "bwd_down");
nullptr);
}
/**
@ -4198,13 +4179,13 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
int k = cache.k_cache;
constexpr int kSmallBwdDirectQlen = 0;
constexpr int kSmallBwdDirectMaxTasks = 16;
auto direct_or_pool = [&](int count, auto&& fn, const char* task_name, int block_size = 1) {
auto direct_or_pool = [&](int count, auto&& fn) {
if (qlen <= kSmallBwdDirectQlen && count <= kSmallBwdDirectMaxTasks) {
for (int i = 0; i < count; i++) {
fn(i);
}
} else {
pool->do_work_stealing_job(count, nullptr, fn, nullptr, task_name, block_size);
pool->do_work_stealing_job(count, nullptr, fn, nullptr);
}
};
@ -4258,8 +4239,7 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
int num_tokens = m_local_num_[expert_idx];
if (num_tokens == 0) return;
memset(grad_output_bf16_ptr_[expert_idx], 0, num_tokens * config_.hidden_size * sizeof(ggml_bf16_t));
},
"bwd_down_zero");
});
// =====================================================
// Step 2: Scatter grad_output to per-expert BF16 buffers
@ -4305,8 +4285,7 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
dst_row[h] = GGML_FP32_TO_BF16(cur);
}
}
},
"bwd_down_scatter");
});
}
// =====================================================
@ -4319,8 +4298,7 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
int num_tokens = m_local_num_[expert_idx];
if (num_tokens == 0) return;
grad_output_ba_[expert_idx]->from_mat(num_tokens, grad_output_bf16_ptr_[expert_idx], 0, 1);
},
"bwd_down_quantize");
});
// =====================================================
// Step 3+4: AMX GEMM + to_mat (merged to use same ith/nth)
@ -4365,7 +4343,7 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
// to_mat: Convert BufferC to BF16 - use same ith, nth as mat_mul!
bc->to_mat(m, grad_intermediate_ + expert_offsets[task_idx], ith, nth);
},
nullptr, "bwd_down_gemm", 1);
nullptr);
// =====================================================
// Step 3.5: Add LoRA contribution to grad_intermediate (AVX512)
@ -4417,8 +4395,7 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
expert_lora_a, // [rank, inter_size] BF16
grad_inter + t_start * inter_size, // [local_num_tokens, inter_size] BF16
local_num_tokens, rank, inter_size, scale);
},
"bwd_down_lora_to_inter");
});
}
@ -4600,7 +4577,7 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
}
}
},
nullptr, "bwd_down_lora_grad_AB");
nullptr);
} else {
struct LoraGradTask {
int expert_task = -1;
@ -4609,7 +4586,6 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
};
std::vector<LoraGradTask> lora_grad_tasks;
uint64_t grad_setup_start = sft_timer::get_trace_timestamp();
float* grad_b_accum_all = down_lora_grad_b_accum_pool_;
float* grad_a_accum_all = down_lora_grad_a_accum_pool_;
if (activated_expert > 0) {
@ -4630,8 +4606,6 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
lora_grad_tasks.push_back({task_id, t, std::min(t + token_tile, buf.num_tokens)});
}
}
uint64_t grad_setup_end = sft_timer::get_trace_timestamp();
sft_timer::add_kernel_trace("bwd_down_lora_grad_setup", grad_setup_start, grad_setup_end, tp_part_idx, 0);
if (!lora_grad_tasks.empty()) {
direct_or_pool(
@ -4748,8 +4722,7 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
grad_a_global[off] += grad_a_local[off];
}
}
},
"bwd_down_lora_grad_AB");
});
constexpr int kDownGradBTile = 512;
constexpr int kDownGradATile = 512;
@ -4788,7 +4761,7 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
}
}
},
nullptr, "bwd_down_lora_write_B");
nullptr);
pool->do_work_stealing_job(
activated_expert * grad_a_blocks, nullptr,
@ -4812,7 +4785,7 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
}
}
},
nullptr, "bwd_down_lora_write_A");
nullptr);
}
}
}
@ -4824,13 +4797,13 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
int qlen = cache.qlen_cache;
constexpr int kSmallBwdDirectQlen = 0;
constexpr int kSmallBwdDirectMaxTasks = 16;
auto direct_or_pool = [&](int count, auto&& fn, const char* task_name, int block_size = 1) {
auto direct_or_pool = [&](int count, auto&& fn) {
if (qlen <= kSmallBwdDirectQlen && count <= kSmallBwdDirectMaxTasks) {
for (int i = 0; i < count; i++) {
fn(i);
}
} else {
pool->do_work_stealing_job(count, nullptr, fn, nullptr, task_name, block_size);
pool->do_work_stealing_job(count, nullptr, fn, nullptr);
}
};
@ -4942,8 +4915,7 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
grad_gate[i] = GGML_FP32_TO_BF16(grad_i_val * u_val * sigmoid_val * (1.0f + g_val * (1.0f - sigmoid_val)));
grad_up[i] = GGML_FP32_TO_BF16(grad_i_val * silu_val);
}
},
"bwd_act_silu");
});
}
@ -4963,13 +4935,13 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
int k = cache.k_cache;
constexpr int kSmallBwdDirectQlen = 0;
constexpr int kSmallBwdDirectMaxTasks = 16;
auto direct_or_pool = [&](int count, auto&& fn, const char* task_name, int block_size = 1) {
auto direct_or_pool = [&](int count, auto&& fn) {
if (qlen <= kSmallBwdDirectQlen && count <= kSmallBwdDirectMaxTasks) {
for (int i = 0; i < count; i++) {
fn(i);
}
} else {
pool->do_work_stealing_job(count, nullptr, fn, nullptr, task_name, block_size);
pool->do_work_stealing_job(count, nullptr, fn, nullptr);
}
};
@ -5065,7 +5037,7 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
}
}
auto scatter_to_grad_input = [&](float scale, const char* task_name) {
auto scatter_to_grad_input = [&](float scale) {
ggml_bf16_t* grad_input_bf16 = (ggml_bf16_t*)grad_input;
const int hidden = config_.hidden_size;
const int hidden_vec_end = hidden & ~31;
@ -5101,13 +5073,10 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
dst[h] = GGML_FP32_TO_BF16(cur);
}
}
},
task_name);
});
};
auto base_pass = [&](bool do_up) {
const char* quant_name = do_up ? "bwd_gu_base_q_up" : "bwd_gu_base_q_gate";
const char* gemm_name = do_up ? "bwd_gu_base_gemm_up" : "bwd_gu_base_gemm_gate";
// Quantize grad to BufferA
direct_or_pool(
@ -5121,8 +5090,7 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
ggml_bf16_t* grad = do_up ? (grad_up_output_ + offset * config_.intermediate_size)
: (grad_gate_output_ + offset * config_.intermediate_size);
down_ba_[expert_idx]->from_mat(m, grad, 0, 1);
},
quant_name);
});
int nth = T::recommended_nth(config_.hidden_size);
pool->do_work_stealing_job(
@ -5141,9 +5109,9 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
amx::mat_mul(m, config_.hidden_size, config_.intermediate_size, ba, bb, bc, ith, nth);
bc->to_mat(m, grad_output_bf16_ptr_[expert_idx], ith, nth);
},
nullptr, gemm_name, 1);
nullptr);
scatter_to_grad_input(1.0f, "bwd_gu_scatter_base");
scatter_to_grad_input(1.0f);
};
base_pass(false); // gate
@ -5341,8 +5309,7 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
up_gradb_global[off] += up_gradb_local[off];
}
}
},
"bwd_gu_lora_u_gradb_fused");
});
constexpr int kGuGradBBlock = 256;
int gradb_blocks = (inter_size + kGuGradBBlock - 1) / kGuGradBBlock;
@ -5375,7 +5342,7 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
}
}
},
nullptr, "bwd_gu_lora_gradb_fused_write");
nullptr);
}
// =====================================================
@ -5386,8 +5353,6 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
// Gate and up still run sequentially because they share grad_output_bf16_ptr_.
// =====================================================
auto lora_pass_remainder = [&](bool do_up) {
const char* gb_gradin_name = do_up ? "bwd_gu_lora_gb_gradin_fused_up" : "bwd_gu_lora_gb_gradin_fused_gate";
const char* grad_a_name = do_up ? "bwd_gu_lora_gradA_up" : "bwd_gu_lora_gradA_gate";
struct GuLoraGradInTask {
int expert_task = -1;
@ -5443,11 +5408,10 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
memset(grad_out, 0, static_cast<size_t>(local_tokens) * hidden * sizeof(ggml_bf16_t));
avx::lora_fp32_bf16_fused_add_transposed(gb, lora_a, grad_out, local_tokens, lora_rank_, hidden, 1.0f);
},
gb_gradin_name);
});
}
scatter_to_grad_input(lora_scaling_, "bwd_gu_scatter_lora");
scatter_to_grad_input(lora_scaling_);
// Step 6: grad_A = G_B^T @ X
ggml_bf16_t* grad_lora_a = do_up ? grad_up_a : grad_gate_a;
@ -5547,7 +5511,7 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
}
}
},
nullptr, grad_a_name);
nullptr);
};
lora_pass_remainder(false); // gate: gb_gradin_fused, scatter, gradA

View file

@ -318,7 +318,7 @@ class TP_MOE_SFT : public TP_MOE<T> {
sizeof(ggml_bf16_t) * tpc.intermediate_size);
}
},
nullptr, "memcpy_weights_tmp");
nullptr);
}
// Set BF16 weight pointers on sub-MOEs for backward
@ -490,16 +490,13 @@ class TP_MOE_SFT : public TP_MOE<T> {
throw std::runtime_error("Weights not loaded");
}
auto start_sft = sft_timer::get_trace_timestamp();
int qlen = *qlen_ptr;
auto pool = config.pool;
// Reset forward timing before computation
// sft_timer::reset_forward();
// Reset per-thread counters in each subpool (to accumulate all do_work_stealing_job calls)
for (int i = 0; i < tp_count; i++) {
pool->get_subpool(i)->reset_counters();
}
// Run forward on each NUMA node
@ -508,24 +505,18 @@ class TP_MOE_SFT : public TP_MOE<T> {
save_for_backward);
});
auto end_fwd = sft_timer::get_trace_timestamp();
// // Collect per-thread timing from all NUMA subpools
// for (int i = 0; i < tp_count; i++) {
// sft_timer::collect_forward(pool->get_subpool(i));
// }
// // Print per-thread forward timing
// sft_timer::print_forward();
// Merge results from all NUMA nodes
this->merge_results(qlen, output);
auto end_merge = sft_timer::get_trace_timestamp();
pool->dispense_backend()->do_numa_job([&](int numa_id) {
sft_timer::add_kernel_trace("fwd", start_sft, end_fwd, numa_id, 0);
sft_timer::add_kernel_trace("merge", end_fwd, end_merge, numa_id, 0);
});
}
@ -561,7 +552,6 @@ class TP_MOE_SFT : public TP_MOE<T> {
void* grad_weights) {
auto pool = config.pool;
auto start_sft = sft_timer::get_trace_timestamp();
// Get full intermediate_size (before TP partitioning)
int full_intermediate_size = sft_config.intermediate_size;
@ -665,9 +655,8 @@ class TP_MOE_SFT : public TP_MOE<T> {
const auto& seg = clear_segs[(size_t)seg_idx];
std::memset(seg.ptr, 0, seg.len);
},
nullptr, "bwd_alloc_memset");
nullptr);
auto end_alloc = sft_timer::get_trace_timestamp();
// Compute TP-slice pointers for copy-type direct writes
// Each TP writes to its own I-slice of the final output tensor
@ -697,7 +686,6 @@ class TP_MOE_SFT : public TP_MOE<T> {
// Run backward on each NUMA node
pool->dispense_backend()->do_numa_job([&](int numa_id) {
auto start_Bwd = sft_timer::get_trace_timestamp();
tps[numa_id]->backward(grad_output, part_grad_input_[numa_id],
// reduce-type: BF16 pointer unused (FP32 sparse used instead)
nullptr, /* grad_gate_lora_a — unused, FP32 path below */
@ -708,18 +696,13 @@ class TP_MOE_SFT : public TP_MOE<T> {
nullptr, /* grad_down_lora_b — unused, FP32 path below */
part_grad_weights_[numa_id], full_intermediate_size, tp_fp32_down_b[numa_id],
tp_fp32_gate_a[numa_id], tp_fp32_up_a[numa_id]);
auto end_bwd = sft_timer::get_trace_timestamp();
sft_timer::add_kernel_trace("bwd_alloc", start_sft, end_alloc, numa_id, 0);
sft_timer::add_kernel_trace("bwd_tp", start_Bwd, end_bwd, numa_id, 0);
});
// // Collect per-thread timing from all NUMA subpools
// for (int i = 0; i < tp_count; i++) {
// sft_timer::collect_backward(pool->get_subpool(i));
// }
// // Print per-thread backward timing
// sft_timer::print_backward();
// // Print expert token distribution for load balancing analysis
// {
@ -739,7 +722,6 @@ class TP_MOE_SFT : public TP_MOE<T> {
// }
// Bug #22 fix: Merge grad_input from all NUMA nodes (sum them together)
auto start_sum = sft_timer::get_trace_timestamp();
{
auto* out = (ggml_bf16_t*)grad_input;
pool->do_work_stealing_job(
@ -784,13 +766,11 @@ class TP_MOE_SFT : public TP_MOE<T> {
dst[h] = GGML_FP32_TO_BF16(sum);
}
},
nullptr, "merge_grad_input");
nullptr);
}
auto end_sum = sft_timer::get_trace_timestamp();
// Merge reduce-type LoRA gradients: sparse FP32 sum across TPs → BF16 final output
// Copy-type grads (gate/up_lora_b, down_lora_a) were written directly — no merge needed.
auto start_merge = sft_timer::get_trace_timestamp();
if constexpr (!kSkipLoRA) {
// Sparse merge for gate_lora_a, up_lora_a: [active_count, r, H] FP32 → [E, r, H] BF16
{
@ -835,7 +815,7 @@ class TP_MOE_SFT : public TP_MOE<T> {
ud[h] = GGML_FP32_TO_BF16(us);
}
},
nullptr, "merge_lora_a");
nullptr);
}
// Sparse merge for down_lora_b: [active_count, H, r] FP32 → [E, H, r] BF16
@ -861,7 +841,7 @@ class TP_MOE_SFT : public TP_MOE<T> {
}
}
},
nullptr, "merge_down_lora_b");
nullptr);
}
} // if constexpr (!kSkipLoRA)
@ -900,14 +880,10 @@ class TP_MOE_SFT : public TP_MOE<T> {
out_grad_weights[i] = sum;
}
},
nullptr, "merge_grad_weights");
nullptr);
}
auto end_merge = sft_timer::get_trace_timestamp();
pool->dispense_backend()->do_numa_job([&](int numa_id) {
sft_timer::add_kernel_trace("merge_tp", start_sum, end_sum, numa_id, 0);
sft_timer::add_kernel_trace("merge_lora_a", end_sum, start_merge, numa_id, 0);
sft_timer::add_kernel_trace("merge_grad_weights", start_merge, end_merge, numa_id, 0);
});
}
@ -989,8 +965,7 @@ class TP_MOE_SFT : public TP_MOE<T> {
numa_id * tp_inter,
sizeof(ggml_bf16_t) * tp_inter);
}
},
"upd_lora_tp");
});
// Update weights after all memcpy complete
tps[numa_id]->update_lora_weights(gate_lora_a, partitioned_gate_lora_b_[numa_id], up_lora_a,
@ -1074,7 +1049,7 @@ class TP_MOE_SFT : public TP_MOE<T> {
sizeof(ggml_bf16_t) * tpc.intermediate_size);
}
},
nullptr, "memcpy_bwd_tmp");
nullptr);
tps[i]->prepare_bwd(temp_gate, temp_up, temp_down);

View file

@ -90,7 +90,7 @@ class KExpertsCPUBuffer:
hidden_size = hidden_states.shape[-1]
batch_size = hidden_states.shape[0]
pin_memory = False
pin_memory = True
if batch_size in cls.capture_buffers:
return cls.capture_buffers[batch_size]

View file

@ -818,35 +818,30 @@ class OnlineQuantConverter(ConverterBase):
print(f" [Fused] tensor {p} shape: {tuple(w.shape)}")
fused_tensors.append(w)
# fused_tensors[0] : down_proj, [E, H, I]
# fused_tensors[1] : gate_up_proj, [E, 2I, H]
# fused_tensors[0] : down-like, [E, I, H]
# fused_tensors[1] : gate_up-like, [E, H, 2I]
down_fused = fused_tensors[0]
gate_up_fused = fused_tensors[1]
# gate_up_fused is [E, 2I, H] — split on dim 1, no transpose needed
# gate_up_fused: [E, H, 2I] -> [E, 2I, H] -> gate / up
if gate_up_fused.dim() != 3:
raise ValueError(
f"[Fused] Expect gate_up fused tensor to be 3D, got shape {tuple(gate_up_fused.shape)}"
)
E = gate_up_fused.shape[0]
I = self.moe_intermediate_size
H = self.hidden_size
E, H, twoI = gate_up_fused.shape
if twoI % 2 != 0:
raise ValueError(f"[Fused] gate_up last dim (2I) not even: {twoI}")
I = twoI // 2
if gate_up_fused.shape != (E, 2 * I, H):
raise ValueError(
f"[Fused] gate_up shape {tuple(gate_up_fused.shape)} != expected ({E}, {2*I}, {H}). "
f"If your model stores gate_up as [E, H, 2I], transpose is needed."
)
gate_proj = gate_up_fused[:, :I, :].contiguous() # [E, I, H]
up_proj = gate_up_fused[:, I:, :].contiguous() # [E, I, H]
gate_up_T = gate_up_fused.transpose(1, 2).contiguous() # [E, 2I, H]
gate_proj = gate_up_T[:, :I, :] # [E, I, H]
up_proj = gate_up_T[:, I:, :] # [E, I, H]
if down_fused.dim() != 3:
raise ValueError(f"[Fused] Expect down fused tensor to be 3D, got shape {tuple(down_fused.shape)}")
if down_fused.shape[0] != E:
raise ValueError(f"[Fused] down_fused expert dim mismatch: {down_fused.shape[0]} vs gate_up {E}")
# down_proj is [E, H, I] — matches load_weights_from_tensors expectation, no transpose
down_proj = down_fused.contiguous() # [E, H, I]
down_proj = down_fused.transpose(1, 2).contiguous() # [E, H, I]
del fused_tensors
del gate_up_fused
del down_fused