Merge pull request #330 from hrz6976/fix-nonetype

thanks..., I was about to submit and found that you had already modified it. Thank you for your contribution
This commit is contained in:
wang jiahao 2025-02-15 22:50:51 +08:00 committed by GitHub
commit ae8da019c1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -172,7 +172,8 @@ class StaticCache(transformers.StaticCache):
for layer_idx in range(len(self.key_cache)):
# In-place ops prevent breaking the static address
self.key_cache[layer_idx].zero_()
self.value_cache[layer_idx].zero_()
if self.value_cache[layer_idx] is not None:
self.value_cache[layer_idx].zero_()
def get_max_cache_shape(self) -> Tuple[int, int, int, int]:
"""Returns the maximum shape of the cache."""