Take max token value from the selected model (#68)

This commit is contained in:
Carl-Robert Linnupuu 2023-04-16 21:55:38 +01:00
parent cc66ed9f74
commit fd072fd673
2 changed files with 15 additions and 12 deletions

View file

@ -8,6 +8,7 @@ import ee.carlrobert.codegpt.EncodingManager;
import ee.carlrobert.codegpt.state.conversations.Conversation;
import ee.carlrobert.codegpt.state.conversations.ConversationsState;
import ee.carlrobert.codegpt.state.settings.SettingsState;
import ee.carlrobert.openai.client.completion.chat.ChatCompletionModel;
import ee.carlrobert.openai.client.completion.chat.request.ChatCompletionMessage;
import ee.carlrobert.openai.client.completion.chat.request.ChatCompletionRequest;
import ee.carlrobert.openai.client.completion.text.TextCompletionModel;
@ -18,6 +19,9 @@ import java.util.Objects;
class CompletionRequestProvider {
private static final int MAX_COMPLETION_TOKENS = 1000;
private final SettingsState settings = SettingsState.getInstance();
private final EncodingManager encodingManager = EncodingManager.getInstance();
private final String prompt;
private final Conversation conversation;
@ -28,15 +32,17 @@ class CompletionRequestProvider {
}
public ChatCompletionRequest buildChatCompletionRequest(String model) {
return new ChatCompletionRequest.Builder(buildMessages())
return (ChatCompletionRequest) new ChatCompletionRequest.Builder(buildMessages())
.setModel(model)
.setMaxTokens(MAX_COMPLETION_TOKENS)
.build();
}
public TextCompletionRequest buildTextCompletionRequest(String model) {
return new TextCompletionRequest.Builder(buildPrompt(model))
return (TextCompletionRequest) new TextCompletionRequest.Builder(buildPrompt(model))
.setStop(List.of(" Human:", " AI:"))
.setModel(model)
.setMaxTokens(MAX_COMPLETION_TOKENS)
.build();
}
@ -51,11 +57,9 @@ class CompletionRequestProvider {
});
messages.add(new ChatCompletionMessage("user", prompt));
var settingsState = SettingsState.getInstance();
// TODO: Add support for other models
var isSeamlessConversationSupported = settingsState.isChatCompletionOptionSelected &&
List.of(GPT_3_5.getCode(), GPT_3_5_SNAPSHOT.getCode()).contains(settingsState.chatCompletionBaseModel);
var isSeamlessConversationSupported = settings.isChatCompletionOptionSelected &&
List.of(GPT_3_5.getCode(), GPT_3_5_SNAPSHOT.getCode()).contains(settings.chatCompletionBaseModel);
if (isSeamlessConversationSupported) {
return tryReducingMessagesOrThrow(messages);
}
@ -63,11 +67,10 @@ class CompletionRequestProvider {
}
private List<ChatCompletionMessage> tryReducingMessagesOrThrow(List<ChatCompletionMessage> messages) {
int MAX_TOKEN_LIMIT = 4097;
int totalMessagesUsage = messages.parallelStream().mapToInt(encodingManager::countMessageTokens).sum();
int totalUsage = totalMessagesUsage + 1000; // 1000 - max completion token size (currently not customizable)
var modelMaxTokens = ChatCompletionModel.findByCode(settings.chatCompletionBaseModel).getMaxTokens();
int totalUsage = messages.parallelStream().mapToInt(encodingManager::countMessageTokens).sum() + MAX_COMPLETION_TOKENS;
if (totalUsage <= MAX_TOKEN_LIMIT) {
if (totalUsage <= modelMaxTokens) {
return messages;
}
@ -79,7 +82,7 @@ class CompletionRequestProvider {
// skip the system prompt
for (int i = 1; i < messages.size(); i++) {
if (totalUsage <= MAX_TOKEN_LIMIT) {
if (totalUsage <= modelMaxTokens) {
break;
}