From ffb829957176631677aa29d51f38fe4e46f5ed49 Mon Sep 17 00:00:00 2001 From: Carl-Robert Linnupuu Date: Fri, 8 Dec 2023 01:03:41 +0200 Subject: [PATCH] fix: azure host and path overriding --- .../codegpt.java-conventions.gradle.kts | 2 +- .../completions/CompletionClientProvider.java | 10 +++-- .../DefaultCompletionRequestHandlerTest.java | 38 +++++++++++++++++++ 3 files changed, 46 insertions(+), 4 deletions(-) diff --git a/buildSrc/src/main/kotlin/codegpt.java-conventions.gradle.kts b/buildSrc/src/main/kotlin/codegpt.java-conventions.gradle.kts index 2788e29a..49380b3d 100644 --- a/buildSrc/src/main/kotlin/codegpt.java-conventions.gradle.kts +++ b/buildSrc/src/main/kotlin/codegpt.java-conventions.gradle.kts @@ -23,7 +23,7 @@ checkstyle { } dependencies { - implementation("ee.carlrobert:llm-client:0.1.1") + implementation("ee.carlrobert:llm-client:0.1.2") } tasks { diff --git a/src/main/java/ee/carlrobert/codegpt/completions/CompletionClientProvider.java b/src/main/java/ee/carlrobert/codegpt/completions/CompletionClientProvider.java index 9821f1db..a8ad98b8 100644 --- a/src/main/java/ee/carlrobert/codegpt/completions/CompletionClientProvider.java +++ b/src/main/java/ee/carlrobert/codegpt/completions/CompletionClientProvider.java @@ -39,9 +39,13 @@ public class CompletionClientProvider { settings.getResourceName(), settings.getDeploymentId(), settings.getApiVersion()); - return new AzureClient.Builder(AzureCredentialsManager.getInstance().getSecret(), params) - .setActiveDirectoryAuthentication(settings.isUseAzureActiveDirectoryAuthentication()) - .build(); + var builder = new AzureClient.Builder(AzureCredentialsManager.getInstance().getSecret(), params) + .setActiveDirectoryAuthentication(settings.isUseAzureActiveDirectoryAuthentication()); + var baseHost = settings.getBaseHost(); + if (baseHost != null) { + builder.setUrl(String.format(baseHost, params.getResourceName())); + } + return builder.build(); } public static YouClient getYouClient() { diff --git a/src/test/java/ee/carlrobert/codegpt/completions/DefaultCompletionRequestHandlerTest.java b/src/test/java/ee/carlrobert/codegpt/completions/DefaultCompletionRequestHandlerTest.java index 0585b07a..0a24808f 100644 --- a/src/test/java/ee/carlrobert/codegpt/completions/DefaultCompletionRequestHandlerTest.java +++ b/src/test/java/ee/carlrobert/codegpt/completions/DefaultCompletionRequestHandlerTest.java @@ -16,6 +16,7 @@ import ee.carlrobert.codegpt.conversations.Conversation; import ee.carlrobert.codegpt.conversations.ConversationService; import ee.carlrobert.codegpt.conversations.message.Message; import ee.carlrobert.codegpt.settings.configuration.ConfigurationState; +import ee.carlrobert.codegpt.settings.state.AzureSettingsState; import ee.carlrobert.llm.client.http.exchange.StreamHttpExchange; import java.util.List; import java.util.Map; @@ -90,6 +91,43 @@ public class DefaultCompletionRequestHandlerTest extends IntegrationTest { await().atMost(5, SECONDS).until(() -> "Hello!".equals(message.getResponse())); } + public void testAzureChatCompletionCallWithCustomSettings() { + useAzureService(); + AzureSettingsState.getInstance().setPath("/codegpt/deployments/%s/completions?api-version=%s"); + var conversationService = ConversationService.getInstance(); + var prevMessage = new Message("TEST_PREV_PROMPT"); + prevMessage.setResponse("TEST_PREV_RESPONSE"); + var conversation = conversationService.startConversation(); + conversation.addMessage(prevMessage); + conversationService.saveConversation(conversation); + expectAzure((StreamHttpExchange) request -> { + assertThat(request.getUri().getPath()) + .isEqualTo("/codegpt/deployments/TEST_DEPLOYMENT_ID/completions"); + assertThat(request.getUri().getQuery()).isEqualTo("api-version=TEST_API_VERSION"); + assertThat(request.getHeaders().get("Api-key").get(0)).isEqualTo("TEST_API_KEY"); + assertThat(request.getHeaders().get("X-application-name").get(0)).isEqualTo("CODEGPT"); + assertThat(request.getBody()) + .extracting("messages") + .isEqualTo( + List.of( + Map.of("role", "system", "content", COMPLETION_SYSTEM_PROMPT), + Map.of("role", "user", "content", "TEST_PREV_PROMPT"), + Map.of("role", "assistant", "content", "TEST_PREV_RESPONSE"), + Map.of("role", "user", "content", "TEST_PROMPT"))); + return List.of( + jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("role", "assistant")))), + jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "Hel")))), + jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "lo")))), + jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "!"))))); + }); + var message = new Message("TEST_PROMPT"); + var requestHandler = new CompletionRequestHandler(false, getRequestEventListener(message)); + + requestHandler.call(conversation, message, false); + + await().atMost(5, SECONDS).until(() -> "Hello!".equals(message.getResponse())); + } + public void testYouChatCompletionCall() { useYouService(); var message = new Message("TEST_PROMPT");