mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-11 01:24:36 +00:00
Merge branch 'upstream' into concedo_experimental
# Conflicts: # docs/backend/SYCL.md # ggml/src/CMakeLists.txt # ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp # ggml/src/ggml-sycl/CMakeLists.txt # tests/test-backend-ops.cpp
This commit is contained in:
commit
7030ebf401
24 changed files with 1883 additions and 398 deletions
|
@ -529,6 +529,8 @@ class Model:
|
||||||
reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in tokenizer.vocab.items()}
|
reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in tokenizer.vocab.items()}
|
||||||
added_vocab = tokenizer.get_added_vocab()
|
added_vocab = tokenizer.get_added_vocab()
|
||||||
|
|
||||||
|
added_tokens_decoder = tokenizer.added_tokens_decoder
|
||||||
|
|
||||||
for i in range(vocab_size):
|
for i in range(vocab_size):
|
||||||
if i not in reverse_vocab:
|
if i not in reverse_vocab:
|
||||||
tokens.append(f"[PAD{i}]")
|
tokens.append(f"[PAD{i}]")
|
||||||
|
@ -538,13 +540,13 @@ class Model:
|
||||||
if token in added_vocab:
|
if token in added_vocab:
|
||||||
# The tokenizer in llama.cpp assumes the CONTROL and USER_DEFINED tokens are pre-normalized.
|
# The tokenizer in llama.cpp assumes the CONTROL and USER_DEFINED tokens are pre-normalized.
|
||||||
# To avoid unexpected issues - we make sure to normalize non-normalized tokens
|
# To avoid unexpected issues - we make sure to normalize non-normalized tokens
|
||||||
if not tokenizer.added_tokens_decoder[i].normalized:
|
if not added_tokens_decoder[i].normalized:
|
||||||
previous_token = token
|
previous_token = token
|
||||||
token = tokenizer.decode(tokenizer.encode(token, add_special_tokens=False))
|
token = tokenizer.decode(tokenizer.encode(token, add_special_tokens=False))
|
||||||
if previous_token != token:
|
if previous_token != token:
|
||||||
logger.info(f"{repr(previous_token)} is encoded and decoded back to {repr(token)} using AutoTokenizer")
|
logger.info(f"{repr(previous_token)} is encoded and decoded back to {repr(token)} using AutoTokenizer")
|
||||||
|
|
||||||
if tokenizer.added_tokens_decoder[i].special or self.does_token_look_special(token):
|
if added_tokens_decoder[i].special or self.does_token_look_special(token):
|
||||||
toktypes.append(gguf.TokenType.CONTROL)
|
toktypes.append(gguf.TokenType.CONTROL)
|
||||||
else:
|
else:
|
||||||
# NOTE: this was added for Gemma.
|
# NOTE: this was added for Gemma.
|
||||||
|
|
Binary file not shown.
|
@ -99,13 +99,9 @@ export default function ChatScreen() {
|
||||||
canvasData,
|
canvasData,
|
||||||
replaceMessageAndGenerate,
|
replaceMessageAndGenerate,
|
||||||
} = useAppContext();
|
} = useAppContext();
|
||||||
const [inputMsg, setInputMsg] = useState(prefilledMsg.content());
|
const textarea = useOptimizedTextarea(prefilledMsg.content());
|
||||||
const inputRef = useRef<HTMLTextAreaElement>(null);
|
|
||||||
|
|
||||||
const { extraContext, clearExtraContext } = useVSCodeContext(
|
const { extraContext, clearExtraContext } = useVSCodeContext(textarea);
|
||||||
inputRef,
|
|
||||||
setInputMsg
|
|
||||||
);
|
|
||||||
// TODO: improve this when we have "upload file" feature
|
// TODO: improve this when we have "upload file" feature
|
||||||
const currExtra: Message['extra'] = extraContext ? [extraContext] : undefined;
|
const currExtra: Message['extra'] = extraContext ? [extraContext] : undefined;
|
||||||
|
|
||||||
|
@ -135,9 +131,10 @@ export default function ChatScreen() {
|
||||||
};
|
};
|
||||||
|
|
||||||
const sendNewMessage = async () => {
|
const sendNewMessage = async () => {
|
||||||
if (inputMsg.trim().length === 0 || isGenerating(currConvId ?? '')) return;
|
const lastInpMsg = textarea.value();
|
||||||
const lastInpMsg = inputMsg;
|
if (lastInpMsg.trim().length === 0 || isGenerating(currConvId ?? ''))
|
||||||
setInputMsg('');
|
return;
|
||||||
|
textarea.setValue('');
|
||||||
scrollToBottom(false);
|
scrollToBottom(false);
|
||||||
setCurrNodeId(-1);
|
setCurrNodeId(-1);
|
||||||
// get the last message node
|
// get the last message node
|
||||||
|
@ -146,13 +143,13 @@ export default function ChatScreen() {
|
||||||
!(await sendMessage(
|
!(await sendMessage(
|
||||||
currConvId,
|
currConvId,
|
||||||
lastMsgNodeId,
|
lastMsgNodeId,
|
||||||
inputMsg,
|
lastInpMsg,
|
||||||
currExtra,
|
currExtra,
|
||||||
onChunk
|
onChunk
|
||||||
))
|
))
|
||||||
) {
|
) {
|
||||||
// restore the input message if failed
|
// restore the input message if failed
|
||||||
setInputMsg(lastInpMsg);
|
textarea.setValue(lastInpMsg);
|
||||||
}
|
}
|
||||||
// OK
|
// OK
|
||||||
clearExtraContext();
|
clearExtraContext();
|
||||||
|
@ -195,16 +192,13 @@ export default function ChatScreen() {
|
||||||
// send the prefilled message if needed
|
// send the prefilled message if needed
|
||||||
sendNewMessage();
|
sendNewMessage();
|
||||||
} else {
|
} else {
|
||||||
// otherwise, focus on the input and move the cursor to the end
|
// otherwise, focus on the input
|
||||||
if (inputRef.current) {
|
textarea.focus();
|
||||||
inputRef.current.focus();
|
|
||||||
inputRef.current.selectionStart = inputRef.current.value.length;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
prefilledMsg.clear();
|
prefilledMsg.clear();
|
||||||
// no need to keep track of sendNewMessage
|
// no need to keep track of sendNewMessage
|
||||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||||
}, [inputRef]);
|
}, [textarea.ref]);
|
||||||
|
|
||||||
// due to some timing issues of StorageUtils.appendMsg(), we need to make sure the pendingMsg is not duplicated upon rendering (i.e. appears once in the saved conversation and once in the pendingMsg)
|
// due to some timing issues of StorageUtils.appendMsg(), we need to make sure the pendingMsg is not duplicated upon rendering (i.e. appears once in the saved conversation and once in the pendingMsg)
|
||||||
const pendingMsgDisplay: MessageDisplay[] =
|
const pendingMsgDisplay: MessageDisplay[] =
|
||||||
|
@ -258,9 +252,7 @@ export default function ChatScreen() {
|
||||||
<textarea
|
<textarea
|
||||||
className="textarea textarea-bordered w-full"
|
className="textarea textarea-bordered w-full"
|
||||||
placeholder="Type a message (Shift+Enter to add a new line)"
|
placeholder="Type a message (Shift+Enter to add a new line)"
|
||||||
ref={inputRef}
|
ref={textarea.ref}
|
||||||
value={inputMsg}
|
|
||||||
onChange={(e) => setInputMsg(e.target.value)}
|
|
||||||
onKeyDown={(e) => {
|
onKeyDown={(e) => {
|
||||||
if (e.nativeEvent.isComposing || e.keyCode === 229) return;
|
if (e.nativeEvent.isComposing || e.keyCode === 229) return;
|
||||||
if (e.key === 'Enter' && e.shiftKey) return;
|
if (e.key === 'Enter' && e.shiftKey) return;
|
||||||
|
@ -280,11 +272,7 @@ export default function ChatScreen() {
|
||||||
Stop
|
Stop
|
||||||
</button>
|
</button>
|
||||||
) : (
|
) : (
|
||||||
<button
|
<button className="btn btn-primary ml-2" onClick={sendNewMessage}>
|
||||||
className="btn btn-primary ml-2"
|
|
||||||
onClick={sendNewMessage}
|
|
||||||
disabled={inputMsg.trim().length === 0}
|
|
||||||
>
|
|
||||||
Send
|
Send
|
||||||
</button>
|
</button>
|
||||||
)}
|
)}
|
||||||
|
@ -298,3 +286,43 @@ export default function ChatScreen() {
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export interface OptimizedTextareaValue {
|
||||||
|
value: () => string;
|
||||||
|
setValue: (value: string) => void;
|
||||||
|
focus: () => void;
|
||||||
|
ref: React.RefObject<HTMLTextAreaElement>;
|
||||||
|
}
|
||||||
|
|
||||||
|
// This is a workaround to prevent the textarea from re-rendering when the inner content changes
|
||||||
|
// See https://github.com/ggml-org/llama.cpp/pull/12299
|
||||||
|
function useOptimizedTextarea(initValue: string): OptimizedTextareaValue {
|
||||||
|
const [savedInitValue, setSavedInitValue] = useState<string>(initValue);
|
||||||
|
const textareaRef = useRef<HTMLTextAreaElement>(null);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (textareaRef.current && savedInitValue) {
|
||||||
|
textareaRef.current.value = savedInitValue;
|
||||||
|
setSavedInitValue('');
|
||||||
|
}
|
||||||
|
}, [textareaRef, savedInitValue, setSavedInitValue]);
|
||||||
|
|
||||||
|
return {
|
||||||
|
value: () => {
|
||||||
|
return textareaRef.current?.value ?? savedInitValue;
|
||||||
|
},
|
||||||
|
setValue: (value: string) => {
|
||||||
|
if (textareaRef.current) {
|
||||||
|
textareaRef.current.value = value;
|
||||||
|
}
|
||||||
|
},
|
||||||
|
focus: () => {
|
||||||
|
if (textareaRef.current) {
|
||||||
|
// focus and move the cursor to the end
|
||||||
|
textareaRef.current.focus();
|
||||||
|
textareaRef.current.selectionStart = textareaRef.current.value.length;
|
||||||
|
}
|
||||||
|
},
|
||||||
|
ref: textareaRef,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
import { useEffect, useState } from 'react';
|
import { useEffect, useState } from 'react';
|
||||||
import { MessageExtraContext } from './types';
|
import { MessageExtraContext } from './types';
|
||||||
|
import { OptimizedTextareaValue } from '../components/ChatScreen';
|
||||||
|
|
||||||
// Extra context when using llama.cpp WebUI from llama-vscode, inside an iframe
|
// Extra context when using llama.cpp WebUI from llama-vscode, inside an iframe
|
||||||
// Ref: https://github.com/ggml-org/llama.cpp/pull/11940
|
// Ref: https://github.com/ggml-org/llama.cpp/pull/11940
|
||||||
|
@ -14,10 +15,7 @@ interface SetTextEvData {
|
||||||
* window.postMessage({ command: 'setText', text: 'Spot the syntax error', context: 'def test()\n return 123' }, '*');
|
* window.postMessage({ command: 'setText', text: 'Spot the syntax error', context: 'def test()\n return 123' }, '*');
|
||||||
*/
|
*/
|
||||||
|
|
||||||
export const useVSCodeContext = (
|
export const useVSCodeContext = (textarea: OptimizedTextareaValue) => {
|
||||||
inputRef: React.RefObject<HTMLTextAreaElement>,
|
|
||||||
setInputMsg: (text: string) => void
|
|
||||||
) => {
|
|
||||||
const [extraContext, setExtraContext] = useState<MessageExtraContext | null>(
|
const [extraContext, setExtraContext] = useState<MessageExtraContext | null>(
|
||||||
null
|
null
|
||||||
);
|
);
|
||||||
|
@ -27,20 +25,20 @@ export const useVSCodeContext = (
|
||||||
const handleMessage = (event: MessageEvent) => {
|
const handleMessage = (event: MessageEvent) => {
|
||||||
if (event.data?.command === 'setText') {
|
if (event.data?.command === 'setText') {
|
||||||
const data: SetTextEvData = event.data;
|
const data: SetTextEvData = event.data;
|
||||||
setInputMsg(data?.text);
|
textarea.setValue(data?.text);
|
||||||
if (data?.context && data.context.length > 0) {
|
if (data?.context && data.context.length > 0) {
|
||||||
setExtraContext({
|
setExtraContext({
|
||||||
type: 'context',
|
type: 'context',
|
||||||
content: data.context,
|
content: data.context,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
inputRef.current?.focus();
|
textarea.focus();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
window.addEventListener('message', handleMessage);
|
window.addEventListener('message', handleMessage);
|
||||||
return () => window.removeEventListener('message', handleMessage);
|
return () => window.removeEventListener('message', handleMessage);
|
||||||
}, [inputRef, setInputMsg]);
|
}, [textarea]);
|
||||||
|
|
||||||
// Add a keydown listener that sends the "escapePressed" message to the parent window
|
// Add a keydown listener that sends the "escapePressed" message to the parent window
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
|
|
|
@ -571,6 +571,10 @@ int main(int argc, char ** argv) {
|
||||||
model_ttc = llama_init_ttc.model.get();
|
model_ttc = llama_init_ttc.model.get();
|
||||||
ctx_ttc = llama_init_ttc.context.get();
|
ctx_ttc = llama_init_ttc.context.get();
|
||||||
|
|
||||||
|
if (model_ttc == nullptr || ctx_ttc == nullptr) {
|
||||||
|
return ENOENT;
|
||||||
|
}
|
||||||
|
|
||||||
const llama_vocab * vocab = llama_model_get_vocab(model_ttc);
|
const llama_vocab * vocab = llama_model_get_vocab(model_ttc);
|
||||||
|
|
||||||
// TODO: refactor in a common struct
|
// TODO: refactor in a common struct
|
||||||
|
@ -586,6 +590,10 @@ int main(int argc, char ** argv) {
|
||||||
model_cts = llama_init_cts.model.get();
|
model_cts = llama_init_cts.model.get();
|
||||||
ctx_cts = llama_init_cts.context.get();
|
ctx_cts = llama_init_cts.context.get();
|
||||||
|
|
||||||
|
if (model_cts == nullptr || ctx_cts == nullptr) {
|
||||||
|
return ENOENT;
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<common_sampler *> smpl(n_parallel);
|
std::vector<common_sampler *> smpl(n_parallel);
|
||||||
for (int i = 0; i < n_parallel; ++i) {
|
for (int i = 0; i < n_parallel; ++i) {
|
||||||
params.sampling.no_perf = (i != 0);
|
params.sampling.no_perf = (i != 0);
|
||||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -606,48 +606,47 @@ static __global__ void flash_attn_stream_k_fixup(
|
||||||
*dst = dst_val / rowsum;
|
*dst = dst_val / rowsum;
|
||||||
}
|
}
|
||||||
|
|
||||||
template<int D, int parallel_blocks> // D == head size
|
template<int D> // D == head size
|
||||||
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
||||||
__launch_bounds__(D, 1)
|
__launch_bounds__(D, 1)
|
||||||
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
||||||
static __global__ void flash_attn_combine_results(
|
static __global__ void flash_attn_combine_results(
|
||||||
const float * __restrict__ VKQ_parts,
|
const float * __restrict__ VKQ_parts,
|
||||||
const float2 * __restrict__ VKQ_meta,
|
const float2 * __restrict__ VKQ_meta,
|
||||||
float * __restrict__ dst) {
|
float * __restrict__ dst,
|
||||||
VKQ_parts += parallel_blocks*D * gridDim.y*blockIdx.x;
|
const int parallel_blocks) {
|
||||||
VKQ_meta += parallel_blocks * gridDim.y*blockIdx.x;
|
VKQ_parts += parallel_blocks*D * gridDim.z*blockIdx.x;
|
||||||
dst += D * gridDim.y*blockIdx.x;
|
VKQ_meta += parallel_blocks * gridDim.z*blockIdx.x;
|
||||||
|
dst += D * gridDim.z*blockIdx.x;
|
||||||
|
|
||||||
const int tid = threadIdx.x;
|
const int tid = threadIdx.x;
|
||||||
__builtin_assume(tid < D);
|
__builtin_assume(tid < D);
|
||||||
|
|
||||||
__shared__ float2 meta[parallel_blocks];
|
extern __shared__ float2 meta[];
|
||||||
if (tid < 2*parallel_blocks) {
|
if (tid < 2*parallel_blocks) {
|
||||||
((float *) meta)[threadIdx.x] = ((const float *)VKQ_meta) [blockIdx.y*(2*parallel_blocks) + tid];
|
((float *) meta)[threadIdx.x] = ((const float *)VKQ_meta) [blockIdx.z*(2*parallel_blocks) + tid];
|
||||||
}
|
}
|
||||||
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
float kqmax = meta[0].x;
|
float kqmax = meta[0].x;
|
||||||
#pragma unroll
|
|
||||||
for (int l = 1; l < parallel_blocks; ++l) {
|
for (int l = 1; l < parallel_blocks; ++l) {
|
||||||
kqmax = max(kqmax, meta[l].x);
|
kqmax = max(kqmax, meta[l].x);
|
||||||
}
|
}
|
||||||
|
|
||||||
float VKQ_numerator = 0.0f;
|
float VKQ_numerator = 0.0f;
|
||||||
float VKQ_denominator = 0.0f;
|
float VKQ_denominator = 0.0f;
|
||||||
#pragma unroll
|
|
||||||
for (int l = 0; l < parallel_blocks; ++l) {
|
for (int l = 0; l < parallel_blocks; ++l) {
|
||||||
const float diff = meta[l].x - kqmax;
|
const float diff = meta[l].x - kqmax;
|
||||||
const float KQ_max_scale = expf(diff);
|
const float KQ_max_scale = expf(diff);
|
||||||
const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD);
|
const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD);
|
||||||
*((uint32_t *) &KQ_max_scale) &= ftz_mask;
|
*((uint32_t *) &KQ_max_scale) &= ftz_mask;
|
||||||
|
|
||||||
VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.y*D + blockIdx.y*D + tid];
|
VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.z*D + blockIdx.z*D + tid];
|
||||||
VKQ_denominator += KQ_max_scale * meta[l].y;
|
VKQ_denominator += KQ_max_scale * meta[l].y;
|
||||||
}
|
}
|
||||||
|
|
||||||
dst[blockIdx.y*D + tid] = VKQ_numerator / VKQ_denominator;
|
dst[blockIdx.z*D + tid] = VKQ_numerator / VKQ_denominator;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void on_no_fattn_vec_case(const int D) {
|
static void on_no_fattn_vec_case(const int D) {
|
||||||
|
@ -671,12 +670,10 @@ static void on_no_fattn_vec_case(const int D) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// parallel_blocks == 0 is stream-k decomposition
|
template <int D, int ncols1, int ncols2, int KQ_stride>
|
||||||
template <int D, int ncols1, int ncols2, int parallel_blocks, int KQ_stride>
|
|
||||||
void launch_fattn(
|
void launch_fattn(
|
||||||
ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel,
|
ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, const int nwarps, const size_t nbytes_shared,
|
||||||
const int nwarps, const size_t nbytes_shared, const bool need_f16_K, const bool need_f16_V,
|
const int KQ_row_granularity, const bool need_f16_K, const bool need_f16_V, const bool stream_k, const int warp_size = WARP_SIZE
|
||||||
const int warp_size = WARP_SIZE
|
|
||||||
) {
|
) {
|
||||||
constexpr int ncols = ncols1 * ncols2;
|
constexpr int ncols = ncols1 * ncols2;
|
||||||
|
|
||||||
|
@ -748,12 +745,14 @@ void launch_fattn(
|
||||||
nb23 = nb23*bs*sizeof(half)/ts;
|
nb23 = nb23*bs*sizeof(half)/ts;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int parallel_blocks = 1;
|
||||||
|
|
||||||
const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1);
|
const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1);
|
||||||
const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3];
|
const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3];
|
||||||
|
|
||||||
const dim3 block_dim(warp_size, nwarps, 1);
|
const dim3 block_dim(warp_size, nwarps, 1);
|
||||||
dim3 blocks_num;
|
dim3 blocks_num;
|
||||||
if (parallel_blocks == 0) {
|
if (stream_k) {
|
||||||
// For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.
|
// For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.
|
||||||
const int max_blocks = 2*nsm;
|
const int max_blocks = 2*nsm;
|
||||||
const int tiles_nwaves = (ntiles_total + max_blocks - 1) / max_blocks;
|
const int tiles_nwaves = (ntiles_total + max_blocks - 1) / max_blocks;
|
||||||
|
@ -769,9 +768,43 @@ void launch_fattn(
|
||||||
|
|
||||||
dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 + D) * sizeof(float));
|
dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 + D) * sizeof(float));
|
||||||
} else {
|
} else {
|
||||||
blocks_num.x = parallel_blocks*ntiles_x;
|
GGML_ASSERT(K->ne[1] % KQ_row_granularity == 0);
|
||||||
blocks_num.y = Q->ne[2];
|
const int ntiles_KQ = K->ne[1] / KQ_row_granularity; // Max. number of parallel blocks limited by tensor size.
|
||||||
blocks_num.z = Q->ne[3];
|
|
||||||
|
int max_blocks_per_sm = 1; // Max. number of active blocks limited by occupancy.
|
||||||
|
CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_blocks_per_sm, fattn_kernel, block_dim.x * block_dim.y * block_dim.z, nbytes_shared));
|
||||||
|
|
||||||
|
// parallel_blocks should be at least large enough to achieve max. occupancy for a single wave:
|
||||||
|
parallel_blocks = std::max((nsm * max_blocks_per_sm) / ntiles_total, 1);
|
||||||
|
|
||||||
|
// parallel_blocks must not be larger than what the tensor size allows:
|
||||||
|
parallel_blocks = std::min(parallel_blocks, ntiles_KQ);
|
||||||
|
|
||||||
|
// If ntiles_total % blocks_per_wave != 0 then some efficiency is lost due to tail effects.
|
||||||
|
// Test whether parallel_blocks can be set to a higher value for better efficiency.
|
||||||
|
const int blocks_per_wave = nsm * max_blocks_per_sm;
|
||||||
|
int nwaves_best = 0;
|
||||||
|
int efficiency_percent_best = 0;
|
||||||
|
for (int parallel_blocks_test = parallel_blocks; parallel_blocks_test <= ntiles_KQ; ++parallel_blocks_test) {
|
||||||
|
const int nblocks_total = ntiles_total * parallel_blocks_test;
|
||||||
|
const int nwaves = (nblocks_total + blocks_per_wave - 1) / blocks_per_wave;
|
||||||
|
const int efficiency_percent = 100 * nblocks_total / (nwaves*blocks_per_wave);
|
||||||
|
|
||||||
|
// Stop trying configurations with more waves if we already have good efficiency to avoid excessive overhead.
|
||||||
|
if (efficiency_percent_best >= 90 && nwaves > nwaves_best) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (efficiency_percent > efficiency_percent_best) {
|
||||||
|
nwaves_best = nwaves;
|
||||||
|
efficiency_percent_best = efficiency_percent;
|
||||||
|
parallel_blocks = parallel_blocks_test;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
blocks_num.x = ntiles_x;
|
||||||
|
blocks_num.y = parallel_blocks;
|
||||||
|
blocks_num.z = Q->ne[2]*Q->ne[3];
|
||||||
|
|
||||||
if (parallel_blocks > 1) {
|
if (parallel_blocks > 1) {
|
||||||
dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
|
dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
|
||||||
|
@ -803,7 +836,7 @@ void launch_fattn(
|
||||||
K_data,
|
K_data,
|
||||||
V_data,
|
V_data,
|
||||||
mask ? ((const char *) mask->data) : nullptr,
|
mask ? ((const char *) mask->data) : nullptr,
|
||||||
(parallel_blocks) > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr,
|
!stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr,
|
||||||
scale, max_bias, m0, m1, n_head_log2, logit_softcap,
|
scale, max_bias, m0, m1, n_head_log2, logit_softcap,
|
||||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||||
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||||
|
@ -815,7 +848,7 @@ void launch_fattn(
|
||||||
);
|
);
|
||||||
CUDA_CHECK(cudaGetLastError());
|
CUDA_CHECK(cudaGetLastError());
|
||||||
|
|
||||||
if constexpr (parallel_blocks == 0) {
|
if (stream_k) {
|
||||||
if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
|
if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
|
||||||
const dim3 block_dim_combine(D, 1, 1);
|
const dim3 block_dim_combine(D, 1, 1);
|
||||||
const dim3 blocks_num_combine = {blocks_num.x, ncols1, ncols2};
|
const dim3 blocks_num_combine = {blocks_num.x, ncols1, ncols2};
|
||||||
|
@ -824,13 +857,14 @@ void launch_fattn(
|
||||||
<<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
|
<<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
|
||||||
((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[1]);
|
((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[1]);
|
||||||
}
|
}
|
||||||
} else if constexpr (parallel_blocks > 1) {
|
} else if (parallel_blocks > 1) {
|
||||||
const dim3 block_dim_combine(D, 1, 1);
|
const dim3 block_dim_combine(D, 1, 1);
|
||||||
const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z);
|
const dim3 blocks_num_combine(Q->ne[1], 1, blocks_num.z);
|
||||||
|
const size_t nbytes_shared_combine = parallel_blocks*sizeof(float2);
|
||||||
|
|
||||||
flash_attn_combine_results<D, parallel_blocks>
|
flash_attn_combine_results<D>
|
||||||
<<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
|
<<<blocks_num_combine, block_dim_combine, nbytes_shared_combine, main_stream>>>
|
||||||
(dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
|
(dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data, parallel_blocks);
|
||||||
}
|
}
|
||||||
CUDA_CHECK(cudaGetLastError());
|
CUDA_CHECK(cudaGetLastError());
|
||||||
}
|
}
|
||||||
|
|
|
@ -970,7 +970,8 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
|
||||||
fattn_kernel = flash_attn_ext_f16<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap>;
|
fattn_kernel = flash_attn_ext_f16<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap>;
|
||||||
}
|
}
|
||||||
|
|
||||||
launch_fattn<D, ncols1, ncols2, 0, KQ_per_iter>(ctx, dst, fattn_kernel, nwarps, nbytes_shared_total, true, true);
|
launch_fattn<D, ncols1, ncols2, KQ_per_iter>
|
||||||
|
(ctx, dst, fattn_kernel, nwarps, nbytes_shared_total, FATTN_KQ_STRIDE, true, true, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
|
|
||||||
#define FATTN_KQ_STRIDE_TILE_F16 64
|
#define FATTN_KQ_STRIDE_TILE_F16 64
|
||||||
|
|
||||||
template<int D, int ncols, int nwarps, int parallel_blocks, bool use_logit_softcap> // D == head size
|
template<int D, int ncols, int nwarps, bool use_logit_softcap> // D == head size
|
||||||
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
||||||
__launch_bounds__(nwarps*WARP_SIZE, 1)
|
__launch_bounds__(nwarps*WARP_SIZE, 1)
|
||||||
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
||||||
|
@ -58,18 +58,17 @@ static __global__ void flash_attn_tile_ext_f16(
|
||||||
|
|
||||||
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
||||||
|
|
||||||
const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
|
const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
|
||||||
const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
|
|
||||||
|
|
||||||
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
||||||
const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y + nb01*ic0);
|
const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.z + nb01*ic0);
|
||||||
const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.y / gqa_ratio));
|
const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.z / gqa_ratio));
|
||||||
const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
|
const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
|
||||||
const half * maskh = (const half *) mask + ne11*ic0;
|
const half * maskh = (const half *) mask + ne11*ic0;
|
||||||
|
|
||||||
const int stride_KV2 = nb11 / sizeof(half2);
|
const int stride_KV2 = nb11 / sizeof(half2);
|
||||||
|
|
||||||
const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
|
const float slopef = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
|
||||||
const half slopeh = __float2half(slopef);
|
const half slopeh = __float2half(slopef);
|
||||||
|
|
||||||
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
|
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
|
||||||
|
@ -105,8 +104,7 @@ static __global__ void flash_attn_tile_ext_f16(
|
||||||
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
const int k_start = parallel_blocks == 1 ? 0 : ip*FATTN_KQ_STRIDE_TILE_F16;
|
for (int k_VKQ_0 = blockIdx.y*FATTN_KQ_STRIDE_TILE_F16; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*FATTN_KQ_STRIDE_TILE_F16) {
|
||||||
for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE_TILE_F16) {
|
|
||||||
// Calculate KQ tile and keep track of new maximum KQ values:
|
// Calculate KQ tile and keep track of new maximum KQ values:
|
||||||
|
|
||||||
half kqmax_new[ncols/nwarps];
|
half kqmax_new[ncols/nwarps];
|
||||||
|
@ -271,16 +269,16 @@ static __global__ void flash_attn_tile_ext_f16(
|
||||||
const int i0 = i00 + 2*threadIdx.x;
|
const int i0 = i00 + 2*threadIdx.x;
|
||||||
|
|
||||||
half2 dst_val = VKQ[j_VKQ_0/nwarps][i0/(2*WARP_SIZE)];
|
half2 dst_val = VKQ[j_VKQ_0/nwarps][i0/(2*WARP_SIZE)];
|
||||||
if (parallel_blocks == 1) {
|
if (gridDim.y == 1) {
|
||||||
dst_val /= __half2half2(kqsum_j);
|
dst_val /= __half2half2(kqsum_j);
|
||||||
}
|
}
|
||||||
const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
|
const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
|
||||||
dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 0] = __low2float(dst_val);
|
dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 0] = __low2float(dst_val);
|
||||||
dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 1] = __high2float(dst_val);
|
dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 1] = __high2float(dst_val);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (parallel_blocks != 1 && threadIdx.x == 0) {
|
if (gridDim.y != 1 && threadIdx.x == 0) {
|
||||||
dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
|
dst_meta[((ic0 + j_VKQ)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
|
@ -288,7 +286,7 @@ static __global__ void flash_attn_tile_ext_f16(
|
||||||
#endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
|
#endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
|
||||||
}
|
}
|
||||||
|
|
||||||
template <int cols_per_block, int parallel_blocks, bool use_logit_softcap>
|
template <int cols_per_block, bool use_logit_softcap>
|
||||||
void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
const ggml_tensor * Q = dst->src[0];
|
const ggml_tensor * Q = dst->src[0];
|
||||||
switch (Q->ne[0]) {
|
switch (Q->ne[0]) {
|
||||||
|
@ -296,15 +294,17 @@ void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
|
||||||
constexpr int D = 64;
|
constexpr int D = 64;
|
||||||
constexpr int nwarps = 8;
|
constexpr int nwarps = 8;
|
||||||
constexpr size_t nbytes_shared = 0;
|
constexpr size_t nbytes_shared = 0;
|
||||||
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
|
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, use_logit_softcap>;
|
||||||
launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
|
launch_fattn<D, cols_per_block, 1, -1>
|
||||||
|
(ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F16, true, true, false);
|
||||||
} break;
|
} break;
|
||||||
case 128: {
|
case 128: {
|
||||||
constexpr int D = 128;
|
constexpr int D = 128;
|
||||||
constexpr int nwarps = 8;
|
constexpr int nwarps = 8;
|
||||||
constexpr size_t nbytes_shared = 0;
|
constexpr size_t nbytes_shared = 0;
|
||||||
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
|
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, use_logit_softcap>;
|
||||||
launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
|
launch_fattn<D, cols_per_block, 1, -1>
|
||||||
|
(ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F16, true, true, false);
|
||||||
} break;
|
} break;
|
||||||
default: {
|
default: {
|
||||||
GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128.");
|
GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128.");
|
||||||
|
@ -324,37 +324,22 @@ void ggml_cuda_flash_attn_ext_tile_f16(ggml_backend_cuda_context & ctx, ggml_ten
|
||||||
|
|
||||||
if (Q->ne[1] <= 16) {
|
if (Q->ne[1] <= 16) {
|
||||||
constexpr int cols_per_block = 16;
|
constexpr int cols_per_block = 16;
|
||||||
constexpr int parallel_blocks = 4;
|
|
||||||
if (logit_softcap == 0.0f) {
|
if (logit_softcap == 0.0f) {
|
||||||
constexpr bool use_logit_softcap = false;
|
constexpr bool use_logit_softcap = false;
|
||||||
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
|
launch_fattn_tile_f16_64_128<cols_per_block, use_logit_softcap>(ctx, dst);
|
||||||
} else {
|
} else {
|
||||||
constexpr bool use_logit_softcap = true;
|
constexpr bool use_logit_softcap = true;
|
||||||
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
|
launch_fattn_tile_f16_64_128<cols_per_block, use_logit_softcap>(ctx, dst);
|
||||||
}
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (Q->ne[1] <= 32) {
|
|
||||||
constexpr int cols_per_block = 32;
|
|
||||||
constexpr int parallel_blocks = 4;
|
|
||||||
if (logit_softcap == 0.0f) {
|
|
||||||
constexpr bool use_logit_softcap = false;
|
|
||||||
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
|
|
||||||
} else {
|
|
||||||
constexpr bool use_logit_softcap = true;
|
|
||||||
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
|
|
||||||
}
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
constexpr int cols_per_block = 32;
|
constexpr int cols_per_block = 32;
|
||||||
constexpr int parallel_blocks = 1;
|
|
||||||
if (logit_softcap == 0.0f) {
|
if (logit_softcap == 0.0f) {
|
||||||
constexpr bool use_logit_softcap = false;
|
constexpr bool use_logit_softcap = false;
|
||||||
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
|
launch_fattn_tile_f16_64_128<cols_per_block, use_logit_softcap>(ctx, dst);
|
||||||
} else {
|
} else {
|
||||||
constexpr bool use_logit_softcap = true;
|
constexpr bool use_logit_softcap = true;
|
||||||
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
|
launch_fattn_tile_f16_64_128<cols_per_block, use_logit_softcap>(ctx, dst);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
|
|
||||||
#define FATTN_KQ_STRIDE_TILE_F32 32
|
#define FATTN_KQ_STRIDE_TILE_F32 32
|
||||||
|
|
||||||
template<int D, int ncols, int nwarps, int parallel_blocks, bool use_logit_softcap> // D == head size
|
template<int D, int ncols, int nwarps, bool use_logit_softcap> // D == head size
|
||||||
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
||||||
__launch_bounds__(nwarps*WARP_SIZE, 1)
|
__launch_bounds__(nwarps*WARP_SIZE, 1)
|
||||||
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
||||||
|
@ -58,18 +58,17 @@ static __global__ void flash_attn_tile_ext_f32(
|
||||||
|
|
||||||
// In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
// In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
||||||
|
|
||||||
const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
|
const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
|
||||||
const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
|
|
||||||
|
|
||||||
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
||||||
const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y + nb01*ic0);
|
const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.z + nb01*ic0);
|
||||||
const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.y / gqa_ratio));
|
const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.z / gqa_ratio));
|
||||||
const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
|
const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
|
||||||
const half * maskh = (const half *) mask + ne11*ic0;
|
const half * maskh = (const half *) mask + ne11*ic0;
|
||||||
|
|
||||||
const int stride_KV2 = nb11 / sizeof(half2);
|
const int stride_KV2 = nb11 / sizeof(half2);
|
||||||
|
|
||||||
const float slope = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
|
const float slope = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
|
||||||
|
|
||||||
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
|
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
|
||||||
|
|
||||||
|
@ -103,8 +102,7 @@ static __global__ void flash_attn_tile_ext_f32(
|
||||||
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
const int k_start = parallel_blocks == 1 ? 0 : ip*FATTN_KQ_STRIDE_TILE_F32;
|
for (int k_VKQ_0 = blockIdx.y*FATTN_KQ_STRIDE_TILE_F32; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*FATTN_KQ_STRIDE_TILE_F32) {
|
||||||
for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE_TILE_F32) {
|
|
||||||
// Calculate KQ tile and keep track of new maximum KQ values:
|
// Calculate KQ tile and keep track of new maximum KQ values:
|
||||||
|
|
||||||
float kqmax_new[ncols/nwarps];
|
float kqmax_new[ncols/nwarps];
|
||||||
|
@ -269,17 +267,17 @@ static __global__ void flash_attn_tile_ext_f32(
|
||||||
const int i0 = i00 + 2*threadIdx.x;
|
const int i0 = i00 + 2*threadIdx.x;
|
||||||
|
|
||||||
float2 dst_val = VKQ[j_VKQ_0/nwarps][i0/(2*WARP_SIZE)];
|
float2 dst_val = VKQ[j_VKQ_0/nwarps][i0/(2*WARP_SIZE)];
|
||||||
if (parallel_blocks == 1) {
|
if (gridDim.y == 1) {
|
||||||
dst_val.x /= kqsum_j;
|
dst_val.x /= kqsum_j;
|
||||||
dst_val.y /= kqsum_j;
|
dst_val.y /= kqsum_j;
|
||||||
}
|
}
|
||||||
const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
|
const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
|
||||||
dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 0] = dst_val.x;
|
dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 0] = dst_val.x;
|
||||||
dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 1] = dst_val.y;
|
dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 1] = dst_val.y;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (parallel_blocks != 1 && threadIdx.x == 0) {
|
if (gridDim.y != 1 && threadIdx.x == 0) {
|
||||||
dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
|
dst_meta[((ic0 + j_VKQ)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
|
@ -287,7 +285,7 @@ static __global__ void flash_attn_tile_ext_f32(
|
||||||
#endif // FLASH_ATTN_AVAILABLE
|
#endif // FLASH_ATTN_AVAILABLE
|
||||||
}
|
}
|
||||||
|
|
||||||
template <int cols_per_block, int parallel_blocks, bool use_logit_softcap>
|
template <int cols_per_block, bool use_logit_softcap>
|
||||||
void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
const ggml_tensor * Q = dst->src[0];
|
const ggml_tensor * Q = dst->src[0];
|
||||||
switch (Q->ne[0]) {
|
switch (Q->ne[0]) {
|
||||||
|
@ -295,15 +293,17 @@ void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
|
||||||
constexpr int D = 64;
|
constexpr int D = 64;
|
||||||
constexpr int nwarps = 8;
|
constexpr int nwarps = 8;
|
||||||
constexpr size_t nbytes_shared = 0;
|
constexpr size_t nbytes_shared = 0;
|
||||||
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
|
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, use_logit_softcap>;
|
||||||
launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
|
launch_fattn<D, cols_per_block, 1, -1>
|
||||||
|
(ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F32, true, true, false);
|
||||||
} break;
|
} break;
|
||||||
case 128: {
|
case 128: {
|
||||||
constexpr int D = 128;
|
constexpr int D = 128;
|
||||||
constexpr int nwarps = 8;
|
constexpr int nwarps = 8;
|
||||||
constexpr size_t nbytes_shared = 0;
|
constexpr size_t nbytes_shared = 0;
|
||||||
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
|
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, use_logit_softcap>;
|
||||||
launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
|
launch_fattn<D, cols_per_block, 1, -1>
|
||||||
|
(ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F32, true, true, false);
|
||||||
} break;
|
} break;
|
||||||
default: {
|
default: {
|
||||||
GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128.");
|
GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128.");
|
||||||
|
@ -320,37 +320,22 @@ void ggml_cuda_flash_attn_ext_tile_f32(ggml_backend_cuda_context & ctx, ggml_ten
|
||||||
|
|
||||||
if (Q->ne[1] <= 16) {
|
if (Q->ne[1] <= 16) {
|
||||||
constexpr int cols_per_block = 16;
|
constexpr int cols_per_block = 16;
|
||||||
constexpr int parallel_blocks = 4;
|
|
||||||
if (logit_softcap == 0.0f) {
|
if (logit_softcap == 0.0f) {
|
||||||
constexpr bool use_logit_softcap = false;
|
constexpr bool use_logit_softcap = false;
|
||||||
launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
|
launch_fattn_tile_f32_64_128<cols_per_block, use_logit_softcap>(ctx, dst);
|
||||||
} else {
|
} else {
|
||||||
constexpr bool use_logit_softcap = true;
|
constexpr bool use_logit_softcap = true;
|
||||||
launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
|
launch_fattn_tile_f32_64_128<cols_per_block, use_logit_softcap>(ctx, dst);
|
||||||
}
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (Q->ne[1] <= 32) {
|
|
||||||
constexpr int cols_per_block = 32;
|
|
||||||
constexpr int parallel_blocks = 4;
|
|
||||||
if (logit_softcap == 0.0f) {
|
|
||||||
constexpr bool use_logit_softcap = false;
|
|
||||||
launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
|
|
||||||
} else {
|
|
||||||
constexpr bool use_logit_softcap = true;
|
|
||||||
launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
|
|
||||||
}
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
constexpr int cols_per_block = 32;
|
constexpr int cols_per_block = 32;
|
||||||
constexpr int parallel_blocks = 1;
|
|
||||||
if (logit_softcap == 0.0f) {
|
if (logit_softcap == 0.0f) {
|
||||||
constexpr bool use_logit_softcap = false;
|
constexpr bool use_logit_softcap = false;
|
||||||
launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
|
launch_fattn_tile_f32_64_128<cols_per_block, use_logit_softcap>(ctx, dst);
|
||||||
} else {
|
} else {
|
||||||
constexpr bool use_logit_softcap = true;
|
constexpr bool use_logit_softcap = true;
|
||||||
launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
|
launch_fattn_tile_f32_64_128<cols_per_block, use_logit_softcap>(ctx, dst);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
#include "common.cuh"
|
#include "common.cuh"
|
||||||
#include "fattn-common.cuh"
|
#include "fattn-common.cuh"
|
||||||
|
|
||||||
template<int D, int ncols, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_logit_softcap> // D == head size
|
template<int D, int ncols, ggml_type type_K, ggml_type type_V, bool use_logit_softcap> // D == head size
|
||||||
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
||||||
__launch_bounds__(D, 1)
|
__launch_bounds__(D, 1)
|
||||||
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
||||||
|
@ -55,17 +55,16 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||||
constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16;
|
constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16;
|
||||||
constexpr dequantize_1_f16_t dequantize_1_v = get_dequantize_1_f16(type_V);
|
constexpr dequantize_1_f16_t dequantize_1_v = get_dequantize_1_f16(type_V);
|
||||||
|
|
||||||
const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
|
const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
|
||||||
const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
|
|
||||||
|
|
||||||
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
||||||
Q += nb02* blockIdx.y + nb01*ic0;
|
Q += nb02* blockIdx.z + nb01*ic0;
|
||||||
K += nb12*(blockIdx.y / gqa_ratio);
|
K += nb12*(blockIdx.z / gqa_ratio);
|
||||||
V += nb22*(blockIdx.y / gqa_ratio);
|
V += nb22*(blockIdx.z / gqa_ratio);
|
||||||
|
|
||||||
const half * maskh = (const half *) mask + ne11*ic0;
|
const half * maskh = (const half *) mask + ne11*ic0;
|
||||||
|
|
||||||
const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
|
const float slopef = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
|
||||||
const half slopeh = __float2half(slopef);
|
const half slopeh = __float2half(slopef);
|
||||||
|
|
||||||
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
|
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
|
||||||
|
@ -172,8 +171,7 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||||
|
|
||||||
half2 VKQ[ncols] = {{0.0f, 0.0f}};
|
half2 VKQ[ncols] = {{0.0f, 0.0f}};
|
||||||
|
|
||||||
const int k_start = parallel_blocks == 1 ? 0 : ip*D;
|
for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D) {
|
||||||
for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*D) {
|
|
||||||
// Calculate KQ tile and keep track of new maximum KQ values:
|
// Calculate KQ tile and keep track of new maximum KQ values:
|
||||||
|
|
||||||
// For unknown reasons using a half array of size 1 for kqmax_new causes a performance regression,
|
// For unknown reasons using a half array of size 1 for kqmax_new causes a performance regression,
|
||||||
|
@ -283,29 +281,29 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||||
kqsum[j_VKQ] = warp_reduce_sum((float)kqsum[j_VKQ]);
|
kqsum[j_VKQ] = warp_reduce_sum((float)kqsum[j_VKQ]);
|
||||||
|
|
||||||
half dst_val = (__low2half(VKQ[j_VKQ]) + __high2half(VKQ[j_VKQ]));
|
half dst_val = (__low2half(VKQ[j_VKQ]) + __high2half(VKQ[j_VKQ]));
|
||||||
if (parallel_blocks == 1) {
|
if (gridDim.y == 1) {
|
||||||
dst_val /= kqsum[j_VKQ];
|
dst_val /= kqsum[j_VKQ];
|
||||||
}
|
}
|
||||||
const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
|
const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
|
||||||
dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val;
|
dst[j_dst*D*gridDim.z + D*blockIdx.z + tid] = dst_val;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (parallel_blocks != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
|
if (gridDim.y != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
|
||||||
dst_meta[(ic0 + tid)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[tid], kqsum[tid]);
|
dst_meta[((ic0 + tid)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]);
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
NO_DEVICE_CODE;
|
NO_DEVICE_CODE;
|
||||||
#endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
|
#endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
|
||||||
}
|
}
|
||||||
|
|
||||||
template <int D, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_logit_softcap>
|
template <int D, int cols_per_block, ggml_type type_K, ggml_type type_V, bool use_logit_softcap>
|
||||||
void ggml_cuda_flash_attn_ext_vec_f16_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
void ggml_cuda_flash_attn_ext_vec_f16_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
constexpr int nwarps = D/WARP_SIZE;
|
constexpr int nwarps = D/WARP_SIZE;
|
||||||
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>;
|
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16<D, cols_per_block, type_K, type_V, use_logit_softcap>;
|
||||||
constexpr bool need_f16_K = D != 128;
|
constexpr bool need_f16_K = D != 128;
|
||||||
constexpr bool need_f16_V = D != 128 && D != 64;
|
constexpr bool need_f16_V = D != 128 && D != 64;
|
||||||
constexpr size_t nbytes_shared = 0;
|
constexpr size_t nbytes_shared = 0;
|
||||||
launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, need_f16_K, need_f16_V);
|
launch_fattn<D, cols_per_block, 1, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <int D, ggml_type type_K, ggml_type type_V>
|
template <int D, ggml_type type_K, ggml_type type_V>
|
||||||
|
@ -326,64 +324,47 @@ void ggml_cuda_flash_attn_ext_vec_f16_case(ggml_backend_cuda_context & ctx, ggml
|
||||||
|
|
||||||
if (Q->ne[1] == 1) {
|
if (Q->ne[1] == 1) {
|
||||||
constexpr int cols_per_block = 1;
|
constexpr int cols_per_block = 1;
|
||||||
constexpr int parallel_blocks = 4;
|
|
||||||
if (logit_softcap == 0.0f) {
|
if (logit_softcap == 0.0f) {
|
||||||
constexpr bool use_logit_softcap = false;
|
constexpr bool use_logit_softcap = false;
|
||||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||||
} else {
|
} else {
|
||||||
constexpr bool use_logit_softcap = true;
|
constexpr bool use_logit_softcap = true;
|
||||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||||
}
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (Q->ne[1] == 2) {
|
if (Q->ne[1] == 2) {
|
||||||
constexpr int cols_per_block = 2;
|
constexpr int cols_per_block = 2;
|
||||||
constexpr int parallel_blocks = 4;
|
|
||||||
if (logit_softcap == 0.0f) {
|
if (logit_softcap == 0.0f) {
|
||||||
constexpr bool use_logit_softcap = false;
|
constexpr bool use_logit_softcap = false;
|
||||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||||
} else {
|
} else {
|
||||||
constexpr bool use_logit_softcap = true;
|
constexpr bool use_logit_softcap = true;
|
||||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||||
}
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (Q->ne[1] <= 4) {
|
if (Q->ne[1] <= 4) {
|
||||||
constexpr int cols_per_block = 4;
|
constexpr int cols_per_block = 4;
|
||||||
constexpr int parallel_blocks = 4;
|
|
||||||
if (logit_softcap == 0.0f) {
|
if (logit_softcap == 0.0f) {
|
||||||
constexpr bool use_logit_softcap = false;
|
constexpr bool use_logit_softcap = false;
|
||||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||||
} else {
|
} else {
|
||||||
constexpr bool use_logit_softcap = true;
|
constexpr bool use_logit_softcap = true;
|
||||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||||
}
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (Q->ne[1] <= 8) {
|
|
||||||
constexpr int cols_per_block = 8;
|
|
||||||
constexpr int parallel_blocks = 4;
|
|
||||||
if (logit_softcap == 0.0f) {
|
|
||||||
constexpr bool use_logit_softcap = false;
|
|
||||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
|
||||||
} else {
|
|
||||||
constexpr bool use_logit_softcap = true;
|
|
||||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
|
||||||
}
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
constexpr int cols_per_block = 8;
|
constexpr int cols_per_block = 8;
|
||||||
constexpr int parallel_blocks = 1;
|
|
||||||
if (logit_softcap == 0.0f) {
|
if (logit_softcap == 0.0f) {
|
||||||
constexpr bool use_logit_softcap = false;
|
constexpr bool use_logit_softcap = false;
|
||||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||||
} else {
|
} else {
|
||||||
constexpr bool use_logit_softcap = true;
|
constexpr bool use_logit_softcap = true;
|
||||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
#include "common.cuh"
|
#include "common.cuh"
|
||||||
#include "fattn-common.cuh"
|
#include "fattn-common.cuh"
|
||||||
|
|
||||||
template<int D, int ncols, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_logit_softcap> // D == head size
|
template<int D, int ncols, ggml_type type_K, ggml_type type_V, bool use_logit_softcap> // D == head size
|
||||||
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
||||||
__launch_bounds__(D, 1)
|
__launch_bounds__(D, 1)
|
||||||
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
||||||
|
@ -55,16 +55,15 @@ static __global__ void flash_attn_vec_ext_f32(
|
||||||
constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16;
|
constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16;
|
||||||
constexpr dequantize_1_f32_t dequantize_1_v = get_dequantize_1_f32(type_V);
|
constexpr dequantize_1_f32_t dequantize_1_v = get_dequantize_1_f32(type_V);
|
||||||
|
|
||||||
const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
|
const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
|
||||||
const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
|
|
||||||
|
|
||||||
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
||||||
Q += nb02* blockIdx.y + nb01*ic0;
|
Q += nb02* blockIdx.z + nb01*ic0;
|
||||||
K += nb12*(blockIdx.y / gqa_ratio);
|
K += nb12*(blockIdx.z / gqa_ratio);
|
||||||
V += nb22*(blockIdx.y / gqa_ratio); // K and V have same shape
|
V += nb22*(blockIdx.z / gqa_ratio); // K and V have same shape
|
||||||
const half * maskh = (const half *) mask + ne11*ic0;
|
const half * maskh = (const half *) mask + ne11*ic0;
|
||||||
|
|
||||||
const float slope = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
|
const float slope = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
|
||||||
|
|
||||||
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
|
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
|
||||||
constexpr int nwarps = D / WARP_SIZE;
|
constexpr int nwarps = D / WARP_SIZE;
|
||||||
|
@ -167,8 +166,7 @@ static __global__ void flash_attn_vec_ext_f32(
|
||||||
|
|
||||||
float VKQ[ncols] = {0.0f};
|
float VKQ[ncols] = {0.0f};
|
||||||
|
|
||||||
const int k_start = parallel_blocks == 1 ? 0 : ip*D;
|
for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D) {
|
||||||
for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*D) {
|
|
||||||
// Calculate KQ tile and keep track of new maximum KQ values:
|
// Calculate KQ tile and keep track of new maximum KQ values:
|
||||||
|
|
||||||
float kqmax_new_arr[ncols];
|
float kqmax_new_arr[ncols];
|
||||||
|
@ -268,29 +266,29 @@ static __global__ void flash_attn_vec_ext_f32(
|
||||||
kqsum[j_VKQ] = warp_reduce_sum(kqsum[j_VKQ]);
|
kqsum[j_VKQ] = warp_reduce_sum(kqsum[j_VKQ]);
|
||||||
|
|
||||||
float dst_val = VKQ[j_VKQ];
|
float dst_val = VKQ[j_VKQ];
|
||||||
if (parallel_blocks == 1) {
|
if (gridDim.y == 1) {
|
||||||
dst_val /= kqsum[j_VKQ];
|
dst_val /= kqsum[j_VKQ];
|
||||||
}
|
}
|
||||||
const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
|
const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
|
||||||
dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val;
|
dst[j_dst*D*gridDim.z + D*blockIdx.z + tid] = dst_val;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (parallel_blocks != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
|
if (gridDim.y != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
|
||||||
dst_meta[(ic0 + tid)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[tid], kqsum[tid]);
|
dst_meta[((ic0 + tid)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]);
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
NO_DEVICE_CODE;
|
NO_DEVICE_CODE;
|
||||||
#endif // FLASH_ATTN_AVAILABLE
|
#endif // FLASH_ATTN_AVAILABLE
|
||||||
}
|
}
|
||||||
|
|
||||||
template <int D, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_logit_softcap>
|
template <int D, int cols_per_block, ggml_type type_K, ggml_type type_V, bool use_logit_softcap>
|
||||||
void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
constexpr int nwarps = D/WARP_SIZE;
|
constexpr int nwarps = D/WARP_SIZE;
|
||||||
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>;
|
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32<D, cols_per_block, type_K, type_V, use_logit_softcap>;
|
||||||
constexpr bool need_f16_K = D != 128;
|
constexpr bool need_f16_K = D != 128;
|
||||||
constexpr bool need_f16_V = D != 128 && D != 64;
|
constexpr bool need_f16_V = D != 128 && D != 64;
|
||||||
constexpr size_t nbytes_shared = 0;
|
constexpr size_t nbytes_shared = 0;
|
||||||
launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, need_f16_K, need_f16_V);
|
launch_fattn<D, cols_per_block, 1, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <int D, ggml_type type_K, ggml_type type_V>
|
template <int D, ggml_type type_K, ggml_type type_V>
|
||||||
|
@ -308,64 +306,47 @@ void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml
|
||||||
|
|
||||||
if (Q->ne[1] == 1) {
|
if (Q->ne[1] == 1) {
|
||||||
constexpr int cols_per_block = 1;
|
constexpr int cols_per_block = 1;
|
||||||
constexpr int parallel_blocks = 4;
|
|
||||||
if (logit_softcap == 0.0f) {
|
if (logit_softcap == 0.0f) {
|
||||||
constexpr bool use_logit_softcap = false;
|
constexpr bool use_logit_softcap = false;
|
||||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||||
} else {
|
} else {
|
||||||
constexpr bool use_logit_softcap = true;
|
constexpr bool use_logit_softcap = true;
|
||||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||||
}
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (Q->ne[1] == 2) {
|
if (Q->ne[1] == 2) {
|
||||||
constexpr int cols_per_block = 2;
|
constexpr int cols_per_block = 2;
|
||||||
constexpr int parallel_blocks = 4;
|
|
||||||
if (logit_softcap == 0.0f) {
|
if (logit_softcap == 0.0f) {
|
||||||
constexpr bool use_logit_softcap = false;
|
constexpr bool use_logit_softcap = false;
|
||||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||||
} else {
|
} else {
|
||||||
constexpr bool use_logit_softcap = true;
|
constexpr bool use_logit_softcap = true;
|
||||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||||
}
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (Q->ne[1] <= 4) {
|
if (Q->ne[1] <= 4) {
|
||||||
constexpr int cols_per_block = 4;
|
constexpr int cols_per_block = 4;
|
||||||
constexpr int parallel_blocks = 4;
|
|
||||||
if (logit_softcap == 0.0f) {
|
if (logit_softcap == 0.0f) {
|
||||||
constexpr bool use_logit_softcap = false;
|
constexpr bool use_logit_softcap = false;
|
||||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||||
} else {
|
} else {
|
||||||
constexpr bool use_logit_softcap = true;
|
constexpr bool use_logit_softcap = true;
|
||||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||||
}
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (Q->ne[1] <= 8) {
|
|
||||||
constexpr int cols_per_block = 8;
|
|
||||||
constexpr int parallel_blocks = 4;
|
|
||||||
if (logit_softcap == 0.0f) {
|
|
||||||
constexpr bool use_logit_softcap = false;
|
|
||||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
|
||||||
} else {
|
|
||||||
constexpr bool use_logit_softcap = true;
|
|
||||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
|
||||||
}
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
constexpr int cols_per_block = 8;
|
constexpr int cols_per_block = 8;
|
||||||
constexpr int parallel_blocks = 1;
|
|
||||||
if (logit_softcap == 0.0f) {
|
if (logit_softcap == 0.0f) {
|
||||||
constexpr bool use_logit_softcap = false;
|
constexpr bool use_logit_softcap = false;
|
||||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||||
} else {
|
} else {
|
||||||
constexpr bool use_logit_softcap = true;
|
constexpr bool use_logit_softcap = true;
|
||||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -18,7 +18,7 @@ namespace wmma = rocwmma;
|
||||||
#endif // FP16_MMA_AVAILABLE
|
#endif // FP16_MMA_AVAILABLE
|
||||||
|
|
||||||
// D == head size, VKQ_stride == num VKQ rows calculated in parallel:
|
// D == head size, VKQ_stride == num VKQ rows calculated in parallel:
|
||||||
template<int D, int ncols, int nwarps, int VKQ_stride, int parallel_blocks, typename KQ_acc_t, bool use_logit_softcap>
|
template<int D, int ncols, int nwarps, int VKQ_stride, typename KQ_acc_t, bool use_logit_softcap>
|
||||||
__launch_bounds__(nwarps*ggml_cuda_get_physical_warp_size(), 1)
|
__launch_bounds__(nwarps*ggml_cuda_get_physical_warp_size(), 1)
|
||||||
static __global__ void flash_attn_ext_f16(
|
static __global__ void flash_attn_ext_f16(
|
||||||
const char * __restrict__ Q,
|
const char * __restrict__ Q,
|
||||||
|
@ -67,8 +67,7 @@ static __global__ void flash_attn_ext_f16(
|
||||||
|
|
||||||
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
||||||
|
|
||||||
const int ic0 = ncols*(blockIdx.x / parallel_blocks); // Index of the first Q/QKV column to work on.
|
const int ic0 = ncols*blockIdx.x; // Index of the first Q/QKV column to work on.
|
||||||
const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
|
|
||||||
|
|
||||||
static_assert(D <= FATTN_KQ_STRIDE, "D must be <= FATTN_KQ_STRIDE.");
|
static_assert(D <= FATTN_KQ_STRIDE, "D must be <= FATTN_KQ_STRIDE.");
|
||||||
static_assert(ncols == 8 || ncols % 16 == 0, "ncols must be 8 or a multiple of 16.");
|
static_assert(ncols == 8 || ncols % 16 == 0, "ncols must be 8 or a multiple of 16.");
|
||||||
|
@ -91,16 +90,16 @@ static __global__ void flash_attn_ext_f16(
|
||||||
constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half);
|
constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half);
|
||||||
|
|
||||||
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
||||||
const float * Q_f = (const float *) (Q + nb02* blockIdx.y + nb01*ic0);
|
const float * Q_f = (const float *) (Q + nb02* blockIdx.z + nb01*ic0);
|
||||||
const half * K_h = (const half *) (K + nb12*(blockIdx.y / gqa_ratio));
|
const half * K_h = (const half *) (K + nb12*(blockIdx.z / gqa_ratio));
|
||||||
const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
|
const half * V_h = (const half *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
|
||||||
const half * maskh = (const half *) mask + (nb31/sizeof(half))* ic0;
|
const half * maskh = (const half *) mask + (nb31/sizeof(half))* ic0;
|
||||||
const half2 * mask2 = (const half2 *) mask + (nb31/sizeof(half))*(ic0/2);
|
const half2 * mask2 = (const half2 *) mask + (nb31/sizeof(half))*(ic0/2);
|
||||||
|
|
||||||
const int stride_Q = nb01 / sizeof(float);
|
const int stride_Q = nb01 / sizeof(float);
|
||||||
const int stride_KV = nb11 / sizeof(half);
|
const int stride_KV = nb11 / sizeof(half);
|
||||||
|
|
||||||
const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
|
const float slopef = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
|
||||||
const half slopeh = __float2half(slopef);
|
const half slopeh = __float2half(slopef);
|
||||||
const half2 slope2 = make_half2(slopef, slopef);
|
const half2 slope2 = make_half2(slopef, slopef);
|
||||||
|
|
||||||
|
@ -176,7 +175,7 @@ static __global__ void flash_attn_ext_f16(
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
// Iterate over ne11 == previous tokens:
|
// Iterate over ne11 == previous tokens:
|
||||||
for (int k_VKQ_0 = ip*FATTN_KQ_STRIDE; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE) {
|
for (int k_VKQ_0 = blockIdx.y*FATTN_KQ_STRIDE; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*FATTN_KQ_STRIDE) {
|
||||||
// Calculate tile of KQ:
|
// Calculate tile of KQ:
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE; i_KQ_0 += KQ_stride_tc) {
|
for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE; i_KQ_0 += KQ_stride_tc) {
|
||||||
|
@ -395,7 +394,7 @@ static __global__ void flash_attn_ext_f16(
|
||||||
if (ic0 + j_VKQ >= ne01) {
|
if (ic0 + j_VKQ >= ne01) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
|
const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
|
||||||
|
|
||||||
float KQ_rowsum_j;
|
float KQ_rowsum_j;
|
||||||
if (std::is_same<KQ_acc_t, float>::value) {
|
if (std::is_same<KQ_acc_t, float>::value) {
|
||||||
|
@ -411,13 +410,13 @@ static __global__ void flash_attn_ext_f16(
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
float dst_val = VKQ[j_VKQ*D_padded + i];
|
float dst_val = VKQ[j_VKQ*D_padded + i];
|
||||||
if (parallel_blocks == 1) {
|
if (gridDim.y == 1) {
|
||||||
dst_val /= KQ_rowsum_j;
|
dst_val /= KQ_rowsum_j;
|
||||||
}
|
}
|
||||||
dst[j_dst*gridDim.y*D + blockIdx.y*D + i] = dst_val;
|
dst[j_dst*gridDim.z*D + blockIdx.z*D + i] = dst_val;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (parallel_blocks == 1 || threadIdx.x != 0) {
|
if (gridDim.y == 1 || threadIdx.x != 0) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -428,7 +427,7 @@ static __global__ void flash_attn_ext_f16(
|
||||||
dst_meta_val.x = __low2float(KQ_max_h2[j0/nwarps]);
|
dst_meta_val.x = __low2float(KQ_max_h2[j0/nwarps]);
|
||||||
}
|
}
|
||||||
dst_meta_val.y = KQ_rowsum_j;
|
dst_meta_val.y = KQ_rowsum_j;
|
||||||
dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = dst_meta_val;
|
dst_meta[((ic0 + j_VKQ)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = dst_meta_val;
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
NO_DEVICE_CODE;
|
NO_DEVICE_CODE;
|
||||||
|
@ -462,60 +461,26 @@ static_assert(get_VKQ_stride( 80, 4, 16) == 16, "Test failed.");
|
||||||
template <int D, int cols_per_block, typename KQ_acc_t>
|
template <int D, int cols_per_block, typename KQ_acc_t>
|
||||||
void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
const ggml_tensor * KQV = dst;
|
const ggml_tensor * KQV = dst;
|
||||||
const ggml_tensor * Q = dst->src[0];
|
|
||||||
|
|
||||||
constexpr int nwarps = 4;
|
constexpr int nwarps = 4;
|
||||||
|
|
||||||
constexpr int frag_m = cols_per_block == 8 && D % 32 == 0 ? 32 : 16;
|
constexpr int frag_m = cols_per_block == 8 && D % 32 == 0 ? 32 : 16;
|
||||||
const int blocks_num_pb1 = ((Q->ne[1] + cols_per_block - 1) / cols_per_block)*Q->ne[2]*Q->ne[3];
|
|
||||||
const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
|
|
||||||
const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size;
|
const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size;
|
||||||
|
|
||||||
float logit_softcap;
|
float logit_softcap;
|
||||||
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
|
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
|
||||||
|
|
||||||
if (4*blocks_num_pb1 < 2*nsm) {
|
|
||||||
constexpr int parallel_blocks = 4;
|
|
||||||
fattn_kernel_t fattn_kernel;
|
fattn_kernel_t fattn_kernel;
|
||||||
if (logit_softcap == 0.0f) {
|
if (logit_softcap == 0.0f) {
|
||||||
constexpr bool use_logit_softcap = false;
|
constexpr bool use_logit_softcap = false;
|
||||||
fattn_kernel = flash_attn_ext_f16<
|
fattn_kernel = flash_attn_ext_f16<
|
||||||
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
|
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), KQ_acc_t, use_logit_softcap>;
|
||||||
} else {
|
} else {
|
||||||
constexpr bool use_logit_softcap = true;
|
constexpr bool use_logit_softcap = true;
|
||||||
fattn_kernel = flash_attn_ext_f16<
|
fattn_kernel = flash_attn_ext_f16<
|
||||||
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
|
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), KQ_acc_t, use_logit_softcap>;
|
||||||
}
|
}
|
||||||
launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true, warp_size);
|
launch_fattn<D, cols_per_block, 1, -1>(ctx, dst, fattn_kernel, nwarps, 0, FATTN_KQ_STRIDE, true, true, false, warp_size);
|
||||||
return;
|
|
||||||
}
|
|
||||||
if (2*blocks_num_pb1 < 2*nsm) {
|
|
||||||
constexpr int parallel_blocks = 2;
|
|
||||||
fattn_kernel_t fattn_kernel;
|
|
||||||
if (logit_softcap == 0.0f) {
|
|
||||||
constexpr bool use_logit_softcap = false;
|
|
||||||
fattn_kernel = flash_attn_ext_f16<
|
|
||||||
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
|
|
||||||
} else {
|
|
||||||
constexpr bool use_logit_softcap = true;
|
|
||||||
fattn_kernel = flash_attn_ext_f16<
|
|
||||||
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
|
|
||||||
}
|
|
||||||
launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true, warp_size);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
constexpr int parallel_blocks = 1;
|
|
||||||
fattn_kernel_t fattn_kernel;
|
|
||||||
if (logit_softcap == 0.0f) {
|
|
||||||
constexpr bool use_logit_softcap = false;
|
|
||||||
fattn_kernel = flash_attn_ext_f16<
|
|
||||||
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
|
|
||||||
} else {
|
|
||||||
constexpr bool use_logit_softcap = true;
|
|
||||||
fattn_kernel = flash_attn_ext_f16<
|
|
||||||
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
|
|
||||||
}
|
|
||||||
launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true, warp_size);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
|
|
|
@ -281,13 +281,13 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
||||||
|
|
||||||
if (!fp16_mma_available(cc)) {
|
if (!fp16_mma_available(cc)) {
|
||||||
if (prec == GGML_PREC_DEFAULT) {
|
if (prec == GGML_PREC_DEFAULT) {
|
||||||
if (Q->ne[1] <= 8) {
|
if (Q->ne[1] <= 8 || Q->ne[0] == 256) {
|
||||||
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
|
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
|
||||||
} else {
|
} else {
|
||||||
ggml_cuda_flash_attn_ext_tile_f16(ctx, dst);
|
ggml_cuda_flash_attn_ext_tile_f16(ctx, dst);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if (Q->ne[1] <= 8) {
|
if (Q->ne[1] <= 8 || Q->ne[0] == 256) {
|
||||||
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
|
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
|
||||||
} else {
|
} else {
|
||||||
ggml_cuda_flash_attn_ext_tile_f32(ctx, dst);
|
ggml_cuda_flash_attn_ext_tile_f32(ctx, dst);
|
||||||
|
@ -296,17 +296,17 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const int gqa_ratio = Q->ne[2] / K->ne[2];
|
const bool gqa_opt_applies = ((Q->ne[2] / K->ne[2]) % 2 == 0) && mask; // The mma-based kernels have GQA-specific optimizations
|
||||||
const bool mma_fast_for_bs1 = fp16_mma_available(cc) && gqa_ratio % 2 == 0 &&
|
const bool mma_needs_data_conversion = K->type != GGML_TYPE_F16 || V->type != GGML_TYPE_F16;
|
||||||
K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16 && mask;
|
const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies && cc < GGML_CUDA_CC_ADA_LOVELACE && !mma_needs_data_conversion;
|
||||||
if (Q->ne[1] == 1 && Q->ne[0] % (2*warp_size) == 0 && !mma_fast_for_bs1) {
|
const bool can_use_vector_kernel = (Q->ne[0] % (2*warp_size) == 0) && (prec == GGML_PREC_DEFAULT || Q->ne[0] <= 128);
|
||||||
|
if (Q->ne[1] == 1 && can_use_vector_kernel && !mma_faster_for_bs1) {
|
||||||
if (prec == GGML_PREC_DEFAULT) {
|
if (prec == GGML_PREC_DEFAULT) {
|
||||||
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
|
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
|
||||||
return;
|
} else {
|
||||||
} else if(Q->ne[0] <= 128) {
|
|
||||||
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
|
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
|
||||||
return;
|
|
||||||
}
|
}
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// The MMA implementation needs Turing or newer, use the old WMMA code for Volta:
|
// The MMA implementation needs Turing or newer, use the old WMMA code for Volta:
|
||||||
|
|
|
@ -3235,6 +3235,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||||
#ifndef FLASH_ATTN_AVAILABLE
|
#ifndef FLASH_ATTN_AVAILABLE
|
||||||
return false;
|
return false;
|
||||||
#endif // FLASH_ATTN_AVAILABLE
|
#endif // FLASH_ATTN_AVAILABLE
|
||||||
|
if (op->src[0]->ne[3] != 1) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
if (op->src[1]->type == GGML_TYPE_BF16 || op->src[2]->type == GGML_TYPE_BF16) {
|
if (op->src[1]->type == GGML_TYPE_BF16 || op->src[2]->type == GGML_TYPE_BF16) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
1
ggml/src/ggml-cuda/vendors/hip.h
vendored
1
ggml/src/ggml-cuda/vendors/hip.h
vendored
|
@ -129,6 +129,7 @@
|
||||||
#define cudaGraph_t hipGraph_t
|
#define cudaGraph_t hipGraph_t
|
||||||
#define cudaStream_t hipStream_t
|
#define cudaStream_t hipStream_t
|
||||||
#define cudaSuccess hipSuccess
|
#define cudaSuccess hipSuccess
|
||||||
|
#define cudaOccupancyMaxActiveBlocksPerMultiprocessor hipOccupancyMaxActiveBlocksPerMultiprocessor
|
||||||
#define __trap() do { abort(); __builtin_unreachable(); } while(0)
|
#define __trap() do { abort(); __builtin_unreachable(); } while(0)
|
||||||
#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
|
#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
|
||||||
#define CUBLAS_STATUS_NOT_INITIALIZED HIPBLAS_STATUS_NOT_INITIALIZED
|
#define CUBLAS_STATUS_NOT_INITIALIZED HIPBLAS_STATUS_NOT_INITIALIZED
|
||||||
|
|
1
ggml/src/ggml-cuda/vendors/musa.h
vendored
1
ggml/src/ggml-cuda/vendors/musa.h
vendored
|
@ -134,5 +134,6 @@
|
||||||
#define cudaStreamCaptureModeRelaxed musaStreamCaptureModeRelaxed
|
#define cudaStreamCaptureModeRelaxed musaStreamCaptureModeRelaxed
|
||||||
#define cudaStreamBeginCapture musaStreamBeginCapture
|
#define cudaStreamBeginCapture musaStreamBeginCapture
|
||||||
#define cudaStreamEndCapture musaStreamEndCapture
|
#define cudaStreamEndCapture musaStreamEndCapture
|
||||||
|
#define cudaOccupancyMaxActiveBlocksPerMultiprocessor musaOccupancyMaxActiveBlocksPerMultiprocessor
|
||||||
|
|
||||||
typedef mt_bfloat16 nv_bfloat16;
|
typedef mt_bfloat16 nv_bfloat16;
|
||||||
|
|
|
@ -170,7 +170,6 @@ static size_t g_scratch_offset = 0;
|
||||||
int get_current_device_id();
|
int get_current_device_id();
|
||||||
|
|
||||||
inline dpct::err0 ggml_sycl_set_device(const int device) try {
|
inline dpct::err0 ggml_sycl_set_device(const int device) try {
|
||||||
|
|
||||||
int current_device_id;
|
int current_device_id;
|
||||||
SYCL_CHECK(CHECK_TRY_ERROR(current_device_id = get_current_device_id()));
|
SYCL_CHECK(CHECK_TRY_ERROR(current_device_id = get_current_device_id()));
|
||||||
|
|
||||||
|
@ -242,6 +241,14 @@ struct ggml_sycl_pool_alloc {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
T * realloc(size_t size) {
|
||||||
|
GGML_ASSERT(pool != nullptr);
|
||||||
|
if (ptr)
|
||||||
|
pool->free(ptr, actual_size);
|
||||||
|
ptr = (T *) pool->alloc(size * sizeof(T), &this->actual_size);
|
||||||
|
return ptr;
|
||||||
|
}
|
||||||
|
|
||||||
// size is in number of elements
|
// size is in number of elements
|
||||||
T * alloc(size_t size) {
|
T * alloc(size_t size) {
|
||||||
GGML_ASSERT(pool != nullptr);
|
GGML_ASSERT(pool != nullptr);
|
||||||
|
@ -371,10 +378,29 @@ struct ggml_backend_sycl_context {
|
||||||
dnnl::stream stream_dnnl() {
|
dnnl::stream stream_dnnl() {
|
||||||
return stream_dnnl(device, 0);
|
return stream_dnnl(device, 0);
|
||||||
}
|
}
|
||||||
|
dnnl::memory get_scratchpad_mem(const dnnl::memory::desc & scratchpad_md,
|
||||||
|
const dnnl::engine & eng, const queue_ptr q) {
|
||||||
|
ggml_sycl_pool_alloc<uint8_t> * pool;
|
||||||
|
auto it = scratchpad_map.find(q);
|
||||||
|
if (it == scratchpad_map.end()) {
|
||||||
|
scratchpad_map[q] = std::make_unique<ggml_sycl_pool_alloc<uint8_t>>(this->pool());
|
||||||
|
pool = scratchpad_map[q].get();
|
||||||
|
} else {
|
||||||
|
pool = it->second.get();
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t scratchpad_size = scratchpad_md.get_size();
|
||||||
|
if (scratchpad_size > pool->actual_size) {
|
||||||
|
pool->realloc(scratchpad_size);
|
||||||
|
}
|
||||||
|
void * mem_ptr = pool->get();
|
||||||
|
return dnnl::memory(scratchpad_md, eng, mem_ptr);
|
||||||
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
// pool
|
// pool
|
||||||
std::unique_ptr<ggml_sycl_pool> pools[GGML_SYCL_MAX_DEVICES];
|
std::unique_ptr<ggml_sycl_pool> pools[GGML_SYCL_MAX_DEVICES];
|
||||||
|
std::unordered_map<sycl::queue *, std::unique_ptr<ggml_sycl_pool_alloc<uint8_t>>> scratchpad_map;
|
||||||
|
|
||||||
std::unique_ptr<ggml_sycl_pool> host_pools[GGML_SYCL_MAX_DEVICES];
|
std::unique_ptr<ggml_sycl_pool> host_pools[GGML_SYCL_MAX_DEVICES];
|
||||||
|
|
||||||
|
|
|
@ -13,9 +13,6 @@
|
||||||
#ifndef GGML_SYCL_GEMM_HPP
|
#ifndef GGML_SYCL_GEMM_HPP
|
||||||
#define GGML_SYCL_GEMM_HPP
|
#define GGML_SYCL_GEMM_HPP
|
||||||
|
|
||||||
#include <fstream>
|
|
||||||
#include <iostream>
|
|
||||||
|
|
||||||
#include "ggml-sycl.h"
|
#include "ggml-sycl.h"
|
||||||
|
|
||||||
#if GGML_SYCL_DNNL
|
#if GGML_SYCL_DNNL
|
||||||
|
@ -35,62 +32,34 @@ public:
|
||||||
else static_assert(0);
|
else static_assert(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
static inline void row_gemm(sycl::queue& q, bool a_trans,
|
static inline void row_gemm(ggml_backend_sycl_context & ctx, bool a_trans, bool b_trans, int m, int n, int k,
|
||||||
bool b_trans, int m, int n, int k,
|
const void * a, dt at, const void * b, dt bt, void * c, dt ct, const queue_ptr & q) {
|
||||||
const void* a, dt at, const void* b, dt bt, void* c, dt ct)
|
auto stream = ctx.stream_dnnl(q);
|
||||||
{
|
auto eng = ctx.engine_dnnl(q);
|
||||||
// Get the device associated with the queue
|
|
||||||
sycl::device dev = q.get_device();
|
|
||||||
// Get the context associated with the queue
|
|
||||||
sycl::context ctx = q.get_context();
|
|
||||||
const dnnl::engine eng = dnnl::sycl_interop::make_engine(dev, ctx);
|
|
||||||
const dnnl::stream stream = dnnl::sycl_interop::make_stream(eng, q);
|
|
||||||
dnnl::memory::dims a_dims = { m, k };
|
dnnl::memory::dims a_dims = { m, k };
|
||||||
dnnl::memory::dims b_dims = { k, n };
|
dnnl::memory::dims b_dims = { k, n };
|
||||||
dnnl::memory::dims c_dims = { m, n };
|
dnnl::memory::dims c_dims = { m, n };
|
||||||
const auto a_in_md = dnnl::memory::desc(a_dims, at, a_trans ? tag::ba : tag::ab);
|
const auto a_in_md = dnnl::memory::desc(a_dims, at, a_trans ? tag::ba : tag::ab);
|
||||||
const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_trans ? tag::ba : tag::ab);
|
const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_trans ? tag::ba : tag::ab);
|
||||||
const auto c_md = dnnl::memory::desc(c_dims, ct, tag::ab);
|
const auto c_md = dnnl::memory::desc(c_dims, ct, tag::ab);
|
||||||
|
|
||||||
|
dnnl::primitive_attr primitive_attr;
|
||||||
|
primitive_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
|
||||||
|
|
||||||
auto a_mem = dnnl::memory(a_in_md, eng, const_cast<void*>(a));
|
auto a_mem = dnnl::memory(a_in_md, eng, const_cast<void*>(a));
|
||||||
auto b_mem = dnnl::memory(b_in_md, eng, const_cast<void*>(b));
|
auto b_mem = dnnl::memory(b_in_md, eng, const_cast<void*>(b));
|
||||||
auto matmul_pd = dnnl::matmul::primitive_desc(eng, a_in_md, b_in_md, c_md);
|
auto matmul_pd = dnnl::matmul::primitive_desc(eng, a_in_md, b_in_md, c_md, primitive_attr);
|
||||||
auto c_mem = dnnl::memory(matmul_pd.dst_desc(), eng, c);
|
auto c_mem = dnnl::memory(matmul_pd.dst_desc(), eng, c);
|
||||||
|
|
||||||
// Create the primitive.
|
auto scratchpad_md = matmul_pd.scratchpad_desc();
|
||||||
|
auto scratchpad_mem = ctx.get_scratchpad_mem(scratchpad_md, eng, q);
|
||||||
auto matmul_prim = dnnl::matmul(matmul_pd);
|
auto matmul_prim = dnnl::matmul(matmul_pd);
|
||||||
// Primitive arguments.
|
|
||||||
std::unordered_map<int, dnnl::memory> matmul_args;
|
|
||||||
matmul_args.insert({ DNNL_ARG_SRC, a_mem });
|
|
||||||
matmul_args.insert({ DNNL_ARG_WEIGHTS, b_mem });
|
|
||||||
matmul_args.insert({ DNNL_ARG_DST, c_mem });
|
|
||||||
|
|
||||||
matmul_prim.execute(stream, matmul_args);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
static inline void row_gemm(const dnnl::stream& stream, bool a_trans,
|
|
||||||
bool b_trans, int m, int n, int k,
|
|
||||||
const void* a, dt at, const void* b, dt bt, void* c, dt ct)
|
|
||||||
{
|
|
||||||
auto const eng = stream.get_engine();
|
|
||||||
dnnl::memory::dims a_dims = { m, k };
|
|
||||||
dnnl::memory::dims b_dims = { k, n };
|
|
||||||
dnnl::memory::dims c_dims = { m, n };
|
|
||||||
const auto a_in_md = dnnl::memory::desc(a_dims, at, a_trans ? tag::ba : tag::ab);
|
|
||||||
const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_trans ? tag::ba : tag::ab);
|
|
||||||
const auto c_md = dnnl::memory::desc(c_dims, ct, tag::ab);
|
|
||||||
auto a_mem = dnnl::memory(a_in_md, eng, const_cast<void*>(a));
|
|
||||||
auto b_mem = dnnl::memory(b_in_md, eng, const_cast<void*>(b));
|
|
||||||
auto matmul_pd = dnnl::matmul::primitive_desc(eng, a_in_md, b_in_md, c_md);
|
|
||||||
auto c_mem = dnnl::memory(matmul_pd.dst_desc(), eng, c);
|
|
||||||
|
|
||||||
// Create the primitive.
|
|
||||||
auto matmul_prim = dnnl::matmul(matmul_pd);
|
|
||||||
// Primitive arguments.
|
|
||||||
std::unordered_map<int, dnnl::memory> matmul_args;
|
std::unordered_map<int, dnnl::memory> matmul_args;
|
||||||
matmul_args.insert({ DNNL_ARG_SRC, a_mem });
|
matmul_args.insert({ DNNL_ARG_SRC, a_mem });
|
||||||
matmul_args.insert({ DNNL_ARG_WEIGHTS, b_mem });
|
matmul_args.insert({ DNNL_ARG_WEIGHTS, b_mem });
|
||||||
matmul_args.insert({ DNNL_ARG_DST, c_mem });
|
matmul_args.insert({ DNNL_ARG_DST, c_mem });
|
||||||
|
matmul_args.insert({ DNNL_ARG_SCRATCHPAD, scratchpad_mem });
|
||||||
|
|
||||||
matmul_prim.execute(stream, matmul_args);
|
matmul_prim.execute(stream, matmul_args);
|
||||||
}
|
}
|
||||||
|
|
|
@ -2058,9 +2058,9 @@ inline void ggml_sycl_op_mul_mat_sycl(
|
||||||
const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst);
|
const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst);
|
||||||
to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
|
to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
|
||||||
#else
|
#else
|
||||||
auto dnnl_stream = ctx.stream_dnnl(stream);
|
DnnlGemmWrapper::row_gemm(ctx, false, true, src1_ncols, row_diff, ne10, src1_ptr,
|
||||||
DnnlGemmWrapper::row_gemm(dnnl_stream, false, true, src1_ncols, row_diff, ne10, src1_ptr, DnnlGemmWrapper::to_dt<sycl::half>(),
|
DnnlGemmWrapper::to_dt<sycl::half>(), src0_ptr, DnnlGemmWrapper::to_dt<sycl::half>(),
|
||||||
src0_ptr, DnnlGemmWrapper::to_dt<sycl::half>(), dst_f16.get(), DnnlGemmWrapper::to_dt<sycl::half>());
|
dst_f16.get(), DnnlGemmWrapper::to_dt<sycl::half>(), stream);
|
||||||
const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst);
|
const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst);
|
||||||
to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff* src1_ncols, stream);
|
to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff* src1_ncols, stream);
|
||||||
#endif
|
#endif
|
||||||
|
@ -2099,9 +2099,9 @@ inline void ggml_sycl_op_mul_mat_sycl(
|
||||||
dst_dd_i, ldc)));
|
dst_dd_i, ldc)));
|
||||||
# endif
|
# endif
|
||||||
#else
|
#else
|
||||||
auto dnnl_stream = ctx.stream_dnnl(stream);
|
DnnlGemmWrapper::row_gemm(ctx, false, true, src1_ncols, row_diff, ne10, src1_ddf1_i,
|
||||||
DnnlGemmWrapper::row_gemm(dnnl_stream, false, true, src1_ncols, row_diff, ne10, src1_ddf1_i, DnnlGemmWrapper::to_dt<float>(),
|
DnnlGemmWrapper::to_dt<float>(), src0_ddf_i, DnnlGemmWrapper::to_dt<float>(),
|
||||||
src0_ddf_i, DnnlGemmWrapper::to_dt<float>(), dst_dd_i, DnnlGemmWrapper::to_dt<float>());
|
dst_dd_i, DnnlGemmWrapper::to_dt<float>(), stream);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
GGML_UNUSED(dst);
|
GGML_UNUSED(dst);
|
||||||
|
|
|
@ -311,8 +311,8 @@ float16_t dequantFuncIQ1_S(const in decodeBufIQ1_S bl, const in uint blockCoords
|
||||||
const float16_t d = bl.block.d;
|
const float16_t d = bl.block.d;
|
||||||
const uint idx = coordInBlock[1];
|
const uint idx = coordInBlock[1];
|
||||||
|
|
||||||
const uint ib32 = idx / 32;
|
const uint ib32 = (idx & 0xE0) >> 5;
|
||||||
const uint ib8 = idx / 8;
|
const uint ib8 = (idx & 0xF8) >> 3;
|
||||||
|
|
||||||
const uint qh = bl.block.qh[ib32];
|
const uint qh = bl.block.qh[ib32];
|
||||||
const uint qs = bl.block.qs[ib8];
|
const uint qs = bl.block.qs[ib8];
|
||||||
|
@ -330,14 +330,20 @@ layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ1
|
||||||
block_iq1_m block;
|
block_iq1_m block;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
layout(buffer_reference, std430, buffer_reference_align = 8) buffer decodeBufIQ1_M_packed64 {
|
||||||
|
block_iq1_m_packed64 block;
|
||||||
|
};
|
||||||
|
|
||||||
float16_t dequantFuncIQ1_M(const in decodeBufIQ1_M bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
float16_t dequantFuncIQ1_M(const in decodeBufIQ1_M bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
||||||
{
|
{
|
||||||
const u16vec4 scales = u16vec4(bl.block.scales[0], bl.block.scales[1], bl.block.scales[2], bl.block.scales[3]) >> 12;
|
decodeBufIQ1_M_packed64 bl64 = decodeBufIQ1_M_packed64(bl);
|
||||||
const float16_t d = uint16BitsToHalf(scales.x | (scales.y << 4) | (scales.z << 8) | (scales.w << 12));
|
|
||||||
const uint idx = coordInBlock[1];
|
const uint idx = coordInBlock[1];
|
||||||
|
|
||||||
const uint ib8 = idx / 8;
|
uvec2 scales = unpack32(bl64.block.scales);
|
||||||
const uint ib16 = idx / 16;
|
const float16_t d = uint16BitsToHalf(uint16_t(((scales.x & 0xF000) >> 12) | ((scales.x & 0xF0000000) >> 24) | ((scales.y & 0xF000) >> 4) | ((scales.y & 0xF0000000) >> 16)));
|
||||||
|
|
||||||
|
const uint ib8 = (idx & 0xF8) >> 3;
|
||||||
|
const uint ib16 = (idx & 0xF0) >> 4;
|
||||||
const int i8 = int(idx % 8);
|
const int i8 = int(idx % 8);
|
||||||
const uint sc = bl.block.scales[ib8 / 8];
|
const uint sc = bl.block.scales[ib8 / 8];
|
||||||
const uint qs = bl.block.qs[ib8];
|
const uint qs = bl.block.qs[ib8];
|
||||||
|
|
|
@ -2,6 +2,7 @@
|
||||||
#if !defined(GGML_TYPES_COMP)
|
#if !defined(GGML_TYPES_COMP)
|
||||||
#define GGML_TYPES_COMP
|
#define GGML_TYPES_COMP
|
||||||
|
|
||||||
|
#extension GL_EXT_shader_explicit_arithmetic_types_int64 : require
|
||||||
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
|
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
|
||||||
#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
|
#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
|
||||||
#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
|
#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
|
||||||
|
@ -312,6 +313,12 @@ struct block_iq1_m {
|
||||||
uint16_t scales[QUANT_K_IQ1_M/64];
|
uint16_t scales[QUANT_K_IQ1_M/64];
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct block_iq1_m_packed64 {
|
||||||
|
uint64_t qs[QUANT_K_IQ1_M/8/8];
|
||||||
|
uint64_t qh[QUANT_K_IQ1_M/16/8];
|
||||||
|
uint64_t scales;
|
||||||
|
};
|
||||||
|
|
||||||
#if defined(DATA_A_IQ1_S)
|
#if defined(DATA_A_IQ1_S)
|
||||||
#define QUANT_K QUANT_K_IQ1_S
|
#define QUANT_K QUANT_K_IQ1_S
|
||||||
#define QUANT_R QUANT_R_IQ1_S
|
#define QUANT_R QUANT_R_IQ1_S
|
||||||
|
|
|
@ -1154,6 +1154,7 @@ int llama_context::encode(llama_batch & inp_batch) {
|
||||||
// remember the sequence ids used during the encoding - needed for cross attention later
|
// remember the sequence ids used during the encoding - needed for cross attention later
|
||||||
cross.seq_ids_enc.resize(n_tokens);
|
cross.seq_ids_enc.resize(n_tokens);
|
||||||
for (int32_t i = 0; i < n_tokens; i++) {
|
for (int32_t i = 0; i < n_tokens; i++) {
|
||||||
|
cross.seq_ids_enc[i].clear();
|
||||||
for (int s = 0; s < ubatch.n_seq_id[i]; s++) {
|
for (int s = 0; s < ubatch.n_seq_id[i]; s++) {
|
||||||
llama_seq_id seq_id = ubatch.seq_id[i][s];
|
llama_seq_id seq_id = ubatch.seq_id[i][s];
|
||||||
cross.seq_ids_enc[i].insert(seq_id);
|
cross.seq_ids_enc[i].insert(seq_id);
|
||||||
|
|
|
@ -276,7 +276,17 @@ static buft_list_t make_cpu_buft_list(const std::vector<ggml_backend_dev_t> & de
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// add extra buffer types
|
bool has_gpu_device = false;
|
||||||
|
for (auto * dev : devices) {
|
||||||
|
if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU) {
|
||||||
|
has_gpu_device = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// add extra buffer types, only if no GPU device is present
|
||||||
|
// ref: https://github.com/ggml-org/llama.cpp/issues/12481#issuecomment-2743136094
|
||||||
|
if (!has_gpu_device) {
|
||||||
auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
|
auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
|
||||||
auto * cpu_reg = ggml_backend_dev_backend_reg(cpu_dev);
|
auto * cpu_reg = ggml_backend_dev_backend_reg(cpu_dev);
|
||||||
auto ggml_backend_dev_get_extra_bufts_fn = (ggml_backend_dev_get_extra_bufts_t)
|
auto ggml_backend_dev_get_extra_bufts_fn = (ggml_backend_dev_get_extra_bufts_t)
|
||||||
|
@ -288,6 +298,9 @@ static buft_list_t make_cpu_buft_list(const std::vector<ggml_backend_dev_t> & de
|
||||||
++extra_bufts;
|
++extra_bufts;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
LLAMA_LOG_WARN("%s: disabling extra buffer types (i.e. repacking) since a GPU device is available\n", __func__);
|
||||||
|
}
|
||||||
|
|
||||||
// add a host buffer type
|
// add a host buffer type
|
||||||
// storing the tensors in a host buffer is useful when the processing of large batches
|
// storing the tensors in a host buffer is useful when the processing of large batches
|
||||||
|
@ -2305,9 +2318,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||||
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
|
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
|
||||||
|
|
||||||
// optional bias tensors
|
// optional bias tensors
|
||||||
layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0);
|
layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED);
|
||||||
layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, 0);
|
layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED);
|
||||||
layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, 0);
|
layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED);
|
||||||
|
|
||||||
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
|
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
|
||||||
|
|
||||||
|
@ -2424,7 +2437,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||||
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0);
|
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0);
|
||||||
layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), { n_embd }, 0);
|
layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), { n_embd }, 0);
|
||||||
|
|
||||||
layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, n_embd + 2 * n_embd_gqa }, llama_model_loader::TENSOR_NOT_REQUIRED);
|
layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, n_embd + 2 * n_embd_gqa }, TENSOR_NOT_REQUIRED);
|
||||||
if (layer.wqkv == nullptr) {
|
if (layer.wqkv == nullptr) {
|
||||||
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0);
|
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0);
|
||||||
layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0);
|
layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0);
|
||||||
|
@ -3310,16 +3323,16 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||||
auto & layer = layers[i];
|
auto & layer = layers[i];
|
||||||
|
|
||||||
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
|
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
|
||||||
layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED);
|
||||||
layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED);
|
||||||
|
|
||||||
if (layer.wqkv == nullptr) {
|
if (layer.wqkv == nullptr) {
|
||||||
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
|
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
|
||||||
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0);
|
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0);
|
||||||
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0);
|
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0);
|
||||||
layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED);
|
||||||
layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED);
|
||||||
layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED);
|
||||||
}
|
}
|
||||||
|
|
||||||
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
|
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
|
||||||
|
@ -3430,12 +3443,12 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||||
layer.time_mix_w2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W2, "weight", i), {time_mix_extra_dim, n_embd, 5}, 0);
|
layer.time_mix_w2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W2, "weight", i), {time_mix_extra_dim, n_embd, 5}, 0);
|
||||||
|
|
||||||
layer.time_mix_lerp_x = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_X, "weight", i), {n_embd, 1, 1}, 0);
|
layer.time_mix_lerp_x = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_X, "weight", i), {n_embd, 1, 1}, 0);
|
||||||
layer.time_mix_lerp_w = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_W, "weight", i), {n_embd, 1, 1}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
layer.time_mix_lerp_w = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_W, "weight", i), {n_embd, 1, 1}, TENSOR_NOT_REQUIRED);
|
||||||
layer.time_mix_lerp_k = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_K, "weight", i), {n_embd, 1, 1}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
layer.time_mix_lerp_k = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_K, "weight", i), {n_embd, 1, 1}, TENSOR_NOT_REQUIRED);
|
||||||
layer.time_mix_lerp_v = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_V, "weight", i), {n_embd, 1, 1}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
layer.time_mix_lerp_v = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_V, "weight", i), {n_embd, 1, 1}, TENSOR_NOT_REQUIRED);
|
||||||
layer.time_mix_lerp_r = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_R, "weight", i), {n_embd, 1, 1}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
layer.time_mix_lerp_r = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_R, "weight", i), {n_embd, 1, 1}, TENSOR_NOT_REQUIRED);
|
||||||
layer.time_mix_lerp_g = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_G, "weight", i), {n_embd, 1, 1}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
layer.time_mix_lerp_g = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_G, "weight", i), {n_embd, 1, 1}, TENSOR_NOT_REQUIRED);
|
||||||
layer.time_mix_lerp_fused = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_FUSED, "weight", i), {n_embd, 1, 1, 5}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
layer.time_mix_lerp_fused = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_FUSED, "weight", i), {n_embd, 1, 1, 5}, TENSOR_NOT_REQUIRED);
|
||||||
GGML_ASSERT(!(layer.time_mix_lerp_fused == NULL && layer.time_mix_lerp_w == NULL));
|
GGML_ASSERT(!(layer.time_mix_lerp_fused == NULL && layer.time_mix_lerp_w == NULL));
|
||||||
|
|
||||||
layer.time_mix_first = create_tensor(tn(LLM_TENSOR_TIME_MIX_FIRST, "weight", i), {head_size, n_embd / head_size}, 0);
|
layer.time_mix_first = create_tensor(tn(LLM_TENSOR_TIME_MIX_FIRST, "weight", i), {head_size, n_embd / head_size}, 0);
|
||||||
|
@ -3465,7 +3478,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||||
|
|
||||||
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
||||||
output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, TENSOR_NOT_REQUIRED);
|
||||||
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
|
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
|
||||||
|
|
||||||
const int time_mix_extra_dim = hparams.time_mix_extra_dim;
|
const int time_mix_extra_dim = hparams.time_mix_extra_dim;
|
||||||
|
@ -3491,7 +3504,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||||
layer.time_mix_lerp_x = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_X, "weight", i), {n_embd, 1, 1}, 0);
|
layer.time_mix_lerp_x = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_X, "weight", i), {n_embd, 1, 1}, 0);
|
||||||
layer.time_mix_lerp_fused = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_FUSED, "weight", i), {n_embd, 1, 1, 5}, 0);
|
layer.time_mix_lerp_fused = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_FUSED, "weight", i), {n_embd, 1, 1, 5}, 0);
|
||||||
|
|
||||||
layer.time_mix_first = create_tensor(tn(LLM_TENSOR_TIME_MIX_FIRST, "weight", i), {head_size, n_embd / head_size}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
layer.time_mix_first = create_tensor(tn(LLM_TENSOR_TIME_MIX_FIRST, "weight", i), {head_size, n_embd / head_size}, TENSOR_NOT_REQUIRED);
|
||||||
layer.time_mix_decay = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY, "weight", i), {n_embd}, 0);
|
layer.time_mix_decay = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY, "weight", i), {n_embd}, 0);
|
||||||
layer.time_mix_decay_w1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY_W1, "weight", i), {n_embd, time_decay_extra_dim}, 0);
|
layer.time_mix_decay_w1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY_W1, "weight", i), {n_embd, time_decay_extra_dim}, 0);
|
||||||
layer.time_mix_decay_w2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY_W2, "weight", i), {time_decay_extra_dim, attn_hidden_size}, 0);
|
layer.time_mix_decay_w2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY_W2, "weight", i), {time_decay_extra_dim, attn_hidden_size}, 0);
|
||||||
|
@ -3500,9 +3513,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||||
layer.time_mix_receptance = create_tensor(tn(LLM_TENSOR_TIME_MIX_RECEPTANCE, "weight", i), {attn_hidden_size, n_embd}, 0);
|
layer.time_mix_receptance = create_tensor(tn(LLM_TENSOR_TIME_MIX_RECEPTANCE, "weight", i), {attn_hidden_size, n_embd}, 0);
|
||||||
layer.time_mix_gate = create_tensor(tn(LLM_TENSOR_TIME_MIX_GATE, "weight", i), {attn_hidden_size, n_embd}, 0);
|
layer.time_mix_gate = create_tensor(tn(LLM_TENSOR_TIME_MIX_GATE, "weight", i), {attn_hidden_size, n_embd}, 0);
|
||||||
// optional bias tensors
|
// optional bias tensors
|
||||||
layer.time_mix_key_b = create_tensor(tn(LLM_TENSOR_TIME_MIX_KEY, "bias", i), {attn_key_value_size}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
layer.time_mix_key_b = create_tensor(tn(LLM_TENSOR_TIME_MIX_KEY, "bias", i), {attn_key_value_size}, TENSOR_NOT_REQUIRED);
|
||||||
layer.time_mix_value_b = create_tensor(tn(LLM_TENSOR_TIME_MIX_VALUE, "bias", i), {attn_key_value_size}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
layer.time_mix_value_b = create_tensor(tn(LLM_TENSOR_TIME_MIX_VALUE, "bias", i), {attn_key_value_size}, TENSOR_NOT_REQUIRED);
|
||||||
layer.time_mix_receptance_b = create_tensor(tn(LLM_TENSOR_TIME_MIX_RECEPTANCE, "bias", i), {attn_hidden_size}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
layer.time_mix_receptance_b = create_tensor(tn(LLM_TENSOR_TIME_MIX_RECEPTANCE, "bias", i), {attn_hidden_size}, TENSOR_NOT_REQUIRED);
|
||||||
|
|
||||||
layer.time_mix_output = create_tensor(tn(LLM_TENSOR_TIME_MIX_OUTPUT, "weight", i), {n_embd, attn_hidden_size}, 0);
|
layer.time_mix_output = create_tensor(tn(LLM_TENSOR_TIME_MIX_OUTPUT, "weight", i), {n_embd, attn_hidden_size}, 0);
|
||||||
|
|
||||||
|
@ -3623,8 +3636,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||||
layer.time_mix_v2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V2, "weight", i), {n_lora_value_res_mix, n_embd}, 0);
|
layer.time_mix_v2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V2, "weight", i), {n_lora_value_res_mix, n_embd}, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
layer.time_mix_g1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_G1, "weight", i), {n_embd, n_lora_gate}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
layer.time_mix_g1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_G1, "weight", i), {n_embd, n_lora_gate}, TENSOR_NOT_REQUIRED);
|
||||||
layer.time_mix_g2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_G2, "weight", i), {n_lora_gate, n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
layer.time_mix_g2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_G2, "weight", i), {n_lora_gate, n_embd}, TENSOR_NOT_REQUIRED);
|
||||||
|
|
||||||
try {
|
try {
|
||||||
layer.time_mix_lerp_fused = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_FUSED, "weight", i), {n_embd, 1, 1, 6}, 0);
|
layer.time_mix_lerp_fused = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_FUSED, "weight", i), {n_embd, 1, 1, 6}, 0);
|
||||||
|
@ -3641,8 +3654,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||||
layer.time_mix_value = create_tensor(tn(LLM_TENSOR_TIME_MIX_VALUE, "weight", i), {attn_hidden_size, n_embd}, 0);
|
layer.time_mix_value = create_tensor(tn(LLM_TENSOR_TIME_MIX_VALUE, "weight", i), {attn_hidden_size, n_embd}, 0);
|
||||||
layer.time_mix_receptance = create_tensor(tn(LLM_TENSOR_TIME_MIX_RECEPTANCE, "weight", i), {attn_hidden_size, n_embd}, 0);
|
layer.time_mix_receptance = create_tensor(tn(LLM_TENSOR_TIME_MIX_RECEPTANCE, "weight", i), {attn_hidden_size, n_embd}, 0);
|
||||||
|
|
||||||
layer.time_mix_ln = create_tensor(tn(LLM_TENSOR_TIME_MIX_LN, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
layer.time_mix_ln = create_tensor(tn(LLM_TENSOR_TIME_MIX_LN, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED);
|
||||||
layer.time_mix_ln_b = create_tensor(tn(LLM_TENSOR_TIME_MIX_LN, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
layer.time_mix_ln_b = create_tensor(tn(LLM_TENSOR_TIME_MIX_LN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED);
|
||||||
layer.time_mix_output = create_tensor(tn(LLM_TENSOR_TIME_MIX_OUTPUT, "weight", i), {n_embd, attn_hidden_size}, 0);
|
layer.time_mix_output = create_tensor(tn(LLM_TENSOR_TIME_MIX_OUTPUT, "weight", i), {n_embd, attn_hidden_size}, 0);
|
||||||
|
|
||||||
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
|
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
|
||||||
|
@ -6296,16 +6309,25 @@ struct llm_build_qwen2moe : public llm_graph_context {
|
||||||
{
|
{
|
||||||
// compute Q and K and RoPE them
|
// compute Q and K and RoPE them
|
||||||
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
|
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
|
||||||
|
cb(Qcur, "Qcur", il);
|
||||||
|
if (model.layers[il].bq) {
|
||||||
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
|
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
|
||||||
cb(Qcur, "Qcur", il);
|
cb(Qcur, "Qcur", il);
|
||||||
|
}
|
||||||
|
|
||||||
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
|
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
|
||||||
|
cb(Kcur, "Kcur", il);
|
||||||
|
if (model.layers[il].bk) {
|
||||||
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
|
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
|
||||||
cb(Kcur, "Kcur", il);
|
cb(Kcur, "Kcur", il);
|
||||||
|
}
|
||||||
|
|
||||||
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
|
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
|
||||||
|
cb(Vcur, "Vcur", il);
|
||||||
|
if (model.layers[il].bv) {
|
||||||
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
|
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
|
||||||
cb(Vcur, "Vcur", il);
|
cb(Vcur, "Vcur", il);
|
||||||
|
}
|
||||||
|
|
||||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||||
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue