mirror of
https://github.com/carlrobertoh/ProxyAI.git
synced 2026-05-09 11:01:22 +00:00
Take max token value from the selected model (#68)
This commit is contained in:
parent
cc66ed9f74
commit
fd072fd673
2 changed files with 15 additions and 12 deletions
|
|
@ -8,6 +8,7 @@ import ee.carlrobert.codegpt.EncodingManager;
|
|||
import ee.carlrobert.codegpt.state.conversations.Conversation;
|
||||
import ee.carlrobert.codegpt.state.conversations.ConversationsState;
|
||||
import ee.carlrobert.codegpt.state.settings.SettingsState;
|
||||
import ee.carlrobert.openai.client.completion.chat.ChatCompletionModel;
|
||||
import ee.carlrobert.openai.client.completion.chat.request.ChatCompletionMessage;
|
||||
import ee.carlrobert.openai.client.completion.chat.request.ChatCompletionRequest;
|
||||
import ee.carlrobert.openai.client.completion.text.TextCompletionModel;
|
||||
|
|
@ -18,6 +19,9 @@ import java.util.Objects;
|
|||
|
||||
class CompletionRequestProvider {
|
||||
|
||||
private static final int MAX_COMPLETION_TOKENS = 1000;
|
||||
|
||||
private final SettingsState settings = SettingsState.getInstance();
|
||||
private final EncodingManager encodingManager = EncodingManager.getInstance();
|
||||
private final String prompt;
|
||||
private final Conversation conversation;
|
||||
|
|
@ -28,15 +32,17 @@ class CompletionRequestProvider {
|
|||
}
|
||||
|
||||
public ChatCompletionRequest buildChatCompletionRequest(String model) {
|
||||
return new ChatCompletionRequest.Builder(buildMessages())
|
||||
return (ChatCompletionRequest) new ChatCompletionRequest.Builder(buildMessages())
|
||||
.setModel(model)
|
||||
.setMaxTokens(MAX_COMPLETION_TOKENS)
|
||||
.build();
|
||||
}
|
||||
|
||||
public TextCompletionRequest buildTextCompletionRequest(String model) {
|
||||
return new TextCompletionRequest.Builder(buildPrompt(model))
|
||||
return (TextCompletionRequest) new TextCompletionRequest.Builder(buildPrompt(model))
|
||||
.setStop(List.of(" Human:", " AI:"))
|
||||
.setModel(model)
|
||||
.setMaxTokens(MAX_COMPLETION_TOKENS)
|
||||
.build();
|
||||
}
|
||||
|
||||
|
|
@ -51,11 +57,9 @@ class CompletionRequestProvider {
|
|||
});
|
||||
messages.add(new ChatCompletionMessage("user", prompt));
|
||||
|
||||
var settingsState = SettingsState.getInstance();
|
||||
|
||||
// TODO: Add support for other models
|
||||
var isSeamlessConversationSupported = settingsState.isChatCompletionOptionSelected &&
|
||||
List.of(GPT_3_5.getCode(), GPT_3_5_SNAPSHOT.getCode()).contains(settingsState.chatCompletionBaseModel);
|
||||
var isSeamlessConversationSupported = settings.isChatCompletionOptionSelected &&
|
||||
List.of(GPT_3_5.getCode(), GPT_3_5_SNAPSHOT.getCode()).contains(settings.chatCompletionBaseModel);
|
||||
if (isSeamlessConversationSupported) {
|
||||
return tryReducingMessagesOrThrow(messages);
|
||||
}
|
||||
|
|
@ -63,11 +67,10 @@ class CompletionRequestProvider {
|
|||
}
|
||||
|
||||
private List<ChatCompletionMessage> tryReducingMessagesOrThrow(List<ChatCompletionMessage> messages) {
|
||||
int MAX_TOKEN_LIMIT = 4097;
|
||||
int totalMessagesUsage = messages.parallelStream().mapToInt(encodingManager::countMessageTokens).sum();
|
||||
int totalUsage = totalMessagesUsage + 1000; // 1000 - max completion token size (currently not customizable)
|
||||
var modelMaxTokens = ChatCompletionModel.findByCode(settings.chatCompletionBaseModel).getMaxTokens();
|
||||
int totalUsage = messages.parallelStream().mapToInt(encodingManager::countMessageTokens).sum() + MAX_COMPLETION_TOKENS;
|
||||
|
||||
if (totalUsage <= MAX_TOKEN_LIMIT) {
|
||||
if (totalUsage <= modelMaxTokens) {
|
||||
return messages;
|
||||
}
|
||||
|
||||
|
|
@ -79,7 +82,7 @@ class CompletionRequestProvider {
|
|||
|
||||
// skip the system prompt
|
||||
for (int i = 1; i < messages.size(); i++) {
|
||||
if (totalUsage <= MAX_TOKEN_LIMIT) {
|
||||
if (totalUsage <= modelMaxTokens) {
|
||||
break;
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue