mirror of
https://github.com/rcourtman/Pulse.git
synced 2026-04-28 03:20:11 +00:00
test: add stream, restart, and fallback tests for AI handlers and providers
This commit is contained in:
parent
66e8460196
commit
1ac53fa9f1
3 changed files with 281 additions and 0 deletions
152
internal/ai/providers/gemini_stream_test.go
Normal file
152
internal/ai/providers/gemini_stream_test.go
Normal file
|
|
@ -0,0 +1,152 @@
|
|||
package providers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestChatStream_ToolResultsConnection(t *testing.T) {
|
||||
// Setup a mock server to capture the request sent by ChatStream
|
||||
var capturedBody geminiRequest
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/models/gemini-pro:streamGenerateContent" {
|
||||
json.NewDecoder(r.Body).Decode(&capturedBody)
|
||||
// Return a minimal SSE response to complete the request
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.Write([]byte("data: {\"candidates\": [{\"content\": {\"parts\": [{\"text\": \"Response\"}]}}]}\n\n"))
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
client := NewGeminiClient("fake-key", "gemini-pro", ts.URL, 10*time.Second)
|
||||
|
||||
// Create a conversation history with tool usage
|
||||
// 1. User asks question
|
||||
// 2. Assistant calls tool
|
||||
// 3. Tool returns result
|
||||
toolID := "list_alerts_0"
|
||||
toolName := "list_alerts"
|
||||
|
||||
messages := []Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "List alerts",
|
||||
},
|
||||
{
|
||||
Role: "assistant", // Model calls use "model" role in Gemini, "assistant" in generic
|
||||
ToolCalls: []ToolCall{
|
||||
{
|
||||
ID: toolID,
|
||||
Name: toolName,
|
||||
Input: map[string]interface{}{
|
||||
"limit": 10,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Role: "user",
|
||||
ToolResult: &ToolResult{
|
||||
ToolUseID: toolID, // Result uses ID
|
||||
Content: `{"alerts": []}`,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
req := ChatRequest{
|
||||
Messages: messages,
|
||||
Model: "gemini-pro",
|
||||
}
|
||||
|
||||
err := client.ChatStream(context.Background(), req, func(event StreamEvent) {})
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Verify the request sent to Gemini
|
||||
// We expect the ToolResult message to be converted to "functionResponse"
|
||||
// AND the name should be resolved to "list_alerts" (from previous assistant msg) not "list_alerts_0"
|
||||
|
||||
// Check content structure
|
||||
// Index 0: User "List alerts"
|
||||
// Index 1: Model "functionCall"
|
||||
// Index 2: User "functionResponse"
|
||||
|
||||
assert.Equal(t, 3, len(capturedBody.Contents))
|
||||
|
||||
// Verify Index 2 (Tool Result)
|
||||
lastContent := capturedBody.Contents[2]
|
||||
assert.Equal(t, "user", lastContent.Role)
|
||||
assert.Equal(t, 1, len(lastContent.Parts))
|
||||
|
||||
part := lastContent.Parts[0]
|
||||
assert.NotNil(t, part.FunctionResponse)
|
||||
// THIS IS THE KEY ASSERTION: Name must match function name, not ID
|
||||
assert.Equal(t, toolName, part.FunctionResponse.Name)
|
||||
assert.Equal(t, `{"alerts": []}`, part.FunctionResponse.Response.Content)
|
||||
}
|
||||
|
||||
func TestChatStream_ToolResults_MultipleMerged(t *testing.T) {
|
||||
// Setup a mock server
|
||||
var capturedBody geminiRequest
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
json.NewDecoder(r.Body).Decode(&capturedBody)
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.Write([]byte("data: {}\n\n"))
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
client := NewGeminiClient("fake-key", "gemini-pro", ts.URL, 10*time.Second)
|
||||
|
||||
messages := []Message{
|
||||
{Role: "user", Content: "Run"},
|
||||
{
|
||||
Role: "assistant", // Assistant calls 3 tools
|
||||
ToolCalls: []ToolCall{
|
||||
{ID: "call_1", Name: "func1", Input: nil},
|
||||
{ID: "call_2", Name: "func2", Input: nil},
|
||||
{ID: "call_3", Name: "func3", Input: nil},
|
||||
},
|
||||
},
|
||||
{
|
||||
Role: "user",
|
||||
ToolResult: &ToolResult{ToolUseID: "call_1", Content: "res1"},
|
||||
},
|
||||
{
|
||||
Role: "user",
|
||||
ToolResult: &ToolResult{ToolUseID: "call_2", Content: "res2"},
|
||||
},
|
||||
{
|
||||
Role: "user",
|
||||
ToolResult: &ToolResult{ToolUseID: "call_3", Content: "res3"},
|
||||
},
|
||||
}
|
||||
|
||||
req := ChatRequest{Messages: messages}
|
||||
client.ChatStream(context.Background(), req, func(e StreamEvent) {})
|
||||
|
||||
// Expect merged content for the tool results
|
||||
// Contents: [User, Model(3 calls), User(merged 3 results)]
|
||||
assert.Equal(t, 3, len(capturedBody.Contents))
|
||||
|
||||
mergedUserMsg := capturedBody.Contents[2]
|
||||
assert.Equal(t, "user", mergedUserMsg.Role)
|
||||
assert.Equal(t, 3, len(mergedUserMsg.Parts))
|
||||
|
||||
// Check correctness of resolved names for ALL parts
|
||||
// Previously, the 3rd part likely failed name resolution
|
||||
assert.Equal(t, "func1", mergedUserMsg.Parts[0].FunctionResponse.Name)
|
||||
assert.Equal(t, "res1", mergedUserMsg.Parts[0].FunctionResponse.Response.Content)
|
||||
|
||||
assert.Equal(t, "func2", mergedUserMsg.Parts[1].FunctionResponse.Name)
|
||||
assert.Equal(t, "res2", mergedUserMsg.Parts[1].FunctionResponse.Response.Content)
|
||||
|
||||
assert.Equal(t, "func3", mergedUserMsg.Parts[2].FunctionResponse.Name)
|
||||
assert.Equal(t, "res3", mergedUserMsg.Parts[2].FunctionResponse.Response.Content)
|
||||
}
|
||||
44
internal/api/ai_handler_restart_test.go
Normal file
44
internal/api/ai_handler_restart_test.go
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/ai/chat"
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/config"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
func TestRestart_StartIfStopped(t *testing.T) {
|
||||
// Mock newChatService factory
|
||||
oldNewService := newChatService
|
||||
defer func() { newChatService = oldNewService }()
|
||||
|
||||
mockSvc := new(MockAIService)
|
||||
newChatService = func(cfg chat.Config) AIService {
|
||||
return mockSvc
|
||||
}
|
||||
|
||||
mockPersist := new(MockAIPersistence)
|
||||
h := newTestAIHandler(nil, mockPersist, nil)
|
||||
// We need h.legacyService to be non-nil for the Restart check to proceed past first nil check
|
||||
// But it must return IsRunning() = false
|
||||
h.legacyService = mockSvc
|
||||
|
||||
// Config allows enabling
|
||||
aiCfg := &config.AIConfig{Enabled: true}
|
||||
mockPersist.On("LoadAIConfig").Return(aiCfg, nil)
|
||||
|
||||
// Service is NOT running
|
||||
mockSvc.On("IsRunning").Return(false)
|
||||
|
||||
// Start should be called because Enabled=true
|
||||
mockSvc.On("Start", mock.Anything).Return(nil)
|
||||
|
||||
err := h.Restart(context.Background())
|
||||
assert.NoError(t, err)
|
||||
|
||||
mockSvc.AssertExpectations(t)
|
||||
mockPersist.AssertExpectations(t)
|
||||
}
|
||||
85
internal/api/metrics_history_fallback_test.go
Normal file
85
internal/api/metrics_history_fallback_test.go
Normal file
|
|
@ -0,0 +1,85 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"math"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/config"
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/models"
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/monitoring"
|
||||
)
|
||||
|
||||
type metricsHistoryResponse struct {
|
||||
ResourceType string `json:"resourceType"`
|
||||
ResourceId string `json:"resourceId"`
|
||||
Metric string `json:"metric"`
|
||||
Range string `json:"range"`
|
||||
Source string `json:"source"`
|
||||
Points []struct {
|
||||
Timestamp int64 `json:"timestamp"`
|
||||
Value float64 `json:"value"`
|
||||
} `json:"points"`
|
||||
}
|
||||
|
||||
func TestMetricsHistoryFallbackUsesLivePoint(t *testing.T) {
|
||||
state := models.NewState()
|
||||
vm := models.VM{
|
||||
ID: "pve1:node1:101",
|
||||
VMID: 101,
|
||||
Name: "vm-101",
|
||||
Node: "node1",
|
||||
Instance: "pve1",
|
||||
Status: "running",
|
||||
Type: "qemu",
|
||||
CPU: 0.42,
|
||||
Memory: models.Memory{
|
||||
Usage: 55.0,
|
||||
},
|
||||
Disk: models.Disk{
|
||||
Usage: 33.0,
|
||||
},
|
||||
}
|
||||
state.UpdateVMsForInstance("pve1", []models.VM{vm})
|
||||
|
||||
monitor := &monitoring.Monitor{}
|
||||
setUnexportedField(t, monitor, "state", state)
|
||||
setUnexportedField(t, monitor, "metricsHistory", monitoring.NewMetricsHistory(10, time.Hour))
|
||||
|
||||
tempDir := t.TempDir()
|
||||
mtp := config.NewMultiTenantPersistence(tempDir)
|
||||
if _, err := mtp.GetPersistence("default"); err != nil {
|
||||
t.Fatalf("failed to init persistence: %v", err)
|
||||
}
|
||||
|
||||
router := &Router{
|
||||
monitor: monitor,
|
||||
licenseHandlers: NewLicenseHandlers(mtp),
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/metrics-store/history?resourceType=vm&resourceId=pve1:node1:101&metric=cpu&range=24h", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.handleMetricsHistory(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d", rec.Code)
|
||||
}
|
||||
|
||||
var resp metricsHistoryResponse
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("failed to unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
if resp.Source != "live" {
|
||||
t.Fatalf("expected source live, got %q", resp.Source)
|
||||
}
|
||||
if len(resp.Points) != 1 {
|
||||
t.Fatalf("expected 1 point, got %d", len(resp.Points))
|
||||
}
|
||||
if math.Abs(resp.Points[0].Value-42.0) > 0.001 {
|
||||
t.Fatalf("expected value 42.0, got %f", resp.Points[0].Value)
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue