chore: Improve code (#442)

* chore: Improve code

* Convert classes to records
This commit is contained in:
Rene Leonhardt 2024-04-10 13:47:38 +02:00 committed by GitHub
parent c29d3928db
commit 7d89650062
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
86 changed files with 528 additions and 976 deletions

View file

@ -9,11 +9,13 @@ import com.knuddels.jtokkit.api.EncodingRegistry;
import com.knuddels.jtokkit.api.EncodingType;
import com.knuddels.jtokkit.api.IntArrayList;
import ee.carlrobert.codegpt.conversations.Conversation;
import ee.carlrobert.codegpt.conversations.message.Message;
import ee.carlrobert.llm.client.openai.completion.request.OpenAIChatCompletionDetailedMessage;
import ee.carlrobert.llm.client.openai.completion.request.OpenAIChatCompletionMessage;
import ee.carlrobert.llm.client.openai.completion.request.OpenAIChatCompletionStandardMessage;
import ee.carlrobert.llm.client.openai.completion.request.OpenAIMessageTextContent;
import java.util.List;
import java.util.stream.Stream;
@Service
public final class EncodingManager {
@ -31,13 +33,10 @@ public final class EncodingManager {
}
public int countConversationTokens(Conversation conversation) {
if (conversation != null) {
return conversation.getMessages().stream()
.mapToInt(
message -> countTokens(message.getPrompt()) + countTokens(message.getResponse()))
.sum();
}
return 0;
return (conversation == null ? Stream.<Message>empty() : conversation.getMessages().stream())
.mapToInt(
message -> countTokens(message.getPrompt()) + countTokens(message.getResponse()))
.sum();
}
public int countMessageTokens(OpenAIChatCompletionMessage message) {
@ -46,11 +45,11 @@ public final class EncodingManager {
}
return ((OpenAIChatCompletionDetailedMessage) message).getContent().stream()
.filter(it -> it instanceof OpenAIMessageTextContent)
.map(it -> countMessageTokens(
.filter(OpenAIMessageTextContent.class::isInstance)
.mapToInt(it -> countMessageTokens(
((OpenAIChatCompletionDetailedMessage) message).getRole(),
((OpenAIMessageTextContent) it).getText()))
.reduce(0, Integer::sum);
.sum();
}
public int countMessageTokens(String role, String content) {
@ -86,9 +85,7 @@ public final class EncodingManager {
private IntArrayList convertToIntArrayList(List<Integer> tokens) {
var result = new IntArrayList(tokens.size());
for (var integer : tokens) {
result.add(integer);
}
tokens.forEach(result::add);
return result;
}
}