feat: add Google Gemini API support (#535)

This commit is contained in:
Phil 2024-05-08 15:51:32 +02:00 committed by GitHub
parent f5a63eb889
commit 74fc2e6219
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
21 changed files with 379 additions and 15 deletions

View file

@ -10,6 +10,7 @@ public final class Icons {
IconLoader.getIcon("/icons/codegpt-small.svg", Icons.class);
public static final Icon Anthropic = IconLoader.getIcon("/icons/anthropic.svg", Icons.class);
public static final Icon Azure = IconLoader.getIcon("/icons/azure.svg", Icons.class);
public static final Icon Google = IconLoader.getIcon("/icons/google.svg", Icons.class);
public static final Icon Llama = IconLoader.getIcon("/icons/llama.svg", Icons.class);
public static final Icon OpenAI = IconLoader.getIcon("/icons/openai.svg", Icons.class);
public static final Icon Send = IconLoader.getIcon("/icons/send.svg", Icons.class);

View file

@ -14,6 +14,7 @@ import ee.carlrobert.codegpt.settings.service.openai.OpenAISettings;
import ee.carlrobert.llm.client.anthropic.ClaudeClient;
import ee.carlrobert.llm.client.azure.AzureClient;
import ee.carlrobert.llm.client.azure.AzureCompletionRequestParams;
import ee.carlrobert.llm.client.google.GoogleClient;
import ee.carlrobert.llm.client.llama.LlamaClient;
import ee.carlrobert.llm.client.ollama.OllamaClient;
import ee.carlrobert.llm.client.openai.OpenAIClient;
@ -105,6 +106,12 @@ public class CompletionClientProvider {
.build(getDefaultClientBuilder());
}
public static GoogleClient getGoogleClient() {
return new GoogleClient.Builder(getCredential(CredentialKey.GOOGLE_API_KEY))
.build(getDefaultClientBuilder());
}
public static OkHttpClient.Builder getDefaultClientBuilder() {
OkHttpClient.Builder builder = new OkHttpClient.Builder();
var advancedSettings = AdvancedSettings.getCurrentState();

View file

@ -40,6 +40,12 @@ import ee.carlrobert.llm.client.anthropic.completion.ClaudeCompletionRequest;
import ee.carlrobert.llm.client.anthropic.completion.ClaudeCompletionStandardMessage;
import ee.carlrobert.llm.client.anthropic.completion.ClaudeMessageImageContent;
import ee.carlrobert.llm.client.anthropic.completion.ClaudeMessageTextContent;
import ee.carlrobert.llm.client.google.completion.GoogleCompletionContent;
import ee.carlrobert.llm.client.google.completion.GoogleCompletionRequest;
import ee.carlrobert.llm.client.google.completion.GoogleContentPart;
import ee.carlrobert.llm.client.google.completion.GoogleContentPart.Blob;
import ee.carlrobert.llm.client.google.completion.GoogleGenerationConfig;
import ee.carlrobert.llm.client.google.models.GoogleModel;
import ee.carlrobert.llm.client.llama.completion.LlamaCompletionRequest;
import ee.carlrobert.llm.client.ollama.completion.request.OllamaChatCompletionMessage;
import ee.carlrobert.llm.client.ollama.completion.request.OllamaChatCompletionRequest;
@ -221,6 +227,16 @@ public class CompletionRequestProvider {
.setTemperature(configuration.getTemperature()).build();
}
public GoogleCompletionRequest buildGoogleChatCompletionRequest(
@Nullable String model,
CallParameters callParameters) {
var configuration = ConfigurationSettings.getCurrentState();
return new GoogleCompletionRequest.Builder(buildGoogleMessages(model, callParameters))
.generationConfig(new GoogleGenerationConfig.Builder()
.maxOutputTokens(configuration.getMaxTokens())
.temperature(configuration.getTemperature()).build()).build();
}
public Request buildCustomOpenAIChatCompletionRequest(
CustomServiceChatCompletionSettingsState settings,
CallParameters callParameters) {
@ -448,6 +464,83 @@ public class CompletionRequestProvider {
return tryReducingMessagesOrThrow(messages, totalUsage, modelMaxTokens);
}
private List<GoogleCompletionContent> buildGoogleMessages(CallParameters callParameters) {
var message = callParameters.getMessage();
var messages = new ArrayList<GoogleCompletionContent>();
// Gemini API does not support direct 'system' prompts:
// see https://www.reddit.com/r/Bard/comments/1b90i8o/does_gemini_have_a_system_prompt_option_while/
if (callParameters.getConversationType() == ConversationType.DEFAULT) {
String systemPrompt = ConfigurationSettings.getCurrentState().getSystemPrompt();
messages.add(new GoogleCompletionContent("user", List.of(systemPrompt)));
messages.add(new GoogleCompletionContent("model", List.of("Understood.")));
}
if (callParameters.getConversationType() == ConversationType.FIX_COMPILE_ERRORS) {
messages.add(
new GoogleCompletionContent("user", List.of(FIX_COMPILE_ERRORS_SYSTEM_PROMPT)));
messages.add(new GoogleCompletionContent("model", List.of("Understood.")));
}
for (var prevMessage : conversation.getMessages()) {
if (callParameters.isRetry() && prevMessage.getId().equals(message.getId())) {
break;
}
var prevMessageImageFilePath = prevMessage.getImageFilePath();
if (prevMessageImageFilePath != null && !prevMessageImageFilePath.isEmpty()) {
try {
var imageFilePath = Path.of(prevMessageImageFilePath);
var imageData = Files.readAllBytes(imageFilePath);
var imageMediaType = FileUtil.getImageMediaType(imageFilePath.getFileName().toString());
messages.add(new GoogleCompletionContent(
List.of(
new GoogleContentPart(null, new Blob(imageMediaType, imageData)),
new GoogleContentPart(prevMessage.getPrompt())), "user"));
} catch (IOException e) {
throw new RuntimeException(e);
}
} else {
messages.add(new GoogleCompletionContent("user", List.of(prevMessage.getPrompt())));
}
messages.add(new GoogleCompletionContent("model", List.of(prevMessage.getResponse())));
}
if (callParameters.getImageMediaType() != null && callParameters.getImageData().length > 0) {
messages.add(new GoogleCompletionContent(
List.of(
new GoogleContentPart(null,
new Blob(callParameters.getImageMediaType(), callParameters.getImageData())),
new GoogleContentPart(message.getPrompt())), "user"));
} else {
messages.add(new GoogleCompletionContent("user", List.of(message.getPrompt())));
}
return messages;
}
private List<GoogleCompletionContent> buildGoogleMessages(
@Nullable String model,
CallParameters callParameters) {
var messages = buildGoogleMessages(callParameters);
if (model == null) {
return messages;
}
int totalUsage = messages.parallelStream()
.mapToInt(message -> encodingManager.countMessageTokens(message.getRole(),
String.join(",", message.getParts().stream().map(GoogleContentPart::getText).toList())))
.sum() + ConfigurationSettings.getCurrentState().getMaxTokens();
int modelMaxTokens;
try {
modelMaxTokens = GoogleModel.findByCode(model).getMaxTokens();
if (totalUsage <= modelMaxTokens) {
return messages;
}
} catch (NoSuchElementException ex) {
return messages;
}
return tryReducingGoogleMessagesOrThrow(messages, totalUsage, modelMaxTokens);
}
private List<OpenAIChatCompletionMessage> tryReducingMessagesOrThrow(
List<OpenAIChatCompletionMessage> messages,
int totalUsage,
@ -473,4 +566,29 @@ public class CompletionRequestProvider {
return messages.stream().filter(Objects::nonNull).toList();
}
private List<GoogleCompletionContent> tryReducingGoogleMessagesOrThrow(
List<GoogleCompletionContent> messages,
int totalUsage,
int modelMaxTokens) {
if (!ConversationsState.getInstance().discardAllTokenLimits) {
if (!conversation.isDiscardTokenLimit()) {
throw new TotalUsageExceededException();
}
}
// skip the system prompt
for (int i = 1; i < messages.size(); i++) {
if (totalUsage <= modelMaxTokens) {
break;
}
var message = messages.get(i);
totalUsage -= encodingManager.countMessageTokens(message.getRole(),
String.join(",", message.getParts().stream().map(GoogleContentPart::getText).toList()));
messages.set(i, null);
}
return messages.stream().filter(Objects::nonNull).toList();
}
}

View file

@ -20,18 +20,23 @@ import ee.carlrobert.codegpt.settings.service.ServiceType;
import ee.carlrobert.codegpt.settings.service.anthropic.AnthropicSettings;
import ee.carlrobert.codegpt.settings.service.azure.AzureSettings;
import ee.carlrobert.codegpt.settings.service.custom.CustomServiceSettings;
import ee.carlrobert.codegpt.settings.service.google.GoogleSettings;
import ee.carlrobert.codegpt.settings.service.google.GoogleSettingsState;
import ee.carlrobert.codegpt.settings.service.llama.LlamaSettings;
import ee.carlrobert.codegpt.settings.service.ollama.OllamaSettings;
import ee.carlrobert.codegpt.settings.service.openai.OpenAISettings;
import ee.carlrobert.llm.client.DeserializationUtil;
import ee.carlrobert.llm.client.anthropic.completion.ClaudeCompletionRequest;
import ee.carlrobert.llm.client.anthropic.completion.ClaudeCompletionStandardMessage;
import ee.carlrobert.llm.client.google.completion.GoogleCompletionContent;
import ee.carlrobert.llm.client.google.completion.GoogleCompletionRequest;
import ee.carlrobert.llm.client.google.completion.GoogleGenerationConfig;
import ee.carlrobert.llm.client.llama.completion.LlamaCompletionRequest;
import ee.carlrobert.llm.client.ollama.completion.request.OllamaChatCompletionMessage;
import ee.carlrobert.llm.client.ollama.completion.request.OllamaChatCompletionRequest;
import ee.carlrobert.llm.client.openai.completion.OpenAIChatCompletionEventSourceListener;
import ee.carlrobert.llm.client.openai.completion.OpenAITextCompletionEventSourceListener;
import ee.carlrobert.llm.client.openai.completion.request.OpenAIChatCompletionRequest;
import ee.carlrobert.llm.client.openai.completion.request.OpenAIChatCompletionRequest.Builder;
import ee.carlrobert.llm.client.openai.completion.request.OpenAIChatCompletionStandardMessage;
import ee.carlrobert.llm.client.openai.completion.response.OpenAIChatCompletionResponse;
import ee.carlrobert.llm.client.openai.completion.response.OpenAIChatCompletionResponseChoice;
@ -110,6 +115,16 @@ public final class CompletionRequestService {
case OLLAMA -> CompletionClientProvider.getOllamaClient().getChatCompletionAsync(
requestProvider.buildOllamaChatCompletionRequest(callParameters),
eventListener);
case GOOGLE -> {
var settings = ApplicationManager.getApplication()
.getService(GoogleSettings.class).getState();
yield CompletionClientProvider.getGoogleClient().getChatCompletionAsync(
requestProvider.buildGoogleChatCompletionRequest(
settings.getModel(),
callParameters),
settings.getModel(),
eventListener);
}
};
}
@ -142,7 +157,7 @@ public final class CompletionRequestService {
String gitDiff,
CompletionEventListener<String> eventListener) {
var configuration = ConfigurationSettings.getCurrentState();
var openaiRequest = new OpenAIChatCompletionRequest.Builder(List.of(
var openaiRequest = new Builder(List.of(
new OpenAIChatCompletionStandardMessage("system", systemPrompt),
new OpenAIChatCompletionStandardMessage("user", gitDiff)))
.setModel(OpenAISettings.getCurrentState().getModel())
@ -212,6 +227,21 @@ public final class CompletionRequestService {
).build();
CompletionClientProvider.getOllamaClient().getChatCompletionAsync(request, eventListener);
break;
case GOOGLE:
GoogleSettingsState state = ApplicationManager.getApplication()
.getService(GoogleSettings.class).getState();
CompletionClientProvider.getGoogleClient()
.getChatCompletionAsync(new GoogleCompletionRequest.Builder(
List.of(
new GoogleCompletionContent("user", List.of(systemPrompt)),
new GoogleCompletionContent("model", List.of("Understood.")),
new GoogleCompletionContent("user", List.of(gitDiff))
))
.generationConfig(new GoogleGenerationConfig.Builder()
.maxOutputTokens(configuration.getMaxTokens())
.temperature(configuration.getTemperature()).build())
.build(), state.getModel(), eventListener);
break;
default:
LOG.debug("Unknown service: {}", selectedService);
break;
@ -255,6 +285,7 @@ public final class CompletionRequestService {
: CredentialKey.AZURE_ACTIVE_DIRECTORY_TOKEN);
case CUSTOM_OPENAI, ANTHROPIC, LLAMA_CPP, OLLAMA -> true;
case YOU -> false;
case GOOGLE -> CredentialsStore.INSTANCE.isCredentialSet(CredentialKey.GOOGLE_API_KEY);
};
}

View file

@ -8,6 +8,7 @@ import ee.carlrobert.codegpt.settings.GeneralSettings;
import ee.carlrobert.codegpt.settings.service.ServiceType;
import ee.carlrobert.codegpt.settings.service.anthropic.AnthropicSettings;
import ee.carlrobert.codegpt.settings.service.azure.AzureSettings;
import ee.carlrobert.codegpt.settings.service.google.GoogleSettings;
import ee.carlrobert.codegpt.settings.service.llama.LlamaSettings;
import ee.carlrobert.codegpt.settings.service.ollama.OllamaSettings;
import ee.carlrobert.codegpt.settings.service.openai.OpenAISettings;
@ -203,6 +204,11 @@ public final class ConversationService {
.getService(OllamaSettings.class)
.getState()
.getModel();
case GOOGLE ->
ApplicationManager.getApplication()
.getService(GoogleSettings.class)
.getState()
.getModel();
};
}
}

View file

@ -10,6 +10,7 @@ import ee.carlrobert.codegpt.conversations.Conversation;
import ee.carlrobert.codegpt.settings.service.ServiceType;
import ee.carlrobert.codegpt.settings.service.anthropic.AnthropicSettings;
import ee.carlrobert.codegpt.settings.service.azure.AzureSettings;
import ee.carlrobert.codegpt.settings.service.google.GoogleSettings;
import ee.carlrobert.codegpt.settings.service.llama.LlamaSettings;
import ee.carlrobert.codegpt.settings.service.ollama.OllamaSettings;
import ee.carlrobert.codegpt.settings.service.openai.OpenAISettings;
@ -107,6 +108,11 @@ public class GeneralSettings implements PersistentStateComponent<GeneralSettings
.getService(OllamaSettings.class)
.getState()
.getModel();
case GOOGLE:
return ApplicationManager.getApplication()
.getService(GoogleSettings.class)
.getState()
.getModel();
default:
return "Unknown";
}

View file

@ -3,6 +3,7 @@ package ee.carlrobert.codegpt.settings;
import static ee.carlrobert.codegpt.settings.service.ServiceType.ANTHROPIC;
import static ee.carlrobert.codegpt.settings.service.ServiceType.AZURE;
import static ee.carlrobert.codegpt.settings.service.ServiceType.CUSTOM_OPENAI;
import static ee.carlrobert.codegpt.settings.service.ServiceType.GOOGLE;
import static ee.carlrobert.codegpt.settings.service.ServiceType.LLAMA_CPP;
import static ee.carlrobert.codegpt.settings.service.ServiceType.OLLAMA;
import static ee.carlrobert.codegpt.settings.service.ServiceType.OPENAI;
@ -19,9 +20,9 @@ import ee.carlrobert.codegpt.settings.service.anthropic.AnthropicSettingsForm;
import ee.carlrobert.codegpt.settings.service.azure.AzureSettings;
import ee.carlrobert.codegpt.settings.service.azure.AzureSettingsForm;
import ee.carlrobert.codegpt.settings.service.custom.CustomServiceForm;
import ee.carlrobert.codegpt.settings.service.google.GoogleSettingsForm;
import ee.carlrobert.codegpt.settings.service.llama.LlamaSettings;
import ee.carlrobert.codegpt.settings.service.llama.form.LlamaSettingsForm;
import ee.carlrobert.codegpt.settings.service.ollama.OllamaSettings;
import ee.carlrobert.codegpt.settings.service.ollama.OllamaSettingsForm;
import ee.carlrobert.codegpt.settings.service.openai.OpenAISettings;
import ee.carlrobert.codegpt.settings.service.openai.OpenAISettingsForm;
@ -49,6 +50,7 @@ public class GeneralSettingsComponent {
private final YouSettingsForm youSettingsForm;
private final LlamaSettingsForm llamaSettingsForm;
private final OllamaSettingsForm ollamaSettingsForm;
private final GoogleSettingsForm googleSettingsForm;
public GeneralSettingsComponent(Disposable parentDisposable, GeneralSettings settings) {
displayNameField = new JBTextField(settings.getState().getDisplayName(), 20);
@ -59,6 +61,7 @@ public class GeneralSettingsComponent {
youSettingsForm = new YouSettingsForm(YouSettings.getCurrentState(), parentDisposable);
llamaSettingsForm = new LlamaSettingsForm(LlamaSettings.getCurrentState());
ollamaSettingsForm = new OllamaSettingsForm();
googleSettingsForm = new GoogleSettingsForm();
var cardLayout = new DynamicCardLayout();
var cards = new JPanel(cardLayout);
@ -69,6 +72,7 @@ public class GeneralSettingsComponent {
cards.add(youSettingsForm, YOU.getCode());
cards.add(llamaSettingsForm, LLAMA_CPP.getCode());
cards.add(ollamaSettingsForm.getForm(), OLLAMA.getCode());
cards.add(googleSettingsForm.getForm(), GOOGLE.getCode());
var serviceComboBoxModel = new DefaultComboBoxModel<ServiceType>();
serviceComboBoxModel.addAll(Arrays.stream(ServiceType.values()).toList());
serviceComboBox = new ComboBox<>(serviceComboBoxModel);
@ -121,6 +125,10 @@ public class GeneralSettingsComponent {
return ollamaSettingsForm;
}
public GoogleSettingsForm getGoogleSettingsForm() {
return googleSettingsForm;
}
public ServiceType getSelectedService() {
return serviceComboBox.getItem();
}
@ -153,6 +161,7 @@ public class GeneralSettingsComponent {
youSettingsForm.resetForm();
llamaSettingsForm.resetForm();
ollamaSettingsForm.resetForm();
googleSettingsForm.resetForm();
}
static class DynamicCardLayout extends CardLayout {

View file

@ -4,6 +4,7 @@ import static ee.carlrobert.codegpt.credentials.CredentialsStore.CredentialKey.A
import static ee.carlrobert.codegpt.credentials.CredentialsStore.CredentialKey.AZURE_ACTIVE_DIRECTORY_TOKEN;
import static ee.carlrobert.codegpt.credentials.CredentialsStore.CredentialKey.AZURE_OPENAI_API_KEY;
import static ee.carlrobert.codegpt.credentials.CredentialsStore.CredentialKey.CUSTOM_SERVICE_API_KEY;
import static ee.carlrobert.codegpt.credentials.CredentialsStore.CredentialKey.GOOGLE_API_KEY;
import static ee.carlrobert.codegpt.credentials.CredentialsStore.CredentialKey.LLAMA_API_KEY;
import static ee.carlrobert.codegpt.credentials.CredentialsStore.CredentialKey.OPENAI_API_KEY;
@ -18,10 +19,9 @@ import ee.carlrobert.codegpt.settings.service.anthropic.AnthropicSettingsForm;
import ee.carlrobert.codegpt.settings.service.azure.AzureSettings;
import ee.carlrobert.codegpt.settings.service.azure.AzureSettingsForm;
import ee.carlrobert.codegpt.settings.service.custom.CustomServiceForm;
import ee.carlrobert.codegpt.settings.service.google.GoogleSettingsForm;
import ee.carlrobert.codegpt.settings.service.llama.LlamaSettings;
import ee.carlrobert.codegpt.settings.service.llama.form.LlamaSettingsForm;
import ee.carlrobert.codegpt.settings.service.ollama.OllamaSettings;
import ee.carlrobert.codegpt.settings.service.ollama.OllamaSettingsForm;
import ee.carlrobert.codegpt.settings.service.openai.OpenAISettings;
import ee.carlrobert.codegpt.settings.service.openai.OpenAISettingsForm;
import ee.carlrobert.codegpt.settings.service.you.YouSettings;
@ -71,7 +71,8 @@ public class GeneralSettingsConfigurable implements Configurable {
|| AzureSettings.getInstance().isModified(component.getAzureSettingsForm())
|| YouSettings.getInstance().isModified(component.getYouSettingsForm())
|| LlamaSettings.getInstance().isModified(component.getLlamaSettingsForm())
|| component.getOllamaSettingsForm().isModified();
|| component.getOllamaSettingsForm().isModified()
|| component.getGoogleSettingsForm().isModified();
}
@Override
@ -88,6 +89,7 @@ public class GeneralSettingsConfigurable implements Configurable {
applyYouSettings(component.getYouSettingsForm());
applyLlamaSettings(component.getLlamaSettingsForm());
component.getOllamaSettingsForm().applyChanges();
applyGoogleSettings(component.getGoogleSettingsForm());
var serviceChanged = component.getSelectedService() != settings.getSelectedService();
var modelChanged = !OpenAISettings.getCurrentState().getModel()
@ -137,8 +139,9 @@ public class GeneralSettingsConfigurable implements Configurable {
form.getActiveDirectoryToken());
}
private void applyOllamaSettings(OllamaSettingsForm form) {
private void applyGoogleSettings(GoogleSettingsForm form) {
form.applyChanges();
CredentialsStore.INSTANCE.setCredential(GOOGLE_API_KEY, form.getApiKey());
}
@Override

View file

@ -9,7 +9,8 @@ public enum ServiceType {
AZURE("AZURE", "service.azure.title", "azure.chat.completion"),
YOU("YOU", "service.you.title", "you.chat.completion"),
LLAMA_CPP("LLAMA_CPP", "service.llama.title", "llama.chat.completion"),
OLLAMA("OLLAMA", "service.ollama.title", "ollama.chat.completion");
OLLAMA("OLLAMA", "service.ollama.title", "ollama.chat.completion"),
GOOGLE("GOOGLE", "service.google.title", "google.chat.completion");
private final String code;
private final String label;

View file

@ -113,6 +113,12 @@ public class ModelComboBoxAction extends ComboBoxAction {
actionGroup.addSeparator("Ollama");
ollamaSettings.getAvailableModels().forEach(model ->
actionGroup.add(createOllamaModelAction(model, presentation)));
actionGroup.addSeparator();
actionGroup.add(createModelAction(
ServiceType.GOOGLE,
"Google (Gemini)",
Icons.Google,
presentation));
if (YouUserManager.getInstance().isSubscribed()) {
actionGroup.addSeparator("You.com");
@ -193,6 +199,10 @@ public class ModelComboBoxAction extends ComboBoxAction {
templatePresentation.setIcon(Icons.Ollama);
templatePresentation.setText(ollamaSettings.getModel());
break;
case GOOGLE:
templatePresentation.setText("Google (Gemini)");
templatePresentation.setIcon(Icons.Google);
break;
default:
break;
}

View file

@ -27,6 +27,9 @@ public class ModelIconLabel extends JBLabel {
if ("llama.chat.completion".equals(clientCode)) {
setIcon(Icons.Llama);
}
if ("google.chat.completion".equals(clientCode)) {
setIcon(Icons.Google);
}
setText(formatModelName(modelCode));
setFont(JBFont.small());
setHorizontalAlignment(SwingConstants.LEADING);

View file

@ -32,6 +32,7 @@ abstract class CodeCompletionFeatureToggleActions(
ANTHROPIC,
AZURE,
YOU,
GOOGLE,
null -> { /* no-op for these services */
}
}
@ -50,6 +51,7 @@ abstract class CodeCompletionFeatureToggleActions(
ANTHROPIC,
AZURE,
YOU,
GOOGLE,
null -> false
}
}
@ -66,6 +68,7 @@ abstract class CodeCompletionFeatureToggleActions(
OLLAMA -> service<OllamaSettings>().state.codeCompletionsEnabled
ANTHROPIC,
AZURE,
GOOGLE,
YOU -> false
}
}

View file

@ -75,6 +75,7 @@ class CodeGPTInlineCompletionProvider : InlineCompletionProvider {
ServiceType.ANTHROPIC,
ServiceType.AZURE,
ServiceType.YOU,
ServiceType.GOOGLE,
null -> false
}
return event is InlineCompletionEvent.DocumentChange && codeCompletionsEnabled

View file

@ -40,6 +40,7 @@ object CredentialsStore {
AZURE_OPENAI_API_KEY,
AZURE_ACTIVE_DIRECTORY_TOKEN,
YOU_ACCOUNT_PASSWORD,
LLAMA_API_KEY
LLAMA_API_KEY,
GOOGLE_API_KEY
}
}

View file

@ -0,0 +1,11 @@
package ee.carlrobert.codegpt.settings.service.google
import com.intellij.openapi.components.*
import ee.carlrobert.llm.client.google.models.GoogleModel
@State(name = "CodeGPT_GoogleSettings_210", storages = [Storage("CodeGPT_GoogleSettings_210.xml")])
class GoogleSettings : SimplePersistentStateComponent<GoogleSettingsState>(GoogleSettingsState())
class GoogleSettingsState : BaseState() {
var model by string(GoogleModel.GEMINI_PRO.code)
}

View file

@ -0,0 +1,92 @@
package ee.carlrobert.codegpt.settings.service.google
import com.intellij.openapi.components.service
import com.intellij.openapi.ui.ComboBox
import com.intellij.ui.EnumComboBoxModel
import com.intellij.ui.TitledSeparator
import com.intellij.ui.components.JBPasswordField
import com.intellij.util.ui.FormBuilder
import com.intellij.util.ui.UI
import ee.carlrobert.codegpt.CodeGPTBundle
import ee.carlrobert.codegpt.credentials.CredentialsStore.CredentialKey
import ee.carlrobert.codegpt.credentials.CredentialsStore.getCredential
import ee.carlrobert.codegpt.ui.UIUtil
import ee.carlrobert.llm.client.google.models.GoogleModel
import javax.swing.JPanel
import javax.swing.event.HyperlinkEvent
class GoogleSettingsForm {
private val apiKeyField = JBPasswordField()
private val completionModelComboBox: ComboBox<GoogleModel>
init {
val state = service<GoogleSettings>().state
apiKeyField.columns = 30
apiKeyField.text =
getCredential(CredentialKey.GOOGLE_API_KEY)
completionModelComboBox = ComboBox(
EnumComboBoxModel(GoogleModel::class.java)
)
completionModelComboBox.selectedItem = GoogleModel.findByCode(state.model)
}
fun getForm(): JPanel = FormBuilder.createFormBuilder()
.addComponent(TitledSeparator(CodeGPTBundle.get("shared.configuration")))
.addComponent(
UIUtil.withEmptyLeftBorder(
UI.PanelFactory.grid()
.add(
UI.PanelFactory.panel(apiKeyField)
.withLabel(CodeGPTBundle.get("settingsConfigurable.shared.apiKey.label"))
.resizeX(false)
.withComment(CodeGPTBundle.get("settingsConfigurable.service.google.apiKey.comment"))
.withCommentHyperlinkListener { event: HyperlinkEvent? ->
UIUtil.handleHyperlinkClicked(
event
)
})
.add(
UI.PanelFactory.panel(completionModelComboBox)
.withLabel(CodeGPTBundle.get("settingsConfigurable.shared.model.label"))
.resizeX(false)
.withComment(CodeGPTBundle.get("settingsConfigurable.service.google.model.comment"))
.withCommentHyperlinkListener { event: HyperlinkEvent? ->
UIUtil.handleHyperlinkClicked(
event
)
}
)
.createPanel()
)
)
.addComponentFillVertically(JPanel(), 0)
.panel
fun getApiKey(): String? = String(apiKeyField.password).ifEmpty { null }
fun getModel(): String = (completionModelComboBox.model
.selectedItem as GoogleModel)
.code
fun getCurrentState() = GoogleSettingsState().apply { model = getModel() }
fun resetForm() {
val state = service<GoogleSettings>().state
apiKeyField.text =
getCredential(CredentialKey.GOOGLE_API_KEY)
completionModelComboBox.selectedItem = GoogleModel.findByCode(state.model)
}
fun isModified(): Boolean = service<GoogleSettings>().state.run {
model != getModel() || getApiKey() != getCredential(CredentialKey.GOOGLE_API_KEY)
}
fun applyChanges() {
service<GoogleSettings>().state.run {
model = getModel()
}
}
}

View file

@ -37,6 +37,7 @@
<applicationService serviceImplementation="ee.carlrobert.codegpt.settings.service.you.YouSettings"/>
<applicationService serviceImplementation="ee.carlrobert.codegpt.settings.service.llama.LlamaSettings"/>
<applicationService serviceImplementation="ee.carlrobert.codegpt.settings.service.ollama.OllamaSettings"/>
<applicationService serviceImplementation="ee.carlrobert.codegpt.settings.service.google.GoogleSettings"/>
<applicationService serviceImplementation="ee.carlrobert.codegpt.settings.IncludedFilesSettings"/>
<applicationService serviceImplementation="ee.carlrobert.codegpt.settings.configuration.ConfigurationSettings"/>
<applicationService serviceImplementation="ee.carlrobert.codegpt.settings.advanced.AdvancedSettings"/>

View file

@ -0,0 +1,14 @@
<svg fill="none" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 16 16">
<path
d="M16 8.016A8.522 8.522 0 008.016 16h-.032A8.521 8.521 0 000 8.016v-.032A8.521 8.521 0 007.984 0h.032A8.522 8.522 0 0016 7.984v.032z"
fill="url(#prefix__paint0_radial_980_20147)"/>
<defs>
<radialGradient id="prefix__paint0_radial_980_20147" cx="0" cy="0" r="1"
gradientUnits="userSpaceOnUse"
gradientTransform="matrix(16.1326 5.4553 -43.70045 129.2322 1.588 6.503)">
<stop offset=".067" stop-color="#9168C0"/>
<stop offset=".343" stop-color="#5684D1"/>
<stop offset=".672" stop-color="#1BA1E3"/>
</radialGradient>
</defs>
</svg>

After

Width:  |  Height:  |  Size: 660 B

View file

@ -23,6 +23,8 @@ settingsConfigurable.service.openai.apiKey.comment=You can find the API key in y
settingsConfigurable.service.openai.customModel.label=Custom model:
settingsConfigurable.service.openai.organization.label=Organization:
settingsConfigurable.section.openai.organization.comment=Useful when you are part of multiple organizations <sup><strong>optional</strong></sup>
settingsConfigurable.service.google.apiKey.comment=You can find the API key in your <a href="https://aistudio.google.com/app/apikey">User settings</a>.
settingsConfigurable.service.google.model.comment=Note: Gemini Vision models <a href="https://ai.google.dev/gemini-api/docs/get-started/web?multi-turn-conversations-chat&hl=en#multi-turn-conversations-chat">do not yet support chats</a>.
settingsConfigurable.service.anthropic.apiKey.comment=You can find the API key in your <a href="https://console.anthropic.com/settings/keys">User settings</a>.
settingsConfigurable.service.anthropic.apiVersion.comment=We always recommend using the <a href="https://docs.anthropic.com/claude/reference/versions">latest API version</a> whenever possible.
settingsConfigurable.service.anthropic.model.comment=For details on model comparison metrics, see <a href="https://docs.anthropic.com/claude/docs/models-overview#model-comparison">model comparison</a>.
@ -175,6 +177,7 @@ service.azure.title=Azure Service
service.you.title=You.com Service (Free, Cloud)
service.llama.title=LLaMA C/C++ Port (Free, Local)
service.ollama.title=Ollama (Free, Local)
service.google.title=Google Service
validation.error.fieldRequired=This field is required.
validation.error.invalidEmail=The email you entered is invalid.
validation.error.mustBeNumber=Value must be number.

View file

@ -7,10 +7,7 @@ import ee.carlrobert.codegpt.conversations.message.Message
import ee.carlrobert.codegpt.settings.configuration.ConfigurationSettings
import ee.carlrobert.llm.client.http.RequestEntity
import ee.carlrobert.llm.client.http.exchange.StreamHttpExchange
import ee.carlrobert.llm.client.util.JSONUtil.e
import ee.carlrobert.llm.client.util.JSONUtil.jsonArray
import ee.carlrobert.llm.client.util.JSONUtil.jsonMap
import ee.carlrobert.llm.client.util.JSONUtil.jsonMapResponse
import ee.carlrobert.llm.client.util.JSONUtil.*
import org.apache.http.HttpHeaders
import org.assertj.core.api.Assertions.assertThat
import testsupport.IntegrationTest
@ -171,6 +168,43 @@ class DefaultCompletionRequestHandlerTest : IntegrationTest() {
waitExpecting { "Hello!" == message.response }
}
fun testGoogleChatCompletionCall() {
useGoogleService()
ConfigurationSettings.getCurrentState().systemPrompt = "TEST_SYSTEM_PROMPT"
val message = Message("TEST_PROMPT")
val conversation = ConversationService.getInstance().startConversation()
val requestHandler = CompletionRequestHandler(getRequestEventListener(message))
expectGoogle(StreamHttpExchange { request: RequestEntity ->
assertThat(request.uri.path).isEqualTo("/v1/models/gemini-pro:streamGenerateContent")
assertThat(request.method).isEqualTo("POST")
assertThat(request.uri.query).isEqualTo("key=TEST_API_KEY&alt=sse")
assertThat(request.body)
.extracting("contents")
.isEqualTo(
listOf(
mapOf("parts" to listOf(mapOf("text" to "TEST_SYSTEM_PROMPT")), "role" to "user"),
mapOf("parts" to listOf(mapOf("text" to "Understood.")), "role" to "model"),
mapOf("parts" to listOf(mapOf("text" to "TEST_PROMPT")), "role" to "user"),
)
)
listOf(
jsonMapResponse(
"candidates",
jsonArray(jsonMap("content", jsonMap("parts", jsonArray(jsonMap("text", "Hello")))))
),
jsonMapResponse(
"candidates",
jsonArray(jsonMap("content", jsonMap("parts", jsonArray(jsonMap("text", "!")))))
)
)
})
requestHandler.call(CallParameters(conversation, ConversationType.DEFAULT, message, false))
waitExpecting { "Hello!" == message.response }
}
private fun getRequestEventListener(message: Message): CompletionResponseEventListener {
return object : CompletionResponseEventListener {
override fun handleCompleted(fullMessage: String, callParameters: CallParameters) {

View file

@ -1,14 +1,16 @@
package testsupport.mixin
import com.intellij.openapi.components.service
import com.intellij.testFramework.PlatformTestUtil
import ee.carlrobert.codegpt.credentials.CredentialsStore.CredentialKey.AZURE_OPENAI_API_KEY
import ee.carlrobert.codegpt.credentials.CredentialsStore.CredentialKey.OPENAI_API_KEY
import ee.carlrobert.codegpt.credentials.CredentialsStore.CredentialKey.*
import ee.carlrobert.codegpt.credentials.CredentialsStore.setCredential
import ee.carlrobert.codegpt.settings.GeneralSettings
import ee.carlrobert.codegpt.settings.service.ServiceType
import ee.carlrobert.codegpt.settings.service.azure.AzureSettings
import ee.carlrobert.codegpt.settings.service.google.GoogleSettings
import ee.carlrobert.codegpt.settings.service.llama.LlamaSettings
import ee.carlrobert.codegpt.settings.service.openai.OpenAISettings
import ee.carlrobert.llm.client.google.models.GoogleModel
import java.util.function.BooleanSupplier
interface ShortcutsTestMixin {
@ -41,6 +43,13 @@ interface ShortcutsTestMixin {
LlamaSettings.getCurrentState().serverPort = null
}
fun useGoogleService() {
GeneralSettings.getCurrentState().selectedService = ServiceType.GOOGLE
setCredential(GOOGLE_API_KEY, "TEST_API_KEY")
service<GoogleSettings>().state.model = GoogleModel.GEMINI_PRO.code
}
fun waitExpecting(condition: BooleanSupplier?) {
PlatformTestUtil.waitWithEventsDispatching(
"Waiting for message response timed out or did not meet expected conditions",