diff --git a/ktransformers/server/api/ollama/completions.py b/ktransformers/server/api/ollama/completions.py index e3a1a51..d0ac17e 100644 --- a/ktransformers/server/api/ollama/completions.py +++ b/ktransformers/server/api/ollama/completions.py @@ -12,8 +12,8 @@ from ktransformers.server.config.config import Config from ktransformers.server.utils.create_interface import get_interface from ktransformers.server.schemas.assistants.streaming import check_link_response from ktransformers.server.backend.base import BackendInterfaceBase -router = APIRouter(prefix='/api') +router = APIRouter(prefix='/api') # https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-completion class OllamaGenerateCompletionRequest(BaseModel): @@ -40,61 +40,95 @@ class OllamaGenerateCompletionRequest(BaseModel): keep_alive: Optional[str] = Field( "5m", description="Controls how long the model will stay loaded into memory following the request.") - class OllamaGenerationStreamResponse(BaseModel): model: str created_at: str response: str done: bool = Field(...) - class OllamaGenerationResponse(BaseModel): pass - @router.post("/generate", tags=['ollama']) async def generate(request: Request, input: OllamaGenerateCompletionRequest): id = str(uuid4()) - interface: BackendInterfaceBase = get_interface() print(f'COMPLETION INPUT:----\n{input.prompt}\n----') - config = Config() if input.stream: async def inner(): - async for token in interface.inference(input.prompt,id): - d = OllamaGenerationStreamResponse(model=config.model_name,created_at=str(datetime.now()),response=token,done=False) - yield d.model_dump_json()+'\n' - # d = {'model':config.model_name,'created_at':"", 'response':token,'done':False} - # yield f"{json.dumps(d)}\n" - # d = {'model':config.model_name,'created_at':"", 'response':'','done':True} - # yield f"{json.dumps(d)}\n" - d = OllamaGenerationStreamResponse(model=config.model_name,created_at=str(datetime.now()),response='',done=True) - yield d.model_dump_json()+'\n' - return check_link_response(request,inner()) + async for token in interface.inference(input.prompt, id): + d = OllamaGenerationStreamResponse( + model=config.model_name, + created_at=str(datetime.now()), + response=token, + done=False + ) + yield d.model_dump_json() + '\n' + d = OllamaGenerationStreamResponse( + model=config.model_name, + created_at=str(datetime.now()), + response='', + done=True + ) + yield d.model_dump_json() + '\n' + return check_link_response(request, inner()) else: raise NotImplementedError # https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-chat-completion - +class OllamaChatCompletionMessage(BaseModel): + role: str + content: str class OllamaChatCompletionRequest(BaseModel): - pass - + model: str = Field(..., description="The model name, which is required.") + messages: List[OllamaChatCompletionMessage] = Field( + ..., description="A list of messages to generate a response for.") + stream: bool = Field(True, description="If true, the response will be streamed.") class OllamaChatCompletionStreamResponse(BaseModel): - pass - + model: str + created_at: str + message: str + done: bool = Field(...) class OllamaChatCompletionResponse(BaseModel): pass - @router.post("/chat", tags=['ollama']) async def chat(request: Request, input: OllamaChatCompletionRequest): - raise NotImplementedError + id = str(uuid4()) + interface: BackendInterfaceBase = get_interface() + config = Config() + # 将消息转换为提示字符串 + prompt = "" + for msg in input.messages: + prompt += f"{msg.role}: {msg.content}\n" + prompt += "assistant:" + + if input.stream: + async def inner(): + async for token in interface.inference(prompt, id): + d = OllamaChatCompletionStreamResponse( + model=config.model_name, + created_at=str(datetime.now()), + message=token, + done=False + ) + yield d.model_dump_json() + '\n' + d = OllamaChatCompletionStreamResponse( + model=config.model_name, + created_at=str(datetime.now()), + message='', + done=True + ) + yield d.model_dump_json() + '\n' + return check_link_response(request, inner()) + else: + raise NotImplementedError("Non-streaming chat is not implemented.") # https://github.com/ollama/ollama/blob/main/docs/api.md#list-local-models class OllamaModel(BaseModel): @@ -103,9 +137,8 @@ class OllamaModel(BaseModel): size: int # TODO: fill the rest correctly - # mock ollama -@router.get("/tags",tags=['ollama']) +@router.get("/tags", tags=['ollama']) async def tags(): config = Config() # TODO: fill this correctly, although it does not effect Tabby @@ -138,25 +171,21 @@ class OllamaShowResponse(BaseModel): class Config: protected_namespaces = () - - @router.post("/show", tags=['ollama']) async def show(request: Request, input: OllamaShowRequest): config = Config() # TODO: Add more info in config to return, although it does not effect Tabby return OllamaShowResponse( - modelfile = "# Modelfile generated by ...", - parameters = " ", - template = " ", - details = OllamaShowDetial( - parent_model = " ", - format = "gguf", - family = " ", - families = [ - " " - ], - parameter_size = " ", - quantization_level = " " + modelfile="# Modelfile generated by ...", + parameters=" ", + template=" ", + details=OllamaShowDetial( + parent_model=" ", + format="gguf", + family=" ", + families=[" "], + parameter_size=" ", + quantization_level=" " ), - model_info = OllamaModelInfo() + model_info=OllamaModelInfo() ) \ No newline at end of file