mirror of
https://github.com/rcourtman/Pulse.git
synced 2026-04-30 04:20:20 +00:00
440 lines
11 KiB
Go
440 lines
11 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"net"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"os"
|
|
"os/exec"
|
|
"os/user"
|
|
"path/filepath"
|
|
"runtime"
|
|
"strconv"
|
|
"strings"
|
|
"testing"
|
|
|
|
"github.com/rs/zerolog"
|
|
)
|
|
|
|
func TestParseLogLevel(t *testing.T) {
|
|
tests := []struct {
|
|
input string
|
|
want zerolog.Level
|
|
}{
|
|
{"trace", zerolog.TraceLevel},
|
|
{"debug", zerolog.DebugLevel},
|
|
{"DEBUG", zerolog.DebugLevel},
|
|
{"info", zerolog.InfoLevel},
|
|
{"warn", zerolog.WarnLevel},
|
|
{"warning", zerolog.WarnLevel},
|
|
{"error", zerolog.ErrorLevel},
|
|
{"fatal", zerolog.FatalLevel},
|
|
{"panic", zerolog.PanicLevel},
|
|
{"disabled", zerolog.Disabled},
|
|
{"none", zerolog.Disabled},
|
|
{"unknown", zerolog.InfoLevel},
|
|
{"", zerolog.InfoLevel},
|
|
}
|
|
for _, tt := range tests {
|
|
if got := parseLogLevel(tt.input); got != tt.want {
|
|
t.Errorf("parseLogLevel(%q) = %v, want %v", tt.input, got, tt.want)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestDropPrivileges(t *testing.T) {
|
|
// Save original functions
|
|
oldGeteuid := osGeteuid
|
|
oldResolve := resolveUserSpecFunc
|
|
oldSetgroups := unixSetgroups
|
|
oldSetgid := unixSetgid
|
|
oldSetuid := unixSetuid
|
|
defer func() {
|
|
osGeteuid = oldGeteuid
|
|
resolveUserSpecFunc = oldResolve
|
|
unixSetgroups = oldSetgroups
|
|
unixSetgid = oldSetgid
|
|
unixSetuid = oldSetuid
|
|
}()
|
|
|
|
// Mock for root user
|
|
osGeteuid = func() int { return 0 }
|
|
resolveUserSpecFunc = func(u string) (*userSpec, error) {
|
|
return &userSpec{name: u, uid: 1000, gid: 1000, groups: []int{1000}}, nil
|
|
}
|
|
unixSetgroups = func(g []int) error { return nil }
|
|
unixSetgid = func(g int) error { return nil }
|
|
unixSetuid = func(u int) error { return nil }
|
|
|
|
// Test success path
|
|
spec, err := dropPrivileges("testuser")
|
|
if err != nil {
|
|
t.Errorf("dropPrivileges failed: %v", err)
|
|
}
|
|
if spec == nil {
|
|
t.Fatal("expected spec, got nil")
|
|
}
|
|
if spec.uid != 1000 {
|
|
t.Errorf("expected uid 1000, got %d", spec.uid)
|
|
}
|
|
|
|
// Test non-root (should return nil, nil)
|
|
osGeteuid = func() int { return 1000 }
|
|
spec, err = dropPrivileges("testuser")
|
|
if err != nil {
|
|
t.Errorf("unexpected error for non-root: %v", err)
|
|
}
|
|
if spec != nil {
|
|
t.Error("expected nil spec for non-root")
|
|
}
|
|
}
|
|
|
|
func TestResolveUserSpec_PasswdFallback(t *testing.T) {
|
|
// Mock passwd file
|
|
tmpDir := t.TempDir()
|
|
pPath, _ := os.CreateTemp(tmpDir, "passwd")
|
|
pPath.WriteString("testuser:x:1001:1001::/home/testuser:/bin/sh\n")
|
|
pPath.Close()
|
|
|
|
origPath := passwdPath
|
|
defer func() { passwdPath = origPath }()
|
|
passwdPath = pPath.Name()
|
|
|
|
// Test
|
|
spec, err := lookupUserFromPasswd("testuser")
|
|
if err != nil {
|
|
t.Fatalf("lookup failed: %v", err)
|
|
}
|
|
if spec.uid != 1001 || spec.gid != 1001 {
|
|
t.Errorf("mismatch: %+v", spec)
|
|
}
|
|
|
|
// Test not found
|
|
_, err = lookupUserFromPasswd("nonexistent")
|
|
if err == nil {
|
|
t.Error("expected error for nonexistent user")
|
|
}
|
|
}
|
|
|
|
func TestResolveUserSpec(t *testing.T) {
|
|
// Mock lookupUserFromPasswd if we could, but resolveUserSpec calls user.Lookup first.
|
|
// We can test that it falls back to passwd if user.Lookup fails (which it might in test env).
|
|
// But lookupUserFromPasswd reads /etc/passwd path which is hardcoded.
|
|
// We can't mock os.Open easily without refactoring.
|
|
|
|
// However, we can test with a known user if possible, or mock the fallback.
|
|
// Let's assume user.Lookup works for current user.
|
|
u, err := user.Current()
|
|
if err != nil {
|
|
t.Skip("cannot get current user")
|
|
}
|
|
|
|
spec, err := resolveUserSpec(u.Username)
|
|
if err != nil {
|
|
t.Fatalf("resolveUserSpec failed: %v", err)
|
|
}
|
|
|
|
expectedUid, _ := strconv.Atoi(u.Uid)
|
|
if spec.uid != expectedUid {
|
|
t.Errorf("expected uid %d, got %d", expectedUid, spec.uid)
|
|
}
|
|
}
|
|
|
|
func TestResolveUserSpec_Root(t *testing.T) {
|
|
// Test success path with existing user
|
|
// We use "root" as it should exist on almost all unix systems
|
|
// Check if we can look it up first
|
|
if _, err := exec.LookPath("id"); err != nil {
|
|
t.Skip("id command not found")
|
|
}
|
|
// Or just try resolving it
|
|
spec, err := resolveUserSpec("root")
|
|
if err != nil {
|
|
t.Skipf("root user resolution failed (maybe minimal env?): %v", err)
|
|
}
|
|
|
|
if spec.name != "root" {
|
|
t.Errorf("expected name root, got %s", spec.name)
|
|
}
|
|
// UID 0 check
|
|
if spec.uid != 0 {
|
|
t.Logf("root uid is %d, not 0 (unusual but possible)", spec.uid)
|
|
}
|
|
}
|
|
|
|
func TestEnsureSSHKeypair(t *testing.T) {
|
|
tmpDir := t.TempDir()
|
|
proxy := &Proxy{sshKeyPath: tmpDir}
|
|
|
|
// Mock exec for ssh-keygen
|
|
origExec := execCommandFunc
|
|
defer func() { execCommandFunc = origExec }()
|
|
execCommandFunc = func(name string, arg ...string) *exec.Cmd {
|
|
args := strings.Join(arg, " ")
|
|
if strings.Contains(args, "ssh-keygen") {
|
|
// Parse shell command to find -f
|
|
// cmd string is in arg[1] (after -c)
|
|
if len(arg) > 1 {
|
|
cmdStr := arg[1]
|
|
parts := strings.Fields(cmdStr)
|
|
for i, p := range parts {
|
|
if p == "-f" && i+1 < len(parts) {
|
|
path := parts[i+1]
|
|
os.WriteFile(path, []byte("priv"), 0600)
|
|
os.WriteFile(path+".pub", []byte("pub"), 0644)
|
|
}
|
|
}
|
|
}
|
|
return mockExecCommand("")
|
|
}
|
|
return mockExecCommand("")
|
|
}
|
|
|
|
// First run: generate
|
|
if err := proxy.ensureSSHKeypair(); err != nil {
|
|
t.Fatalf("ensureSSHKeypair failed: %v", err)
|
|
}
|
|
|
|
if _, err := os.Stat(filepath.Join(tmpDir, "id_ed25519")); err != nil {
|
|
t.Error("private key not created")
|
|
}
|
|
|
|
// Second run: existing
|
|
// Restore exec to fail if called (should not be called)
|
|
execCommandFunc = func(name string, arg ...string) *exec.Cmd {
|
|
return errorExecCommand("should not be called")
|
|
}
|
|
if err := proxy.ensureSSHKeypair(); err != nil {
|
|
t.Fatalf("ensureSSHKeypair existing failed: %v", err)
|
|
}
|
|
}
|
|
|
|
type mockListener struct {
|
|
net.Listener
|
|
closed bool
|
|
}
|
|
|
|
func (m *mockListener) Close() error {
|
|
m.closed = true
|
|
return nil
|
|
}
|
|
|
|
func (m *mockListener) Accept() (net.Conn, error) {
|
|
// Block until closed
|
|
select {}
|
|
}
|
|
|
|
func (m *mockListener) Addr() net.Addr {
|
|
return &net.UnixAddr{Name: "/tmp/sock", Net: "unix"}
|
|
}
|
|
|
|
func TestProxy_StartStop(t *testing.T) {
|
|
tmpDir := t.TempDir()
|
|
sshDir := filepath.Join(tmpDir, "ssh")
|
|
socketPath := filepath.Join(tmpDir, "sock")
|
|
|
|
// Mock net.Listen
|
|
origListen := netListen
|
|
defer func() { netListen = origListen }()
|
|
listenCalled := false
|
|
netListen = func(network, address string) (net.Listener, error) {
|
|
listenCalled = true
|
|
return &mockListener{}, nil
|
|
}
|
|
|
|
// Mock exec for key gen
|
|
origExec := execCommandFunc
|
|
defer func() { execCommandFunc = origExec }()
|
|
execCommandFunc = func(name string, arg ...string) *exec.Cmd {
|
|
args := strings.Join(arg, " ")
|
|
if strings.Contains(args, "ssh-keygen") {
|
|
// Create dummy key files
|
|
for i, a := range arg {
|
|
if a == "-f" && i+1 < len(arg) {
|
|
os.MkdirAll(filepath.Dir(arg[i+1]), 0755)
|
|
os.WriteFile(arg[i+1], []byte("priv"), 0600)
|
|
os.WriteFile(arg[i+1]+".pub", []byte("pub"), 0644)
|
|
}
|
|
}
|
|
return mockExecCommand("")
|
|
}
|
|
return mockExecCommand("")
|
|
}
|
|
|
|
proxy := &Proxy{
|
|
sshKeyPath: sshDir,
|
|
socketPath: socketPath,
|
|
metrics: NewProxyMetrics("test"),
|
|
}
|
|
|
|
if err := proxy.Start(); err != nil {
|
|
t.Fatalf("Start failed: %v", err)
|
|
}
|
|
|
|
if !listenCalled {
|
|
t.Error("net.Listen not called")
|
|
}
|
|
|
|
// Check directories created
|
|
if _, err := os.Stat(sshDir); err != nil {
|
|
t.Error("ssh dir not created")
|
|
}
|
|
|
|
// Stop
|
|
proxy.Stop()
|
|
// Should close listener -> our mock doesn't block Stop.
|
|
|
|
// Check socket removed (Start code removes it first)
|
|
// But our mock listener doesn't create file.
|
|
// The Start() function calls os.RemoveAll(p.socketPath).
|
|
}
|
|
|
|
// Helpers for http_server_test.go which I might have deleted if they were in main_test.go
|
|
func mockExecCommand(output string) *exec.Cmd {
|
|
cs := []string{"-test.run=TestHelperProcess", "--", output}
|
|
cmd := exec.Command(os.Args[0], cs...)
|
|
cmd.Env = []string{"GO_WANT_HELPER_PROCESS=1", "GO_HELPER_OUTPUT=" + output}
|
|
return cmd
|
|
}
|
|
|
|
func errorExecCommand(msg string) *exec.Cmd {
|
|
cs := []string{"-test.run=TestHelperProcess", "--", msg}
|
|
cmd := exec.Command(os.Args[0], cs...)
|
|
cmd.Env = []string{"GO_WANT_HELPER_PROCESS=1", "GO_HELPER_ERROR=" + msg}
|
|
return cmd
|
|
}
|
|
|
|
func TestHelperProcess(t *testing.T) {
|
|
if os.Getenv("GO_WANT_HELPER_PROCESS") != "1" {
|
|
return
|
|
}
|
|
|
|
if msg := os.Getenv("GO_HELPER_ERROR"); msg != "" {
|
|
fmt.Fprint(os.Stderr, msg)
|
|
os.Exit(1)
|
|
}
|
|
|
|
output := os.Getenv("GO_HELPER_OUTPUT")
|
|
fmt.Fprint(os.Stdout, output)
|
|
|
|
if codeStr := os.Getenv("GO_HELPER_EXIT_CODE"); codeStr != "" {
|
|
if code, err := strconv.Atoi(codeStr); err == nil {
|
|
os.Exit(code)
|
|
}
|
|
}
|
|
os.Exit(0)
|
|
}
|
|
|
|
func TestFetchAuthorizedNodes(t *testing.T) {
|
|
// Mock server
|
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.Header.Get("X-Proxy-Token") != "token" {
|
|
http.Error(w, "unauthorized", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
resp := struct {
|
|
Nodes []struct {
|
|
Name string `json:"name"`
|
|
IP string `json:"ip"`
|
|
} `json:"nodes"`
|
|
Hash string `json:"hash"`
|
|
RefreshInterval int `json:"refresh_interval"`
|
|
}{
|
|
Nodes: []struct {
|
|
Name string `json:"name"`
|
|
IP string `json:"ip"`
|
|
}{
|
|
{Name: "node1", IP: "10.0.0.1"},
|
|
},
|
|
Hash: "abc",
|
|
RefreshInterval: 60,
|
|
}
|
|
json.NewEncoder(w).Encode(resp)
|
|
})
|
|
server := httptest.NewServer(handler)
|
|
defer server.Close()
|
|
|
|
nv, _ := newNodeValidator(nil, nil)
|
|
proxy := &Proxy{
|
|
controlPlaneCfg: &ControlPlaneConfig{
|
|
URL: server.URL,
|
|
},
|
|
controlPlaneToken: "token",
|
|
nodeValidator: nv,
|
|
}
|
|
|
|
// Need a client
|
|
client := server.Client()
|
|
|
|
if err := proxy.fetchAuthorizedNodes(client); err != nil {
|
|
t.Fatalf("fetchAuthorizedNodes failed: %v", err)
|
|
}
|
|
|
|
if proxy.controlPlaneCfg.RefreshIntervalSec != 60 {
|
|
t.Errorf("expected refresh interval 60, got %d", proxy.controlPlaneCfg.RefreshIntervalSec)
|
|
}
|
|
|
|
// Need context for Validate
|
|
if err := proxy.nodeValidator.Validate(context.Background(), "node1"); err != nil {
|
|
t.Errorf("expected node1 to be valid: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestDefaultExtractPeerCredentials_Real(t *testing.T) {
|
|
if runtime.GOOS != "linux" {
|
|
t.Skip("Skipping Linux-specific test")
|
|
}
|
|
|
|
// Create unix socket
|
|
tmpDir := t.TempDir()
|
|
sockPath := filepath.Join(tmpDir, "test.sock")
|
|
l, err := net.Listen("unix", sockPath)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer l.Close()
|
|
|
|
serverConnCh := make(chan net.Conn)
|
|
|
|
// Accept in goroutine
|
|
go func() {
|
|
c, err := l.Accept()
|
|
if err != nil {
|
|
close(serverConnCh)
|
|
return
|
|
}
|
|
serverConnCh <- c
|
|
}()
|
|
|
|
clientConn, err := net.Dial("unix", sockPath)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer clientConn.Close()
|
|
|
|
serverConn := <-serverConnCh
|
|
if serverConn == nil {
|
|
t.Fatal("failed to accept connection")
|
|
}
|
|
defer serverConn.Close()
|
|
|
|
creds, err := defaultExtractPeerCredentials(serverConn)
|
|
if err != nil {
|
|
t.Fatalf("defaultExtractPeerCredentials failed: %v", err)
|
|
}
|
|
|
|
// Check against current process
|
|
if creds.uid != uint32(os.Getuid()) {
|
|
t.Errorf("expected uid %d, got %d", os.Getuid(), creds.uid)
|
|
}
|
|
if creds.gid != uint32(os.Getgid()) {
|
|
t.Errorf("expected gid %d, got %d", os.Getgid(), creds.gid)
|
|
}
|
|
if creds.pid != uint32(os.Getpid()) {
|
|
t.Errorf("expected pid %d, got %d", os.Getpid(), creds.pid)
|
|
}
|
|
}
|