mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-09 13:55:27 +00:00
toy support for experts on GPU, no CUDA Graph
This commit is contained in:
parent
1548c99234
commit
c189d55bd1
6 changed files with 199 additions and 65 deletions
|
@ -119,7 +119,7 @@ class KLinearTorch(KLinearBase):
|
|||
super().__init__(key, gguf_loader, config, orig_module, device, **kwargs)
|
||||
self.has_bias = False
|
||||
self.dtype = torch.get_default_dtype()
|
||||
self.w = None
|
||||
self.weight = None
|
||||
self.has_bias = False
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
|
@ -127,7 +127,7 @@ class KLinearTorch(KLinearBase):
|
|||
out_device = x.device
|
||||
# TODO: support CUDA Graph when using cpu, but CPUInfer is recommended.
|
||||
x = x.to(device=self.device, dtype=self.dtype)
|
||||
x = x @ self.w
|
||||
x = x @ self.weight
|
||||
if self.has_bias:
|
||||
x = x + self.bias
|
||||
x = x.to(dtype=dtype, device=out_device)
|
||||
|
@ -140,27 +140,27 @@ class KLinearTorch(KLinearBase):
|
|||
|
||||
if isinstance(w, nn.Parameter):
|
||||
try:
|
||||
self.w = w.to(dtype=self.dtype).view(self.out_features, self.in_features).T
|
||||
self.weight = w.to(dtype=self.dtype).view(self.out_features, self.in_features).T
|
||||
except:
|
||||
self.w = w.to(dtype=self.dtype).T
|
||||
self.weight = w.to(dtype=self.dtype).T
|
||||
self.has_bias = False
|
||||
elif isinstance(w, tuple):
|
||||
try:
|
||||
self.w = w[0].to(dtype=self.dtype).view(self.out_features, self.in_features).T
|
||||
self.weight = w[0].to(dtype=self.dtype).view(self.out_features, self.in_features).T
|
||||
except:
|
||||
self.w = w[0].to(dtype=self.dtype).T
|
||||
self.weight = w[0].to(dtype=self.dtype).T
|
||||
self.bias = w[1].to(dtype=self.dtype)
|
||||
self.has_bias = True
|
||||
else:
|
||||
raise ValueError("Invalid weight type")
|
||||
# self.linear = self.linear.to(device)
|
||||
self.w = self.w.to(device)
|
||||
self.weight = self.weight.to(device)
|
||||
if self.has_bias:
|
||||
self.bias = self.bias.to(device)
|
||||
|
||||
def unload(self):
|
||||
if self.w is not None:
|
||||
self.w = None
|
||||
if self.weight is not None:
|
||||
self.weight = None
|
||||
if self.has_bias:
|
||||
self.bias = None
|
||||
|
||||
|
@ -218,6 +218,7 @@ class KLinearMarlin(KLinearBase):
|
|||
self.workspace = MarlinWorkspace(
|
||||
self.out_features, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL,self.device
|
||||
)
|
||||
self.weight = marlin_q_w # modeling_xxx.py may use linear.weight
|
||||
self.marlin_q_w = marlin_q_w
|
||||
self.marlin_s = marlin_s
|
||||
self.g_idx = g_idx
|
||||
|
@ -424,11 +425,13 @@ class KTransformersLinear(BaseInjectedModule, KLinearBase):
|
|||
if mode == InferenceState.PREFILL:
|
||||
self.generate_linear.unload()
|
||||
self.prefill_linear.load(w=w)
|
||||
self.device = self.prefill_linear.device
|
||||
self.device = self.prefill_linear.device
|
||||
self.weight = self.prefill_linear.weight # modeling_xxx.py may use linear.weight
|
||||
elif mode == InferenceState.GENERATE:
|
||||
self.prefill_linear.unload()
|
||||
self.generate_linear.load(w=w)
|
||||
self.device = self.generate_linear.device
|
||||
self.weight = self.generate_linear.weight # modeling_xxx.py may use linear.weight
|
||||
elif mode == InferenceState.UNLOAD:
|
||||
self.prefill_linear.unload()
|
||||
self.generate_linear.unload()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue