diff --git a/build.gradle.kts b/build.gradle.kts index ec11230c..1c1266d1 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -22,7 +22,7 @@ dependencies { implementation("com.fifesoft:rsyntaxtextarea:3.3.2") implementation("com.vladsch.flexmark:flexmark-all:0.64.0") implementation("org.apache.commons:commons-text:1.10.0") - implementation("ee.carlrobert:openai-client:1.0.5") + implementation("ee.carlrobert:openai-client:1.0.6") implementation("com.knuddels:jtokkit:0.2.0") } diff --git a/src/main/java/ee/carlrobert/codegpt/client/CompletionRequestProvider.java b/src/main/java/ee/carlrobert/codegpt/client/CompletionRequestProvider.java index b16b81ab..a8c5e7d7 100644 --- a/src/main/java/ee/carlrobert/codegpt/client/CompletionRequestProvider.java +++ b/src/main/java/ee/carlrobert/codegpt/client/CompletionRequestProvider.java @@ -8,6 +8,7 @@ import ee.carlrobert.codegpt.EncodingManager; import ee.carlrobert.codegpt.state.conversations.Conversation; import ee.carlrobert.codegpt.state.conversations.ConversationsState; import ee.carlrobert.codegpt.state.settings.SettingsState; +import ee.carlrobert.openai.client.completion.chat.ChatCompletionModel; import ee.carlrobert.openai.client.completion.chat.request.ChatCompletionMessage; import ee.carlrobert.openai.client.completion.chat.request.ChatCompletionRequest; import ee.carlrobert.openai.client.completion.text.TextCompletionModel; @@ -18,6 +19,9 @@ import java.util.Objects; class CompletionRequestProvider { + private static final int MAX_COMPLETION_TOKENS = 1000; + + private final SettingsState settings = SettingsState.getInstance(); private final EncodingManager encodingManager = EncodingManager.getInstance(); private final String prompt; private final Conversation conversation; @@ -28,15 +32,17 @@ class CompletionRequestProvider { } public ChatCompletionRequest buildChatCompletionRequest(String model) { - return new ChatCompletionRequest.Builder(buildMessages()) + return (ChatCompletionRequest) new ChatCompletionRequest.Builder(buildMessages()) .setModel(model) + .setMaxTokens(MAX_COMPLETION_TOKENS) .build(); } public TextCompletionRequest buildTextCompletionRequest(String model) { - return new TextCompletionRequest.Builder(buildPrompt(model)) + return (TextCompletionRequest) new TextCompletionRequest.Builder(buildPrompt(model)) .setStop(List.of(" Human:", " AI:")) .setModel(model) + .setMaxTokens(MAX_COMPLETION_TOKENS) .build(); } @@ -51,11 +57,9 @@ class CompletionRequestProvider { }); messages.add(new ChatCompletionMessage("user", prompt)); - var settingsState = SettingsState.getInstance(); - // TODO: Add support for other models - var isSeamlessConversationSupported = settingsState.isChatCompletionOptionSelected && - List.of(GPT_3_5.getCode(), GPT_3_5_SNAPSHOT.getCode()).contains(settingsState.chatCompletionBaseModel); + var isSeamlessConversationSupported = settings.isChatCompletionOptionSelected && + List.of(GPT_3_5.getCode(), GPT_3_5_SNAPSHOT.getCode()).contains(settings.chatCompletionBaseModel); if (isSeamlessConversationSupported) { return tryReducingMessagesOrThrow(messages); } @@ -63,11 +67,10 @@ class CompletionRequestProvider { } private List tryReducingMessagesOrThrow(List messages) { - int MAX_TOKEN_LIMIT = 4097; - int totalMessagesUsage = messages.parallelStream().mapToInt(encodingManager::countMessageTokens).sum(); - int totalUsage = totalMessagesUsage + 1000; // 1000 - max completion token size (currently not customizable) + var modelMaxTokens = ChatCompletionModel.findByCode(settings.chatCompletionBaseModel).getMaxTokens(); + int totalUsage = messages.parallelStream().mapToInt(encodingManager::countMessageTokens).sum() + MAX_COMPLETION_TOKENS; - if (totalUsage <= MAX_TOKEN_LIMIT) { + if (totalUsage <= modelMaxTokens) { return messages; } @@ -79,7 +82,7 @@ class CompletionRequestProvider { // skip the system prompt for (int i = 1; i < messages.size(); i++) { - if (totalUsage <= MAX_TOKEN_LIMIT) { + if (totalUsage <= modelMaxTokens) { break; }