mirror of
https://github.com/rcourtman/Pulse.git
synced 2026-04-28 19:41:17 +00:00
152 lines
4.5 KiB
Go
152 lines
4.5 KiB
Go
package providers
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
)
|
|
|
|
func TestChatStream_ToolResultsConnection(t *testing.T) {
|
|
// Setup a mock server to capture the request sent by ChatStream
|
|
var capturedBody geminiRequest
|
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.URL.Path == "/models/gemini-pro:streamGenerateContent" {
|
|
json.NewDecoder(r.Body).Decode(&capturedBody)
|
|
// Return a minimal SSE response to complete the request
|
|
w.Header().Set("Content-Type", "text/event-stream")
|
|
w.Write([]byte("data: {\"candidates\": [{\"content\": {\"parts\": [{\"text\": \"Response\"}]}}]}\n\n"))
|
|
return
|
|
}
|
|
w.WriteHeader(http.StatusNotFound)
|
|
}))
|
|
defer ts.Close()
|
|
|
|
client := NewGeminiClient("fake-key", "gemini-pro", ts.URL, 10*time.Second)
|
|
|
|
// Create a conversation history with tool usage
|
|
// 1. User asks question
|
|
// 2. Assistant calls tool
|
|
// 3. Tool returns result
|
|
toolID := "list_alerts_0"
|
|
toolName := "list_alerts"
|
|
|
|
messages := []Message{
|
|
{
|
|
Role: "user",
|
|
Content: "List alerts",
|
|
},
|
|
{
|
|
Role: "assistant", // Model calls use "model" role in Gemini, "assistant" in generic
|
|
ToolCalls: []ToolCall{
|
|
{
|
|
ID: toolID,
|
|
Name: toolName,
|
|
Input: map[string]interface{}{
|
|
"limit": 10,
|
|
},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
Role: "user",
|
|
ToolResult: &ToolResult{
|
|
ToolUseID: toolID, // Result uses ID
|
|
Content: `{"alerts": []}`,
|
|
},
|
|
},
|
|
}
|
|
|
|
req := ChatRequest{
|
|
Messages: messages,
|
|
Model: "gemini-pro",
|
|
}
|
|
|
|
err := client.ChatStream(context.Background(), req, func(event StreamEvent) {})
|
|
assert.NoError(t, err)
|
|
|
|
// Verify the request sent to Gemini
|
|
// We expect the ToolResult message to be converted to "functionResponse"
|
|
// AND the name should be resolved to "list_alerts" (from previous assistant msg) not "list_alerts_0"
|
|
|
|
// Check content structure
|
|
// Index 0: User "List alerts"
|
|
// Index 1: Model "functionCall"
|
|
// Index 2: User "functionResponse"
|
|
|
|
assert.Equal(t, 3, len(capturedBody.Contents))
|
|
|
|
// Verify Index 2 (Tool Result)
|
|
lastContent := capturedBody.Contents[2]
|
|
assert.Equal(t, "user", lastContent.Role)
|
|
assert.Equal(t, 1, len(lastContent.Parts))
|
|
|
|
part := lastContent.Parts[0]
|
|
assert.NotNil(t, part.FunctionResponse)
|
|
// THIS IS THE KEY ASSERTION: Name must match function name, not ID
|
|
assert.Equal(t, toolName, part.FunctionResponse.Name)
|
|
assert.Equal(t, `{"alerts": []}`, part.FunctionResponse.Response.Content)
|
|
}
|
|
|
|
func TestChatStream_ToolResults_MultipleMerged(t *testing.T) {
|
|
// Setup a mock server
|
|
var capturedBody geminiRequest
|
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
json.NewDecoder(r.Body).Decode(&capturedBody)
|
|
w.Header().Set("Content-Type", "text/event-stream")
|
|
w.Write([]byte("data: {}\n\n"))
|
|
}))
|
|
defer ts.Close()
|
|
|
|
client := NewGeminiClient("fake-key", "gemini-pro", ts.URL, 10*time.Second)
|
|
|
|
messages := []Message{
|
|
{Role: "user", Content: "Run"},
|
|
{
|
|
Role: "assistant", // Assistant calls 3 tools
|
|
ToolCalls: []ToolCall{
|
|
{ID: "call_1", Name: "func1", Input: nil},
|
|
{ID: "call_2", Name: "func2", Input: nil},
|
|
{ID: "call_3", Name: "func3", Input: nil},
|
|
},
|
|
},
|
|
{
|
|
Role: "user",
|
|
ToolResult: &ToolResult{ToolUseID: "call_1", Content: "res1"},
|
|
},
|
|
{
|
|
Role: "user",
|
|
ToolResult: &ToolResult{ToolUseID: "call_2", Content: "res2"},
|
|
},
|
|
{
|
|
Role: "user",
|
|
ToolResult: &ToolResult{ToolUseID: "call_3", Content: "res3"},
|
|
},
|
|
}
|
|
|
|
req := ChatRequest{Messages: messages}
|
|
client.ChatStream(context.Background(), req, func(e StreamEvent) {})
|
|
|
|
// Expect merged content for the tool results
|
|
// Contents: [User, Model(3 calls), User(merged 3 results)]
|
|
assert.Equal(t, 3, len(capturedBody.Contents))
|
|
|
|
mergedUserMsg := capturedBody.Contents[2]
|
|
assert.Equal(t, "user", mergedUserMsg.Role)
|
|
assert.Equal(t, 3, len(mergedUserMsg.Parts))
|
|
|
|
// Check correctness of resolved names for ALL parts
|
|
// Previously, the 3rd part likely failed name resolution
|
|
assert.Equal(t, "func1", mergedUserMsg.Parts[0].FunctionResponse.Name)
|
|
assert.Equal(t, "res1", mergedUserMsg.Parts[0].FunctionResponse.Response.Content)
|
|
|
|
assert.Equal(t, "func2", mergedUserMsg.Parts[1].FunctionResponse.Name)
|
|
assert.Equal(t, "res2", mergedUserMsg.Parts[1].FunctionResponse.Response.Content)
|
|
|
|
assert.Equal(t, "func3", mergedUserMsg.Parts[2].FunctionResponse.Name)
|
|
assert.Equal(t, "res3", mergedUserMsg.Parts[2].FunctionResponse.Response.Content)
|
|
}
|