toy support for experts on GPU, no CUDA Graph

This commit is contained in:
Atream 2025-02-15 15:16:00 +00:00
parent 1548c99234
commit c189d55bd1
6 changed files with 199 additions and 65 deletions

View file

@ -119,7 +119,7 @@ class KLinearTorch(KLinearBase):
super().__init__(key, gguf_loader, config, orig_module, device, **kwargs)
self.has_bias = False
self.dtype = torch.get_default_dtype()
self.w = None
self.weight = None
self.has_bias = False
def forward(self, x: torch.Tensor) -> torch.Tensor:
@ -127,7 +127,7 @@ class KLinearTorch(KLinearBase):
out_device = x.device
# TODO: support CUDA Graph when using cpu, but CPUInfer is recommended.
x = x.to(device=self.device, dtype=self.dtype)
x = x @ self.w
x = x @ self.weight
if self.has_bias:
x = x + self.bias
x = x.to(dtype=dtype, device=out_device)
@ -140,27 +140,27 @@ class KLinearTorch(KLinearBase):
if isinstance(w, nn.Parameter):
try:
self.w = w.to(dtype=self.dtype).view(self.out_features, self.in_features).T
self.weight = w.to(dtype=self.dtype).view(self.out_features, self.in_features).T
except:
self.w = w.to(dtype=self.dtype).T
self.weight = w.to(dtype=self.dtype).T
self.has_bias = False
elif isinstance(w, tuple):
try:
self.w = w[0].to(dtype=self.dtype).view(self.out_features, self.in_features).T
self.weight = w[0].to(dtype=self.dtype).view(self.out_features, self.in_features).T
except:
self.w = w[0].to(dtype=self.dtype).T
self.weight = w[0].to(dtype=self.dtype).T
self.bias = w[1].to(dtype=self.dtype)
self.has_bias = True
else:
raise ValueError("Invalid weight type")
# self.linear = self.linear.to(device)
self.w = self.w.to(device)
self.weight = self.weight.to(device)
if self.has_bias:
self.bias = self.bias.to(device)
def unload(self):
if self.w is not None:
self.w = None
if self.weight is not None:
self.weight = None
if self.has_bias:
self.bias = None
@ -218,6 +218,7 @@ class KLinearMarlin(KLinearBase):
self.workspace = MarlinWorkspace(
self.out_features, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL,self.device
)
self.weight = marlin_q_w # modeling_xxx.py may use linear.weight
self.marlin_q_w = marlin_q_w
self.marlin_s = marlin_s
self.g_idx = g_idx
@ -424,11 +425,13 @@ class KTransformersLinear(BaseInjectedModule, KLinearBase):
if mode == InferenceState.PREFILL:
self.generate_linear.unload()
self.prefill_linear.load(w=w)
self.device = self.prefill_linear.device
self.device = self.prefill_linear.device
self.weight = self.prefill_linear.weight # modeling_xxx.py may use linear.weight
elif mode == InferenceState.GENERATE:
self.prefill_linear.unload()
self.generate_linear.load(w=w)
self.device = self.generate_linear.device
self.weight = self.generate_linear.weight # modeling_xxx.py may use linear.weight
elif mode == InferenceState.UNLOAD:
self.prefill_linear.unload()
self.generate_linear.unload()