mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-10 06:14:58 +00:00
Initial commit
This commit is contained in:
commit
18c42e67df
247 changed files with 53775 additions and 0 deletions
84
ktransformers/util/textstream.py
Normal file
84
ktransformers/util/textstream.py
Normal file
|
@ -0,0 +1,84 @@
|
|||
from typing import Any, List, Optional, Set
|
||||
class TextStreamer:
|
||||
|
||||
def __init__(self, tokenizer: "AutoTokenizer", skip_prompt: bool = False, **decode_kwargs):
|
||||
self.tokenizer = tokenizer
|
||||
self.skip_prompt = skip_prompt
|
||||
self.decode_kwargs = decode_kwargs
|
||||
|
||||
# variables used in the streaming process
|
||||
self.token_cache = []
|
||||
self.print_len = 0
|
||||
self.next_tokens_are_prompt = True
|
||||
|
||||
def reset(self):
|
||||
self.token_cache = []
|
||||
self.print_len = 0
|
||||
|
||||
def put(self, value)->Optional[str]:
|
||||
"""
|
||||
Receives tokens, decodes them, and prints them to stdout as soon as they form entire words.
|
||||
"""
|
||||
if not isinstance(value,int):
|
||||
raise ValueError("TextStreamer only supports batch size 1, and int type input")
|
||||
|
||||
|
||||
if self.skip_prompt and self.next_tokens_are_prompt:
|
||||
self.next_tokens_are_prompt = False
|
||||
return None
|
||||
|
||||
# Add the new token to the cache and decodes the entire thing.
|
||||
self.token_cache.append(value)
|
||||
text = self.tokenizer.decode(self.token_cache, skip_special_tokens=True,**self.decode_kwargs)
|
||||
|
||||
# After the symbol for a new line, we flush the cache.
|
||||
if text.endswith("\n"):
|
||||
printable_text = text[self.print_len :]
|
||||
self.reset()
|
||||
# If the last token is a CJK character, we print the characters.
|
||||
elif len(text) > 0 and self._is_chinese_char(ord(text[-1])):
|
||||
printable_text = text[self.print_len :]
|
||||
self.print_len += len(printable_text)
|
||||
# Otherwise, prints until the last space char (simple heuristic to avoid printing incomplete words,
|
||||
# which may change with the subsequent token -- there are probably smarter ways to do this!)
|
||||
else:
|
||||
printable_text = text[self.print_len : text.rfind(" ") + 1]
|
||||
self.print_len += len(printable_text)
|
||||
return printable_text
|
||||
|
||||
def end(self)->Optional[str]:
|
||||
"""Flushes any remaining cache and prints a newline to stdout."""
|
||||
# Flush the cache, if it exists
|
||||
if len(self.token_cache) > 0:
|
||||
text = self.tokenizer.decode(self.token_cache, skip_special_tokens=True, **self.decode_kwargs)
|
||||
printable_text = text[self.print_len :]
|
||||
self.reset()
|
||||
else:
|
||||
printable_text = ""
|
||||
|
||||
self.next_tokens_are_prompt = True
|
||||
return printable_text
|
||||
|
||||
def _is_chinese_char(self, cp):
|
||||
"""Checks whether CP is the codepoint of a CJK character."""
|
||||
# This defines a "chinese character" as anything in the CJK Unicode block:
|
||||
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
|
||||
#
|
||||
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
|
||||
# despite its name. The modern Korean Hangul alphabet is a different block,
|
||||
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
|
||||
# space-separated words, so they are not treated specially and handled
|
||||
# like the all of the other languages.
|
||||
if (
|
||||
(cp >= 0x4E00 and cp <= 0x9FFF)
|
||||
or (cp >= 0x3400 and cp <= 0x4DBF) #
|
||||
or (cp >= 0x20000 and cp <= 0x2A6DF) #
|
||||
or (cp >= 0x2A700 and cp <= 0x2B73F) #
|
||||
or (cp >= 0x2B740 and cp <= 0x2B81F) #
|
||||
or (cp >= 0x2B820 and cp <= 0x2CEAF) #
|
||||
or (cp >= 0xF900 and cp <= 0xFAFF)
|
||||
or (cp >= 0x2F800 and cp <= 0x2FA1F) #
|
||||
): #
|
||||
return True
|
||||
|
||||
return False
|
Loading…
Add table
Add a link
Reference in a new issue