clean and rename old clblast files in preparation for merge

This commit is contained in:
Concedo 2024-12-15 15:29:02 +08:00
parent a577015425
commit 1e07043a6e
6 changed files with 247 additions and 294 deletions

View file

@ -578,7 +578,7 @@ ggml_v1_failsafe.o: otherarch/ggml_v1.c otherarch/ggml_v1.h
$(CC) $(FASTCFLAGS) $(NONECFLAGS) -c $< -o $@ $(CC) $(FASTCFLAGS) $(NONECFLAGS) -c $< -o $@
#opencl #opencl
ggml-opencl.o: ggml-opencl.cpp ggml-opencl.h ggml-opencl.o: otherarch/ggml_v3b-opencl.cpp otherarch/ggml_v3b-opencl.h
$(CXX) $(CXXFLAGS) $(CLBLAST_FLAGS) -c $< -o $@ $(CXX) $(CXXFLAGS) $(CLBLAST_FLAGS) -c $< -o $@
ggml_v2-opencl.o: otherarch/ggml_v2-opencl.cpp otherarch/ggml_v2-opencl.h ggml_v2-opencl.o: otherarch/ggml_v2-opencl.cpp otherarch/ggml_v2-opencl.h
$(CXX) $(CXXFLAGS) $(CLBLAST_FLAGS) -c $< -o $@ $(CXX) $(CXXFLAGS) $(CLBLAST_FLAGS) -c $< -o $@

View file

@ -87,7 +87,7 @@
#include <Accelerate/Accelerate.h> #include <Accelerate/Accelerate.h>
#endif #endif
#if defined(GGML_USE_CLBLAST) // allow usage of CLBlast alongside Accelerate functions #if defined(GGML_USE_CLBLAST) // allow usage of CLBlast alongside Accelerate functions
#include "ggml-opencl.h" #include "ggml_v3b-opencl.h"
#endif #endif
// floating point type used to accumulate sums // floating point type used to accumulate sums

View file

@ -4125,7 +4125,9 @@ Current version indicated by LITEVER below.
const xtts_gen_endpoint = "/tts_to_audio/"; const xtts_gen_endpoint = "/tts_to_audio/";
const xtts_voices_endpoint = "/speakers_list"; const xtts_voices_endpoint = "/speakers_list";
const alltalk_gen_endpoint = "/api/tts-generate"; const alltalk_gen_endpoint = "/api/tts-generate";
const alltalk_stream_endpoint = "/api/tts-generate-streaming";
const alltalk_voices_endpoint = "/api/voices"; const alltalk_voices_endpoint = "/api/voices";
const alltalk_rvc_voices_endpoint = "/api/rvcvoices";
//support for quick news updates //support for quick news updates
const horde_news_endpoint = "https://hordenews.concedo.workers.dev" const horde_news_endpoint = "https://hordenews.concedo.workers.dev"
@ -4181,7 +4183,7 @@ Current version indicated by LITEVER below.
var current_wi = []; //each item stores a wi object. var current_wi = []; //each item stores a wi object.
var wi_insertlocation = 0; //after memory var wi_insertlocation = 0; //after memory
var wi_searchdepth = 0; //search everything var wi_searchdepth = 0; //search everything
var generateimagesinterval = 700; //if generated images is enabled, it will trigger after every 700 new characters in context. var generateimagesinterval = 750; //if generated images is enabled, it will trigger after every 700 new characters in context.
var nextgeneratedimagemilestone = generateimagesinterval; //used to keep track of when to generate the next image var nextgeneratedimagemilestone = generateimagesinterval; //used to keep track of when to generate the next image
var image_db = {}; //stores a dictionary of pending images var image_db = {}; //stores a dictionary of pending images
var interrogation_db = {}; var interrogation_db = {};
@ -5052,7 +5054,11 @@ Current version indicated by LITEVER below.
const foundChub = urlParams.get('chub'); const foundChub = urlParams.get('chub');
const foundPyg = urlParams.get('pyg'); const foundPyg = urlParams.get('pyg');
const foundAicc = urlParams.get('aicc'); const foundAicc = urlParams.get('aicc');
const foundQuery = urlParams.get('query'); let foundQuery = urlParams.get('query');
if (!foundQuery || foundQuery == "")
{
foundQuery = urlParams.get('q');
}
if (foundStory && foundStory != "") { if (foundStory && foundStory != "") {
if (localsettings.persist_session && !safe_to_overwrite()) { if (localsettings.persist_session && !safe_to_overwrite()) {
@ -8054,9 +8060,18 @@ Current version indicated by LITEVER below.
let style = (elem.trusted ? "style=\"color:#dd77ff;\"" : ""); let style = (elem.trusted ? "style=\"color:#dd77ff;\"" : "");
let brokenstyle = (elem.maintenance_mode ? "style=\"color:#ee4444;\"" : ""); let brokenstyle = (elem.maintenance_mode ? "style=\"color:#ee4444;\"" : "");
let workerNameHtml = escapeHtml(elem.name.substring(0, 40)); let workerNameHtml = escapeHtml(elem.name.substring(0, 40));
let clickinfo = "";
if(elem.info && elem.info!="") if(elem.info && elem.info!="")
{ {
workerNameHtml = "<a class=\"color_blueurl\" href=\"#\" onclick=\"msgbox(\'"+escapeHtml(replaceAll(elem.info,"\'","\\\'"))+"\','Worker Info',false,false,hide_msgbox)\">"+workerNameHtml+"</a>"; clickinfo += escapeHtml(replaceAll(elem.info,"\'","\\\'"));
}
if(elem.threads>1)
{
clickinfo += (clickinfo==""?"":"<br><br>") + "Threads: " + elem.threads;
}
if(clickinfo!="")
{
workerNameHtml = "<a class=\"color_blueurl\" href=\"#\" onclick=\"msgbox(\'"+clickinfo+"\','Worker Info',false,false,hide_msgbox)\">"+workerNameHtml+"</a>";
} }
let allmdls = ""; let allmdls = "";
for (let n = 0; n < elem.models.length; ++n) { for (let n = 0; n < elem.models.length; ++n) {
@ -11653,7 +11668,7 @@ Current version indicated by LITEVER below.
function do_auto_gen_image(truncated_context) function do_auto_gen_image(truncated_context)
{ {
var tclen = truncated_context.length; var tclen = truncated_context.length;
var sentence = truncated_context.substring(tclen - 380, tclen); var sentence = truncated_context.substring(tclen - 400, tclen);
sentence = start_trim_to_sentence(sentence); sentence = start_trim_to_sentence(sentence);
sentence = end_trim_to_sentence(sentence,true); sentence = end_trim_to_sentence(sentence,true);
if (sentence.length > 0) { if (sentence.length > 0) {
@ -11827,6 +11842,28 @@ Current version indicated by LITEVER below.
//alltalk mode //alltalk mode
data = data.voices; data = data.voices;
} }
else if(data && !data.length && data.constructor == Object)
{
//hybrid new xtts mantella
let newdata = [];
for(key in data)
{
let lang = data[key];
if(lang && lang.speakers && lang.speakers.length>0)
{
for(let i=0;i<lang.speakers.length;++i)
{
newdata.push(lang.speakers[i]);
}
}
}
if(newdata.length > 0)
{
data = newdata;
}
}
let dropdown = document.getElementById("xtts_voices"); let dropdown = document.getElementById("xtts_voices");
let selectionhtml = ``; let selectionhtml = ``;
for (var i = 0; i < data.length; ++i) { for (var i = 0; i < data.length; ++i) {
@ -11864,46 +11901,92 @@ Current version indicated by LITEVER below.
{ {
document.getElementById("xtts_container").classList.add("hidden"); document.getElementById("xtts_container").classList.add("hidden");
document.getElementById("oai_tts_container").classList.add("hidden"); document.getElementById("oai_tts_container").classList.add("hidden");
if(document.getElementById("ttsselect").value==XTTS_ID || document.getElementById("ttsselect").value==ALLTALK_ID) document.getElementById("alltalk_specific_controls").classList.add("hidden");
{
const selectedTTS = document.getElementById("ttsselect").value;
if(selectedTTS == XTTS_ID || selectedTTS == ALLTALK_ID) {
document.getElementById("xtts_container").classList.remove("hidden"); document.getElementById("xtts_container").classList.remove("hidden");
fetch_xtts_voices(true, document.getElementById("ttsselect").value==XTTS_ID);
if(selectedTTS == ALLTALK_ID) {
document.getElementById("alltalk_specific_controls").classList.remove("hidden");
fetch_rvc_voices();
adjust_alltalk_controls();
}
fetch_xtts_voices(true, selectedTTS == XTTS_ID);
} }
else if(document.getElementById("ttsselect").value==OAI_TTS_ID) else if(selectedTTS == OAI_TTS_ID) {
{
document.getElementById("oai_tts_container").classList.remove("hidden"); document.getElementById("oai_tts_container").classList.remove("hidden");
} }
} }
function set_xtts_url()
// Fetch RVC voices for AllTalk
function fetch_rvc_voices()
{ {
if(!xtts_is_connected) //prevent it from constantly fetching, will only fetch once before connecting
{
fetch(localsettings.saved_alltalk_url + alltalk_rvc_voices_endpoint)
.then(response => response.json())
.then(data => {
console.log("RVC voices response:", data); // Debug log
const rvcSelect = document.getElementById("alltalk_rvc_voice");
rvcSelect.innerHTML = '<option value="Disabled">Disabled</option>';
if (data.status === "success" && Array.isArray(data.rvcvoices)) { // Changed from data.voices to data.rvcvoices
data.rvcvoices.forEach(voice => { // Changed from data.voices to data.rvcvoices
if (voice !== "Disabled") {
const option = document.createElement("option");
option.value = voice;
option.textContent = voice.split("\\").pop().replace(".pth", "");
rvcSelect.appendChild(option);
}
});
}
})
.catch(error => {
console.log("Error fetching RVC voices:", error);
});
}
}
//single callback to update alltalk controls on any alltalk UI event.
function adjust_alltalk_controls() {
const pitchSlider = document.getElementById("alltalk_rvc_pitch");
const pitchValue = document.getElementById("alltalk_rvc_pitch_value");
pitchValue.textContent = pitchSlider.value;
const streamingMode = (document.getElementById("alltalk_streaming").checked ? true : false);
const rvcSelect = document.getElementById("alltalk_rvc_voice");
const rvcPitch = document.getElementById("alltalk_rvc_pitch");
rvcSelect.disabled = streamingMode;
rvcPitch.disabled = streamingMode;
}
// Update set_xtts_url to use the new fetch_rvc_voices function
function set_xtts_url() {
let is_xtts = (document.getElementById("ttsselect").value==XTTS_ID); let is_xtts = (document.getElementById("ttsselect").value==XTTS_ID);
let epname = (is_xtts?"XTTS":"AllTalk"); let epname = (is_xtts?"XTTS":"AllTalk");
inputBox("Enter "+epname+" API Server URL.",epname+" API Server URL",(is_xtts?localsettings.saved_xtts_url:localsettings.saved_alltalk_url),"Input "+epname+" API Server URL", ()=>{ inputBox("Enter "+epname+" API Server URL.",epname+" API Server URL",(is_xtts?localsettings.saved_xtts_url:localsettings.saved_alltalk_url),"Input "+epname+" API Server URL", ()=>{
let userinput = getInputBoxValue(); let userinput = getInputBoxValue();
userinput = userinput.trim(); userinput = userinput.trim();
if(userinput!="" && userinput.slice(-1)=="/") if(userinput!="" && userinput.slice(-1)=="/") {
{
userinput = userinput.slice(0, -1); userinput = userinput.slice(0, -1);
} }
if(userinput=="") if(userinput=="") {
{
userinput = (is_xtts?default_xtts_base:default_alltalk_base); userinput = (is_xtts?default_xtts_base:default_alltalk_base);
} }
if (userinput != null && userinput!="") { if (userinput != null && userinput!="") {
if(is_xtts)
{
localsettings.saved_xtts_url = userinput.trim();
}
else
{
localsettings.saved_alltalk_url = userinput.trim();
}
xtts_is_connected = false; xtts_is_connected = false;
if(is_xtts) {
localsettings.saved_xtts_url = userinput.trim();
} else {
localsettings.saved_alltalk_url = userinput.trim();
// Fetch RVC voices with new URL
fetch_rvc_voices();
}
fetch_xtts_voices(false, is_xtts); fetch_xtts_voices(false, is_xtts);
} }
},false); },false);
} }
function tts_speak(text, speech_synth_override=null) function tts_speak(text, speech_synth_override=null)
{ {
if(!text || text=="" || text.trim()=="") if(!text || text=="" || text.trim()=="")
@ -12019,35 +12102,18 @@ Current version indicated by LITEVER below.
},300); },300);
}; };
}).catch((error) => { }).catch((error) => {
xtts_is_playing = false;
update_submit_button(false);
console.log("XTTS Speak Error: " + error); console.log("XTTS Speak Error: " + error);
}); });
} }
else else
{ {
//alltalk //alltalk
const formData = new FormData(); const isStreaming = (document.getElementById("alltalk_streaming").checked ? true : false);
formData.append("text_input", text); // max 2000 chars
formData.append("text_filtering", "none"); // (none|standard|html)
formData.append("character_voice_gen", document.getElementById("xtts_voices").value);
formData.append("narrator_enabled", false);
formData.append("narrator_voice_gen", document.getElementById("xtts_voices").value);
formData.append("text_not_inside", "character"); // character or narrator, determines which to use
formData.append("language", document.getElementById("xtts_lang").value.trim().toLowerCase());
formData.append("output_file_name", "audiofile"); // NOTE: file name only, with no extension and no dashes!
formData.append("output_file_timestamp", true);
formData.append("autoplay", false); //to play in browser
formData.append("autoplay_volume", 1.0); // (0.1..2.0)
formData.append("streaming", true); // unknown why
fetch(localsettings.saved_alltalk_url + alltalk_gen_endpoint, { let playDecodedAllTalkData = function(decodedData)
method: 'POST', {
body: formData, // send payload as FormData
})
.then(response => response.arrayBuffer())
.then(data => {
return audioContext.decodeAudioData(data);
})
.then(decodedData => {
const playSound = audioContext.createBufferSource(); const playSound = audioContext.createBufferSource();
playSound.buffer = decodedData; playSound.buffer = decodedData;
playSound.connect(audioContext.destination); playSound.connect(audioContext.destination);
@ -12061,9 +12127,109 @@ Current version indicated by LITEVER below.
console.log("Audio finished playing"); console.log("Audio finished playing");
},300); },300);
}; };
}).catch((error) => { }
console.log("AllTalk Speak Error: " + error);
}); if (isStreaming) {
// Create a URLSearchParams object for streaming
const params = new URLSearchParams({
text: text,
voice: document.getElementById("xtts_voices").value,
language: document.getElementById("xtts_lang").value.trim().toLowerCase(),
output_file: "klite_stream_output.wav",
});
// Create streaming URL, but right now it's as good as sync
const streamingUrl = `${localsettings.saved_alltalk_url}${alltalk_stream_endpoint}?${params.toString()}`;
fetch(streamingUrl)
.then(response => response.arrayBuffer())
.then(data => {
return audioContext.decodeAudioData(data);
})
.then(decodedData => {
playDecodedAllTalkData(decodedData);
})
.catch((error) => {
console.log("AllTalk v2 Speak Error:", data);
xtts_is_playing = false;
update_submit_button(false);
});
} else {
// Standard mode using FormData
const formData = new FormData();
formData.append("text_input", text);
formData.append("text_filtering", "none");
formData.append("character_voice_gen", document.getElementById("xtts_voices").value);
formData.append("narrator_enabled", false);
formData.append("narrator_voice_gen", document.getElementById("xtts_voices").value);
formData.append("text_not_inside", "character");
formData.append("language", document.getElementById("xtts_lang").value.trim().toLowerCase());
formData.append("output_file_name", "audiofile");
formData.append("output_file_timestamp", true);
formData.append("autoplay", false);
formData.append("autoplay_volume", 1.0);
formData.append("rvccharacter_voice_gen", document.getElementById("alltalk_rvc_voice").value);
formData.append("rvccharacter_pitch", document.getElementById("alltalk_rvc_pitch").value);
formData.append("rvcnarrator_voice_gen", document.getElementById("alltalk_rvc_voice").value);
formData.append("rvcnarrator_pitch", document.getElementById("alltalk_rvc_pitch").value);
fetch(localsettings.saved_alltalk_url + alltalk_gen_endpoint, {
method: 'POST',
body: formData, // send payload as FormData
}).then(response => {
//content type can be JSON (alltalk v2) or raw audio (v1)
const contentType = response.headers.get("Content-Type");
//alltalk v2 json
if (contentType && contentType.toLowerCase().includes("application/json"))
{
return response.json().then(data => {
if (data && data.output_file_url && data.status === "generate-success")
{
const audioUrl = `${localsettings.saved_alltalk_url}${data.output_file_url}`;
fetch(audioUrl)
.then(response => response.arrayBuffer())
.then(data => {
return audioContext.decodeAudioData(data);
})
.then(decodedData => {
playDecodedAllTalkData(decodedData);
})
.catch((error) => {
console.log("AllTalk v2 Speak Error:", data);
xtts_is_playing = false;
update_submit_button(false);
});
} else {
console.log("AllTalk Generation Error:", data);
xtts_is_playing = false;
update_submit_button(false);
}
})
.catch((error) => {
console.log("AllTalk Request Error:", error);
xtts_is_playing = false;
update_submit_button(false);
});
}
else //alltalk v1 audio
{
return response.arrayBuffer().then(data => {
return audioContext.decodeAudioData(data);
})
.then(decodedData => {
playDecodedAllTalkData(decodedData);
}).catch((error) => {
console.log("AllTalk v1 Speak Error: " + error);
xtts_is_playing = false;
update_submit_button(false);
});
}
}).catch((error) => {
console.log("AllTalk Non-Stream Req Error: " + error);
xtts_is_playing = false;
update_submit_button(false);
});
}
} }
} }
} }
@ -15880,7 +16046,7 @@ Current version indicated by LITEVER below.
whorun = "<br>You're using the Cohere API"; whorun = "<br>You're using the Cohere API";
} }
else { else {
whorun = `<br>There are <span class="color_orange">${selected_models.reduce((s, a) => s + a.count, 0)}</span> <a class="color_green mainnav" href="#" tabindex="${mainmenu_is_untab?`-1`:`0`}" onclick="get_and_show_workers()">volunteer(s)</a> running selected models with a total queue length of <span class="color_orange">${selected_models.reduce((s, a) => s + a.queued, 0)}</span> tokens`; whorun = `<br>Horde <a class="color_green mainnav" href="#" tabindex="${mainmenu_is_untab?`-1`:`0`}" onclick="get_and_show_workers()">Volunteer(s)</a> are running <span class="color_orange">${selected_models.reduce((s, a) => s + a.count, 0)} threads</span> for selected models with a total queue length of <span class="color_orange">${selected_models.reduce((s, a) => s + a.queued, 0)}</span> tokens`;
} }
let nowmode = (localsettings.opmode==1?"Story Mode":(localsettings.opmode==2?"Adventure Mode":(localsettings.opmode==3?"Chat Mode":"Instruct Mode"))); let nowmode = (localsettings.opmode==1?"Story Mode":(localsettings.opmode==2?"Adventure Mode":(localsettings.opmode==3?"Chat Mode":"Instruct Mode")));
let selmodelstr = ""; let selmodelstr = "";
@ -19562,13 +19728,34 @@ Current version indicated by LITEVER below.
</select> </select>
<button id="test_tts" type="button" class="bg_green btn btn-primary" style="height:20px; width:30px; padding:2px 3px;font-size:11px; margin-left: 2px;" onclick="test_tts()">Test</button> <button id="test_tts" type="button" class="bg_green btn btn-primary" style="height:20px; width:30px; padding:2px 3px;font-size:11px; margin-left: 2px;" onclick="test_tts()">Test</button>
<div id="xtts_container" class="settinglabel hidden"> <div id="xtts_container" class="settinglabel hidden">
<table width="100%"><tr> <div>
<td><button id="xtts_url" type="button" class="btn btn-primary" style="width:100%; padding:2px 3px;margin-top:2px;font-size:11px;" onclick="set_xtts_url()">Set URL</button></td> <table width="100%"><tr>
<td><select class="form-control" id="xtts_voices" style="font-size:12px;height:20px;padding:0;margin:0px 0 0;"> <td><button id="xtts_url" type="button" class="btn btn-primary" style="width:100%; padding:2px 3px;margin-top:2px;font-size:11px;" onclick="set_xtts_url()">Set URL</button></td>
<option value="female_calm" selected>female_calm</option><option value="female">female</option><option value="male">male</option> <td><select class="form-control" id="xtts_voices" style="font-size:12px;height:20px;padding:0;margin:0px 0 0;">
</select></td> <option value="female_calm" selected>female_calm</option><option value="female">female</option><option value="male">male</option>
</tr><tr style="font-size:12px;padding:2px;margin:0px 0 0;"><td>Language </td><td><input class="settinglabel miniinput" type="text" value="EN" id="xtts_lang" style="margin-left:3px; height:18px; width: 40px; padding: 2px;"></td></tr> </select></td>
</table> </tr><tr style="font-size:12px;padding:2px;margin:0px 0 0;"><td>Language </td><td><input class="settinglabel miniinput" type="text" value="EN" id="xtts_lang" style="margin-left:3px; height:18px; width: 40px; padding: 2px;"></td></tr>
</table>
</div>
<div id="alltalk_specific_controls" style="width:100%;font-size: 11px;" class="settinglabel hidden">
<div>
<div class="justifyleft" style="padding:2px" title="AllTalk Streaming">Audio Streaming </div>
<input title="AllTalk Streaming" onchange="adjust_alltalk_controls();" type="checkbox" id="alltalk_streaming" style="margin:0px 0px 0px auto;">
</div>
<div>
<div>RVC Voice</div>
<select class="form-control" id="alltalk_rvc_voice" style="font-size:12px;height:20px;padding:0;margin:0px 0 0;width:100%;">
<option value="Disabled">Disabled</option>
</select>
</div>
<div>
<div>RVC Pitch</div>
<div style="display:flex;align-items:center;">
<input oninput="adjust_alltalk_controls();" type="range" id="alltalk_rvc_pitch" min="-24" max="24" value="0" style="flex:1;height:20px;">
<span id="alltalk_rvc_pitch_value" style="margin-left:5px;font-size:12px;">0</span>
</div>
</div>
</div>
</div> </div>
<div id="oai_tts_container" class="settinglabel hidden"> <div id="oai_tts_container" class="settinglabel hidden">
<table width="100%"><tr> <table width="100%"><tr>

