mirror of
https://github.com/carlrobertoh/ProxyAI.git
synced 2026-05-20 09:24:08 +00:00
feat: support qwen2.5 and o1 models
This commit is contained in:
parent
5c9253278f
commit
24ae263a39
32 changed files with 521 additions and 314 deletions
|
|
@ -12,7 +12,7 @@ jsoup = "1.17.2"
|
|||
jtokkit = "1.1.0"
|
||||
junit = "5.11.0"
|
||||
kotlin = "2.0.0"
|
||||
llm-client = "0.8.18"
|
||||
llm-client = "0.8.19"
|
||||
okio = "3.9.0"
|
||||
tree-sitter = "0.22.6a"
|
||||
|
||||
|
|
|
|||
|
|
@ -15,10 +15,12 @@ public final class Icons {
|
|||
public static final Icon Azure = IconLoader.getIcon("/icons/azure.svg", Icons.class);
|
||||
public static final Icon Databricks = IconLoader.getIcon("/icons/dbrx.svg", Icons.class);
|
||||
public static final Icon DeepSeek = IconLoader.getIcon("/icons/deepseek.png", Icons.class);
|
||||
public static final Icon Qwen = IconLoader.getIcon("/icons/qwen.png", Icons.class);
|
||||
public static final Icon Google = IconLoader.getIcon("/icons/google.svg", Icons.class);
|
||||
public static final Icon Llama = IconLoader.getIcon("/icons/llama.svg", Icons.class);
|
||||
public static final Icon OpenAI = IconLoader.getIcon("/icons/openai.svg", Icons.class);
|
||||
public static final Icon Meta = IconLoader.getIcon("/icons/meta.svg", Icons.class);
|
||||
public static final Icon Mistral = IconLoader.getIcon("/icons/mistral.svg", Icons.class);
|
||||
public static final Icon Send = IconLoader.getIcon("/icons/send.svg", Icons.class);
|
||||
public static final Icon Sparkle = IconLoader.getIcon("/icons/sparkle.svg", Icons.class);
|
||||
public static final Icon You = IconLoader.getIcon("/icons/you.svg", Icons.class);
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ import com.intellij.vcs.commit.CommitWorkflowUi;
|
|||
import ee.carlrobert.codegpt.CodeGPTBundle;
|
||||
import ee.carlrobert.codegpt.EncodingManager;
|
||||
import ee.carlrobert.codegpt.Icons;
|
||||
import ee.carlrobert.codegpt.completions.CommitMessageRequestParameters;
|
||||
import ee.carlrobert.codegpt.completions.CompletionRequestService;
|
||||
import ee.carlrobert.codegpt.settings.configuration.CommitMessageTemplate;
|
||||
import ee.carlrobert.codegpt.ui.OverlayUtil;
|
||||
|
|
@ -85,8 +86,9 @@ public class GenerateGitCommitMessageAction extends AnAction {
|
|||
var commitWorkflowUi = event.getData(VcsDataKeys.COMMIT_WORKFLOW_UI);
|
||||
if (commitWorkflowUi != null) {
|
||||
CompletionRequestService.getInstance().getCommitMessageAsync(
|
||||
project.getService(CommitMessageTemplate.class).getSystemPrompt(),
|
||||
gitDiff,
|
||||
new CommitMessageRequestParameters(
|
||||
gitDiff,
|
||||
project.getService(CommitMessageTemplate.class).getSystemPrompt()),
|
||||
getEventListener(project, commitWorkflowUi));
|
||||
}
|
||||
}
|
||||
|
|
@ -162,11 +164,22 @@ public class GenerateGitCommitMessageAction extends AnAction {
|
|||
@Override
|
||||
public void onMessage(String message, EventSource eventSource) {
|
||||
messageBuilder.append(message);
|
||||
var application = ApplicationManager.getApplication();
|
||||
application.invokeLater(() ->
|
||||
application.runWriteAction(() ->
|
||||
WriteCommandAction.runWriteCommandAction(project, () ->
|
||||
commitWorkflowUi.getCommitMessageUi().setText(messageBuilder.toString()))));
|
||||
updateCommitMessage(messageBuilder.toString());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onComplete(StringBuilder result) {
|
||||
if (messageBuilder.isEmpty()) {
|
||||
updateCommitMessage(result.toString());
|
||||
}
|
||||
}
|
||||
|
||||
private void updateCommitMessage(String message) {
|
||||
ApplicationManager.getApplication().invokeLater(() ->
|
||||
WriteCommandAction.runWriteCommandAction(project, () ->
|
||||
commitWorkflowUi.getCommitMessageUi().setText(message)
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
|||
|
|
@ -0,0 +1,74 @@
|
|||
package ee.carlrobert.codegpt.completions;
|
||||
|
||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import ee.carlrobert.codegpt.events.CodeGPTEvent;
|
||||
import ee.carlrobert.codegpt.telemetry.TelemetryAction;
|
||||
import ee.carlrobert.llm.client.openai.completion.ErrorDetails;
|
||||
import ee.carlrobert.llm.completion.CompletionEventListener;
|
||||
import okhttp3.sse.EventSource;
|
||||
|
||||
public class ChatCompletionEventListener implements CompletionEventListener<String> {
|
||||
|
||||
private final CallParameters callParameters;
|
||||
private final CompletionResponseEventListener eventListener;
|
||||
private final StringBuilder messageBuilder = new StringBuilder();
|
||||
|
||||
public ChatCompletionEventListener(
|
||||
CallParameters callParameters,
|
||||
CompletionResponseEventListener eventListener) {
|
||||
this.callParameters = callParameters;
|
||||
this.eventListener = eventListener;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onEvent(String data) {
|
||||
try {
|
||||
var event = new ObjectMapper().readValue(data, CodeGPTEvent.class);
|
||||
eventListener.handleCodeGPTEvent(event);
|
||||
} catch (JsonProcessingException e) {
|
||||
// ignore
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onMessage(String message, EventSource eventSource) {
|
||||
messageBuilder.append(message);
|
||||
callParameters.getMessage().setResponse(messageBuilder.toString());
|
||||
eventListener.handleMessage(message);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onComplete(StringBuilder messageBuilder) {
|
||||
eventListener.handleCompleted(messageBuilder.toString(), callParameters);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onCancelled(StringBuilder messageBuilder) {
|
||||
eventListener.handleCompleted(messageBuilder.toString(), callParameters);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onError(ErrorDetails error, Throwable ex) {
|
||||
try {
|
||||
eventListener.handleError(error, ex);
|
||||
} finally {
|
||||
sendError(error, ex);
|
||||
}
|
||||
}
|
||||
|
||||
private void sendError(ErrorDetails error, Throwable ex) {
|
||||
var telemetryMessage = TelemetryAction.COMPLETION_ERROR.createActionMessage();
|
||||
if ("insufficient_quota".equals(error.getCode())) {
|
||||
telemetryMessage
|
||||
.property("type", "USER")
|
||||
.property("code", "INSUFFICIENT_QUOTA");
|
||||
} else {
|
||||
telemetryMessage
|
||||
.property("conversationId", callParameters.getConversation().getId().toString())
|
||||
.property("model", callParameters.getConversation().getModel())
|
||||
.error(new RuntimeException(error.toString(), ex));
|
||||
}
|
||||
telemetryMessage.send();
|
||||
}
|
||||
}
|
||||
|
|
@ -89,7 +89,6 @@ public class CompletionClientProvider {
|
|||
return builder.build(getDefaultClientBuilder());
|
||||
}
|
||||
|
||||
|
||||
public static GoogleClient getGoogleClient() {
|
||||
return new GoogleClient.Builder(getCredential(CredentialKey.GOOGLE_API_KEY))
|
||||
.build(getDefaultClientBuilder());
|
||||
|
|
|
|||
|
|
@ -1,129 +0,0 @@
|
|||
package ee.carlrobert.codegpt.completions;
|
||||
|
||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import ee.carlrobert.codegpt.events.CodeGPTEvent;
|
||||
import ee.carlrobert.codegpt.settings.GeneralSettings;
|
||||
import ee.carlrobert.codegpt.telemetry.TelemetryAction;
|
||||
import ee.carlrobert.llm.client.openai.completion.ErrorDetails;
|
||||
import ee.carlrobert.llm.completion.CompletionEventListener;
|
||||
import okhttp3.sse.EventSource;
|
||||
|
||||
public class CompletionRequestHandler {
|
||||
|
||||
private final StringBuilder messageBuilder = new StringBuilder();
|
||||
private final CompletionResponseEventListener completionResponseEventListener;
|
||||
private EventSource eventSource;
|
||||
|
||||
public CompletionRequestHandler(CompletionResponseEventListener completionResponseEventListener) {
|
||||
this.completionResponseEventListener = completionResponseEventListener;
|
||||
}
|
||||
|
||||
public void call(CallParameters callParameters) {
|
||||
try {
|
||||
eventSource = startCall(callParameters, new RequestCompletionEventListener(callParameters));
|
||||
} catch (TotalUsageExceededException e) {
|
||||
completionResponseEventListener.handleTokensExceeded(
|
||||
callParameters.getConversation(),
|
||||
callParameters.getMessage());
|
||||
} finally {
|
||||
sendInfo(callParameters);
|
||||
}
|
||||
}
|
||||
|
||||
public void cancel() {
|
||||
if (eventSource != null) {
|
||||
eventSource.cancel();
|
||||
}
|
||||
}
|
||||
|
||||
private EventSource startCall(
|
||||
CallParameters callParameters,
|
||||
CompletionEventListener<String> eventListener) {
|
||||
try {
|
||||
return CompletionRequestService.getInstance()
|
||||
.getChatCompletionAsync(callParameters, eventListener);
|
||||
} catch (Throwable ex) {
|
||||
handleCallException(ex);
|
||||
throw ex;
|
||||
}
|
||||
}
|
||||
|
||||
private void handleCallException(Throwable ex) {
|
||||
var errorMessage = "Something went wrong";
|
||||
if (ex instanceof TotalUsageExceededException) {
|
||||
errorMessage =
|
||||
"The length of the context exceeds the maximum limit that the model can handle. "
|
||||
+ "Try reducing the input message or maximum completion token size.";
|
||||
}
|
||||
completionResponseEventListener.handleError(new ErrorDetails(errorMessage), ex);
|
||||
}
|
||||
|
||||
class RequestCompletionEventListener implements CompletionEventListener<String> {
|
||||
|
||||
private final CallParameters callParameters;
|
||||
|
||||
public RequestCompletionEventListener(CallParameters callParameters) {
|
||||
this.callParameters = callParameters;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onEvent(String data) {
|
||||
try {
|
||||
var event = new ObjectMapper().readValue(data, CodeGPTEvent.class);
|
||||
completionResponseEventListener.handleCodeGPTEvent(event);
|
||||
} catch (JsonProcessingException e) {
|
||||
// ignore
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onMessage(String message, EventSource eventSource) {
|
||||
messageBuilder.append(message);
|
||||
callParameters.getMessage().setResponse(messageBuilder.toString());
|
||||
completionResponseEventListener.handleMessage(message);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onComplete(StringBuilder messageBuilder) {
|
||||
completionResponseEventListener.handleCompleted(messageBuilder.toString(), callParameters);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onCancelled(StringBuilder messageBuilder) {
|
||||
completionResponseEventListener.handleCompleted(messageBuilder.toString(), callParameters);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onError(ErrorDetails error, Throwable ex) {
|
||||
try {
|
||||
completionResponseEventListener.handleError(error, ex);
|
||||
} finally {
|
||||
sendError(error, ex);
|
||||
}
|
||||
}
|
||||
|
||||
private void sendError(ErrorDetails error, Throwable ex) {
|
||||
var telemetryMessage = TelemetryAction.COMPLETION_ERROR.createActionMessage();
|
||||
if ("insufficient_quota".equals(error.getCode())) {
|
||||
telemetryMessage
|
||||
.property("type", "USER")
|
||||
.property("code", "INSUFFICIENT_QUOTA");
|
||||
} else {
|
||||
telemetryMessage
|
||||
.property("conversationId", callParameters.getConversation().getId().toString())
|
||||
.property("model", callParameters.getConversation().getModel())
|
||||
.error(new RuntimeException(error.toString(), ex));
|
||||
}
|
||||
telemetryMessage.send();
|
||||
}
|
||||
}
|
||||
|
||||
private void sendInfo(CallParameters callParameters) {
|
||||
TelemetryAction.COMPLETION.createActionMessage()
|
||||
.property("conversationId", callParameters.getConversation().getId().toString())
|
||||
.property("model", callParameters.getConversation().getModel())
|
||||
.property("service", GeneralSettings.getSelectedService().getCode().toLowerCase())
|
||||
.send();
|
||||
}
|
||||
}
|
||||
|
|
@ -3,7 +3,9 @@ package ee.carlrobert.codegpt.completions;
|
|||
import com.intellij.openapi.application.ApplicationManager;
|
||||
import com.intellij.openapi.components.Service;
|
||||
import com.intellij.openapi.diagnostic.Logger;
|
||||
import ee.carlrobert.codegpt.actions.editor.EditCodeRequestParams;
|
||||
import com.intellij.openapi.progress.ProgressIndicator;
|
||||
import com.intellij.openapi.progress.ProgressManager;
|
||||
import com.intellij.openapi.progress.Task;
|
||||
import ee.carlrobert.codegpt.completions.factory.CustomOpenAIRequest;
|
||||
import ee.carlrobert.codegpt.credentials.CredentialsStore;
|
||||
import ee.carlrobert.codegpt.credentials.CredentialsStore.CredentialKey;
|
||||
|
|
@ -26,12 +28,15 @@ import ee.carlrobert.llm.completion.CompletionEventListener;
|
|||
import ee.carlrobert.llm.completion.CompletionRequest;
|
||||
import java.io.IOException;
|
||||
import java.util.Collection;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
import java.util.stream.Stream;
|
||||
import javax.swing.SwingUtilities;
|
||||
import okhttp3.Request;
|
||||
import okhttp3.sse.EventSource;
|
||||
import okhttp3.sse.EventSources;
|
||||
import org.jetbrains.annotations.NotNull;
|
||||
|
||||
@Service
|
||||
public final class CompletionRequestService {
|
||||
|
|
@ -63,50 +68,50 @@ public final class CompletionRequestService {
|
|||
new OpenAIChatCompletionEventSourceListener(eventListener));
|
||||
}
|
||||
|
||||
public String getLookupCompletion(String prompt) {
|
||||
return getChatCompletion(
|
||||
CompletionRequestFactory.getFactory(GeneralSettings.getSelectedService())
|
||||
.createLookupRequest(prompt));
|
||||
public String getLookupCompletion(LookupRequestCallParameters params) {
|
||||
var request = CompletionRequestFactory
|
||||
.getFactory(GeneralSettings.getSelectedService())
|
||||
.createLookupRequest(params);
|
||||
return getChatCompletion(request);
|
||||
}
|
||||
|
||||
public EventSource getCommitMessageAsync(
|
||||
String systemPrompt,
|
||||
String gitDiff,
|
||||
CommitMessageRequestParameters params,
|
||||
CompletionEventListener<String> eventListener) {
|
||||
return getChatCompletionAsync(
|
||||
CompletionRequestFactory.getFactory(GeneralSettings.getSelectedService())
|
||||
.createCommitMessageRequest(systemPrompt, gitDiff),
|
||||
eventListener);
|
||||
var request = CompletionRequestFactory
|
||||
.getFactory(GeneralSettings.getSelectedService())
|
||||
.createCommitMessageRequest(params);
|
||||
return getChatCompletionAsync(request, eventListener);
|
||||
}
|
||||
|
||||
public EventSource getEditCodeCompletionAsync(
|
||||
EditCodeRequestParams params,
|
||||
EditCodeRequestParameters params,
|
||||
CompletionEventListener<String> eventListener) {
|
||||
var input = "%s\n\n%s".formatted(params.getPrompt(), params.getSelectedText());
|
||||
return getChatCompletionAsync(
|
||||
CompletionRequestFactory.getFactory(GeneralSettings.getSelectedService())
|
||||
.createEditCodeRequest(input),
|
||||
eventListener);
|
||||
var request = CompletionRequestFactory
|
||||
.getFactory(GeneralSettings.getSelectedService())
|
||||
.createEditCodeRequest(params);
|
||||
return getChatCompletionAsync(request, eventListener);
|
||||
}
|
||||
|
||||
public EventSource getChatCompletionAsync(
|
||||
CallParameters callParameters,
|
||||
CompletionEventListener<String> eventListener) {
|
||||
return getChatCompletionAsync(
|
||||
CompletionRequestFactory.getFactory(GeneralSettings.getSelectedService())
|
||||
.createChatRequest(callParameters),
|
||||
eventListener);
|
||||
}
|
||||
|
||||
private EventSource getChatCompletionAsync(
|
||||
CompletionRequest request,
|
||||
CompletionEventListener<String> eventListener) {
|
||||
if (request instanceof OpenAIChatCompletionRequest completionRequest) {
|
||||
return switch (GeneralSettings.getSelectedService()) {
|
||||
case CODEGPT -> CompletionClientProvider.getCodeGPTClient()
|
||||
.getChatCompletionAsync(completionRequest, eventListener);
|
||||
case OPENAI -> CompletionClientProvider.getOpenAIClient()
|
||||
.getChatCompletionAsync(completionRequest, eventListener);
|
||||
case CODEGPT -> {
|
||||
if (List.of("o1-mini", "o1-preview").contains(completionRequest.getModel())) {
|
||||
yield getO1ChatCompletionAsync(completionRequest, eventListener);
|
||||
}
|
||||
yield CompletionClientProvider.getCodeGPTClient()
|
||||
.getChatCompletionAsync(completionRequest, eventListener);
|
||||
}
|
||||
case OPENAI -> {
|
||||
if (List.of("o1-mini", "o1-preview").contains(completionRequest.getModel())) {
|
||||
yield getO1ChatCompletionAsync(completionRequest, eventListener);
|
||||
}
|
||||
yield CompletionClientProvider.getOpenAIClient()
|
||||
.getChatCompletionAsync(completionRequest, eventListener);
|
||||
}
|
||||
case AZURE -> CompletionClientProvider.getAzureClient()
|
||||
.getChatCompletionAsync(completionRequest, eventListener);
|
||||
default -> throw new RuntimeException("Unknown service selected");
|
||||
|
|
@ -142,7 +147,33 @@ public final class CompletionRequestService {
|
|||
throw new IllegalStateException("Unknown request type: " + request.getClass());
|
||||
}
|
||||
|
||||
private String getChatCompletion(CompletionRequest request) {
|
||||
private EventSource getO1ChatCompletionAsync(
|
||||
OpenAIChatCompletionRequest request,
|
||||
CompletionEventListener<String> eventListener) {
|
||||
ProgressManager.getInstance()
|
||||
.run(new Task.Backgroundable(null, "CodeGPT: Processing o1 request") {
|
||||
@Override
|
||||
public void run(@NotNull ProgressIndicator indicator) {
|
||||
indicator.setIndeterminate(true);
|
||||
var response = CompletionRequestService.getInstance().getChatCompletion(request);
|
||||
SwingUtilities.invokeLater(() -> eventListener.onComplete(new StringBuilder(response)));
|
||||
}
|
||||
});
|
||||
|
||||
return new EventSource() {
|
||||
@Override
|
||||
public @NotNull Request request() {
|
||||
return new Request.Builder().build(); // dummy
|
||||
}
|
||||
|
||||
@Override
|
||||
public void cancel() {
|
||||
eventListener.onCancelled(new StringBuilder("Cancelled"));
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
public String getChatCompletion(CompletionRequest request) {
|
||||
if (request instanceof OpenAIChatCompletionRequest completionRequest) {
|
||||
var response = switch (GeneralSettings.getSelectedService()) {
|
||||
case CODEGPT -> CompletionClientProvider.getCodeGPTClient()
|
||||
|
|
|
|||
|
|
@ -16,6 +16,9 @@ public interface CompletionResponseEventListener {
|
|||
default void handleTokensExceeded(Conversation conversation, Message message) {
|
||||
}
|
||||
|
||||
default void handleCompleted(String fullMessage) {
|
||||
}
|
||||
|
||||
default void handleCompleted(String fullMessage, CallParameters callParameters) {
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -56,7 +56,8 @@ public class MethodNameLookupListener implements LookupManagerListener {
|
|||
Application application,
|
||||
String prompt) {
|
||||
try {
|
||||
var response = CompletionRequestService.getInstance().getLookupCompletion(prompt);
|
||||
var response = CompletionRequestService.getInstance()
|
||||
.getLookupCompletion(new LookupRequestCallParameters(prompt));
|
||||
if (!response.isEmpty()) {
|
||||
for (var value : response.split(",")) {
|
||||
application.invokeLater(() -> application.runReadAction(() -> {
|
||||
|
|
|
|||
|
|
@ -0,0 +1,67 @@
|
|||
package ee.carlrobert.codegpt.completions;
|
||||
|
||||
import ee.carlrobert.codegpt.settings.GeneralSettings;
|
||||
import ee.carlrobert.codegpt.telemetry.TelemetryAction;
|
||||
import ee.carlrobert.llm.client.openai.completion.ErrorDetails;
|
||||
import okhttp3.sse.EventSource;
|
||||
|
||||
public class ToolwindowChatCompletionRequestHandler {
|
||||
|
||||
private final CompletionResponseEventListener completionResponseEventListener;
|
||||
private EventSource eventSource;
|
||||
|
||||
public ToolwindowChatCompletionRequestHandler(
|
||||
CompletionResponseEventListener completionResponseEventListener) {
|
||||
this.completionResponseEventListener = completionResponseEventListener;
|
||||
}
|
||||
|
||||
public void call(CallParameters callParameters) {
|
||||
try {
|
||||
eventSource = startCall(callParameters);
|
||||
} catch (TotalUsageExceededException e) {
|
||||
completionResponseEventListener.handleTokensExceeded(
|
||||
callParameters.getConversation(),
|
||||
callParameters.getMessage());
|
||||
} finally {
|
||||
sendInfo(callParameters);
|
||||
}
|
||||
}
|
||||
|
||||
public void cancel() {
|
||||
if (eventSource != null) {
|
||||
eventSource.cancel();
|
||||
}
|
||||
}
|
||||
|
||||
private EventSource startCall(CallParameters callParameters) {
|
||||
try {
|
||||
var request = CompletionRequestFactory
|
||||
.getFactory(GeneralSettings.getSelectedService())
|
||||
.createChatRequest(new ChatCompletionRequestParameters(callParameters));
|
||||
return CompletionRequestService.getInstance().getChatCompletionAsync(
|
||||
request,
|
||||
new ChatCompletionEventListener(callParameters, completionResponseEventListener));
|
||||
} catch (Throwable ex) {
|
||||
handleCallException(ex);
|
||||
throw ex;
|
||||
}
|
||||
}
|
||||
|
||||
private void handleCallException(Throwable ex) {
|
||||
var errorMessage = "Something went wrong";
|
||||
if (ex instanceof TotalUsageExceededException) {
|
||||
errorMessage =
|
||||
"The length of the context exceeds the maximum limit that the model can handle. "
|
||||
+ "Try reducing the input message or maximum completion token size.";
|
||||
}
|
||||
completionResponseEventListener.handleError(new ErrorDetails(errorMessage), ex);
|
||||
}
|
||||
|
||||
private void sendInfo(CallParameters callParameters) {
|
||||
TelemetryAction.COMPLETION.createActionMessage()
|
||||
.property("conversationId", callParameters.getConversation().getId().toString())
|
||||
.property("model", callParameters.getConversation().getModel())
|
||||
.property("service", GeneralSettings.getSelectedService().getCode().toLowerCase())
|
||||
.send();
|
||||
}
|
||||
}
|
||||
|
|
@ -11,8 +11,8 @@ public class AdvancedSettingsState {
|
|||
private boolean proxyAuthSelected;
|
||||
private String proxyUsername;
|
||||
private String proxyPassword;
|
||||
private int connectTimeout = 30;
|
||||
private int readTimeout = 30;
|
||||
private int connectTimeout = 120;
|
||||
private int readTimeout = 120;
|
||||
|
||||
public String getProxyHost() {
|
||||
return proxyHost;
|
||||
|
|
|
|||
|
|
@ -15,10 +15,10 @@ import ee.carlrobert.codegpt.CodeGPTKeys;
|
|||
import ee.carlrobert.codegpt.ReferencedFile;
|
||||
import ee.carlrobert.codegpt.actions.ActionType;
|
||||
import ee.carlrobert.codegpt.completions.CallParameters;
|
||||
import ee.carlrobert.codegpt.completions.CompletionRequestHandler;
|
||||
import ee.carlrobert.codegpt.completions.CompletionRequestService;
|
||||
import ee.carlrobert.codegpt.completions.CompletionRequestUtil;
|
||||
import ee.carlrobert.codegpt.completions.ConversationType;
|
||||
import ee.carlrobert.codegpt.completions.ToolwindowChatCompletionRequestHandler;
|
||||
import ee.carlrobert.codegpt.conversations.Conversation;
|
||||
import ee.carlrobert.codegpt.conversations.ConversationService;
|
||||
import ee.carlrobert.codegpt.conversations.message.Message;
|
||||
|
|
@ -60,7 +60,7 @@ public class ChatToolWindowTabPanel implements Disposable {
|
|||
private final TotalTokensPanel totalTokensPanel;
|
||||
private final ChatToolWindowScrollablePanel toolWindowScrollablePanel;
|
||||
|
||||
private @Nullable CompletionRequestHandler requestHandler;
|
||||
private @Nullable ToolwindowChatCompletionRequestHandler requestHandler;
|
||||
|
||||
public ChatToolWindowTabPanel(@NotNull Project project, @NotNull Conversation conversation) {
|
||||
this.project = project;
|
||||
|
|
@ -250,7 +250,7 @@ public class ChatToolWindowTabPanel implements Disposable {
|
|||
return;
|
||||
}
|
||||
|
||||
requestHandler = new CompletionRequestHandler(
|
||||
requestHandler = new ToolwindowChatCompletionRequestHandler(
|
||||
new ToolWindowCompletionResponseEventListener(
|
||||
conversationService,
|
||||
responsePanel,
|
||||
|
|
|
|||
|
|
@ -112,6 +112,9 @@ abstract class ToolWindowCompletionResponseEventListener implements
|
|||
try {
|
||||
responsePanel.enableActions();
|
||||
responseContainer.enableActions();
|
||||
if (!responseContainer.isResponseReceived() && !fullMessage.isEmpty()) {
|
||||
responseContainer.withResponse(fullMessage);
|
||||
}
|
||||
totalTokensPanel.updateUserPromptTokens(textArea.getText());
|
||||
totalTokensPanel.updateConversationTokens(callParameters.getConversation());
|
||||
} finally {
|
||||
|
|
|
|||
|
|
@ -113,6 +113,10 @@ public class ChatMessageResponseBody extends JPanel {
|
|||
}
|
||||
|
||||
public ChatMessageResponseBody withResponse(String response) {
|
||||
if (!responseReceived) {
|
||||
removeAll();
|
||||
}
|
||||
|
||||
for (var message : MarkdownUtil.splitCodeBlocks(response)) {
|
||||
currentlyProcessedEditorPanel = null;
|
||||
currentlyProcessedTextPane = null;
|
||||
|
|
@ -362,4 +366,8 @@ public class ChatMessageResponseBody extends JPanel {
|
|||
panel.add(listPanel, BorderLayout.CENTER);
|
||||
return panel;
|
||||
}
|
||||
|
||||
public boolean isResponseReceived() {
|
||||
return responseReceived;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -123,9 +123,10 @@ public class ModelComboBoxAction extends ComboBoxAction {
|
|||
var openaiGroup = DefaultActionGroup.createPopupGroup(() -> "OpenAI");
|
||||
openaiGroup.getTemplatePresentation().setIcon(Icons.OpenAI);
|
||||
List.of(
|
||||
OpenAIChatCompletionModel.O_1_PREVIEW,
|
||||
OpenAIChatCompletionModel.O_1_MINI,
|
||||
OpenAIChatCompletionModel.GPT_4_O,
|
||||
OpenAIChatCompletionModel.GPT_4_O_MINI,
|
||||
OpenAIChatCompletionModel.GPT_4_VISION_PREVIEW,
|
||||
OpenAIChatCompletionModel.GPT_4_0125_128k)
|
||||
.forEach(model -> openaiGroup.add(createOpenAIModelAction(model, presentation)));
|
||||
actionGroup.add(openaiGroup);
|
||||
|
|
|
|||
|
|
@ -34,7 +34,12 @@ class EditCodeCompletionListener(
|
|||
}
|
||||
|
||||
override fun onComplete(messageBuilder: StringBuilder) {
|
||||
runInEdt { cleanupAndFormat() }
|
||||
runInEdt {
|
||||
if (replacedLength == 0 && messageBuilder.isNotEmpty()) {
|
||||
handleDiff(messageBuilder.toString())
|
||||
}
|
||||
cleanupAndFormat()
|
||||
}
|
||||
observableProperties.loading.set(false)
|
||||
}
|
||||
|
||||
|
|
@ -73,7 +78,6 @@ class EditCodeCompletionListener(
|
|||
val document = editor.document
|
||||
val startOffset = selectionTextRange.startOffset
|
||||
val endOffset = selectionTextRange.endOffset
|
||||
|
||||
runUndoTransparentWriteAction {
|
||||
val remainingOriginalLength = endOffset - (startOffset + replacedLength)
|
||||
if (remainingOriginalLength > 0) {
|
||||
|
|
|
|||
|
|
@ -9,10 +9,9 @@ import com.intellij.openapi.util.TextRange
|
|||
import com.intellij.openapi.util.text.StringUtil
|
||||
import com.jetbrains.rd.util.AtomicReference
|
||||
import ee.carlrobert.codegpt.completions.CompletionRequestService
|
||||
import ee.carlrobert.codegpt.completions.EditCodeRequestParameters
|
||||
import ee.carlrobert.codegpt.ui.ObservableProperties
|
||||
|
||||
data class EditCodeRequestParams(val prompt: String, val selectedText: String)
|
||||
|
||||
class EditCodeSubmissionHandler(
|
||||
private val editor: Editor,
|
||||
private val observableProperties: ObservableProperties,
|
||||
|
|
@ -36,7 +35,7 @@ class EditCodeSubmissionHandler(
|
|||
runInEdt { editor.selectionModel.removeSelection() }
|
||||
|
||||
service<CompletionRequestService>().getEditCodeCompletionAsync(
|
||||
EditCodeRequestParams(userPrompt, selectedText),
|
||||
EditCodeRequestParameters(userPrompt, selectedText),
|
||||
EditCodeCompletionListener(editor, observableProperties, selectionTextRange)
|
||||
)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,19 @@
|
|||
package ee.carlrobert.codegpt.completions
|
||||
|
||||
interface CompletionCallParameters
|
||||
|
||||
data class ChatCompletionRequestParameters(
|
||||
val callParameters: CallParameters
|
||||
) : CompletionCallParameters
|
||||
|
||||
data class CommitMessageRequestParameters(
|
||||
val gitDiff: String,
|
||||
val systemPrompt: String
|
||||
) : CompletionCallParameters
|
||||
|
||||
data class LookupRequestCallParameters(val prompt: String) : CompletionCallParameters
|
||||
|
||||
data class EditCodeRequestParameters(
|
||||
val prompt: String,
|
||||
val selectedText: String
|
||||
) : CompletionCallParameters
|
||||
|
|
@ -7,10 +7,10 @@ import ee.carlrobert.codegpt.settings.service.ServiceType
|
|||
import ee.carlrobert.llm.completion.CompletionRequest
|
||||
|
||||
interface CompletionRequestFactory {
|
||||
fun createChatRequest(callParameters: CallParameters): CompletionRequest
|
||||
fun createEditCodeRequest(input: String): CompletionRequest
|
||||
fun createCommitMessageRequest(systemPrompt: String, gitDiff: String): CompletionRequest
|
||||
fun createLookupRequest(prompt: String): CompletionRequest
|
||||
fun createChatRequest(params: ChatCompletionRequestParameters): CompletionRequest
|
||||
fun createEditCodeRequest(params: EditCodeRequestParameters): CompletionRequest
|
||||
fun createCommitMessageRequest(params: CommitMessageRequestParameters): CompletionRequest
|
||||
fun createLookupRequest(params: LookupRequestCallParameters): CompletionRequest
|
||||
|
||||
companion object {
|
||||
@JvmStatic
|
||||
|
|
@ -30,24 +30,23 @@ interface CompletionRequestFactory {
|
|||
}
|
||||
|
||||
abstract class BaseRequestFactory : CompletionRequestFactory {
|
||||
override fun createEditCodeRequest(input: String): CompletionRequest {
|
||||
return createBasicCompletionRequest(EDIT_CODE_SYSTEM_PROMPT, input, true)
|
||||
override fun createEditCodeRequest(params: EditCodeRequestParameters): CompletionRequest {
|
||||
val prompt = "${params.prompt}\n\n${params.selectedText}"
|
||||
return createBasicCompletionRequest(EDIT_CODE_SYSTEM_PROMPT, prompt, 8192, true)
|
||||
}
|
||||
|
||||
override fun createCommitMessageRequest(
|
||||
systemPrompt: String,
|
||||
gitDiff: String
|
||||
): CompletionRequest {
|
||||
return createBasicCompletionRequest(systemPrompt, gitDiff, true)
|
||||
override fun createCommitMessageRequest(params: CommitMessageRequestParameters): CompletionRequest {
|
||||
return createBasicCompletionRequest(params.systemPrompt, params.gitDiff, 512, true)
|
||||
}
|
||||
|
||||
override fun createLookupRequest(prompt: String): CompletionRequest {
|
||||
return createBasicCompletionRequest(GENERATE_METHOD_NAMES_SYSTEM_PROMPT, prompt)
|
||||
override fun createLookupRequest(params: LookupRequestCallParameters): CompletionRequest {
|
||||
return createBasicCompletionRequest(GENERATE_METHOD_NAMES_SYSTEM_PROMPT, params.prompt, 512)
|
||||
}
|
||||
|
||||
abstract fun createBasicCompletionRequest(
|
||||
systemPrompt: String,
|
||||
userPrompt: String,
|
||||
maxTokens: Int = 4096,
|
||||
stream: Boolean = false
|
||||
): CompletionRequest
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ package ee.carlrobert.codegpt.completions.factory
|
|||
|
||||
import com.intellij.openapi.components.service
|
||||
import ee.carlrobert.codegpt.completions.BaseRequestFactory
|
||||
import ee.carlrobert.codegpt.completions.CallParameters
|
||||
import ee.carlrobert.codegpt.completions.ChatCompletionRequestParameters
|
||||
import ee.carlrobert.codegpt.completions.factory.OpenAIRequestFactory.Companion.buildOpenAIMessages
|
||||
import ee.carlrobert.codegpt.settings.configuration.ConfigurationSettings
|
||||
import ee.carlrobert.llm.client.openai.completion.request.OpenAIChatCompletionRequest
|
||||
|
|
@ -10,10 +10,10 @@ import ee.carlrobert.llm.completion.CompletionRequest
|
|||
|
||||
class AzureRequestFactory : BaseRequestFactory() {
|
||||
|
||||
override fun createChatRequest(callParameters: CallParameters): OpenAIChatCompletionRequest {
|
||||
override fun createChatRequest(params: ChatCompletionRequestParameters): OpenAIChatCompletionRequest {
|
||||
val configuration = service<ConfigurationSettings>().state
|
||||
val requestBuilder: OpenAIChatCompletionRequest.Builder =
|
||||
OpenAIChatCompletionRequest.Builder(buildOpenAIMessages(null, callParameters))
|
||||
OpenAIChatCompletionRequest.Builder(buildOpenAIMessages(null, params.callParameters))
|
||||
.setMaxTokens(configuration.maxTokens)
|
||||
.setStream(true)
|
||||
.setTemperature(configuration.temperature.toDouble())
|
||||
|
|
@ -23,6 +23,7 @@ class AzureRequestFactory : BaseRequestFactory() {
|
|||
override fun createBasicCompletionRequest(
|
||||
systemPrompt: String,
|
||||
userPrompt: String,
|
||||
maxTokens: Int,
|
||||
stream: Boolean
|
||||
): CompletionRequest {
|
||||
return OpenAIRequestFactory.createBasicCompletionRequest(
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ package ee.carlrobert.codegpt.completions.factory
|
|||
|
||||
import com.intellij.openapi.components.service
|
||||
import ee.carlrobert.codegpt.completions.BaseRequestFactory
|
||||
import ee.carlrobert.codegpt.completions.CallParameters
|
||||
import ee.carlrobert.codegpt.completions.ChatCompletionRequestParameters
|
||||
import ee.carlrobert.codegpt.settings.configuration.ConfigurationSettings
|
||||
import ee.carlrobert.codegpt.settings.persona.PersonaSettings
|
||||
import ee.carlrobert.codegpt.settings.service.anthropic.AnthropicSettings
|
||||
|
|
@ -11,7 +11,8 @@ import ee.carlrobert.llm.completion.CompletionRequest
|
|||
|
||||
class ClaudeRequestFactory : BaseRequestFactory() {
|
||||
|
||||
override fun createChatRequest(callParameters: CallParameters): ClaudeCompletionRequest {
|
||||
override fun createChatRequest(params: ChatCompletionRequestParameters): ClaudeCompletionRequest {
|
||||
val (callParameters) = params
|
||||
return ClaudeCompletionRequest().apply {
|
||||
model = service<AnthropicSettings>().state.model
|
||||
maxTokens = service<ConfigurationSettings>().state.maxTokens
|
||||
|
|
@ -57,15 +58,16 @@ class ClaudeRequestFactory : BaseRequestFactory() {
|
|||
override fun createBasicCompletionRequest(
|
||||
systemPrompt: String,
|
||||
userPrompt: String,
|
||||
maxTokens: Int,
|
||||
stream: Boolean
|
||||
): CompletionRequest {
|
||||
return ClaudeCompletionRequest().apply {
|
||||
system = systemPrompt
|
||||
isStream = stream
|
||||
maxTokens = service<ConfigurationSettings>().state.maxTokens
|
||||
model = service<AnthropicSettings>().state.model
|
||||
messages =
|
||||
listOf<ClaudeCompletionMessage>(ClaudeCompletionStandardMessage("user", userPrompt))
|
||||
this.maxTokens = maxTokens
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,7 +2,8 @@ package ee.carlrobert.codegpt.completions.factory
|
|||
|
||||
import com.intellij.openapi.components.service
|
||||
import ee.carlrobert.codegpt.completions.BaseRequestFactory
|
||||
import ee.carlrobert.codegpt.completions.CallParameters
|
||||
import ee.carlrobert.codegpt.completions.ChatCompletionRequestParameters
|
||||
import ee.carlrobert.codegpt.completions.factory.OpenAIRequestFactory.Companion.buildBasicO1Request
|
||||
import ee.carlrobert.codegpt.completions.factory.OpenAIRequestFactory.Companion.buildOpenAIMessages
|
||||
import ee.carlrobert.codegpt.settings.configuration.ConfigurationSettings
|
||||
import ee.carlrobert.codegpt.settings.service.codegpt.CodeGPTServiceSettings
|
||||
|
|
@ -11,15 +12,26 @@ import ee.carlrobert.llm.client.openai.completion.request.RequestDocumentationDe
|
|||
|
||||
class CodeGPTRequestFactory : BaseRequestFactory() {
|
||||
|
||||
override fun createChatRequest(callParameters: CallParameters): OpenAIChatCompletionRequest {
|
||||
override fun createChatRequest(params: ChatCompletionRequestParameters): OpenAIChatCompletionRequest {
|
||||
val (callParameters) = params
|
||||
val model = service<CodeGPTServiceSettings>().state.chatCompletionSettings.model
|
||||
val configuration = service<ConfigurationSettings>().state
|
||||
val requestBuilder: OpenAIChatCompletionRequest.Builder =
|
||||
OpenAIChatCompletionRequest.Builder(buildOpenAIMessages(model, callParameters))
|
||||
.setModel(model)
|
||||
.setMaxTokens(configuration.maxTokens)
|
||||
if ("o1-mini" == model || "o1-preview" == model) {
|
||||
requestBuilder
|
||||
.setMaxCompletionTokens(configuration.maxTokens)
|
||||
.setStream(false)
|
||||
.setMaxTokens(null)
|
||||
.setTemperature(null)
|
||||
} else {
|
||||
requestBuilder
|
||||
.setStream(true)
|
||||
.setMaxTokens(configuration.maxTokens)
|
||||
.setTemperature(configuration.temperature.toDouble())
|
||||
}
|
||||
|
||||
if (callParameters.message.isWebSearchIncluded) {
|
||||
requestBuilder.setWebSearchIncluded(true)
|
||||
}
|
||||
|
|
@ -36,12 +48,17 @@ class CodeGPTRequestFactory : BaseRequestFactory() {
|
|||
override fun createBasicCompletionRequest(
|
||||
systemPrompt: String,
|
||||
userPrompt: String,
|
||||
maxTokens: Int,
|
||||
stream: Boolean
|
||||
): OpenAIChatCompletionRequest {
|
||||
val model = service<CodeGPTServiceSettings>().state.chatCompletionSettings.model
|
||||
if (model == "o1-mini" || model == "o1-preview") {
|
||||
return buildBasicO1Request(model, userPrompt, systemPrompt, maxTokens)
|
||||
}
|
||||
return OpenAIRequestFactory.createBasicCompletionRequest(
|
||||
systemPrompt,
|
||||
userPrompt,
|
||||
service<CodeGPTServiceSettings>().state.chatCompletionSettings.model,
|
||||
model,
|
||||
stream
|
||||
)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ package ee.carlrobert.codegpt.completions.factory
|
|||
import com.fasterxml.jackson.databind.ObjectMapper
|
||||
import com.intellij.openapi.components.service
|
||||
import ee.carlrobert.codegpt.completions.BaseRequestFactory
|
||||
import ee.carlrobert.codegpt.completions.CallParameters
|
||||
import ee.carlrobert.codegpt.completions.ChatCompletionRequestParameters
|
||||
import ee.carlrobert.codegpt.credentials.CredentialsStore.CredentialKey
|
||||
import ee.carlrobert.codegpt.credentials.CredentialsStore.getCredential
|
||||
import ee.carlrobert.codegpt.settings.service.custom.CustomServiceChatCompletionSettingsState
|
||||
|
|
@ -19,7 +19,8 @@ class CustomOpenAIRequest(val request: Request) : CompletionRequest
|
|||
|
||||
class CustomOpenAIRequestFactory : BaseRequestFactory() {
|
||||
|
||||
override fun createChatRequest(callParameters: CallParameters): CustomOpenAIRequest {
|
||||
override fun createChatRequest(params: ChatCompletionRequestParameters): CustomOpenAIRequest {
|
||||
val (callParameters) = params
|
||||
val request = buildCustomOpenAIChatCompletionRequest(
|
||||
service<CustomServiceSettings>()
|
||||
.state
|
||||
|
|
@ -34,6 +35,7 @@ class CustomOpenAIRequestFactory : BaseRequestFactory() {
|
|||
override fun createBasicCompletionRequest(
|
||||
systemPrompt: String,
|
||||
userPrompt: String,
|
||||
maxTokens: Int,
|
||||
stream: Boolean
|
||||
): CompletionRequest {
|
||||
val request = buildCustomOpenAIChatCompletionRequest(
|
||||
|
|
|
|||
|
|
@ -2,11 +2,8 @@ package ee.carlrobert.codegpt.completions.factory
|
|||
|
||||
import com.intellij.openapi.components.service
|
||||
import ee.carlrobert.codegpt.EncodingManager
|
||||
import ee.carlrobert.codegpt.completions.BaseRequestFactory
|
||||
import ee.carlrobert.codegpt.completions.CallParameters
|
||||
import ee.carlrobert.codegpt.completions.*
|
||||
import ee.carlrobert.codegpt.completions.CompletionRequestUtil.FIX_COMPILE_ERRORS_SYSTEM_PROMPT
|
||||
import ee.carlrobert.codegpt.completions.ConversationType
|
||||
import ee.carlrobert.codegpt.completions.TotalUsageExceededException
|
||||
import ee.carlrobert.codegpt.conversations.ConversationsState
|
||||
import ee.carlrobert.codegpt.settings.configuration.ConfigurationSettings
|
||||
import ee.carlrobert.codegpt.settings.persona.PersonaSettings
|
||||
|
|
@ -23,7 +20,8 @@ import java.nio.file.Path
|
|||
|
||||
class GoogleRequestFactory : BaseRequestFactory() {
|
||||
|
||||
override fun createChatRequest(callParameters: CallParameters): GoogleCompletionRequest {
|
||||
override fun createChatRequest(params: ChatCompletionRequestParameters): GoogleCompletionRequest {
|
||||
val (callParameters) = params
|
||||
val configuration = service<ConfigurationSettings>().state
|
||||
val messages = buildGoogleMessages(service<GoogleSettings>().state.model, callParameters)
|
||||
return GoogleCompletionRequest.Builder(messages)
|
||||
|
|
@ -38,6 +36,7 @@ class GoogleRequestFactory : BaseRequestFactory() {
|
|||
override fun createBasicCompletionRequest(
|
||||
systemPrompt: String,
|
||||
userPrompt: String,
|
||||
maxTokens: Int,
|
||||
stream: Boolean
|
||||
): GoogleCompletionRequest {
|
||||
val configuration = service<ConfigurationSettings>().state
|
||||
|
|
@ -50,7 +49,7 @@ class GoogleRequestFactory : BaseRequestFactory() {
|
|||
)
|
||||
.generationConfig(
|
||||
GoogleGenerationConfig.Builder()
|
||||
.maxOutputTokens(configuration.maxTokens)
|
||||
.maxOutputTokens(maxTokens)
|
||||
.temperature(configuration.temperature.toDouble()).build()
|
||||
)
|
||||
.build()
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ package ee.carlrobert.codegpt.completions.factory
|
|||
|
||||
import com.intellij.openapi.components.service
|
||||
import ee.carlrobert.codegpt.completions.BaseRequestFactory
|
||||
import ee.carlrobert.codegpt.completions.CallParameters
|
||||
import ee.carlrobert.codegpt.completions.ChatCompletionRequestParameters
|
||||
import ee.carlrobert.codegpt.completions.CompletionRequestUtil.FIX_COMPILE_ERRORS_SYSTEM_PROMPT
|
||||
import ee.carlrobert.codegpt.completions.ConversationType
|
||||
import ee.carlrobert.codegpt.completions.llama.LlamaModel
|
||||
|
|
@ -14,7 +14,8 @@ import ee.carlrobert.llm.client.llama.completion.LlamaCompletionRequest
|
|||
|
||||
class LlamaRequestFactory : BaseRequestFactory() {
|
||||
|
||||
override fun createChatRequest(callParameters: CallParameters): LlamaCompletionRequest {
|
||||
override fun createChatRequest(params: ChatCompletionRequestParameters): LlamaCompletionRequest {
|
||||
val (callParameters) = params
|
||||
val promptTemplate = getPromptTemplate()
|
||||
val systemPrompt =
|
||||
if (callParameters.conversationType == ConversationType.FIX_COMPILE_ERRORS)
|
||||
|
|
@ -33,6 +34,7 @@ class LlamaRequestFactory : BaseRequestFactory() {
|
|||
override fun createBasicCompletionRequest(
|
||||
systemPrompt: String,
|
||||
userPrompt: String,
|
||||
maxTokens: Int,
|
||||
stream: Boolean
|
||||
): LlamaCompletionRequest {
|
||||
val promptTemplate = getPromptTemplate()
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ package ee.carlrobert.codegpt.completions.factory
|
|||
import com.intellij.openapi.components.service
|
||||
import ee.carlrobert.codegpt.completions.BaseRequestFactory
|
||||
import ee.carlrobert.codegpt.completions.CallParameters
|
||||
import ee.carlrobert.codegpt.completions.ChatCompletionRequestParameters
|
||||
import ee.carlrobert.codegpt.completions.CompletionRequestUtil.FIX_COMPILE_ERRORS_SYSTEM_PROMPT
|
||||
import ee.carlrobert.codegpt.completions.ConversationType
|
||||
import ee.carlrobert.codegpt.settings.configuration.ConfigurationSettings
|
||||
|
|
@ -18,7 +19,8 @@ import java.util.*
|
|||
|
||||
class OllamaRequestFactory : BaseRequestFactory() {
|
||||
|
||||
override fun createChatRequest(callParameters: CallParameters): OllamaChatCompletionRequest {
|
||||
override fun createChatRequest(params: ChatCompletionRequestParameters): OllamaChatCompletionRequest {
|
||||
val (callParameters) = params
|
||||
val configuration = service<ConfigurationSettings>().state
|
||||
val settings = service<OllamaSettings>().state
|
||||
return OllamaChatCompletionRequest.Builder(
|
||||
|
|
@ -38,6 +40,7 @@ class OllamaRequestFactory : BaseRequestFactory() {
|
|||
override fun createBasicCompletionRequest(
|
||||
systemPrompt: String,
|
||||
userPrompt: String,
|
||||
maxTokens: Int,
|
||||
stream: Boolean
|
||||
): OllamaChatCompletionRequest {
|
||||
return OllamaChatCompletionRequest.Builder(
|
||||
|
|
|
|||
|
|
@ -2,13 +2,10 @@ package ee.carlrobert.codegpt.completions.factory
|
|||
|
||||
import com.intellij.openapi.components.service
|
||||
import ee.carlrobert.codegpt.EncodingManager
|
||||
import ee.carlrobert.codegpt.completions.CallParameters
|
||||
import ee.carlrobert.codegpt.completions.CompletionRequestFactory
|
||||
import ee.carlrobert.codegpt.completions.*
|
||||
import ee.carlrobert.codegpt.completions.CompletionRequestUtil.EDIT_CODE_SYSTEM_PROMPT
|
||||
import ee.carlrobert.codegpt.completions.CompletionRequestUtil.FIX_COMPILE_ERRORS_SYSTEM_PROMPT
|
||||
import ee.carlrobert.codegpt.completions.CompletionRequestUtil.GENERATE_METHOD_NAMES_SYSTEM_PROMPT
|
||||
import ee.carlrobert.codegpt.completions.ConversationType
|
||||
import ee.carlrobert.codegpt.completions.TotalUsageExceededException
|
||||
import ee.carlrobert.codegpt.conversations.ConversationsState
|
||||
import ee.carlrobert.codegpt.settings.configuration.ConfigurationSettings
|
||||
import ee.carlrobert.codegpt.settings.configuration.ConfigurationSettings.Companion.getState
|
||||
|
|
@ -17,62 +14,93 @@ import ee.carlrobert.codegpt.settings.service.openai.OpenAISettings
|
|||
import ee.carlrobert.codegpt.util.file.FileUtil.getImageMediaType
|
||||
import ee.carlrobert.llm.client.openai.completion.OpenAIChatCompletionModel
|
||||
import ee.carlrobert.llm.client.openai.completion.request.*
|
||||
import ee.carlrobert.llm.completion.CompletionRequest
|
||||
import java.io.IOException
|
||||
import java.nio.file.Files
|
||||
import java.nio.file.Path
|
||||
|
||||
class OpenAIRequestFactory : CompletionRequestFactory {
|
||||
|
||||
override fun createChatRequest(callParameters: CallParameters): OpenAIChatCompletionRequest {
|
||||
override fun createChatRequest(params: ChatCompletionRequestParameters): OpenAIChatCompletionRequest {
|
||||
val (callParameters) = params
|
||||
val model = service<OpenAISettings>().state.model
|
||||
val configuration = service<ConfigurationSettings>().state
|
||||
val requestBuilder: OpenAIChatCompletionRequest.Builder =
|
||||
OpenAIChatCompletionRequest.Builder(buildOpenAIMessages(model, callParameters))
|
||||
.setModel(model)
|
||||
.setMaxTokens(configuration.maxTokens)
|
||||
if ("o1-mini" == model || "o1-preview" == model) {
|
||||
requestBuilder
|
||||
.setMaxCompletionTokens(configuration.maxTokens)
|
||||
.setStream(false)
|
||||
.setMaxTokens(null)
|
||||
.setTemperature(null)
|
||||
.setPresencePenalty(null)
|
||||
.setFrequencyPenalty(null)
|
||||
} else {
|
||||
requestBuilder
|
||||
.setStream(true)
|
||||
.setMaxTokens(configuration.maxTokens)
|
||||
.setTemperature(configuration.temperature.toDouble())
|
||||
}
|
||||
return requestBuilder.build()
|
||||
}
|
||||
|
||||
override fun createEditCodeRequest(input: String): OpenAIChatCompletionRequest {
|
||||
return buildEditCodeRequest(input, service<OpenAISettings>().state.model)
|
||||
override fun createEditCodeRequest(params: EditCodeRequestParameters): OpenAIChatCompletionRequest {
|
||||
val model = service<OpenAISettings>().state.model
|
||||
if (model == "o1-mini" || model == "o1-preview") {
|
||||
return buildBasicO1Request(model, params.prompt, EDIT_CODE_SYSTEM_PROMPT)
|
||||
}
|
||||
return createBasicCompletionRequest(EDIT_CODE_SYSTEM_PROMPT, params.prompt, model, true)
|
||||
}
|
||||
|
||||
override fun createCommitMessageRequest(
|
||||
systemPrompt: String,
|
||||
gitDiff: String
|
||||
): CompletionRequest {
|
||||
return createBasicCompletionRequest(
|
||||
systemPrompt,
|
||||
gitDiff,
|
||||
service<OpenAISettings>().state.model,
|
||||
true
|
||||
)
|
||||
override fun createCommitMessageRequest(params: CommitMessageRequestParameters): OpenAIChatCompletionRequest {
|
||||
val model = service<OpenAISettings>().state.model
|
||||
val (gitDiff, systemPrompt) = params
|
||||
if (model == "o1-mini" || model == "o1-preview") {
|
||||
return buildBasicO1Request(model, gitDiff, systemPrompt)
|
||||
}
|
||||
return createBasicCompletionRequest(systemPrompt, gitDiff, model, true)
|
||||
}
|
||||
|
||||
override fun createLookupRequest(prompt: String): CompletionRequest {
|
||||
return createBasicCompletionRequest(
|
||||
GENERATE_METHOD_NAMES_SYSTEM_PROMPT,
|
||||
prompt,
|
||||
service<OpenAISettings>().state.model
|
||||
)
|
||||
override fun createLookupRequest(params: LookupRequestCallParameters): OpenAIChatCompletionRequest {
|
||||
val model = service<OpenAISettings>().state.model
|
||||
val (prompt) = params
|
||||
if (model == "o1-mini" || model == "o1-preview") {
|
||||
return buildBasicO1Request(model, prompt, GENERATE_METHOD_NAMES_SYSTEM_PROMPT)
|
||||
}
|
||||
return createBasicCompletionRequest(GENERATE_METHOD_NAMES_SYSTEM_PROMPT, prompt, model)
|
||||
}
|
||||
|
||||
companion object {
|
||||
fun buildEditCodeRequest(
|
||||
input: String,
|
||||
model: String? = null
|
||||
fun buildBasicO1Request(
|
||||
model: String,
|
||||
prompt: String,
|
||||
systemPrompt: String = "",
|
||||
maxCompletionTokens: Int = 4096
|
||||
): OpenAIChatCompletionRequest {
|
||||
return createBasicCompletionRequest(EDIT_CODE_SYSTEM_PROMPT, input, model, true)
|
||||
val messages = if (systemPrompt.isEmpty()) {
|
||||
listOf(OpenAIChatCompletionStandardMessage("user", prompt))
|
||||
} else {
|
||||
listOf(
|
||||
OpenAIChatCompletionStandardMessage("user", systemPrompt),
|
||||
OpenAIChatCompletionStandardMessage("user", prompt)
|
||||
)
|
||||
}
|
||||
return OpenAIChatCompletionRequest.Builder(messages)
|
||||
.setModel(model)
|
||||
.setMaxCompletionTokens(maxCompletionTokens)
|
||||
.setStream(false)
|
||||
.setTemperature(null)
|
||||
.setFrequencyPenalty(null)
|
||||
.setPresencePenalty(null)
|
||||
.setMaxTokens(null)
|
||||
.build()
|
||||
}
|
||||
|
||||
fun buildOpenAIMessages(
|
||||
model: String?,
|
||||
callParameters: CallParameters
|
||||
): List<OpenAIChatCompletionMessage> {
|
||||
val messages = buildOpenAIMessages(callParameters)
|
||||
val messages = buildOpenAIChatMessages(model, callParameters)
|
||||
|
||||
if (model == null) {
|
||||
return messages
|
||||
|
|
@ -104,21 +132,24 @@ class OpenAIRequestFactory : CompletionRequestFactory {
|
|||
)
|
||||
}
|
||||
|
||||
private fun buildOpenAIMessages(
|
||||
private fun buildOpenAIChatMessages(
|
||||
model: String?,
|
||||
callParameters: CallParameters
|
||||
): MutableList<OpenAIChatCompletionMessage> {
|
||||
val message = callParameters.message
|
||||
val messages = mutableListOf<OpenAIChatCompletionMessage>()
|
||||
val role = if ("o1-mini" == model || "o1-preview" == model) "user" else "system"
|
||||
|
||||
if (callParameters.conversationType == ConversationType.DEFAULT) {
|
||||
val sessionPersonaDetails = callParameters.message.personaDetails
|
||||
if (callParameters.message.personaDetails == null) {
|
||||
messages.add(
|
||||
OpenAIChatCompletionStandardMessage("system", getSystemPrompt())
|
||||
OpenAIChatCompletionStandardMessage(role, getSystemPrompt())
|
||||
)
|
||||
} else {
|
||||
messages.add(
|
||||
OpenAIChatCompletionStandardMessage(
|
||||
"system",
|
||||
role,
|
||||
sessionPersonaDetails.instructions
|
||||
)
|
||||
)
|
||||
|
|
@ -126,7 +157,7 @@ class OpenAIRequestFactory : CompletionRequestFactory {
|
|||
}
|
||||
if (callParameters.conversationType == ConversationType.FIX_COMPILE_ERRORS) {
|
||||
messages.add(
|
||||
OpenAIChatCompletionStandardMessage("system", FIX_COMPILE_ERRORS_SYSTEM_PROMPT)
|
||||
OpenAIChatCompletionStandardMessage(role, FIX_COMPILE_ERRORS_SYSTEM_PROMPT)
|
||||
)
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -14,49 +14,46 @@ object CodeGPTAvailableModels {
|
|||
fun getToolWindowModels(pricingPlan: PricingPlan?): List<CodeGPTModel> {
|
||||
return when (pricingPlan) {
|
||||
null, ANONYMOUS -> listOf(
|
||||
CodeGPTModel("GPT-4o", "gpt-4o", Icons.OpenAI, INDIVIDUAL),
|
||||
CodeGPTModel("Claude 3.5 Sonnet", "claude-3.5-sonnet", Icons.Anthropic, INDIVIDUAL),
|
||||
CodeGPTModel("Llama 3.1 (405B)", "llama-3.1-405b", Icons.Meta, INDIVIDUAL),
|
||||
CodeGPTModel("DeepSeek Coder V2", "deepseek-coder-v2", Icons.DeepSeek, INDIVIDUAL),
|
||||
CodeGPTModel("o1-mini", "o1-mini", Icons.OpenAI, INDIVIDUAL),
|
||||
CodeGPTModel("GPT-4o", "gpt-4o", Icons.OpenAI, FREE),
|
||||
CodeGPTModel("Claude 3.5 Sonnet", "claude-3.5-sonnet", Icons.Anthropic, FREE),
|
||||
CodeGPTModel("Llama 3.1 (405B)", "llama-3.1-405b", Icons.Meta, FREE),
|
||||
CodeGPTModel("DeepSeek Coder V2 - FREE", "deepseek-coder-v2", Icons.DeepSeek, ANONYMOUS),
|
||||
CodeGPTModel("GPT-4o mini - FREE", "gpt-4o-mini", Icons.OpenAI, ANONYMOUS),
|
||||
CodeGPTModel("Llama 3 (8B) - FREE", "llama-3-8b", Icons.Meta, ANONYMOUS)
|
||||
)
|
||||
|
||||
FREE -> listOf(
|
||||
CodeGPTModel("GPT-4o", "gpt-4o", Icons.OpenAI, INDIVIDUAL),
|
||||
CodeGPTModel("Claude 3.5 Sonnet", "claude-3.5-sonnet", Icons.Anthropic, INDIVIDUAL),
|
||||
CodeGPTModel("GPT-4o mini", "gpt-4o-mini", Icons.OpenAI, ANONYMOUS),
|
||||
CodeGPTModel("Llama 3 (70B)", "llama-3-70b", Icons.Meta, FREE),
|
||||
CodeGPTModel("Mixtral (8x22B)", "mixtral-8x22b", Icons.CodeGPTModel, FREE),
|
||||
CodeGPTModel("Code Llama (70B)", "codellama:chat", Icons.Meta, FREE),
|
||||
CodeGPTModel("o1-mini", "o1-mini", Icons.OpenAI, INDIVIDUAL),
|
||||
CodeGPTModel("GPT-4o", "gpt-4o", Icons.OpenAI, FREE),
|
||||
CodeGPTModel("Claude 3.5 Sonnet", "claude-3.5-sonnet", Icons.Anthropic, FREE),
|
||||
CodeGPTModel("Llama 3.1 (405B)", "llama-3.1-405b", Icons.Meta, FREE),
|
||||
CodeGPTModel("DeepSeek Coder V2", "deepseek-coder-v2", Icons.DeepSeek, ANONYMOUS),
|
||||
CodeGPTModel("Qwen 2.5 (72B)", "qwen-2.5-72b", Icons.Qwen, FREE),
|
||||
CodeGPTModel("Mixtral (8x22B)", "mixtral-8x22b", Icons.Mistral, FREE),
|
||||
)
|
||||
|
||||
INDIVIDUAL -> listOf(
|
||||
CodeGPTModel("GPT-4o", "gpt-4o", Icons.OpenAI, INDIVIDUAL),
|
||||
CodeGPTModel("o1-mini", "o1-mini", Icons.OpenAI, INDIVIDUAL),
|
||||
CodeGPTModel("GPT-4o", "gpt-4o", Icons.OpenAI, FREE),
|
||||
CodeGPTModel("Claude 3.5 Sonnet", "claude-3.5-sonnet", Icons.Anthropic, FREE),
|
||||
CodeGPTModel("Claude 3 Opus", "claude-3-opus", Icons.Anthropic, INDIVIDUAL),
|
||||
CodeGPTModel("Claude 3.5 Sonnet", "claude-3.5-sonnet", Icons.Anthropic, INDIVIDUAL),
|
||||
CodeGPTModel("Llama 3.1 (405B)", "llama-3.1-405b", Icons.Meta, INDIVIDUAL),
|
||||
CodeGPTModel("DeepSeek Coder V2", "deepseek-coder-v2", Icons.DeepSeek, INDIVIDUAL),
|
||||
CodeGPTModel("DBRX", "dbrx", Icons.Databricks, INDIVIDUAL),
|
||||
CodeGPTModel("Llama 3.1 (405B)", "llama-3.1-405b", Icons.Meta, FREE),
|
||||
CodeGPTModel("DeepSeek Coder V2", "deepseek-coder-v2", Icons.DeepSeek, FREE),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@JvmStatic
|
||||
val ALL_CHAT_MODELS: List<CodeGPTModel> = listOf(
|
||||
CodeGPTModel("GPT-4o", "gpt-4o", Icons.OpenAI, INDIVIDUAL),
|
||||
CodeGPTModel("o1-mini", "o1-mini", Icons.OpenAI, INDIVIDUAL),
|
||||
CodeGPTModel("GPT-4o", "gpt-4o", Icons.OpenAI, FREE),
|
||||
CodeGPTModel("GPT-4o mini", "gpt-4o-mini", Icons.OpenAI, ANONYMOUS),
|
||||
CodeGPTModel("Claude 3 Opus", "claude-3-opus", Icons.Anthropic, INDIVIDUAL),
|
||||
CodeGPTModel("Claude 3.5 Sonnet", "claude-3.5-sonnet", Icons.Anthropic, INDIVIDUAL),
|
||||
CodeGPTModel("Llama 3.1 (405B)", "llama-3.1-405b", Icons.Meta, INDIVIDUAL),
|
||||
CodeGPTModel("Llama 3 (70B)", "llama-3-70b", Icons.Meta, FREE),
|
||||
CodeGPTModel("DeepSeek Coder V2", "deepseek-coder-v2", Icons.DeepSeek, INDIVIDUAL),
|
||||
CodeGPTModel("DBRX", "dbrx", Icons.Databricks, INDIVIDUAL),
|
||||
CodeGPTModel("Llama 3 (8B) - FREE", "llama-3-8b", Icons.Meta, ANONYMOUS),
|
||||
CodeGPTModel("Code Llama (70B)", "codellama:chat", Icons.Meta, FREE),
|
||||
CodeGPTModel("Mixtral (8x22B)", "mixtral-8x22b", Icons.CodeGPTModel, FREE),
|
||||
CodeGPTModel("DeepSeek Coder (33B)", "deepseek-coder-33b", Icons.CodeGPTModel, FREE),
|
||||
CodeGPTModel("WizardLM-2 (8x22B)", "wizardlm-2-8x22b", Icons.CodeGPTModel, FREE)
|
||||
CodeGPTModel("Claude 3.5 Sonnet", "claude-3.5-sonnet", Icons.Anthropic, FREE),
|
||||
CodeGPTModel("Llama 3.1 (405B)", "llama-3.1-405b", Icons.Meta, FREE),
|
||||
CodeGPTModel("DeepSeek Coder V2", "deepseek-coder-v2", Icons.DeepSeek, FREE),
|
||||
CodeGPTModel("Mixtral (8x22B)", "mixtral-8x22b", Icons.Mistral, FREE),
|
||||
CodeGPTModel("Qwen 2.5 (72B)", "qwen-2.5-72b", Icons.Qwen, FREE),
|
||||
)
|
||||
|
||||
@JvmStatic
|
||||
|
|
@ -65,7 +62,6 @@ object CodeGPTAvailableModels {
|
|||
CodeGPTModel("StarCoder (16B)", "starcoder-16b", Icons.CodeGPTModel, FREE),
|
||||
CodeGPTModel("StarCoder (7B) - FREE", "starcoder-7b", Icons.CodeGPTModel, FREE),
|
||||
CodeGPTModel("WizardCoder Python (34B)", "wizardcoder-python", Icons.CodeGPTModel, FREE),
|
||||
CodeGPTModel("Phind Code LLaMA v2 (34B)", "phind-codellama", Icons.CodeGPTModel, FREE)
|
||||
)
|
||||
|
||||
@JvmStatic
|
||||
|
|
|
|||
32
src/main/resources/icons/mistral.svg
Normal file
32
src/main/resources/icons/mistral.svg
Normal file
|
|
@ -0,0 +1,32 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<svg width="13px" height="13px" viewBox="0 0 256 233" version="1.1" xmlns="http://www.w3.org/2000/svg" preserveAspectRatio="xMidYMid">
|
||||
<title>Mistral AI</title>
|
||||
<g>
|
||||
<rect fill="#000000" x="186.181818" y="0" width="46.5454545" height="46.5454545"></rect>
|
||||
<rect fill="#F7D046" x="209.454545" y="0" width="46.5454545" height="46.5454545"></rect>
|
||||
<rect fill="#000000" x="0" y="0" width="46.5454545" height="46.5454545"></rect>
|
||||
<rect fill="#000000" x="0" y="46.5454545" width="46.5454545" height="46.5454545"></rect>
|
||||
<rect fill="#000000" x="0" y="93.0909091" width="46.5454545" height="46.5454545"></rect>
|
||||
<rect fill="#000000" x="0" y="139.636364" width="46.5454545" height="46.5454545"></rect>
|
||||
<rect fill="#000000" x="0" y="186.181818" width="46.5454545" height="46.5454545"></rect>
|
||||
<rect fill="#F7D046" x="23.2727273" y="0" width="46.5454545" height="46.5454545"></rect>
|
||||
<rect fill="#F2A73B" x="209.454545" y="46.5454545" width="46.5454545" height="46.5454545"></rect>
|
||||
<rect fill="#F2A73B" x="23.2727273" y="46.5454545" width="46.5454545" height="46.5454545"></rect>
|
||||
<rect fill="#000000" x="139.636364" y="46.5454545" width="46.5454545" height="46.5454545"></rect>
|
||||
<rect fill="#F2A73B" x="162.909091" y="46.5454545" width="46.5454545" height="46.5454545"></rect>
|
||||
<rect fill="#F2A73B" x="69.8181818" y="46.5454545" width="46.5454545" height="46.5454545"></rect>
|
||||
<rect fill="#EE792F" x="116.363636" y="93.0909091" width="46.5454545" height="46.5454545"></rect>
|
||||
<rect fill="#EE792F" x="162.909091" y="93.0909091" width="46.5454545" height="46.5454545"></rect>
|
||||
<rect fill="#EE792F" x="69.8181818" y="93.0909091" width="46.5454545" height="46.5454545"></rect>
|
||||
<rect fill="#000000" x="93.0909091" y="139.636364" width="46.5454545" height="46.5454545"></rect>
|
||||
<rect fill="#EB5829" x="116.363636" y="139.636364" width="46.5454545" height="46.5454545"></rect>
|
||||
<rect fill="#EE792F" x="209.454545" y="93.0909091" width="46.5454545" height="46.5454545"></rect>
|
||||
<rect fill="#EE792F" x="23.2727273" y="93.0909091" width="46.5454545" height="46.5454545"></rect>
|
||||
<rect fill="#000000" x="186.181818" y="139.636364" width="46.5454545" height="46.5454545"></rect>
|
||||
<rect fill="#EB5829" x="209.454545" y="139.636364" width="46.5454545" height="46.5454545"></rect>
|
||||
<rect fill="#000000" x="186.181818" y="186.181818" width="46.5454545" height="46.5454545"></rect>
|
||||
<rect fill="#EB5829" x="23.2727273" y="139.636364" width="46.5454545" height="46.5454545"></rect>
|
||||
<rect fill="#EA3326" x="209.454545" y="186.181818" width="46.5454545" height="46.5454545"></rect>
|
||||
<rect fill="#EA3326" x="23.2727273" y="186.181818" width="46.5454545" height="46.5454545"></rect>
|
||||
</g>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 2.7 KiB |
BIN
src/main/resources/icons/qwen.png
Normal file
BIN
src/main/resources/icons/qwen.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 713 B |
|
|
@ -24,12 +24,14 @@ class CompletionRequestProviderTest : IntegrationTest() {
|
|||
conversation.addMessage(secondMessage)
|
||||
|
||||
val request = OpenAIRequestFactory().createChatRequest(
|
||||
CallParameters(
|
||||
conversation,
|
||||
ConversationType.DEFAULT,
|
||||
Message("TEST_CHAT_COMPLETION_PROMPT"),
|
||||
null,
|
||||
false
|
||||
ChatCompletionRequestParameters(
|
||||
CallParameters(
|
||||
conversation,
|
||||
ConversationType.DEFAULT,
|
||||
Message("TEST_CHAT_COMPLETION_PROMPT"),
|
||||
null,
|
||||
false
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -55,12 +57,14 @@ class CompletionRequestProviderTest : IntegrationTest() {
|
|||
conversation.addMessage(secondMessage)
|
||||
|
||||
val request = OpenAIRequestFactory().createChatRequest(
|
||||
CallParameters(
|
||||
conversation,
|
||||
ConversationType.DEFAULT,
|
||||
Message("TEST_CHAT_COMPLETION_PROMPT"),
|
||||
null,
|
||||
false
|
||||
ChatCompletionRequestParameters(
|
||||
CallParameters(
|
||||
conversation,
|
||||
ConversationType.DEFAULT,
|
||||
Message("TEST_CHAT_COMPLETION_PROMPT"),
|
||||
null,
|
||||
false
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -86,12 +90,14 @@ class CompletionRequestProviderTest : IntegrationTest() {
|
|||
conversation.addMessage(secondMessage)
|
||||
|
||||
val request = OpenAIRequestFactory().createChatRequest(
|
||||
CallParameters(
|
||||
conversation,
|
||||
ConversationType.DEFAULT,
|
||||
secondMessage,
|
||||
null,
|
||||
true
|
||||
ChatCompletionRequestParameters(
|
||||
CallParameters(
|
||||
conversation,
|
||||
ConversationType.DEFAULT,
|
||||
secondMessage,
|
||||
null,
|
||||
true
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -118,12 +124,14 @@ class CompletionRequestProviderTest : IntegrationTest() {
|
|||
conversation.discardTokenLimits()
|
||||
|
||||
val request = OpenAIRequestFactory().createChatRequest(
|
||||
CallParameters(
|
||||
conversation,
|
||||
ConversationType.DEFAULT,
|
||||
Message("TEST_CHAT_COMPLETION_PROMPT"),
|
||||
null,
|
||||
false
|
||||
ChatCompletionRequestParameters(
|
||||
CallParameters(
|
||||
conversation,
|
||||
ConversationType.DEFAULT,
|
||||
Message("TEST_CHAT_COMPLETION_PROMPT"),
|
||||
null,
|
||||
false
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -146,12 +154,14 @@ class CompletionRequestProviderTest : IntegrationTest() {
|
|||
|
||||
assertThrows(TotalUsageExceededException::class.java) {
|
||||
OpenAIRequestFactory().createChatRequest(
|
||||
CallParameters(
|
||||
conversation,
|
||||
ConversationType.DEFAULT,
|
||||
createDummyMessage(100),
|
||||
null,
|
||||
false
|
||||
ChatCompletionRequestParameters(
|
||||
CallParameters(
|
||||
conversation,
|
||||
ConversationType.DEFAULT,
|
||||
createDummyMessage(100),
|
||||
null,
|
||||
false
|
||||
)
|
||||
)
|
||||
)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -14,14 +14,17 @@ import org.apache.http.HttpHeaders
|
|||
import org.assertj.core.api.Assertions.assertThat
|
||||
import testsupport.IntegrationTest
|
||||
|
||||
class DefaultCompletionRequestHandlerTest : IntegrationTest() {
|
||||
class DefaultToolwindowChatCompletionRequestHandlerTest : IntegrationTest() {
|
||||
|
||||
fun testOpenAIChatCompletionCall() {
|
||||
useOpenAIService()
|
||||
service<PersonaSettings>().state.selectedPersona.instructions = "TEST_SYSTEM_PROMPT"
|
||||
val message = Message("TEST_PROMPT")
|
||||
val conversation = ConversationService.getInstance().startConversation()
|
||||
val requestHandler = CompletionRequestHandler(getRequestEventListener(message))
|
||||
val requestHandler =
|
||||
ToolwindowChatCompletionRequestHandler(
|
||||
getRequestEventListener(message)
|
||||
)
|
||||
expectOpenAI(StreamHttpExchange { request: RequestEntity ->
|
||||
assertThat(request.uri.path).isEqualTo("/v1/chat/completions")
|
||||
assertThat(request.method).isEqualTo("POST")
|
||||
|
|
@ -77,7 +80,10 @@ class DefaultCompletionRequestHandlerTest : IntegrationTest() {
|
|||
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "!")))))
|
||||
})
|
||||
val message = Message("TEST_PROMPT")
|
||||
val requestHandler = CompletionRequestHandler(getRequestEventListener(message))
|
||||
val requestHandler =
|
||||
ToolwindowChatCompletionRequestHandler(
|
||||
getRequestEventListener(message)
|
||||
)
|
||||
|
||||
requestHandler.call(CallParameters(conversation, message))
|
||||
|
||||
|
|
@ -91,7 +97,10 @@ class DefaultCompletionRequestHandlerTest : IntegrationTest() {
|
|||
val message = Message("TEST_PROMPT")
|
||||
val conversation = ConversationService.getInstance().startConversation()
|
||||
conversation.addMessage(Message("Ping", "Pong"))
|
||||
val requestHandler = CompletionRequestHandler(getRequestEventListener(message))
|
||||
val requestHandler =
|
||||
ToolwindowChatCompletionRequestHandler(
|
||||
getRequestEventListener(message)
|
||||
)
|
||||
expectLlama(StreamHttpExchange { request: RequestEntity ->
|
||||
assertThat(request.uri.path).isEqualTo("/completion")
|
||||
assertThat(request.body)
|
||||
|
|
@ -125,7 +134,10 @@ class DefaultCompletionRequestHandlerTest : IntegrationTest() {
|
|||
service<PersonaSettings>().state.selectedPersona.instructions = "TEST_SYSTEM_PROMPT"
|
||||
val message = Message("TEST_PROMPT")
|
||||
val conversation = ConversationService.getInstance().startConversation()
|
||||
val requestHandler = CompletionRequestHandler(getRequestEventListener(message))
|
||||
val requestHandler =
|
||||
ToolwindowChatCompletionRequestHandler(
|
||||
getRequestEventListener(message)
|
||||
)
|
||||
expectOllama(NdJsonStreamHttpExchange { request: RequestEntity ->
|
||||
assertThat(request.uri.path).isEqualTo("/api/chat")
|
||||
assertThat(request.headers[HttpHeaders.AUTHORIZATION]!![0]).isEqualTo("Bearer TEST_API_KEY")
|
||||
|
|
@ -171,7 +183,10 @@ class DefaultCompletionRequestHandlerTest : IntegrationTest() {
|
|||
service<PersonaSettings>().state.selectedPersona.instructions = "TEST_SYSTEM_PROMPT"
|
||||
val message = Message("TEST_PROMPT")
|
||||
val conversation = ConversationService.getInstance().startConversation()
|
||||
val requestHandler = CompletionRequestHandler(getRequestEventListener(message))
|
||||
val requestHandler =
|
||||
ToolwindowChatCompletionRequestHandler(
|
||||
getRequestEventListener(message)
|
||||
)
|
||||
expectGoogle(StreamHttpExchange { request: RequestEntity ->
|
||||
assertThat(request.uri.path).isEqualTo("/v1/models/gemini-pro:streamGenerateContent")
|
||||
assertThat(request.method).isEqualTo("POST")
|
||||
|
|
@ -207,7 +222,10 @@ class DefaultCompletionRequestHandlerTest : IntegrationTest() {
|
|||
service<PersonaSettings>().state.selectedPersona.instructions = "TEST_SYSTEM_PROMPT"
|
||||
val message = Message("TEST_PROMPT")
|
||||
val conversation = ConversationService.getInstance().startConversation()
|
||||
val requestHandler = CompletionRequestHandler(getRequestEventListener(message))
|
||||
val requestHandler =
|
||||
ToolwindowChatCompletionRequestHandler(
|
||||
getRequestEventListener(message)
|
||||
)
|
||||
expectCodeGPT(StreamHttpExchange { request: RequestEntity ->
|
||||
assertThat(request.uri.path).isEqualTo("/v1/chat/completions")
|
||||
assertThat(request.method).isEqualTo("POST")
|
||||
Loading…
Add table
Add a link
Reference in a new issue