mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-10 06:14:58 +00:00
[ADD] support multi-gpu qlen>1 q5_k
This commit is contained in:
parent
f293803156
commit
f5f79f5c0e
63 changed files with 3271 additions and 1285 deletions
|
@ -21,6 +21,7 @@ class CUDAGraphRunner:
|
|||
position_ids,
|
||||
cache_position,
|
||||
past_key_values,
|
||||
main_device,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
assert self.graph is None
|
||||
|
@ -29,15 +30,24 @@ class CUDAGraphRunner:
|
|||
self.graph = torch.cuda.CUDAGraph()
|
||||
#self.graph.enable_debug_mode()
|
||||
self.model = model
|
||||
inputs_embeds = model.model.embed_tokens(cur_token.to("cpu")).to("cuda")
|
||||
with torch.cuda.graph(self.graph):
|
||||
inputs_embeds = model.model.embed_tokens(cur_token.to("cpu")).to(main_device)
|
||||
# torch.cuda.set_device can't set "cuda", must have a index
|
||||
if main_device == "cuda":
|
||||
main_device = "cuda:0"
|
||||
torch.cuda.set_device(main_device)
|
||||
self.main_device = main_device
|
||||
capture_stream = torch.cuda.Stream()
|
||||
with torch.cuda.graph(self.graph, stream = capture_stream):
|
||||
logits=model(inputs_embeds=inputs_embeds,
|
||||
position_ids=position_ids,
|
||||
cache_position=cache_position,
|
||||
past_key_values=past_key_values,
|
||||
**kwargs)[0]
|
||||
capture_stream.wait_stream(torch.cuda.current_stream())
|
||||
torch.cuda.set_device(main_device)
|
||||
torch.cuda.set_stream(capture_stream)
|
||||
past_key_values.change_seq_length(-1)
|
||||
torch.cuda.synchronize()
|
||||
torch.cuda.synchronize(self.main_device)
|
||||
#self.graph.debug_dump("cuda_graph_hooked.dot")
|
||||
|
||||
# Save the input and output buffers.
|
||||
|
@ -65,7 +75,7 @@ class CUDAGraphRunner:
|
|||
#print("begin replay")
|
||||
#time.sleep(1)
|
||||
self.graph.replay()
|
||||
torch.cuda.synchronize()
|
||||
torch.cuda.synchronize(self.main_device)
|
||||
# Return the output tensor.
|
||||
return self.output_buffers["logits"]
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue