[ADD] support multi-gpu qlen>1 q5_k

This commit is contained in:
chenxl 2024-08-12 11:17:29 +00:00
parent f293803156
commit f5f79f5c0e
63 changed files with 3271 additions and 1285 deletions

View file

@ -21,6 +21,7 @@ class CUDAGraphRunner:
position_ids,
cache_position,
past_key_values,
main_device,
**kwargs,
) -> None:
assert self.graph is None
@ -29,15 +30,24 @@ class CUDAGraphRunner:
self.graph = torch.cuda.CUDAGraph()
#self.graph.enable_debug_mode()
self.model = model
inputs_embeds = model.model.embed_tokens(cur_token.to("cpu")).to("cuda")
with torch.cuda.graph(self.graph):
inputs_embeds = model.model.embed_tokens(cur_token.to("cpu")).to(main_device)
# torch.cuda.set_device can't set "cuda", must have a index
if main_device == "cuda":
main_device = "cuda:0"
torch.cuda.set_device(main_device)
self.main_device = main_device
capture_stream = torch.cuda.Stream()
with torch.cuda.graph(self.graph, stream = capture_stream):
logits=model(inputs_embeds=inputs_embeds,
position_ids=position_ids,
cache_position=cache_position,
past_key_values=past_key_values,
**kwargs)[0]
capture_stream.wait_stream(torch.cuda.current_stream())
torch.cuda.set_device(main_device)
torch.cuda.set_stream(capture_stream)
past_key_values.change_seq_length(-1)
torch.cuda.synchronize()
torch.cuda.synchronize(self.main_device)
#self.graph.debug_dump("cuda_graph_hooked.dot")
# Save the input and output buffers.
@ -65,7 +75,7 @@ class CUDAGraphRunner:
#print("begin replay")
#time.sleep(1)
self.graph.replay()
torch.cuda.synchronize()
torch.cuda.synchronize(self.main_device)
# Return the output tensor.
return self.output_buffers["logits"]