View file

@ -1,5 +1,5 @@
#include "ggml.h" #include "ggml.h"
#include "ggml-opencl.h" #include "ggml_v3b-opencl.h"
#include "ggml-backend-impl.h" #include "ggml-backend-impl.h"
#include "ggml-cpu.h" #include "ggml-cpu.h"
@ -2012,224 +2012,3 @@ void ggml_cl_transform_tensor(void * data, ggml_tensor * tensor) {
tensor->extra = dst; tensor->extra = dst;
GGML_ASSERT(tensor->backend == GGML_BACKEND_TYPE_GPU); GGML_ASSERT(tensor->backend == GGML_BACKEND_TYPE_GPU);
} }
// ggml-backend
// buffer
struct ggml_backend_opencl_buffer_context {
~ggml_backend_opencl_buffer_context() {
if (buffer) {
clReleaseMemObject(buffer);
}
for (auto * sub_buffer : sub_buffers) {
clReleaseMemObject(sub_buffer);
}
}
cl_mem buffer;
std::vector<cl_mem> sub_buffers;
};
static void * const cl_ptr_base = (void *)(uintptr_t) 0x1000;
static const char * ggml_backend_opencl_buffer_get_name(ggml_backend_buffer_t buffer) {
return "OpenCL";
GGML_UNUSED(buffer);
}
static void ggml_backend_opencl_buffer_free_buffer(ggml_backend_buffer_t buffer) {
ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context;
delete ctx;
}
static void * ggml_backend_opencl_buffer_get_base(ggml_backend_buffer_t buffer) {
return cl_ptr_base;
GGML_UNUSED(buffer);
}
static void ggml_backend_opencl_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
if (tensor->view_src != NULL && tensor->view_offs == 0) {
tensor->extra = tensor->view_src->extra;
} else {
ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context;
cl_buffer_region region = {(size_t)((char *)tensor->data - (char *)cl_ptr_base), ggml_nbytes(tensor)};
cl_int err;
cl_mem sub_buffer = clCreateSubBuffer(ctx->buffer, CL_MEM_READ_WRITE, CL_BUFFER_CREATE_TYPE_REGION, &region, &err);
CL_CHECK(err);
ctx->sub_buffers.push_back(sub_buffer);
tensor->extra = sub_buffer;
}
tensor->backend = GGML_BACKEND_TYPE_GPU;
}
static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
cl_mem tensor_buffer = (cl_mem) tensor->extra;
CL_CHECK(clEnqueueWriteBuffer(queue, tensor_buffer, true, offset, size, data, 0, NULL, NULL));
CL_CHECK(clFinish(queue));
GGML_UNUSED(buffer);
}
static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
cl_mem tensor_buffer = (cl_mem) tensor->extra;
CL_CHECK(clEnqueueReadBuffer(queue, tensor_buffer, true, offset, size, data, 0, NULL, NULL));
CL_CHECK(clFinish(queue));
GGML_UNUSED(buffer);
}
static void ggml_backend_opencl_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context;
CL_CHECK(clEnqueueFillBuffer(queue, ctx->buffer, &value, sizeof(value), 0, buffer->size, 0, NULL, NULL));
CL_CHECK(clFinish(queue));
}
static void ggml_backend_opencl_buffer_reset(ggml_backend_buffer_t buffer) {
ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context;
for (auto * sub_buffer : ctx->sub_buffers) {
clReleaseMemObject(sub_buffer);
}
ctx->sub_buffers.clear();
}
static ggml_backend_buffer_i ggml_backend_opencl_buffer_interface = {
/* .free_buffer = */ ggml_backend_opencl_buffer_free_buffer,
/* .get_base = */ ggml_backend_opencl_buffer_get_base,
/* .init_tensor = */ ggml_backend_opencl_buffer_init_tensor,
/* .memset_tensor = */ NULL,
/* .set_tensor = */ ggml_backend_opencl_buffer_set_tensor,
/* .get_tensor = */ ggml_backend_opencl_buffer_get_tensor,
/* .cpy_tensor = */ NULL,
/* .clear = */ ggml_backend_opencl_buffer_clear,
/* .reset = */ ggml_backend_opencl_buffer_reset,
};
// buffer type
static const char * ggml_backend_opencl_buffer_type_name(ggml_backend_buffer_type_t buffer_type) {
return "OpenCL";
GGML_UNUSED(buffer_type);
}
static ggml_backend_buffer_t ggml_backend_opencl_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buffer_type, size_t size) {
ggml_cl_init();
cl_int err;
cl_mem mem = clCreateBuffer(context, CL_MEM_READ_WRITE, size, NULL, &err);
if (err != CL_SUCCESS) {
fprintf(stderr, "%s: failed to allocate %.2f MiB\n", __func__, size / 1024.0 / 1024.0);
return nullptr;
}
ggml_backend_opencl_buffer_context * ctx = new ggml_backend_opencl_buffer_context{mem, {}};
return ggml_backend_buffer_init(buffer_type, ggml_backend_opencl_buffer_interface, ctx, size);
}
static size_t ggml_backend_opencl_buffer_type_get_alignment(ggml_backend_buffer_type_t buffer_type) {
// FIXME: not thread safe, device may not be initialized yet
static cl_uint alignment = -1;
if (alignment == (cl_uint)-1) {
ggml_cl_init();
clGetDeviceInfo(device, CL_DEVICE_MEM_BASE_ADDR_ALIGN, sizeof(cl_uint), &alignment, NULL);
alignment /= 8; // bits to bytes
}
return alignment;
GGML_UNUSED(buffer_type);
}
static size_t ggml_backend_opencl_buffer_type_get_max_size(ggml_backend_buffer_type_t buffer_type) {
static size_t max_size = -1;
if (max_size == (size_t)-1) {
ggml_cl_init();
clGetDeviceInfo(device, CL_DEVICE_MAX_MEM_ALLOC_SIZE, sizeof(size_t), &max_size, NULL);
}
return max_size;
}
static bool ggml_backend_opencl_buffer_type_supports_backend(ggml_backend_buffer_type_t buffer_type, ggml_backend_t backend) {
//return ggml_backend_is_opencl(backend); // opencl must be used through the cpu backend
return ggml_backend_is_cpu(backend);
GGML_UNUSED(buffer_type);
}
static ggml_backend_buffer_type_i ggml_backend_opencl_buffer_type_interface = {
/* .get_name = */ ggml_backend_opencl_buffer_type_name,
/* .alloc_buffer = */ ggml_backend_opencl_buffer_type_alloc_buffer,
/* .get_alignment = */ ggml_backend_opencl_buffer_type_get_alignment,
/* .get_max_size = */ ggml_backend_opencl_buffer_type_get_max_size,
/* .get_alloc_size = */ NULL,
/* .is_host = */ NULL,
};
ggml_backend_buffer_type_t ggml_backend_opencl_buffer_type() {
static ggml_backend_buffer_type buffer_type = {
/* .iface = */ ggml_backend_opencl_buffer_type_interface,
/* .context = */ nullptr,
};
return &buffer_type;
}
#if 0
// host buffer type
static const char * ggml_backend_opencl_host_buffer_type_name(ggml_backend_buffer_type_t buft) {
return "CL_Host";
GGML_UNUSED(buft);
}
static const char * ggml_backend_opencl_host_buffer_name(ggml_backend_buffer_t buffer) {
return "CL_Host";
GGML_UNUSED(buffer);
}
static void ggml_backend_opencl_host_buffer_free_buffer(ggml_backend_buffer_t buffer) {
ggml_cl_host_free(buffer->context);
}
static ggml_backend_buffer_t ggml_backend_opencl_host_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
void * ptr = ggml_cl_host_malloc(size);
if (ptr == nullptr) {
// fallback to cpu buffer
return ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size);
}
ggml_backend_buffer_t buffer = ggml_backend_cpu_buffer_from_ptr(ptr, size);
buffer->buft = buft;
buffer->iface.get_name = ggml_backend_opencl_host_buffer_name;
buffer->iface.free_buffer = ggml_backend_opencl_host_buffer_free_buffer;
return buffer;
}
ggml_backend_buffer_type_t ggml_backend_opencl_host_buffer_type() {
static struct ggml_backend_buffer_type ggml_backend_opencl_buffer_type_host = {
/* .iface = */ {
/* .get_name = */ ggml_backend_opencl_host_buffer_type_name,
/* .alloc_buffer = */ ggml_backend_opencl_host_buffer_type_alloc_buffer,
/* .get_alignment = */ ggml_backend_cpu_buffer_type()->iface.get_alignment,
/* .get_max_size = */ NULL, // defaults to SIZE_MAX
/* .get_alloc_size = */ ggml_backend_cpu_buffer_type()->iface.get_alloc_size,
/* .supports_backend = */ ggml_backend_cpu_buffer_type()->iface.supports_backend,
/* .is_host = */ ggml_backend_cpu_buffer_type()->iface.is_host,
},
/* .context = */ nullptr,
};
return &ggml_backend_opencl_buffer_type_host;
}
// backend
#endif

