feat: support qwen2.5 and o1 models

This commit is contained in:
Carl-Robert Linnupuu 2024-10-01 12:36:37 +03:00
parent c614411315
commit 6b4e22b545
33 changed files with 521 additions and 314 deletions

View file

@ -15,10 +15,12 @@ public final class Icons {
public static final Icon Azure = IconLoader.getIcon("/icons/azure.svg", Icons.class);
public static final Icon Databricks = IconLoader.getIcon("/icons/dbrx.svg", Icons.class);
public static final Icon DeepSeek = IconLoader.getIcon("/icons/deepseek.png", Icons.class);
public static final Icon Qwen = IconLoader.getIcon("/icons/qwen.png", Icons.class);
public static final Icon Google = IconLoader.getIcon("/icons/google.svg", Icons.class);
public static final Icon Llama = IconLoader.getIcon("/icons/llama.svg", Icons.class);
public static final Icon OpenAI = IconLoader.getIcon("/icons/openai.svg", Icons.class);
public static final Icon Meta = IconLoader.getIcon("/icons/meta.svg", Icons.class);
public static final Icon Mistral = IconLoader.getIcon("/icons/mistral.svg", Icons.class);
public static final Icon Send = IconLoader.getIcon("/icons/send.svg", Icons.class);
public static final Icon Sparkle = IconLoader.getIcon("/icons/sparkle.svg", Icons.class);
public static final Icon You = IconLoader.getIcon("/icons/you.svg", Icons.class);

View file

@ -18,6 +18,7 @@ import com.intellij.vcs.commit.CommitWorkflowUi;
import ee.carlrobert.codegpt.CodeGPTBundle;
import ee.carlrobert.codegpt.EncodingManager;
import ee.carlrobert.codegpt.Icons;
import ee.carlrobert.codegpt.completions.CommitMessageRequestParameters;
import ee.carlrobert.codegpt.completions.CompletionRequestService;
import ee.carlrobert.codegpt.settings.configuration.CommitMessageTemplate;
import ee.carlrobert.codegpt.ui.OverlayUtil;
@ -85,8 +86,9 @@ public class GenerateGitCommitMessageAction extends AnAction {
var commitWorkflowUi = event.getData(VcsDataKeys.COMMIT_WORKFLOW_UI);
if (commitWorkflowUi != null) {
CompletionRequestService.getInstance().getCommitMessageAsync(
project.getService(CommitMessageTemplate.class).getSystemPrompt(),
gitDiff,
new CommitMessageRequestParameters(
gitDiff,
project.getService(CommitMessageTemplate.class).getSystemPrompt()),
getEventListener(project, commitWorkflowUi));
}
}
@ -162,11 +164,22 @@ public class GenerateGitCommitMessageAction extends AnAction {
@Override
public void onMessage(String message, EventSource eventSource) {
messageBuilder.append(message);
var application = ApplicationManager.getApplication();
application.invokeLater(() ->
application.runWriteAction(() ->
WriteCommandAction.runWriteCommandAction(project, () ->
commitWorkflowUi.getCommitMessageUi().setText(messageBuilder.toString()))));
updateCommitMessage(messageBuilder.toString());
}
@Override
public void onComplete(StringBuilder result) {
if (messageBuilder.isEmpty()) {
updateCommitMessage(result.toString());
}
}
private void updateCommitMessage(String message) {
ApplicationManager.getApplication().invokeLater(() ->
WriteCommandAction.runWriteCommandAction(project, () ->
commitWorkflowUi.getCommitMessageUi().setText(message)
)
);
}
@Override

View file

@ -0,0 +1,74 @@
package ee.carlrobert.codegpt.completions;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import ee.carlrobert.codegpt.events.CodeGPTEvent;
import ee.carlrobert.codegpt.telemetry.TelemetryAction;
import ee.carlrobert.llm.client.openai.completion.ErrorDetails;
import ee.carlrobert.llm.completion.CompletionEventListener;
import okhttp3.sse.EventSource;
public class ChatCompletionEventListener implements CompletionEventListener<String> {
private final CallParameters callParameters;
private final CompletionResponseEventListener eventListener;
private final StringBuilder messageBuilder = new StringBuilder();
public ChatCompletionEventListener(
CallParameters callParameters,
CompletionResponseEventListener eventListener) {
this.callParameters = callParameters;
this.eventListener = eventListener;
}
@Override
public void onEvent(String data) {
try {
var event = new ObjectMapper().readValue(data, CodeGPTEvent.class);
eventListener.handleCodeGPTEvent(event);
} catch (JsonProcessingException e) {
// ignore
}
}
@Override
public void onMessage(String message, EventSource eventSource) {
messageBuilder.append(message);
callParameters.getMessage().setResponse(messageBuilder.toString());
eventListener.handleMessage(message);
}
@Override
public void onComplete(StringBuilder messageBuilder) {
eventListener.handleCompleted(messageBuilder.toString(), callParameters);
}
@Override
public void onCancelled(StringBuilder messageBuilder) {
eventListener.handleCompleted(messageBuilder.toString(), callParameters);
}
@Override
public void onError(ErrorDetails error, Throwable ex) {
try {
eventListener.handleError(error, ex);
} finally {
sendError(error, ex);
}
}
private void sendError(ErrorDetails error, Throwable ex) {
var telemetryMessage = TelemetryAction.COMPLETION_ERROR.createActionMessage();
if ("insufficient_quota".equals(error.getCode())) {
telemetryMessage
.property("type", "USER")
.property("code", "INSUFFICIENT_QUOTA");
} else {
telemetryMessage
.property("conversationId", callParameters.getConversation().getId().toString())
.property("model", callParameters.getConversation().getModel())
.error(new RuntimeException(error.toString(), ex));
}
telemetryMessage.send();
}
}

View file

@ -89,7 +89,6 @@ public class CompletionClientProvider {
return builder.build(getDefaultClientBuilder());
}
public static GoogleClient getGoogleClient() {
return new GoogleClient.Builder(getCredential(CredentialKey.GOOGLE_API_KEY))
.build(getDefaultClientBuilder());

View file

@ -1,129 +0,0 @@
package ee.carlrobert.codegpt.completions;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import ee.carlrobert.codegpt.events.CodeGPTEvent;
import ee.carlrobert.codegpt.settings.GeneralSettings;
import ee.carlrobert.codegpt.telemetry.TelemetryAction;
import ee.carlrobert.llm.client.openai.completion.ErrorDetails;
import ee.carlrobert.llm.completion.CompletionEventListener;
import okhttp3.sse.EventSource;
public class CompletionRequestHandler {
private final StringBuilder messageBuilder = new StringBuilder();
private final CompletionResponseEventListener completionResponseEventListener;
private EventSource eventSource;
public CompletionRequestHandler(CompletionResponseEventListener completionResponseEventListener) {
this.completionResponseEventListener = completionResponseEventListener;
}
public void call(CallParameters callParameters) {
try {
eventSource = startCall(callParameters, new RequestCompletionEventListener(callParameters));
} catch (TotalUsageExceededException e) {
completionResponseEventListener.handleTokensExceeded(
callParameters.getConversation(),
callParameters.getMessage());
} finally {
sendInfo(callParameters);
}
}
public void cancel() {
if (eventSource != null) {
eventSource.cancel();
}
}
private EventSource startCall(
CallParameters callParameters,
CompletionEventListener<String> eventListener) {
try {
return CompletionRequestService.getInstance()
.getChatCompletionAsync(callParameters, eventListener);
} catch (Throwable ex) {
handleCallException(ex);
throw ex;
}
}
private void handleCallException(Throwable ex) {
var errorMessage = "Something went wrong";
if (ex instanceof TotalUsageExceededException) {
errorMessage =
"The length of the context exceeds the maximum limit that the model can handle. "
+ "Try reducing the input message or maximum completion token size.";
}
completionResponseEventListener.handleError(new ErrorDetails(errorMessage), ex);
}
class RequestCompletionEventListener implements CompletionEventListener<String> {
private final CallParameters callParameters;
public RequestCompletionEventListener(CallParameters callParameters) {
this.callParameters = callParameters;
}
@Override
public void onEvent(String data) {
try {
var event = new ObjectMapper().readValue(data, CodeGPTEvent.class);
completionResponseEventListener.handleCodeGPTEvent(event);
} catch (JsonProcessingException e) {
// ignore
}
}
@Override
public void onMessage(String message, EventSource eventSource) {
messageBuilder.append(message);
callParameters.getMessage().setResponse(messageBuilder.toString());
completionResponseEventListener.handleMessage(message);
}
@Override
public void onComplete(StringBuilder messageBuilder) {
completionResponseEventListener.handleCompleted(messageBuilder.toString(), callParameters);
}
@Override
public void onCancelled(StringBuilder messageBuilder) {
completionResponseEventListener.handleCompleted(messageBuilder.toString(), callParameters);
}
@Override
public void onError(ErrorDetails error, Throwable ex) {
try {
completionResponseEventListener.handleError(error, ex);
} finally {
sendError(error, ex);
}
}
private void sendError(ErrorDetails error, Throwable ex) {
var telemetryMessage = TelemetryAction.COMPLETION_ERROR.createActionMessage();
if ("insufficient_quota".equals(error.getCode())) {
telemetryMessage
.property("type", "USER")
.property("code", "INSUFFICIENT_QUOTA");
} else {
telemetryMessage
.property("conversationId", callParameters.getConversation().getId().toString())
.property("model", callParameters.getConversation().getModel())
.error(new RuntimeException(error.toString(), ex));
}
telemetryMessage.send();
}
}
private void sendInfo(CallParameters callParameters) {
TelemetryAction.COMPLETION.createActionMessage()
.property("conversationId", callParameters.getConversation().getId().toString())
.property("model", callParameters.getConversation().getModel())
.property("service", GeneralSettings.getSelectedService().getCode().toLowerCase())
.send();
}
}

View file

@ -3,7 +3,9 @@ package ee.carlrobert.codegpt.completions;
import com.intellij.openapi.application.ApplicationManager;
import com.intellij.openapi.components.Service;
import com.intellij.openapi.diagnostic.Logger;
import ee.carlrobert.codegpt.actions.editor.EditCodeRequestParams;
import com.intellij.openapi.progress.ProgressIndicator;
import com.intellij.openapi.progress.ProgressManager;
import com.intellij.openapi.progress.Task;
import ee.carlrobert.codegpt.completions.factory.CustomOpenAIRequest;
import ee.carlrobert.codegpt.credentials.CredentialsStore;
import ee.carlrobert.codegpt.credentials.CredentialsStore.CredentialKey;
@ -26,12 +28,15 @@ import ee.carlrobert.llm.completion.CompletionEventListener;
import ee.carlrobert.llm.completion.CompletionRequest;
import java.io.IOException;
import java.util.Collection;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Stream;
import javax.swing.SwingUtilities;
import okhttp3.Request;
import okhttp3.sse.EventSource;
import okhttp3.sse.EventSources;
import org.jetbrains.annotations.NotNull;
@Service
public final class CompletionRequestService {
@ -63,50 +68,50 @@ public final class CompletionRequestService {
new OpenAIChatCompletionEventSourceListener(eventListener));
}
public String getLookupCompletion(String prompt) {
return getChatCompletion(
CompletionRequestFactory.getFactory(GeneralSettings.getSelectedService())
.createLookupRequest(prompt));
public String getLookupCompletion(LookupRequestCallParameters params) {
var request = CompletionRequestFactory
.getFactory(GeneralSettings.getSelectedService())
.createLookupRequest(params);
return getChatCompletion(request);
}
public EventSource getCommitMessageAsync(
String systemPrompt,
String gitDiff,
CommitMessageRequestParameters params,
CompletionEventListener<String> eventListener) {
return getChatCompletionAsync(
CompletionRequestFactory.getFactory(GeneralSettings.getSelectedService())
.createCommitMessageRequest(systemPrompt, gitDiff),
eventListener);
var request = CompletionRequestFactory
.getFactory(GeneralSettings.getSelectedService())
.createCommitMessageRequest(params);
return getChatCompletionAsync(request, eventListener);
}
public EventSource getEditCodeCompletionAsync(
EditCodeRequestParams params,
EditCodeRequestParameters params,
CompletionEventListener<String> eventListener) {
var input = "%s\n\n%s".formatted(params.getPrompt(), params.getSelectedText());
return getChatCompletionAsync(
CompletionRequestFactory.getFactory(GeneralSettings.getSelectedService())
.createEditCodeRequest(input),
eventListener);
var request = CompletionRequestFactory
.getFactory(GeneralSettings.getSelectedService())
.createEditCodeRequest(params);
return getChatCompletionAsync(request, eventListener);
}
public EventSource getChatCompletionAsync(
CallParameters callParameters,
CompletionEventListener<String> eventListener) {
return getChatCompletionAsync(
CompletionRequestFactory.getFactory(GeneralSettings.getSelectedService())
.createChatRequest(callParameters),
eventListener);
}
private EventSource getChatCompletionAsync(
CompletionRequest request,
CompletionEventListener<String> eventListener) {
if (request instanceof OpenAIChatCompletionRequest completionRequest) {
return switch (GeneralSettings.getSelectedService()) {
case CODEGPT -> CompletionClientProvider.getCodeGPTClient()
.getChatCompletionAsync(completionRequest, eventListener);
case OPENAI -> CompletionClientProvider.getOpenAIClient()
.getChatCompletionAsync(completionRequest, eventListener);
case CODEGPT -> {
if (List.of("o1-mini", "o1-preview").contains(completionRequest.getModel())) {
yield getO1ChatCompletionAsync(completionRequest, eventListener);
}
yield CompletionClientProvider.getCodeGPTClient()
.getChatCompletionAsync(completionRequest, eventListener);
}
case OPENAI -> {
if (List.of("o1-mini", "o1-preview").contains(completionRequest.getModel())) {
yield getO1ChatCompletionAsync(completionRequest, eventListener);
}
yield CompletionClientProvider.getOpenAIClient()
.getChatCompletionAsync(completionRequest, eventListener);
}
case AZURE -> CompletionClientProvider.getAzureClient()
.getChatCompletionAsync(completionRequest, eventListener);
default -> throw new RuntimeException("Unknown service selected");
@ -142,7 +147,33 @@ public final class CompletionRequestService {
throw new IllegalStateException("Unknown request type: " + request.getClass());
}
private String getChatCompletion(CompletionRequest request) {
private EventSource getO1ChatCompletionAsync(
OpenAIChatCompletionRequest request,
CompletionEventListener<String> eventListener) {
ProgressManager.getInstance()
.run(new Task.Backgroundable(null, "CodeGPT: Processing o1 request") {
@Override
public void run(@NotNull ProgressIndicator indicator) {
indicator.setIndeterminate(true);
var response = CompletionRequestService.getInstance().getChatCompletion(request);
SwingUtilities.invokeLater(() -> eventListener.onComplete(new StringBuilder(response)));
}
});
return new EventSource() {
@Override
public @NotNull Request request() {
return new Request.Builder().build(); // dummy
}
@Override
public void cancel() {
eventListener.onCancelled(new StringBuilder("Cancelled"));
}
};
}
public String getChatCompletion(CompletionRequest request) {
if (request instanceof OpenAIChatCompletionRequest completionRequest) {
var response = switch (GeneralSettings.getSelectedService()) {
case CODEGPT -> CompletionClientProvider.getCodeGPTClient()

View file

@ -16,6 +16,9 @@ public interface CompletionResponseEventListener {
default void handleTokensExceeded(Conversation conversation, Message message) {
}
default void handleCompleted(String fullMessage) {
}
default void handleCompleted(String fullMessage, CallParameters callParameters) {
}

View file

@ -56,7 +56,8 @@ public class MethodNameLookupListener implements LookupManagerListener {
Application application,
String prompt) {
try {
var response = CompletionRequestService.getInstance().getLookupCompletion(prompt);
var response = CompletionRequestService.getInstance()
.getLookupCompletion(new LookupRequestCallParameters(prompt));
if (!response.isEmpty()) {
for (var value : response.split(",")) {
application.invokeLater(() -> application.runReadAction(() -> {

View file

@ -0,0 +1,67 @@
package ee.carlrobert.codegpt.completions;
import ee.carlrobert.codegpt.settings.GeneralSettings;
import ee.carlrobert.codegpt.telemetry.TelemetryAction;
import ee.carlrobert.llm.client.openai.completion.ErrorDetails;
import okhttp3.sse.EventSource;
public class ToolwindowChatCompletionRequestHandler {
private final CompletionResponseEventListener completionResponseEventListener;
private EventSource eventSource;
public ToolwindowChatCompletionRequestHandler(
CompletionResponseEventListener completionResponseEventListener) {
this.completionResponseEventListener = completionResponseEventListener;
}
public void call(CallParameters callParameters) {
try {
eventSource = startCall(callParameters);
} catch (TotalUsageExceededException e) {
completionResponseEventListener.handleTokensExceeded(
callParameters.getConversation(),
callParameters.getMessage());
} finally {
sendInfo(callParameters);
}
}
public void cancel() {
if (eventSource != null) {
eventSource.cancel();
}
}
private EventSource startCall(CallParameters callParameters) {
try {
var request = CompletionRequestFactory
.getFactory(GeneralSettings.getSelectedService())
.createChatRequest(new ChatCompletionRequestParameters(callParameters));
return CompletionRequestService.getInstance().getChatCompletionAsync(
request,
new ChatCompletionEventListener(callParameters, completionResponseEventListener));
} catch (Throwable ex) {
handleCallException(ex);
throw ex;
}
}
private void handleCallException(Throwable ex) {
var errorMessage = "Something went wrong";
if (ex instanceof TotalUsageExceededException) {
errorMessage =
"The length of the context exceeds the maximum limit that the model can handle. "
+ "Try reducing the input message or maximum completion token size.";
}
completionResponseEventListener.handleError(new ErrorDetails(errorMessage), ex);
}
private void sendInfo(CallParameters callParameters) {
TelemetryAction.COMPLETION.createActionMessage()
.property("conversationId", callParameters.getConversation().getId().toString())
.property("model", callParameters.getConversation().getModel())
.property("service", GeneralSettings.getSelectedService().getCode().toLowerCase())
.send();
}
}

View file

@ -11,8 +11,8 @@ public class AdvancedSettingsState {
private boolean proxyAuthSelected;
private String proxyUsername;
private String proxyPassword;
private int connectTimeout = 30;
private int readTimeout = 30;
private int connectTimeout = 120;
private int readTimeout = 120;
public String getProxyHost() {
return proxyHost;

View file

@ -15,10 +15,10 @@ import ee.carlrobert.codegpt.CodeGPTKeys;
import ee.carlrobert.codegpt.ReferencedFile;
import ee.carlrobert.codegpt.actions.ActionType;
import ee.carlrobert.codegpt.completions.CallParameters;
import ee.carlrobert.codegpt.completions.CompletionRequestHandler;
import ee.carlrobert.codegpt.completions.CompletionRequestService;
import ee.carlrobert.codegpt.completions.CompletionRequestUtil;
import ee.carlrobert.codegpt.completions.ConversationType;
import ee.carlrobert.codegpt.completions.ToolwindowChatCompletionRequestHandler;
import ee.carlrobert.codegpt.conversations.Conversation;
import ee.carlrobert.codegpt.conversations.ConversationService;
import ee.carlrobert.codegpt.conversations.message.Message;
@ -60,7 +60,7 @@ public class ChatToolWindowTabPanel implements Disposable {
private final TotalTokensPanel totalTokensPanel;
private final ChatToolWindowScrollablePanel toolWindowScrollablePanel;
private @Nullable CompletionRequestHandler requestHandler;
private @Nullable ToolwindowChatCompletionRequestHandler requestHandler;
public ChatToolWindowTabPanel(@NotNull Project project, @NotNull Conversation conversation) {
this.project = project;
@ -250,7 +250,7 @@ public class ChatToolWindowTabPanel implements Disposable {
return;
}
requestHandler = new CompletionRequestHandler(
requestHandler = new ToolwindowChatCompletionRequestHandler(
new ToolWindowCompletionResponseEventListener(
conversationService,
responsePanel,

View file

@ -112,6 +112,9 @@ abstract class ToolWindowCompletionResponseEventListener implements
try {
responsePanel.enableActions();
responseContainer.enableActions();
if (!responseContainer.isResponseReceived() && !fullMessage.isEmpty()) {
responseContainer.withResponse(fullMessage);
}
totalTokensPanel.updateUserPromptTokens(textArea.getText());
totalTokensPanel.updateConversationTokens(callParameters.getConversation());
} finally {

View file

@ -113,6 +113,10 @@ public class ChatMessageResponseBody extends JPanel {
}
public ChatMessageResponseBody withResponse(String response) {
if (!responseReceived) {
removeAll();
}
for (var message : MarkdownUtil.splitCodeBlocks(response)) {
currentlyProcessedEditorPanel = null;
currentlyProcessedTextPane = null;
@ -362,4 +366,8 @@ public class ChatMessageResponseBody extends JPanel {
panel.add(listPanel, BorderLayout.CENTER);
return panel;
}
public boolean isResponseReceived() {
return responseReceived;
}
}

View file

@ -123,9 +123,10 @@ public class ModelComboBoxAction extends ComboBoxAction {
var openaiGroup = DefaultActionGroup.createPopupGroup(() -> "OpenAI");
openaiGroup.getTemplatePresentation().setIcon(Icons.OpenAI);
List.of(
OpenAIChatCompletionModel.O_1_PREVIEW,
OpenAIChatCompletionModel.O_1_MINI,
OpenAIChatCompletionModel.GPT_4_O,
OpenAIChatCompletionModel.GPT_4_O_MINI,
OpenAIChatCompletionModel.GPT_4_VISION_PREVIEW,
OpenAIChatCompletionModel.GPT_4_0125_128k)
.forEach(model -> openaiGroup.add(createOpenAIModelAction(model, presentation)));
actionGroup.add(openaiGroup);