diff --git a/bytelatent/model/local_models.py b/bytelatent/model/local_models.py index 09a5a19..7083ac4 100644 --- a/bytelatent/model/local_models.py +++ b/bytelatent/model/local_models.py @@ -308,8 +308,7 @@ class LocalEncoder(LocalModelBase): kv=h, mask=cross_mask, ) - patch_embeds += patch_embeds_cross - return patch_embeds + return patch_embeds + patch_embeds_cross class LocalDecoder(LocalModelBase):