mirror of
https://github.com/rcourtman/Pulse.git
synced 2026-04-28 11:30:15 +00:00
422 lines
9.9 KiB
Go
422 lines
9.9 KiB
Go
package knownhosts
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"os"
|
|
"os/exec"
|
|
"path/filepath"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
)
|
|
|
|
// Manager exposes operations for ensuring SSH host keys exist locally.
|
|
type Manager interface {
|
|
// Ensure guarantees that the host key for the provided host exists in the
|
|
// managed known_hosts file.
|
|
Ensure(ctx context.Context, host string) error
|
|
// EnsureWithPort guarantees that the host key for the provided host:port exists
|
|
// in the managed known_hosts file.
|
|
EnsureWithPort(ctx context.Context, host string, port int) error
|
|
// EnsureWithEntries installs provided host key entries for the given host/port.
|
|
EnsureWithEntries(ctx context.Context, host string, port int, entries [][]byte) error
|
|
// Path returns the absolute path to the managed known_hosts file.
|
|
Path() string
|
|
}
|
|
|
|
type manager struct {
|
|
path string
|
|
cache map[string]struct{}
|
|
mu sync.Mutex
|
|
keyscanFn keyscanFunc
|
|
keyscanTimeout time.Duration
|
|
}
|
|
|
|
type keyscanFunc func(ctx context.Context, host string, port int, timeout time.Duration) ([]byte, error)
|
|
|
|
const (
|
|
defaultKeyscanTimeout = 5 * time.Second
|
|
)
|
|
|
|
var (
|
|
mkdirAllFn = os.MkdirAll
|
|
statFn = os.Stat
|
|
openFileFn = os.OpenFile
|
|
openFn = os.Open
|
|
appendOpenFileFn = func(path string) (io.WriteCloser, error) {
|
|
return openFileFn(path, os.O_APPEND|os.O_WRONLY, 0o600)
|
|
}
|
|
keyscanCmdRunner = func(ctx context.Context, args ...string) ([]byte, error) {
|
|
cmd := exec.CommandContext(ctx, "ssh-keyscan", args...)
|
|
return cmd.CombinedOutput()
|
|
}
|
|
|
|
// ErrNoHostKeys is returned when ssh-keyscan yields no usable entries.
|
|
ErrNoHostKeys = errors.New("knownhosts: no host keys discovered")
|
|
// ErrHostKeyChanged signals that a host key already exists with a different fingerprint.
|
|
ErrHostKeyChanged = errors.New("knownhosts: host key changed")
|
|
)
|
|
|
|
var (
|
|
defaultMkdirAllFn = mkdirAllFn
|
|
defaultStatFn = statFn
|
|
defaultOpenFileFn = openFileFn
|
|
defaultOpenFn = openFn
|
|
defaultAppendOpenFileFn = appendOpenFileFn
|
|
defaultKeyscanCmdRunner = keyscanCmdRunner
|
|
)
|
|
|
|
// HostKeyChangeError describes a detected host key mismatch.
|
|
type HostKeyChangeError struct {
|
|
Host string
|
|
Existing string
|
|
Provided string
|
|
}
|
|
|
|
func (e *HostKeyChangeError) Error() string {
|
|
return fmt.Sprintf("knownhosts: host key for %s changed", e.Host)
|
|
}
|
|
|
|
func (e *HostKeyChangeError) Unwrap() error {
|
|
return ErrHostKeyChanged
|
|
}
|
|
|
|
// Option allows customizing Manager construction.
|
|
type Option func(*manager)
|
|
|
|
// WithTimeout overrides the ssh-keyscan timeout (defaults to 5 seconds).
|
|
func WithTimeout(d time.Duration) Option {
|
|
return func(m *manager) {
|
|
if d > 0 {
|
|
m.keyscanTimeout = d
|
|
}
|
|
}
|
|
}
|
|
|
|
// WithKeyscanFunc overrides the function used to fetch host keys (mainly for tests).
|
|
func WithKeyscanFunc(fn keyscanFunc) Option {
|
|
return func(m *manager) {
|
|
if fn != nil {
|
|
m.keyscanFn = fn
|
|
}
|
|
}
|
|
}
|
|
|
|
// NewManager returns a Manager writing to the supplied known_hosts path.
|
|
func NewManager(path string, opts ...Option) (Manager, error) {
|
|
if strings.TrimSpace(path) == "" {
|
|
return nil, fmt.Errorf("knownhosts: empty path")
|
|
}
|
|
|
|
m := &manager{
|
|
path: path,
|
|
cache: make(map[string]struct{}),
|
|
keyscanFn: defaultKeyscan,
|
|
keyscanTimeout: defaultKeyscanTimeout,
|
|
}
|
|
|
|
for _, opt := range opts {
|
|
opt(m)
|
|
}
|
|
|
|
return m, nil
|
|
}
|
|
|
|
// Ensure implements Manager.Ensure (uses default port 22).
|
|
func (m *manager) Ensure(ctx context.Context, host string) error {
|
|
return m.EnsureWithPort(ctx, host, 22)
|
|
}
|
|
|
|
// EnsureWithPort implements Manager.EnsureWithPort.
|
|
func (m *manager) EnsureWithPort(ctx context.Context, host string, port int) error {
|
|
if strings.TrimSpace(host) == "" {
|
|
return fmt.Errorf("knownhosts: missing host")
|
|
}
|
|
if port <= 0 {
|
|
port = 22 // Default to standard SSH port
|
|
}
|
|
|
|
hostSpec := host
|
|
if port != 22 {
|
|
hostSpec = fmt.Sprintf("[%s]:%d", host, port)
|
|
}
|
|
|
|
cacheKey := fmt.Sprintf("%s:%d", host, port)
|
|
m.mu.Lock()
|
|
_, cached := m.cache[cacheKey]
|
|
m.mu.Unlock()
|
|
if cached {
|
|
return nil
|
|
}
|
|
|
|
keyData, err := m.keyscanFn(ctx, host, port, m.keyscanTimeout)
|
|
if err != nil {
|
|
return fmt.Errorf("knownhosts: ssh-keyscan failed for %s:%d: %w", host, port, err)
|
|
}
|
|
|
|
entries := sanitizeKeyscanOutput(hostSpec, keyData)
|
|
if len(entries) == 0 {
|
|
return fmt.Errorf("%w for %s:%d", ErrNoHostKeys, host, port)
|
|
}
|
|
|
|
return m.EnsureWithEntries(ctx, host, port, entries)
|
|
}
|
|
|
|
// EnsureWithEntries installs the provided host key entries for host:port.
|
|
func (m *manager) EnsureWithEntries(ctx context.Context, host string, port int, entries [][]byte) error {
|
|
if strings.TrimSpace(host) == "" {
|
|
return fmt.Errorf("knownhosts: missing host")
|
|
}
|
|
if port <= 0 {
|
|
port = 22
|
|
}
|
|
if len(entries) == 0 {
|
|
return fmt.Errorf("knownhosts: no host key entries provided for %s", host)
|
|
}
|
|
|
|
cacheKey := fmt.Sprintf("%s:%d", host, port)
|
|
hostSpec := host
|
|
if port != 22 {
|
|
hostSpec = fmt.Sprintf("[%s]:%d", host, port)
|
|
}
|
|
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
|
|
if err := m.ensureKnownHostsFile(); err != nil {
|
|
return err
|
|
}
|
|
|
|
var toAppend [][]byte
|
|
for _, entry := range entries {
|
|
normalized, keyType, err := normalizeHostEntry(hostSpec, entry)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
existing, err := findHostKeyLine(m.path, hostSpec, keyType)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if existing != "" {
|
|
if existing != string(normalized) {
|
|
return &HostKeyChangeError{
|
|
Host: hostSpec,
|
|
Existing: existing,
|
|
Provided: string(normalized),
|
|
}
|
|
}
|
|
continue
|
|
}
|
|
|
|
toAppend = append(toAppend, normalized)
|
|
}
|
|
|
|
if len(toAppend) > 0 {
|
|
if err := appendHostKey(m.path, toAppend); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
m.cache[cacheKey] = struct{}{}
|
|
return nil
|
|
}
|
|
|
|
// Path implements Manager.Path.
|
|
func (m *manager) Path() string {
|
|
return m.path
|
|
}
|
|
|
|
func (m *manager) ensureKnownHostsFile() error {
|
|
dir := filepath.Dir(m.path)
|
|
if err := mkdirAllFn(dir, 0o700); err != nil {
|
|
return fmt.Errorf("knownhosts: mkdir %s: %w", dir, err)
|
|
}
|
|
|
|
if _, err := statFn(m.path); err == nil {
|
|
return nil
|
|
} else if !os.IsNotExist(err) {
|
|
return err
|
|
}
|
|
|
|
f, err := openFileFn(m.path, os.O_CREATE|os.O_WRONLY, 0o600)
|
|
if err != nil {
|
|
return fmt.Errorf("knownhosts: create %s: %w", m.path, err)
|
|
}
|
|
return f.Close()
|
|
}
|
|
|
|
func appendHostKey(path string, entries [][]byte) error {
|
|
f, err := appendOpenFileFn(path)
|
|
if err != nil {
|
|
return fmt.Errorf("knownhosts: open %s: %w", path, err)
|
|
}
|
|
defer f.Close()
|
|
|
|
for _, entry := range entries {
|
|
if len(entry) == 0 {
|
|
continue
|
|
}
|
|
if _, err := f.Write(append(entry, '\n')); err != nil {
|
|
return fmt.Errorf("knownhosts: write entry: %w", err)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func sanitizeKeyscanOutput(host string, raw []byte) [][]byte {
|
|
var entries [][]byte
|
|
|
|
lines := bytes.Split(raw, []byte{'\n'})
|
|
for _, line := range lines {
|
|
line = bytes.TrimSpace(line)
|
|
if len(line) == 0 {
|
|
continue
|
|
}
|
|
if hostLineMatches(host, string(line)) {
|
|
entries = append(entries, line)
|
|
}
|
|
}
|
|
|
|
return entries
|
|
}
|
|
|
|
func normalizeHostEntry(host string, entry []byte) ([]byte, string, error) {
|
|
trimmed := strings.TrimSpace(string(entry))
|
|
fields := strings.Fields(trimmed)
|
|
if len(fields) < 3 {
|
|
return nil, "", fmt.Errorf("knownhosts: invalid host key entry for %s", host)
|
|
}
|
|
|
|
keyType := fields[1]
|
|
keyData := fields[2]
|
|
var comment string
|
|
if len(fields) > 3 {
|
|
comment = strings.Join(fields[3:], " ")
|
|
}
|
|
|
|
if comment != "" {
|
|
return []byte(fmt.Sprintf("%s %s %s %s", host, keyType, keyData, comment)), keyType, nil
|
|
}
|
|
return []byte(fmt.Sprintf("%s %s %s", host, keyType, keyData)), keyType, nil
|
|
}
|
|
|
|
func findHostKeyLine(path, host, keyType string) (string, error) {
|
|
f, err := openFn(path)
|
|
if err != nil {
|
|
if os.IsNotExist(err) {
|
|
return "", nil
|
|
}
|
|
return "", err
|
|
}
|
|
defer f.Close()
|
|
|
|
scanner := bufio.NewScanner(f)
|
|
for scanner.Scan() {
|
|
line := scanner.Text()
|
|
if !hostLineMatches(host, line) {
|
|
continue
|
|
}
|
|
|
|
fields := strings.Fields(line)
|
|
if len(fields) < 3 {
|
|
continue
|
|
}
|
|
if keyType != "" && fields[1] != keyType {
|
|
continue
|
|
}
|
|
return strings.TrimSpace(line), nil
|
|
}
|
|
if err := scanner.Err(); err != nil {
|
|
return "", err
|
|
}
|
|
return "", nil
|
|
}
|
|
|
|
func hostLineMatches(host, line string) bool {
|
|
trimmed := strings.TrimSpace(line)
|
|
if trimmed == "" || strings.HasPrefix(trimmed, "#") {
|
|
return false
|
|
}
|
|
if strings.HasPrefix(trimmed, "|") {
|
|
return false // hashed entry; we only manage clear-text hosts
|
|
}
|
|
|
|
fields := strings.Fields(trimmed)
|
|
return hostFieldMatches(host, fields[0])
|
|
}
|
|
|
|
func hostFieldMatches(host, field string) bool {
|
|
for _, part := range strings.Split(field, ",") {
|
|
for _, candidate := range hostCandidates(part) {
|
|
if strings.EqualFold(candidate, host) {
|
|
return true
|
|
}
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
// HostFieldMatches reports whether a known_hosts host field matches the provided host.
|
|
func HostFieldMatches(host, field string) bool {
|
|
return hostFieldMatches(host, field)
|
|
}
|
|
|
|
func hostCandidates(part string) []string {
|
|
part = strings.TrimSpace(part)
|
|
if part == "" {
|
|
return nil
|
|
}
|
|
|
|
if strings.HasPrefix(part, "[") {
|
|
if idx := strings.Index(part, "]"); idx != -1 {
|
|
host := part[1:idx]
|
|
candidates := []string{part}
|
|
if host != "" {
|
|
candidates = append(candidates, host)
|
|
}
|
|
return candidates
|
|
}
|
|
}
|
|
|
|
candidates := []string{part}
|
|
if strings.Count(part, ":") == 1 {
|
|
if idx := strings.Index(part, ":"); idx > 0 {
|
|
candidates = append(candidates, part[:idx])
|
|
}
|
|
}
|
|
|
|
return candidates
|
|
}
|
|
|
|
func defaultKeyscan(ctx context.Context, host string, port int, timeout time.Duration) ([]byte, error) {
|
|
seconds := int(timeout.Round(time.Second) / time.Second)
|
|
if seconds <= 0 {
|
|
seconds = int(defaultKeyscanTimeout / time.Second)
|
|
}
|
|
if port <= 0 {
|
|
port = 22
|
|
}
|
|
|
|
scanCtx, cancel := context.WithTimeout(ctx, timeout)
|
|
defer cancel()
|
|
|
|
args := []string{"-T", strconv.Itoa(seconds)}
|
|
if port != 22 {
|
|
args = append(args, "-p", strconv.Itoa(port))
|
|
}
|
|
args = append(args, host)
|
|
|
|
output, err := keyscanCmdRunner(scanCtx, args...)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("%w (output: %s)", err, strings.TrimSpace(string(output)))
|
|
}
|
|
return output, nil
|
|
}
|