diff --git a/ktransformers/operators/experts.py b/ktransformers/operators/experts.py index 10e3a66..88960c7 100644 --- a/ktransformers/operators/experts.py +++ b/ktransformers/operators/experts.py @@ -459,9 +459,9 @@ class KExpertsTorch(KExpertsBase): self.up[i] = w["up"][i, ...].to(device=device, dtype=self.dtype) self.down[i] = w["down"][i, ...].to(device=device, dtype=self.dtype) - self.up = torch.cat(self.up, dim=0) - self.gate = torch.cat(self.gate, dim=0) - self.down = torch.cat(self.down, dim=0) + self.up = torch.stack(self.up, dim=0) + self.gate = torch.stack(self.gate, dim=0) + self.down = torch.stack(self.down, dim=0) return def unload(self):