Pulse/internal/ssh/knownhosts/manager.go
2025-12-29 17:25:21 +00:00

422 lines
9.9 KiB
Go

package knownhosts
import (
"bufio"
"bytes"
"context"
"errors"
"fmt"
"io"
"os"
"os/exec"
"path/filepath"
"strconv"
"strings"
"sync"
"time"
)
// Manager exposes operations for ensuring SSH host keys exist locally.
type Manager interface {
// Ensure guarantees that the host key for the provided host exists in the
// managed known_hosts file.
Ensure(ctx context.Context, host string) error
// EnsureWithPort guarantees that the host key for the provided host:port exists
// in the managed known_hosts file.
EnsureWithPort(ctx context.Context, host string, port int) error
// EnsureWithEntries installs provided host key entries for the given host/port.
EnsureWithEntries(ctx context.Context, host string, port int, entries [][]byte) error
// Path returns the absolute path to the managed known_hosts file.
Path() string
}
type manager struct {
path string
cache map[string]struct{}
mu sync.Mutex
keyscanFn keyscanFunc
keyscanTimeout time.Duration
}
type keyscanFunc func(ctx context.Context, host string, port int, timeout time.Duration) ([]byte, error)
const (
defaultKeyscanTimeout = 5 * time.Second
)
var (
mkdirAllFn = os.MkdirAll
statFn = os.Stat
openFileFn = os.OpenFile
openFn = os.Open
appendOpenFileFn = func(path string) (io.WriteCloser, error) {
return openFileFn(path, os.O_APPEND|os.O_WRONLY, 0o600)
}
keyscanCmdRunner = func(ctx context.Context, args ...string) ([]byte, error) {
cmd := exec.CommandContext(ctx, "ssh-keyscan", args...)
return cmd.CombinedOutput()
}
// ErrNoHostKeys is returned when ssh-keyscan yields no usable entries.
ErrNoHostKeys = errors.New("knownhosts: no host keys discovered")
// ErrHostKeyChanged signals that a host key already exists with a different fingerprint.
ErrHostKeyChanged = errors.New("knownhosts: host key changed")
)
var (
defaultMkdirAllFn = mkdirAllFn
defaultStatFn = statFn
defaultOpenFileFn = openFileFn
defaultOpenFn = openFn
defaultAppendOpenFileFn = appendOpenFileFn
defaultKeyscanCmdRunner = keyscanCmdRunner
)
// HostKeyChangeError describes a detected host key mismatch.
type HostKeyChangeError struct {
Host string
Existing string
Provided string
}
func (e *HostKeyChangeError) Error() string {
return fmt.Sprintf("knownhosts: host key for %s changed", e.Host)
}
func (e *HostKeyChangeError) Unwrap() error {
return ErrHostKeyChanged
}
// Option allows customizing Manager construction.
type Option func(*manager)
// WithTimeout overrides the ssh-keyscan timeout (defaults to 5 seconds).
func WithTimeout(d time.Duration) Option {
return func(m *manager) {
if d > 0 {
m.keyscanTimeout = d
}
}
}
// WithKeyscanFunc overrides the function used to fetch host keys (mainly for tests).
func WithKeyscanFunc(fn keyscanFunc) Option {
return func(m *manager) {
if fn != nil {
m.keyscanFn = fn
}
}
}
// NewManager returns a Manager writing to the supplied known_hosts path.
func NewManager(path string, opts ...Option) (Manager, error) {
if strings.TrimSpace(path) == "" {
return nil, fmt.Errorf("knownhosts: empty path")
}
m := &manager{
path: path,
cache: make(map[string]struct{}),
keyscanFn: defaultKeyscan,
keyscanTimeout: defaultKeyscanTimeout,
}
for _, opt := range opts {
opt(m)
}
return m, nil
}
// Ensure implements Manager.Ensure (uses default port 22).
func (m *manager) Ensure(ctx context.Context, host string) error {
return m.EnsureWithPort(ctx, host, 22)
}
// EnsureWithPort implements Manager.EnsureWithPort.
func (m *manager) EnsureWithPort(ctx context.Context, host string, port int) error {
if strings.TrimSpace(host) == "" {
return fmt.Errorf("knownhosts: missing host")
}
if port <= 0 {
port = 22 // Default to standard SSH port
}
hostSpec := host
if port != 22 {
hostSpec = fmt.Sprintf("[%s]:%d", host, port)
}
cacheKey := fmt.Sprintf("%s:%d", host, port)
m.mu.Lock()
_, cached := m.cache[cacheKey]
m.mu.Unlock()
if cached {
return nil
}
keyData, err := m.keyscanFn(ctx, host, port, m.keyscanTimeout)
if err != nil {
return fmt.Errorf("knownhosts: ssh-keyscan failed for %s:%d: %w", host, port, err)
}
entries := sanitizeKeyscanOutput(hostSpec, keyData)
if len(entries) == 0 {
return fmt.Errorf("%w for %s:%d", ErrNoHostKeys, host, port)
}
return m.EnsureWithEntries(ctx, host, port, entries)
}
// EnsureWithEntries installs the provided host key entries for host:port.
func (m *manager) EnsureWithEntries(ctx context.Context, host string, port int, entries [][]byte) error {
if strings.TrimSpace(host) == "" {
return fmt.Errorf("knownhosts: missing host")
}
if port <= 0 {
port = 22
}
if len(entries) == 0 {
return fmt.Errorf("knownhosts: no host key entries provided for %s", host)
}
cacheKey := fmt.Sprintf("%s:%d", host, port)
hostSpec := host
if port != 22 {
hostSpec = fmt.Sprintf("[%s]:%d", host, port)
}
m.mu.Lock()
defer m.mu.Unlock()
if err := m.ensureKnownHostsFile(); err != nil {
return err
}
var toAppend [][]byte
for _, entry := range entries {
normalized, keyType, err := normalizeHostEntry(hostSpec, entry)
if err != nil {
return err
}
existing, err := findHostKeyLine(m.path, hostSpec, keyType)
if err != nil {
return err
}
if existing != "" {
if existing != string(normalized) {
return &HostKeyChangeError{
Host: hostSpec,
Existing: existing,
Provided: string(normalized),
}
}
continue
}
toAppend = append(toAppend, normalized)
}
if len(toAppend) > 0 {
if err := appendHostKey(m.path, toAppend); err != nil {
return err
}
}
m.cache[cacheKey] = struct{}{}
return nil
}
// Path implements Manager.Path.
func (m *manager) Path() string {
return m.path
}
func (m *manager) ensureKnownHostsFile() error {
dir := filepath.Dir(m.path)
if err := mkdirAllFn(dir, 0o700); err != nil {
return fmt.Errorf("knownhosts: mkdir %s: %w", dir, err)
}
if _, err := statFn(m.path); err == nil {
return nil
} else if !os.IsNotExist(err) {
return err
}
f, err := openFileFn(m.path, os.O_CREATE|os.O_WRONLY, 0o600)
if err != nil {
return fmt.Errorf("knownhosts: create %s: %w", m.path, err)
}
return f.Close()
}
func appendHostKey(path string, entries [][]byte) error {
f, err := appendOpenFileFn(path)
if err != nil {
return fmt.Errorf("knownhosts: open %s: %w", path, err)
}
defer f.Close()
for _, entry := range entries {
if len(entry) == 0 {
continue
}
if _, err := f.Write(append(entry, '\n')); err != nil {
return fmt.Errorf("knownhosts: write entry: %w", err)
}
}
return nil
}
func sanitizeKeyscanOutput(host string, raw []byte) [][]byte {
var entries [][]byte
lines := bytes.Split(raw, []byte{'\n'})
for _, line := range lines {
line = bytes.TrimSpace(line)
if len(line) == 0 {
continue
}
if hostLineMatches(host, string(line)) {
entries = append(entries, line)
}
}
return entries
}
func normalizeHostEntry(host string, entry []byte) ([]byte, string, error) {
trimmed := strings.TrimSpace(string(entry))
fields := strings.Fields(trimmed)
if len(fields) < 3 {
return nil, "", fmt.Errorf("knownhosts: invalid host key entry for %s", host)
}
keyType := fields[1]
keyData := fields[2]
var comment string
if len(fields) > 3 {
comment = strings.Join(fields[3:], " ")
}
if comment != "" {
return []byte(fmt.Sprintf("%s %s %s %s", host, keyType, keyData, comment)), keyType, nil
}
return []byte(fmt.Sprintf("%s %s %s", host, keyType, keyData)), keyType, nil
}
func findHostKeyLine(path, host, keyType string) (string, error) {
f, err := openFn(path)
if err != nil {
if os.IsNotExist(err) {
return "", nil
}
return "", err
}
defer f.Close()
scanner := bufio.NewScanner(f)
for scanner.Scan() {
line := scanner.Text()
if !hostLineMatches(host, line) {
continue
}
fields := strings.Fields(line)
if len(fields) < 3 {
continue
}
if keyType != "" && fields[1] != keyType {
continue
}
return strings.TrimSpace(line), nil
}
if err := scanner.Err(); err != nil {
return "", err
}
return "", nil
}
func hostLineMatches(host, line string) bool {
trimmed := strings.TrimSpace(line)
if trimmed == "" || strings.HasPrefix(trimmed, "#") {
return false
}
if strings.HasPrefix(trimmed, "|") {
return false // hashed entry; we only manage clear-text hosts
}
fields := strings.Fields(trimmed)
return hostFieldMatches(host, fields[0])
}
func hostFieldMatches(host, field string) bool {
for _, part := range strings.Split(field, ",") {
for _, candidate := range hostCandidates(part) {
if strings.EqualFold(candidate, host) {
return true
}
}
}
return false
}
// HostFieldMatches reports whether a known_hosts host field matches the provided host.
func HostFieldMatches(host, field string) bool {
return hostFieldMatches(host, field)
}
func hostCandidates(part string) []string {
part = strings.TrimSpace(part)
if part == "" {
return nil
}
if strings.HasPrefix(part, "[") {
if idx := strings.Index(part, "]"); idx != -1 {
host := part[1:idx]
candidates := []string{part}
if host != "" {
candidates = append(candidates, host)
}
return candidates
}
}
candidates := []string{part}
if strings.Count(part, ":") == 1 {
if idx := strings.Index(part, ":"); idx > 0 {
candidates = append(candidates, part[:idx])
}
}
return candidates
}
func defaultKeyscan(ctx context.Context, host string, port int, timeout time.Duration) ([]byte, error) {
seconds := int(timeout.Round(time.Second) / time.Second)
if seconds <= 0 {
seconds = int(defaultKeyscanTimeout / time.Second)
}
if port <= 0 {
port = 22
}
scanCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
args := []string{"-T", strconv.Itoa(seconds)}
if port != 22 {
args = append(args, "-p", strconv.Itoa(port))
}
args = append(args, host)
output, err := keyscanCmdRunner(scanCtx, args...)
if err != nil {
return nil, fmt.Errorf("%w (output: %s)", err, strings.TrimSpace(string(output)))
}
return output, nil
}