feat: support qwen2.5 and o1 models

This commit is contained in:
Carl-Robert Linnupuu 2024-10-01 12:36:37 +03:00
parent 5c9253278f
commit 24ae263a39
32 changed files with 521 additions and 314 deletions

View file

@ -12,7 +12,7 @@ jsoup = "1.17.2"
jtokkit = "1.1.0"
junit = "5.11.0"
kotlin = "2.0.0"
llm-client = "0.8.18"
llm-client = "0.8.19"
okio = "3.9.0"
tree-sitter = "0.22.6a"

View file

@ -15,10 +15,12 @@ public final class Icons {
public static final Icon Azure = IconLoader.getIcon("/icons/azure.svg", Icons.class);
public static final Icon Databricks = IconLoader.getIcon("/icons/dbrx.svg", Icons.class);
public static final Icon DeepSeek = IconLoader.getIcon("/icons/deepseek.png", Icons.class);
public static final Icon Qwen = IconLoader.getIcon("/icons/qwen.png", Icons.class);
public static final Icon Google = IconLoader.getIcon("/icons/google.svg", Icons.class);
public static final Icon Llama = IconLoader.getIcon("/icons/llama.svg", Icons.class);
public static final Icon OpenAI = IconLoader.getIcon("/icons/openai.svg", Icons.class);
public static final Icon Meta = IconLoader.getIcon("/icons/meta.svg", Icons.class);
public static final Icon Mistral = IconLoader.getIcon("/icons/mistral.svg", Icons.class);
public static final Icon Send = IconLoader.getIcon("/icons/send.svg", Icons.class);
public static final Icon Sparkle = IconLoader.getIcon("/icons/sparkle.svg", Icons.class);
public static final Icon You = IconLoader.getIcon("/icons/you.svg", Icons.class);

View file

@ -18,6 +18,7 @@ import com.intellij.vcs.commit.CommitWorkflowUi;
import ee.carlrobert.codegpt.CodeGPTBundle;
import ee.carlrobert.codegpt.EncodingManager;
import ee.carlrobert.codegpt.Icons;
import ee.carlrobert.codegpt.completions.CommitMessageRequestParameters;
import ee.carlrobert.codegpt.completions.CompletionRequestService;
import ee.carlrobert.codegpt.settings.configuration.CommitMessageTemplate;
import ee.carlrobert.codegpt.ui.OverlayUtil;
@ -85,8 +86,9 @@ public class GenerateGitCommitMessageAction extends AnAction {
var commitWorkflowUi = event.getData(VcsDataKeys.COMMIT_WORKFLOW_UI);
if (commitWorkflowUi != null) {
CompletionRequestService.getInstance().getCommitMessageAsync(
project.getService(CommitMessageTemplate.class).getSystemPrompt(),
gitDiff,
new CommitMessageRequestParameters(
gitDiff,
project.getService(CommitMessageTemplate.class).getSystemPrompt()),
getEventListener(project, commitWorkflowUi));
}
}
@ -162,11 +164,22 @@ public class GenerateGitCommitMessageAction extends AnAction {
@Override
public void onMessage(String message, EventSource eventSource) {
messageBuilder.append(message);
var application = ApplicationManager.getApplication();
application.invokeLater(() ->
application.runWriteAction(() ->
WriteCommandAction.runWriteCommandAction(project, () ->
commitWorkflowUi.getCommitMessageUi().setText(messageBuilder.toString()))));
updateCommitMessage(messageBuilder.toString());
}
@Override
public void onComplete(StringBuilder result) {
if (messageBuilder.isEmpty()) {
updateCommitMessage(result.toString());
}
}
private void updateCommitMessage(String message) {
ApplicationManager.getApplication().invokeLater(() ->
WriteCommandAction.runWriteCommandAction(project, () ->
commitWorkflowUi.getCommitMessageUi().setText(message)
)
);
}
@Override

View file

@ -0,0 +1,74 @@
package ee.carlrobert.codegpt.completions;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import ee.carlrobert.codegpt.events.CodeGPTEvent;
import ee.carlrobert.codegpt.telemetry.TelemetryAction;
import ee.carlrobert.llm.client.openai.completion.ErrorDetails;
import ee.carlrobert.llm.completion.CompletionEventListener;
import okhttp3.sse.EventSource;
public class ChatCompletionEventListener implements CompletionEventListener<String> {
private final CallParameters callParameters;
private final CompletionResponseEventListener eventListener;
private final StringBuilder messageBuilder = new StringBuilder();
public ChatCompletionEventListener(
CallParameters callParameters,
CompletionResponseEventListener eventListener) {
this.callParameters = callParameters;
this.eventListener = eventListener;
}
@Override
public void onEvent(String data) {
try {
var event = new ObjectMapper().readValue(data, CodeGPTEvent.class);
eventListener.handleCodeGPTEvent(event);
} catch (JsonProcessingException e) {
// ignore
}
}
@Override
public void onMessage(String message, EventSource eventSource) {
messageBuilder.append(message);
callParameters.getMessage().setResponse(messageBuilder.toString());
eventListener.handleMessage(message);
}
@Override
public void onComplete(StringBuilder messageBuilder) {
eventListener.handleCompleted(messageBuilder.toString(), callParameters);
}
@Override
public void onCancelled(StringBuilder messageBuilder) {
eventListener.handleCompleted(messageBuilder.toString(), callParameters);
}
@Override
public void onError(ErrorDetails error, Throwable ex) {
try {
eventListener.handleError(error, ex);
} finally {
sendError(error, ex);
}
}
private void sendError(ErrorDetails error, Throwable ex) {
var telemetryMessage = TelemetryAction.COMPLETION_ERROR.createActionMessage();
if ("insufficient_quota".equals(error.getCode())) {
telemetryMessage
.property("type", "USER")
.property("code", "INSUFFICIENT_QUOTA");
} else {
telemetryMessage
.property("conversationId", callParameters.getConversation().getId().toString())
.property("model", callParameters.getConversation().getModel())
.error(new RuntimeException(error.toString(), ex));
}
telemetryMessage.send();
}
}

View file

@ -89,7 +89,6 @@ public class CompletionClientProvider {
return builder.build(getDefaultClientBuilder());
}
public static GoogleClient getGoogleClient() {
return new GoogleClient.Builder(getCredential(CredentialKey.GOOGLE_API_KEY))
.build(getDefaultClientBuilder());

View file

@ -1,129 +0,0 @@
package ee.carlrobert.codegpt.completions;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import ee.carlrobert.codegpt.events.CodeGPTEvent;
import ee.carlrobert.codegpt.settings.GeneralSettings;
import ee.carlrobert.codegpt.telemetry.TelemetryAction;
import ee.carlrobert.llm.client.openai.completion.ErrorDetails;
import ee.carlrobert.llm.completion.CompletionEventListener;
import okhttp3.sse.EventSource;
public class CompletionRequestHandler {
private final StringBuilder messageBuilder = new StringBuilder();
private final CompletionResponseEventListener completionResponseEventListener;
private EventSource eventSource;
public CompletionRequestHandler(CompletionResponseEventListener completionResponseEventListener) {
this.completionResponseEventListener = completionResponseEventListener;
}
public void call(CallParameters callParameters) {
try {
eventSource = startCall(callParameters, new RequestCompletionEventListener(callParameters));
} catch (TotalUsageExceededException e) {
completionResponseEventListener.handleTokensExceeded(
callParameters.getConversation(),
callParameters.getMessage());
} finally {
sendInfo(callParameters);
}
}
public void cancel() {
if (eventSource != null) {
eventSource.cancel();
}
}
private EventSource startCall(
CallParameters callParameters,
CompletionEventListener<String> eventListener) {
try {
return CompletionRequestService.getInstance()
.getChatCompletionAsync(callParameters, eventListener);
} catch (Throwable ex) {
handleCallException(ex);
throw ex;
}
}
private void handleCallException(Throwable ex) {
var errorMessage = "Something went wrong";
if (ex instanceof TotalUsageExceededException) {
errorMessage =
"The length of the context exceeds the maximum limit that the model can handle. "
+ "Try reducing the input message or maximum completion token size.";
}
completionResponseEventListener.handleError(new ErrorDetails(errorMessage), ex);
}
class RequestCompletionEventListener implements CompletionEventListener<String> {
private final CallParameters callParameters;
public RequestCompletionEventListener(CallParameters callParameters) {
this.callParameters = callParameters;
}
@Override
public void onEvent(String data) {
try {
var event = new ObjectMapper().readValue(data, CodeGPTEvent.class);
completionResponseEventListener.handleCodeGPTEvent(event);
} catch (JsonProcessingException e) {
// ignore
}
}
@Override
public void onMessage(String message, EventSource eventSource) {
messageBuilder.append(message);
callParameters.getMessage().setResponse(messageBuilder.toString());
completionResponseEventListener.handleMessage(message);
}
@Override
public void onComplete(StringBuilder messageBuilder) {
completionResponseEventListener.handleCompleted(messageBuilder.toString(), callParameters);
}
@Override
public void onCancelled(StringBuilder messageBuilder) {
completionResponseEventListener.handleCompleted(messageBuilder.toString(), callParameters);
}
@Override
public void onError(ErrorDetails error, Throwable ex) {
try {
completionResponseEventListener.handleError(error, ex);
} finally {
sendError(error, ex);
}
}
private void sendError(ErrorDetails error, Throwable ex) {
var telemetryMessage = TelemetryAction.COMPLETION_ERROR.createActionMessage();
if ("insufficient_quota".equals(error.getCode())) {
telemetryMessage
.property("type", "USER")
.property("code", "INSUFFICIENT_QUOTA");
} else {
telemetryMessage
.property("conversationId", callParameters.getConversation().getId().toString())
.property("model", callParameters.getConversation().getModel())
.error(new RuntimeException(error.toString(), ex));
}
telemetryMessage.send();
}
}
private void sendInfo(CallParameters callParameters) {
TelemetryAction.COMPLETION.createActionMessage()
.property("conversationId", callParameters.getConversation().getId().toString())
.property("model", callParameters.getConversation().getModel())
.property("service", GeneralSettings.getSelectedService().getCode().toLowerCase())
.send();
}
}

View file

@ -3,7 +3,9 @@ package ee.carlrobert.codegpt.completions;
import com.intellij.openapi.application.ApplicationManager;
import com.intellij.openapi.components.Service;
import com.intellij.openapi.diagnostic.Logger;
import ee.carlrobert.codegpt.actions.editor.EditCodeRequestParams;
import com.intellij.openapi.progress.ProgressIndicator;
import com.intellij.openapi.progress.ProgressManager;
import com.intellij.openapi.progress.Task;
import ee.carlrobert.codegpt.completions.factory.CustomOpenAIRequest;
import ee.carlrobert.codegpt.credentials.CredentialsStore;
import ee.carlrobert.codegpt.credentials.CredentialsStore.CredentialKey;
@ -26,12 +28,15 @@ import ee.carlrobert.llm.completion.CompletionEventListener;
import ee.carlrobert.llm.completion.CompletionRequest;
import java.io.IOException;
import java.util.Collection;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Stream;
import javax.swing.SwingUtilities;
import okhttp3.Request;
import okhttp3.sse.EventSource;
import okhttp3.sse.EventSources;
import org.jetbrains.annotations.NotNull;
@Service
public final class CompletionRequestService {
@ -63,50 +68,50 @@ public final class CompletionRequestService {
new OpenAIChatCompletionEventSourceListener(eventListener));
}
public String getLookupCompletion(String prompt) {
return getChatCompletion(
CompletionRequestFactory.getFactory(GeneralSettings.getSelectedService())
.createLookupRequest(prompt));
public String getLookupCompletion(LookupRequestCallParameters params) {
var request = CompletionRequestFactory
.getFactory(GeneralSettings.getSelectedService())
.createLookupRequest(params);
return getChatCompletion(request);
}
public EventSource getCommitMessageAsync(
String systemPrompt,
String gitDiff,
CommitMessageRequestParameters params,
CompletionEventListener<String> eventListener) {
return getChatCompletionAsync(
CompletionRequestFactory.getFactory(GeneralSettings.getSelectedService())
.createCommitMessageRequest(systemPrompt, gitDiff),
eventListener);
var request = CompletionRequestFactory
.getFactory(GeneralSettings.getSelectedService())
.createCommitMessageRequest(params);
return getChatCompletionAsync(request, eventListener);
}
public EventSource getEditCodeCompletionAsync(
EditCodeRequestParams params,
EditCodeRequestParameters params,
CompletionEventListener<String> eventListener) {
var input = "%s\n\n%s".formatted(params.getPrompt(), params.getSelectedText());
return getChatCompletionAsync(
CompletionRequestFactory.getFactory(GeneralSettings.getSelectedService())
.createEditCodeRequest(input),
eventListener);
var request = CompletionRequestFactory
.getFactory(GeneralSettings.getSelectedService())
.createEditCodeRequest(params);
return getChatCompletionAsync(request, eventListener);
}
public EventSource getChatCompletionAsync(
CallParameters callParameters,
CompletionEventListener<String> eventListener) {
return getChatCompletionAsync(
CompletionRequestFactory.getFactory(GeneralSettings.getSelectedService())
.createChatRequest(callParameters),
eventListener);
}
private EventSource getChatCompletionAsync(
CompletionRequest request,
CompletionEventListener<String> eventListener) {
if (request instanceof OpenAIChatCompletionRequest completionRequest) {
return switch (GeneralSettings.getSelectedService()) {
case CODEGPT -> CompletionClientProvider.getCodeGPTClient()
.getChatCompletionAsync(completionRequest, eventListener);
case OPENAI -> CompletionClientProvider.getOpenAIClient()
.getChatCompletionAsync(completionRequest, eventListener);
case CODEGPT -> {
if (List.of("o1-mini", "o1-preview").contains(completionRequest.getModel())) {
yield getO1ChatCompletionAsync(completionRequest, eventListener);
}
yield CompletionClientProvider.getCodeGPTClient()
.getChatCompletionAsync(completionRequest, eventListener);
}
case OPENAI -> {
if (List.of("o1-mini", "o1-preview").contains(completionRequest.getModel())) {
yield getO1ChatCompletionAsync(completionRequest, eventListener);
}
yield CompletionClientProvider.getOpenAIClient()
.getChatCompletionAsync(completionRequest, eventListener);
}
case AZURE -> CompletionClientProvider.getAzureClient()
.getChatCompletionAsync(completionRequest, eventListener);
default -> throw new RuntimeException("Unknown service selected");
@ -142,7 +147,33 @@ public final class CompletionRequestService {
throw new IllegalStateException("Unknown request type: " + request.getClass());
}
private String getChatCompletion(CompletionRequest request) {
private EventSource getO1ChatCompletionAsync(
OpenAIChatCompletionRequest request,
CompletionEventListener<String> eventListener) {
ProgressManager.getInstance()
.run(new Task.Backgroundable(null, "CodeGPT: Processing o1 request") {
@Override
public void run(@NotNull ProgressIndicator indicator) {
indicator.setIndeterminate(true);
var response = CompletionRequestService.getInstance().getChatCompletion(request);
SwingUtilities.invokeLater(() -> eventListener.onComplete(new StringBuilder(response)));
}
});
return new EventSource() {
@Override
public @NotNull Request request() {
return new Request.Builder().build(); // dummy
}
@Override
public void cancel() {
eventListener.onCancelled(new StringBuilder("Cancelled"));
}
};
}
public String getChatCompletion(CompletionRequest request) {
if (request instanceof OpenAIChatCompletionRequest completionRequest) {
var response = switch (GeneralSettings.getSelectedService()) {
case CODEGPT -> CompletionClientProvider.getCodeGPTClient()

View file

@ -16,6 +16,9 @@ public interface CompletionResponseEventListener {
default void handleTokensExceeded(Conversation conversation, Message message) {
}
default void handleCompleted(String fullMessage) {
}
default void handleCompleted(String fullMessage, CallParameters callParameters) {
}

View file

@ -56,7 +56,8 @@ public class MethodNameLookupListener implements LookupManagerListener {
Application application,
String prompt) {
try {
var response = CompletionRequestService.getInstance().getLookupCompletion(prompt);
var response = CompletionRequestService.getInstance()
.getLookupCompletion(new LookupRequestCallParameters(prompt));
if (!response.isEmpty()) {
for (var value : response.split(",")) {
application.invokeLater(() -> application.runReadAction(() -> {

View file

@ -0,0 +1,67 @@
package ee.carlrobert.codegpt.completions;
import ee.carlrobert.codegpt.settings.GeneralSettings;
import ee.carlrobert.codegpt.telemetry.TelemetryAction;
import ee.carlrobert.llm.client.openai.completion.ErrorDetails;
import okhttp3.sse.EventSource;
public class ToolwindowChatCompletionRequestHandler {
private final CompletionResponseEventListener completionResponseEventListener;
private EventSource eventSource;
public ToolwindowChatCompletionRequestHandler(
CompletionResponseEventListener completionResponseEventListener) {
this.completionResponseEventListener = completionResponseEventListener;
}
public void call(CallParameters callParameters) {
try {
eventSource = startCall(callParameters);
} catch (TotalUsageExceededException e) {
completionResponseEventListener.handleTokensExceeded(
callParameters.getConversation(),
callParameters.getMessage());
} finally {
sendInfo(callParameters);
}
}
public void cancel() {
if (eventSource != null) {
eventSource.cancel();
}
}
private EventSource startCall(CallParameters callParameters) {
try {
var request = CompletionRequestFactory
.getFactory(GeneralSettings.getSelectedService())
.createChatRequest(new ChatCompletionRequestParameters(callParameters));
return CompletionRequestService.getInstance().getChatCompletionAsync(
request,
new ChatCompletionEventListener(callParameters, completionResponseEventListener));
} catch (Throwable ex) {
handleCallException(ex);
throw ex;
}
}
private void handleCallException(Throwable ex) {
var errorMessage = "Something went wrong";
if (ex instanceof TotalUsageExceededException) {
errorMessage =
"The length of the context exceeds the maximum limit that the model can handle. "
+ "Try reducing the input message or maximum completion token size.";
}
completionResponseEventListener.handleError(new ErrorDetails(errorMessage), ex);
}
private void sendInfo(CallParameters callParameters) {
TelemetryAction.COMPLETION.createActionMessage()
.property("conversationId", callParameters.getConversation().getId().toString())
.property("model", callParameters.getConversation().getModel())
.property("service", GeneralSettings.getSelectedService().getCode().toLowerCase())
.send();
}
}

View file

@ -11,8 +11,8 @@ public class AdvancedSettingsState {
private boolean proxyAuthSelected;
private String proxyUsername;
private String proxyPassword;
private int connectTimeout = 30;
private int readTimeout = 30;
private int connectTimeout = 120;
private int readTimeout = 120;
public String getProxyHost() {
return proxyHost;

View file

@ -15,10 +15,10 @@ import ee.carlrobert.codegpt.CodeGPTKeys;
import ee.carlrobert.codegpt.ReferencedFile;
import ee.carlrobert.codegpt.actions.ActionType;
import ee.carlrobert.codegpt.completions.CallParameters;
import ee.carlrobert.codegpt.completions.CompletionRequestHandler;
import ee.carlrobert.codegpt.completions.CompletionRequestService;
import ee.carlrobert.codegpt.completions.CompletionRequestUtil;
import ee.carlrobert.codegpt.completions.ConversationType;
import ee.carlrobert.codegpt.completions.ToolwindowChatCompletionRequestHandler;
import ee.carlrobert.codegpt.conversations.Conversation;
import ee.carlrobert.codegpt.conversations.ConversationService;
import ee.carlrobert.codegpt.conversations.message.Message;
@ -60,7 +60,7 @@ public class ChatToolWindowTabPanel implements Disposable {
private final TotalTokensPanel totalTokensPanel;
private final ChatToolWindowScrollablePanel toolWindowScrollablePanel;
private @Nullable CompletionRequestHandler requestHandler;
private @Nullable ToolwindowChatCompletionRequestHandler requestHandler;
public ChatToolWindowTabPanel(@NotNull Project project, @NotNull Conversation conversation) {
this.project = project;
@ -250,7 +250,7 @@ public class ChatToolWindowTabPanel implements Disposable {
return;
}
requestHandler = new CompletionRequestHandler(
requestHandler = new ToolwindowChatCompletionRequestHandler(
new ToolWindowCompletionResponseEventListener(
conversationService,
responsePanel,

View file

@ -112,6 +112,9 @@ abstract class ToolWindowCompletionResponseEventListener implements
try {
responsePanel.enableActions();
responseContainer.enableActions();
if (!responseContainer.isResponseReceived() && !fullMessage.isEmpty()) {
responseContainer.withResponse(fullMessage);
}
totalTokensPanel.updateUserPromptTokens(textArea.getText());
totalTokensPanel.updateConversationTokens(callParameters.getConversation());
} finally {

View file

@ -113,6 +113,10 @@ public class ChatMessageResponseBody extends JPanel {
}
public ChatMessageResponseBody withResponse(String response) {
if (!responseReceived) {
removeAll();
}
for (var message : MarkdownUtil.splitCodeBlocks(response)) {
currentlyProcessedEditorPanel = null;
currentlyProcessedTextPane = null;
@ -362,4 +366,8 @@ public class ChatMessageResponseBody extends JPanel {
panel.add(listPanel, BorderLayout.CENTER);
return panel;
}
public boolean isResponseReceived() {
return responseReceived;
}
}

View file

@ -123,9 +123,10 @@ public class ModelComboBoxAction extends ComboBoxAction {
var openaiGroup = DefaultActionGroup.createPopupGroup(() -> "OpenAI");
openaiGroup.getTemplatePresentation().setIcon(Icons.OpenAI);
List.of(
OpenAIChatCompletionModel.O_1_PREVIEW,
OpenAIChatCompletionModel.O_1_MINI,
OpenAIChatCompletionModel.GPT_4_O,
OpenAIChatCompletionModel.GPT_4_O_MINI,
OpenAIChatCompletionModel.GPT_4_VISION_PREVIEW,
OpenAIChatCompletionModel.GPT_4_0125_128k)
.forEach(model -> openaiGroup.add(createOpenAIModelAction(model, presentation)));
actionGroup.add(openaiGroup);

View file

@ -34,7 +34,12 @@ class EditCodeCompletionListener(
}
override fun onComplete(messageBuilder: StringBuilder) {
runInEdt { cleanupAndFormat() }
runInEdt {
if (replacedLength == 0 && messageBuilder.isNotEmpty()) {
handleDiff(messageBuilder.toString())
}
cleanupAndFormat()
}
observableProperties.loading.set(false)
}
@ -73,7 +78,6 @@ class EditCodeCompletionListener(
val document = editor.document
val startOffset = selectionTextRange.startOffset
val endOffset = selectionTextRange.endOffset
runUndoTransparentWriteAction {
val remainingOriginalLength = endOffset - (startOffset + replacedLength)
if (remainingOriginalLength > 0) {

View file

@ -9,10 +9,9 @@ import com.intellij.openapi.util.TextRange
import com.intellij.openapi.util.text.StringUtil
import com.jetbrains.rd.util.AtomicReference
import ee.carlrobert.codegpt.completions.CompletionRequestService
import ee.carlrobert.codegpt.completions.EditCodeRequestParameters
import ee.carlrobert.codegpt.ui.ObservableProperties
data class EditCodeRequestParams(val prompt: String, val selectedText: String)
class EditCodeSubmissionHandler(
private val editor: Editor,
private val observableProperties: ObservableProperties,
@ -36,7 +35,7 @@ class EditCodeSubmissionHandler(
runInEdt { editor.selectionModel.removeSelection() }
service<CompletionRequestService>().getEditCodeCompletionAsync(
EditCodeRequestParams(userPrompt, selectedText),
EditCodeRequestParameters(userPrompt, selectedText),
EditCodeCompletionListener(editor, observableProperties, selectionTextRange)
)
}

View file

@ -0,0 +1,19 @@
package ee.carlrobert.codegpt.completions
interface CompletionCallParameters
data class ChatCompletionRequestParameters(
val callParameters: CallParameters
) : CompletionCallParameters
data class CommitMessageRequestParameters(
val gitDiff: String,
val systemPrompt: String
) : CompletionCallParameters
data class LookupRequestCallParameters(val prompt: String) : CompletionCallParameters
data class EditCodeRequestParameters(
val prompt: String,
val selectedText: String
) : CompletionCallParameters

View file

@ -7,10 +7,10 @@ import ee.carlrobert.codegpt.settings.service.ServiceType
import ee.carlrobert.llm.completion.CompletionRequest
interface CompletionRequestFactory {
fun createChatRequest(callParameters: CallParameters): CompletionRequest
fun createEditCodeRequest(input: String): CompletionRequest
fun createCommitMessageRequest(systemPrompt: String, gitDiff: String): CompletionRequest
fun createLookupRequest(prompt: String): CompletionRequest
fun createChatRequest(params: ChatCompletionRequestParameters): CompletionRequest
fun createEditCodeRequest(params: EditCodeRequestParameters): CompletionRequest
fun createCommitMessageRequest(params: CommitMessageRequestParameters): CompletionRequest
fun createLookupRequest(params: LookupRequestCallParameters): CompletionRequest
companion object {
@JvmStatic
@ -30,24 +30,23 @@ interface CompletionRequestFactory {
}
abstract class BaseRequestFactory : CompletionRequestFactory {
override fun createEditCodeRequest(input: String): CompletionRequest {
return createBasicCompletionRequest(EDIT_CODE_SYSTEM_PROMPT, input, true)
override fun createEditCodeRequest(params: EditCodeRequestParameters): CompletionRequest {
val prompt = "${params.prompt}\n\n${params.selectedText}"
return createBasicCompletionRequest(EDIT_CODE_SYSTEM_PROMPT, prompt, 8192, true)
}
override fun createCommitMessageRequest(
systemPrompt: String,
gitDiff: String
): CompletionRequest {
return createBasicCompletionRequest(systemPrompt, gitDiff, true)
override fun createCommitMessageRequest(params: CommitMessageRequestParameters): CompletionRequest {
return createBasicCompletionRequest(params.systemPrompt, params.gitDiff, 512, true)
}
override fun createLookupRequest(prompt: String): CompletionRequest {
return createBasicCompletionRequest(GENERATE_METHOD_NAMES_SYSTEM_PROMPT, prompt)
override fun createLookupRequest(params: LookupRequestCallParameters): CompletionRequest {
return createBasicCompletionRequest(GENERATE_METHOD_NAMES_SYSTEM_PROMPT, params.prompt, 512)
}
abstract fun createBasicCompletionRequest(
systemPrompt: String,
userPrompt: String,
maxTokens: Int = 4096,
stream: Boolean = false
): CompletionRequest
}

View file

@ -2,7 +2,7 @@ package ee.carlrobert.codegpt.completions.factory
import com.intellij.openapi.components.service
import ee.carlrobert.codegpt.completions.BaseRequestFactory
import ee.carlrobert.codegpt.completions.CallParameters
import ee.carlrobert.codegpt.completions.ChatCompletionRequestParameters
import ee.carlrobert.codegpt.completions.factory.OpenAIRequestFactory.Companion.buildOpenAIMessages
import ee.carlrobert.codegpt.settings.configuration.ConfigurationSettings
import ee.carlrobert.llm.client.openai.completion.request.OpenAIChatCompletionRequest
@ -10,10 +10,10 @@ import ee.carlrobert.llm.completion.CompletionRequest
class AzureRequestFactory : BaseRequestFactory() {
override fun createChatRequest(callParameters: CallParameters): OpenAIChatCompletionRequest {
override fun createChatRequest(params: ChatCompletionRequestParameters): OpenAIChatCompletionRequest {
val configuration = service<ConfigurationSettings>().state
val requestBuilder: OpenAIChatCompletionRequest.Builder =
OpenAIChatCompletionRequest.Builder(buildOpenAIMessages(null, callParameters))
OpenAIChatCompletionRequest.Builder(buildOpenAIMessages(null, params.callParameters))
.setMaxTokens(configuration.maxTokens)
.setStream(true)
.setTemperature(configuration.temperature.toDouble())
@ -23,6 +23,7 @@ class AzureRequestFactory : BaseRequestFactory() {
override fun createBasicCompletionRequest(
systemPrompt: String,
userPrompt: String,
maxTokens: Int,
stream: Boolean
): CompletionRequest {
return OpenAIRequestFactory.createBasicCompletionRequest(

View file

@ -2,7 +2,7 @@ package ee.carlrobert.codegpt.completions.factory
import com.intellij.openapi.components.service
import ee.carlrobert.codegpt.completions.BaseRequestFactory
import ee.carlrobert.codegpt.completions.CallParameters
import ee.carlrobert.codegpt.completions.ChatCompletionRequestParameters
import ee.carlrobert.codegpt.settings.configuration.ConfigurationSettings
import ee.carlrobert.codegpt.settings.persona.PersonaSettings
import ee.carlrobert.codegpt.settings.service.anthropic.AnthropicSettings
@ -11,7 +11,8 @@ import ee.carlrobert.llm.completion.CompletionRequest
class ClaudeRequestFactory : BaseRequestFactory() {
override fun createChatRequest(callParameters: CallParameters): ClaudeCompletionRequest {
override fun createChatRequest(params: ChatCompletionRequestParameters): ClaudeCompletionRequest {
val (callParameters) = params
return ClaudeCompletionRequest().apply {
model = service<AnthropicSettings>().state.model
maxTokens = service<ConfigurationSettings>().state.maxTokens
@ -57,15 +58,16 @@ class ClaudeRequestFactory : BaseRequestFactory() {
override fun createBasicCompletionRequest(
systemPrompt: String,
userPrompt: String,
maxTokens: Int,
stream: Boolean
): CompletionRequest {
return ClaudeCompletionRequest().apply {
system = systemPrompt
isStream = stream
maxTokens = service<ConfigurationSettings>().state.maxTokens
model = service<AnthropicSettings>().state.model
messages =
listOf<ClaudeCompletionMessage>(ClaudeCompletionStandardMessage("user", userPrompt))
this.maxTokens = maxTokens
}
}
}

View file

@ -2,7 +2,8 @@ package ee.carlrobert.codegpt.completions.factory
import com.intellij.openapi.components.service
import ee.carlrobert.codegpt.completions.BaseRequestFactory
import ee.carlrobert.codegpt.completions.CallParameters
import ee.carlrobert.codegpt.completions.ChatCompletionRequestParameters
import ee.carlrobert.codegpt.completions.factory.OpenAIRequestFactory.Companion.buildBasicO1Request
import ee.carlrobert.codegpt.completions.factory.OpenAIRequestFactory.Companion.buildOpenAIMessages
import ee.carlrobert.codegpt.settings.configuration.ConfigurationSettings
import ee.carlrobert.codegpt.settings.service.codegpt.CodeGPTServiceSettings
@ -11,15 +12,26 @@ import ee.carlrobert.llm.client.openai.completion.request.RequestDocumentationDe
class CodeGPTRequestFactory : BaseRequestFactory() {
override fun createChatRequest(callParameters: CallParameters): OpenAIChatCompletionRequest {
override fun createChatRequest(params: ChatCompletionRequestParameters): OpenAIChatCompletionRequest {
val (callParameters) = params
val model = service<CodeGPTServiceSettings>().state.chatCompletionSettings.model
val configuration = service<ConfigurationSettings>().state
val requestBuilder: OpenAIChatCompletionRequest.Builder =
OpenAIChatCompletionRequest.Builder(buildOpenAIMessages(model, callParameters))
.setModel(model)
.setMaxTokens(configuration.maxTokens)
if ("o1-mini" == model || "o1-preview" == model) {
requestBuilder
.setMaxCompletionTokens(configuration.maxTokens)
.setStream(false)
.setMaxTokens(null)
.setTemperature(null)
} else {
requestBuilder
.setStream(true)
.setMaxTokens(configuration.maxTokens)
.setTemperature(configuration.temperature.toDouble())
}
if (callParameters.message.isWebSearchIncluded) {
requestBuilder.setWebSearchIncluded(true)
}
@ -36,12 +48,17 @@ class CodeGPTRequestFactory : BaseRequestFactory() {
override fun createBasicCompletionRequest(
systemPrompt: String,
userPrompt: String,
maxTokens: Int,
stream: Boolean
): OpenAIChatCompletionRequest {
val model = service<CodeGPTServiceSettings>().state.chatCompletionSettings.model
if (model == "o1-mini" || model == "o1-preview") {
return buildBasicO1Request(model, userPrompt, systemPrompt, maxTokens)
}
return OpenAIRequestFactory.createBasicCompletionRequest(
systemPrompt,
userPrompt,
service<CodeGPTServiceSettings>().state.chatCompletionSettings.model,
model,
stream
)
}

View file

@ -3,7 +3,7 @@ package ee.carlrobert.codegpt.completions.factory
import com.fasterxml.jackson.databind.ObjectMapper
import com.intellij.openapi.components.service
import ee.carlrobert.codegpt.completions.BaseRequestFactory
import ee.carlrobert.codegpt.completions.CallParameters
import ee.carlrobert.codegpt.completions.ChatCompletionRequestParameters
import ee.carlrobert.codegpt.credentials.CredentialsStore.CredentialKey
import ee.carlrobert.codegpt.credentials.CredentialsStore.getCredential
import ee.carlrobert.codegpt.settings.service.custom.CustomServiceChatCompletionSettingsState
@ -19,7 +19,8 @@ class CustomOpenAIRequest(val request: Request) : CompletionRequest
class CustomOpenAIRequestFactory : BaseRequestFactory() {
override fun createChatRequest(callParameters: CallParameters): CustomOpenAIRequest {
override fun createChatRequest(params: ChatCompletionRequestParameters): CustomOpenAIRequest {
val (callParameters) = params
val request = buildCustomOpenAIChatCompletionRequest(
service<CustomServiceSettings>()
.state
@ -34,6 +35,7 @@ class CustomOpenAIRequestFactory : BaseRequestFactory() {
override fun createBasicCompletionRequest(
systemPrompt: String,
userPrompt: String,
maxTokens: Int,
stream: Boolean
): CompletionRequest {
val request = buildCustomOpenAIChatCompletionRequest(

View file

@ -2,11 +2,8 @@ package ee.carlrobert.codegpt.completions.factory
import com.intellij.openapi.components.service
import ee.carlrobert.codegpt.EncodingManager
import ee.carlrobert.codegpt.completions.BaseRequestFactory
import ee.carlrobert.codegpt.completions.CallParameters
import ee.carlrobert.codegpt.completions.*
import ee.carlrobert.codegpt.completions.CompletionRequestUtil.FIX_COMPILE_ERRORS_SYSTEM_PROMPT
import ee.carlrobert.codegpt.completions.ConversationType
import ee.carlrobert.codegpt.completions.TotalUsageExceededException
import ee.carlrobert.codegpt.conversations.ConversationsState
import ee.carlrobert.codegpt.settings.configuration.ConfigurationSettings
import ee.carlrobert.codegpt.settings.persona.PersonaSettings
@ -23,7 +20,8 @@ import java.nio.file.Path
class GoogleRequestFactory : BaseRequestFactory() {
override fun createChatRequest(callParameters: CallParameters): GoogleCompletionRequest {
override fun createChatRequest(params: ChatCompletionRequestParameters): GoogleCompletionRequest {
val (callParameters) = params
val configuration = service<ConfigurationSettings>().state
val messages = buildGoogleMessages(service<GoogleSettings>().state.model, callParameters)
return GoogleCompletionRequest.Builder(messages)
@ -38,6 +36,7 @@ class GoogleRequestFactory : BaseRequestFactory() {
override fun createBasicCompletionRequest(
systemPrompt: String,
userPrompt: String,
maxTokens: Int,
stream: Boolean
): GoogleCompletionRequest {
val configuration = service<ConfigurationSettings>().state
@ -50,7 +49,7 @@ class GoogleRequestFactory : BaseRequestFactory() {
)
.generationConfig(
GoogleGenerationConfig.Builder()
.maxOutputTokens(configuration.maxTokens)
.maxOutputTokens(maxTokens)
.temperature(configuration.temperature.toDouble()).build()
)
.build()

View file

@ -2,7 +2,7 @@ package ee.carlrobert.codegpt.completions.factory
import com.intellij.openapi.components.service
import ee.carlrobert.codegpt.completions.BaseRequestFactory
import ee.carlrobert.codegpt.completions.CallParameters
import ee.carlrobert.codegpt.completions.ChatCompletionRequestParameters
import ee.carlrobert.codegpt.completions.CompletionRequestUtil.FIX_COMPILE_ERRORS_SYSTEM_PROMPT
import ee.carlrobert.codegpt.completions.ConversationType
import ee.carlrobert.codegpt.completions.llama.LlamaModel
@ -14,7 +14,8 @@ import ee.carlrobert.llm.client.llama.completion.LlamaCompletionRequest
class LlamaRequestFactory : BaseRequestFactory() {
override fun createChatRequest(callParameters: CallParameters): LlamaCompletionRequest {
override fun createChatRequest(params: ChatCompletionRequestParameters): LlamaCompletionRequest {
val (callParameters) = params
val promptTemplate = getPromptTemplate()
val systemPrompt =
if (callParameters.conversationType == ConversationType.FIX_COMPILE_ERRORS)
@ -33,6 +34,7 @@ class LlamaRequestFactory : BaseRequestFactory() {
override fun createBasicCompletionRequest(
systemPrompt: String,
userPrompt: String,
maxTokens: Int,
stream: Boolean
): LlamaCompletionRequest {
val promptTemplate = getPromptTemplate()

View file

@ -3,6 +3,7 @@ package ee.carlrobert.codegpt.completions.factory
import com.intellij.openapi.components.service
import ee.carlrobert.codegpt.completions.BaseRequestFactory
import ee.carlrobert.codegpt.completions.CallParameters
import ee.carlrobert.codegpt.completions.ChatCompletionRequestParameters
import ee.carlrobert.codegpt.completions.CompletionRequestUtil.FIX_COMPILE_ERRORS_SYSTEM_PROMPT
import ee.carlrobert.codegpt.completions.ConversationType
import ee.carlrobert.codegpt.settings.configuration.ConfigurationSettings
@ -18,7 +19,8 @@ import java.util.*
class OllamaRequestFactory : BaseRequestFactory() {
override fun createChatRequest(callParameters: CallParameters): OllamaChatCompletionRequest {
override fun createChatRequest(params: ChatCompletionRequestParameters): OllamaChatCompletionRequest {
val (callParameters) = params
val configuration = service<ConfigurationSettings>().state
val settings = service<OllamaSettings>().state
return OllamaChatCompletionRequest.Builder(
@ -38,6 +40,7 @@ class OllamaRequestFactory : BaseRequestFactory() {
override fun createBasicCompletionRequest(
systemPrompt: String,
userPrompt: String,
maxTokens: Int,
stream: Boolean
): OllamaChatCompletionRequest {
return OllamaChatCompletionRequest.Builder(

View file

@ -2,13 +2,10 @@ package ee.carlrobert.codegpt.completions.factory
import com.intellij.openapi.components.service
import ee.carlrobert.codegpt.EncodingManager
import ee.carlrobert.codegpt.completions.CallParameters
import ee.carlrobert.codegpt.completions.CompletionRequestFactory
import ee.carlrobert.codegpt.completions.*
import ee.carlrobert.codegpt.completions.CompletionRequestUtil.EDIT_CODE_SYSTEM_PROMPT
import ee.carlrobert.codegpt.completions.CompletionRequestUtil.FIX_COMPILE_ERRORS_SYSTEM_PROMPT
import ee.carlrobert.codegpt.completions.CompletionRequestUtil.GENERATE_METHOD_NAMES_SYSTEM_PROMPT
import ee.carlrobert.codegpt.completions.ConversationType
import ee.carlrobert.codegpt.completions.TotalUsageExceededException
import ee.carlrobert.codegpt.conversations.ConversationsState
import ee.carlrobert.codegpt.settings.configuration.ConfigurationSettings
import ee.carlrobert.codegpt.settings.configuration.ConfigurationSettings.Companion.getState
@ -17,62 +14,93 @@ import ee.carlrobert.codegpt.settings.service.openai.OpenAISettings
import ee.carlrobert.codegpt.util.file.FileUtil.getImageMediaType
import ee.carlrobert.llm.client.openai.completion.OpenAIChatCompletionModel
import ee.carlrobert.llm.client.openai.completion.request.*
import ee.carlrobert.llm.completion.CompletionRequest
import java.io.IOException
import java.nio.file.Files
import java.nio.file.Path
class OpenAIRequestFactory : CompletionRequestFactory {
override fun createChatRequest(callParameters: CallParameters): OpenAIChatCompletionRequest {
override fun createChatRequest(params: ChatCompletionRequestParameters): OpenAIChatCompletionRequest {
val (callParameters) = params
val model = service<OpenAISettings>().state.model
val configuration = service<ConfigurationSettings>().state
val requestBuilder: OpenAIChatCompletionRequest.Builder =
OpenAIChatCompletionRequest.Builder(buildOpenAIMessages(model, callParameters))
.setModel(model)
.setMaxTokens(configuration.maxTokens)
if ("o1-mini" == model || "o1-preview" == model) {
requestBuilder
.setMaxCompletionTokens(configuration.maxTokens)
.setStream(false)
.setMaxTokens(null)
.setTemperature(null)
.setPresencePenalty(null)
.setFrequencyPenalty(null)
} else {
requestBuilder
.setStream(true)
.setMaxTokens(configuration.maxTokens)
.setTemperature(configuration.temperature.toDouble())
}
return requestBuilder.build()
}
override fun createEditCodeRequest(input: String): OpenAIChatCompletionRequest {
return buildEditCodeRequest(input, service<OpenAISettings>().state.model)
override fun createEditCodeRequest(params: EditCodeRequestParameters): OpenAIChatCompletionRequest {
val model = service<OpenAISettings>().state.model
if (model == "o1-mini" || model == "o1-preview") {
return buildBasicO1Request(model, params.prompt, EDIT_CODE_SYSTEM_PROMPT)
}
return createBasicCompletionRequest(EDIT_CODE_SYSTEM_PROMPT, params.prompt, model, true)
}
override fun createCommitMessageRequest(
systemPrompt: String,
gitDiff: String
): CompletionRequest {
return createBasicCompletionRequest(
systemPrompt,
gitDiff,
service<OpenAISettings>().state.model,
true
)
override fun createCommitMessageRequest(params: CommitMessageRequestParameters): OpenAIChatCompletionRequest {
val model = service<OpenAISettings>().state.model
val (gitDiff, systemPrompt) = params
if (model == "o1-mini" || model == "o1-preview") {
return buildBasicO1Request(model, gitDiff, systemPrompt)
}
return createBasicCompletionRequest(systemPrompt, gitDiff, model, true)
}
override fun createLookupRequest(prompt: String): CompletionRequest {
return createBasicCompletionRequest(
GENERATE_METHOD_NAMES_SYSTEM_PROMPT,
prompt,
service<OpenAISettings>().state.model
)
override fun createLookupRequest(params: LookupRequestCallParameters): OpenAIChatCompletionRequest {
val model = service<OpenAISettings>().state.model
val (prompt) = params
if (model == "o1-mini" || model == "o1-preview") {
return buildBasicO1Request(model, prompt, GENERATE_METHOD_NAMES_SYSTEM_PROMPT)
}
return createBasicCompletionRequest(GENERATE_METHOD_NAMES_SYSTEM_PROMPT, prompt, model)
}
companion object {
fun buildEditCodeRequest(
input: String,
model: String? = null
fun buildBasicO1Request(
model: String,
prompt: String,
systemPrompt: String = "",
maxCompletionTokens: Int = 4096
): OpenAIChatCompletionRequest {
return createBasicCompletionRequest(EDIT_CODE_SYSTEM_PROMPT, input, model, true)
val messages = if (systemPrompt.isEmpty()) {
listOf(OpenAIChatCompletionStandardMessage("user", prompt))
} else {
listOf(
OpenAIChatCompletionStandardMessage("user", systemPrompt),
OpenAIChatCompletionStandardMessage("user", prompt)
)
}
return OpenAIChatCompletionRequest.Builder(messages)
.setModel(model)
.setMaxCompletionTokens(maxCompletionTokens)
.setStream(false)
.setTemperature(null)
.setFrequencyPenalty(null)
.setPresencePenalty(null)
.setMaxTokens(null)
.build()
}
fun buildOpenAIMessages(
model: String?,
callParameters: CallParameters
): List<OpenAIChatCompletionMessage> {
val messages = buildOpenAIMessages(callParameters)
val messages = buildOpenAIChatMessages(model, callParameters)
if (model == null) {
return messages
@ -104,21 +132,24 @@ class OpenAIRequestFactory : CompletionRequestFactory {
)
}
private fun buildOpenAIMessages(
private fun buildOpenAIChatMessages(
model: String?,
callParameters: CallParameters
): MutableList<OpenAIChatCompletionMessage> {
val message = callParameters.message
val messages = mutableListOf<OpenAIChatCompletionMessage>()
val role = if ("o1-mini" == model || "o1-preview" == model) "user" else "system"
if (callParameters.conversationType == ConversationType.DEFAULT) {
val sessionPersonaDetails = callParameters.message.personaDetails
if (callParameters.message.personaDetails == null) {
messages.add(
OpenAIChatCompletionStandardMessage("system", getSystemPrompt())
OpenAIChatCompletionStandardMessage(role, getSystemPrompt())
)
} else {
messages.add(
OpenAIChatCompletionStandardMessage(
"system",
role,
sessionPersonaDetails.instructions
)
)
@ -126,7 +157,7 @@ class OpenAIRequestFactory : CompletionRequestFactory {
}
if (callParameters.conversationType == ConversationType.FIX_COMPILE_ERRORS) {
messages.add(
OpenAIChatCompletionStandardMessage("system", FIX_COMPILE_ERRORS_SYSTEM_PROMPT)
OpenAIChatCompletionStandardMessage(role, FIX_COMPILE_ERRORS_SYSTEM_PROMPT)
)
}

View file

@ -14,49 +14,46 @@ object CodeGPTAvailableModels {
fun getToolWindowModels(pricingPlan: PricingPlan?): List<CodeGPTModel> {
return when (pricingPlan) {
null, ANONYMOUS -> listOf(
CodeGPTModel("GPT-4o", "gpt-4o", Icons.OpenAI, INDIVIDUAL),
CodeGPTModel("Claude 3.5 Sonnet", "claude-3.5-sonnet", Icons.Anthropic, INDIVIDUAL),
CodeGPTModel("Llama 3.1 (405B)", "llama-3.1-405b", Icons.Meta, INDIVIDUAL),
CodeGPTModel("DeepSeek Coder V2", "deepseek-coder-v2", Icons.DeepSeek, INDIVIDUAL),
CodeGPTModel("o1-mini", "o1-mini", Icons.OpenAI, INDIVIDUAL),
CodeGPTModel("GPT-4o", "gpt-4o", Icons.OpenAI, FREE),
CodeGPTModel("Claude 3.5 Sonnet", "claude-3.5-sonnet", Icons.Anthropic, FREE),
CodeGPTModel("Llama 3.1 (405B)", "llama-3.1-405b", Icons.Meta, FREE),
CodeGPTModel("DeepSeek Coder V2 - FREE", "deepseek-coder-v2", Icons.DeepSeek, ANONYMOUS),
CodeGPTModel("GPT-4o mini - FREE", "gpt-4o-mini", Icons.OpenAI, ANONYMOUS),
CodeGPTModel("Llama 3 (8B) - FREE", "llama-3-8b", Icons.Meta, ANONYMOUS)
)
FREE -> listOf(
CodeGPTModel("GPT-4o", "gpt-4o", Icons.OpenAI, INDIVIDUAL),
CodeGPTModel("Claude 3.5 Sonnet", "claude-3.5-sonnet", Icons.Anthropic, INDIVIDUAL),
CodeGPTModel("GPT-4o mini", "gpt-4o-mini", Icons.OpenAI, ANONYMOUS),
CodeGPTModel("Llama 3 (70B)", "llama-3-70b", Icons.Meta, FREE),
CodeGPTModel("Mixtral (8x22B)", "mixtral-8x22b", Icons.CodeGPTModel, FREE),
CodeGPTModel("Code Llama (70B)", "codellama:chat", Icons.Meta, FREE),
CodeGPTModel("o1-mini", "o1-mini", Icons.OpenAI, INDIVIDUAL),
CodeGPTModel("GPT-4o", "gpt-4o", Icons.OpenAI, FREE),
CodeGPTModel("Claude 3.5 Sonnet", "claude-3.5-sonnet", Icons.Anthropic, FREE),
CodeGPTModel("Llama 3.1 (405B)", "llama-3.1-405b", Icons.Meta, FREE),
CodeGPTModel("DeepSeek Coder V2", "deepseek-coder-v2", Icons.DeepSeek, ANONYMOUS),
CodeGPTModel("Qwen 2.5 (72B)", "qwen-2.5-72b", Icons.Qwen, FREE),
CodeGPTModel("Mixtral (8x22B)", "mixtral-8x22b", Icons.Mistral, FREE),
)
INDIVIDUAL -> listOf(
CodeGPTModel("GPT-4o", "gpt-4o", Icons.OpenAI, INDIVIDUAL),
CodeGPTModel("o1-mini", "o1-mini", Icons.OpenAI, INDIVIDUAL),
CodeGPTModel("GPT-4o", "gpt-4o", Icons.OpenAI, FREE),
CodeGPTModel("Claude 3.5 Sonnet", "claude-3.5-sonnet", Icons.Anthropic, FREE),
CodeGPTModel("Claude 3 Opus", "claude-3-opus", Icons.Anthropic, INDIVIDUAL),
CodeGPTModel("Claude 3.5 Sonnet", "claude-3.5-sonnet", Icons.Anthropic, INDIVIDUAL),
CodeGPTModel("Llama 3.1 (405B)", "llama-3.1-405b", Icons.Meta, INDIVIDUAL),
CodeGPTModel("DeepSeek Coder V2", "deepseek-coder-v2", Icons.DeepSeek, INDIVIDUAL),
CodeGPTModel("DBRX", "dbrx", Icons.Databricks, INDIVIDUAL),
CodeGPTModel("Llama 3.1 (405B)", "llama-3.1-405b", Icons.Meta, FREE),
CodeGPTModel("DeepSeek Coder V2", "deepseek-coder-v2", Icons.DeepSeek, FREE),
)
}
}
@JvmStatic
val ALL_CHAT_MODELS: List<CodeGPTModel> = listOf(
CodeGPTModel("GPT-4o", "gpt-4o", Icons.OpenAI, INDIVIDUAL),
CodeGPTModel("o1-mini", "o1-mini", Icons.OpenAI, INDIVIDUAL),
CodeGPTModel("GPT-4o", "gpt-4o", Icons.OpenAI, FREE),
CodeGPTModel("GPT-4o mini", "gpt-4o-mini", Icons.OpenAI, ANONYMOUS),
CodeGPTModel("Claude 3 Opus", "claude-3-opus", Icons.Anthropic, INDIVIDUAL),
CodeGPTModel("Claude 3.5 Sonnet", "claude-3.5-sonnet", Icons.Anthropic, INDIVIDUAL),
CodeGPTModel("Llama 3.1 (405B)", "llama-3.1-405b", Icons.Meta, INDIVIDUAL),
CodeGPTModel("Llama 3 (70B)", "llama-3-70b", Icons.Meta, FREE),
CodeGPTModel("DeepSeek Coder V2", "deepseek-coder-v2", Icons.DeepSeek, INDIVIDUAL),
CodeGPTModel("DBRX", "dbrx", Icons.Databricks, INDIVIDUAL),
CodeGPTModel("Llama 3 (8B) - FREE", "llama-3-8b", Icons.Meta, ANONYMOUS),
CodeGPTModel("Code Llama (70B)", "codellama:chat", Icons.Meta, FREE),
CodeGPTModel("Mixtral (8x22B)", "mixtral-8x22b", Icons.CodeGPTModel, FREE),
CodeGPTModel("DeepSeek Coder (33B)", "deepseek-coder-33b", Icons.CodeGPTModel, FREE),
CodeGPTModel("WizardLM-2 (8x22B)", "wizardlm-2-8x22b", Icons.CodeGPTModel, FREE)
CodeGPTModel("Claude 3.5 Sonnet", "claude-3.5-sonnet", Icons.Anthropic, FREE),
CodeGPTModel("Llama 3.1 (405B)", "llama-3.1-405b", Icons.Meta, FREE),
CodeGPTModel("DeepSeek Coder V2", "deepseek-coder-v2", Icons.DeepSeek, FREE),
CodeGPTModel("Mixtral (8x22B)", "mixtral-8x22b", Icons.Mistral, FREE),
CodeGPTModel("Qwen 2.5 (72B)", "qwen-2.5-72b", Icons.Qwen, FREE),
)
@JvmStatic
@ -65,7 +62,6 @@ object CodeGPTAvailableModels {
CodeGPTModel("StarCoder (16B)", "starcoder-16b", Icons.CodeGPTModel, FREE),
CodeGPTModel("StarCoder (7B) - FREE", "starcoder-7b", Icons.CodeGPTModel, FREE),
CodeGPTModel("WizardCoder Python (34B)", "wizardcoder-python", Icons.CodeGPTModel, FREE),
CodeGPTModel("Phind Code LLaMA v2 (34B)", "phind-codellama", Icons.CodeGPTModel, FREE)
)
@JvmStatic

View file

@ -0,0 +1,32 @@
<?xml version="1.0" encoding="UTF-8"?>
<svg width="13px" height="13px" viewBox="0 0 256 233" version="1.1" xmlns="http://www.w3.org/2000/svg" preserveAspectRatio="xMidYMid">
<title>Mistral AI</title>
<g>
<rect fill="#000000" x="186.181818" y="0" width="46.5454545" height="46.5454545"></rect>
<rect fill="#F7D046" x="209.454545" y="0" width="46.5454545" height="46.5454545"></rect>
<rect fill="#000000" x="0" y="0" width="46.5454545" height="46.5454545"></rect>
<rect fill="#000000" x="0" y="46.5454545" width="46.5454545" height="46.5454545"></rect>
<rect fill="#000000" x="0" y="93.0909091" width="46.5454545" height="46.5454545"></rect>
<rect fill="#000000" x="0" y="139.636364" width="46.5454545" height="46.5454545"></rect>
<rect fill="#000000" x="0" y="186.181818" width="46.5454545" height="46.5454545"></rect>
<rect fill="#F7D046" x="23.2727273" y="0" width="46.5454545" height="46.5454545"></rect>
<rect fill="#F2A73B" x="209.454545" y="46.5454545" width="46.5454545" height="46.5454545"></rect>
<rect fill="#F2A73B" x="23.2727273" y="46.5454545" width="46.5454545" height="46.5454545"></rect>
<rect fill="#000000" x="139.636364" y="46.5454545" width="46.5454545" height="46.5454545"></rect>
<rect fill="#F2A73B" x="162.909091" y="46.5454545" width="46.5454545" height="46.5454545"></rect>
<rect fill="#F2A73B" x="69.8181818" y="46.5454545" width="46.5454545" height="46.5454545"></rect>
<rect fill="#EE792F" x="116.363636" y="93.0909091" width="46.5454545" height="46.5454545"></rect>
<rect fill="#EE792F" x="162.909091" y="93.0909091" width="46.5454545" height="46.5454545"></rect>
<rect fill="#EE792F" x="69.8181818" y="93.0909091" width="46.5454545" height="46.5454545"></rect>
<rect fill="#000000" x="93.0909091" y="139.636364" width="46.5454545" height="46.5454545"></rect>
<rect fill="#EB5829" x="116.363636" y="139.636364" width="46.5454545" height="46.5454545"></rect>
<rect fill="#EE792F" x="209.454545" y="93.0909091" width="46.5454545" height="46.5454545"></rect>
<rect fill="#EE792F" x="23.2727273" y="93.0909091" width="46.5454545" height="46.5454545"></rect>
<rect fill="#000000" x="186.181818" y="139.636364" width="46.5454545" height="46.5454545"></rect>
<rect fill="#EB5829" x="209.454545" y="139.636364" width="46.5454545" height="46.5454545"></rect>
<rect fill="#000000" x="186.181818" y="186.181818" width="46.5454545" height="46.5454545"></rect>
<rect fill="#EB5829" x="23.2727273" y="139.636364" width="46.5454545" height="46.5454545"></rect>
<rect fill="#EA3326" x="209.454545" y="186.181818" width="46.5454545" height="46.5454545"></rect>
<rect fill="#EA3326" x="23.2727273" y="186.181818" width="46.5454545" height="46.5454545"></rect>
</g>
</svg>

After

Width:  |  Height:  |  Size: 2.7 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 713 B

View file

@ -24,12 +24,14 @@ class CompletionRequestProviderTest : IntegrationTest() {
conversation.addMessage(secondMessage)
val request = OpenAIRequestFactory().createChatRequest(
CallParameters(
conversation,
ConversationType.DEFAULT,
Message("TEST_CHAT_COMPLETION_PROMPT"),
null,
false
ChatCompletionRequestParameters(
CallParameters(
conversation,
ConversationType.DEFAULT,
Message("TEST_CHAT_COMPLETION_PROMPT"),
null,
false
)
)
)
@ -55,12 +57,14 @@ class CompletionRequestProviderTest : IntegrationTest() {
conversation.addMessage(secondMessage)
val request = OpenAIRequestFactory().createChatRequest(
CallParameters(
conversation,
ConversationType.DEFAULT,
Message("TEST_CHAT_COMPLETION_PROMPT"),
null,
false
ChatCompletionRequestParameters(
CallParameters(
conversation,
ConversationType.DEFAULT,
Message("TEST_CHAT_COMPLETION_PROMPT"),
null,
false
)
)
)
@ -86,12 +90,14 @@ class CompletionRequestProviderTest : IntegrationTest() {
conversation.addMessage(secondMessage)
val request = OpenAIRequestFactory().createChatRequest(
CallParameters(
conversation,
ConversationType.DEFAULT,
secondMessage,
null,
true
ChatCompletionRequestParameters(
CallParameters(
conversation,
ConversationType.DEFAULT,
secondMessage,
null,
true
)
)
)
@ -118,12 +124,14 @@ class CompletionRequestProviderTest : IntegrationTest() {
conversation.discardTokenLimits()
val request = OpenAIRequestFactory().createChatRequest(
CallParameters(
conversation,
ConversationType.DEFAULT,
Message("TEST_CHAT_COMPLETION_PROMPT"),
null,
false
ChatCompletionRequestParameters(
CallParameters(
conversation,
ConversationType.DEFAULT,
Message("TEST_CHAT_COMPLETION_PROMPT"),
null,
false
)
)
)
@ -146,12 +154,14 @@ class CompletionRequestProviderTest : IntegrationTest() {
assertThrows(TotalUsageExceededException::class.java) {
OpenAIRequestFactory().createChatRequest(
CallParameters(
conversation,
ConversationType.DEFAULT,
createDummyMessage(100),
null,
false
ChatCompletionRequestParameters(
CallParameters(
conversation,
ConversationType.DEFAULT,
createDummyMessage(100),
null,
false
)
)
)
}

View file

@ -14,14 +14,17 @@ import org.apache.http.HttpHeaders
import org.assertj.core.api.Assertions.assertThat
import testsupport.IntegrationTest
class DefaultCompletionRequestHandlerTest : IntegrationTest() {
class DefaultToolwindowChatCompletionRequestHandlerTest : IntegrationTest() {
fun testOpenAIChatCompletionCall() {
useOpenAIService()
service<PersonaSettings>().state.selectedPersona.instructions = "TEST_SYSTEM_PROMPT"
val message = Message("TEST_PROMPT")
val conversation = ConversationService.getInstance().startConversation()
val requestHandler = CompletionRequestHandler(getRequestEventListener(message))
val requestHandler =
ToolwindowChatCompletionRequestHandler(
getRequestEventListener(message)
)
expectOpenAI(StreamHttpExchange { request: RequestEntity ->
assertThat(request.uri.path).isEqualTo("/v1/chat/completions")
assertThat(request.method).isEqualTo("POST")
@ -77,7 +80,10 @@ class DefaultCompletionRequestHandlerTest : IntegrationTest() {
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "!")))))
})
val message = Message("TEST_PROMPT")
val requestHandler = CompletionRequestHandler(getRequestEventListener(message))
val requestHandler =
ToolwindowChatCompletionRequestHandler(
getRequestEventListener(message)
)
requestHandler.call(CallParameters(conversation, message))
@ -91,7 +97,10 @@ class DefaultCompletionRequestHandlerTest : IntegrationTest() {
val message = Message("TEST_PROMPT")
val conversation = ConversationService.getInstance().startConversation()
conversation.addMessage(Message("Ping", "Pong"))
val requestHandler = CompletionRequestHandler(getRequestEventListener(message))
val requestHandler =
ToolwindowChatCompletionRequestHandler(
getRequestEventListener(message)
)
expectLlama(StreamHttpExchange { request: RequestEntity ->
assertThat(request.uri.path).isEqualTo("/completion")
assertThat(request.body)
@ -125,7 +134,10 @@ class DefaultCompletionRequestHandlerTest : IntegrationTest() {
service<PersonaSettings>().state.selectedPersona.instructions = "TEST_SYSTEM_PROMPT"
val message = Message("TEST_PROMPT")
val conversation = ConversationService.getInstance().startConversation()
val requestHandler = CompletionRequestHandler(getRequestEventListener(message))
val requestHandler =
ToolwindowChatCompletionRequestHandler(
getRequestEventListener(message)
)
expectOllama(NdJsonStreamHttpExchange { request: RequestEntity ->
assertThat(request.uri.path).isEqualTo("/api/chat")
assertThat(request.headers[HttpHeaders.AUTHORIZATION]!![0]).isEqualTo("Bearer TEST_API_KEY")
@ -171,7 +183,10 @@ class DefaultCompletionRequestHandlerTest : IntegrationTest() {
service<PersonaSettings>().state.selectedPersona.instructions = "TEST_SYSTEM_PROMPT"
val message = Message("TEST_PROMPT")
val conversation = ConversationService.getInstance().startConversation()
val requestHandler = CompletionRequestHandler(getRequestEventListener(message))
val requestHandler =
ToolwindowChatCompletionRequestHandler(
getRequestEventListener(message)
)
expectGoogle(StreamHttpExchange { request: RequestEntity ->
assertThat(request.uri.path).isEqualTo("/v1/models/gemini-pro:streamGenerateContent")
assertThat(request.method).isEqualTo("POST")
@ -207,7 +222,10 @@ class DefaultCompletionRequestHandlerTest : IntegrationTest() {
service<PersonaSettings>().state.selectedPersona.instructions = "TEST_SYSTEM_PROMPT"
val message = Message("TEST_PROMPT")
val conversation = ConversationService.getInstance().startConversation()
val requestHandler = CompletionRequestHandler(getRequestEventListener(message))
val requestHandler =
ToolwindowChatCompletionRequestHandler(
getRequestEventListener(message)
)
expectCodeGPT(StreamHttpExchange { request: RequestEntity ->
assertThat(request.uri.path).isEqualTo("/v1/chat/completions")
assertThat(request.method).isEqualTo("POST")