View file

@ -16,22 +16,9 @@ GGML_API bool ggml_cl_can_mul_mat(const struct ggml_tensor * src0, const struc
GGML_API size_t ggml_cl_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst); GGML_API size_t ggml_cl_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
GGML_API void ggml_cl_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst, void * wdata, size_t wsize); GGML_API void ggml_cl_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst, void * wdata, size_t wsize);
// GGML_API void * ggml_cl_host_malloc(size_t size);
// GGML_API void ggml_cl_host_free(void * ptr);
GGML_API void ggml_cl_free_data(const struct ggml_tensor* tensor); GGML_API void ggml_cl_free_data(const struct ggml_tensor* tensor);
GGML_API void ggml_cl_transform_tensor(void * data, struct ggml_tensor * tensor); GGML_API void ggml_cl_transform_tensor(void * data, struct ggml_tensor * tensor);
// backend API
// GGML_API ggml_backend_t ggml_backend_opencl_init(void);
// GGML_API bool ggml_backend_is_opencl(ggml_backend_t backend);
GGML_API ggml_backend_buffer_type_t ggml_backend_opencl_buffer_type(void);
// GGML_API ggml_backend_buffer_type_t ggml_backend_opencl_host_buffer_type(void);
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif

View file

@ -15,7 +15,7 @@
#ifdef GGML_USE_CUDA #ifdef GGML_USE_CUDA
# include "ggml-cuda.h" # include "ggml-cuda.h"
#elif defined(GGML_USE_CLBLAST) #elif defined(GGML_USE_CLBLAST)
# include "ggml-opencl.h" # include "ggml_v3b-opencl.h"
#endif #endif