mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-04-28 03:39:48 +00:00
align sft branch with main: revert worker_pool, strip sft_timer, fix inference defaults
- Revert worker_pool.cpp/.h to main (remove RDTSC timer, Chrome Trace, sft_timer namespace, ITT API, extended do_work_stealing_job API) - Strip all sft_timer instrumentation from sft-only files (sft_moe.hpp, moe-sft-tp.hpp, avx_kernels.hpp) - Restore pin_memory=True in KExpertsCPUBuffer (inference path) - Restore fused tensor transpose logic in convert_cpu_weights.py (main layout) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
168e10f254
commit
a789729923
7 changed files with 159 additions and 766 deletions
|
|
@ -13,437 +13,20 @@
|
|||
#include <hwloc/bitmap.h>
|
||||
#include <numa.h>
|
||||
#include <numaif.h>
|
||||
#include <signal.h>
|
||||
#include <x86intrin.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <chrono>
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
#include <fstream>
|
||||
#include <iomanip>
|
||||
#include <mutex>
|
||||
#include <numeric>
|
||||
#include <sstream>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "hwloc.h"
|
||||
|
||||
// RDTSC-based timer for lightweight timing
|
||||
// Uses CPU timestamp counter instead of system clock for lower overhead
|
||||
namespace {
|
||||
|
||||
// Read CPU timestamp counter (RDTSC)
|
||||
inline uint64_t rdtsc_now() { return __rdtsc(); }
|
||||
|
||||
// Estimate RDTSC cycles for given milliseconds
|
||||
// This is calculated once at startup
|
||||
static uint64_t g_rdtsc_cycles_per_ms = 0;
|
||||
|
||||
// Initialize RDTSC frequency by measuring against chrono
|
||||
static uint64_t init_rdtsc_frequency() {
|
||||
auto start_chrono = std::chrono::high_resolution_clock::now();
|
||||
uint64_t start_rdtsc = rdtsc_now();
|
||||
|
||||
// Busy wait for ~10ms to calibrate
|
||||
while (true) {
|
||||
auto now = std::chrono::high_resolution_clock::now();
|
||||
auto elapsed = std::chrono::duration_cast<std::chrono::milliseconds>(now - start_chrono).count();
|
||||
if (elapsed >= 10) break;
|
||||
}
|
||||
|
||||
uint64_t end_rdtsc = rdtsc_now();
|
||||
auto end_chrono = std::chrono::high_resolution_clock::now();
|
||||
auto elapsed_ms = std::chrono::duration_cast<std::chrono::milliseconds>(end_chrono - start_chrono).count();
|
||||
|
||||
if (elapsed_ms > 0) {
|
||||
return (end_rdtsc - start_rdtsc) / elapsed_ms;
|
||||
}
|
||||
// Fallback: assume 2.5 GHz CPU
|
||||
return 2500000;
|
||||
}
|
||||
|
||||
// Get cycles per millisecond (lazy initialization)
|
||||
inline uint64_t get_rdtsc_cycles_per_ms() {
|
||||
if (g_rdtsc_cycles_per_ms == 0) {
|
||||
g_rdtsc_cycles_per_ms = init_rdtsc_frequency();
|
||||
}
|
||||
return g_rdtsc_cycles_per_ms;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// =====================================================
|
||||
// Global per-thread timing for SFT MOE forward/backward
|
||||
// Collects timing from InNumaPool worker threads
|
||||
// =====================================================
|
||||
#ifndef SFT_TIMER_DISABLED
|
||||
namespace sft_timer {
|
||||
|
||||
constexpr int MAX_THREADS = 256;
|
||||
static uint64_t forward_rt[MAX_THREADS] = {0};
|
||||
static uint64_t backward_rt[MAX_THREADS] = {0};
|
||||
static int forward_tasks[MAX_THREADS] = {0};
|
||||
static int backward_tasks[MAX_THREADS] = {0};
|
||||
static int forward_threads = 0;
|
||||
static int backward_threads = 0;
|
||||
|
||||
inline double ticks_to_ms(uint64_t cycles) { return (double)cycles / get_rdtsc_cycles_per_ms(); }
|
||||
|
||||
// =====================================================
|
||||
// Chrome Trace Event Format support
|
||||
// =====================================================
|
||||
struct TraceEvent {
|
||||
std::string name; // event name (op_name)
|
||||
std::string cat; // category
|
||||
char ph; // phase: 'X' for complete event, 'B' for begin, 'E' for end
|
||||
double ts; // timestamp in microseconds (with ns precision via decimals)
|
||||
double dur; // duration in microseconds (with ns precision via decimals)
|
||||
int pid; // process id (numa_id)
|
||||
int tid; // thread id
|
||||
int task_count; // number of tasks processed
|
||||
std::string args_json; // optional custom args JSON (for kernel traces)
|
||||
};
|
||||
|
||||
static std::vector<TraceEvent> g_trace_events;
|
||||
static std::mutex g_trace_mutex;
|
||||
static uint64_t g_trace_start_time = 0; // baseline timestamp (RDTSC)
|
||||
static double g_trace_start_epoch_us = 0.0; // wall-clock epoch time in microseconds
|
||||
static std::string g_trace_output_path = "sft_trace.json";
|
||||
|
||||
// Thread-safe initialization using std::call_once
|
||||
static std::once_flag g_trace_init_flag;
|
||||
|
||||
// Forward declaration for atexit registration.
|
||||
static void write_trace_to_file();
|
||||
|
||||
// Initialize trace start time (thread-safe)
|
||||
static void init_trace() {
|
||||
std::call_once(g_trace_init_flag, []() {
|
||||
g_trace_start_time = rdtsc_now();
|
||||
// Record wall-clock epoch time for cross-process trace alignment
|
||||
auto now_wall = std::chrono::system_clock::now();
|
||||
auto epoch_us = std::chrono::duration_cast<std::chrono::microseconds>(now_wall.time_since_epoch()).count();
|
||||
g_trace_start_epoch_us = static_cast<double>(epoch_us);
|
||||
// Check for custom output path from environment
|
||||
const char* env_path = std::getenv("SFT_TRACE_PATH");
|
||||
if (env_path && env_path[0] != '\0') {
|
||||
g_trace_output_path = env_path;
|
||||
}
|
||||
// Flush trace on normal exit before static destructors run.
|
||||
std::atexit(write_trace_to_file);
|
||||
});
|
||||
}
|
||||
|
||||
// Convert RDTSC cycles to microseconds with nanosecond precision (as double)
|
||||
// Chrome tracing uses microseconds but supports fractional values for sub-us precision
|
||||
static double cycles_to_us(uint64_t cycles) {
|
||||
// cycles_per_ms * 1000 = cycles_per_us
|
||||
// cycles / cycles_per_us = microseconds
|
||||
// Using 1e6 for cycles_per_ms -> cycles_per_s, then divide to get us with ns precision
|
||||
double cycles_per_us = get_rdtsc_cycles_per_ms() / 1000.0;
|
||||
return static_cast<double>(cycles) / cycles_per_us;
|
||||
}
|
||||
|
||||
// Add trace events for an operation using absolute timestamps
|
||||
static void add_trace_events(const char* op_name, int numa_id, int thread_count, const uint64_t* start_ts_arr,
|
||||
const uint64_t* end_ts_arr, const int* tasks) {
|
||||
init_trace();
|
||||
|
||||
std::lock_guard<std::mutex> lock(g_trace_mutex);
|
||||
|
||||
for (int i = 0; i < thread_count; i++) {
|
||||
// Convert absolute RDTSC timestamps to relative microseconds from trace start
|
||||
double start_us = (start_ts_arr[i] > g_trace_start_time) ? cycles_to_us(start_ts_arr[i] - g_trace_start_time) : 0.0;
|
||||
double end_us = (end_ts_arr[i] > g_trace_start_time) ? cycles_to_us(end_ts_arr[i] - g_trace_start_time) : 0.0;
|
||||
double dur_us = end_us - start_us;
|
||||
if (dur_us < 0) dur_us = 0;
|
||||
|
||||
TraceEvent ev;
|
||||
ev.name = op_name;
|
||||
ev.cat = "sft_op";
|
||||
ev.ph = 'X'; // Complete event
|
||||
ev.ts = start_us;
|
||||
ev.dur = dur_us;
|
||||
ev.pid = numa_id;
|
||||
ev.tid = i;
|
||||
ev.task_count = tasks[i];
|
||||
|
||||
g_trace_events.push_back(ev);
|
||||
}
|
||||
}
|
||||
|
||||
// Write trace events to JSON file (Chrome Trace Event Format)
|
||||
static void write_trace_to_file() {
|
||||
std::lock_guard<std::mutex> lock(g_trace_mutex);
|
||||
|
||||
if (g_trace_events.empty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Sort events by (pid, tid, ts) to fix overlap issues in Chrome trace viewer
|
||||
// Events from same thread should be ordered by start time
|
||||
std::sort(g_trace_events.begin(), g_trace_events.end(), [](const TraceEvent& a, const TraceEvent& b) {
|
||||
if (a.pid != b.pid) return a.pid < b.pid;
|
||||
if (a.tid != b.tid) return a.tid < b.tid;
|
||||
return a.ts < b.ts;
|
||||
});
|
||||
|
||||
std::ofstream ofs(g_trace_output_path);
|
||||
if (!ofs.is_open()) {
|
||||
fprintf(stderr, "sft_timer: Failed to open trace file: %s\n", g_trace_output_path.c_str());
|
||||
return;
|
||||
}
|
||||
|
||||
// Use fixed precision for nanosecond accuracy (3 decimal places in microseconds = nanoseconds)
|
||||
ofs << std::fixed << std::setprecision(3);
|
||||
|
||||
ofs << "{\n";
|
||||
ofs << " \"traceEvents\": [\n";
|
||||
|
||||
for (size_t i = 0; i < g_trace_events.size(); i++) {
|
||||
const auto& ev = g_trace_events[i];
|
||||
ofs << " {";
|
||||
ofs << "\"name\":\"" << ev.name << "\",";
|
||||
ofs << "\"cat\":\"" << ev.cat << "\",";
|
||||
ofs << "\"ph\":\"" << ev.ph << "\",";
|
||||
ofs << "\"ts\":" << ev.ts << ",";
|
||||
ofs << "\"dur\":" << ev.dur << ",";
|
||||
ofs << "\"pid\":" << ev.pid << ",";
|
||||
ofs << "\"tid\":" << ev.tid << ",";
|
||||
if (!ev.args_json.empty()) {
|
||||
ofs << "\"args\":" << ev.args_json;
|
||||
} else {
|
||||
ofs << "\"args\":{\"task_count\":" << ev.task_count << "}";
|
||||
}
|
||||
ofs << "}";
|
||||
if (i < g_trace_events.size() - 1) {
|
||||
ofs << ",";
|
||||
}
|
||||
ofs << "\n";
|
||||
}
|
||||
|
||||
ofs << " ],\n";
|
||||
ofs << " \"metadata\": {\"start_epoch_us\": " << std::setprecision(0) << g_trace_start_epoch_us << "},\n";
|
||||
ofs << std::setprecision(3);
|
||||
ofs << " \"displayTimeUnit\": \"ns\"\n";
|
||||
ofs << "}\n";
|
||||
|
||||
ofs.close();
|
||||
fprintf(stderr, "sft_timer: Trace written to %s (%zu events)\n", g_trace_output_path.c_str(), g_trace_events.size());
|
||||
}
|
||||
|
||||
// Signal handler for SIGTERM
|
||||
static void sigterm_handler(int sig) {
|
||||
fprintf(stderr, "sft_timer: Received signal %d, writing trace...\n", sig);
|
||||
write_trace_to_file();
|
||||
// Re-raise the signal with default handler to allow normal termination
|
||||
signal(sig, SIG_DFL);
|
||||
raise(sig);
|
||||
}
|
||||
|
||||
// Register signal handlers
|
||||
static void register_signal_handlers() {
|
||||
static bool registered = false;
|
||||
if (!registered) {
|
||||
signal(SIGTERM, sigterm_handler);
|
||||
signal(SIGINT, sigterm_handler);
|
||||
registered = true;
|
||||
}
|
||||
}
|
||||
|
||||
void print_rt(FILE* out, const char* name, uint64_t* rt, int* tasks, int rt_threads) {
|
||||
if (rt_threads <= 0) return;
|
||||
FILE* output = out ? out : stderr;
|
||||
auto max_val = *std::max_element(rt, rt + rt_threads);
|
||||
auto min_val = *std::min_element(rt, rt + rt_threads);
|
||||
uint64_t sum = std::accumulate(rt, rt + rt_threads, (uint64_t)0);
|
||||
int total_tasks = std::accumulate(tasks, tasks + rt_threads, 0);
|
||||
|
||||
// Sort to find 20% and 80% percentile thresholds
|
||||
std::vector<uint64_t> sorted(rt, rt + rt_threads);
|
||||
std::sort(sorted.begin(), sorted.end());
|
||||
int p20_idx = rt_threads * 20 / 100;
|
||||
int p80_idx = rt_threads * 80 / 100;
|
||||
uint64_t p20_threshold = sorted[p20_idx]; // Fast threshold (top 20%)
|
||||
uint64_t p80_threshold = sorted[p80_idx]; // Slow threshold (bottom 20%)
|
||||
|
||||
// ANSI color codes
|
||||
const char* GREEN = "\033[32m";
|
||||
const char* RED = "\033[31m";
|
||||
const char* RESET = "\033[0m";
|
||||
|
||||
// Line 1: time
|
||||
fprintf(output, "%30s max %.3f min %.3f avg %.3f : ", name, ticks_to_ms(max_val), ticks_to_ms(min_val),
|
||||
ticks_to_ms(sum / rt_threads));
|
||||
for (int i = 0; i < rt_threads; i++) {
|
||||
if (rt[i] <= p20_threshold) {
|
||||
fprintf(output, "%s%.3f%s ", GREEN, ticks_to_ms(rt[i]), RESET);
|
||||
} else if (rt[i] >= p80_threshold) {
|
||||
fprintf(output, "%s%.3f%s ", RED, ticks_to_ms(rt[i]), RESET);
|
||||
} else {
|
||||
fprintf(output, "%.3f ", ticks_to_ms(rt[i]));
|
||||
}
|
||||
}
|
||||
fprintf(output, "\n");
|
||||
|
||||
// Line 2: task count
|
||||
fprintf(output, "%30s total %d : ", "tasks", total_tasks);
|
||||
for (int i = 0; i < rt_threads; i++) {
|
||||
if (rt[i] <= p20_threshold) {
|
||||
fprintf(output, "%s%d%s ", GREEN, tasks[i], RESET);
|
||||
} else if (rt[i] >= p80_threshold) {
|
||||
fprintf(output, "%s%d%s ", RED, tasks[i], RESET);
|
||||
} else {
|
||||
fprintf(output, "%d ", tasks[i]);
|
||||
}
|
||||
}
|
||||
fprintf(output, "\n");
|
||||
}
|
||||
|
||||
void reset_forward() {
|
||||
std::fill(forward_rt, forward_rt + MAX_THREADS, 0);
|
||||
std::fill(forward_tasks, forward_tasks + MAX_THREADS, 0);
|
||||
forward_threads = 0;
|
||||
}
|
||||
|
||||
void reset_backward() {
|
||||
std::fill(backward_rt, backward_rt + MAX_THREADS, 0);
|
||||
std::fill(backward_tasks, backward_tasks + MAX_THREADS, 0);
|
||||
backward_threads = 0;
|
||||
}
|
||||
|
||||
void collect_forward(InNumaPool* pool) {
|
||||
int n = pool->get_worker_count();
|
||||
for (int i = 0; i < n && forward_threads < MAX_THREADS; i++) {
|
||||
forward_rt[forward_threads] = pool->get_thread_cycles(i);
|
||||
forward_tasks[forward_threads] = pool->get_thread_task_count(i);
|
||||
forward_threads++;
|
||||
}
|
||||
}
|
||||
|
||||
void collect_backward(InNumaPool* pool) {
|
||||
int n = pool->get_worker_count();
|
||||
for (int i = 0; i < n && backward_threads < MAX_THREADS; i++) {
|
||||
backward_rt[backward_threads] = pool->get_thread_cycles(i);
|
||||
backward_tasks[backward_threads] = pool->get_thread_task_count(i);
|
||||
backward_threads++;
|
||||
}
|
||||
}
|
||||
|
||||
void print_forward() { print_rt(stderr, "forward", forward_rt, forward_tasks, forward_threads); }
|
||||
void print_backward(const char* name) { print_rt(stderr, name, backward_rt, backward_tasks, backward_threads); }
|
||||
|
||||
void print_op_stats(InNumaPool* pool, const char* op_name) {
|
||||
if (pool == nullptr || op_name == nullptr || op_name[0] == '\0') {
|
||||
return;
|
||||
}
|
||||
int n = pool->get_worker_count();
|
||||
if (n <= 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Ensure signal handlers are registered on first call
|
||||
static bool handlers_registered = false;
|
||||
if (!handlers_registered) {
|
||||
register_signal_handlers();
|
||||
handlers_registered = true;
|
||||
}
|
||||
|
||||
FILE* output = stderr;
|
||||
int numa_id = pool->get_numa_id();
|
||||
// if (numa_id == 0) {
|
||||
// output = stdout;
|
||||
// } else if (numa_id == 1) {
|
||||
// output = stderr;
|
||||
// }
|
||||
std::vector<uint64_t> rt(n);
|
||||
std::vector<uint64_t> start_ts(n);
|
||||
std::vector<uint64_t> end_ts(n);
|
||||
std::vector<int> tasks(n);
|
||||
for (int i = 0; i < n; i++) {
|
||||
rt[i] = pool->get_thread_cycles(i);
|
||||
tasks[i] = pool->get_thread_task_count(i);
|
||||
start_ts[i] = pool->get_thread_start_ts(i);
|
||||
end_ts[i] = pool->get_thread_end_ts(i);
|
||||
}
|
||||
// print_rt(output, op_name, rt.data(), tasks.data(), n);
|
||||
|
||||
// Save trace data to memory for later export
|
||||
add_trace_events(op_name, numa_id, n, start_ts.data(), end_ts.data(), tasks.data());
|
||||
}
|
||||
|
||||
// =====================================================
|
||||
// Kernel-level tracing API implementation
|
||||
// =====================================================
|
||||
|
||||
uint64_t get_trace_timestamp() { return rdtsc_now(); }
|
||||
|
||||
void add_kernel_trace(const char* name, uint64_t start_ts, uint64_t end_ts, int numa_id, int thread_id,
|
||||
const char* args) {
|
||||
init_trace();
|
||||
|
||||
// Convert absolute RDTSC timestamps to relative microseconds from trace start
|
||||
double start_us = (start_ts > g_trace_start_time) ? cycles_to_us(start_ts - g_trace_start_time) : 0.0;
|
||||
double end_us = (end_ts > g_trace_start_time) ? cycles_to_us(end_ts - g_trace_start_time) : 0.0;
|
||||
double dur_us = end_us - start_us;
|
||||
if (dur_us < 0) dur_us = 0;
|
||||
|
||||
std::lock_guard<std::mutex> lock(g_trace_mutex);
|
||||
|
||||
TraceEvent ev;
|
||||
ev.name = name;
|
||||
ev.cat = "kernel";
|
||||
ev.ph = 'X'; // Complete event
|
||||
ev.ts = start_us;
|
||||
ev.dur = dur_us;
|
||||
ev.pid = numa_id;
|
||||
ev.tid = thread_id;
|
||||
ev.task_count = 0; // Not applicable for kernel traces
|
||||
if (args != nullptr && args[0] != '\0') {
|
||||
ev.args_json = args;
|
||||
}
|
||||
|
||||
g_trace_events.push_back(ev);
|
||||
}
|
||||
|
||||
} // namespace sft_timer
|
||||
#endif // SFT_TIMER_DISABLED
|
||||
|
||||
// Intel ITT API for profiler integration (VTune, etc.)
|
||||
// Allows profilers to identify spin-wait regions
|
||||
#ifdef USE_ITT_NOTIFY
|
||||
#include <ittnotify.h>
|
||||
static __itt_domain* g_itt_domain = nullptr;
|
||||
static __itt_string_handle* g_itt_spin_wait = nullptr;
|
||||
|
||||
static void init_itt() {
|
||||
if (g_itt_domain == nullptr) {
|
||||
g_itt_domain = __itt_domain_create("WorkerPool");
|
||||
g_itt_spin_wait = __itt_string_handle_create("SpinWait");
|
||||
}
|
||||
}
|
||||
|
||||
#define ITT_SYNC_PREPARE(addr) __itt_sync_prepare(addr)
|
||||
#define ITT_SYNC_CANCEL(addr) __itt_sync_cancel(addr)
|
||||
#define ITT_SYNC_ACQUIRED(addr) __itt_sync_acquired(addr)
|
||||
#else
|
||||
#define ITT_SYNC_PREPARE(addr) ((void)0)
|
||||
#define ITT_SYNC_CANCEL(addr) ((void)0)
|
||||
#define ITT_SYNC_ACQUIRED(addr) ((void)0)
|
||||
static void init_itt() {}
|
||||
#endif
|
||||
|
||||
thread_local int WorkerPool::thread_local_id = -1;
|
||||
|
||||
InNumaPool::InNumaPool(int max_thread_num) {
|
||||
printf("In Numa Worker Pool at NUMA %d, %d threads\n", numa_node_of_cpu(sched_getcpu()), max_thread_num);
|
||||
numa_id_ = numa_node_of_cpu(sched_getcpu());
|
||||
total_worker_count = max_thread_num;
|
||||
block_size_ = 0;
|
||||
set_restricted_worker_count(total_worker_count);
|
||||
thread_state_ = std::unique_ptr<ThreadState[]>(new ThreadState[max_thread_num]);
|
||||
for (int i = 0; i < total_worker_count; i++) {
|
||||
|
|
@ -463,9 +46,7 @@ InNumaPool::InNumaPool(int max_thread_num, int numa_id, int threads_id_start) {
|
|||
hwloc_topology_init(&topology);
|
||||
hwloc_topology_load(topology);
|
||||
printf("In Numa Worker Pool at NUMA %d, %d threads\n", numa_node_of_cpu(sched_getcpu()), max_thread_num);
|
||||
numa_id_ = numa_id;
|
||||
total_worker_count = max_thread_num;
|
||||
block_size_ = 0;
|
||||
set_restricted_worker_count(total_worker_count);
|
||||
thread_state_ = std::unique_ptr<ThreadState[]>(new ThreadState[max_thread_num]);
|
||||
for (int i = 0; i < total_worker_count; i++) {
|
||||
|
|
@ -514,7 +95,11 @@ InNumaPool::InNumaPool(int max_thread_num, int numa_id, int threads_id_start) {
|
|||
|
||||
InNumaPool::~InNumaPool() {
|
||||
for (int i = 0; i < total_worker_count; i++) {
|
||||
thread_state_[i].status.store(ThreadStatus::EXIT, std::memory_order_release);
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(thread_state_[i].mutex);
|
||||
thread_state_[i].status.store(ThreadStatus::EXIT, std::memory_order_release);
|
||||
}
|
||||
thread_state_[i].cv.notify_one();
|
||||
}
|
||||
for (int i = 0; i < total_worker_count; i++) {
|
||||
if (workers_[i].joinable()) {
|
||||
|
|
@ -551,51 +136,41 @@ void InNumaPool::wait() {
|
|||
#endif
|
||||
}
|
||||
|
||||
void InNumaPool::do_work_stealing_job(int task_num, std::function<void(int)> compute_func, const char* task_name,
|
||||
int block_size, bool async) {
|
||||
do_work_stealing_job(task_num, nullptr, compute_func, nullptr, task_name, block_size);
|
||||
void InNumaPool::do_work_stealing_job(int task_num, std::function<void(int)> compute_func) {
|
||||
do_work_stealing_job(task_num, nullptr, compute_func, nullptr);
|
||||
}
|
||||
|
||||
void InNumaPool::do_work_stealing_job(int task_num, std::function<void(int)> init_func,
|
||||
std::function<void(int)> compute_func, std::function<void(int)> finalize_func,
|
||||
const char* task_name, int block_size, bool async) {
|
||||
bool has_name = task_name != nullptr && task_name[0] != '\0';
|
||||
if (has_name) {
|
||||
reset_counters();
|
||||
}
|
||||
do_work_stealing_job_async(task_num, init_func, compute_func, finalize_func, block_size);
|
||||
if (!async) wait();
|
||||
if (has_name) {
|
||||
sft_timer::print_op_stats(this, task_name);
|
||||
}
|
||||
std::function<void(int)> compute_func, std::function<void(int)> finalize_func) {
|
||||
do_work_stealing_job_async(task_num, init_func, compute_func, finalize_func);
|
||||
wait();
|
||||
}
|
||||
|
||||
void InNumaPool::do_work_stealing_job_async(int task_num, std::function<void(int)> init_func,
|
||||
std::function<void(int)> compute_func,
|
||||
std::function<void(int)> finalize_func, int block_size) {
|
||||
std::function<void(int)> finalize_func) {
|
||||
init_func_ = init_func;
|
||||
compute_func_ = compute_func;
|
||||
finalize_func_ = finalize_func;
|
||||
block_size_ = block_size;
|
||||
worker_count = std::min(restricted_worker_count, task_num);
|
||||
curr_.store(0, std::memory_order_release);
|
||||
end_ = task_num;
|
||||
|
||||
for (int i = 0; i < worker_count; i++) {
|
||||
thread_state_[i].status.store(ThreadStatus::WORKING, std::memory_order_release);
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(thread_state_[i].mutex);
|
||||
thread_state_[i].status.store(ThreadStatus::WORKING, std::memory_order_release);
|
||||
}
|
||||
thread_state_[i].cv.notify_one();
|
||||
}
|
||||
WorkerPool::thread_local_id = 0;
|
||||
process_tasks(0);
|
||||
}
|
||||
|
||||
void InNumaPool::process_tasks(int thread_id) {
|
||||
uint64_t start_cycles = rdtsc_now();
|
||||
#ifdef PROFILE_BALANCE
|
||||
auto start = std::chrono::high_resolution_clock::now();
|
||||
#endif
|
||||
auto& s = thread_state_[thread_id];
|
||||
int local_task_count = 0;
|
||||
|
||||
// Record absolute start timestamp
|
||||
s.start_ts = start_cycles;
|
||||
|
||||
if (init_func_ != nullptr) {
|
||||
init_func_(thread_id);
|
||||
}
|
||||
|
|
@ -608,12 +183,7 @@ void InNumaPool::process_tasks(int thread_id) {
|
|||
break;
|
||||
}
|
||||
|
||||
int block = 0;
|
||||
if (block_size_ > 0) {
|
||||
block = std::min(block_size_, rem);
|
||||
} else {
|
||||
block = (rem + worker_count - 1) / worker_count;
|
||||
}
|
||||
int block = (rem + worker_count - 1) / worker_count;
|
||||
block = 1;
|
||||
int task_id = curr_.fetch_add(block, std::memory_order_acq_rel);
|
||||
if (task_id >= end_) {
|
||||
|
|
@ -625,7 +195,6 @@ void InNumaPool::process_tasks(int thread_id) {
|
|||
break;
|
||||
}
|
||||
compute_func_(task_id + i);
|
||||
local_task_count++;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -633,44 +202,34 @@ void InNumaPool::process_tasks(int thread_id) {
|
|||
finalize_func_(thread_id);
|
||||
}
|
||||
|
||||
// IMPORTANT: Update timing BEFORE setting status to WAITING
|
||||
// The release semantics of status.store() ensures all prior writes are visible
|
||||
uint64_t end_cycles = rdtsc_now();
|
||||
s.finish_cycles = end_cycles - start_cycles;
|
||||
s.task_count = local_task_count;
|
||||
s.end_ts = end_cycles;
|
||||
|
||||
// Signal completion - release ensures timing writes are visible to wait()
|
||||
s.status.store(ThreadStatus::WAITING, std::memory_order_release);
|
||||
#ifdef PROFILE_BALANCE
|
||||
s.finish_ns =
|
||||
std::chrono::duration_cast<std::chrono::nanoseconds>(std::chrono::high_resolution_clock::now() - start).count();
|
||||
#endif
|
||||
}
|
||||
|
||||
void InNumaPool::worker_thread(int thread_id, int numa_id) {
|
||||
if (numa_id >= 0) {
|
||||
set_memory_to_numa(numa_id);
|
||||
}
|
||||
init_itt(); // Initialize ITT if enabled
|
||||
// Use RDTSC for lightweight timing instead of std::chrono
|
||||
const uint64_t sleep_threshold_cycles = get_rdtsc_cycles_per_ms() * 50; // 50ms in cycles
|
||||
uint64_t start = rdtsc_now();
|
||||
auto start = std::chrono::high_resolution_clock::now();
|
||||
WorkerPool::thread_local_id = thread_id; // 设置线程本地变量
|
||||
while (true) {
|
||||
ITT_SYNC_PREPARE(&thread_state_[thread_id].status); // Signal profiler: about to spin-wait
|
||||
ThreadStatus status = thread_state_[thread_id].status.load(std::memory_order_acquire);
|
||||
if (status == ThreadStatus::WORKING) {
|
||||
ITT_SYNC_ACQUIRED(&thread_state_[thread_id].status); // Signal profiler: acquired work
|
||||
process_tasks(thread_id);
|
||||
start = rdtsc_now();
|
||||
start = std::chrono::high_resolution_clock::now();
|
||||
} else if (status == ThreadStatus::WAITING) {
|
||||
// PAUSE instruction hints to CPU this is a spin-wait loop
|
||||
_mm_pause();
|
||||
uint64_t now = rdtsc_now();
|
||||
uint64_t elapsed_cycles = now - start;
|
||||
if (elapsed_cycles > sleep_threshold_cycles) {
|
||||
ITT_SYNC_CANCEL(&thread_state_[thread_id].status); // Signal profiler: going to sleep
|
||||
std::this_thread::sleep_for(std::chrono::microseconds(100));
|
||||
auto now = std::chrono::high_resolution_clock::now();
|
||||
auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(now - start).count();
|
||||
if (duration > 50) {
|
||||
std::unique_lock<std::mutex> lock(thread_state_[thread_id].mutex);
|
||||
thread_state_[thread_id].cv.wait(lock, [&] {
|
||||
return thread_state_[thread_id].status.load(std::memory_order_acquire) != ThreadStatus::WAITING;
|
||||
});
|
||||
}
|
||||
} else if (status == ThreadStatus::EXIT) {
|
||||
ITT_SYNC_CANCEL(&thread_state_[thread_id].status); // Signal profiler: exiting
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
|
@ -695,6 +254,8 @@ void NumaJobDistributor::init(std::vector<int> numa_ids) {
|
|||
this->numa_ids = numa_ids;
|
||||
for (size_t i = 0; i < numa_count; i++) {
|
||||
status.push_back(nullptr);
|
||||
mutexes.push_back(std::make_unique<std::mutex>());
|
||||
cvs.push_back(std::make_unique<std::condition_variable>());
|
||||
}
|
||||
|
||||
workers.resize(numa_count);
|
||||
|
|
@ -716,6 +277,8 @@ void NumaJobDistributor::init(std::vector<int> numa_ids, std::vector<int> thread
|
|||
this->numa_ids = numa_ids;
|
||||
for (size_t i = 0; i < numa_count; i++) {
|
||||
status.push_back(nullptr);
|
||||
mutexes.push_back(std::make_unique<std::mutex>());
|
||||
cvs.push_back(std::make_unique<std::condition_variable>());
|
||||
}
|
||||
|
||||
workers.resize(numa_count);
|
||||
|
|
@ -767,7 +330,11 @@ void NumaJobDistributor::init(std::vector<int> numa_ids, std::vector<int> thread
|
|||
|
||||
NumaJobDistributor::~NumaJobDistributor() {
|
||||
for (int i = 0; i < numa_count; i++) {
|
||||
status[i]->store(ThreadStatus::EXIT, std::memory_order_release);
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(*mutexes[i]);
|
||||
status[i]->store(ThreadStatus::EXIT, std::memory_order_release);
|
||||
}
|
||||
cvs[i]->notify_one();
|
||||
}
|
||||
for (int i = 0; i < numa_count; i++) {
|
||||
if (workers[i].joinable()) {
|
||||
|
|
@ -784,7 +351,11 @@ void NumaJobDistributor::do_numa_job(std::function<void(int)> compute_func) {
|
|||
for (int i = 0; i < numa_count; i++) {
|
||||
if (i == me_numa) continue;
|
||||
|
||||
status[i]->store(ThreadStatus::WORKING, std::memory_order_release);
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(*mutexes[i]);
|
||||
status[i]->store(ThreadStatus::WORKING, std::memory_order_release);
|
||||
}
|
||||
cvs[i]->notify_one();
|
||||
}
|
||||
compute_func(me_numa);
|
||||
for (int i = 0; i < numa_count; i++) {
|
||||
|
|
@ -798,7 +369,11 @@ void NumaJobDistributor::do_numa_job(std::function<void(int)> compute_func) {
|
|||
void NumaJobDistributor::do_numa_job(std::function<void(int)> compute_func) {
|
||||
this->compute_func = compute_func;
|
||||
for (int i = 0; i < numa_count; i++) {
|
||||
status[i]->store(ThreadStatus::WORKING, std::memory_order_release);
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(*mutexes[i]);
|
||||
status[i]->store(ThreadStatus::WORKING, std::memory_order_release);
|
||||
}
|
||||
cvs[i]->notify_one();
|
||||
}
|
||||
for (int i = 0; i < numa_count; i++) {
|
||||
while (status[i]->load(std::memory_order_acquire) == ThreadStatus::WORKING) {
|
||||
|
|
@ -808,35 +383,29 @@ void NumaJobDistributor::do_numa_job(std::function<void(int)> compute_func) {
|
|||
#endif
|
||||
|
||||
void NumaJobDistributor::worker_thread(int numa_id) {
|
||||
init_itt(); // Initialize ITT if enabled
|
||||
// Use RDTSC for lightweight timing instead of std::chrono
|
||||
const uint64_t sleep_threshold_cycles = get_rdtsc_cycles_per_ms() * 50; // 50ms in cycles
|
||||
uint64_t start = rdtsc_now();
|
||||
auto start = std::chrono::high_resolution_clock::now();
|
||||
set_memory_to_numa(numa_id);
|
||||
status[numa_id] =
|
||||
std::move(std::unique_ptr<std::atomic<ThreadStatus>>(new std::atomic<ThreadStatus>(ThreadStatus::WAITING)));
|
||||
ready_bar->arrive_and_wait();
|
||||
while (true) {
|
||||
ITT_SYNC_PREPARE(status[numa_id].get()); // Signal profiler: about to spin-wait
|
||||
auto stat = status[numa_id]->load(std::memory_order_acquire);
|
||||
if (stat == ThreadStatus::WORKING) {
|
||||
ITT_SYNC_ACQUIRED(status[numa_id].get()); // Signal profiler: acquired work
|
||||
auto me_numa = numa_node_of_cpu(sched_getcpu());
|
||||
// printf("numa work on %d, me %d\n", numa_id, me_numa);
|
||||
compute_func(numa_id);
|
||||
status[numa_id]->store(ThreadStatus::WAITING, std::memory_order_release);
|
||||
start = rdtsc_now();
|
||||
start = std::chrono::high_resolution_clock::now();
|
||||
} else if (stat == ThreadStatus::WAITING) {
|
||||
// PAUSE instruction hints to CPU this is a spin-wait loop
|
||||
_mm_pause();
|
||||
uint64_t now = rdtsc_now();
|
||||
uint64_t elapsed_cycles = now - start;
|
||||
if (elapsed_cycles > sleep_threshold_cycles) {
|
||||
ITT_SYNC_CANCEL(status[numa_id].get()); // Signal profiler: going to sleep
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(1));
|
||||
auto now = std::chrono::high_resolution_clock::now();
|
||||
auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(now - start).count();
|
||||
if (duration > 50) {
|
||||
std::unique_lock<std::mutex> lock(*mutexes[numa_id]);
|
||||
cvs[numa_id]->wait(lock, [&] {
|
||||
return status[numa_id]->load(std::memory_order_acquire) != ThreadStatus::WAITING;
|
||||
});
|
||||
}
|
||||
} else if (stat == ThreadStatus::EXIT) {
|
||||
ITT_SYNC_CANCEL(status[numa_id].get()); // Signal profiler: exiting
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
|
@ -911,15 +480,10 @@ InNumaPool* WorkerPool::get_subpool(int numa_id) { return numa_worker_pools[numa
|
|||
NumaJobDistributor* WorkerPool::dispense_backend() { return distributor.get(); }
|
||||
|
||||
void WorkerPool::do_work_stealing_job(int task_num, std::function<void(int)> init_func,
|
||||
std::function<void(int)> compute_func, std::function<void(int)> finalize_func,
|
||||
const char* task_name, int block_size, bool async) {
|
||||
numa_worker_pools[0]->do_work_stealing_job(task_num, init_func, compute_func, finalize_func, task_name, block_size,
|
||||
async);
|
||||
std::function<void(int)> compute_func, std::function<void(int)> finalize_func) {
|
||||
numa_worker_pools[0]->do_work_stealing_job(task_num, init_func, compute_func, finalize_func);
|
||||
}
|
||||
|
||||
void WorkerPool::do_work_stealing_job(int task_num, std::function<void(int)> compute_func, const char* task_name,
|
||||
int block_size, bool async) {
|
||||
do_work_stealing_job(task_num, nullptr, compute_func, nullptr, task_name, block_size, async);
|
||||
void WorkerPool::do_work_stealing_job(int task_num, std::function<void(int)> compute_func) {
|
||||
do_work_stealing_job(task_num, nullptr, compute_func, nullptr);
|
||||
}
|
||||
|
||||
void WorkerPool::wait() { numa_worker_pools[0]->wait(); }
|
||||
|
|
|
|||
|
|
@ -62,10 +62,11 @@ enum ThreadStatus {
|
|||
|
||||
struct alignas(64) ThreadState {
|
||||
std::atomic<ThreadStatus> status;
|
||||
uint64_t finish_cycles; // Per-thread timing (always enabled)
|
||||
int task_count; // Per-thread task count
|
||||
uint64_t start_ts; // Absolute start timestamp (RDTSC)
|
||||
uint64_t end_ts; // Absolute end timestamp (RDTSC)
|
||||
std::mutex mutex;
|
||||
std::condition_variable cv;
|
||||
#ifdef PROFILE_BALANCE
|
||||
size_t finish_ns;
|
||||
#endif
|
||||
};
|
||||
|
||||
class InNumaPool {
|
||||
|
|
@ -76,45 +77,21 @@ class InNumaPool {
|
|||
int get_thread_num();
|
||||
void set_restricted_worker_count(int count);
|
||||
|
||||
void do_work_stealing_job_async(int, std::function<void(int)>, std::function<void(int)>, std::function<void(int)>,
|
||||
int block_size = 0);
|
||||
void do_work_stealing_job_async(int, std::function<void(int)>, std::function<void(int)>, std::function<void(int)>);
|
||||
void wait();
|
||||
|
||||
void do_work_stealing_job(int, std::function<void(int)>, std::function<void(int)>, std::function<void(int)>,
|
||||
const char* task_name = nullptr, int block_size = 0, bool async = false);
|
||||
void do_work_stealing_job(int, std::function<void(int)>, const char* task_name = nullptr, int block_size = 0,
|
||||
bool async = false);
|
||||
|
||||
// Get per-thread timing info
|
||||
int get_worker_count() const { return worker_count; }
|
||||
int get_numa_id() const { return numa_id_; }
|
||||
uint64_t get_thread_cycles(int tid) const { return thread_state_[tid].finish_cycles; }
|
||||
int get_thread_task_count(int tid) const { return thread_state_[tid].task_count; }
|
||||
uint64_t get_thread_start_ts(int tid) const { return thread_state_[tid].start_ts; }
|
||||
uint64_t get_thread_end_ts(int tid) const { return thread_state_[tid].end_ts; }
|
||||
|
||||
// Reset per-thread timing/task counters (call before timing a sequence of operations)
|
||||
// NOTE: Only call when all workers are in WAITING state (after wait() returns)
|
||||
void reset_counters() {
|
||||
for (int i = 0; i < total_worker_count; i++) {
|
||||
thread_state_[i].finish_cycles = 0;
|
||||
thread_state_[i].task_count = 0;
|
||||
thread_state_[i].start_ts = 0;
|
||||
thread_state_[i].end_ts = 0;
|
||||
}
|
||||
}
|
||||
void do_work_stealing_job(int, std::function<void(int)>, std::function<void(int)>, std::function<void(int)>);
|
||||
void do_work_stealing_job(int, std::function<void(int)>);
|
||||
|
||||
private:
|
||||
int worker_count;
|
||||
int total_worker_count;
|
||||
int numa_id_;
|
||||
|
||||
std::unique_ptr<ThreadState[]> thread_state_; // [thread_num]
|
||||
std::vector<std::thread> workers_;
|
||||
|
||||
// changed ever time called do_work_stealing_job_async
|
||||
int restricted_worker_count;
|
||||
int block_size_;
|
||||
std::function<void(int)> init_func_;
|
||||
std::function<void(int)> compute_func_;
|
||||
std::function<void(int)> finalize_func_;
|
||||
|
|
@ -144,6 +121,8 @@ class NumaJobDistributor {
|
|||
int numa_count;
|
||||
std::vector<int> numa_ids;
|
||||
std::vector<std::unique_ptr<std::atomic<ThreadStatus>>> status;
|
||||
std::vector<std::unique_ptr<std::mutex>> mutexes;
|
||||
std::vector<std::unique_ptr<std::condition_variable>> cvs;
|
||||
std::function<void(int)> compute_func;
|
||||
std::vector<std::thread> workers;
|
||||
|
||||
|
|
@ -171,12 +150,8 @@ class WorkerPool {
|
|||
|
||||
InNumaPool* get_subpool(int numa_id);
|
||||
|
||||
void do_work_stealing_job(int, std::function<void(int)>, std::function<void(int)>, std::function<void(int)>,
|
||||
const char* task_name = nullptr, int block_size = 0, bool async = false);
|
||||
void do_work_stealing_job(int, std::function<void(int)>, const char* task_name = nullptr, int block_size = 0,
|
||||
bool async = false);
|
||||
|
||||
void wait();
|
||||
void do_work_stealing_job(int, std::function<void(int)>, std::function<void(int)>, std::function<void(int)>);
|
||||
void do_work_stealing_job(int, std::function<void(int)>);
|
||||
|
||||
WorkerPoolConfig config;
|
||||
|
||||
|
|
@ -191,58 +166,4 @@ class WorkerPool {
|
|||
std::vector<std::unique_ptr<InNumaPool>> numa_worker_pools;
|
||||
};
|
||||
|
||||
// =====================================================
|
||||
// Global per-thread timing for SFT MOE forward/backward
|
||||
// =====================================================
|
||||
// Define SFT_TIMER_DISABLED to disable all timing (functions become no-ops)
|
||||
// #define SFT_TIMER_DISABLED
|
||||
namespace sft_timer {
|
||||
|
||||
#ifdef SFT_TIMER_DISABLED
|
||||
// Disabled: all functions are no-ops
|
||||
inline void reset_forward() {}
|
||||
inline void reset_backward() {}
|
||||
inline void collect_forward(InNumaPool*) {}
|
||||
inline void collect_backward(InNumaPool*) {}
|
||||
inline void print_forward() {}
|
||||
inline void print_backward(const char* = "backward") {}
|
||||
inline void print_op_stats(InNumaPool*, const char*) {}
|
||||
inline uint64_t get_trace_timestamp() { return 0; }
|
||||
inline void add_kernel_trace(const char*, uint64_t, uint64_t, int, int, const char* = nullptr) {}
|
||||
#else
|
||||
// Enabled: declarations only, implementation in worker_pool.cpp
|
||||
void reset_forward();
|
||||
void reset_backward();
|
||||
void collect_forward(InNumaPool* pool);
|
||||
void collect_backward(InNumaPool* pool);
|
||||
void print_forward();
|
||||
void print_backward(const char* name = "backward");
|
||||
|
||||
// Print per-thread timing for a single operation
|
||||
// Call pool->reset_counters() BEFORE the operation, then call this AFTER
|
||||
void print_op_stats(InNumaPool* pool, const char* op_name);
|
||||
|
||||
// =====================================================
|
||||
// Kernel-level tracing API
|
||||
// For tracing individual kernels (e.g., AVX matmul) within worker threads
|
||||
// =====================================================
|
||||
|
||||
// Get current RDTSC timestamp (lightweight, ~20 cycles overhead)
|
||||
uint64_t get_trace_timestamp();
|
||||
|
||||
// Add a kernel trace event
|
||||
// @param name Kernel name (e.g., "lora_bf16_matmul_t4r4")
|
||||
// @param start_ts Start timestamp from get_trace_timestamp()
|
||||
// @param end_ts End timestamp from get_trace_timestamp()
|
||||
// @param numa_id NUMA node ID (use -1 for auto-detect or 0 if unknown)
|
||||
// @param thread_id Thread ID within the pool (use WorkerPool::thread_local_id)
|
||||
// @param args Optional JSON args string (e.g., "{\"tokens\":128,\"rank\":8}")
|
||||
void add_kernel_trace(const char* name, uint64_t start_ts, uint64_t end_ts, int numa_id, int thread_id,
|
||||
const char* args = nullptr);
|
||||
|
||||
static void write_trace_to_file(); // Write all collected traces to a file (e.g., "sft_kernel_traces.json")
|
||||
#endif
|
||||
|
||||
} // namespace sft_timer
|
||||
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -48,9 +48,6 @@ namespace avx {
|
|||
*/
|
||||
inline void lora_bf16_matmul_t4r4(const ggml_bf16_t* __restrict input, const ggml_bf16_t* __restrict weight,
|
||||
float* __restrict output, int num_tokens, int k_dim, int rank) {
|
||||
// #if AVX_KERNEL_TRACE_ENABLED
|
||||
// uint64_t trace_start = sft_timer::get_trace_timestamp();
|
||||
// #endif
|
||||
|
||||
constexpr int T_BLOCK = 4;
|
||||
constexpr int R_BLOCK = 4;
|
||||
|
|
@ -240,13 +237,6 @@ inline void lora_bf16_matmul_t4r4(const ggml_bf16_t* __restrict input, const ggm
|
|||
}
|
||||
}
|
||||
|
||||
// #if AVX_KERNEL_TRACE_ENABLED
|
||||
// uint64_t trace_end = sft_timer::get_trace_timestamp();
|
||||
// char args_buf[128];
|
||||
// snprintf(args_buf, sizeof(args_buf), "{\"T\":%d,\"K\":%d,\"R\":%d}", num_tokens, k_dim, rank);
|
||||
// sft_timer::add_kernel_trace("lora_bf16_matmul_t4r4", trace_start, trace_end, 0, WorkerPool::thread_local_id,
|
||||
// args_buf);
|
||||
// #endif
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -275,7 +265,6 @@ inline void lora_fp32_bf16_fused_add(const float* __restrict intermediate, const
|
|||
ggml_bf16_t* __restrict output, int num_tokens, int rank, int output_dim,
|
||||
float scale) {
|
||||
#if AVX_KERNEL_TRACE_ENABLED
|
||||
uint64_t trace_start = sft_timer::get_trace_timestamp();
|
||||
#endif
|
||||
|
||||
constexpr int T_BLOCK = 4;
|
||||
|
|
@ -579,11 +568,8 @@ inline void lora_fp32_bf16_fused_add(const float* __restrict intermediate, const
|
|||
}
|
||||
|
||||
#if AVX_KERNEL_TRACE_ENABLED
|
||||
uint64_t trace_end = sft_timer::get_trace_timestamp();
|
||||
char args_buf[128];
|
||||
snprintf(args_buf, sizeof(args_buf), "{\"T\":%d,\"R\":%d,\"O\":%d}", num_tokens, rank, output_dim);
|
||||
sft_timer::add_kernel_trace("lora_fp32_bf16_fused_add", trace_start, trace_end, 0, WorkerPool::thread_local_id,
|
||||
args_buf);
|
||||
#endif
|
||||
}
|
||||
|
||||
|
|
@ -615,7 +601,6 @@ inline void lora_fp32_bf16_fused_add_wt(const float* __restrict intermediate, co
|
|||
ggml_bf16_t* __restrict output, int num_tokens, int rank, int output_dim,
|
||||
float scale) {
|
||||
#if AVX_KERNEL_TRACE_ENABLED
|
||||
uint64_t trace_start = sft_timer::get_trace_timestamp();
|
||||
#endif
|
||||
|
||||
constexpr int T_BLOCK = 4;
|
||||
|
|
@ -886,11 +871,8 @@ inline void lora_fp32_bf16_fused_add_wt(const float* __restrict intermediate, co
|
|||
}
|
||||
|
||||
#if AVX_KERNEL_TRACE_ENABLED
|
||||
uint64_t trace_end = sft_timer::get_trace_timestamp();
|
||||
char args_buf[128];
|
||||
snprintf(args_buf, sizeof(args_buf), "{\"T\":%d,\"R\":%d,\"O\":%d}", num_tokens, rank, output_dim);
|
||||
sft_timer::add_kernel_trace("lora_fp32_bf16_fused_add_wt", trace_start, trace_end, 0, WorkerPool::thread_local_id,
|
||||
args_buf);
|
||||
#endif
|
||||
}
|
||||
|
||||
|
|
@ -1190,7 +1172,6 @@ inline void lora_fp32_bf16_fused_add_transposed(const float* __restrict intermed
|
|||
const ggml_bf16_t* __restrict weight_t, ggml_bf16_t* __restrict output,
|
||||
int num_tokens, int rank, int output_dim, float scale) {
|
||||
#if AVX_KERNEL_TRACE_ENABLED
|
||||
uint64_t trace_start = sft_timer::get_trace_timestamp();
|
||||
#endif
|
||||
|
||||
constexpr int T_BLOCK = 4;
|
||||
|
|
@ -1384,11 +1365,8 @@ inline void lora_fp32_bf16_fused_add_transposed(const float* __restrict intermed
|
|||
}
|
||||
|
||||
#if AVX_KERNEL_TRACE_ENABLED
|
||||
uint64_t trace_end = sft_timer::get_trace_timestamp();
|
||||
char args_buf[128];
|
||||
snprintf(args_buf, sizeof(args_buf), "{\"T\":%d,\"R\":%d,\"O\":%d}", num_tokens, rank, output_dim);
|
||||
sft_timer::add_kernel_trace("lora_fp32_bf16_fused_add_transposed", trace_start, trace_end, 0,
|
||||
WorkerPool::thread_local_id, args_buf);
|
||||
#endif
|
||||
}
|
||||
|
||||
|
|
@ -1408,7 +1386,6 @@ inline void lora_fp32_bf16_fused_add_transposed(const float* __restrict intermed
|
|||
inline void lora_backward_matmul_transposed(const ggml_bf16_t* __restrict grad, const ggml_bf16_t* __restrict lora_b_t,
|
||||
float* __restrict result, int num_tokens, int hidden, int rank) {
|
||||
#if AVX_KERNEL_TRACE_ENABLED
|
||||
uint64_t trace_start = sft_timer::get_trace_timestamp();
|
||||
#endif
|
||||
|
||||
constexpr int H_BLOCK = 32;
|
||||
|
|
@ -1448,11 +1425,8 @@ inline void lora_backward_matmul_transposed(const ggml_bf16_t* __restrict grad,
|
|||
}
|
||||
|
||||
#if AVX_KERNEL_TRACE_ENABLED
|
||||
uint64_t trace_end = sft_timer::get_trace_timestamp();
|
||||
char args_buf[128];
|
||||
snprintf(args_buf, sizeof(args_buf), "{\"T\":%d,\"H\":%d,\"R\":%d}", num_tokens, hidden, rank);
|
||||
sft_timer::add_kernel_trace("lora_backward_matmul_transposed", trace_start, trace_end, 0, WorkerPool::thread_local_id,
|
||||
args_buf);
|
||||
#endif
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1075,13 +1075,13 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
|
|||
}
|
||||
|
||||
// Step 3: Copy input to expert buffers
|
||||
auto direct_or_pool = [&](int count, auto&& fn, const char* task_name, int block_size) {
|
||||
auto direct_or_pool = [&](int count, auto&& fn) {
|
||||
if (qlen < 10) {
|
||||
for (int i = 0; i < count; i++) {
|
||||
fn(i);
|
||||
}
|
||||
} else {
|
||||
pool->do_work_stealing_job(count, nullptr, fn, nullptr, task_name, block_size);
|
||||
pool->do_work_stealing_job(count, nullptr, fn, nullptr);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -1095,8 +1095,7 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
|
|||
memcpy(m_local_input_ptr_[expert_ids[i * k + j]] + m_local_pos_[i][j] * config_.hidden_size,
|
||||
(ggml_bf16_t*)input + i * config_.hidden_size, sizeof(ggml_bf16_t) * config_.hidden_size);
|
||||
}
|
||||
},
|
||||
"fwd_pack_input", 1);
|
||||
});
|
||||
|
||||
// NaN Check: Step 3 - Packed input
|
||||
if (is_nan_check_enabled()) {
|
||||
|
|
@ -1118,8 +1117,7 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
|
|||
[this](int task_id) {
|
||||
int expert_idx = m_expert_id_map_[task_id];
|
||||
gate_up_ba_[expert_idx]->from_mat(m_local_num_[expert_idx], m_local_input_ptr_[expert_idx], 0, 1);
|
||||
},
|
||||
"fwd_quantize_in", 1);
|
||||
});
|
||||
|
||||
// Step 5: Gate + Up GEMM (base projection)
|
||||
int nth = T::recommended_nth(config_.intermediate_size);
|
||||
|
|
@ -1137,7 +1135,7 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
|
|||
gate_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_gate_output_ptr_[expert_idx], ith, nth);
|
||||
}
|
||||
},
|
||||
nullptr, "fwd_gate_up_gemm", 1);
|
||||
nullptr);
|
||||
|
||||
// NaN Check: Step 5 - Gate/Up GEMM output (before LoRA)
|
||||
if (is_nan_check_enabled()) {
|
||||
|
|
@ -1230,10 +1228,7 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
|
|||
|
||||
// Step 6: Activation (silu(gate) * up)
|
||||
{
|
||||
uint64_t act_start = sft_timer::get_trace_timestamp();
|
||||
Base::apply_activation(activated_expert, nth, qlen);
|
||||
uint64_t act_end = sft_timer::get_trace_timestamp();
|
||||
sft_timer::add_kernel_trace("apply_activation", act_start, act_end, tp_part_idx, 0);
|
||||
}
|
||||
|
||||
// NaN Check: Step 6 - Activation output (silu(gate) * up)
|
||||
|
|
@ -1295,7 +1290,7 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
|
|||
int expert_idx = m_expert_id_map_[task_id];
|
||||
down_ba_[expert_idx]->from_mat(m_local_num_[expert_idx], m_local_gate_output_ptr_[expert_idx], 0, 1);
|
||||
},
|
||||
nullptr, "fwd_down_quantize");
|
||||
nullptr);
|
||||
|
||||
// Step 8: Down GEMM
|
||||
nth = T::recommended_nth(config_.hidden_size);
|
||||
|
|
@ -1307,7 +1302,7 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
|
|||
this->do_down_gemm(expert_idx, ith, nth, qlen);
|
||||
down_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_down_output_ptr_[expert_idx], ith, nth);
|
||||
},
|
||||
nullptr, "fwd_down_gemm", 1);
|
||||
nullptr);
|
||||
|
||||
// NaN Check: Step 8 - Down GEMM output (before LoRA)
|
||||
if (is_nan_check_enabled()) {
|
||||
|
|
@ -1373,7 +1368,7 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
|
|||
f32out[1] = x1;
|
||||
}
|
||||
},
|
||||
nullptr, "fwd_merge");
|
||||
nullptr);
|
||||
|
||||
// NaN Check: Step 9 - Final output (after weighted merge)
|
||||
if (is_nan_check_enabled()) {
|
||||
|
|
@ -1428,12 +1423,6 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
|
|||
int activated_expert = cache.activated_expert_cache;
|
||||
constexpr int kSmallBwdDirectQlen = 0;
|
||||
constexpr int kSmallBwdDirectMaxTasks = 16;
|
||||
auto trace_phase = [this](const char* name, auto&& fn) {
|
||||
uint64_t start = sft_timer::get_trace_timestamp();
|
||||
fn();
|
||||
uint64_t end = sft_timer::get_trace_timestamp();
|
||||
sft_timer::add_kernel_trace(name, start, end, tp_part_idx, 0);
|
||||
};
|
||||
|
||||
// NaN Check: grad_output input
|
||||
if (is_nan_check_enabled()) {
|
||||
|
|
@ -1492,9 +1481,8 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
|
|||
}
|
||||
|
||||
// ★ Allocate backward-phase buffers ★
|
||||
trace_phase("bwd_setup_alloc", [&] { alloc_backward_buffers(); });
|
||||
alloc_backward_buffers();
|
||||
|
||||
trace_phase("bwd_setup_state", [&] {
|
||||
// ★ share_backward_bb: check if async repack already prepared this layer ★
|
||||
if (config_.share_backward_bb) {
|
||||
auto& shared = SFTSharedPools::instance();
|
||||
|
|
@ -1544,12 +1532,10 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
|
|||
m_local_down_output_ptr_[i] = m_local_down_output_ + offset * config_.hidden_size;
|
||||
offset += m_local_num_[i];
|
||||
}
|
||||
});
|
||||
|
||||
// Restore input data from cache into m_local_input_ (shared_mem_buffer may have been
|
||||
// overwritten by subsequent layers' forward passes). This is needed for gate/up LoRA
|
||||
// gradient computation which reads from m_local_input_ptr_.
|
||||
trace_phase("bwd_phase_down", [&] {
|
||||
auto pool_local = config_.pool->get_subpool(tp_part_idx);
|
||||
auto restore_input = [&](int i) {
|
||||
for (int j = 0; j < k; j++) {
|
||||
|
|
@ -1569,7 +1555,7 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
|
|||
restore_input(i);
|
||||
}
|
||||
} else {
|
||||
pool_local->do_work_stealing_job(qlen, nullptr, restore_input, nullptr, "bwd_restore_input", 1);
|
||||
pool_local->do_work_stealing_job(qlen, nullptr, restore_input, nullptr);
|
||||
}
|
||||
|
||||
// Step 1: Down projection backward
|
||||
|
|
@ -1579,7 +1565,6 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
|
|||
} else {
|
||||
// backward_down(cache, grad_output, grad_down_lora_a, grad_down_lora_b);
|
||||
}
|
||||
});
|
||||
|
||||
// // Compute total tokens for debug
|
||||
// size_t total_tokens = 0;
|
||||
|
|
@ -1655,7 +1640,7 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
|
|||
// }
|
||||
// }
|
||||
|
||||
trace_phase("bwd_phase_act", [&] { backward_activation(cache); });
|
||||
backward_activation(cache);
|
||||
|
||||
// NaN Check: Step 2 - After backward_activation
|
||||
if (is_nan_check_enabled()) {
|
||||
|
|
@ -1695,14 +1680,12 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
|
|||
// }
|
||||
// }
|
||||
|
||||
trace_phase("bwd_phase_gate_up", [&] {
|
||||
if constexpr (supports_standard_mat_mul_v<T>) {
|
||||
backward_gate_up_amx(cache, grad_input, grad_gate_lora_a, grad_gate_lora_b, grad_up_lora_a, grad_up_lora_b,
|
||||
full_intermediate_size, fp32_grad_gate_lora_a, fp32_grad_up_lora_a);
|
||||
} else {
|
||||
// backward_gate_up(cache, grad_input, grad_gate_lora_a, grad_gate_lora_b, grad_up_lora_a, grad_up_lora_b);
|
||||
}
|
||||
});
|
||||
|
||||
// NaN Check: Step 3 - After backward_gate_up
|
||||
if (is_nan_check_enabled()) {
|
||||
|
|
@ -1740,7 +1723,6 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
|
|||
// Step 4: Compute grad_weights (gradient for routing weights)
|
||||
// grad_weights[token_idx, expert_pos] = dot(grad_output[token_idx], down_output[token, expert])
|
||||
if (grad_weights != nullptr) {
|
||||
trace_phase("bwd_phase_gradw", [&] {
|
||||
auto pool = config_.pool->get_subpool(tp_part_idx);
|
||||
float* grad_w = (float*)grad_weights;
|
||||
const ggml_bf16_t* grad_out = (const ggml_bf16_t*)grad_output;
|
||||
|
|
@ -1788,9 +1770,8 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
|
|||
compute_grad_weight(token_idx);
|
||||
}
|
||||
} else {
|
||||
pool->do_work_stealing_job(qlen, nullptr, compute_grad_weight, nullptr, "bwd_grad_weights");
|
||||
pool->do_work_stealing_job(qlen, nullptr, compute_grad_weight, nullptr);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// NaN Check: Step 4 - After grad_weights computation
|
||||
|
|
@ -2106,7 +2087,7 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
|
|||
break;
|
||||
}
|
||||
},
|
||||
nullptr, "transpose_lora_b_weights");
|
||||
nullptr);
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -2167,7 +2148,7 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
|
|||
break;
|
||||
}
|
||||
},
|
||||
nullptr, "fwd_lora_prep");
|
||||
nullptr);
|
||||
|
||||
lora_weights_prepared_ = true;
|
||||
}
|
||||
|
|
@ -2212,7 +2193,7 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
|
|||
break;
|
||||
}
|
||||
},
|
||||
nullptr, "bwd_lora_prep");
|
||||
nullptr);
|
||||
|
||||
lora_backward_weights_prepared_ = true;
|
||||
}
|
||||
|
|
@ -2265,7 +2246,7 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
|
|||
dst_bb->from_mat_transposed((ggml_bf16_t*)(src + expert_offset), config_.intermediate_size,
|
||||
config_.hidden_size, ith, nth_gate_up);
|
||||
},
|
||||
nullptr, "bwd_prep_gate_up");
|
||||
nullptr);
|
||||
|
||||
// Phase 2: down backward
|
||||
// down_proj: [hidden_size, intermediate_size] -> transposed BufferB [intermediate_size, hidden_size]
|
||||
|
|
@ -2281,7 +2262,7 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
|
|||
down_backward_bb_[expert_idx]->from_mat_transposed((ggml_bf16_t*)(src + expert_offset), config_.hidden_size,
|
||||
config_.intermediate_size, ith, nth_down);
|
||||
},
|
||||
nullptr, "bwd_prep_down");
|
||||
nullptr);
|
||||
|
||||
backward_weights_prepared_ = true;
|
||||
}
|
||||
|
|
@ -2317,7 +2298,7 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
|
|||
dst_bb->from_mat_transposed(workspace.data(), src_bb->n, src_bb->k, p, dst_nth);
|
||||
}
|
||||
},
|
||||
nullptr, "bwd_repack_gate_up");
|
||||
nullptr);
|
||||
|
||||
// Phase 2: down (uses [hidden_size, intermediate_size] -> [intermediate_size, hidden_size])
|
||||
pool->do_work_stealing_job(
|
||||
|
|
@ -2339,7 +2320,7 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
|
|||
dst_bb->from_mat_transposed(workspace.data(), src_bb->n, src_bb->k, p, dst_nth);
|
||||
}
|
||||
},
|
||||
nullptr, "bwd_repack_down");
|
||||
nullptr);
|
||||
|
||||
backward_weights_prepared_ = true;
|
||||
}
|
||||
|
|
@ -2543,7 +2524,7 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
|
|||
down_backward_bb_[expert_idx].get());
|
||||
}
|
||||
},
|
||||
nullptr, "load_bwd_kt");
|
||||
nullptr);
|
||||
|
||||
if (!ok.load()) return false;
|
||||
backward_weights_prepared_ = true;
|
||||
|
|
@ -2596,7 +2577,7 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
|
|||
}
|
||||
}
|
||||
},
|
||||
nullptr, "load_bwd_projs");
|
||||
nullptr);
|
||||
|
||||
backward_weights_prepared_ = true;
|
||||
}
|
||||
|
|
@ -3421,7 +3402,7 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
|
|||
do_up ? lora_up_intermediate_ptr_[expert_idx] : lora_gate_intermediate_ptr_[expert_idx];
|
||||
bc->to_mat(m, inter_ptr, ith, nth);
|
||||
},
|
||||
nullptr, "fwd_lora_gu_a");
|
||||
nullptr);
|
||||
|
||||
// =====================================================
|
||||
// Step 2: Quantize lora_intermediate to BufferA
|
||||
|
|
@ -3439,7 +3420,7 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
|
|||
ggml_bf16_t* ptr = do_up ? lora_up_intermediate_ptr_[expert_idx] : lora_gate_intermediate_ptr_[expert_idx];
|
||||
ba->from_mat(m, ptr, 0, 1);
|
||||
},
|
||||
nullptr, "fwd_lora_gu_quant");
|
||||
nullptr);
|
||||
|
||||
// =====================================================
|
||||
// Step 3a: lora_intermediate @ lora_B^T -> lora_output (GEMM only)
|
||||
|
|
@ -3464,7 +3445,7 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
|
|||
// GEMM: [m, padded_lora_rank] @ [intermediate_size, padded_lora_rank]^T -> [m, intermediate_size]
|
||||
amx::mat_mul(m, config_.intermediate_size, padded_lora_rank_, ba, bb, bc, ith, nth);
|
||||
},
|
||||
nullptr, "fwd_lora_gu_gemm");
|
||||
nullptr);
|
||||
|
||||
// =====================================================
|
||||
// Step 3b: Add LoRA output to main output with scaling
|
||||
|
|
@ -3489,7 +3470,7 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
|
|||
add_lora_output_to_main(bc.get(), main_output, m, config_.intermediate_size, lora_scaling_, ith, nth,
|
||||
lora_sum_ptr);
|
||||
},
|
||||
nullptr, "fwd_lora_gu_add");
|
||||
nullptr);
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -3581,7 +3562,7 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
|
|||
// Convert BufferC to BF16 for step 2 input
|
||||
bc->to_mat(m, lora_gate_intermediate_ptr_[expert_idx], ith, nth);
|
||||
},
|
||||
nullptr, "fwd_lora_down_a");
|
||||
nullptr);
|
||||
|
||||
|
||||
|
||||
|
|
@ -3597,7 +3578,7 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
|
|||
// Reuse gate intermediate buffer (no race condition for down projection)
|
||||
lora_gate_intermediate_ba_[expert_idx]->from_mat(m, lora_gate_intermediate_ptr_[expert_idx], 0, 1);
|
||||
},
|
||||
nullptr, "fwd_lora_down_quant");
|
||||
nullptr);
|
||||
|
||||
// =====================================================
|
||||
// Step 3a: lora_intermediate @ down_lora_B^T -> lora_output (GEMM only)
|
||||
|
|
@ -3620,7 +3601,7 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
|
|||
// GEMM: [m, padded_lora_rank] @ [hidden_size, padded_lora_rank]^T -> [m, hidden_size]
|
||||
amx::mat_mul(m, config_.hidden_size, padded_lora_rank_, ba, bb, bc, ith, nth);
|
||||
},
|
||||
nullptr, "fwd_lora_down_gemm", 1);
|
||||
nullptr);
|
||||
|
||||
// =====================================================
|
||||
// Step 3b: Add LoRA output to main output with scaling
|
||||
|
|
@ -3642,7 +3623,7 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
|
|||
add_lora_output_to_main(bc.get(), m_local_down_output_ptr_[expert_idx], m, config_.hidden_size, lora_scaling_,
|
||||
ith, nth, &down_lora_sum);
|
||||
},
|
||||
nullptr, "fwd_lora_down_add");
|
||||
nullptr);
|
||||
|
||||
// // Print LoRA contribution statistics
|
||||
// size_t total_elements = 0;
|
||||
|
|
@ -3791,7 +3772,7 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
|
|||
output + t_start * inter_size, // output [local_num_tokens, inter_size]
|
||||
local_num_tokens, rank, inter_size, scale);
|
||||
},
|
||||
nullptr, "fwd_lora_gu");
|
||||
nullptr);
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -3856,7 +3837,7 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
|
|||
output + t_start * hidden, // output [local_num_tokens, hidden]
|
||||
local_num_tokens, rank, hidden, scale);
|
||||
},
|
||||
nullptr, "fwd_lora_down");
|
||||
nullptr);
|
||||
}
|
||||
|
||||
ForwardCache& push_cache() {
|
||||
|
|
@ -3941,7 +3922,7 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
|
|||
}
|
||||
}
|
||||
},
|
||||
nullptr, "save_cache");
|
||||
nullptr);
|
||||
|
||||
cache.valid = true;
|
||||
}
|
||||
|
|
@ -3968,7 +3949,7 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
|
|||
memcpy(cache.intermediate_cache + cache_offsets_[i] * config_.intermediate_size,
|
||||
m_local_gate_output_ptr_[expert_idx], num_tokens * config_.intermediate_size * sizeof(ggml_bf16_t));
|
||||
},
|
||||
nullptr, "save_inter_cache");
|
||||
nullptr);
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -3993,7 +3974,7 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
|
|||
memcpy(cache.down_output_cache + cache_offsets_[i] * config_.hidden_size, src_ptr,
|
||||
num_tokens * config_.hidden_size * sizeof(ggml_bf16_t));
|
||||
},
|
||||
nullptr, "save_down_cache");
|
||||
nullptr);
|
||||
}
|
||||
|
||||
void backward_down(const ForwardCache& cache, const void* grad_output, void* grad_down_lora_a,
|
||||
|
|
@ -4025,7 +4006,7 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
|
|||
memset(reinterpret_cast<char*>(grad_intermediate_) + offset, 0, size);
|
||||
}
|
||||
},
|
||||
nullptr, "bwd_down_memset");
|
||||
nullptr);
|
||||
}
|
||||
|
||||
// Scatter grad_output to per-expert buffers and compute gradients
|
||||
|
|
@ -4176,7 +4157,7 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
|
|||
}
|
||||
}
|
||||
},
|
||||
nullptr, "bwd_down");
|
||||
nullptr);
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -4198,13 +4179,13 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
|
|||
int k = cache.k_cache;
|
||||
constexpr int kSmallBwdDirectQlen = 0;
|
||||
constexpr int kSmallBwdDirectMaxTasks = 16;
|
||||
auto direct_or_pool = [&](int count, auto&& fn, const char* task_name, int block_size = 1) {
|
||||
auto direct_or_pool = [&](int count, auto&& fn) {
|
||||
if (qlen <= kSmallBwdDirectQlen && count <= kSmallBwdDirectMaxTasks) {
|
||||
for (int i = 0; i < count; i++) {
|
||||
fn(i);
|
||||
}
|
||||
} else {
|
||||
pool->do_work_stealing_job(count, nullptr, fn, nullptr, task_name, block_size);
|
||||
pool->do_work_stealing_job(count, nullptr, fn, nullptr);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -4258,8 +4239,7 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
|
|||
int num_tokens = m_local_num_[expert_idx];
|
||||
if (num_tokens == 0) return;
|
||||
memset(grad_output_bf16_ptr_[expert_idx], 0, num_tokens * config_.hidden_size * sizeof(ggml_bf16_t));
|
||||
},
|
||||
"bwd_down_zero");
|
||||
});
|
||||
|
||||
// =====================================================
|
||||
// Step 2: Scatter grad_output to per-expert BF16 buffers
|
||||
|
|
@ -4305,8 +4285,7 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
|
|||
dst_row[h] = GGML_FP32_TO_BF16(cur);
|
||||
}
|
||||
}
|
||||
},
|
||||
"bwd_down_scatter");
|
||||
});
|
||||
}
|
||||
|
||||
// =====================================================
|
||||
|
|
@ -4319,8 +4298,7 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
|
|||
int num_tokens = m_local_num_[expert_idx];
|
||||
if (num_tokens == 0) return;
|
||||
grad_output_ba_[expert_idx]->from_mat(num_tokens, grad_output_bf16_ptr_[expert_idx], 0, 1);
|
||||
},
|
||||
"bwd_down_quantize");
|
||||
});
|
||||
|
||||
// =====================================================
|
||||
// Step 3+4: AMX GEMM + to_mat (merged to use same ith/nth)
|
||||
|
|
@ -4365,7 +4343,7 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
|
|||
// to_mat: Convert BufferC to BF16 - use same ith, nth as mat_mul!
|
||||
bc->to_mat(m, grad_intermediate_ + expert_offsets[task_idx], ith, nth);
|
||||
},
|
||||
nullptr, "bwd_down_gemm", 1);
|
||||
nullptr);
|
||||
|
||||
// =====================================================
|
||||
// Step 3.5: Add LoRA contribution to grad_intermediate (AVX512)
|
||||
|
|
@ -4417,8 +4395,7 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
|
|||
expert_lora_a, // [rank, inter_size] BF16
|
||||
grad_inter + t_start * inter_size, // [local_num_tokens, inter_size] BF16
|
||||
local_num_tokens, rank, inter_size, scale);
|
||||
},
|
||||
"bwd_down_lora_to_inter");
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -4600,7 +4577,7 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
|
|||
}
|
||||
}
|
||||
},
|
||||
nullptr, "bwd_down_lora_grad_AB");
|
||||
nullptr);
|
||||
} else {
|
||||
struct LoraGradTask {
|
||||
int expert_task = -1;
|
||||
|
|
@ -4609,7 +4586,6 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
|
|||
};
|
||||
std::vector<LoraGradTask> lora_grad_tasks;
|
||||
|
||||
uint64_t grad_setup_start = sft_timer::get_trace_timestamp();
|
||||
float* grad_b_accum_all = down_lora_grad_b_accum_pool_;
|
||||
float* grad_a_accum_all = down_lora_grad_a_accum_pool_;
|
||||
if (activated_expert > 0) {
|
||||
|
|
@ -4630,8 +4606,6 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
|
|||
lora_grad_tasks.push_back({task_id, t, std::min(t + token_tile, buf.num_tokens)});
|
||||
}
|
||||
}
|
||||
uint64_t grad_setup_end = sft_timer::get_trace_timestamp();
|
||||
sft_timer::add_kernel_trace("bwd_down_lora_grad_setup", grad_setup_start, grad_setup_end, tp_part_idx, 0);
|
||||
|
||||
if (!lora_grad_tasks.empty()) {
|
||||
direct_or_pool(
|
||||
|
|
@ -4748,8 +4722,7 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
|
|||
grad_a_global[off] += grad_a_local[off];
|
||||
}
|
||||
}
|
||||
},
|
||||
"bwd_down_lora_grad_AB");
|
||||
});
|
||||
|
||||
constexpr int kDownGradBTile = 512;
|
||||
constexpr int kDownGradATile = 512;
|
||||
|
|
@ -4788,7 +4761,7 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
|
|||
}
|
||||
}
|
||||
},
|
||||
nullptr, "bwd_down_lora_write_B");
|
||||
nullptr);
|
||||
|
||||
pool->do_work_stealing_job(
|
||||
activated_expert * grad_a_blocks, nullptr,
|
||||
|
|
@ -4812,7 +4785,7 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
|
|||
}
|
||||
}
|
||||
},
|
||||
nullptr, "bwd_down_lora_write_A");
|
||||
nullptr);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -4824,13 +4797,13 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
|
|||
int qlen = cache.qlen_cache;
|
||||
constexpr int kSmallBwdDirectQlen = 0;
|
||||
constexpr int kSmallBwdDirectMaxTasks = 16;
|
||||
auto direct_or_pool = [&](int count, auto&& fn, const char* task_name, int block_size = 1) {
|
||||
auto direct_or_pool = [&](int count, auto&& fn) {
|
||||
if (qlen <= kSmallBwdDirectQlen && count <= kSmallBwdDirectMaxTasks) {
|
||||
for (int i = 0; i < count; i++) {
|
||||
fn(i);
|
||||
}
|
||||
} else {
|
||||
pool->do_work_stealing_job(count, nullptr, fn, nullptr, task_name, block_size);
|
||||
pool->do_work_stealing_job(count, nullptr, fn, nullptr);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -4942,8 +4915,7 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
|
|||
grad_gate[i] = GGML_FP32_TO_BF16(grad_i_val * u_val * sigmoid_val * (1.0f + g_val * (1.0f - sigmoid_val)));
|
||||
grad_up[i] = GGML_FP32_TO_BF16(grad_i_val * silu_val);
|
||||
}
|
||||
},
|
||||
"bwd_act_silu");
|
||||
});
|
||||
|
||||
}
|
||||
|
||||
|
|
@ -4963,13 +4935,13 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
|
|||
int k = cache.k_cache;
|
||||
constexpr int kSmallBwdDirectQlen = 0;
|
||||
constexpr int kSmallBwdDirectMaxTasks = 16;
|
||||
auto direct_or_pool = [&](int count, auto&& fn, const char* task_name, int block_size = 1) {
|
||||
auto direct_or_pool = [&](int count, auto&& fn) {
|
||||
if (qlen <= kSmallBwdDirectQlen && count <= kSmallBwdDirectMaxTasks) {
|
||||
for (int i = 0; i < count; i++) {
|
||||
fn(i);
|
||||
}
|
||||
} else {
|
||||
pool->do_work_stealing_job(count, nullptr, fn, nullptr, task_name, block_size);
|
||||
pool->do_work_stealing_job(count, nullptr, fn, nullptr);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -5065,7 +5037,7 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
|
|||
}
|
||||
}
|
||||
|
||||
auto scatter_to_grad_input = [&](float scale, const char* task_name) {
|
||||
auto scatter_to_grad_input = [&](float scale) {
|
||||
ggml_bf16_t* grad_input_bf16 = (ggml_bf16_t*)grad_input;
|
||||
const int hidden = config_.hidden_size;
|
||||
const int hidden_vec_end = hidden & ~31;
|
||||
|
|
@ -5101,13 +5073,10 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
|
|||
dst[h] = GGML_FP32_TO_BF16(cur);
|
||||
}
|
||||
}
|
||||
},
|
||||
task_name);
|
||||
});
|
||||
};
|
||||
|
||||
auto base_pass = [&](bool do_up) {
|
||||
const char* quant_name = do_up ? "bwd_gu_base_q_up" : "bwd_gu_base_q_gate";
|
||||
const char* gemm_name = do_up ? "bwd_gu_base_gemm_up" : "bwd_gu_base_gemm_gate";
|
||||
|
||||
// Quantize grad to BufferA
|
||||
direct_or_pool(
|
||||
|
|
@ -5121,8 +5090,7 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
|
|||
ggml_bf16_t* grad = do_up ? (grad_up_output_ + offset * config_.intermediate_size)
|
||||
: (grad_gate_output_ + offset * config_.intermediate_size);
|
||||
down_ba_[expert_idx]->from_mat(m, grad, 0, 1);
|
||||
},
|
||||
quant_name);
|
||||
});
|
||||
|
||||
int nth = T::recommended_nth(config_.hidden_size);
|
||||
pool->do_work_stealing_job(
|
||||
|
|
@ -5141,9 +5109,9 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
|
|||
amx::mat_mul(m, config_.hidden_size, config_.intermediate_size, ba, bb, bc, ith, nth);
|
||||
bc->to_mat(m, grad_output_bf16_ptr_[expert_idx], ith, nth);
|
||||
},
|
||||
nullptr, gemm_name, 1);
|
||||
nullptr);
|
||||
|
||||
scatter_to_grad_input(1.0f, "bwd_gu_scatter_base");
|
||||
scatter_to_grad_input(1.0f);
|
||||
};
|
||||
|
||||
base_pass(false); // gate
|
||||
|
|
@ -5341,8 +5309,7 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
|
|||
up_gradb_global[off] += up_gradb_local[off];
|
||||
}
|
||||
}
|
||||
},
|
||||
"bwd_gu_lora_u_gradb_fused");
|
||||
});
|
||||
|
||||
constexpr int kGuGradBBlock = 256;
|
||||
int gradb_blocks = (inter_size + kGuGradBBlock - 1) / kGuGradBBlock;
|
||||
|
|
@ -5375,7 +5342,7 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
|
|||
}
|
||||
}
|
||||
},
|
||||
nullptr, "bwd_gu_lora_gradb_fused_write");
|
||||
nullptr);
|
||||
}
|
||||
|
||||
// =====================================================
|
||||
|
|
@ -5386,8 +5353,6 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
|
|||
// Gate and up still run sequentially because they share grad_output_bf16_ptr_.
|
||||
// =====================================================
|
||||
auto lora_pass_remainder = [&](bool do_up) {
|
||||
const char* gb_gradin_name = do_up ? "bwd_gu_lora_gb_gradin_fused_up" : "bwd_gu_lora_gb_gradin_fused_gate";
|
||||
const char* grad_a_name = do_up ? "bwd_gu_lora_gradA_up" : "bwd_gu_lora_gradA_gate";
|
||||
|
||||
struct GuLoraGradInTask {
|
||||
int expert_task = -1;
|
||||
|
|
@ -5443,11 +5408,10 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
|
|||
|
||||
memset(grad_out, 0, static_cast<size_t>(local_tokens) * hidden * sizeof(ggml_bf16_t));
|
||||
avx::lora_fp32_bf16_fused_add_transposed(gb, lora_a, grad_out, local_tokens, lora_rank_, hidden, 1.0f);
|
||||
},
|
||||
gb_gradin_name);
|
||||
});
|
||||
}
|
||||
|
||||
scatter_to_grad_input(lora_scaling_, "bwd_gu_scatter_lora");
|
||||
scatter_to_grad_input(lora_scaling_);
|
||||
|
||||
// Step 6: grad_A = G_B^T @ X
|
||||
ggml_bf16_t* grad_lora_a = do_up ? grad_up_a : grad_gate_a;
|
||||
|
|
@ -5547,7 +5511,7 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
|
|||
}
|
||||
}
|
||||
},
|
||||
nullptr, grad_a_name);
|
||||
nullptr);
|
||||
};
|
||||
|
||||
lora_pass_remainder(false); // gate: gb_gradin_fused, scatter, gradA
|
||||
|
|
|
|||
|
|
@ -318,7 +318,7 @@ class TP_MOE_SFT : public TP_MOE<T> {
|
|||
sizeof(ggml_bf16_t) * tpc.intermediate_size);
|
||||
}
|
||||
},
|
||||
nullptr, "memcpy_weights_tmp");
|
||||
nullptr);
|
||||
}
|
||||
|
||||
// Set BF16 weight pointers on sub-MOEs for backward
|
||||
|
|
@ -490,16 +490,13 @@ class TP_MOE_SFT : public TP_MOE<T> {
|
|||
throw std::runtime_error("Weights not loaded");
|
||||
}
|
||||
|
||||
auto start_sft = sft_timer::get_trace_timestamp();
|
||||
|
||||
int qlen = *qlen_ptr;
|
||||
auto pool = config.pool;
|
||||
|
||||
// Reset forward timing before computation
|
||||
// sft_timer::reset_forward();
|
||||
// Reset per-thread counters in each subpool (to accumulate all do_work_stealing_job calls)
|
||||
for (int i = 0; i < tp_count; i++) {
|
||||
pool->get_subpool(i)->reset_counters();
|
||||
}
|
||||
|
||||
// Run forward on each NUMA node
|
||||
|
|
@ -508,24 +505,18 @@ class TP_MOE_SFT : public TP_MOE<T> {
|
|||
save_for_backward);
|
||||
});
|
||||
|
||||
auto end_fwd = sft_timer::get_trace_timestamp();
|
||||
|
||||
// // Collect per-thread timing from all NUMA subpools
|
||||
// for (int i = 0; i < tp_count; i++) {
|
||||
// sft_timer::collect_forward(pool->get_subpool(i));
|
||||
// }
|
||||
|
||||
// // Print per-thread forward timing
|
||||
// sft_timer::print_forward();
|
||||
|
||||
// Merge results from all NUMA nodes
|
||||
this->merge_results(qlen, output);
|
||||
|
||||
auto end_merge = sft_timer::get_trace_timestamp();
|
||||
|
||||
pool->dispense_backend()->do_numa_job([&](int numa_id) {
|
||||
sft_timer::add_kernel_trace("fwd", start_sft, end_fwd, numa_id, 0);
|
||||
sft_timer::add_kernel_trace("merge", end_fwd, end_merge, numa_id, 0);
|
||||
});
|
||||
}
|
||||
|
||||
|
|
@ -561,7 +552,6 @@ class TP_MOE_SFT : public TP_MOE<T> {
|
|||
void* grad_weights) {
|
||||
auto pool = config.pool;
|
||||
|
||||
auto start_sft = sft_timer::get_trace_timestamp();
|
||||
|
||||
// Get full intermediate_size (before TP partitioning)
|
||||
int full_intermediate_size = sft_config.intermediate_size;
|
||||
|
|
@ -665,9 +655,8 @@ class TP_MOE_SFT : public TP_MOE<T> {
|
|||
const auto& seg = clear_segs[(size_t)seg_idx];
|
||||
std::memset(seg.ptr, 0, seg.len);
|
||||
},
|
||||
nullptr, "bwd_alloc_memset");
|
||||
nullptr);
|
||||
|
||||
auto end_alloc = sft_timer::get_trace_timestamp();
|
||||
|
||||
// Compute TP-slice pointers for copy-type direct writes
|
||||
// Each TP writes to its own I-slice of the final output tensor
|
||||
|
|
@ -697,7 +686,6 @@ class TP_MOE_SFT : public TP_MOE<T> {
|
|||
|
||||
// Run backward on each NUMA node
|
||||
pool->dispense_backend()->do_numa_job([&](int numa_id) {
|
||||
auto start_Bwd = sft_timer::get_trace_timestamp();
|
||||
tps[numa_id]->backward(grad_output, part_grad_input_[numa_id],
|
||||
// reduce-type: BF16 pointer unused (FP32 sparse used instead)
|
||||
nullptr, /* grad_gate_lora_a — unused, FP32 path below */
|
||||
|
|
@ -708,18 +696,13 @@ class TP_MOE_SFT : public TP_MOE<T> {
|
|||
nullptr, /* grad_down_lora_b — unused, FP32 path below */
|
||||
part_grad_weights_[numa_id], full_intermediate_size, tp_fp32_down_b[numa_id],
|
||||
tp_fp32_gate_a[numa_id], tp_fp32_up_a[numa_id]);
|
||||
auto end_bwd = sft_timer::get_trace_timestamp();
|
||||
sft_timer::add_kernel_trace("bwd_alloc", start_sft, end_alloc, numa_id, 0);
|
||||
sft_timer::add_kernel_trace("bwd_tp", start_Bwd, end_bwd, numa_id, 0);
|
||||
});
|
||||
|
||||
// // Collect per-thread timing from all NUMA subpools
|
||||
// for (int i = 0; i < tp_count; i++) {
|
||||
// sft_timer::collect_backward(pool->get_subpool(i));
|
||||
// }
|
||||
|
||||
// // Print per-thread backward timing
|
||||
// sft_timer::print_backward();
|
||||
|
||||
// // Print expert token distribution for load balancing analysis
|
||||
// {
|
||||
|
|
@ -739,7 +722,6 @@ class TP_MOE_SFT : public TP_MOE<T> {
|
|||
// }
|
||||
|
||||
// Bug #22 fix: Merge grad_input from all NUMA nodes (sum them together)
|
||||
auto start_sum = sft_timer::get_trace_timestamp();
|
||||
{
|
||||
auto* out = (ggml_bf16_t*)grad_input;
|
||||
pool->do_work_stealing_job(
|
||||
|
|
@ -784,13 +766,11 @@ class TP_MOE_SFT : public TP_MOE<T> {
|
|||
dst[h] = GGML_FP32_TO_BF16(sum);
|
||||
}
|
||||
},
|
||||
nullptr, "merge_grad_input");
|
||||
nullptr);
|
||||
}
|
||||
auto end_sum = sft_timer::get_trace_timestamp();
|
||||
|
||||
// Merge reduce-type LoRA gradients: sparse FP32 sum across TPs → BF16 final output
|
||||
// Copy-type grads (gate/up_lora_b, down_lora_a) were written directly — no merge needed.
|
||||
auto start_merge = sft_timer::get_trace_timestamp();
|
||||
if constexpr (!kSkipLoRA) {
|
||||
// Sparse merge for gate_lora_a, up_lora_a: [active_count, r, H] FP32 → [E, r, H] BF16
|
||||
{
|
||||
|
|
@ -835,7 +815,7 @@ class TP_MOE_SFT : public TP_MOE<T> {
|
|||
ud[h] = GGML_FP32_TO_BF16(us);
|
||||
}
|
||||
},
|
||||
nullptr, "merge_lora_a");
|
||||
nullptr);
|
||||
}
|
||||
|
||||
// Sparse merge for down_lora_b: [active_count, H, r] FP32 → [E, H, r] BF16
|
||||
|
|
@ -861,7 +841,7 @@ class TP_MOE_SFT : public TP_MOE<T> {
|
|||
}
|
||||
}
|
||||
},
|
||||
nullptr, "merge_down_lora_b");
|
||||
nullptr);
|
||||
}
|
||||
} // if constexpr (!kSkipLoRA)
|
||||
|
||||
|
|
@ -900,14 +880,10 @@ class TP_MOE_SFT : public TP_MOE<T> {
|
|||
out_grad_weights[i] = sum;
|
||||
}
|
||||
},
|
||||
nullptr, "merge_grad_weights");
|
||||
nullptr);
|
||||
}
|
||||
auto end_merge = sft_timer::get_trace_timestamp();
|
||||
|
||||
pool->dispense_backend()->do_numa_job([&](int numa_id) {
|
||||
sft_timer::add_kernel_trace("merge_tp", start_sum, end_sum, numa_id, 0);
|
||||
sft_timer::add_kernel_trace("merge_lora_a", end_sum, start_merge, numa_id, 0);
|
||||
sft_timer::add_kernel_trace("merge_grad_weights", start_merge, end_merge, numa_id, 0);
|
||||
});
|
||||
}
|
||||
|
||||
|
|
@ -989,8 +965,7 @@ class TP_MOE_SFT : public TP_MOE<T> {
|
|||
numa_id * tp_inter,
|
||||
sizeof(ggml_bf16_t) * tp_inter);
|
||||
}
|
||||
},
|
||||
"upd_lora_tp");
|
||||
});
|
||||
|
||||
// Update weights after all memcpy complete
|
||||
tps[numa_id]->update_lora_weights(gate_lora_a, partitioned_gate_lora_b_[numa_id], up_lora_a,
|
||||
|
|
@ -1074,7 +1049,7 @@ class TP_MOE_SFT : public TP_MOE<T> {
|
|||
sizeof(ggml_bf16_t) * tpc.intermediate_size);
|
||||
}
|
||||
},
|
||||
nullptr, "memcpy_bwd_tmp");
|
||||
nullptr);
|
||||
|
||||
tps[i]->prepare_bwd(temp_gate, temp_up, temp_down);
|
||||
|
||||
|
|
|
|||
|
|
@ -90,7 +90,7 @@ class KExpertsCPUBuffer:
|
|||
hidden_size = hidden_states.shape[-1]
|
||||
batch_size = hidden_states.shape[0]
|
||||
|
||||
pin_memory = False
|
||||
pin_memory = True
|
||||
|
||||
if batch_size in cls.capture_buffers:
|
||||
return cls.capture_buffers[batch_size]
|
||||
|
|
|
|||
|
|
@ -818,35 +818,30 @@ class OnlineQuantConverter(ConverterBase):
|
|||
print(f" [Fused] tensor {p} shape: {tuple(w.shape)}")
|
||||
fused_tensors.append(w)
|
||||
|
||||
# fused_tensors[0] : down_proj, [E, H, I]
|
||||
# fused_tensors[1] : gate_up_proj, [E, 2I, H]
|
||||
# fused_tensors[0] : down-like, [E, I, H]
|
||||
# fused_tensors[1] : gate_up-like, [E, H, 2I]
|
||||
down_fused = fused_tensors[0]
|
||||
gate_up_fused = fused_tensors[1]
|
||||
|
||||
# gate_up_fused is [E, 2I, H] — split on dim 1, no transpose needed
|
||||
# gate_up_fused: [E, H, 2I] -> [E, 2I, H] -> gate / up
|
||||
if gate_up_fused.dim() != 3:
|
||||
raise ValueError(
|
||||
f"[Fused] Expect gate_up fused tensor to be 3D, got shape {tuple(gate_up_fused.shape)}"
|
||||
)
|
||||
E = gate_up_fused.shape[0]
|
||||
I = self.moe_intermediate_size
|
||||
H = self.hidden_size
|
||||
E, H, twoI = gate_up_fused.shape
|
||||
if twoI % 2 != 0:
|
||||
raise ValueError(f"[Fused] gate_up last dim (2I) not even: {twoI}")
|
||||
I = twoI // 2
|
||||
|
||||
if gate_up_fused.shape != (E, 2 * I, H):
|
||||
raise ValueError(
|
||||
f"[Fused] gate_up shape {tuple(gate_up_fused.shape)} != expected ({E}, {2*I}, {H}). "
|
||||
f"If your model stores gate_up as [E, H, 2I], transpose is needed."
|
||||
)
|
||||
|
||||
gate_proj = gate_up_fused[:, :I, :].contiguous() # [E, I, H]
|
||||
up_proj = gate_up_fused[:, I:, :].contiguous() # [E, I, H]
|
||||
gate_up_T = gate_up_fused.transpose(1, 2).contiguous() # [E, 2I, H]
|
||||
gate_proj = gate_up_T[:, :I, :] # [E, I, H]
|
||||
up_proj = gate_up_T[:, I:, :] # [E, I, H]
|
||||
|
||||
if down_fused.dim() != 3:
|
||||
raise ValueError(f"[Fused] Expect down fused tensor to be 3D, got shape {tuple(down_fused.shape)}")
|
||||
if down_fused.shape[0] != E:
|
||||
raise ValueError(f"[Fused] down_fused expert dim mismatch: {down_fused.shape[0]} vs gate_up {E}")
|
||||
# down_proj is [E, H, I] — matches load_weights_from_tensors expectation, no transpose
|
||||
down_proj = down_fused.contiguous() # [E, H, I]
|
||||
down_proj = down_fused.transpose(1, 2).contiguous() # [E, H, I]
|
||||
del fused_tensors
|
||||
del gate_up_fused
|
||||
del down_fused
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue