mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-04-28 03:39:48 +00:00
sft_route_moe use openmp for parallel in backward (#1662)
Co-authored-by: unknown <xiongchenhui@hisense.ad>
This commit is contained in:
parent
fb787283f6
commit
4662368606
1 changed files with 13 additions and 1 deletions
|
|
@ -15,6 +15,7 @@
|
|||
#include <vector>
|
||||
#include <fstream>
|
||||
#include <filesystem>
|
||||
#include <omp.h>
|
||||
|
||||
#include "debug_sft_moe.hpp"
|
||||
|
||||
|
|
@ -1461,9 +1462,11 @@ public:
|
|||
float scaling = config_.lora_scaling;
|
||||
|
||||
// grad_B[i, r] += sum_t(gate_grad[t,i] * lora_inter[t,r]) * scaling
|
||||
#pragma omp parallel for collapse(2) schedule(dynamic)
|
||||
for (int i = 0; i < config_.intermediate_size; i++) {
|
||||
for (int r = 0; r < config_.lora_rank; r++) {
|
||||
float sum = 0.0f;
|
||||
#pragma omp simd reduction(+:sum)
|
||||
for (int t = 0; t < num_tokens; t++) {
|
||||
float grad_val = ggml_bf16_to_fp32(gate_grad[t * config_.intermediate_size + i]);
|
||||
float inter_val = lora_inter[t * padded_lora_rank_ + r]; // BufferC stores in row-major
|
||||
|
|
@ -1512,9 +1515,11 @@ public:
|
|||
|
||||
// grad_A[r, h] += sum_t(temp_grad[t,r] * input[t,h]) * scaling
|
||||
// Note: Only multiply by scaling once (not scaling²)
|
||||
#pragma omp parallel for collapse(2) schedule(dynamic)
|
||||
for (int r = 0; r < config_.lora_rank; r++) {
|
||||
for (int h = 0; h < config_.hidden_size; h++) {
|
||||
float sum = 0.0f;
|
||||
#pragma omp simd reduction(+:sum)
|
||||
for (int t = 0; t < num_tokens; t++) {
|
||||
float temp_val = temp_grad[t * padded_lora_rank_ + r];
|
||||
float input_val = ggml_bf16_to_fp32(input[t * config_.hidden_size + h]);
|
||||
|
|
@ -1557,9 +1562,11 @@ public:
|
|||
|
||||
float scaling = config_.lora_scaling;
|
||||
|
||||
#pragma omp parallel for collapse(2) schedule(dynamic)
|
||||
for (int i = 0; i < config_.intermediate_size; i++) {
|
||||
for (int r = 0; r < config_.lora_rank; r++) {
|
||||
float sum = 0.0f;
|
||||
#pragma omp simd reduction(+:sum)
|
||||
for (int t = 0; t < num_tokens; t++) {
|
||||
sum += ggml_bf16_to_fp32(up_grad[t * config_.intermediate_size + i]) * lora_inter[t * padded_lora_rank_ + r];
|
||||
}
|
||||
|
|
@ -1602,9 +1609,11 @@ public:
|
|||
|
||||
// grad_A[r, h] += sum_t(temp_grad[t,r] * input[t,h]) * scaling
|
||||
// Note: Only multiply by scaling once (not scaling²)
|
||||
#pragma omp parallel for collapse(2) schedule(dynamic)
|
||||
for (int r = 0; r < config_.lora_rank; r++) {
|
||||
for (int h = 0; h < config_.hidden_size; h++) {
|
||||
float sum = 0.0f;
|
||||
#pragma omp simd reduction(+:sum)
|
||||
for (int t = 0; t < num_tokens; t++) {
|
||||
sum += temp_grad[t * padded_lora_rank_ + r] * ggml_bf16_to_fp32(input[t * config_.hidden_size + h]);
|
||||
}
|
||||
|
|
@ -1664,10 +1673,11 @@ public:
|
|||
ggml_bf16_t *grad_B_dst = (ggml_bf16_t *)config_.grad_down_lora_B + expert_idx * config_.hidden_size * config_.lora_rank;
|
||||
|
||||
float scaling = config_.lora_scaling;
|
||||
|
||||
#pragma omp parallel for collapse(2) schedule(dynamic)
|
||||
for (int h = 0; h < config_.hidden_size; h++) {
|
||||
for (int r = 0; r < config_.lora_rank; r++) {
|
||||
float sum = 0.0f;
|
||||
#pragma omp simd reduction(+:sum)
|
||||
for (int t = 0; t < num_tokens; t++) {
|
||||
sum += ggml_bf16_to_fp32(down_grad_weighted[t * config_.hidden_size + h]) * lora_inter[t * padded_lora_rank_ + r];
|
||||
}
|
||||
|
|
@ -1712,9 +1722,11 @@ public:
|
|||
|
||||
// grad_A[r, i] += sum_t(temp_grad[t,r] * intermediate[t,i]) * scaling
|
||||
// Note: Only multiply by scaling once (not scaling²)
|
||||
#pragma omp parallel for collapse(2) schedule(dynamic)
|
||||
for (int r = 0; r < config_.lora_rank; r++) {
|
||||
for (int i = 0; i < config_.intermediate_size; i++) {
|
||||
float sum = 0.0f;
|
||||
#pragma omp simd reduction(+:sum)
|
||||
for (int t = 0; t < num_tokens; t++) {
|
||||
sum += temp_grad[t * padded_lora_rank_ + r] * ggml_bf16_to_fp32(intermediate[t * config_.intermediate_size + i]);
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue