diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index 223c2fee..c8661f6b 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -12,7 +12,7 @@ jsoup = "1.17.2" jtokkit = "1.1.0" junit = "5.11.0" kotlin = "2.0.0" -llm-client = "0.8.17" +llm-client = "0.8.18" okio = "3.9.0" tree-sitter = "0.22.6a" diff --git a/src/main/java/ee/carlrobert/codegpt/ProjectCompilationStatusListener.java b/src/main/java/ee/carlrobert/codegpt/ProjectCompilationStatusListener.java index e4e0930d..31daad3d 100644 --- a/src/main/java/ee/carlrobert/codegpt/ProjectCompilationStatusListener.java +++ b/src/main/java/ee/carlrobert/codegpt/ProjectCompilationStatusListener.java @@ -12,7 +12,7 @@ import com.intellij.openapi.compiler.CompileContext; import com.intellij.openapi.compiler.CompilerMessage; import com.intellij.openapi.compiler.CompilerMessageCategory; import com.intellij.openapi.project.Project; -import ee.carlrobert.codegpt.completions.CompletionRequestProvider; +import ee.carlrobert.codegpt.completions.CompletionRequestUtil; import ee.carlrobert.codegpt.conversations.message.Message; import ee.carlrobert.codegpt.settings.configuration.ConfigurationSettings; import ee.carlrobert.codegpt.toolwindow.chat.ChatToolWindowContentManager; @@ -69,7 +69,7 @@ public class ProjectCompilationStatusListener implements CompilationStatusListen .map(ReferencedFile::getFilePath) .toList()); message.setUserMessage(message.getPrompt()); - message.setPrompt(CompletionRequestProvider.getPromptWithContext( + message.setPrompt(CompletionRequestUtil.getPromptWithContext( new ArrayList<>(errorMapping.keySet()), prompt)); return message; diff --git a/src/main/java/ee/carlrobert/codegpt/actions/GenerateGitCommitMessageAction.java b/src/main/java/ee/carlrobert/codegpt/actions/GenerateGitCommitMessageAction.java index 07829fb1..2ed218df 100644 --- a/src/main/java/ee/carlrobert/codegpt/actions/GenerateGitCommitMessageAction.java +++ b/src/main/java/ee/carlrobert/codegpt/actions/GenerateGitCommitMessageAction.java @@ -84,11 +84,10 @@ public class GenerateGitCommitMessageAction extends AnAction { var commitWorkflowUi = event.getData(VcsDataKeys.COMMIT_WORKFLOW_UI); if (commitWorkflowUi != null) { - CompletionRequestService.getInstance() - .generateCommitMessageAsync( - project.getService(CommitMessageTemplate.class).getSystemPrompt(), - gitDiff, - getEventListener(project, commitWorkflowUi)); + CompletionRequestService.getInstance().getCommitMessageAsync( + project.getService(CommitMessageTemplate.class).getSystemPrompt(), + gitDiff, + getEventListener(project, commitWorkflowUi)); } } diff --git a/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestProvider.java b/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestProvider.java deleted file mode 100644 index cdb33519..00000000 --- a/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestProvider.java +++ /dev/null @@ -1,698 +0,0 @@ -package ee.carlrobert.codegpt.completions; - -import static ee.carlrobert.codegpt.completions.ConversationType.FIX_COMPILE_ERRORS; -import static ee.carlrobert.codegpt.credentials.CredentialsStore.CredentialKey.CUSTOM_SERVICE_API_KEY; -import static ee.carlrobert.codegpt.util.file.FileUtil.getResourceContent; -import static java.lang.String.format; -import static java.util.Collections.emptyList; -import static java.util.Objects.requireNonNull; -import static java.util.stream.Collectors.joining; -import static java.util.stream.Collectors.toList; - -import com.fasterxml.jackson.core.JsonProcessingException; -import com.fasterxml.jackson.databind.ObjectMapper; -import com.intellij.openapi.application.ApplicationManager; -import ee.carlrobert.codegpt.EncodingManager; -import ee.carlrobert.codegpt.ReferencedFile; -import ee.carlrobert.codegpt.completions.llama.LlamaModel; -import ee.carlrobert.codegpt.completions.llama.PromptTemplate; -import ee.carlrobert.codegpt.conversations.Conversation; -import ee.carlrobert.codegpt.conversations.ConversationsState; -import ee.carlrobert.codegpt.conversations.message.Message; -import ee.carlrobert.codegpt.credentials.CredentialsStore; -import ee.carlrobert.codegpt.settings.IncludedFilesSettings; -import ee.carlrobert.codegpt.settings.configuration.ConfigurationSettings; -import ee.carlrobert.codegpt.settings.persona.PersonaSettings; -import ee.carlrobert.codegpt.settings.service.anthropic.AnthropicSettings; -import ee.carlrobert.codegpt.settings.service.custom.CustomServiceChatCompletionSettingsState; -import ee.carlrobert.codegpt.settings.service.custom.CustomServiceSettings; -import ee.carlrobert.codegpt.settings.service.llama.LlamaSettings; -import ee.carlrobert.codegpt.settings.service.ollama.OllamaSettings; -import ee.carlrobert.codegpt.settings.service.openai.OpenAISettings; -import ee.carlrobert.codegpt.util.file.FileUtil; -import ee.carlrobert.llm.client.anthropic.completion.ClaudeBase64Source; -import ee.carlrobert.llm.client.anthropic.completion.ClaudeCompletionDetailedMessage; -import ee.carlrobert.llm.client.anthropic.completion.ClaudeCompletionMessage; -import ee.carlrobert.llm.client.anthropic.completion.ClaudeCompletionRequest; -import ee.carlrobert.llm.client.anthropic.completion.ClaudeCompletionStandardMessage; -import ee.carlrobert.llm.client.anthropic.completion.ClaudeMessageImageContent; -import ee.carlrobert.llm.client.anthropic.completion.ClaudeMessageTextContent; -import ee.carlrobert.llm.client.google.completion.GoogleCompletionContent; -import ee.carlrobert.llm.client.google.completion.GoogleCompletionRequest; -import ee.carlrobert.llm.client.google.completion.GoogleContentPart; -import ee.carlrobert.llm.client.google.completion.GoogleContentPart.Blob; -import ee.carlrobert.llm.client.google.completion.GoogleGenerationConfig; -import ee.carlrobert.llm.client.google.models.GoogleModel; -import ee.carlrobert.llm.client.llama.completion.LlamaCompletionRequest; -import ee.carlrobert.llm.client.ollama.completion.request.OllamaChatCompletionMessage; -import ee.carlrobert.llm.client.ollama.completion.request.OllamaChatCompletionRequest; -import ee.carlrobert.llm.client.ollama.completion.request.OllamaParameters; -import ee.carlrobert.llm.client.openai.completion.OpenAIChatCompletionModel; -import ee.carlrobert.llm.client.openai.completion.request.OpenAIChatCompletionDetailedMessage; -import ee.carlrobert.llm.client.openai.completion.request.OpenAIChatCompletionMessage; -import ee.carlrobert.llm.client.openai.completion.request.OpenAIChatCompletionRequest; -import ee.carlrobert.llm.client.openai.completion.request.OpenAIChatCompletionStandardMessage; -import ee.carlrobert.llm.client.openai.completion.request.OpenAIImageUrl; -import ee.carlrobert.llm.client.openai.completion.request.OpenAIMessageImageURLContent; -import ee.carlrobert.llm.client.openai.completion.request.OpenAIMessageTextContent; -import ee.carlrobert.llm.client.openai.completion.request.RequestDocumentationDetails; -import java.io.IOException; -import java.nio.charset.StandardCharsets; -import java.nio.file.Files; -import java.nio.file.Path; -import java.util.ArrayList; -import java.util.Base64; -import java.util.List; -import java.util.Map; -import java.util.NoSuchElementException; -import java.util.Objects; -import java.util.stream.Collectors; -import java.util.stream.Stream; -import okhttp3.Request; -import okhttp3.RequestBody; -import org.jetbrains.annotations.Nullable; - -public class CompletionRequestProvider { - - public static final String GENERATE_COMMIT_MESSAGE_SYSTEM_PROMPT = - getResourceContent("/prompts/generate-commit-message.txt"); - - public static final String FIX_COMPILE_ERRORS_SYSTEM_PROMPT = - getResourceContent("/prompts/fix-compile-errors.txt"); - - public static final String GENERATE_METHOD_NAMES_SYSTEM_PROMPT = - getResourceContent("/prompts/method-name-generator.txt"); - - public static final String EDIT_CODE_SYSTEM_PROMPT = - getResourceContent("/prompts/edit-code.txt"); - - public static String getPromptWithContext(List referencedFiles, - String userPrompt) { - var includedFilesSettings = IncludedFilesSettings.getCurrentState(); - var repeatableContext = referencedFiles.stream() - .map(item -> includedFilesSettings.getRepeatableContext() - .replace("{FILE_PATH}", item.getFilePath()) - .replace("{FILE_CONTENT}", format( - "```%s%n%s%n```", - item.getFileExtension(), - item.getFileContent().trim()))) - .collect(joining("\n\n")); - - return includedFilesSettings.getPromptTemplate() - .replace("{REPEATABLE_CONTEXT}", repeatableContext) - .replace("{QUESTION}", userPrompt); - } - - public static OpenAIChatCompletionRequest buildOpenAILookupCompletionRequest(String context) { - return buildOpenAILookupCompletionRequest(context, OpenAISettings.getCurrentState().getModel()); - } - - public static OpenAIChatCompletionRequest buildOpenAILookupCompletionRequest( - String context, - String model) { - return new OpenAIChatCompletionRequest.Builder( - List.of( - new OpenAIChatCompletionStandardMessage("system", GENERATE_METHOD_NAMES_SYSTEM_PROMPT), - new OpenAIChatCompletionStandardMessage("user", context))) - .setModel(model) - .setStream(false) - .build(); - } - - public static OpenAIChatCompletionRequest buildEditCodeRequest( - String context, - @Nullable String model) { - return new OpenAIChatCompletionRequest.Builder( - List.of( - new OpenAIChatCompletionStandardMessage("system", EDIT_CODE_SYSTEM_PROMPT), - new OpenAIChatCompletionStandardMessage("user", context))) - .setModel(model) - .setStream(true) - .setMaxTokens(ConfigurationSettings.getState().getMaxTokens()) - .build(); - } - - public static Request buildCustomOpenAIChatCompletionRequest(CallParameters callParameters) { - return buildCustomOpenAIChatCompletionRequest( - ApplicationManager.getApplication().getService(CustomServiceSettings.class) - .getState() - .getChatCompletionSettings(), - CompletionRequestProvider.buildOpenAIMessages(callParameters), - true, - CredentialsStore.getCredential(CUSTOM_SERVICE_API_KEY)); - } - - private static Request buildCustomOpenAIChatCompletionRequest( - CustomServiceChatCompletionSettingsState settings, - List messages, - boolean streamRequest, - String credential) { - var requestBuilder = - new Request.Builder().url(requireNonNull(settings.getUrl()).trim()); - for (var entry : settings.getHeaders().entrySet()) { - String value = entry.getValue(); - if (credential != null && value.contains("$CUSTOM_SERVICE_API_KEY")) { - value = value.replace("$CUSTOM_SERVICE_API_KEY", credential); - } - requestBuilder.addHeader(entry.getKey(), value); - } - - var body = settings.getBody().entrySet().stream() - .collect(Collectors.toMap( - Map.Entry::getKey, - entry -> { - if (!streamRequest && "stream".equals(entry.getKey())) { - return false; - } - - var value = entry.getValue(); - if (value instanceof String string && "$OPENAI_MESSAGES".equals(string.trim())) { - return messages; - } - return value; - } - )); - - try { - var requestBody = RequestBody.create(new ObjectMapper() - .writerWithDefaultPrettyPrinter() - .writeValueAsString(body) - .getBytes(StandardCharsets.UTF_8)); - return requestBuilder.post(requestBody).build(); - } catch (JsonProcessingException e) { - throw new RuntimeException(e); - } - } - - public static Request buildCustomOpenAIEditCodeRequest(String input) { - return buildCustomOpenAIChatCompletionRequest( - ApplicationManager.getApplication().getService(CustomServiceSettings.class) - .getState() - .getChatCompletionSettings(), - List.of( - new OpenAIChatCompletionStandardMessage("system", EDIT_CODE_SYSTEM_PROMPT), - new OpenAIChatCompletionStandardMessage("user", input)), - true, - CredentialsStore.getCredential(CUSTOM_SERVICE_API_KEY)); - } - - public static Request buildCustomOpenAICompletionRequest(String system, String context) { - return buildCustomOpenAIChatCompletionRequest( - ApplicationManager.getApplication().getService(CustomServiceSettings.class) - .getState() - .getChatCompletionSettings(), - List.of( - new OpenAIChatCompletionStandardMessage("system", system), - new OpenAIChatCompletionStandardMessage("user", context)), - true, - CredentialsStore.getCredential(CUSTOM_SERVICE_API_KEY)); - } - - public static Request buildCustomOpenAICompletionRequest(String context, String url, - Map headers, Map body, String credential) { - var usedSettings = new CustomServiceChatCompletionSettingsState(); - usedSettings.setBody(body); - usedSettings.setHeaders(headers); - usedSettings.setUrl(url); - return buildCustomOpenAIChatCompletionRequest( - usedSettings, - List.of(new OpenAIChatCompletionStandardMessage("user", context)), - true, - credential); - } - - public static Request buildCustomOpenAILookupCompletionRequest(String context) { - return buildCustomOpenAIChatCompletionRequest( - ApplicationManager.getApplication().getService(CustomServiceSettings.class) - .getState() - .getChatCompletionSettings(), - List.of( - new OpenAIChatCompletionStandardMessage( - "system", - GENERATE_COMMIT_MESSAGE_SYSTEM_PROMPT), - new OpenAIChatCompletionStandardMessage("user", context)), - false, - CredentialsStore.getCredential(CUSTOM_SERVICE_API_KEY)); - } - - public static LlamaCompletionRequest buildLlamaLookupCompletionRequest(String context) { - return new LlamaCompletionRequest.Builder( - PromptTemplate.LLAMA.buildPrompt(GENERATE_COMMIT_MESSAGE_SYSTEM_PROMPT, context, List.of())) - .setStream(false) - .build(); - } - - public static LlamaCompletionRequest buildLlamaCompletionRequest( - Message message, - Conversation conversation, - ConversationType conversationType) { - var settings = LlamaSettings.getCurrentState(); - PromptTemplate promptTemplate; - if (settings.isRunLocalServer()) { - promptTemplate = settings.isUseCustomModel() - ? settings.getLocalModelPromptTemplate() - : LlamaModel.findByHuggingFaceModel(settings.getHuggingFaceModel()).getPromptTemplate(); - } else { - promptTemplate = settings.getRemoteModelPromptTemplate(); - } - - var systemPrompt = conversationType == FIX_COMPILE_ERRORS - ? FIX_COMPILE_ERRORS_SYSTEM_PROMPT : PersonaSettings.getSystemPrompt(); - - var prompt = promptTemplate.buildPrompt( - systemPrompt, - message.getPrompt(), - conversation.getMessages()); - var configuration = ConfigurationSettings.getState(); - return new LlamaCompletionRequest.Builder(prompt) - .setN_predict(configuration.getMaxTokens()) - .setTemperature(configuration.getTemperature()) - .setTop_k(settings.getTopK()) - .setTop_p(settings.getTopP()) - .setMin_p(settings.getMinP()) - .setRepeat_penalty(settings.getRepeatPenalty()) - .setStop(promptTemplate.getStopTokens()) - .build(); - } - - public static LlamaCompletionRequest buildLlamaEditCodeRequest(String input) { - var settings = LlamaSettings.getCurrentState(); - PromptTemplate promptTemplate; - if (settings.isRunLocalServer()) { - promptTemplate = settings.isUseCustomModel() - ? settings.getLocalModelPromptTemplate() - : LlamaModel.findByHuggingFaceModel(settings.getHuggingFaceModel()).getPromptTemplate(); - } else { - promptTemplate = settings.getRemoteModelPromptTemplate(); - } - - var prompt = promptTemplate.buildPrompt(EDIT_CODE_SYSTEM_PROMPT, input, emptyList()); - var configuration = ConfigurationSettings.getState(); - return new LlamaCompletionRequest.Builder(prompt) - .setN_predict(configuration.getMaxTokens()) - .setTemperature(configuration.getTemperature()) - .setTop_k(settings.getTopK()) - .setTop_p(settings.getTopP()) - .setMin_p(settings.getMinP()) - .setRepeat_penalty(settings.getRepeatPenalty()) - .setStop(promptTemplate.getStopTokens()) - .build(); - } - - public static OpenAIChatCompletionRequest buildOpenAIChatCompletionRequest( - @Nullable String model, - CallParameters callParameters) { - var configuration = ConfigurationSettings.getState(); - var requestBuilder = new OpenAIChatCompletionRequest.Builder( - buildOpenAIMessages(model, callParameters)) - .setModel(model) - .setMaxTokens(configuration.getMaxTokens()) - .setStream(true) - .setTemperature(configuration.getTemperature()); - if (callParameters.getMessage().isWebSearchIncluded()) { - // tri-state boolean - requestBuilder.setWebSearchIncluded(true); - } - var documentationDetails = - callParameters.getMessage().getDocumentationDetails(); - if (documentationDetails != null) { - var requestDocumentationDetails = new RequestDocumentationDetails(); - requestDocumentationDetails.setName(documentationDetails.getName()); - requestDocumentationDetails.setUrl(documentationDetails.getUrl()); - requestBuilder.setDocumentationDetails(requestDocumentationDetails); - } - return requestBuilder.build(); - } - - public static GoogleCompletionRequest buildGoogleChatCompletionRequest( - @Nullable String model, - CallParameters callParameters) { - var configuration = ConfigurationSettings.getState(); - return new GoogleCompletionRequest.Builder(buildGoogleMessages(model, callParameters)) - .generationConfig(new GoogleGenerationConfig.Builder() - .maxOutputTokens(configuration.getMaxTokens()) - .temperature(configuration.getTemperature()).build()).build(); - } - - public static GoogleCompletionRequest buildGoogleEditCodeRequest(String input) { - var configuration = ConfigurationSettings.getState(); - return new GoogleCompletionRequest.Builder(List.of( - new GoogleCompletionContent("user", List.of(EDIT_CODE_SYSTEM_PROMPT)), - new GoogleCompletionContent("model", List.of("Understood.")), - new GoogleCompletionContent("user", List.of(input)))) - .generationConfig(new GoogleGenerationConfig.Builder() - .maxOutputTokens(configuration.getMaxTokens()) - .temperature(configuration.getTemperature()).build()).build(); - } - - public static ClaudeCompletionRequest buildAnthropicChatCompletionRequest( - CallParameters callParameters) { - var configuration = ConfigurationSettings.getState(); - var settings = AnthropicSettings.getCurrentState(); - var request = new ClaudeCompletionRequest(); - request.setModel(settings.getModel()); - request.setMaxTokens(configuration.getMaxTokens()); - request.setStream(true); - request.setSystem(PersonaSettings.getSystemPrompt()); - List messages = callParameters.getConversation().getMessages().stream() - .filter(prevMessage -> prevMessage.getResponse() != null - && !prevMessage.getResponse().isEmpty()) - .flatMap(prevMessage -> Stream.of( - new ClaudeCompletionStandardMessage("user", prevMessage.getPrompt()), - new ClaudeCompletionStandardMessage("assistant", prevMessage.getResponse()))) - .collect(toList()); - - if (callParameters.getImageMediaType() != null && callParameters.getImageData().length > 0) { - messages.add(new ClaudeCompletionDetailedMessage("user", - List.of( - new ClaudeMessageImageContent(new ClaudeBase64Source( - callParameters.getImageMediaType(), - callParameters.getImageData())), - new ClaudeMessageTextContent(callParameters.getMessage().getPrompt())))); - } else { - messages.add( - new ClaudeCompletionStandardMessage("user", callParameters.getMessage().getPrompt())); - } - request.setMessages(messages); - return request; - } - - public static ClaudeCompletionRequest buildAnthropicEditCodeRequest( - String input) { - var configuration = ConfigurationSettings.getState(); - var settings = AnthropicSettings.getCurrentState(); - var request = new ClaudeCompletionRequest(); - request.setModel(settings.getModel()); - request.setMaxTokens(configuration.getMaxTokens()); - request.setStream(true); - request.setSystem(EDIT_CODE_SYSTEM_PROMPT); - request.setMessages(List.of(new ClaudeCompletionStandardMessage("user", input))); - return request; - } - - public static OllamaChatCompletionRequest buildOllamaChatCompletionRequest( - CallParameters callParameters - ) { - var configuration = ConfigurationSettings.getState(); - var settings = ApplicationManager.getApplication().getService(OllamaSettings.class).getState(); - return new OllamaChatCompletionRequest - .Builder(settings.getModel(), buildOllamaMessages(callParameters)) - .setStream(true) - .setOptions(new OllamaParameters.Builder() - .numPredict(configuration.getMaxTokens()) - .temperature((double) configuration.getTemperature()) - .build()) - .build(); - } - - public static OllamaChatCompletionRequest buildOllamaEditCodeRequest(String input) { - var configuration = ConfigurationSettings.getState(); - var settings = - ApplicationManager.getApplication().getService(OllamaSettings.class).getState(); - return new OllamaChatCompletionRequest - .Builder(settings.getModel(), List.of( - new OllamaChatCompletionMessage("system", EDIT_CODE_SYSTEM_PROMPT, null), - new OllamaChatCompletionMessage("user", input, null))) - .setStream(true) - .setOptions(new OllamaParameters.Builder() - .numPredict(configuration.getMaxTokens()) - .temperature((double) configuration.getTemperature()) - .build()) - .build(); - } - - private static List buildOllamaMessages( - CallParameters callParameters) { - var message = callParameters.getMessage(); - var messages = new ArrayList(); - if (callParameters.getConversationType() == ConversationType.DEFAULT) { - messages.add( - new OllamaChatCompletionMessage("system", PersonaSettings.getSystemPrompt(), null)); - } - if (callParameters.getConversationType() == ConversationType.FIX_COMPILE_ERRORS) { - messages.add( - new OllamaChatCompletionMessage("system", FIX_COMPILE_ERRORS_SYSTEM_PROMPT, null) - ); - } - - for (var prevMessage : callParameters.getConversation().getMessages()) { - if (callParameters.isRetry() && prevMessage.getId().equals(message.getId())) { - break; - } - var prevMessageImageFilePath = prevMessage.getImageFilePath(); - if (prevMessageImageFilePath != null && !prevMessageImageFilePath.isEmpty()) { - try { - var imageFilePath = Path.of(prevMessageImageFilePath); - var imageBytes = Files.readAllBytes(imageFilePath); - var imageBase64 = Base64.getEncoder().encodeToString(imageBytes); - messages.add( - new OllamaChatCompletionMessage( - "user", prevMessage.getPrompt(), List.of(imageBase64) - ) - ); - } catch (IOException e) { - throw new RuntimeException(e); - } - } else { - messages.add( - new OllamaChatCompletionMessage("user", prevMessage.getPrompt(), null) - ); - } - messages.add( - new OllamaChatCompletionMessage("assistant", prevMessage.getResponse(), null) - ); - } - - if (callParameters.getImageMediaType() != null && callParameters.getImageData().length > 0) { - var imageBase64 = Base64.getEncoder().encodeToString(callParameters.getImageData()); - messages.add( - new OllamaChatCompletionMessage("user", message.getPrompt(), List.of(imageBase64)) - ); - } else { - messages.add(new OllamaChatCompletionMessage("user", message.getPrompt(), null)); - } - return messages; - } - - private static List buildOpenAIMessages( - CallParameters callParameters) { - var message = callParameters.getMessage(); - var messages = new ArrayList(); - if (callParameters.getConversationType() == ConversationType.DEFAULT) { - var sessionPersonaDetails = callParameters.getMessage().getPersonaDetails(); - if (callParameters.getMessage().getPersonaDetails() == null) { - messages.add( - new OpenAIChatCompletionStandardMessage("system", PersonaSettings.getSystemPrompt())); - } else { - messages.add(new OpenAIChatCompletionStandardMessage( - "system", - sessionPersonaDetails.instructions())); - } - } - if (callParameters.getConversationType() == ConversationType.FIX_COMPILE_ERRORS) { - messages.add( - new OpenAIChatCompletionStandardMessage("system", FIX_COMPILE_ERRORS_SYSTEM_PROMPT)); - } - - for (var prevMessage : callParameters.getConversation().getMessages()) { - if (callParameters.isRetry() && prevMessage.getId().equals(message.getId())) { - break; - } - var prevMessageImageFilePath = prevMessage.getImageFilePath(); - if (prevMessageImageFilePath != null && !prevMessageImageFilePath.isEmpty()) { - try { - var imageFilePath = Path.of(prevMessageImageFilePath); - var imageData = Files.readAllBytes(imageFilePath); - var imageMediaType = FileUtil.getImageMediaType(imageFilePath.getFileName().toString()); - messages.add(new OpenAIChatCompletionDetailedMessage("user", - List.of( - new OpenAIMessageImageURLContent(new OpenAIImageUrl(imageMediaType, imageData)), - new OpenAIMessageTextContent(prevMessage.getPrompt())))); - } catch (IOException e) { - throw new RuntimeException(e); - } - } else { - messages.add(new OpenAIChatCompletionStandardMessage("user", prevMessage.getPrompt())); - } - messages.add( - new OpenAIChatCompletionStandardMessage("assistant", prevMessage.getResponse()) - ); - } - - if (callParameters.getImageMediaType() != null && callParameters.getImageData().length > 0) { - messages.add(new OpenAIChatCompletionDetailedMessage("user", - List.of( - new OpenAIMessageImageURLContent( - new OpenAIImageUrl(callParameters.getImageMediaType(), - callParameters.getImageData())), - new OpenAIMessageTextContent(message.getPrompt())))); - } else { - messages.add(new OpenAIChatCompletionStandardMessage("user", message.getPrompt())); - } - return messages; - } - - public static List buildOpenAIMessages( - @Nullable String model, - CallParameters callParameters) { - var messages = buildOpenAIMessages(callParameters); - - if (model == null) { - return messages; - } - - var encodingManager = EncodingManager.getInstance(); - int totalUsage = messages.parallelStream() - .mapToInt(encodingManager::countMessageTokens) - .sum() + ConfigurationSettings.getState().getMaxTokens(); - int modelMaxTokens; - try { - modelMaxTokens = OpenAIChatCompletionModel.findByCode(model).getMaxTokens(); - - if (totalUsage <= modelMaxTokens) { - return messages; - } - } catch (NoSuchElementException ex) { - return messages; - } - return tryReducingMessagesOrThrow( - messages, - callParameters.getConversation().isDiscardTokenLimit(), - totalUsage, - modelMaxTokens); - } - - private static List buildGoogleMessages(CallParameters callParameters) { - var message = callParameters.getMessage(); - var messages = new ArrayList(); - // Gemini API does not support direct 'system' prompts: - // see https://www.reddit.com/r/Bard/comments/1b90i8o/does_gemini_have_a_system_prompt_option_while/ - if (callParameters.getConversationType() == ConversationType.DEFAULT) { - messages.add(new GoogleCompletionContent("user", List.of(PersonaSettings.getSystemPrompt()))); - messages.add(new GoogleCompletionContent("model", List.of("Understood."))); - } - if (callParameters.getConversationType() == ConversationType.FIX_COMPILE_ERRORS) { - messages.add( - new GoogleCompletionContent("user", List.of(FIX_COMPILE_ERRORS_SYSTEM_PROMPT))); - messages.add(new GoogleCompletionContent("model", List.of("Understood."))); - } - - for (var prevMessage : callParameters.getConversation().getMessages()) { - if (callParameters.isRetry() && prevMessage.getId().equals(message.getId())) { - break; - } - var prevMessageImageFilePath = prevMessage.getImageFilePath(); - if (prevMessageImageFilePath != null && !prevMessageImageFilePath.isEmpty()) { - try { - var imageFilePath = Path.of(prevMessageImageFilePath); - var imageData = Files.readAllBytes(imageFilePath); - var imageMediaType = FileUtil.getImageMediaType(imageFilePath.getFileName().toString()); - messages.add(new GoogleCompletionContent( - List.of( - new GoogleContentPart(null, new Blob(imageMediaType, imageData)), - new GoogleContentPart(prevMessage.getPrompt())), "user")); - } catch (IOException e) { - throw new RuntimeException(e); - } - } else { - messages.add(new GoogleCompletionContent("user", List.of(prevMessage.getPrompt()))); - } - messages.add(new GoogleCompletionContent("model", List.of(prevMessage.getResponse()))); - } - - if (callParameters.getImageMediaType() != null && callParameters.getImageData().length > 0) { - messages.add(new GoogleCompletionContent( - List.of( - new GoogleContentPart(null, - new Blob(callParameters.getImageMediaType(), callParameters.getImageData())), - new GoogleContentPart(message.getPrompt())), "user")); - } else { - messages.add(new GoogleCompletionContent("user", List.of(message.getPrompt()))); - } - return messages; - } - - private static List buildGoogleMessages( - @Nullable String model, - CallParameters callParameters) { - var messages = buildGoogleMessages(callParameters); - - if (model == null) { - return messages; - } - - var encodingManager = EncodingManager.getInstance(); - int totalUsage = messages.parallelStream() - .mapToInt(message -> encodingManager.countMessageTokens(message.getRole(), - String.join(",", message.getParts().stream().map(GoogleContentPart::getText).toList()))) - .sum() + ConfigurationSettings.getState().getMaxTokens(); - int modelMaxTokens; - try { - modelMaxTokens = GoogleModel.findByCode(model).getMaxTokens(); - - if (totalUsage <= modelMaxTokens) { - return messages; - } - } catch (NoSuchElementException ex) { - return messages; - } - return tryReducingGoogleMessagesOrThrow( - messages, - callParameters.getConversation().isDiscardTokenLimit(), - totalUsage, - modelMaxTokens); - } - - private static List tryReducingMessagesOrThrow( - List messages, - boolean discardTokenLimit, - int totalUsage, - int modelMaxTokens) { - if (!ConversationsState.getInstance().discardAllTokenLimits) { - if (!discardTokenLimit) { - throw new TotalUsageExceededException(); - } - } - var encodingManager = EncodingManager.getInstance(); - // skip the system prompt - for (int i = 1; i < messages.size(); i++) { - if (totalUsage <= modelMaxTokens) { - break; - } - - var message = messages.get(i); - if (message instanceof OpenAIChatCompletionStandardMessage) { - totalUsage -= encodingManager.countMessageTokens(message); - messages.set(i, null); - } - } - - return messages.stream().filter(Objects::nonNull).toList(); - } - - private static List tryReducingGoogleMessagesOrThrow( - List messages, - boolean discardTokenLimit, - int totalUsage, - int modelMaxTokens) { - if (!ConversationsState.getInstance().discardAllTokenLimits) { - if (!discardTokenLimit) { - throw new TotalUsageExceededException(); - } - } - var encodingManager = EncodingManager.getInstance(); - // skip the system prompt - for (int i = 1; i < messages.size(); i++) { - if (totalUsage <= modelMaxTokens) { - break; - } - - var message = messages.get(i); - totalUsage -= encodingManager.countMessageTokens(message.getRole(), - String.join(",", message.getParts().stream().map(GoogleContentPart::getText).toList())); - messages.set(i, null); - } - - return messages.stream().filter(Objects::nonNull).toList(); - } -} diff --git a/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestService.java b/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestService.java index eafc2be8..f3573209 100644 --- a/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestService.java +++ b/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestService.java @@ -4,41 +4,28 @@ import com.intellij.openapi.application.ApplicationManager; import com.intellij.openapi.components.Service; import com.intellij.openapi.diagnostic.Logger; import ee.carlrobert.codegpt.actions.editor.EditCodeRequestParams; -import ee.carlrobert.codegpt.completions.llama.LlamaModel; -import ee.carlrobert.codegpt.completions.llama.PromptTemplate; +import ee.carlrobert.codegpt.completions.factory.CustomOpenAIRequest; import ee.carlrobert.codegpt.credentials.CredentialsStore; import ee.carlrobert.codegpt.credentials.CredentialsStore.CredentialKey; import ee.carlrobert.codegpt.settings.GeneralSettings; -import ee.carlrobert.codegpt.settings.configuration.ConfigurationSettings; import ee.carlrobert.codegpt.settings.service.ServiceType; -import ee.carlrobert.codegpt.settings.service.anthropic.AnthropicSettings; import ee.carlrobert.codegpt.settings.service.azure.AzureSettings; -import ee.carlrobert.codegpt.settings.service.codegpt.CodeGPTServiceSettings; import ee.carlrobert.codegpt.settings.service.google.GoogleSettings; -import ee.carlrobert.codegpt.settings.service.google.GoogleSettingsState; -import ee.carlrobert.codegpt.settings.service.llama.LlamaSettings; -import ee.carlrobert.codegpt.settings.service.ollama.OllamaSettings; -import ee.carlrobert.codegpt.settings.service.openai.OpenAISettings; import ee.carlrobert.llm.client.DeserializationUtil; import ee.carlrobert.llm.client.anthropic.completion.ClaudeCompletionRequest; -import ee.carlrobert.llm.client.anthropic.completion.ClaudeCompletionStandardMessage; -import ee.carlrobert.llm.client.google.completion.GoogleCompletionContent; import ee.carlrobert.llm.client.google.completion.GoogleCompletionRequest; -import ee.carlrobert.llm.client.google.completion.GoogleGenerationConfig; import ee.carlrobert.llm.client.llama.completion.LlamaCompletionRequest; -import ee.carlrobert.llm.client.ollama.completion.request.OllamaChatCompletionMessage; import ee.carlrobert.llm.client.ollama.completion.request.OllamaChatCompletionRequest; import ee.carlrobert.llm.client.openai.completion.OpenAIChatCompletionEventSourceListener; import ee.carlrobert.llm.client.openai.completion.OpenAITextCompletionEventSourceListener; -import ee.carlrobert.llm.client.openai.completion.request.OpenAIChatCompletionRequest.Builder; -import ee.carlrobert.llm.client.openai.completion.request.OpenAIChatCompletionStandardMessage; +import ee.carlrobert.llm.client.openai.completion.request.OpenAIChatCompletionRequest; import ee.carlrobert.llm.client.openai.completion.response.OpenAIChatCompletionResponse; import ee.carlrobert.llm.client.openai.completion.response.OpenAIChatCompletionResponseChoice; import ee.carlrobert.llm.client.openai.completion.response.OpenAIChatCompletionResponseChoiceDelta; import ee.carlrobert.llm.completion.CompletionEventListener; +import ee.carlrobert.llm.completion.CompletionRequest; import java.io.IOException; import java.util.Collection; -import java.util.List; import java.util.Objects; import java.util.Optional; import java.util.stream.Stream; @@ -76,242 +63,138 @@ public final class CompletionRequestService { new OpenAIChatCompletionEventSourceListener(eventListener)); } - public EventSource getChatCompletionAsync( - CallParameters callParameters, + public String getLookupCompletion(String prompt) { + return getChatCompletion( + CompletionRequestFactory.getFactory(GeneralSettings.getSelectedService()) + .createLookupCompletionRequest(prompt)); + } + + public EventSource getCommitMessageAsync( + String systemPrompt, + String gitDiff, CompletionEventListener eventListener) { - var application = ApplicationManager.getApplication(); - return switch (GeneralSettings.getSelectedService()) { - case CODEGPT -> CompletionClientProvider.getCodeGPTClient().getChatCompletionAsync( - CompletionRequestProvider.buildOpenAIChatCompletionRequest( - application.getService(CodeGPTServiceSettings.class) - .getState() - .getChatCompletionSettings() - .getModel(), - callParameters), - eventListener - ); - case OPENAI -> CompletionClientProvider.getOpenAIClient().getChatCompletionAsync( - CompletionRequestProvider.buildOpenAIChatCompletionRequest( - OpenAISettings.getCurrentState().getModel(), - callParameters), - eventListener); - case CUSTOM_OPENAI -> getCustomOpenAIChatCompletionAsync( - CompletionRequestProvider.buildCustomOpenAIChatCompletionRequest(callParameters), - eventListener); - case ANTHROPIC -> CompletionClientProvider.getClaudeClient().getCompletionAsync( - CompletionRequestProvider.buildAnthropicChatCompletionRequest(callParameters), - eventListener); - case AZURE -> CompletionClientProvider.getAzureClient().getChatCompletionAsync( - CompletionRequestProvider.buildOpenAIChatCompletionRequest(null, callParameters), - eventListener); - case LLAMA_CPP -> CompletionClientProvider.getLlamaClient().getChatCompletionAsync( - CompletionRequestProvider.buildLlamaCompletionRequest( - callParameters.getMessage(), - callParameters.getConversation(), - callParameters.getConversationType()), - eventListener); - case OLLAMA -> CompletionClientProvider.getOllamaClient().getChatCompletionAsync( - CompletionRequestProvider.buildOllamaChatCompletionRequest(callParameters), - eventListener); - case GOOGLE -> { - var settings = application.getService(GoogleSettings.class).getState(); - yield CompletionClientProvider.getGoogleClient().getChatCompletionAsync( - CompletionRequestProvider.buildGoogleChatCompletionRequest( - settings.getModel(), - callParameters), - settings.getModel(), - eventListener); - } - }; + return getChatCompletionAsync( + CompletionRequestFactory.getFactory(GeneralSettings.getSelectedService()) + .createCommitMessageCompletionRequest(systemPrompt, gitDiff), + eventListener); } public EventSource getEditCodeCompletionAsync( EditCodeRequestParams params, CompletionEventListener eventListener) { var input = "%s\n\n%s".formatted(params.getPrompt(), params.getSelectedText()); - return switch (GeneralSettings.getSelectedService()) { - case CODEGPT -> CompletionClientProvider.getCodeGPTClient().getChatCompletionAsync( - CompletionRequestProvider.buildEditCodeRequest( - input, - ApplicationManager.getApplication().getService(CodeGPTServiceSettings.class) - .getState() - .getChatCompletionSettings() - .getModel() - ), - eventListener - ); - case OPENAI -> CompletionClientProvider.getOpenAIClient().getChatCompletionAsync( - CompletionRequestProvider.buildEditCodeRequest( - input, - OpenAISettings.getCurrentState().getModel()), - eventListener); - case CUSTOM_OPENAI -> getCustomOpenAIChatCompletionAsync( - CompletionRequestProvider.buildCustomOpenAIEditCodeRequest(input), - eventListener); - case ANTHROPIC -> CompletionClientProvider.getClaudeClient().getCompletionAsync( - CompletionRequestProvider.buildAnthropicEditCodeRequest(input), - eventListener); - case AZURE -> CompletionClientProvider.getAzureClient().getChatCompletionAsync( - CompletionRequestProvider.buildEditCodeRequest(input, null), - eventListener); - case LLAMA_CPP -> CompletionClientProvider.getLlamaClient().getChatCompletionAsync( - CompletionRequestProvider.buildLlamaEditCodeRequest(input), - eventListener); - case OLLAMA -> CompletionClientProvider.getOllamaClient().getChatCompletionAsync( - CompletionRequestProvider.buildOllamaEditCodeRequest(input), - eventListener); - case GOOGLE -> { - var model = - ApplicationManager.getApplication().getService(GoogleSettings.class) - .getState() - .getModel(); - yield CompletionClientProvider.getGoogleClient().getChatCompletionAsync( - CompletionRequestProvider.buildGoogleEditCodeRequest(input), - model, - eventListener); - } - }; + return getChatCompletionAsync( + CompletionRequestFactory.getFactory(GeneralSettings.getSelectedService()) + .createEditCodeCompletionRequest(input), + eventListener); } - public void generateCommitMessageAsync( - String systemPrompt, - String gitDiff, + public EventSource getChatCompletionAsync( + CallParameters callParameters, CompletionEventListener eventListener) { - var configuration = ConfigurationSettings.getState(); - var openaiRequestBuilder = new Builder(List.of( - new OpenAIChatCompletionStandardMessage("system", systemPrompt), - new OpenAIChatCompletionStandardMessage("user", gitDiff))) - .setModel(OpenAISettings.getCurrentState().getModel()); - var selectedService = GeneralSettings.getSelectedService(); - switch (selectedService) { - case CODEGPT: - CompletionClientProvider.getCodeGPTClient().getChatCompletionAsync( - openaiRequestBuilder - .setModel( - ApplicationManager.getApplication().getService(CodeGPTServiceSettings.class) - .getState() - .getChatCompletionSettings() - .getModel()) - .build(), - eventListener); - break; - case OPENAI: - CompletionClientProvider.getOpenAIClient().getChatCompletionAsync( - openaiRequestBuilder - .setModel(OpenAISettings.getCurrentState().getModel()) - .build(), - eventListener); - break; - case CUSTOM_OPENAI: - var httpClient = CompletionClientProvider.getDefaultClientBuilder().build(); - EventSources.createFactory(httpClient).newEventSource( - CompletionRequestProvider.buildCustomOpenAICompletionRequest( - systemPrompt, - gitDiff), - new OpenAIChatCompletionEventSourceListener(eventListener)); - break; - case ANTHROPIC: - var anthropicSettings = AnthropicSettings.getCurrentState(); - var claudeRequest = new ClaudeCompletionRequest(); - claudeRequest.setSystem(systemPrompt); - claudeRequest.setStream(true); - claudeRequest.setMaxTokens(configuration.getMaxTokens()); - claudeRequest.setModel(anthropicSettings.getModel()); - claudeRequest.setMessages(List.of(new ClaudeCompletionStandardMessage("user", gitDiff))); - CompletionClientProvider.getClaudeClient() - .getCompletionAsync(claudeRequest, eventListener); - break; - case AZURE: - CompletionClientProvider.getAzureClient() - .getChatCompletionAsync(openaiRequestBuilder.build(), eventListener); - break; - case LLAMA_CPP: - var settings = LlamaSettings.getCurrentState(); - PromptTemplate promptTemplate; - if (settings.isRunLocalServer()) { - promptTemplate = settings.isUseCustomModel() - ? settings.getLocalModelPromptTemplate() - : LlamaModel.findByHuggingFaceModel(settings.getHuggingFaceModel()) - .getPromptTemplate(); - } else { - promptTemplate = settings.getRemoteModelPromptTemplate(); - } - var finalPrompt = promptTemplate.buildPrompt(systemPrompt, gitDiff, List.of()); - CompletionClientProvider.getLlamaClient().getChatCompletionAsync( - new LlamaCompletionRequest.Builder(finalPrompt) - .setN_predict(configuration.getMaxTokens()) - .setTemperature(configuration.getTemperature()) - .setTop_k(settings.getTopK()) - .setTop_p(settings.getTopP()) - .setMin_p(settings.getMinP()) - .setRepeat_penalty(settings.getRepeatPenalty()) - .build(), eventListener); - break; - case OLLAMA: - var model = ApplicationManager.getApplication() - .getService(OllamaSettings.class) - .getState() - .getModel(); - var request = new OllamaChatCompletionRequest.Builder( - model, - List.of( - new OllamaChatCompletionMessage("system", systemPrompt, null), - new OllamaChatCompletionMessage("user", gitDiff, null) - ) - ).build(); - CompletionClientProvider.getOllamaClient().getChatCompletionAsync(request, eventListener); - break; - case GOOGLE: - GoogleSettingsState state = ApplicationManager.getApplication() - .getService(GoogleSettings.class).getState(); - CompletionClientProvider.getGoogleClient() - .getChatCompletionAsync(new GoogleCompletionRequest.Builder( - List.of( - new GoogleCompletionContent("user", List.of(systemPrompt)), - new GoogleCompletionContent("model", List.of("Understood.")), - new GoogleCompletionContent("user", List.of(gitDiff)) - )) - .generationConfig(new GoogleGenerationConfig.Builder() - .maxOutputTokens(configuration.getMaxTokens()) - .temperature(configuration.getTemperature()).build()) - .build(), state.getModel(), eventListener); - break; - default: - LOG.debug("Unknown service: {}", selectedService); - break; - } + return getChatCompletionAsync( + CompletionRequestFactory.getFactory(GeneralSettings.getSelectedService()) + .createChatCompletionRequest(callParameters), + eventListener); } - public Optional getLookupCompletion(String prompt) { - var openaiRequest = CompletionRequestProvider.buildOpenAILookupCompletionRequest(prompt); - var selectedService = GeneralSettings.getSelectedService(); - switch (selectedService) { - case CODEGPT: - var model = ApplicationManager.getApplication().getService(CodeGPTServiceSettings.class) - .getState() - .getChatCompletionSettings() - .getModel(); - return tryExtractContent( - CompletionClientProvider.getCodeGPTClient().getChatCompletion( - CompletionRequestProvider.buildOpenAILookupCompletionRequest(prompt, model))); - case OPENAI: - return tryExtractContent( - CompletionClientProvider.getOpenAIClient().getChatCompletion(openaiRequest)); - case AZURE: - return tryExtractContent( - CompletionClientProvider.getAzureClient().getChatCompletion(openaiRequest)); - case CUSTOM_OPENAI: - var request = CompletionRequestProvider.buildCustomOpenAILookupCompletionRequest(prompt); - var httpClient = CompletionClientProvider.getDefaultClientBuilder().build(); - try (var response = httpClient.newCall(request).execute()) { - return tryExtractContent( - DeserializationUtil.mapResponse(response, OpenAIChatCompletionResponse.class)); - } catch (IOException e) { - throw new RuntimeException(e); - } - default: - return Optional.empty(); + private EventSource getChatCompletionAsync( + CompletionRequest request, + CompletionEventListener eventListener) { + if (request instanceof OpenAIChatCompletionRequest completionRequest) { + return switch (GeneralSettings.getSelectedService()) { + case CODEGPT -> CompletionClientProvider.getCodeGPTClient() + .getChatCompletionAsync(completionRequest, eventListener); + case OPENAI -> CompletionClientProvider.getOpenAIClient() + .getChatCompletionAsync(completionRequest, eventListener); + case AZURE -> CompletionClientProvider.getAzureClient() + .getChatCompletionAsync(completionRequest, eventListener); + default -> throw new RuntimeException("Unknown service selected"); + }; } + if (request instanceof CustomOpenAIRequest completionRequest) { + return getCustomOpenAIChatCompletionAsync(completionRequest.getRequest(), eventListener); + } + if (request instanceof ClaudeCompletionRequest completionRequest) { + return CompletionClientProvider.getClaudeClient().getCompletionAsync( + completionRequest, + eventListener); + } + if (request instanceof GoogleCompletionRequest completionRequest) { + return CompletionClientProvider.getGoogleClient().getChatCompletionAsync( + completionRequest, + ApplicationManager.getApplication().getService(GoogleSettings.class) + .getState() + .getModel(), + eventListener); + } + if (request instanceof OllamaChatCompletionRequest completionRequest) { + return CompletionClientProvider.getOllamaClient().getChatCompletionAsync( + completionRequest, + eventListener); + } + if (request instanceof LlamaCompletionRequest completionRequest) { + return CompletionClientProvider.getLlamaClient().getChatCompletionAsync( + completionRequest, + eventListener); + } + + throw new IllegalStateException("Unknown request type: " + request.getClass()); + } + + private String getChatCompletion(CompletionRequest request) { + if (request instanceof OpenAIChatCompletionRequest completionRequest) { + var response = switch (GeneralSettings.getSelectedService()) { + case CODEGPT -> CompletionClientProvider.getCodeGPTClient() + .getChatCompletion(completionRequest); + case OPENAI -> CompletionClientProvider.getOpenAIClient() + .getChatCompletion(completionRequest); + case AZURE -> CompletionClientProvider.getAzureClient() + .getChatCompletion(completionRequest); + default -> throw new RuntimeException("Unknown service selected"); + }; + return tryExtractContent(response).orElseThrow(); + } + if (request instanceof CustomOpenAIRequest completionRequest) { + var httpClient = CompletionClientProvider.getDefaultClientBuilder().build(); + try (var response = httpClient.newCall(completionRequest.getRequest()).execute()) { + return DeserializationUtil.mapResponse(response, OpenAIChatCompletionResponse.class) + .getChoices().get(0) + .getMessage() + .getContent(); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + if (request instanceof ClaudeCompletionRequest completionRequest) { + return CompletionClientProvider.getClaudeClient() + .getCompletion(completionRequest) + .getContent().get(0) + .getText(); + } + if (request instanceof GoogleCompletionRequest completionRequest) { + return CompletionClientProvider.getGoogleClient().getChatCompletion( + completionRequest, + ApplicationManager.getApplication().getService(GoogleSettings.class) + .getState() + .getModel()) + .getCandidates().get(0) + .getContent().getParts().get(0) + .getText(); + } + if (request instanceof OllamaChatCompletionRequest completionRequest) { + return CompletionClientProvider.getOllamaClient() + .getChatCompletion(completionRequest) + .getMessage() + .getContent(); + } + if (request instanceof LlamaCompletionRequest completionRequest) { + return CompletionClientProvider.getLlamaClient() + .getChatCompletion(completionRequest) + .getContent(); + } + + throw new IllegalStateException("Unknown request type: " + request.getClass()); } public boolean isAllowed() { diff --git a/src/main/java/ee/carlrobert/codegpt/completions/MethodNameLookupListener.java b/src/main/java/ee/carlrobert/codegpt/completions/MethodNameLookupListener.java index 8c69c1b0..53e2f8ea 100644 --- a/src/main/java/ee/carlrobert/codegpt/completions/MethodNameLookupListener.java +++ b/src/main/java/ee/carlrobert/codegpt/completions/MethodNameLookupListener.java @@ -55,17 +55,21 @@ public class MethodNameLookupListener implements LookupManagerListener { LookupImpl lookup, Application application, String prompt) { - CompletionRequestService.getInstance().getLookupCompletion(prompt) - .ifPresent(response -> { - for (var value : response.split(",")) { - application.invokeLater(() -> application.runReadAction(() -> { - lookup.addItem( - LookupElementBuilder.create(value.trim()).withIcon(Icons.Sparkle), - PrefixMatcher.ALWAYS_TRUE); - lookup.refreshUi(true, true); - })); - } - }); + try { + var response = CompletionRequestService.getInstance().getLookupCompletion(prompt); + if (!response.isEmpty()) { + for (var value : response.split(",")) { + application.invokeLater(() -> application.runReadAction(() -> { + lookup.addItem( + LookupElementBuilder.create(value.trim()).withIcon(Icons.Sparkle), + PrefixMatcher.ALWAYS_TRUE); + lookup.refreshUi(true, true); + })); + } + } + } catch (Exception e) { + throw new RuntimeException("Failed to add completion lookup values", e); + } } private enum PSIMethodMapping { diff --git a/src/main/java/ee/carlrobert/codegpt/completions/TotalUsageExceededException.java b/src/main/java/ee/carlrobert/codegpt/completions/TotalUsageExceededException.java index 3143455f..52a15587 100644 --- a/src/main/java/ee/carlrobert/codegpt/completions/TotalUsageExceededException.java +++ b/src/main/java/ee/carlrobert/codegpt/completions/TotalUsageExceededException.java @@ -1,5 +1,5 @@ package ee.carlrobert.codegpt.completions; -class TotalUsageExceededException extends RuntimeException { +public class TotalUsageExceededException extends RuntimeException { } diff --git a/src/main/java/ee/carlrobert/codegpt/settings/service/llama/LlamaSettings.java b/src/main/java/ee/carlrobert/codegpt/settings/service/llama/LlamaSettings.java index 84362d34..dd61e807 100644 --- a/src/main/java/ee/carlrobert/codegpt/settings/service/llama/LlamaSettings.java +++ b/src/main/java/ee/carlrobert/codegpt/settings/service/llama/LlamaSettings.java @@ -12,6 +12,7 @@ import com.intellij.openapi.components.Storage; import ee.carlrobert.codegpt.codecompletions.InfillPromptTemplate; import ee.carlrobert.codegpt.completions.HuggingFaceModel; import ee.carlrobert.codegpt.completions.llama.LlamaModel; +import ee.carlrobert.codegpt.completions.llama.PromptTemplate; import ee.carlrobert.codegpt.credentials.CredentialsStore; import ee.carlrobert.codegpt.settings.GeneralSettings; import ee.carlrobert.codegpt.settings.service.llama.form.LlamaSettingsForm; @@ -42,13 +43,16 @@ public class LlamaSettings implements PersistentStateComponent> extends JPanel super(new FlowLayout(FlowLayout.LEADING, 0, 0)); this.enumClass = enumClass; promptTemplateComboBox = new ComboBox<>(new EnumComboBoxModel<>(enumClass)); - promptTemplateComboBox.setSelectedItem(initiallySelectedTemplate); + promptTemplateComboBox.setSelectedItem( + initiallySelectedTemplate == null ? PromptTemplate.CODE_QWEN : initiallySelectedTemplate); promptTemplateComboBox.setEnabled(enabled); promptTemplateComboBox.addItemListener( item -> updatePromptTemplateHelpTooltip(enumClass.cast(item.getItem()))); @@ -45,7 +47,9 @@ public abstract class BasePromptTemplatePanel> extends JPanel CodeGPTBundle.get(helpTextKey), true); promptTemplateHelpText.setBorder(JBUI.Borders.empty(0, 4)); - updatePromptTemplateHelpTooltip(initiallySelectedTemplate); + if (initiallySelectedTemplate != null) { + updatePromptTemplateHelpTooltip(initiallySelectedTemplate); + } } public void setPromptTemplate(T promptTemplate) { diff --git a/src/main/java/ee/carlrobert/codegpt/settings/service/llama/form/LlamaServerPreferencesForm.java b/src/main/java/ee/carlrobert/codegpt/settings/service/llama/form/LlamaServerPreferencesForm.java index 6efa4ecc..77ed854a 100644 --- a/src/main/java/ee/carlrobert/codegpt/settings/service/llama/form/LlamaServerPreferencesForm.java +++ b/src/main/java/ee/carlrobert/codegpt/settings/service/llama/form/LlamaServerPreferencesForm.java @@ -388,8 +388,9 @@ public class LlamaServerPreferencesForm { } public PromptTemplate getPromptTemplate() { - return isRunLocalServer() ? llamaModelPreferencesForm.getPromptTemplate() + var template = isRunLocalServer() ? llamaModelPreferencesForm.getPromptTemplate() : remotePromptTemplatePanel.getPromptTemplate(); + return template == null ? PromptTemplate.CODE_QWEN : template; } public @Nullable String getApiKey() { diff --git a/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/ChatToolWindowTabPanel.java b/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/ChatToolWindowTabPanel.java index 28a37e72..ba271bf8 100644 --- a/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/ChatToolWindowTabPanel.java +++ b/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/ChatToolWindowTabPanel.java @@ -1,6 +1,5 @@ package ee.carlrobert.codegpt.toolwindow.chat; -import static ee.carlrobert.codegpt.completions.CompletionRequestProvider.getPromptWithContext; import static ee.carlrobert.codegpt.ui.UIUtil.createScrollPaneWithSmartScroller; import static java.lang.String.format; @@ -18,6 +17,7 @@ import ee.carlrobert.codegpt.actions.ActionType; import ee.carlrobert.codegpt.completions.CallParameters; import ee.carlrobert.codegpt.completions.CompletionRequestHandler; import ee.carlrobert.codegpt.completions.CompletionRequestService; +import ee.carlrobert.codegpt.completions.CompletionRequestUtil; import ee.carlrobert.codegpt.completions.ConversationType; import ee.carlrobert.codegpt.conversations.Conversation; import ee.carlrobert.codegpt.conversations.ConversationService; @@ -134,7 +134,8 @@ public class ChatToolWindowTabPanel implements Disposable { .toList(); message.setReferencedFilePaths(referencedFilePaths); message.setUserMessage(message.getPrompt()); - message.setPrompt(getPromptWithContext(referencedFiles, message.getPrompt())); + message.setPrompt( + CompletionRequestUtil.getPromptWithContext(referencedFiles, message.getPrompt())); totalTokensPanel.updateReferencedFilesTokens(referencedFiles); @@ -300,21 +301,22 @@ public class ChatToolWindowTabPanel implements Disposable { } promptBuilder.append(remainingText); - String highlightedTextMd = ""; + String selectedText = ""; + String selectedTextMd = ""; if (editor != null) { var selectionModel = editor.getSelectionModel(); - var selectedText = selectionModel.getSelectedText(); + selectedText = selectionModel.getSelectedText(); if (selectedText != null && !selectedText.isEmpty()) { var fileExtension = FileUtil.getFileExtension(editor.getVirtualFile().getName()); - highlightedTextMd = format("\n```%s\n%s\n```\n", fileExtension, selectedText); + selectedTextMd = format("\n```%s\n%s\n```\n", fileExtension, selectedText); selectionModel.removeSelection(); } } - message.setUserMessage(highlightedTextMd + promptBuilder); - message.setPrompt(highlightedTextMd + promptBuilder); + message.setUserMessage(selectedTextMd + promptBuilder); + message.setPrompt(selectedTextMd + promptBuilder); - sendMessage(message, ConversationType.DEFAULT, processEditorSelection(editor, message)); + sendMessage(message, ConversationType.DEFAULT, selectedText); return Unit.INSTANCE; } diff --git a/src/main/kotlin/ee/carlrobert/codegpt/completions/CompletionRequestFactory.kt b/src/main/kotlin/ee/carlrobert/codegpt/completions/CompletionRequestFactory.kt new file mode 100644 index 00000000..e8ed1136 --- /dev/null +++ b/src/main/kotlin/ee/carlrobert/codegpt/completions/CompletionRequestFactory.kt @@ -0,0 +1,56 @@ +package ee.carlrobert.codegpt.completions + +import ee.carlrobert.codegpt.completions.CompletionRequestUtil.EDIT_CODE_SYSTEM_PROMPT +import ee.carlrobert.codegpt.completions.CompletionRequestUtil.GENERATE_METHOD_NAMES_SYSTEM_PROMPT +import ee.carlrobert.codegpt.completions.factory.* +import ee.carlrobert.codegpt.settings.service.ServiceType +import ee.carlrobert.llm.completion.CompletionRequest + +interface CompletionRequestFactory { + fun createChatCompletionRequest(callParameters: CallParameters): CompletionRequest + fun createEditCodeCompletionRequest(input: String): CompletionRequest + fun createCommitMessageCompletionRequest( + systemPrompt: String, + gitDiff: String + ): CompletionRequest + fun createLookupCompletionRequest(prompt: String): CompletionRequest + + companion object { + @JvmStatic + fun getFactory(serviceType: ServiceType): CompletionRequestFactory { + return when (serviceType) { + ServiceType.CODEGPT -> CodeGPTRequestFactory() + ServiceType.OPENAI -> OpenAIRequestFactory() + ServiceType.CUSTOM_OPENAI -> CustomOpenAIRequestFactory() + ServiceType.AZURE -> AzureRequestFactory() + ServiceType.ANTHROPIC -> ClaudeRequestFactory() + ServiceType.GOOGLE -> GoogleRequestFactory() + ServiceType.OLLAMA -> OllamaRequestFactory() + ServiceType.LLAMA_CPP -> LlamaRequestFactory() + } + } + } +} + +abstract class BaseRequestFactory : CompletionRequestFactory { + override fun createEditCodeCompletionRequest(input: String): CompletionRequest { + return createBasicCompletionRequest(EDIT_CODE_SYSTEM_PROMPT, input, true) + } + + override fun createCommitMessageCompletionRequest( + systemPrompt: String, + gitDiff: String + ): CompletionRequest { + return createBasicCompletionRequest(systemPrompt, gitDiff, true) + } + + override fun createLookupCompletionRequest(prompt: String): CompletionRequest { + return createBasicCompletionRequest(GENERATE_METHOD_NAMES_SYSTEM_PROMPT, prompt) + } + + abstract fun createBasicCompletionRequest( + systemPrompt: String, + userPrompt: String, + stream: Boolean = false + ): CompletionRequest +} diff --git a/src/main/kotlin/ee/carlrobert/codegpt/completions/CompletionRequestUtil.kt b/src/main/kotlin/ee/carlrobert/codegpt/completions/CompletionRequestUtil.kt new file mode 100644 index 00000000..91f2ff0c --- /dev/null +++ b/src/main/kotlin/ee/carlrobert/codegpt/completions/CompletionRequestUtil.kt @@ -0,0 +1,42 @@ +package ee.carlrobert.codegpt.completions + +import com.intellij.openapi.components.service +import ee.carlrobert.codegpt.ReferencedFile +import ee.carlrobert.codegpt.settings.IncludedFilesSettings +import ee.carlrobert.codegpt.util.file.FileUtil.getResourceContent +import java.util.stream.Collectors + +object CompletionRequestUtil { + val GENERATE_COMMIT_MESSAGE_SYSTEM_PROMPT = + getResourceContent("/prompts/generate-commit-message.txt") + val FIX_COMPILE_ERRORS_SYSTEM_PROMPT = + getResourceContent("/prompts/fix-compile-errors.txt") + val GENERATE_METHOD_NAMES_SYSTEM_PROMPT = + getResourceContent("/prompts/method-name-generator.txt") + val EDIT_CODE_SYSTEM_PROMPT = + getResourceContent("/prompts/edit-code.txt") + + @JvmStatic + fun getPromptWithContext( + referencedFiles: List, + userPrompt: String? + ): String { + val includedFilesSettings = service().state + val repeatableContext = referencedFiles.stream() + .map { item: ReferencedFile -> + includedFilesSettings.repeatableContext + .replace("{FILE_PATH}", item.filePath) + .replace( + "{FILE_CONTENT}", String.format( + "```%s%n%s%n```", + item.fileExtension, + item.fileContent.trim { it <= ' ' }) + ) + } + .collect(Collectors.joining("\n\n")) + + return includedFilesSettings.promptTemplate + .replace("{REPEATABLE_CONTEXT}", repeatableContext) + .replace("{QUESTION}", userPrompt!!) + } +} \ No newline at end of file diff --git a/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/AzureRequestFactory.kt b/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/AzureRequestFactory.kt new file mode 100644 index 00000000..c7a18499 --- /dev/null +++ b/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/AzureRequestFactory.kt @@ -0,0 +1,34 @@ +package ee.carlrobert.codegpt.completions.factory + +import com.intellij.openapi.components.service +import ee.carlrobert.codegpt.completions.BaseRequestFactory +import ee.carlrobert.codegpt.completions.CallParameters +import ee.carlrobert.codegpt.completions.factory.OpenAIRequestFactory.Companion.buildOpenAIMessages +import ee.carlrobert.codegpt.settings.configuration.ConfigurationSettings +import ee.carlrobert.llm.client.openai.completion.request.OpenAIChatCompletionRequest +import ee.carlrobert.llm.completion.CompletionRequest + +class AzureRequestFactory : BaseRequestFactory() { + + override fun createChatCompletionRequest(callParameters: CallParameters): OpenAIChatCompletionRequest { + val configuration = service().state + val requestBuilder: OpenAIChatCompletionRequest.Builder = + OpenAIChatCompletionRequest.Builder(buildOpenAIMessages(null, callParameters)) + .setMaxTokens(configuration.maxTokens) + .setStream(true) + .setTemperature(configuration.temperature.toDouble()) + return requestBuilder.build() + } + + override fun createBasicCompletionRequest( + systemPrompt: String, + userPrompt: String, + stream: Boolean + ): CompletionRequest { + return OpenAIRequestFactory.createBasicCompletionRequest( + systemPrompt, + userPrompt, + isStream = stream + ) + } +} diff --git a/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/ClaudeRequestFactory.kt b/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/ClaudeRequestFactory.kt new file mode 100644 index 00000000..66ee4b12 --- /dev/null +++ b/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/ClaudeRequestFactory.kt @@ -0,0 +1,72 @@ +package ee.carlrobert.codegpt.completions.factory + +import com.intellij.openapi.components.service +import ee.carlrobert.codegpt.completions.BaseRequestFactory +import ee.carlrobert.codegpt.completions.CallParameters +import ee.carlrobert.codegpt.settings.configuration.ConfigurationSettings +import ee.carlrobert.codegpt.settings.persona.PersonaSettings +import ee.carlrobert.codegpt.settings.service.anthropic.AnthropicSettings +import ee.carlrobert.llm.client.anthropic.completion.* +import ee.carlrobert.llm.completion.CompletionRequest + +class ClaudeRequestFactory : BaseRequestFactory() { + + override fun createChatCompletionRequest(callParameters: CallParameters): ClaudeCompletionRequest { + return ClaudeCompletionRequest().apply { + model = service().state.model + maxTokens = service().state.maxTokens + isStream = true + system = PersonaSettings.getSystemPrompt() + + messages = callParameters.conversation.messages + .filter { it.response != null && it.response.isNotEmpty() } + .flatMap { prevMessage -> + sequenceOf( + ClaudeCompletionStandardMessage("user", prevMessage.prompt), + ClaudeCompletionStandardMessage("assistant", prevMessage.response) + ) + } + .toList() + + when { + callParameters.imageMediaType != null && callParameters.imageData.isNotEmpty() -> { + messages.add( + ClaudeCompletionDetailedMessage( + "user", + listOf( + ClaudeMessageImageContent( + ClaudeBase64Source( + callParameters.imageMediaType, + callParameters.imageData + ) + ), + ClaudeMessageTextContent(callParameters.message.prompt) + ) + ) + ) + } + + else -> { + messages.add( + ClaudeCompletionStandardMessage("user", callParameters.message.prompt) + ) + } + } + } + } + + override fun createBasicCompletionRequest( + systemPrompt: String, + userPrompt: String, + stream: Boolean + ): CompletionRequest { + return ClaudeCompletionRequest().apply { + system = systemPrompt + isStream = stream + maxTokens = service().state.maxTokens + model = service().state.model + messages = + listOf(ClaudeCompletionStandardMessage("user", userPrompt)) + } + } +} diff --git a/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/CodeGPTRequestFactory.kt b/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/CodeGPTRequestFactory.kt new file mode 100644 index 00000000..7c6a4f84 --- /dev/null +++ b/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/CodeGPTRequestFactory.kt @@ -0,0 +1,48 @@ +package ee.carlrobert.codegpt.completions.factory + +import com.intellij.openapi.components.service +import ee.carlrobert.codegpt.completions.BaseRequestFactory +import ee.carlrobert.codegpt.completions.CallParameters +import ee.carlrobert.codegpt.completions.factory.OpenAIRequestFactory.Companion.buildOpenAIMessages +import ee.carlrobert.codegpt.settings.configuration.ConfigurationSettings +import ee.carlrobert.codegpt.settings.service.codegpt.CodeGPTServiceSettings +import ee.carlrobert.llm.client.openai.completion.request.OpenAIChatCompletionRequest +import ee.carlrobert.llm.client.openai.completion.request.RequestDocumentationDetails + +class CodeGPTRequestFactory : BaseRequestFactory() { + + override fun createChatCompletionRequest(callParameters: CallParameters): OpenAIChatCompletionRequest { + val model = service().state.chatCompletionSettings.model + val configuration = service().state + val requestBuilder: OpenAIChatCompletionRequest.Builder = + OpenAIChatCompletionRequest.Builder(buildOpenAIMessages(model, callParameters)) + .setModel(model) + .setMaxTokens(configuration.maxTokens) + .setStream(true) + .setTemperature(configuration.temperature.toDouble()) + if (callParameters.message.isWebSearchIncluded) { + requestBuilder.setWebSearchIncluded(true) + } + val documentationDetails = callParameters.message.documentationDetails + if (documentationDetails != null) { + val requestDocumentationDetails = RequestDocumentationDetails() + requestDocumentationDetails.name = documentationDetails.name + requestDocumentationDetails.url = documentationDetails.url + requestBuilder.setDocumentationDetails(requestDocumentationDetails) + } + return requestBuilder.build() + } + + override fun createBasicCompletionRequest( + systemPrompt: String, + userPrompt: String, + stream: Boolean + ): OpenAIChatCompletionRequest { + return OpenAIRequestFactory.createBasicCompletionRequest( + systemPrompt, + userPrompt, + service().state.chatCompletionSettings.model, + stream + ) + } +} diff --git a/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/CustomOpenAIRequestFactory.kt b/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/CustomOpenAIRequestFactory.kt new file mode 100644 index 00000000..41418282 --- /dev/null +++ b/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/CustomOpenAIRequestFactory.kt @@ -0,0 +1,108 @@ +package ee.carlrobert.codegpt.completions.factory + +import com.fasterxml.jackson.databind.ObjectMapper +import com.intellij.openapi.components.service +import ee.carlrobert.codegpt.completions.BaseRequestFactory +import ee.carlrobert.codegpt.completions.CallParameters +import ee.carlrobert.codegpt.credentials.CredentialsStore.CredentialKey +import ee.carlrobert.codegpt.credentials.CredentialsStore.getCredential +import ee.carlrobert.codegpt.settings.service.custom.CustomServiceChatCompletionSettingsState +import ee.carlrobert.codegpt.settings.service.custom.CustomServiceSettings +import ee.carlrobert.llm.client.openai.completion.request.OpenAIChatCompletionMessage +import ee.carlrobert.llm.client.openai.completion.request.OpenAIChatCompletionStandardMessage +import ee.carlrobert.llm.completion.CompletionRequest +import okhttp3.Request +import okhttp3.RequestBody.Companion.toRequestBody +import java.nio.charset.StandardCharsets + +class CustomOpenAIRequest(val request: Request) : CompletionRequest + +class CustomOpenAIRequestFactory : BaseRequestFactory() { + + override fun createChatCompletionRequest(callParameters: CallParameters): CustomOpenAIRequest { + val request = buildCustomOpenAIChatCompletionRequest( + service() + .state + .chatCompletionSettings, + OpenAIRequestFactory.buildOpenAIMessages(null, callParameters), + true, + getCredential(CredentialKey.CUSTOM_SERVICE_API_KEY) + ) + return CustomOpenAIRequest(request) + } + + override fun createBasicCompletionRequest( + systemPrompt: String, + userPrompt: String, + stream: Boolean + ): CompletionRequest { + val request = buildCustomOpenAIChatCompletionRequest( + service().state.chatCompletionSettings, + listOf( + OpenAIChatCompletionStandardMessage("system", systemPrompt), + OpenAIChatCompletionStandardMessage("user", userPrompt) + ), + stream, + getCredential(CredentialKey.CUSTOM_SERVICE_API_KEY) + ) + return CustomOpenAIRequest(request) + } + + companion object { + fun buildCustomOpenAICompletionRequest( + context: String, + url: String, + headers: MutableMap, + body: MutableMap, + credential: String? + ): Request { + val usedSettings = CustomServiceChatCompletionSettingsState() + usedSettings.body = body + usedSettings.headers = headers + usedSettings.url = url + return buildCustomOpenAIChatCompletionRequest( + usedSettings, + listOf(OpenAIChatCompletionStandardMessage("user", context)), + true, + credential + ) + } + + fun buildCustomOpenAIChatCompletionRequest( + settings: CustomServiceChatCompletionSettingsState, + messages: List, + streamRequest: Boolean, + credential: String? + ): Request { + val requestBuilder = Request.Builder().url(requireNotNull(settings.url).trim()) + + settings.headers.forEach { (key, value) -> + val headerValue = when { + credential != null && value.contains("\$CUSTOM_SERVICE_API_KEY") -> + value.replace("\$CUSTOM_SERVICE_API_KEY", credential) + + else -> value + } + requestBuilder.addHeader(key, headerValue) + } + + val body = settings.body.mapValues { (key, value) -> + when { + !streamRequest && key == "stream" -> false + value is String && value.trim() == "\$OPENAI_MESSAGES" -> messages + else -> value + } + } + + return try { + val requestBody = ObjectMapper().writerWithDefaultPrettyPrinter() + .writeValueAsString(body) + .toByteArray(StandardCharsets.UTF_8) + .toRequestBody() + requestBuilder.post(requestBody).build() + } catch (e: Exception) { + throw RuntimeException(e) + } + } + } +} diff --git a/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/GoogleRequestFactory.kt b/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/GoogleRequestFactory.kt new file mode 100644 index 00000000..f770de60 --- /dev/null +++ b/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/GoogleRequestFactory.kt @@ -0,0 +1,199 @@ +package ee.carlrobert.codegpt.completions.factory + +import com.intellij.openapi.components.service +import ee.carlrobert.codegpt.EncodingManager +import ee.carlrobert.codegpt.completions.BaseRequestFactory +import ee.carlrobert.codegpt.completions.CallParameters +import ee.carlrobert.codegpt.completions.CompletionRequestUtil.FIX_COMPILE_ERRORS_SYSTEM_PROMPT +import ee.carlrobert.codegpt.completions.ConversationType +import ee.carlrobert.codegpt.completions.TotalUsageExceededException +import ee.carlrobert.codegpt.conversations.ConversationsState +import ee.carlrobert.codegpt.settings.configuration.ConfigurationSettings +import ee.carlrobert.codegpt.settings.persona.PersonaSettings +import ee.carlrobert.codegpt.settings.service.google.GoogleSettings +import ee.carlrobert.codegpt.util.file.FileUtil +import ee.carlrobert.llm.client.google.completion.GoogleCompletionContent +import ee.carlrobert.llm.client.google.completion.GoogleCompletionRequest +import ee.carlrobert.llm.client.google.completion.GoogleContentPart +import ee.carlrobert.llm.client.google.completion.GoogleGenerationConfig +import ee.carlrobert.llm.client.google.models.GoogleModel +import java.io.IOException +import java.nio.file.Files +import java.nio.file.Path + +class GoogleRequestFactory : BaseRequestFactory() { + + override fun createChatCompletionRequest(callParameters: CallParameters): GoogleCompletionRequest { + val configuration = service().state + val messages = buildGoogleMessages(service().state.model, callParameters) + return GoogleCompletionRequest.Builder(messages) + .generationConfig( + GoogleGenerationConfig.Builder() + .maxOutputTokens(configuration.maxTokens) + .temperature(configuration.temperature.toDouble()).build() + ) + .build() + } + + override fun createBasicCompletionRequest( + systemPrompt: String, + userPrompt: String, + stream: Boolean + ): GoogleCompletionRequest { + val configuration = service().state + return GoogleCompletionRequest.Builder( + listOf( + GoogleCompletionContent("user", listOf(systemPrompt)), + GoogleCompletionContent("model", listOf("Understood.")), + GoogleCompletionContent("user", listOf(userPrompt)) + ) + ) + .generationConfig( + GoogleGenerationConfig.Builder() + .maxOutputTokens(configuration.maxTokens) + .temperature(configuration.temperature.toDouble()).build() + ) + .build() + } + + private fun buildGoogleMessages( + model: String?, + callParameters: CallParameters + ): List { + val messages = buildGoogleMessages(callParameters) + + if (model == null) { + return messages + } + + val encodingManager = service() + val totalUsage = messages.parallelStream() + .mapToInt { message -> + encodingManager.countMessageTokens( + message.role, + message.parts.joinToString(",") { it.text ?: "" } + ) + } + .sum() + service().state.maxTokens + + return GoogleModel.findByCode(model)?.let { googleModel -> + if (totalUsage <= googleModel.maxTokens) { + messages + } else { + tryReducingGoogleMessagesOrThrow( + messages, + callParameters.conversation.isDiscardTokenLimit, + totalUsage, + googleModel.maxTokens + ) + } + } ?: messages + } + + private fun buildGoogleMessages(callParameters: CallParameters): List { + val message = callParameters.message + val messages = mutableListOf() + + when (callParameters.conversationType) { + ConversationType.DEFAULT -> { + messages.add( + GoogleCompletionContent( + "user", + listOf(PersonaSettings.getSystemPrompt()) + ) + ) + messages.add(GoogleCompletionContent("model", listOf("Understood."))) + } + + ConversationType.FIX_COMPILE_ERRORS -> { + messages.add( + GoogleCompletionContent("user", listOf(FIX_COMPILE_ERRORS_SYSTEM_PROMPT)) + ) + messages.add(GoogleCompletionContent("model", listOf("Understood."))) + } + + else -> {} + } + + for (prevMessage in callParameters.conversation.messages) { + if (callParameters.isRetry && prevMessage.id == message.id) { + break + } + + prevMessage.imageFilePath?.takeIf { it.isNotEmpty() }?.let { imagePath -> + try { + val imageData = Files.readAllBytes(Path.of(imagePath)) + val imageMediaType = + FileUtil.getImageMediaType(Path.of(imagePath).fileName.toString()) + messages.add( + GoogleCompletionContent( + listOf( + GoogleContentPart( + null, + GoogleContentPart.Blob(imageMediaType, imageData) + ), + GoogleContentPart(prevMessage.prompt) + ), "user" + ) + ) + } catch (e: IOException) { + throw RuntimeException(e) + } + } ?: messages.add(GoogleCompletionContent("user", listOf(prevMessage.prompt))) + + messages.add(GoogleCompletionContent("model", listOf(prevMessage.response))) + } + + if (callParameters.imageMediaType != null && callParameters.imageData.isNotEmpty()) { + messages.add( + GoogleCompletionContent( + listOf( + GoogleContentPart( + null, + GoogleContentPart.Blob( + callParameters.imageMediaType, + callParameters.imageData + ) + ), + GoogleContentPart(message.prompt) + ), "user" + ) + ) + } else { + messages.add(GoogleCompletionContent("user", listOf(message.prompt))) + } + + return messages + } + + private fun tryReducingGoogleMessagesOrThrow( + messages: List, + discardTokenLimit: Boolean, + totalUsage: Int, + modelMaxTokens: Int + ): List { + if (!service().state!!.discardAllTokenLimits) { + if (!discardTokenLimit) { + throw TotalUsageExceededException() + } + } + + val encodingManager = EncodingManager.getInstance() + var currentUsage = totalUsage + + // skip the system prompt + val updatedMessages = messages.mapIndexed { index, message -> + if (index == 0 || currentUsage <= modelMaxTokens) { + message + } else { + currentUsage -= encodingManager.countMessageTokens( + message.role, + message.parts.joinToString(",") { it.text } + ) + null + } + } + + return updatedMessages.filterNotNull() + } +} diff --git a/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/LlamaRequestFactory.kt b/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/LlamaRequestFactory.kt new file mode 100644 index 00000000..747f32dd --- /dev/null +++ b/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/LlamaRequestFactory.kt @@ -0,0 +1,78 @@ +package ee.carlrobert.codegpt.completions.factory + +import com.intellij.openapi.components.service +import ee.carlrobert.codegpt.completions.BaseRequestFactory +import ee.carlrobert.codegpt.completions.CallParameters +import ee.carlrobert.codegpt.completions.CompletionRequestUtil.FIX_COMPILE_ERRORS_SYSTEM_PROMPT +import ee.carlrobert.codegpt.completions.ConversationType +import ee.carlrobert.codegpt.completions.llama.LlamaModel +import ee.carlrobert.codegpt.settings.configuration.ConfigurationSettings +import ee.carlrobert.codegpt.settings.configuration.ConfigurationSettings.Companion.getState +import ee.carlrobert.codegpt.settings.persona.PersonaSettings.Companion.getSystemPrompt +import ee.carlrobert.codegpt.settings.service.llama.LlamaSettings +import ee.carlrobert.llm.client.llama.completion.LlamaCompletionRequest + +class LlamaRequestFactory : BaseRequestFactory() { + + override fun createChatCompletionRequest(callParameters: CallParameters): LlamaCompletionRequest { + val settings = service().state + val promptTemplate = if (settings.isRunLocalServer) { + if (settings.isUseCustomModel) + settings.localModelPromptTemplate + else + LlamaModel.findByHuggingFaceModel(settings.huggingFaceModel).promptTemplate + } else { + settings.remoteModelPromptTemplate + } + + val systemPrompt = + if (callParameters.conversationType == ConversationType.FIX_COMPILE_ERRORS) + FIX_COMPILE_ERRORS_SYSTEM_PROMPT + else + getSystemPrompt() + + val prompt = promptTemplate.buildPrompt( + systemPrompt, + callParameters.message.prompt, + callParameters.conversation.messages + ) + val configuration = getState() + return LlamaCompletionRequest.Builder(prompt) + .setN_predict(configuration.maxTokens) + .setTemperature(configuration.temperature.toDouble()) + .setTop_k(settings.topK) + .setTop_p(settings.topP) + .setMin_p(settings.minP) + .setRepeat_penalty(settings.repeatPenalty) + .setStop(promptTemplate.stopTokens) + .build() + } + + override fun createBasicCompletionRequest( + systemPrompt: String, + userPrompt: String, + stream: Boolean + ): LlamaCompletionRequest { + val settings = service().state + val promptTemplate = if (settings.isRunLocalServer) { + if (settings.isUseCustomModel) + settings.localModelPromptTemplate + else + LlamaModel.findByHuggingFaceModel(settings.huggingFaceModel).promptTemplate + } else { + settings.remoteModelPromptTemplate + } + val configuration = service().state + val finalPrompt = + promptTemplate.buildPrompt(systemPrompt, userPrompt, listOf()) + return LlamaCompletionRequest.Builder(finalPrompt) + .setN_predict(configuration.maxTokens) + .setTemperature(configuration.temperature.toDouble()) + .setTop_k(settings.topK) + .setTop_p(settings.topP) + .setMin_p(settings.minP) + .setStream(stream) + .setRepeat_penalty(settings.repeatPenalty) + .build() + } +} diff --git a/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/OllamaRequestFactory.kt b/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/OllamaRequestFactory.kt new file mode 100644 index 00000000..80fc9954 --- /dev/null +++ b/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/OllamaRequestFactory.kt @@ -0,0 +1,101 @@ +package ee.carlrobert.codegpt.completions.factory + +import com.intellij.openapi.components.service +import ee.carlrobert.codegpt.completions.BaseRequestFactory +import ee.carlrobert.codegpt.completions.CallParameters +import ee.carlrobert.codegpt.completions.CompletionRequestUtil.FIX_COMPILE_ERRORS_SYSTEM_PROMPT +import ee.carlrobert.codegpt.completions.ConversationType +import ee.carlrobert.codegpt.settings.configuration.ConfigurationSettings +import ee.carlrobert.codegpt.settings.persona.PersonaSettings +import ee.carlrobert.codegpt.settings.service.ollama.OllamaSettings +import ee.carlrobert.llm.client.ollama.completion.request.OllamaChatCompletionMessage +import ee.carlrobert.llm.client.ollama.completion.request.OllamaChatCompletionRequest +import ee.carlrobert.llm.client.ollama.completion.request.OllamaParameters +import java.io.IOException +import java.nio.file.Files +import java.nio.file.Path +import java.util.* + +class OllamaRequestFactory : BaseRequestFactory() { + + override fun createChatCompletionRequest(callParameters: CallParameters): OllamaChatCompletionRequest { + val configuration = service().state + val settings = service().state + return OllamaChatCompletionRequest.Builder( + settings.model, + buildOllamaMessages(callParameters) + ) + .setStream(true) + .setOptions( + OllamaParameters.Builder() + .numPredict(configuration.maxTokens) + .temperature(configuration.temperature.toDouble()) + .build() + ) + .build() + } + + override fun createBasicCompletionRequest( + systemPrompt: String, + userPrompt: String, + stream: Boolean + ): OllamaChatCompletionRequest { + return OllamaChatCompletionRequest.Builder( + service().state.model, + listOf( + OllamaChatCompletionMessage("system", systemPrompt, null), + OllamaChatCompletionMessage("user", userPrompt, null) + ) + ) + .setStream(stream) + .build() + } + + private fun buildOllamaMessages(callParameters: CallParameters): List { + val message = callParameters.message + val messages = mutableListOf() + + when (callParameters.conversationType) { + ConversationType.DEFAULT -> messages.add( + OllamaChatCompletionMessage("system", PersonaSettings.getSystemPrompt(), null) + ) + + ConversationType.FIX_COMPILE_ERRORS -> messages.add( + OllamaChatCompletionMessage("system", FIX_COMPILE_ERRORS_SYSTEM_PROMPT, null) + ) + + else -> {} + } + + for (prevMessage in callParameters.conversation.messages) { + if (callParameters.isRetry && prevMessage.id == message.id) break + + prevMessage.imageFilePath?.takeIf { it.isNotEmpty() }?.let { imagePath -> + try { + val imageBytes = Files.readAllBytes(Path.of(imagePath)) + val imageBase64 = Base64.getEncoder().encodeToString(imageBytes) + messages.add( + OllamaChatCompletionMessage( + "user", + prevMessage.prompt, + listOf(imageBase64) + ) + ) + } catch (e: IOException) { + throw RuntimeException(e) + } + } ?: messages.add(OllamaChatCompletionMessage("user", prevMessage.prompt, null)) + + messages.add(OllamaChatCompletionMessage("assistant", prevMessage.response, null)) + } + + if (callParameters.imageMediaType != null && callParameters.imageData.isNotEmpty()) { + val imageBase64 = Base64.getEncoder().encodeToString(callParameters.imageData) + messages.add(OllamaChatCompletionMessage("user", message.prompt, listOf(imageBase64))) + } else { + messages.add(OllamaChatCompletionMessage("user", message.prompt, null)) + } + + return messages + } +} \ No newline at end of file diff --git a/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/OpenAIRequestFactory.kt b/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/OpenAIRequestFactory.kt new file mode 100644 index 00000000..6401a488 --- /dev/null +++ b/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/OpenAIRequestFactory.kt @@ -0,0 +1,227 @@ +package ee.carlrobert.codegpt.completions.factory + +import com.intellij.openapi.components.service +import ee.carlrobert.codegpt.EncodingManager +import ee.carlrobert.codegpt.completions.CallParameters +import ee.carlrobert.codegpt.completions.CompletionRequestFactory +import ee.carlrobert.codegpt.completions.CompletionRequestUtil.EDIT_CODE_SYSTEM_PROMPT +import ee.carlrobert.codegpt.completions.CompletionRequestUtil.FIX_COMPILE_ERRORS_SYSTEM_PROMPT +import ee.carlrobert.codegpt.completions.CompletionRequestUtil.GENERATE_METHOD_NAMES_SYSTEM_PROMPT +import ee.carlrobert.codegpt.completions.ConversationType +import ee.carlrobert.codegpt.completions.TotalUsageExceededException +import ee.carlrobert.codegpt.conversations.ConversationsState +import ee.carlrobert.codegpt.settings.configuration.ConfigurationSettings +import ee.carlrobert.codegpt.settings.configuration.ConfigurationSettings.Companion.getState +import ee.carlrobert.codegpt.settings.persona.PersonaSettings.Companion.getSystemPrompt +import ee.carlrobert.codegpt.settings.service.openai.OpenAISettings +import ee.carlrobert.codegpt.util.file.FileUtil.getImageMediaType +import ee.carlrobert.llm.client.openai.completion.OpenAIChatCompletionModel +import ee.carlrobert.llm.client.openai.completion.request.* +import ee.carlrobert.llm.completion.CompletionRequest +import java.io.IOException +import java.nio.file.Files +import java.nio.file.Path + +class OpenAIRequestFactory : CompletionRequestFactory { + + override fun createChatCompletionRequest(callParameters: CallParameters): OpenAIChatCompletionRequest { + val model = service().state.model + val configuration = service().state + val requestBuilder: OpenAIChatCompletionRequest.Builder = + OpenAIChatCompletionRequest.Builder(buildOpenAIMessages(model, callParameters)) + .setModel(model) + .setMaxTokens(configuration.maxTokens) + .setStream(true) + .setTemperature(configuration.temperature.toDouble()) + return requestBuilder.build() + } + + override fun createEditCodeCompletionRequest(input: String): OpenAIChatCompletionRequest { + return buildEditCodeRequest(input, service().state.model) + } + + override fun createCommitMessageCompletionRequest( + systemPrompt: String, + gitDiff: String + ): CompletionRequest { + return createBasicCompletionRequest(systemPrompt, gitDiff, isStream = true) + } + + override fun createLookupCompletionRequest(prompt: String): CompletionRequest { + return createBasicCompletionRequest(GENERATE_METHOD_NAMES_SYSTEM_PROMPT, prompt) + } + + companion object { + fun buildEditCodeRequest( + input: String, + model: String? = null + ): OpenAIChatCompletionRequest { + return createBasicCompletionRequest(EDIT_CODE_SYSTEM_PROMPT, input, model, true) + } + + fun buildOpenAIMessages( + model: String?, + callParameters: CallParameters + ): List { + val messages = buildOpenAIMessages(callParameters) + + if (model == null) { + return messages + } + + val encodingManager = EncodingManager.getInstance() + val totalUsage = messages.parallelStream() + .mapToInt { message: OpenAIChatCompletionMessage? -> + encodingManager.countMessageTokens( + message + ) + } + .sum() + getState().maxTokens + val modelMaxTokens: Int + try { + modelMaxTokens = OpenAIChatCompletionModel.findByCode(model).maxTokens + + if (totalUsage <= modelMaxTokens) { + return messages + } + } catch (ex: NoSuchElementException) { + return messages + } + return tryReducingMessagesOrThrow( + messages, + callParameters.conversation.isDiscardTokenLimit, + totalUsage, + modelMaxTokens + ) + } + + private fun buildOpenAIMessages( + callParameters: CallParameters + ): MutableList { + val message = callParameters.message + val messages = mutableListOf() + if (callParameters.conversationType == ConversationType.DEFAULT) { + val sessionPersonaDetails = callParameters.message.personaDetails + if (callParameters.message.personaDetails == null) { + messages.add( + OpenAIChatCompletionStandardMessage("system", getSystemPrompt()) + ) + } else { + messages.add( + OpenAIChatCompletionStandardMessage( + "system", + sessionPersonaDetails.instructions + ) + ) + } + } + if (callParameters.conversationType == ConversationType.FIX_COMPILE_ERRORS) { + messages.add( + OpenAIChatCompletionStandardMessage("system", FIX_COMPILE_ERRORS_SYSTEM_PROMPT) + ) + } + + for (prevMessage in callParameters.conversation.messages) { + if (callParameters.isRetry && prevMessage.id == message.id) { + break + } + val prevMessageImageFilePath = prevMessage.imageFilePath + if (!prevMessageImageFilePath.isNullOrEmpty()) { + try { + val imageFilePath = Path.of(prevMessageImageFilePath) + val imageData = Files.readAllBytes(imageFilePath) + val imageMediaType = getImageMediaType(imageFilePath.fileName.toString()) + messages.add( + OpenAIChatCompletionDetailedMessage( + "user", + listOf( + OpenAIMessageImageURLContent( + OpenAIImageUrl( + imageMediaType, + imageData + ) + ), + OpenAIMessageTextContent(prevMessage.prompt) + ) + ) + ) + } catch (e: IOException) { + throw RuntimeException(e) + } + } else { + messages.add(OpenAIChatCompletionStandardMessage("user", prevMessage.prompt)) + } + messages.add( + OpenAIChatCompletionStandardMessage("assistant", prevMessage.response) + ) + } + + if (callParameters.imageMediaType != null && callParameters.imageData.isNotEmpty()) { + messages.add( + OpenAIChatCompletionDetailedMessage( + "user", + listOf( + OpenAIMessageImageURLContent( + OpenAIImageUrl( + callParameters.imageMediaType, + callParameters.imageData + ) + ), + OpenAIMessageTextContent(message.prompt) + ) + ) + ) + } else { + messages.add(OpenAIChatCompletionStandardMessage("user", message.prompt)) + } + return messages + } + + private fun tryReducingMessagesOrThrow( + messages: MutableList, + discardTokenLimit: Boolean, + totalInputUsage: Int, + modelMaxTokens: Int + ): List { + val result: MutableList = messages.toMutableList() + var totalUsage = totalInputUsage + if (!ConversationsState.getInstance().discardAllTokenLimits) { + if (!discardTokenLimit) { + throw TotalUsageExceededException() + } + } + val encodingManager = EncodingManager.getInstance() + // skip the system prompt + for (i in 1 until result.size - 1) { + if (totalUsage <= modelMaxTokens) { + break + } + + val message = result[i] + if (message is OpenAIChatCompletionStandardMessage) { + totalUsage -= encodingManager.countMessageTokens(message) + result[i] = null + } + } + + return result.filterNotNull() + } + + fun createBasicCompletionRequest( + systemPrompt: String, + userPrompt: String, + model: String? = null, + isStream: Boolean = false + ): OpenAIChatCompletionRequest { + return OpenAIChatCompletionRequest.Builder( + listOf( + OpenAIChatCompletionStandardMessage("system", systemPrompt), + OpenAIChatCompletionStandardMessage("user", userPrompt) + ) + ) + .setModel(model) + .setStream(isStream) + .build() + } + } +} diff --git a/src/main/kotlin/ee/carlrobert/codegpt/settings/configuration/ConfigurationSettings.kt b/src/main/kotlin/ee/carlrobert/codegpt/settings/configuration/ConfigurationSettings.kt index a330cdb2..ab8d5201 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/settings/configuration/ConfigurationSettings.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/settings/configuration/ConfigurationSettings.kt @@ -2,7 +2,7 @@ package ee.carlrobert.codegpt.settings.configuration import com.intellij.openapi.components.* import ee.carlrobert.codegpt.actions.editor.EditorActionsUtil -import ee.carlrobert.codegpt.completions.CompletionRequestProvider.GENERATE_COMMIT_MESSAGE_SYSTEM_PROMPT +import ee.carlrobert.codegpt.completions.CompletionRequestUtil.GENERATE_COMMIT_MESSAGE_SYSTEM_PROMPT import kotlin.math.max import kotlin.math.min diff --git a/src/main/kotlin/ee/carlrobert/codegpt/settings/service/LlamaServiceConfigurable.kt b/src/main/kotlin/ee/carlrobert/codegpt/settings/service/LlamaServiceConfigurable.kt index 921fda1f..12fa965e 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/settings/service/LlamaServiceConfigurable.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/settings/service/LlamaServiceConfigurable.kt @@ -1,6 +1,7 @@ package ee.carlrobert.codegpt.settings.service import com.intellij.openapi.components.service +import com.intellij.openapi.diagnostic.thisLogger import com.intellij.openapi.options.Configurable import ee.carlrobert.codegpt.credentials.CredentialsStore.CredentialKey.LLAMA_API_KEY import ee.carlrobert.codegpt.credentials.CredentialsStore.getCredential diff --git a/src/main/kotlin/ee/carlrobert/codegpt/settings/service/custom/form/CustomServiceChatCompletionForm.kt b/src/main/kotlin/ee/carlrobert/codegpt/settings/service/custom/form/CustomServiceChatCompletionForm.kt index 9f8d0385..56c90bc0 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/settings/service/custom/form/CustomServiceChatCompletionForm.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/settings/service/custom/form/CustomServiceChatCompletionForm.kt @@ -4,8 +4,8 @@ import com.intellij.openapi.application.runInEdt import com.intellij.openapi.ui.MessageType import com.intellij.util.ui.FormBuilder import ee.carlrobert.codegpt.CodeGPTBundle -import ee.carlrobert.codegpt.completions.CompletionRequestProvider import ee.carlrobert.codegpt.completions.CompletionRequestService +import ee.carlrobert.codegpt.completions.factory.CustomOpenAIRequestFactory import ee.carlrobert.codegpt.settings.service.custom.CustomServiceChatCompletionSettingsState import ee.carlrobert.codegpt.settings.service.custom.CustomServiceFormTabbedPane import ee.carlrobert.codegpt.ui.OverlayUtil @@ -16,7 +16,6 @@ import okhttp3.sse.EventSource import java.awt.BorderLayout import javax.swing.JButton import javax.swing.JPanel -import javax.swing.SwingUtilities class CustomServiceChatCompletionForm( state: CustomServiceChatCompletionSettingsState, @@ -73,7 +72,7 @@ class CustomServiceChatCompletionForm( private fun testConnection() { CompletionRequestService.getInstance().getCustomOpenAIChatCompletionAsync( - CompletionRequestProvider.buildCustomOpenAICompletionRequest( + CustomOpenAIRequestFactory.buildCustomOpenAICompletionRequest( "Test", urlField.text, tabbedPane.headers, diff --git a/src/test/kotlin/ee/carlrobert/codegpt/completions/CompletionRequestProviderTest.kt b/src/test/kotlin/ee/carlrobert/codegpt/completions/CompletionRequestProviderTest.kt index 3c8f033d..d6051cc1 100644 --- a/src/test/kotlin/ee/carlrobert/codegpt/completions/CompletionRequestProviderTest.kt +++ b/src/test/kotlin/ee/carlrobert/codegpt/completions/CompletionRequestProviderTest.kt @@ -1,6 +1,7 @@ package ee.carlrobert.codegpt.completions import com.intellij.openapi.components.service +import ee.carlrobert.codegpt.completions.factory.OpenAIRequestFactory import ee.carlrobert.codegpt.conversations.Conversation import ee.carlrobert.codegpt.conversations.ConversationService import ee.carlrobert.codegpt.conversations.message.Message @@ -13,147 +14,157 @@ import testsupport.IntegrationTest class CompletionRequestProviderTest : IntegrationTest() { - fun testChatCompletionRequestWithSystemPromptOverride() { - useOpenAIService(OpenAIChatCompletionModel.GPT_3_5.code) - service().state.selectedPersona.instructions = "TEST_SYSTEM_PROMPT" - val conversation = ConversationService.getInstance().startConversation() - val firstMessage = createDummyMessage(500) - val secondMessage = createDummyMessage(250) - conversation.addMessage(firstMessage) - conversation.addMessage(secondMessage) + fun testChatCompletionRequestWithSystemPromptOverride() { + useOpenAIService(OpenAIChatCompletionModel.GPT_3_5.code) + service().state.selectedPersona.instructions = "TEST_SYSTEM_PROMPT" + val conversation = ConversationService.getInstance().startConversation() + val firstMessage = createDummyMessage(500) + val secondMessage = createDummyMessage(250) + conversation.addMessage(firstMessage) + conversation.addMessage(secondMessage) - val request = CompletionRequestProvider.buildOpenAIChatCompletionRequest( - OpenAIChatCompletionModel.GPT_3_5.code, - CallParameters( - conversation, - ConversationType.DEFAULT, - Message("TEST_CHAT_COMPLETION_PROMPT"), - null, - false)) + val request = OpenAIRequestFactory().createChatCompletionRequest( + CallParameters( + conversation, + ConversationType.DEFAULT, + Message("TEST_CHAT_COMPLETION_PROMPT"), + null, + false + ) + ) - assertThat(request.messages) - .extracting("role", "content") - .containsExactly( - Tuple.tuple("system", "TEST_SYSTEM_PROMPT"), - Tuple.tuple("user", "TEST_PROMPT"), - Tuple.tuple("assistant", firstMessage.response), - Tuple.tuple("user", "TEST_PROMPT"), - Tuple.tuple("assistant", secondMessage.response), - Tuple.tuple("user", "TEST_CHAT_COMPLETION_PROMPT")) - } + assertThat(request.messages) + .extracting("role", "content") + .containsExactly( + Tuple.tuple("system", "TEST_SYSTEM_PROMPT"), + Tuple.tuple("user", "TEST_PROMPT"), + Tuple.tuple("assistant", firstMessage.response), + Tuple.tuple("user", "TEST_PROMPT"), + Tuple.tuple("assistant", secondMessage.response), + Tuple.tuple("user", "TEST_CHAT_COMPLETION_PROMPT") + ) + } - fun testChatCompletionRequestWithoutSystemPromptOverride() { - useOpenAIService(OpenAIChatCompletionModel.GPT_3_5.code) - service().state.selectedPersona.instructions = DEFAULT_PROMPT - val conversation = ConversationService.getInstance().startConversation() - val firstMessage = createDummyMessage(500) - val secondMessage = createDummyMessage(250) - conversation.addMessage(firstMessage) - conversation.addMessage(secondMessage) + fun testChatCompletionRequestWithoutSystemPromptOverride() { + useOpenAIService(OpenAIChatCompletionModel.GPT_3_5.code) + service().state.selectedPersona.instructions = DEFAULT_PROMPT + val conversation = ConversationService.getInstance().startConversation() + val firstMessage = createDummyMessage(500) + val secondMessage = createDummyMessage(250) + conversation.addMessage(firstMessage) + conversation.addMessage(secondMessage) - val request = CompletionRequestProvider.buildOpenAIChatCompletionRequest( - OpenAIChatCompletionModel.GPT_3_5.code, - CallParameters( - conversation, - ConversationType.DEFAULT, - Message("TEST_CHAT_COMPLETION_PROMPT"), - null, - false)) + val request = OpenAIRequestFactory().createChatCompletionRequest( + CallParameters( + conversation, + ConversationType.DEFAULT, + Message("TEST_CHAT_COMPLETION_PROMPT"), + null, + false + ) + ) - assertThat(request.messages) - .extracting("role", "content") - .containsExactly( - Tuple.tuple("system", DEFAULT_PROMPT), - Tuple.tuple("user", "TEST_PROMPT"), - Tuple.tuple("assistant", firstMessage.response), - Tuple.tuple("user", "TEST_PROMPT"), - Tuple.tuple("assistant", secondMessage.response), - Tuple.tuple("user", "TEST_CHAT_COMPLETION_PROMPT")) - } + assertThat(request.messages) + .extracting("role", "content") + .containsExactly( + Tuple.tuple("system", DEFAULT_PROMPT), + Tuple.tuple("user", "TEST_PROMPT"), + Tuple.tuple("assistant", firstMessage.response), + Tuple.tuple("user", "TEST_PROMPT"), + Tuple.tuple("assistant", secondMessage.response), + Tuple.tuple("user", "TEST_CHAT_COMPLETION_PROMPT") + ) + } - fun testChatCompletionRequestRetry() { - useOpenAIService(OpenAIChatCompletionModel.GPT_3_5.code) - service().state.selectedPersona.instructions = "TEST_SYSTEM_PROMPT" - val conversation = ConversationService.getInstance().startConversation() - val firstMessage = createDummyMessage("FIRST_TEST_PROMPT", 500) - val secondMessage = createDummyMessage("SECOND_TEST_PROMPT", 250) - conversation.addMessage(firstMessage) - conversation.addMessage(secondMessage) + fun testChatCompletionRequestRetry() { + useOpenAIService(OpenAIChatCompletionModel.GPT_3_5.code) + service().state.selectedPersona.instructions = "TEST_SYSTEM_PROMPT" + val conversation = ConversationService.getInstance().startConversation() + val firstMessage = createDummyMessage("FIRST_TEST_PROMPT", 500) + val secondMessage = createDummyMessage("SECOND_TEST_PROMPT", 250) + conversation.addMessage(firstMessage) + conversation.addMessage(secondMessage) - val request = CompletionRequestProvider.buildOpenAIChatCompletionRequest( - OpenAIChatCompletionModel.GPT_3_5.code, - CallParameters( - conversation, - ConversationType.DEFAULT, - secondMessage, - null, - true)) + val request = OpenAIRequestFactory().createChatCompletionRequest( + CallParameters( + conversation, + ConversationType.DEFAULT, + secondMessage, + null, + true + ) + ) - assertThat(request.messages) - .extracting("role", "content") - .containsExactly( - Tuple.tuple("system", "TEST_SYSTEM_PROMPT"), - Tuple.tuple("user", "FIRST_TEST_PROMPT"), - Tuple.tuple("assistant", firstMessage.response), - Tuple.tuple("user", "SECOND_TEST_PROMPT")) - } + assertThat(request.messages) + .extracting("role", "content") + .containsExactly( + Tuple.tuple("system", "TEST_SYSTEM_PROMPT"), + Tuple.tuple("user", "FIRST_TEST_PROMPT"), + Tuple.tuple("assistant", firstMessage.response), + Tuple.tuple("user", "SECOND_TEST_PROMPT") + ) + } - fun testReducedChatCompletionRequest() { - useOpenAIService(OpenAIChatCompletionModel.GPT_3_5.code) - service().state.selectedPersona.instructions = DEFAULT_PROMPT - val conversation = Conversation() - conversation.addMessage(createDummyMessage(50)) - conversation.addMessage(createDummyMessage(100)) - conversation.addMessage(createDummyMessage(150)) - conversation.addMessage(createDummyMessage(1000)) - val remainingMessage = createDummyMessage(1000) - conversation.addMessage(remainingMessage) - conversation.discardTokenLimits() + fun testReducedChatCompletionRequest() { + useOpenAIService(OpenAIChatCompletionModel.GPT_3_5.code) + service().state.selectedPersona.instructions = DEFAULT_PROMPT + val conversation = Conversation() + conversation.addMessage(createDummyMessage(50)) + conversation.addMessage(createDummyMessage(100)) + conversation.addMessage(createDummyMessage(150)) + conversation.addMessage(createDummyMessage(1000)) + val remainingMessage = createDummyMessage(1000) + conversation.addMessage(remainingMessage) + conversation.discardTokenLimits() - val request = CompletionRequestProvider.buildOpenAIChatCompletionRequest( - OpenAIChatCompletionModel.GPT_3_5.code, - CallParameters( - conversation, - ConversationType.DEFAULT, - Message("TEST_CHAT_COMPLETION_PROMPT"), - null, - false)) + val request = OpenAIRequestFactory().createChatCompletionRequest( + CallParameters( + conversation, + ConversationType.DEFAULT, + Message("TEST_CHAT_COMPLETION_PROMPT"), + null, + false + ) + ) - assertThat(request.messages) - .extracting("role", "content") - .containsExactly( - Tuple.tuple("system", DEFAULT_PROMPT), - Tuple.tuple("user", "TEST_PROMPT"), - Tuple.tuple("assistant", remainingMessage.response), - Tuple.tuple("user", "TEST_CHAT_COMPLETION_PROMPT")) - } + assertThat(request.messages) + .extracting("role", "content") + .containsExactly( + Tuple.tuple("system", DEFAULT_PROMPT), + Tuple.tuple("user", "TEST_PROMPT"), + Tuple.tuple("assistant", remainingMessage.response), + Tuple.tuple("user", "TEST_CHAT_COMPLETION_PROMPT") + ) + } - fun testTotalUsageExceededException() { - useOpenAIService(OpenAIChatCompletionModel.GPT_3_5.code) - val conversation = ConversationService.getInstance().startConversation() - conversation.addMessage(createDummyMessage(1500)) - conversation.addMessage(createDummyMessage(1500)) - conversation.addMessage(createDummyMessage(1500)) + fun testTotalUsageExceededException() { + useOpenAIService(OpenAIChatCompletionModel.GPT_3_5.code) + val conversation = ConversationService.getInstance().startConversation() + conversation.addMessage(createDummyMessage(1500)) + conversation.addMessage(createDummyMessage(1500)) + conversation.addMessage(createDummyMessage(1500)) - assertThrows(TotalUsageExceededException::class.java) { - CompletionRequestProvider.buildOpenAIChatCompletionRequest( - OpenAIChatCompletionModel.GPT_3_5.code, - CallParameters( - conversation, - ConversationType.DEFAULT, - createDummyMessage(100), - null, - false)) } - } + assertThrows(TotalUsageExceededException::class.java) { + OpenAIRequestFactory().createChatCompletionRequest( + CallParameters( + conversation, + ConversationType.DEFAULT, + createDummyMessage(100), + null, + false + ) + ) + } + } - private fun createDummyMessage(tokenSize: Int): Message { - return createDummyMessage("TEST_PROMPT", tokenSize) - } + private fun createDummyMessage(tokenSize: Int): Message { + return createDummyMessage("TEST_PROMPT", tokenSize) + } - private fun createDummyMessage(prompt: String, tokenSize: Int): Message { - val message = Message(prompt) - // 'zz' = 1 token, prompt = 6 tokens, 7 tokens per message (GPT-3), - message.response = "zz".repeat((tokenSize) - 6 - 7) - return message - } + private fun createDummyMessage(prompt: String, tokenSize: Int): Message { + val message = Message(prompt) + // 'zz' = 1 token, prompt = 6 tokens, 7 tokens per message (GPT-3), + message.response = "zz".repeat((tokenSize) - 6 - 7) + return message + } } diff --git a/src/test/kotlin/ee/carlrobert/codegpt/telemetry/core/service/segment/IdentifyTraitsPersistenceTest.kt b/src/test/kotlin/ee/carlrobert/codegpt/telemetry/core/service/segment/IdentifyTraitsPersistenceTest.kt deleted file mode 100644 index 689f784a..00000000 --- a/src/test/kotlin/ee/carlrobert/codegpt/telemetry/core/service/segment/IdentifyTraitsPersistenceTest.kt +++ /dev/null @@ -1,67 +0,0 @@ -package ee.carlrobert.codegpt.telemetry.core.service.segment - -import com.google.gson.Gson -import com.google.gson.JsonSyntaxException -import com.intellij.util.io.write -import org.junit.jupiter.api.BeforeEach -import org.junit.jupiter.api.Test -import kotlin.io.path.Path -import kotlin.io.path.createTempFile -import kotlin.io.path.readText -import kotlin.test.assertEquals -import kotlin.test.assertFailsWith -import kotlin.test.assertNull - -private const val NOT_JSON = "}NOT]:JSON{" - -class IdentifyTraitsPersistenceTest { - private val gson = Gson() - private val persistence = IdentifyTraitsPersistence.INSTANCE - private val identifyTraits = IdentifyTraits("locale", "timezone", "os", "version", "distribution") - - @BeforeEach - fun setUp() { - persistence.identifyTraits = null - IdentifyTraitsPersistence.FILE = createTempFile() - } - - @Test - fun `get returns null when file does not exist`() { - IdentifyTraitsPersistence.FILE = Path(" ") - assertNull(persistence.get()) - } - - @Test - fun `get throws JsonSyntaxException when file contains malformed JSON`() { - IdentifyTraitsPersistence.FILE.write(NOT_JSON) - assertFailsWith { - persistence.get() - } - } - - @Test - fun `set saves the event to the file overwriting it`() { - IdentifyTraitsPersistence.FILE.write(NOT_JSON) - persistence.set(identifyTraits) - assertEquals(IdentifyTraitsPersistence.FILE.readText(), gson.toJson(identifyTraits)) - } - - @Test - fun `set saves the event to the file when file does not exist`() { - persistence.set(identifyTraits) - assertEquals(IdentifyTraitsPersistence.FILE.readText(), gson.toJson(identifyTraits)) - } - - @Test - fun `get returns the deserialized event`() { - IdentifyTraitsPersistence.FILE.write(gson.toJson(identifyTraits)) - assertEquals(identifyTraits, persistence.get()) - } - - @Test - fun `set throws IOException when file cannot be written and returns false`() { - IdentifyTraitsPersistence.FILE = IdentifyTraitsPersistence.FILE.resolve(" xyz ") - assertEquals(persistence.set(identifyTraits), false) - } - -} diff --git a/src/test/kotlin/ee/carlrobert/codegpt/toolwindow/chat/ChatToolWindowTabPanelTest.kt b/src/test/kotlin/ee/carlrobert/codegpt/toolwindow/chat/ChatToolWindowTabPanelTest.kt index a15274ff..829d9743 100644 --- a/src/test/kotlin/ee/carlrobert/codegpt/toolwindow/chat/ChatToolWindowTabPanelTest.kt +++ b/src/test/kotlin/ee/carlrobert/codegpt/toolwindow/chat/ChatToolWindowTabPanelTest.kt @@ -4,7 +4,7 @@ import com.intellij.openapi.components.service import ee.carlrobert.codegpt.CodeGPTKeys import ee.carlrobert.codegpt.EncodingManager import ee.carlrobert.codegpt.ReferencedFile -import ee.carlrobert.codegpt.completions.CompletionRequestProvider.FIX_COMPILE_ERRORS_SYSTEM_PROMPT +import ee.carlrobert.codegpt.completions.CompletionRequestUtil.FIX_COMPILE_ERRORS_SYSTEM_PROMPT import ee.carlrobert.codegpt.completions.ConversationType import ee.carlrobert.codegpt.completions.HuggingFaceModel import ee.carlrobert.codegpt.completions.llama.PromptTemplate.LLAMA @@ -15,270 +15,117 @@ import ee.carlrobert.codegpt.settings.persona.PersonaSettings import ee.carlrobert.codegpt.settings.service.llama.LlamaSettings import ee.carlrobert.llm.client.http.RequestEntity import ee.carlrobert.llm.client.http.exchange.StreamHttpExchange -import ee.carlrobert.llm.client.util.JSONUtil.e -import ee.carlrobert.llm.client.util.JSONUtil.jsonArray -import ee.carlrobert.llm.client.util.JSONUtil.jsonMap -import ee.carlrobert.llm.client.util.JSONUtil.jsonMapResponse +import ee.carlrobert.llm.client.util.JSONUtil.* import org.apache.http.HttpHeaders import org.assertj.core.api.Assertions.assertThat import testsupport.IntegrationTest import java.io.IOException import java.nio.file.Files import java.nio.file.Path -import java.util.Base64 -import java.util.Objects +import java.util.* class ChatToolWindowTabPanelTest : IntegrationTest() { - fun testSendingOpenAIMessage() { - useOpenAIService() - service().state.selectedPersona.instructions = "TEST_SYSTEM_PROMPT" - val message = Message("Hello!") - val conversation = ConversationService.getInstance().startConversation() - val panel = ChatToolWindowTabPanel(project, conversation) - expectOpenAI(StreamHttpExchange { request: RequestEntity -> - assertThat(request.uri.path).isEqualTo("/v1/chat/completions") - assertThat(request.method).isEqualTo("POST") - assertThat(request.headers[HttpHeaders.AUTHORIZATION]!![0]).isEqualTo("Bearer TEST_API_KEY") - assertThat(request.body) - .extracting( - "model", - "messages") - .containsExactly( - "gpt-4", - listOf( - mapOf("role" to "system", "content" to "TEST_SYSTEM_PROMPT"), - mapOf("role" to "user", "content" to "Hello!"))) - listOf( - jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("role", "assistant")))), - jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "Hel")))), - jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "lo")))), - jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "!"))))) - }) - - panel.sendMessage(message) - - waitExpecting { - val messages = conversation.messages - messages.isNotEmpty() && "Hello!" == messages[0].response - && panel.tokenDetails.conversationTokens > 0 - } - val encodingManager = EncodingManager.getInstance() - assertThat(panel.tokenDetails).extracting( - "systemPromptTokens", - "conversationTokens", - "userPromptTokens", - "highlightedTokens") - .containsExactly( - encodingManager.countTokens("TEST_SYSTEM_PROMPT"), - encodingManager.countConversationTokens(conversation), - 0, - 0) - assertThat(panel.conversation) - .isNotNull() - .extracting("id", "model", "clientCode", "discardTokenLimit") - .containsExactly( - conversation.id, - conversation.model, - conversation.clientCode, - false) - val messages = panel.conversation.messages - assertThat(messages).hasSize(1) - assertThat(messages[0]) - .extracting("id", "prompt", "response") - .containsExactly(message.id, message.prompt, message.response) - } - - fun testSendingOpenAIMessageWithReferencedContext() { - project.putUserData(CodeGPTKeys.SELECTED_FILES, listOf( - ReferencedFile("TEST_FILE_NAME_1", "TEST_FILE_PATH_1", "TEST_FILE_CONTENT_1"), - ReferencedFile("TEST_FILE_NAME_2", "TEST_FILE_PATH_2", "TEST_FILE_CONTENT_2"), - ReferencedFile("TEST_FILE_NAME_3", "TEST_FILE_PATH_3", "TEST_FILE_CONTENT_3"))) - useOpenAIService() - service().state.selectedPersona.instructions = "TEST_SYSTEM_PROMPT" - val message = Message("TEST_MESSAGE") - message.userMessage = "TEST_MESSAGE" - message.referencedFilePaths = listOf("TEST_FILE_PATH_1", "TEST_FILE_PATH_2", "TEST_FILE_PATH_3") - val conversation = ConversationService.getInstance().startConversation() - val panel = ChatToolWindowTabPanel(project, conversation) - expectOpenAI(StreamHttpExchange { request: RequestEntity -> - assertThat(request.uri.path).isEqualTo("/v1/chat/completions") - assertThat(request.method).isEqualTo("POST") - assertThat(request.headers[HttpHeaders.AUTHORIZATION]!![0]).isEqualTo("Bearer TEST_API_KEY") - assertThat(request.body) - .extracting( - "model", - "messages") - .containsExactly( - "gpt-4", - listOf( - mapOf("role" to "system", "content" to "TEST_SYSTEM_PROMPT"), - mapOf("role" to "user", "content" to """ - Use the following context to answer question at the end: - - File Path: TEST_FILE_PATH_1 - File Content: - ```TEST_FILE_NAME_1 - TEST_FILE_CONTENT_1 - ``` - - File Path: TEST_FILE_PATH_2 - File Content: - ```TEST_FILE_NAME_2 - TEST_FILE_CONTENT_2 - ``` - - File Path: TEST_FILE_PATH_3 - File Content: - ```TEST_FILE_NAME_3 - TEST_FILE_CONTENT_3 - ``` - - Question: TEST_MESSAGE""".trimIndent()))) - listOf( - jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("role", "assistant")))), - jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "Hel")))), - jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "lo")))), - jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "!"))))) - }) - - panel.sendMessage(message) - - waitExpecting { - val messages = conversation.messages - messages.isNotEmpty() && "Hello!" == messages[0].response - && panel.tokenDetails.conversationTokens > 0 - } - val encodingManager = EncodingManager.getInstance() - assertThat(panel.tokenDetails).extracting( - "systemPromptTokens", - "conversationTokens", - "userPromptTokens", - "highlightedTokens") - .containsExactly( - encodingManager.countTokens("TEST_SYSTEM_PROMPT"), - encodingManager.countConversationTokens(conversation), - 0, - 0) - assertThat(panel.conversation) - .isNotNull() - .extracting("id", "model", "clientCode", "discardTokenLimit") - .containsExactly( - conversation.id, - conversation.model, - conversation.clientCode, - false) - val messages = panel.conversation.messages - assertThat(messages).hasSize(1) - assertThat(messages[0]) - .extracting("id", "prompt", "response", "referencedFilePaths") - .containsExactly( - message.id, - message.prompt, - message.response, - listOf("TEST_FILE_PATH_1", "TEST_FILE_PATH_2", "TEST_FILE_PATH_3")) - } - - fun testSendingOpenAIMessageWithImage() { - val testImagePath = Objects.requireNonNull(javaClass.getResource("/images/test-image.png")).path - project.putUserData(CodeGPTKeys.IMAGE_ATTACHMENT_FILE_PATH, testImagePath) - useOpenAIService("gpt-4-vision-preview") - service().state.selectedPersona.instructions = "TEST_SYSTEM_PROMPT" - val message = Message("TEST_MESSAGE") - val conversation = ConversationService.getInstance().startConversation() - val panel = ChatToolWindowTabPanel(project, conversation) - expectOpenAI(StreamHttpExchange { request: RequestEntity -> - assertThat(request.uri.path).isEqualTo("/v1/chat/completions") - assertThat(request.method).isEqualTo("POST") - assertThat(request.headers[HttpHeaders.AUTHORIZATION]!![0]).isEqualTo("Bearer TEST_API_KEY") - try { - val testImageUrl = ("data:image/png;base64," - + Base64.getEncoder().encodeToString(Files.readAllBytes(Path.of(testImagePath)))) - assertThat(request.body) - .extracting("model", "messages") - .containsExactly( - "gpt-4-vision-preview", + fun testSendingOpenAIMessage() { + useOpenAIService() + service().state.selectedPersona.instructions = "TEST_SYSTEM_PROMPT" + val message = Message("Hello!") + val conversation = ConversationService.getInstance().startConversation() + val panel = ChatToolWindowTabPanel(project, conversation) + expectOpenAI(StreamHttpExchange { request: RequestEntity -> + assertThat(request.uri.path).isEqualTo("/v1/chat/completions") + assertThat(request.method).isEqualTo("POST") + assertThat(request.headers[HttpHeaders.AUTHORIZATION]!![0]).isEqualTo("Bearer TEST_API_KEY") + assertThat(request.body) + .extracting( + "model", + "messages" + ) + .containsExactly( + "gpt-4", + listOf( + mapOf("role" to "system", "content" to "TEST_SYSTEM_PROMPT"), + mapOf("role" to "user", "content" to "Hello!") + ) + ) listOf( - mapOf("role" to "system", "content" to "TEST_SYSTEM_PROMPT"), - mapOf("role" to "user", "content" to listOf( - mapOf( - "type" to "image_url", - "image_url" to mapOf("url" to testImageUrl)), - mapOf("type" to "text", "text" to "TEST_MESSAGE") - )))) - } catch (e: IOException) { - throw RuntimeException(e) - } - listOf( - jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("role", "assistant")))), - jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "Hel")))), - jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "lo")))), - jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "!"))))) - }) + jsonMapResponse( + "choices", + jsonArray(jsonMap("delta", jsonMap("role", "assistant"))) + ), + jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "Hel")))), + jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "lo")))), + jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "!")))) + ) + }) - panel.sendMessage(message) + panel.sendMessage(message) - waitExpecting { - val messages = conversation.messages - messages.isNotEmpty() && "Hello!" == messages[0].response - && panel.tokenDetails.conversationTokens > 0 + waitExpecting { + val messages = conversation.messages + messages.isNotEmpty() && "Hello!" == messages[0].response + && panel.tokenDetails.conversationTokens > 0 + } + val encodingManager = EncodingManager.getInstance() + assertThat(panel.tokenDetails).extracting( + "systemPromptTokens", + "conversationTokens", + "userPromptTokens", + "highlightedTokens" + ) + .containsExactly( + encodingManager.countTokens("TEST_SYSTEM_PROMPT"), + encodingManager.countConversationTokens(conversation), + 0, + 0 + ) + assertThat(panel.conversation) + .isNotNull() + .extracting("id", "model", "clientCode", "discardTokenLimit") + .containsExactly( + conversation.id, + conversation.model, + conversation.clientCode, + false + ) + val messages = panel.conversation.messages + assertThat(messages).hasSize(1) + assertThat(messages[0]) + .extracting("id", "prompt", "response") + .containsExactly(message.id, message.prompt, message.response) } - val encodingManager = EncodingManager.getInstance() - assertThat(panel.tokenDetails).extracting( - "systemPromptTokens", - "conversationTokens", - "userPromptTokens", - "highlightedTokens") - .containsExactly( - encodingManager.countTokens("TEST_SYSTEM_PROMPT"), - encodingManager.countConversationTokens(conversation), - 0, - 0) - assertThat(panel.conversation) - .isNotNull() - .extracting("id", "model", "clientCode", "discardTokenLimit") - .containsExactly( - conversation.id, - conversation.model, - conversation.clientCode, - false) - val messages = panel.conversation.messages - assertThat(messages).hasSize(1) - assertThat(messages[0]) - .extracting("id", "prompt", "response", "imageFilePath") - .containsExactly( - message.id, - message.prompt, - message.response, - message.imageFilePath) - } - fun testFixCompileErrorsWithOpenAIService() { - project.putUserData( - CodeGPTKeys.SELECTED_FILES, listOf( - ReferencedFile("TEST_FILE_NAME_1", "TEST_FILE_PATH_1", "TEST_FILE_CONTENT_1"), - ReferencedFile("TEST_FILE_NAME_2", "TEST_FILE_PATH_2", "TEST_FILE_CONTENT_2"), - ReferencedFile("TEST_FILE_NAME_3", "TEST_FILE_PATH_3", "TEST_FILE_CONTENT_3"))) - useOpenAIService() - service().state.selectedPersona.instructions = "TEST_SYSTEM_PROMPT" - val message = Message("TEST_MESSAGE") - message.userMessage = "TEST_MESSAGE" - message.referencedFilePaths = listOf("TEST_FILE_PATH_1", "TEST_FILE_PATH_2", "TEST_FILE_PATH_3") - val conversation = ConversationService.getInstance().startConversation() - val panel = ChatToolWindowTabPanel(project, conversation) - expectOpenAI(StreamHttpExchange { request: RequestEntity -> - assertThat(request.uri.path).isEqualTo("/v1/chat/completions") - assertThat(request.method).isEqualTo("POST") - assertThat(request.headers[HttpHeaders.AUTHORIZATION]!![0]).isEqualTo("Bearer TEST_API_KEY") - assertThat(request.body) - .extracting( - "model", - "messages") - .containsExactly( - "gpt-4", - listOf( - mapOf("role" to "system", "content" to FIX_COMPILE_ERRORS_SYSTEM_PROMPT), - mapOf("role" to "user", "content" to """ + fun testSendingOpenAIMessageWithReferencedContext() { + project.putUserData( + CodeGPTKeys.SELECTED_FILES, listOf( + ReferencedFile("TEST_FILE_NAME_1", "TEST_FILE_PATH_1", "TEST_FILE_CONTENT_1"), + ReferencedFile("TEST_FILE_NAME_2", "TEST_FILE_PATH_2", "TEST_FILE_CONTENT_2"), + ReferencedFile("TEST_FILE_NAME_3", "TEST_FILE_PATH_3", "TEST_FILE_CONTENT_3") + ) + ) + useOpenAIService() + service().state.selectedPersona.instructions = "TEST_SYSTEM_PROMPT" + val message = Message("TEST_MESSAGE") + message.userMessage = "TEST_MESSAGE" + message.referencedFilePaths = + listOf("TEST_FILE_PATH_1", "TEST_FILE_PATH_2", "TEST_FILE_PATH_3") + val conversation = ConversationService.getInstance().startConversation() + val panel = ChatToolWindowTabPanel(project, conversation) + expectOpenAI(StreamHttpExchange { request: RequestEntity -> + assertThat(request.uri.path).isEqualTo("/v1/chat/completions") + assertThat(request.method).isEqualTo("POST") + assertThat(request.headers[HttpHeaders.AUTHORIZATION]!![0]).isEqualTo("Bearer TEST_API_KEY") + assertThat(request.body) + .extracting( + "model", + "messages" + ) + .containsExactly( + "gpt-4", + listOf( + mapOf("role" to "system", "content" to "TEST_SYSTEM_PROMPT"), + mapOf( + "role" to "user", "content" to """ Use the following context to answer question at the end: File Path: TEST_FILE_PATH_1 @@ -299,118 +146,331 @@ class ChatToolWindowTabPanelTest : IntegrationTest() { TEST_FILE_CONTENT_3 ``` - Question: TEST_MESSAGE""".trimIndent()))) - listOf( - jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("role", "assistant")))), - jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "Hel")))), - jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "lo")))), - jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "!"))))) - }) + Question: TEST_MESSAGE""".trimIndent() + ) + ) + ) + listOf( + jsonMapResponse( + "choices", + jsonArray(jsonMap("delta", jsonMap("role", "assistant"))) + ), + jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "Hel")))), + jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "lo")))), + jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "!")))) + ) + }) - panel.sendMessage(message, ConversationType.FIX_COMPILE_ERRORS) + panel.sendMessage(message) - waitExpecting { - val messages = conversation.messages - messages.isNotEmpty() && "Hello!" == messages[0].response - && panel.tokenDetails.conversationTokens > 0 + waitExpecting { + val messages = conversation.messages + messages.isNotEmpty() && "Hello!" == messages[0].response + && panel.tokenDetails.conversationTokens > 0 + } + val encodingManager = EncodingManager.getInstance() + assertThat(panel.tokenDetails).extracting( + "systemPromptTokens", + "conversationTokens", + "userPromptTokens", + "highlightedTokens" + ) + .containsExactly( + encodingManager.countTokens("TEST_SYSTEM_PROMPT"), + encodingManager.countConversationTokens(conversation), + 0, + 0 + ) + assertThat(panel.conversation) + .isNotNull() + .extracting("id", "model", "clientCode", "discardTokenLimit") + .containsExactly( + conversation.id, + conversation.model, + conversation.clientCode, + false + ) + val messages = panel.conversation.messages + assertThat(messages).hasSize(1) + assertThat(messages[0]) + .extracting("id", "prompt", "response", "referencedFilePaths") + .containsExactly( + message.id, + message.prompt, + message.response, + listOf("TEST_FILE_PATH_1", "TEST_FILE_PATH_2", "TEST_FILE_PATH_3") + ) } - val encodingManager = EncodingManager.getInstance() - assertThat(panel.tokenDetails).extracting( - "systemPromptTokens", - "conversationTokens", - "userPromptTokens", - "highlightedTokens") - .containsExactly( - encodingManager.countTokens("TEST_SYSTEM_PROMPT"), - encodingManager.countConversationTokens(conversation), - 0, - 0) - assertThat(panel.conversation) - .isNotNull() - .extracting("id", "model", "clientCode", "discardTokenLimit") - .containsExactly( - conversation.id, - conversation.model, - conversation.clientCode, - false) - val messages = panel.conversation.messages - assertThat(messages).hasSize(1) - assertThat(messages[0]) - .extracting("id", "prompt", "response", "referencedFilePaths") - .containsExactly( - message.id, - message.prompt, - message.response, - listOf("TEST_FILE_PATH_1", "TEST_FILE_PATH_2", "TEST_FILE_PATH_3")) - } - fun testSendingLlamaMessage() { - useLlamaService() - val configurationState = service().state - service().state.selectedPersona.instructions = "TEST_SYSTEM_PROMPT" - configurationState.maxTokens = 1000 - configurationState.temperature = 0.1f - val llamaSettings = LlamaSettings.getCurrentState() - llamaSettings.isUseCustomModel = false - llamaSettings.huggingFaceModel = HuggingFaceModel.CODE_LLAMA_7B_Q4 - llamaSettings.topK = 30 - llamaSettings.topP = 0.8 - llamaSettings.minP = 0.03 - llamaSettings.repeatPenalty = 1.3 - val message = Message("TEST_PROMPT") - val conversation = ConversationService.getInstance().startConversation() - val panel = ChatToolWindowTabPanel(project, conversation) - expectLlama(StreamHttpExchange { request: RequestEntity -> - assertThat(request.uri.path).isEqualTo("/completion") - assertThat(request.body) - .extracting( - "prompt", - "n_predict", - "stream", - "temperature", - "top_k", - "top_p", - "min_p", - "repeat_penalty") - .containsExactly( - LLAMA.buildPrompt( - "TEST_SYSTEM_PROMPT", - "TEST_PROMPT", - conversation.messages), - configurationState.maxTokens, - true, - configurationState.temperature.toDouble(), - llamaSettings.topK, - llamaSettings.topP, - llamaSettings.minP, - llamaSettings.repeatPenalty) - listOf( - jsonMapResponse("content", "Hel"), - jsonMapResponse("content", "lo!"), - jsonMapResponse( - e("content", ""), - e("stop", true))) - }) + fun testSendingOpenAIMessageWithImage() { + val testImagePath = + Objects.requireNonNull(javaClass.getResource("/images/test-image.png")).path + project.putUserData(CodeGPTKeys.IMAGE_ATTACHMENT_FILE_PATH, testImagePath) + useOpenAIService("gpt-4-vision-preview") + service().state.selectedPersona.instructions = "TEST_SYSTEM_PROMPT" + val message = Message("TEST_MESSAGE") + val conversation = ConversationService.getInstance().startConversation() + val panel = ChatToolWindowTabPanel(project, conversation) + expectOpenAI(StreamHttpExchange { request: RequestEntity -> + assertThat(request.uri.path).isEqualTo("/v1/chat/completions") + assertThat(request.method).isEqualTo("POST") + assertThat(request.headers[HttpHeaders.AUTHORIZATION]!![0]).isEqualTo("Bearer TEST_API_KEY") + try { + val testImageUrl = ("data:image/png;base64," + + Base64.getEncoder() + .encodeToString(Files.readAllBytes(Path.of(testImagePath)))) + assertThat(request.body) + .extracting("model", "messages") + .containsExactly( + "gpt-4-vision-preview", + listOf( + mapOf("role" to "system", "content" to "TEST_SYSTEM_PROMPT"), + mapOf( + "role" to "user", "content" to listOf( + mapOf( + "type" to "image_url", + "image_url" to mapOf("url" to testImageUrl) + ), + mapOf("type" to "text", "text" to "TEST_MESSAGE") + ) + ) + ) + ) + } catch (e: IOException) { + throw RuntimeException(e) + } + listOf( + jsonMapResponse( + "choices", + jsonArray(jsonMap("delta", jsonMap("role", "assistant"))) + ), + jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "Hel")))), + jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "lo")))), + jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "!")))) + ) + }) - panel.sendMessage(message, ConversationType.DEFAULT) + panel.sendMessage(message) - waitExpecting { - val messages = conversation.messages - messages.isNotEmpty() && "Hello!" == messages[0].response - && panel.tokenDetails.conversationTokens > 0 + waitExpecting { + val messages = conversation.messages + messages.isNotEmpty() && "Hello!" == messages[0].response + && panel.tokenDetails.conversationTokens > 0 + } + val encodingManager = EncodingManager.getInstance() + assertThat(panel.tokenDetails).extracting( + "systemPromptTokens", + "conversationTokens", + "userPromptTokens", + "highlightedTokens" + ) + .containsExactly( + encodingManager.countTokens("TEST_SYSTEM_PROMPT"), + encodingManager.countConversationTokens(conversation), + 0, + 0 + ) + assertThat(panel.conversation) + .isNotNull() + .extracting("id", "model", "clientCode", "discardTokenLimit") + .containsExactly( + conversation.id, + conversation.model, + conversation.clientCode, + false + ) + val messages = panel.conversation.messages + assertThat(messages).hasSize(1) + assertThat(messages[0]) + .extracting("id", "prompt", "response", "imageFilePath") + .containsExactly( + message.id, + message.prompt, + message.response, + message.imageFilePath + ) + } + + fun testFixCompileErrorsWithOpenAIService() { + project.putUserData( + CodeGPTKeys.SELECTED_FILES, listOf( + ReferencedFile("TEST_FILE_NAME_1", "TEST_FILE_PATH_1", "TEST_FILE_CONTENT_1"), + ReferencedFile("TEST_FILE_NAME_2", "TEST_FILE_PATH_2", "TEST_FILE_CONTENT_2"), + ReferencedFile("TEST_FILE_NAME_3", "TEST_FILE_PATH_3", "TEST_FILE_CONTENT_3") + ) + ) + useOpenAIService() + service().state.selectedPersona.instructions = "TEST_SYSTEM_PROMPT" + val message = Message("TEST_MESSAGE") + message.userMessage = "TEST_MESSAGE" + message.referencedFilePaths = + listOf("TEST_FILE_PATH_1", "TEST_FILE_PATH_2", "TEST_FILE_PATH_3") + val conversation = ConversationService.getInstance().startConversation() + val panel = ChatToolWindowTabPanel(project, conversation) + expectOpenAI(StreamHttpExchange { request: RequestEntity -> + assertThat(request.uri.path).isEqualTo("/v1/chat/completions") + assertThat(request.method).isEqualTo("POST") + assertThat(request.headers[HttpHeaders.AUTHORIZATION]!![0]).isEqualTo("Bearer TEST_API_KEY") + assertThat(request.body) + .extracting( + "model", + "messages" + ) + .containsExactly( + "gpt-4", + listOf( + mapOf("role" to "system", "content" to FIX_COMPILE_ERRORS_SYSTEM_PROMPT), + mapOf( + "role" to "user", "content" to """ + Use the following context to answer question at the end: + + File Path: TEST_FILE_PATH_1 + File Content: + ```TEST_FILE_NAME_1 + TEST_FILE_CONTENT_1 + ``` + + File Path: TEST_FILE_PATH_2 + File Content: + ```TEST_FILE_NAME_2 + TEST_FILE_CONTENT_2 + ``` + + File Path: TEST_FILE_PATH_3 + File Content: + ```TEST_FILE_NAME_3 + TEST_FILE_CONTENT_3 + ``` + + Question: TEST_MESSAGE""".trimIndent() + ) + ) + ) + listOf( + jsonMapResponse( + "choices", + jsonArray(jsonMap("delta", jsonMap("role", "assistant"))) + ), + jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "Hel")))), + jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "lo")))), + jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "!")))) + ) + }) + + panel.sendMessage(message, ConversationType.FIX_COMPILE_ERRORS) + + waitExpecting { + val messages = conversation.messages + messages.isNotEmpty() && "Hello!" == messages[0].response + && panel.tokenDetails.conversationTokens > 0 + } + val encodingManager = EncodingManager.getInstance() + assertThat(panel.tokenDetails).extracting( + "systemPromptTokens", + "conversationTokens", + "userPromptTokens", + "highlightedTokens" + ) + .containsExactly( + encodingManager.countTokens("TEST_SYSTEM_PROMPT"), + encodingManager.countConversationTokens(conversation), + 0, + 0 + ) + assertThat(panel.conversation) + .isNotNull() + .extracting("id", "model", "clientCode", "discardTokenLimit") + .containsExactly( + conversation.id, + conversation.model, + conversation.clientCode, + false + ) + val messages = panel.conversation.messages + assertThat(messages).hasSize(1) + assertThat(messages[0]) + .extracting("id", "prompt", "response", "referencedFilePaths") + .containsExactly( + message.id, + message.prompt, + message.response, + listOf("TEST_FILE_PATH_1", "TEST_FILE_PATH_2", "TEST_FILE_PATH_3") + ) + } + + fun testSendingLlamaMessage() { + useLlamaService() + val configurationState = service().state + service().state.selectedPersona.instructions = "TEST_SYSTEM_PROMPT" + configurationState.maxTokens = 1000 + configurationState.temperature = 0.1f + val llamaSettings = LlamaSettings.getCurrentState() + llamaSettings.isUseCustomModel = false + llamaSettings.huggingFaceModel = HuggingFaceModel.CODE_LLAMA_7B_Q4 + llamaSettings.topK = 30 + llamaSettings.topP = 0.8 + llamaSettings.minP = 0.03 + llamaSettings.repeatPenalty = 1.3 + val message = Message("TEST_PROMPT") + val conversation = ConversationService.getInstance().startConversation() + val panel = ChatToolWindowTabPanel(project, conversation) + expectLlama(StreamHttpExchange { request: RequestEntity -> + assertThat(request.uri.path).isEqualTo("/completion") + assertThat(request.body) + .extracting( + "prompt", + "n_predict", + "stream", + "temperature", + "top_k", + "top_p", + "min_p", + "repeat_penalty" + ) + .containsExactly( + LLAMA.buildPrompt( + "TEST_SYSTEM_PROMPT", + "TEST_PROMPT", + conversation.messages + ), + configurationState.maxTokens, + true, + configurationState.temperature.toDouble(), + llamaSettings.topK, + llamaSettings.topP, + llamaSettings.minP, + llamaSettings.repeatPenalty + ) + listOf( + jsonMapResponse("content", "Hel"), + jsonMapResponse("content", "lo!"), + jsonMapResponse( + e("content", ""), + e("stop", true) + ) + ) + }) + + panel.sendMessage(message, ConversationType.DEFAULT) + + waitExpecting { + val messages = conversation.messages + messages.isNotEmpty() && "Hello!" == messages[0].response + && panel.tokenDetails.conversationTokens > 0 + } + assertThat(panel.conversation) + .isNotNull() + .extracting("id", "model", "clientCode", "discardTokenLimit") + .containsExactly( + conversation.id, + conversation.model, + conversation.clientCode, + false + ) + val messages = panel.conversation.messages + assertThat(messages).hasSize(1) + assertThat(messages[0]) + .extracting("id", "prompt", "response") + .containsExactly(message.id, message.prompt, message.response) } - assertThat(panel.conversation) - .isNotNull() - .extracting("id", "model", "clientCode", "discardTokenLimit") - .containsExactly( - conversation.id, - conversation.model, - conversation.clientCode, - false) - val messages = panel.conversation.messages - assertThat(messages).hasSize(1) - assertThat(messages[0]) - .extracting("id", "prompt", "response") - .containsExactly(message.id, message.prompt, message.response) - } }