mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-11 17:44:38 +00:00
Merge branch 'upstream' into concedo_experimental
This commit is contained in:
commit
c21c8cd00a
2 changed files with 10 additions and 3 deletions
|
@ -3192,7 +3192,7 @@ kernel void kernel_flash_attn_ext(
|
|||
|
||||
{
|
||||
float S[Q] = { [0 ... Q-1] = 0.0f };
|
||||
float M[Q] = { [0 ... Q-1] = -__FLT16_MAX__/2 };
|
||||
float M[Q] = { [0 ... Q-1] = -__FLT_MAX__/2 };
|
||||
|
||||
// thread indices inside the simdgroup
|
||||
// TODO: see if we can utilize quad-group functions for better performance
|
||||
|
@ -3452,7 +3452,7 @@ kernel void kernel_flash_attn_ext(
|
|||
// reduce the warps sequentially
|
||||
for (ushort sg = 1; sg < nsg; ++sg) {
|
||||
float S = { 0.0f };
|
||||
float M = { -__FLT16_MAX__/2 };
|
||||
float M = { -__FLT_MAX__/2 };
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
|
@ -3699,7 +3699,7 @@ kernel void kernel_flash_attn_ext_vec(
|
|||
|
||||
{
|
||||
float S = 0.0f;
|
||||
float M = -__FLT16_MAX__/2;
|
||||
float M = -__FLT_MAX__/2;
|
||||
|
||||
// thread indices inside the simdgroup
|
||||
const short tx = tiisg%NL;
|
||||
|
|
|
@ -262,6 +262,7 @@ struct vk_device_struct {
|
|||
bool pipeline_robustness;
|
||||
vk::Device device;
|
||||
uint32_t vendor_id;
|
||||
vk::DriverId driver_id;
|
||||
vk_device_architecture architecture;
|
||||
vk_queue compute_queue;
|
||||
vk_queue transfer_queue;
|
||||
|
@ -1756,6 +1757,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
m_warptile_mmq_int = { 128, 64, 64, 32, subgroup_size_8, 32, 2, 2, 2, 1, subgroup_size_8 };
|
||||
s_warptile_mmq_int = { subgroup_size_32, 32, 32, 32, 32, 32, 2, 2, 1, 1, subgroup_size_8 };
|
||||
|
||||
// chip specific tuning
|
||||
if ((device->architecture == AMD_GCN) && (device->driver_id != vk::DriverId::eAmdProprietary)) {
|
||||
m_warptile_mmq = m_warptile_mmq_int = { 256, 64, 64, 32, 16, 16, 2, 2, 2, 1, 16 };
|
||||
}
|
||||
|
||||
l_mmq_wg_denoms = l_wg_denoms = {128, 128, 1 };
|
||||
m_mmq_wg_denoms = m_wg_denoms = { 64, 64, 1 };
|
||||
s_mmq_wg_denoms = s_wg_denoms = { 32, 32, 1 };
|
||||
|
@ -2678,6 +2684,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|||
device->physical_device.getProperties2(&props2);
|
||||
device->properties = props2.properties;
|
||||
device->vendor_id = device->properties.vendorID;
|
||||
device->driver_id = driver_props.driverID;
|
||||
|
||||
const char* GGML_VK_FORCE_MAX_ALLOCATION_SIZE = getenv("GGML_VK_FORCE_MAX_ALLOCATION_SIZE");
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue