diff --git a/kt-sft/csrc/ktransformers_ext/operators/amx/sft_route_moe.hpp b/kt-sft/csrc/ktransformers_ext/operators/amx/sft_route_moe.hpp index 209cc553..c7953de4 100644 --- a/kt-sft/csrc/ktransformers_ext/operators/amx/sft_route_moe.hpp +++ b/kt-sft/csrc/ktransformers_ext/operators/amx/sft_route_moe.hpp @@ -15,6 +15,7 @@ #include #include #include +#include #include "debug_sft_moe.hpp" @@ -1461,9 +1462,11 @@ public: float scaling = config_.lora_scaling; // grad_B[i, r] += sum_t(gate_grad[t,i] * lora_inter[t,r]) * scaling + #pragma omp parallel for collapse(2) schedule(dynamic) for (int i = 0; i < config_.intermediate_size; i++) { for (int r = 0; r < config_.lora_rank; r++) { float sum = 0.0f; + #pragma omp simd reduction(+:sum) for (int t = 0; t < num_tokens; t++) { float grad_val = ggml_bf16_to_fp32(gate_grad[t * config_.intermediate_size + i]); float inter_val = lora_inter[t * padded_lora_rank_ + r]; // BufferC stores in row-major @@ -1512,9 +1515,11 @@ public: // grad_A[r, h] += sum_t(temp_grad[t,r] * input[t,h]) * scaling // Note: Only multiply by scaling once (not scaling²) + #pragma omp parallel for collapse(2) schedule(dynamic) for (int r = 0; r < config_.lora_rank; r++) { for (int h = 0; h < config_.hidden_size; h++) { float sum = 0.0f; + #pragma omp simd reduction(+:sum) for (int t = 0; t < num_tokens; t++) { float temp_val = temp_grad[t * padded_lora_rank_ + r]; float input_val = ggml_bf16_to_fp32(input[t * config_.hidden_size + h]); @@ -1557,9 +1562,11 @@ public: float scaling = config_.lora_scaling; + #pragma omp parallel for collapse(2) schedule(dynamic) for (int i = 0; i < config_.intermediate_size; i++) { for (int r = 0; r < config_.lora_rank; r++) { float sum = 0.0f; + #pragma omp simd reduction(+:sum) for (int t = 0; t < num_tokens; t++) { sum += ggml_bf16_to_fp32(up_grad[t * config_.intermediate_size + i]) * lora_inter[t * padded_lora_rank_ + r]; } @@ -1602,9 +1609,11 @@ public: // grad_A[r, h] += sum_t(temp_grad[t,r] * input[t,h]) * scaling // Note: Only multiply by scaling once (not scaling²) + #pragma omp parallel for collapse(2) schedule(dynamic) for (int r = 0; r < config_.lora_rank; r++) { for (int h = 0; h < config_.hidden_size; h++) { float sum = 0.0f; + #pragma omp simd reduction(+:sum) for (int t = 0; t < num_tokens; t++) { sum += temp_grad[t * padded_lora_rank_ + r] * ggml_bf16_to_fp32(input[t * config_.hidden_size + h]); } @@ -1664,10 +1673,11 @@ public: ggml_bf16_t *grad_B_dst = (ggml_bf16_t *)config_.grad_down_lora_B + expert_idx * config_.hidden_size * config_.lora_rank; float scaling = config_.lora_scaling; - + #pragma omp parallel for collapse(2) schedule(dynamic) for (int h = 0; h < config_.hidden_size; h++) { for (int r = 0; r < config_.lora_rank; r++) { float sum = 0.0f; + #pragma omp simd reduction(+:sum) for (int t = 0; t < num_tokens; t++) { sum += ggml_bf16_to_fp32(down_grad_weighted[t * config_.hidden_size + h]) * lora_inter[t * padded_lora_rank_ + r]; } @@ -1712,9 +1722,11 @@ public: // grad_A[r, i] += sum_t(temp_grad[t,r] * intermediate[t,i]) * scaling // Note: Only multiply by scaling once (not scaling²) + #pragma omp parallel for collapse(2) schedule(dynamic) for (int r = 0; r < config_.lora_rank; r++) { for (int i = 0; i < config_.intermediate_size; i++) { float sum = 0.0f; + #pragma omp simd reduction(+:sum) for (int t = 0; t < num_tokens; t++) { sum += temp_grad[t * padded_lora_rank_ + r] * ggml_bf16_to_fp32(intermediate[t * config_.intermediate_size + i]); }