diff --git a/ktransformers/server/backend/base.py b/ktransformers/server/backend/base.py index 5148a48..aa011bf 100644 --- a/ktransformers/server/backend/base.py +++ b/ktransformers/server/backend/base.py @@ -15,6 +15,7 @@ from ktransformers.server.schemas.assistants.assistants import AssistantObject from ktransformers.server.schemas.assistants.messages import MessageCreate, MessageObject, Role from ktransformers.server.schemas.assistants.runs import RunObject from ktransformers.server.schemas.assistants.threads import ThreadObject +from ktransformers.server.schemas.endpoints.chat import RawUsage from ktransformers.server.schemas.base import ObjectID, Order from ktransformers.server.utils.multi_timer import Profiler @@ -142,12 +143,16 @@ class ThreadContext: yield reply_message.stream_response_with_event(MessageObject.Status.in_progress) yield self.run.stream_response_with_event(RunObject.Status.in_progress) - async for token, finish_reason in self.interface.inference(local_messages,self.thread.id): - if self.run.status == RunObject.Status.cancelling: - logger.warn(f'Run {self.run.id} cancelling') - break - yield reply_message.append_message_delta(token) - response_str_count+=1 + async for res in self.interface.inference(local_messages,self.thread.id): + if isinstance(res, RawUsage): + raw_usage = res + else: + token, finish_reason = res + if self.run.status == RunObject.Status.cancelling: + logger.warn(f'Run {self.run.id} cancelling') + break + yield reply_message.append_message_delta(token) + response_str_count+=1 if self.run.status == RunObject.Status.cancelling: yield self.run.stream_response_with_event(RunObject.Status.cancelled)