mirror of
https://github.com/rcourtman/Pulse.git
synced 2026-04-28 11:30:15 +00:00
777 lines
21 KiB
Go
777 lines
21 KiB
Go
package knownhosts
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"errors"
|
|
"io"
|
|
"os"
|
|
"path/filepath"
|
|
"runtime"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
)
|
|
|
|
func resetKnownHostsFns() {
|
|
mkdirAllFn = defaultMkdirAllFn
|
|
statFn = defaultStatFn
|
|
openFileFn = defaultOpenFileFn
|
|
openFn = defaultOpenFn
|
|
appendOpenFileFn = defaultAppendOpenFileFn
|
|
keyscanCmdRunner = defaultKeyscanCmdRunner
|
|
}
|
|
|
|
func TestEnsureCreatesFileAndCaches(t *testing.T) {
|
|
dir := t.TempDir()
|
|
path := filepath.Join(dir, "known_hosts")
|
|
|
|
var calls int
|
|
keyscan := func(ctx context.Context, host string, port int, timeout time.Duration) ([]byte, error) {
|
|
calls++
|
|
return []byte(host + " ssh-ed25519 AAAA"), nil
|
|
}
|
|
|
|
mgr, err := NewManager(path, WithKeyscanFunc(keyscan))
|
|
if err != nil {
|
|
t.Fatalf("NewManager: %v", err)
|
|
}
|
|
|
|
ctx := context.Background()
|
|
if err := mgr.Ensure(ctx, "example.com"); err != nil {
|
|
t.Fatalf("Ensure: %v", err)
|
|
}
|
|
if _, err := os.Stat(path); err != nil {
|
|
t.Fatalf("known_hosts not created: %v", err)
|
|
}
|
|
|
|
if err := mgr.Ensure(ctx, "example.com"); err != nil {
|
|
t.Fatalf("Ensure second call: %v", err)
|
|
}
|
|
if calls != 1 {
|
|
t.Fatalf("expected keyscan once, got %d", calls)
|
|
}
|
|
}
|
|
|
|
func TestEnsureUsesSanitizedOutput(t *testing.T) {
|
|
dir := t.TempDir()
|
|
path := filepath.Join(dir, "known_hosts")
|
|
|
|
keyscan := func(ctx context.Context, host string, port int, timeout time.Duration) ([]byte, error) {
|
|
return []byte(`# comment
|
|
example.com ssh-ed25519 AAAA
|
|
example.com,192.0.2.10 ssh-rsa BBBB
|
|
other.com ssh-ed25519 CCCC
|
|
`), nil
|
|
}
|
|
|
|
mgr, err := NewManager(path, WithKeyscanFunc(keyscan))
|
|
if err != nil {
|
|
t.Fatalf("NewManager: %v", err)
|
|
}
|
|
|
|
if err := mgr.Ensure(context.Background(), "example.com"); err != nil {
|
|
t.Fatalf("Ensure: %v", err)
|
|
}
|
|
|
|
data, err := os.ReadFile(path)
|
|
if err != nil {
|
|
t.Fatalf("ReadFile: %v", err)
|
|
}
|
|
if want := "example.com ssh-ed25519 AAAA\nexample.com ssh-rsa BBBB\n"; string(data) != want {
|
|
t.Fatalf("unexpected known_hosts contents\nwant:\n%s\ngot:\n%s", want, data)
|
|
}
|
|
}
|
|
|
|
func TestEnsureReturnsErrorWhenNoEntries(t *testing.T) {
|
|
dir := t.TempDir()
|
|
path := filepath.Join(dir, "known_hosts")
|
|
|
|
mgr, err := NewManager(path, WithKeyscanFunc(func(ctx context.Context, host string, port int, timeout time.Duration) ([]byte, error) {
|
|
return []byte("|1|hash|salt ssh-ed25519 AAAA\n"), nil
|
|
}))
|
|
if err != nil {
|
|
t.Fatalf("NewManager: %v", err)
|
|
}
|
|
|
|
err = mgr.Ensure(context.Background(), "example.com")
|
|
if !errors.Is(err, ErrNoHostKeys) {
|
|
t.Fatalf("expected ErrNoHostKeys, got %v", err)
|
|
}
|
|
}
|
|
|
|
func TestEnsureRespectsContextCancellation(t *testing.T) {
|
|
dir := t.TempDir()
|
|
path := filepath.Join(dir, "known_hosts")
|
|
|
|
keyscan := func(ctx context.Context, host string, port int, timeout time.Duration) ([]byte, error) {
|
|
<-ctx.Done()
|
|
return nil, ctx.Err()
|
|
}
|
|
|
|
mgr, err := NewManager(path, WithKeyscanFunc(keyscan), WithTimeout(50*time.Millisecond))
|
|
if err != nil {
|
|
t.Fatalf("NewManager: %v", err)
|
|
}
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
cancel()
|
|
|
|
if err := mgr.Ensure(ctx, "example.com"); err == nil {
|
|
t.Fatalf("expected context error, got nil")
|
|
}
|
|
}
|
|
|
|
func TestNewManagerEmptyPath(t *testing.T) {
|
|
if _, err := NewManager(""); err == nil {
|
|
t.Fatal("expected error for empty path")
|
|
}
|
|
}
|
|
|
|
func TestEnsureWithPortMissingHost(t *testing.T) {
|
|
dir := t.TempDir()
|
|
path := filepath.Join(dir, "known_hosts")
|
|
|
|
mgr, err := NewManager(path)
|
|
if err != nil {
|
|
t.Fatalf("NewManager: %v", err)
|
|
}
|
|
|
|
if err := mgr.EnsureWithPort(context.Background(), "", 22); err == nil {
|
|
t.Fatal("expected error for missing host")
|
|
}
|
|
}
|
|
|
|
func TestEnsureWithPortDefaultsPort(t *testing.T) {
|
|
dir := t.TempDir()
|
|
path := filepath.Join(dir, "known_hosts")
|
|
|
|
var gotPort int
|
|
keyscan := func(ctx context.Context, host string, port int, timeout time.Duration) ([]byte, error) {
|
|
gotPort = port
|
|
return []byte(host + " ssh-ed25519 AAAA"), nil
|
|
}
|
|
|
|
mgr, err := NewManager(path, WithKeyscanFunc(keyscan))
|
|
if err != nil {
|
|
t.Fatalf("NewManager: %v", err)
|
|
}
|
|
|
|
if err := mgr.EnsureWithPort(context.Background(), "example.com", 0); err != nil {
|
|
t.Fatalf("EnsureWithPort: %v", err)
|
|
}
|
|
if gotPort != 22 {
|
|
t.Fatalf("expected port 22, got %d", gotPort)
|
|
}
|
|
}
|
|
|
|
func TestEnsureWithPortCustomPort(t *testing.T) {
|
|
dir := t.TempDir()
|
|
path := filepath.Join(dir, "known_hosts")
|
|
|
|
keyscan := func(ctx context.Context, host string, port int, timeout time.Duration) ([]byte, error) {
|
|
return []byte("[example.com]:2222 ssh-ed25519 AAAA"), nil
|
|
}
|
|
|
|
mgr, err := NewManager(path, WithKeyscanFunc(keyscan))
|
|
if err != nil {
|
|
t.Fatalf("NewManager: %v", err)
|
|
}
|
|
|
|
if err := mgr.EnsureWithPort(context.Background(), "example.com", 2222); err != nil {
|
|
t.Fatalf("EnsureWithPort: %v", err)
|
|
}
|
|
|
|
data, err := os.ReadFile(path)
|
|
if err != nil {
|
|
t.Fatalf("ReadFile: %v", err)
|
|
}
|
|
if got := strings.TrimSpace(string(data)); got != "[example.com]:2222 ssh-ed25519 AAAA" {
|
|
t.Fatalf("unexpected known_hosts contents: %s", got)
|
|
}
|
|
}
|
|
|
|
func TestEnsureWithPortKeyscanError(t *testing.T) {
|
|
dir := t.TempDir()
|
|
path := filepath.Join(dir, "known_hosts")
|
|
|
|
keyscan := func(ctx context.Context, host string, port int, timeout time.Duration) ([]byte, error) {
|
|
return nil, errors.New("scan failed")
|
|
}
|
|
|
|
mgr, err := NewManager(path, WithKeyscanFunc(keyscan))
|
|
if err != nil {
|
|
t.Fatalf("NewManager: %v", err)
|
|
}
|
|
|
|
if err := mgr.EnsureWithPort(context.Background(), "example.com", 22); err == nil {
|
|
t.Fatal("expected keyscan error")
|
|
}
|
|
}
|
|
|
|
func TestEnsureWithEntriesMissingHost(t *testing.T) {
|
|
dir := t.TempDir()
|
|
path := filepath.Join(dir, "known_hosts")
|
|
|
|
mgr, err := NewManager(path)
|
|
if err != nil {
|
|
t.Fatalf("NewManager: %v", err)
|
|
}
|
|
|
|
if err := mgr.EnsureWithEntries(context.Background(), "", 22, [][]byte{[]byte("example.com ssh-ed25519 AAAA")}); err == nil {
|
|
t.Fatal("expected missing host error")
|
|
}
|
|
}
|
|
|
|
func TestEnsureWithEntriesNoEntries(t *testing.T) {
|
|
dir := t.TempDir()
|
|
path := filepath.Join(dir, "known_hosts")
|
|
|
|
mgr, err := NewManager(path)
|
|
if err != nil {
|
|
t.Fatalf("NewManager: %v", err)
|
|
}
|
|
|
|
if err := mgr.EnsureWithEntries(context.Background(), "example.com", 22, nil); err == nil {
|
|
t.Fatal("expected no entries error")
|
|
}
|
|
}
|
|
|
|
func TestEnsureWithEntriesNormalizeError(t *testing.T) {
|
|
dir := t.TempDir()
|
|
path := filepath.Join(dir, "known_hosts")
|
|
|
|
mgr, err := NewManager(path)
|
|
if err != nil {
|
|
t.Fatalf("NewManager: %v", err)
|
|
}
|
|
|
|
if err := mgr.EnsureWithEntries(context.Background(), "example.com", 22, [][]byte{[]byte("invalid")}); err == nil {
|
|
t.Fatal("expected normalize error")
|
|
}
|
|
}
|
|
|
|
func TestEnsureWithEntriesEnsureKnownHostsFileError(t *testing.T) {
|
|
t.Cleanup(resetKnownHostsFns)
|
|
mkdirAllFn = func(string, os.FileMode) error {
|
|
return errors.New("mkdir failed")
|
|
}
|
|
|
|
mgr, err := NewManager(filepath.Join(t.TempDir(), "known_hosts"))
|
|
if err != nil {
|
|
t.Fatalf("NewManager: %v", err)
|
|
}
|
|
|
|
if err := mgr.EnsureWithEntries(context.Background(), "example.com", 22, [][]byte{[]byte("example.com ssh-ed25519 AAAA")}); err == nil {
|
|
t.Fatal("expected ensureKnownHostsFile error")
|
|
}
|
|
}
|
|
|
|
func TestEnsureWithEntriesFindHostKeyLineError(t *testing.T) {
|
|
t.Cleanup(resetKnownHostsFns)
|
|
openFn = func(string) (*os.File, error) {
|
|
return nil, errors.New("open failed")
|
|
}
|
|
|
|
mgr, err := NewManager(filepath.Join(t.TempDir(), "known_hosts"))
|
|
if err != nil {
|
|
t.Fatalf("NewManager: %v", err)
|
|
}
|
|
|
|
if err := mgr.EnsureWithEntries(context.Background(), "example.com", 22, [][]byte{[]byte("example.com ssh-ed25519 AAAA")}); err == nil {
|
|
t.Fatal("expected open error")
|
|
}
|
|
}
|
|
|
|
func TestHostCandidates(t *testing.T) {
|
|
tests := []struct {
|
|
input string
|
|
want []string
|
|
}{
|
|
{"example.com", []string{"example.com"}},
|
|
{"example.com:2222", []string{"example.com:2222", "example.com"}},
|
|
{"[example.com]:2222", []string{"[example.com]:2222", "example.com"}},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
got := hostCandidates(tt.input)
|
|
if len(got) != len(tt.want) {
|
|
t.Fatalf("hostCandidates(%q) len = %d, want %d", tt.input, len(got), len(tt.want))
|
|
}
|
|
for i := range got {
|
|
if got[i] != tt.want[i] {
|
|
t.Fatalf("hostCandidates(%q)[%d] = %q, want %q", tt.input, i, got[i], tt.want[i])
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestEnsureWithEntriesDetectsChange(t *testing.T) {
|
|
dir := t.TempDir()
|
|
path := filepath.Join(dir, "known_hosts")
|
|
|
|
mgr, err := NewManager(path)
|
|
if err != nil {
|
|
t.Fatalf("NewManager: %v", err)
|
|
}
|
|
|
|
entry := []byte("example.com ssh-ed25519 AAAA")
|
|
if err := mgr.EnsureWithEntries(context.Background(), "example.com", 22, [][]byte{entry}); err != nil {
|
|
t.Fatalf("EnsureWithEntries: %v", err)
|
|
}
|
|
|
|
// Same entry should be a no-op
|
|
if err := mgr.EnsureWithEntries(context.Background(), "example.com", 22, [][]byte{entry}); err != nil {
|
|
t.Fatalf("EnsureWithEntries repeat: %v", err)
|
|
}
|
|
|
|
// Different key should trigger change error
|
|
changeEntry := []byte("example.com ssh-ed25519 BBBB")
|
|
err = mgr.EnsureWithEntries(context.Background(), "example.com", 22, [][]byte{changeEntry})
|
|
var changeErr *HostKeyChangeError
|
|
if !errors.As(err, &changeErr) {
|
|
t.Fatalf("expected HostKeyChangeError, got %v", err)
|
|
}
|
|
}
|
|
|
|
func TestEnsureWithEntriesAppendsNewKeyTypes(t *testing.T) {
|
|
dir := t.TempDir()
|
|
path := filepath.Join(dir, "known_hosts")
|
|
|
|
mgr, err := NewManager(path)
|
|
if err != nil {
|
|
t.Fatalf("NewManager: %v", err)
|
|
}
|
|
|
|
ctx := context.Background()
|
|
if err := mgr.EnsureWithEntries(ctx, "example.com", 22, [][]byte{[]byte("example.com ssh-ed25519 AAAA")}); err != nil {
|
|
t.Fatalf("EnsureWithEntries ed25519: %v", err)
|
|
}
|
|
if err := mgr.EnsureWithEntries(ctx, "example.com", 22, [][]byte{[]byte("example.com ssh-rsa BBBB")}); err != nil {
|
|
t.Fatalf("EnsureWithEntries rsa: %v", err)
|
|
}
|
|
|
|
data, err := os.ReadFile(path)
|
|
if err != nil {
|
|
t.Fatalf("ReadFile: %v", err)
|
|
}
|
|
got := string(data)
|
|
if !strings.Contains(got, "ssh-ed25519 AAAA") || !strings.Contains(got, "ssh-rsa BBBB") {
|
|
t.Fatalf("expected both key types, got %s", got)
|
|
}
|
|
}
|
|
|
|
func TestHostKeyChangeErrorError(t *testing.T) {
|
|
tests := []struct {
|
|
host string
|
|
want string
|
|
}{
|
|
{"example.com", "knownhosts: host key for example.com changed"},
|
|
{"192.168.1.1", "knownhosts: host key for 192.168.1.1 changed"},
|
|
{"[example.com]:2222", "knownhosts: host key for [example.com]:2222 changed"},
|
|
{"", "knownhosts: host key for changed"},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
err := &HostKeyChangeError{Host: tt.host}
|
|
if got := err.Error(); got != tt.want {
|
|
t.Errorf("HostKeyChangeError{Host: %q}.Error() = %q, want %q", tt.host, got, tt.want)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestHostKeyChangeErrorUnwrap(t *testing.T) {
|
|
err := &HostKeyChangeError{
|
|
Host: "example.com",
|
|
Existing: "example.com ssh-ed25519 AAAA",
|
|
Provided: "example.com ssh-ed25519 BBBB",
|
|
}
|
|
|
|
if !errors.Is(err, ErrHostKeyChanged) {
|
|
t.Error("errors.Is(HostKeyChangeError, ErrHostKeyChanged) = false, want true")
|
|
}
|
|
|
|
unwrapped := err.Unwrap()
|
|
if unwrapped != ErrHostKeyChanged {
|
|
t.Errorf("Unwrap() = %v, want %v", unwrapped, ErrHostKeyChanged)
|
|
}
|
|
}
|
|
|
|
func TestManagerPath(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
path string
|
|
}{
|
|
{"simple path", "/tmp/known_hosts"},
|
|
{"nested path", "/home/user/.ssh/known_hosts"},
|
|
{"relative-like path", "/opt/pulse/data/known_hosts"},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
mgr, err := NewManager(tt.path)
|
|
if err != nil {
|
|
t.Fatalf("NewManager(%q): %v", tt.path, err)
|
|
}
|
|
if got := mgr.Path(); got != tt.path {
|
|
t.Errorf("Path() = %q, want %q", got, tt.path)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestHostFieldMatches(t *testing.T) {
|
|
tests := []struct {
|
|
host string
|
|
field string
|
|
want bool
|
|
}{
|
|
// Exact matches
|
|
{"example.com", "example.com", true},
|
|
{"192.168.1.1", "192.168.1.1", true},
|
|
|
|
// Case insensitive
|
|
{"EXAMPLE.COM", "example.com", true},
|
|
{"example.com", "EXAMPLE.COM", true},
|
|
|
|
// Comma-separated hosts
|
|
{"example.com", "example.com,192.168.1.1", true},
|
|
{"192.168.1.1", "example.com,192.168.1.1", true},
|
|
{"other.com", "example.com,192.168.1.1", false},
|
|
|
|
// Bracketed hosts with ports
|
|
{"[example.com]:2222", "[example.com]:2222", true},
|
|
{"example.com", "[example.com]:2222", true},
|
|
|
|
// Host:port format
|
|
{"example.com:2222", "example.com:2222", true},
|
|
{"example.com", "example.com:2222", true},
|
|
|
|
// No match
|
|
{"other.com", "example.com", false},
|
|
{"example.org", "example.com", false},
|
|
|
|
// Empty cases
|
|
{"example.com", "", false},
|
|
{"", "example.com", false},
|
|
{"", "", false},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
name := tt.host + "_" + tt.field
|
|
t.Run(name, func(t *testing.T) {
|
|
if got := HostFieldMatches(tt.host, tt.field); got != tt.want {
|
|
t.Errorf("HostFieldMatches(%q, %q) = %v, want %v", tt.host, tt.field, got, tt.want)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestEnsureKnownHostsFileMkdirError(t *testing.T) {
|
|
t.Cleanup(resetKnownHostsFns)
|
|
mkdirAllFn = func(string, os.FileMode) error {
|
|
return errors.New("mkdir failed")
|
|
}
|
|
|
|
m := &manager{path: filepath.Join(t.TempDir(), "known_hosts")}
|
|
if err := m.ensureKnownHostsFile(); err == nil {
|
|
t.Fatal("expected mkdir error")
|
|
}
|
|
}
|
|
|
|
func TestEnsureKnownHostsFileStatError(t *testing.T) {
|
|
t.Cleanup(resetKnownHostsFns)
|
|
statFn = func(string) (os.FileInfo, error) {
|
|
return nil, errors.New("stat failed")
|
|
}
|
|
|
|
m := &manager{path: filepath.Join(t.TempDir(), "known_hosts")}
|
|
if err := m.ensureKnownHostsFile(); err == nil {
|
|
t.Fatal("expected stat error")
|
|
}
|
|
}
|
|
|
|
func TestEnsureKnownHostsFileCreateError(t *testing.T) {
|
|
t.Cleanup(resetKnownHostsFns)
|
|
statFn = func(string) (os.FileInfo, error) {
|
|
return nil, os.ErrNotExist
|
|
}
|
|
openFileFn = func(string, int, os.FileMode) (*os.File, error) {
|
|
return nil, errors.New("open failed")
|
|
}
|
|
|
|
m := &manager{path: filepath.Join(t.TempDir(), "known_hosts")}
|
|
if err := m.ensureKnownHostsFile(); err == nil {
|
|
t.Fatal("expected create error")
|
|
}
|
|
}
|
|
|
|
func TestAppendHostKeyOpenError(t *testing.T) {
|
|
t.Cleanup(resetKnownHostsFns)
|
|
appendOpenFileFn = func(string) (io.WriteCloser, error) {
|
|
return nil, errors.New("open failed")
|
|
}
|
|
|
|
if err := appendHostKey("ignored", [][]byte{[]byte("example.com ssh-ed25519 AAAA")}); err == nil {
|
|
t.Fatal("expected open error")
|
|
}
|
|
}
|
|
|
|
func TestAppendHostKeyWriteError(t *testing.T) {
|
|
t.Cleanup(resetKnownHostsFns)
|
|
appendOpenFileFn = func(string) (io.WriteCloser, error) {
|
|
return errWriteCloser{err: errors.New("write failed")}, nil
|
|
}
|
|
|
|
if err := appendHostKey("ignored", [][]byte{[]byte("example.com ssh-ed25519 AAAA")}); err == nil {
|
|
t.Fatal("expected write error")
|
|
}
|
|
}
|
|
|
|
func TestNormalizeHostEntryWithComment(t *testing.T) {
|
|
entry := []byte("example.com ssh-ed25519 AAAA comment here")
|
|
normalized, keyType, err := normalizeHostEntry("example.com", entry)
|
|
if err != nil {
|
|
t.Fatalf("normalizeHostEntry error: %v", err)
|
|
}
|
|
if keyType != "ssh-ed25519" {
|
|
t.Fatalf("expected key type ssh-ed25519, got %s", keyType)
|
|
}
|
|
if string(normalized) != "example.com ssh-ed25519 AAAA comment here" {
|
|
t.Fatalf("unexpected normalized entry: %s", string(normalized))
|
|
}
|
|
}
|
|
|
|
func TestFindHostKeyLineNotExists(t *testing.T) {
|
|
line, err := findHostKeyLine(filepath.Join(t.TempDir(), "missing"), "example.com", "")
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if line != "" {
|
|
t.Fatalf("expected empty line, got %q", line)
|
|
}
|
|
}
|
|
|
|
func TestFindHostKeyLineOpenError(t *testing.T) {
|
|
t.Cleanup(resetKnownHostsFns)
|
|
openFn = func(string) (*os.File, error) {
|
|
return nil, errors.New("open failed")
|
|
}
|
|
|
|
if _, err := findHostKeyLine("ignored", "example.com", ""); err == nil {
|
|
t.Fatal("expected open error")
|
|
}
|
|
}
|
|
|
|
func TestFindHostKeyLineScannerError(t *testing.T) {
|
|
dir := t.TempDir()
|
|
path := filepath.Join(dir, "known_hosts")
|
|
longLine := strings.Repeat("a", 70000)
|
|
if err := os.WriteFile(path, []byte(longLine+"\n"), 0600); err != nil {
|
|
t.Fatalf("failed to write file: %v", err)
|
|
}
|
|
|
|
if _, err := findHostKeyLine(path, "example.com", ""); err == nil {
|
|
t.Fatal("expected scanner error")
|
|
}
|
|
}
|
|
|
|
func TestHostLineMatchesSkips(t *testing.T) {
|
|
if hostLineMatches("example.com", "") {
|
|
t.Fatal("expected empty line to be false")
|
|
}
|
|
if hostLineMatches("example.com", "# comment") {
|
|
t.Fatal("expected comment line to be false")
|
|
}
|
|
if hostLineMatches("example.com", "|1|hash|salt ssh-ed25519 AAAA") {
|
|
t.Fatal("expected hashed entry to be false")
|
|
}
|
|
}
|
|
|
|
func TestDefaultKeyscanSuccess(t *testing.T) {
|
|
t.Cleanup(resetKnownHostsFns)
|
|
keyscanCmdRunner = func(ctx context.Context, args ...string) ([]byte, error) {
|
|
return []byte("example.com ssh-ed25519 AAAA"), nil
|
|
}
|
|
|
|
out, err := defaultKeyscan(context.Background(), "example.com", 22, time.Second)
|
|
if err != nil {
|
|
t.Fatalf("defaultKeyscan error: %v", err)
|
|
}
|
|
if string(out) != "example.com ssh-ed25519 AAAA" {
|
|
t.Fatalf("unexpected output: %s", string(out))
|
|
}
|
|
}
|
|
|
|
func TestDefaultKeyscanError(t *testing.T) {
|
|
t.Cleanup(resetKnownHostsFns)
|
|
keyscanCmdRunner = func(ctx context.Context, args ...string) ([]byte, error) {
|
|
return []byte("boom"), errors.New("scan failed")
|
|
}
|
|
|
|
if _, err := defaultKeyscan(context.Background(), "example.com", 22, time.Second); err == nil {
|
|
t.Fatal("expected error")
|
|
}
|
|
}
|
|
|
|
func TestEnsureWithEntriesDefaultsPort(t *testing.T) {
|
|
dir := t.TempDir()
|
|
path := filepath.Join(dir, "known_hosts")
|
|
|
|
mgr, err := NewManager(path)
|
|
if err != nil {
|
|
t.Fatalf("NewManager: %v", err)
|
|
}
|
|
|
|
if err := mgr.EnsureWithEntries(context.Background(), "example.com", 0, [][]byte{[]byte("example.com ssh-ed25519 AAAA")}); err != nil {
|
|
t.Fatalf("EnsureWithEntries: %v", err)
|
|
}
|
|
|
|
data, err := os.ReadFile(path)
|
|
if err != nil {
|
|
t.Fatalf("ReadFile: %v", err)
|
|
}
|
|
if got := strings.TrimSpace(string(data)); got != "example.com ssh-ed25519 AAAA" {
|
|
t.Fatalf("unexpected known_hosts contents: %s", got)
|
|
}
|
|
}
|
|
|
|
func TestEnsureWithEntriesAppendError(t *testing.T) {
|
|
t.Cleanup(resetKnownHostsFns)
|
|
appendOpenFileFn = func(string) (io.WriteCloser, error) {
|
|
return nil, errors.New("open failed")
|
|
}
|
|
|
|
mgr, err := NewManager(filepath.Join(t.TempDir(), "known_hosts"))
|
|
if err != nil {
|
|
t.Fatalf("NewManager: %v", err)
|
|
}
|
|
|
|
if err := mgr.EnsureWithEntries(context.Background(), "example.com", 22, [][]byte{[]byte("example.com ssh-ed25519 AAAA")}); err == nil {
|
|
t.Fatal("expected append error")
|
|
}
|
|
}
|
|
|
|
func TestAppendHostKeySkipsEmptyEntry(t *testing.T) {
|
|
t.Cleanup(resetKnownHostsFns)
|
|
buf := &bufferWriteCloser{}
|
|
appendOpenFileFn = func(string) (io.WriteCloser, error) {
|
|
return buf, nil
|
|
}
|
|
|
|
if err := appendHostKey("ignored", [][]byte{nil, []byte("example.com ssh-ed25519 AAAA")}); err != nil {
|
|
t.Fatalf("appendHostKey error: %v", err)
|
|
}
|
|
if !strings.Contains(buf.String(), "example.com ssh-ed25519 AAAA") {
|
|
t.Fatalf("expected entry to be written, got %q", buf.String())
|
|
}
|
|
}
|
|
|
|
func TestFindHostKeyLineSkipsInvalidLines(t *testing.T) {
|
|
dir := t.TempDir()
|
|
path := filepath.Join(dir, "known_hosts")
|
|
contents := strings.Join([]string{
|
|
"other.com ssh-ed25519 AAAA",
|
|
"example.com ssh-ed25519",
|
|
"example.com ssh-ed25519 AAAA",
|
|
}, "\n") + "\n"
|
|
if err := os.WriteFile(path, []byte(contents), 0600); err != nil {
|
|
t.Fatalf("failed to write file: %v", err)
|
|
}
|
|
|
|
line, err := findHostKeyLine(path, "example.com", "ssh-ed25519")
|
|
if err != nil {
|
|
t.Fatalf("findHostKeyLine error: %v", err)
|
|
}
|
|
if line != "example.com ssh-ed25519 AAAA" {
|
|
t.Fatalf("unexpected line: %q", line)
|
|
}
|
|
}
|
|
|
|
func TestDefaultKeyscanArgs(t *testing.T) {
|
|
t.Cleanup(resetKnownHostsFns)
|
|
|
|
var gotArgs []string
|
|
keyscanCmdRunner = func(ctx context.Context, args ...string) ([]byte, error) {
|
|
gotArgs = append([]string{}, args...)
|
|
return []byte("ok"), nil
|
|
}
|
|
|
|
if _, err := defaultKeyscan(context.Background(), "example.com", 0, 0); err != nil {
|
|
t.Fatalf("defaultKeyscan error: %v", err)
|
|
}
|
|
for _, arg := range gotArgs {
|
|
if arg == "-p" {
|
|
t.Fatal("did not expect -p for default port")
|
|
}
|
|
}
|
|
if len(gotArgs) < 3 || gotArgs[len(gotArgs)-1] != "example.com" {
|
|
t.Fatalf("unexpected args: %v", gotArgs)
|
|
}
|
|
|
|
keyscanCmdRunner = func(ctx context.Context, args ...string) ([]byte, error) {
|
|
gotArgs = append([]string{}, args...)
|
|
return []byte("ok"), nil
|
|
}
|
|
if _, err := defaultKeyscan(context.Background(), "example.com", 2222, time.Second); err != nil {
|
|
t.Fatalf("defaultKeyscan error: %v", err)
|
|
}
|
|
hasPort := false
|
|
for i := 0; i < len(gotArgs)-1; i++ {
|
|
if gotArgs[i] == "-p" && gotArgs[i+1] == "2222" {
|
|
hasPort = true
|
|
break
|
|
}
|
|
}
|
|
if !hasPort {
|
|
t.Fatalf("expected -p 2222 in args, got %v", gotArgs)
|
|
}
|
|
}
|
|
|
|
func TestKeyscanCmdRunnerDefault(t *testing.T) {
|
|
if runtime.GOOS == "windows" {
|
|
t.Skip("ssh-keyscan helper script requires sh")
|
|
}
|
|
t.Cleanup(resetKnownHostsFns)
|
|
|
|
dir := t.TempDir()
|
|
scriptPath := filepath.Join(dir, "ssh-keyscan")
|
|
script := []byte("#!/bin/sh\necho example.com ssh-ed25519 AAAA\n")
|
|
if err := os.WriteFile(scriptPath, script, 0700); err != nil {
|
|
t.Fatalf("failed to write script: %v", err)
|
|
}
|
|
|
|
oldPath := os.Getenv("PATH")
|
|
if err := os.Setenv("PATH", dir+string(os.PathListSeparator)+oldPath); err != nil {
|
|
t.Fatalf("failed to set PATH: %v", err)
|
|
}
|
|
t.Cleanup(func() { _ = os.Setenv("PATH", oldPath) })
|
|
|
|
output, err := keyscanCmdRunner(context.Background(), "example.com")
|
|
if err != nil {
|
|
t.Fatalf("keyscanCmdRunner error: %v", err)
|
|
}
|
|
if strings.TrimSpace(string(output)) != "example.com ssh-ed25519 AAAA" {
|
|
t.Fatalf("unexpected output: %s", string(output))
|
|
}
|
|
}
|
|
|
|
type errWriteCloser struct {
|
|
err error
|
|
}
|
|
|
|
func (e errWriteCloser) Write(p []byte) (int, error) {
|
|
return 0, e.err
|
|
}
|
|
|
|
func (e errWriteCloser) Close() error {
|
|
return nil
|
|
}
|
|
|
|
type bufferWriteCloser struct {
|
|
bytes.Buffer
|
|
}
|
|
|
|
func (b *bufferWriteCloser) Close() error {
|
|
return nil
|
|
}
|