hexagon: add MROPE and IMROPE support in HTP rope op (#23317)

This commit is contained in:
Aparna M P 2026-05-20 02:40:13 +05:30 committed by GitHub
parent 67ace021da
commit 17d22a35b2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 102 additions and 23 deletions

View file

@ -2661,7 +2661,7 @@ static bool ggml_hexagon_supported_rope(const struct ggml_hexagon_session * sess
int mode = op_params[2];
if ((mode & GGML_ROPE_TYPE_MROPE) || (mode & GGML_ROPE_TYPE_VISION)) {
if (mode == GGML_ROPE_TYPE_VISION) {
return false;
}
if (mode & 1) {

View file

@ -18,9 +18,11 @@
#include "htp-ops.h"
#include "htp-ops.h"
// Redefined the types GGML_ROPE_TYPE_NORMAL & GGML_ROPE_TYPE_NEOX as we can't include ggml.h
// Redefined the rope type constants as we can't include ggml.h
#define HTP_ROPE_TYPE_NORMAL 0
#define HTP_ROPE_TYPE_NEOX 2
#define HTP_ROPE_TYPE_MROPE 8
#define HTP_ROPE_TYPE_IMROPE 40
#define HTP_ROPE_SPAD_NROWS 16
#define HTP_ROPE_SPAD_BLOCK (HTP_ROPE_SPAD_NROWS/2)
@ -82,6 +84,29 @@ static float rope_yarn_ramp(const float low, const float high, const int i0) {
return (1 - MIN(1, MAX(0, y)));
}
// Compute one (cos, sin) pair into cache[i0], cache[i0+1] applying YaRN scaling.
static inline void rope_yarn_one(float theta, float freq_scale, float * corr_dims,
uint32_t i0, float ext_factor, float mscale,
float * cache) {
float theta_extrap = theta;
// Get n-d rotational scaling corrected for extrapolation
float theta_interp = freq_scale * theta_extrap;
float theta_final = theta_interp;
float mscale_final = mscale;
if (ext_factor != 0.0f) {
float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
theta_final = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
// Get n-d magnitude scaling corrected for interpolation
mscale_final *= 1.0f + 0.1f * logf(1.0f / freq_scale);
}
cache[i0 + 0] = cosf(theta_final) * mscale_final;
cache[i0 + 1] = sinf(theta_final) * mscale_final;
}
static void rope_cache_init(const float theta_base,
const float freq_scale,
const float * freq_factors,
@ -96,29 +121,65 @@ static void rope_cache_init(const float theta_base,
for (uint32_t i0 = 0; i0 < ne0; i0 += 2) {
const float ff = freq_factors ? freq_factors[i0 / 2] : 1.0f;
float theta_extrap = theta / ff;
// Get n-d rotational scaling corrected for extrapolation
float theta_interp = freq_scale * theta_extrap;
float theta_final = theta_interp;
float mscale_final = mscale;
if (ext_factor != 0.0f) {
float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
theta_final = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
// Get n-d magnitude scaling corrected for interpolation
mscale_final *= 1.0f + 0.1f * logf(1.0f / freq_scale);
}
cache[i0 + 0] = cosf(theta_final) * mscale_final;
cache[i0 + 1] = sinf(theta_final) * mscale_final;
rope_yarn_one(theta / ff, freq_scale, corr_dims, i0, ext_factor, mscale, cache);
theta *= theta_scale;
}
}
// pos_t/h/w/e: the four position ids for this sequence step (t=time, h=height, w=width, e=extra).
// sections[4]: number of head dims assigned to each position component.
static void mrope_cache_init(const float pos_t,
const float pos_h,
const float pos_w,
const float pos_e,
const int32_t sections[4],
const bool is_imrope,
const float freq_scale,
const float * freq_factors,
float * corr_dims,
const uint32_t ne0,
const float ext_factor,
const float mscale,
float * cache,
const float theta_scale) {
const int sect_dims = sections[0] + sections[1] + sections[2] + sections[3];
const int sec_w = sections[0] + sections[1];
const int sec_e = sec_w + sections[2];
float theta_t = pos_t;
float theta_h = pos_h;
float theta_w = pos_w;
float theta_e = pos_e;
for (uint32_t i0 = 0; i0 < ne0; i0 += 2) {
const float ff = freq_factors ? freq_factors[i0 / 2] : 1.0f;
const int sector = (i0 / 2) % sect_dims;
float theta;
if (is_imrope) {
// Interleaved: sector mod 3 selects component
if (sector % 3 == 0 && sector < 3 * sections[0]) { theta = theta_t; }
else if (sector % 3 == 1 && sector < 3 * sections[1]) { theta = theta_h; }
else if (sector % 3 == 2 && sector < 3 * sections[2]) { theta = theta_w; }
else { theta = theta_e; }
} else {
// Contiguous sections
if (sector < sections[0]) { theta = theta_t; }
else if (sector < sec_w) { theta = theta_h; }
else if (sector < sec_e) { theta = theta_w; }
else { theta = theta_e; }
}
rope_yarn_one(theta / ff, freq_scale, corr_dims, i0, ext_factor, mscale, cache);
theta_t *= theta_scale;
theta_h *= theta_scale;
theta_w *= theta_scale;
theta_e *= theta_scale;
}
}
#define M_PI 3.1415926535897932384626433
static void rope_corr_dims(int n_dims,
@ -274,7 +335,8 @@ static void rope_job_f32(unsigned int nth, unsigned int ith, void * data) {
uint64_t tt = HAP_perf_get_qtimer_count();
const int32_t mode = rctx->mode;
const bool is_neox = mode & HTP_ROPE_TYPE_NEOX;
// MROPE and IMROPE use NEOX-style pairing for the rotation
const bool is_neox = (mode & HTP_ROPE_TYPE_NEOX) || (mode & HTP_ROPE_TYPE_MROPE);
// VTCM setup
uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread);
@ -326,8 +388,25 @@ static void rope_job_f32(unsigned int nth, unsigned int ith, void * data) {
if (i2 != prev_i2) {
prev_i2 = i2;
const int32_t p = pos[i2];
rope_cache_init(p, rctx->freq_scale, freq_factors, rctx->corr_dims, ne0, rctx->ext_factor, rctx->attn_factor, theta_cache, rctx->theta_scale);
const bool is_mrope = (rctx->mode & HTP_ROPE_TYPE_MROPE) != 0;
if (is_mrope) {
// src1 holds four position arrays stacked along ne0:
// pos[i2], pos[i2+ne2], pos[i2+ne2*2], pos[i2+ne2*3]
const bool is_imrope = (rctx->mode == HTP_ROPE_TYPE_IMROPE);
mrope_cache_init(
(float) pos[i2],
(float) pos[i2 + ne2],
(float) pos[i2 + ne2 * 2],
(float) pos[i2 + ne2 * 3],
rctx->sections, is_imrope,
rctx->freq_scale, freq_factors, rctx->corr_dims,
ne0, rctx->ext_factor, rctx->attn_factor,
theta_cache, rctx->theta_scale);
} else {
rope_cache_init(pos[i2], rctx->freq_scale, freq_factors, rctx->corr_dims,
ne0, rctx->ext_factor, rctx->attn_factor,
theta_cache, rctx->theta_scale);
}
// FARF(HIGH, "rope-theta %u: ir %u i1 %u i2 %u i3 %u cache %p : usec %u", ith, ir, i1, i2, i3, theta_cache,
// (unsigned) HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - rctx->t_start));