from typing import Any, List, Optional, Set class TextStreamer: def __init__(self, tokenizer: "AutoTokenizer", skip_prompt: bool = False, **decode_kwargs): self.tokenizer = tokenizer self.skip_prompt = skip_prompt self.decode_kwargs = decode_kwargs # variables used in the streaming process self.token_cache = [] self.print_len = 0 self.next_tokens_are_prompt = True def reset(self): self.token_cache = [] self.print_len = 0 def put(self, value)->Optional[str]: """ Receives tokens, decodes them, and prints them to stdout as soon as they form entire words. """ if not isinstance(value,int): raise ValueError("TextStreamer only supports batch size 1, and int type input") if self.skip_prompt and self.next_tokens_are_prompt: self.next_tokens_are_prompt = False return None # Add the new token to the cache and decodes the entire thing. self.token_cache.append(value) text = self.tokenizer.decode(self.token_cache, skip_special_tokens=True,**self.decode_kwargs) # After the symbol for a new line, we flush the cache. if text.endswith("\n"): printable_text = text[self.print_len :] self.reset() # If the last token is a CJK character, we print the characters. elif len(text) > 0 and self._is_chinese_char(ord(text[-1])): printable_text = text[self.print_len :] self.print_len += len(printable_text) # Otherwise, prints until the last space char (simple heuristic to avoid printing incomplete words, # which may change with the subsequent token -- there are probably smarter ways to do this!) else: printable_text = text[self.print_len : text.rfind(" ") + 1] self.print_len += len(printable_text) return printable_text def end(self)->Optional[str]: """Flushes any remaining cache and prints a newline to stdout.""" # Flush the cache, if it exists if len(self.token_cache) > 0: text = self.tokenizer.decode(self.token_cache, skip_special_tokens=True, **self.decode_kwargs) printable_text = text[self.print_len :] self.reset() else: printable_text = "" self.next_tokens_are_prompt = True return printable_text def _is_chinese_char(self, cp): """Checks whether CP is the codepoint of a CJK character.""" # This defines a "chinese character" as anything in the CJK Unicode block: # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) # # Note that the CJK Unicode block is NOT all Japanese and Korean characters, # despite its name. The modern Korean Hangul alphabet is a different block, # as is Japanese Hiragana and Katakana. Those alphabets are used to write # space-separated words, so they are not treated specially and handled # like the all of the other languages. if ( (cp >= 0x4E00 and cp <= 0x9FFF) or (cp >= 0x3400 and cp <= 0x4DBF) # or (cp >= 0x20000 and cp <= 0x2A6DF) # or (cp >= 0x2A700 and cp <= 0x2B73F) # or (cp >= 0x2B740 and cp <= 0x2B81F) # or (cp >= 0x2B820 and cp <= 0x2CEAF) # or (cp >= 0xF900 and cp <= 0xFAFF) or (cp >= 0x2F800 and cp <= 0x2FA1F) # ): # return True return False