sft_route_moe use openmp for parallel in backward (#1662)

Co-authored-by: unknown <xiongchenhui@hisense.ad>
This commit is contained in:
Pory 2025-12-06 20:38:50 +08:00 committed by GitHub
parent fb787283f6
commit 4662368606
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -15,6 +15,7 @@
#include <vector>
#include <fstream>
#include <filesystem>
#include <omp.h>
#include "debug_sft_moe.hpp"
@ -1461,9 +1462,11 @@ public:
float scaling = config_.lora_scaling;
// grad_B[i, r] += sum_t(gate_grad[t,i] * lora_inter[t,r]) * scaling
#pragma omp parallel for collapse(2) schedule(dynamic)
for (int i = 0; i < config_.intermediate_size; i++) {
for (int r = 0; r < config_.lora_rank; r++) {
float sum = 0.0f;
#pragma omp simd reduction(+:sum)
for (int t = 0; t < num_tokens; t++) {
float grad_val = ggml_bf16_to_fp32(gate_grad[t * config_.intermediate_size + i]);
float inter_val = lora_inter[t * padded_lora_rank_ + r]; // BufferC stores in row-major
@ -1512,9 +1515,11 @@ public:
// grad_A[r, h] += sum_t(temp_grad[t,r] * input[t,h]) * scaling
// Note: Only multiply by scaling once (not scaling²)
#pragma omp parallel for collapse(2) schedule(dynamic)
for (int r = 0; r < config_.lora_rank; r++) {
for (int h = 0; h < config_.hidden_size; h++) {
float sum = 0.0f;
#pragma omp simd reduction(+:sum)
for (int t = 0; t < num_tokens; t++) {
float temp_val = temp_grad[t * padded_lora_rank_ + r];
float input_val = ggml_bf16_to_fp32(input[t * config_.hidden_size + h]);
@ -1557,9 +1562,11 @@ public:
float scaling = config_.lora_scaling;
#pragma omp parallel for collapse(2) schedule(dynamic)
for (int i = 0; i < config_.intermediate_size; i++) {
for (int r = 0; r < config_.lora_rank; r++) {
float sum = 0.0f;
#pragma omp simd reduction(+:sum)
for (int t = 0; t < num_tokens; t++) {
sum += ggml_bf16_to_fp32(up_grad[t * config_.intermediate_size + i]) * lora_inter[t * padded_lora_rank_ + r];
}
@ -1602,9 +1609,11 @@ public:
// grad_A[r, h] += sum_t(temp_grad[t,r] * input[t,h]) * scaling
// Note: Only multiply by scaling once (not scaling²)
#pragma omp parallel for collapse(2) schedule(dynamic)
for (int r = 0; r < config_.lora_rank; r++) {
for (int h = 0; h < config_.hidden_size; h++) {
float sum = 0.0f;
#pragma omp simd reduction(+:sum)
for (int t = 0; t < num_tokens; t++) {
sum += temp_grad[t * padded_lora_rank_ + r] * ggml_bf16_to_fp32(input[t * config_.hidden_size + h]);
}
@ -1664,10 +1673,11 @@ public:
ggml_bf16_t *grad_B_dst = (ggml_bf16_t *)config_.grad_down_lora_B + expert_idx * config_.hidden_size * config_.lora_rank;
float scaling = config_.lora_scaling;
#pragma omp parallel for collapse(2) schedule(dynamic)
for (int h = 0; h < config_.hidden_size; h++) {
for (int r = 0; r < config_.lora_rank; r++) {
float sum = 0.0f;
#pragma omp simd reduction(+:sum)
for (int t = 0; t < num_tokens; t++) {
sum += ggml_bf16_to_fp32(down_grad_weighted[t * config_.hidden_size + h]) * lora_inter[t * padded_lora_rank_ + r];
}
@ -1712,9 +1722,11 @@ public:
// grad_A[r, i] += sum_t(temp_grad[t,r] * intermediate[t,i]) * scaling
// Note: Only multiply by scaling once (not scaling²)
#pragma omp parallel for collapse(2) schedule(dynamic)
for (int r = 0; r < config_.lora_rank; r++) {
for (int i = 0; i < config_.intermediate_size; i++) {
float sum = 0.0f;
#pragma omp simd reduction(+:sum)
for (int t = 0; t < num_tokens; t++) {
sum += temp_grad[t * padded_lora_rank_ + r] * ggml_bf16_to_fp32(intermediate[t * config_.intermediate_size + i]);
}