test(envdetect): cover environment detection decisions

This commit is contained in:
rcourtman 2025-12-17 16:08:10 +00:00
parent 71e1b5dc86
commit e44a6fdadd
2 changed files with 375 additions and 40 deletions

View file

@ -0,0 +1,287 @@
package envdetect
import (
"errors"
"net"
"os"
"os/exec"
"strings"
"testing"
"time"
)
type fakeFileInfo struct{}
func (fakeFileInfo) Name() string { return "fake" }
func (fakeFileInfo) Size() int64 { return 0 }
func (fakeFileInfo) Mode() os.FileMode { return 0 }
func (fakeFileInfo) ModTime() (t time.Time) {
return t
}
func (fakeFileInfo) IsDir() bool { return false }
func (fakeFileInfo) Sys() interface{} { return nil }
type fakeEnvironmentProbe struct {
lookPathPresent map[string]bool
commandOutput map[string][]byte
commandErr map[string]error
fileData map[string][]byte
fileErr map[string]error
statPresent map[string]bool
interfaces []ifaceInfo
interfacesErr error
}
func (p fakeEnvironmentProbe) LookPath(file string) (string, error) {
if p.lookPathPresent[file] {
return "/usr/bin/" + file, nil
}
return "", &exec.Error{Name: file, Err: exec.ErrNotFound}
}
func (p fakeEnvironmentProbe) CommandCombinedOutput(name string, args ...string) ([]byte, error) {
key := name + "\x00" + strings.Join(args, "\x00")
if out, ok := p.commandOutput[key]; ok {
return out, p.commandErr[key]
}
return nil, errors.New("unexpected command invocation")
}
func (p fakeEnvironmentProbe) Stat(name string) (os.FileInfo, error) {
if p.statPresent[name] {
return fakeFileInfo{}, nil
}
return nil, &os.PathError{Op: "stat", Path: name, Err: os.ErrNotExist}
}
func (p fakeEnvironmentProbe) ReadFile(name string) ([]byte, error) {
if err, ok := p.fileErr[name]; ok && err != nil {
return nil, err
}
if data, ok := p.fileData[name]; ok {
return data, nil
}
return nil, &os.PathError{Op: "open", Path: name, Err: os.ErrNotExist}
}
func (p fakeEnvironmentProbe) Interfaces() ([]ifaceInfo, error) {
if p.interfacesErr != nil {
return nil, p.interfacesErr
}
return p.interfaces, nil
}
func mustIPNet(t *testing.T, cidr string) *net.IPNet {
t.Helper()
_, network, err := net.ParseCIDR(cidr)
if err != nil {
t.Fatalf("net.ParseCIDR(%q): %v", cidr, err)
}
return network
}
func TestDetectEnvironment_Native(t *testing.T) {
probe := fakeEnvironmentProbe{
interfaces: []ifaceInfo{
{
Name: "eth0",
Flags: net.FlagUp,
Addrs: []net.Addr{
&net.IPNet{IP: net.IPv4(192, 168, 1, 10), Mask: net.CIDRMask(24, 32)},
},
},
},
}
profile, err := detectEnvironment(probe)
if err != nil {
t.Fatalf("detectEnvironment returned error: %v", err)
}
if profile.Type != Native {
t.Fatalf("Type = %v, want %v", profile.Type, Native)
}
if profile.Metadata["container_detected"] != "false" {
t.Fatalf("container_detected = %q, want false", profile.Metadata["container_detected"])
}
if len(profile.Phases) != 1 || profile.Phases[0].Name != "local_networks" {
t.Fatalf("Phases = %#v, want single local_networks phase", profile.Phases)
}
if got := profile.Phases[0].Subnets[0].String(); got != mustIPNet(t, "192.168.1.0/24").String() {
t.Fatalf("subnet = %q, want %q", got, "192.168.1.0/24")
}
}
func TestDetectEnvironment_DockerHostMode(t *testing.T) {
route := strings.Join([]string{
"Iface\tDestination\tGateway\tFlags\tRefCnt\tUse\tMetric\tMask\tMTU\tWindow\tIRTT",
"lo\t00000000\t00000000\t0003\t0\t0\t0\t00000000\t0\t0\t0",
"eth0\t00000000\t010011AC\t0003\t0\t0\t0\t00000000\t0\t0\t0",
"eth0\t000011AC\t00000000\t0001\t0\t0\t0\t0000FFFF\t0\t0\t0",
"eth1\t00000000\t010011AC\t0003\t0\t0\t0\t00000000\t0\t0\t0",
"eth2\t00000000\t010011AC\t0003\t0\t0\t0\t00000000\t0\t0\t0",
"eth3\t00000000\t010011AC\t0003\t0\t0\t0\t00000000\t0\t0\t0",
}, "\n")
probe := fakeEnvironmentProbe{
statPresent: map[string]bool{"/.dockerenv": true},
fileData: map[string][]byte{"/proc/net/route": []byte(route)},
interfaces: []ifaceInfo{
{Name: "lo", Flags: net.FlagUp | net.FlagLoopback},
{
Name: "eth0",
Flags: net.FlagUp,
Addrs: []net.Addr{
&net.IPNet{IP: net.IPv4(10, 0, 0, 10), Mask: net.CIDRMask(24, 32)},
},
},
{Name: "eth1", Flags: net.FlagUp},
{Name: "eth2", Flags: net.FlagUp},
},
}
profile, err := detectEnvironment(probe)
if err != nil {
t.Fatalf("detectEnvironment returned error: %v", err)
}
if profile.Type != DockerHost {
t.Fatalf("Type = %v, want %v", profile.Type, DockerHost)
}
if profile.Metadata["docker_mode"] != "host" {
t.Fatalf("docker_mode = %q, want host", profile.Metadata["docker_mode"])
}
if len(profile.Phases) != 1 || profile.Phases[0].Name != "host_networks" {
t.Fatalf("Phases = %#v, want single host_networks phase", profile.Phases)
}
if got := profile.Phases[0].Subnets[0].String(); got != mustIPNet(t, "10.0.0.0/24").String() {
t.Fatalf("subnet = %q, want %q", got, "10.0.0.0/24")
}
}
func TestDetectEnvironment_DockerBridge_InferredHostNetwork(t *testing.T) {
route := strings.Join([]string{
"Iface\tDestination\tGateway\tFlags\tRefCnt\tUse\tMetric\tMask\tMTU\tWindow\tIRTT",
"eth0\t00000000\t010011AC\t0003\t0\t0\t0\t00000000\t0\t0\t0",
"eth0\t000011AC\t00000000\t0001\t0\t0\t0\t0000FFFF\t0\t0\t0",
}, "\n")
probe := fakeEnvironmentProbe{
statPresent: map[string]bool{"/.dockerenv": true},
fileData: map[string][]byte{"/proc/net/route": []byte(route)},
interfaces: []ifaceInfo{
{
Name: "eth0",
Flags: net.FlagUp,
Addrs: []net.Addr{
&net.IPNet{IP: net.IPv4(172, 17, 0, 2), Mask: net.CIDRMask(16, 32)},
},
},
},
}
profile, err := detectEnvironment(probe)
if err != nil {
t.Fatalf("detectEnvironment returned error: %v", err)
}
if profile.Type != DockerBridge {
t.Fatalf("Type = %v, want %v", profile.Type, DockerBridge)
}
var containerFound, inferredFound bool
for _, phase := range profile.Phases {
switch phase.Name {
case "container_network":
containerFound = true
if got := phase.Subnets[0].String(); got != mustIPNet(t, "172.17.0.0/16").String() {
t.Fatalf("container subnet = %q, want %q", got, "172.17.0.0/16")
}
case "inferred_host_network":
inferredFound = true
if got := phase.Subnets[0].String(); got != mustIPNet(t, "172.17.0.0/24").String() {
t.Fatalf("inferred subnet = %q, want %q", got, "172.17.0.0/24")
}
if phase.Confidence != 0.7 {
t.Fatalf("inferred confidence = %v, want 0.7", phase.Confidence)
}
}
}
if !containerFound || !inferredFound {
t.Fatalf("expected both container_network and inferred_host_network phases, got %#v", profile.Phases)
}
}
func TestDetectEnvironment_LXCPrivileged_SystemdDetectVirt(t *testing.T) {
probe := fakeEnvironmentProbe{
lookPathPresent: map[string]bool{"systemd-detect-virt": true},
commandOutput: map[string][]byte{
"systemd-detect-virt\x00--container": []byte("lxc\n"),
},
fileData: map[string][]byte{
"/proc/self/uid_map": []byte("0 0 4294967295\n"),
},
interfaces: []ifaceInfo{
{
Name: "eth0",
Flags: net.FlagUp,
Addrs: []net.Addr{
&net.IPNet{IP: net.IPv4(192, 168, 50, 10), Mask: net.CIDRMask(24, 32)},
},
},
},
}
profile, err := detectEnvironment(probe)
if err != nil {
t.Fatalf("detectEnvironment returned error: %v", err)
}
if profile.Type != LXCPrivileged {
t.Fatalf("Type = %v, want %v", profile.Type, LXCPrivileged)
}
if profile.Metadata["lxc_privileged"] != "true" {
t.Fatalf("lxc_privileged = %q, want true", profile.Metadata["lxc_privileged"])
}
if len(profile.Phases) != 1 || profile.Phases[0].Name != "lxc_host_networks" {
t.Fatalf("Phases = %#v, want single lxc_host_networks phase", profile.Phases)
}
}
func TestGetDefaultGateway_DefaultRouteNotFound(t *testing.T) {
route := strings.Join([]string{
"Iface\tDestination\tGateway\tFlags\tRefCnt\tUse\tMetric\tMask\tMTU\tWindow\tIRTT",
"eth0\t000011AC\t00000000\t0001\t0\t0\t0\t0000FFFF\t0\t0\t0",
}, "\n")
probe := fakeEnvironmentProbe{
fileData: map[string][]byte{"/proc/net/route": []byte(route)},
}
_, err := getDefaultGateway(probe)
if err == nil {
t.Fatalf("expected error, got nil")
}
if !strings.Contains(err.Error(), "default gateway not found") {
t.Fatalf("unexpected error: %v", err)
}
}
func TestCountKernelRoutes_SkipsHeaderAndBlankLines(t *testing.T) {
route := strings.Join([]string{
"Iface\tDestination\tGateway\tFlags\tRefCnt\tUse\tMetric\tMask\tMTU\tWindow\tIRTT",
"",
"eth0\t00000000\t010011AC\t0003\t0\t0\t0\t00000000\t0\t0\t0",
"eth0\t000011AC\t00000000\t0001\t0\t0\t0\t0000FFFF\t0\t0\t0",
" ",
}, "\n")
probe := fakeEnvironmentProbe{
fileData: map[string][]byte{"/proc/net/route": []byte(route)},
}
count, warn := countKernelRoutes(probe)
if warn != "" {
t.Fatalf("warn = %q, want empty", warn)
}
if count != 2 {
t.Fatalf("count = %d, want 2", count)
}
}

View file

@ -15,6 +15,52 @@ import (
"github.com/rs/zerolog/log"
)
type ifaceInfo struct {
Name string
Flags net.Flags
Addrs []net.Addr
AddrsErr error
}
type environmentProbe interface {
LookPath(file string) (string, error)
CommandCombinedOutput(name string, args ...string) ([]byte, error)
Stat(name string) (os.FileInfo, error)
ReadFile(name string) ([]byte, error)
Interfaces() ([]ifaceInfo, error)
}
type systemEnvironmentProbe struct{}
func (systemEnvironmentProbe) LookPath(file string) (string, error) { return exec.LookPath(file) }
func (systemEnvironmentProbe) CommandCombinedOutput(name string, args ...string) ([]byte, error) {
return exec.Command(name, args...).CombinedOutput()
}
func (systemEnvironmentProbe) Stat(name string) (os.FileInfo, error) { return os.Stat(name) }
func (systemEnvironmentProbe) ReadFile(name string) ([]byte, error) { return os.ReadFile(name) }
func (systemEnvironmentProbe) Interfaces() ([]ifaceInfo, error) {
interfaces, err := net.Interfaces()
if err != nil {
return nil, err
}
out := make([]ifaceInfo, 0, len(interfaces))
for _, iface := range interfaces {
addrs, addrsErr := iface.Addrs()
out = append(out, ifaceInfo{
Name: iface.Name,
Flags: iface.Flags,
Addrs: addrs,
AddrsErr: addrsErr,
})
}
return out, nil
}
// Environment represents the runtime environment type.
type Environment int
@ -95,6 +141,10 @@ type EnvironmentProfile struct {
// DetectEnvironment performs environment detection and returns a profile.
func DetectEnvironment() (*EnvironmentProfile, error) {
return detectEnvironment(systemEnvironmentProbe{})
}
func detectEnvironment(probe environmentProbe) (*EnvironmentProfile, error) {
profile := &EnvironmentProfile{
Type: Unknown,
Phases: []SubnetPhase{},
@ -106,7 +156,7 @@ func DetectEnvironment() (*EnvironmentProfile, error) {
log.Info().Msg("Detecting runtime environment")
isContainer, containerType := detectContainer()
isContainer, containerType := detectContainer(probe)
profile.Metadata["container_detected"] = strconv.FormatBool(isContainer)
if containerType != "" {
profile.Metadata["container_type"] = containerType
@ -115,11 +165,11 @@ func DetectEnvironment() (*EnvironmentProfile, error) {
var err error
switch {
case !isContainer:
profile, err = detectNativeEnvironment(profile)
profile, err = detectNativeEnvironment(profile, probe)
case containerType == "docker":
profile, err = detectDockerEnvironment(profile)
profile, err = detectDockerEnvironment(profile, probe)
case containerType == "lxc":
profile, err = detectLXCEnvironment(profile)
profile, err = detectLXCEnvironment(profile, probe)
default:
profile.Type = Unknown
profile.Confidence = 0.3
@ -152,13 +202,12 @@ func DetectEnvironment() (*EnvironmentProfile, error) {
}
// detectContainer inspects the host to determine whether we are inside a container.
func detectContainer() (bool, string) {
func detectContainer(probe environmentProbe) (bool, string) {
containerType := ""
// 1. systemd-detect-virt --container
if _, err := exec.LookPath("systemd-detect-virt"); err == nil {
cmd := exec.Command("systemd-detect-virt", "--container")
output, err := cmd.CombinedOutput()
if _, err := probe.LookPath("systemd-detect-virt"); err == nil {
output, err := probe.CommandCombinedOutput("systemd-detect-virt", "--container")
if len(output) > 0 {
result := strings.TrimSpace(string(output))
if result != "" && result != "none" {
@ -179,17 +228,17 @@ func detectContainer() (bool, string) {
}
// 2. Marker files
if _, err := os.Stat("/.dockerenv"); err == nil {
if _, err := probe.Stat("/.dockerenv"); err == nil {
log.Debug().Msg("Detected /.dockerenv marker (Docker container)")
return true, "docker"
}
if _, err := os.Stat("/run/.containerenv"); err == nil {
if _, err := probe.Stat("/run/.containerenv"); err == nil {
log.Debug().Msg("Detected /run/.containerenv marker (Podman/OCI container)")
return true, "docker"
}
// 3. /proc/1/cgroup
if data, err := os.ReadFile("/proc/1/cgroup"); err == nil {
if data, err := probe.ReadFile("/proc/1/cgroup"); err == nil {
text := string(data)
switch {
case strings.Contains(text, "docker"), strings.Contains(text, "kubepods"), strings.Contains(text, "containerd"):
@ -204,7 +253,7 @@ func detectContainer() (bool, string) {
}
// 4. /proc/1/environ
if data, err := os.ReadFile("/proc/1/environ"); err == nil {
if data, err := probe.ReadFile("/proc/1/environ"); err == nil {
text := string(data)
switch {
case strings.Contains(text, "container=lxc"):
@ -223,8 +272,8 @@ func detectContainer() (bool, string) {
}
// detectNativeEnvironment builds an EnvironmentProfile for native or VM deployments.
func detectNativeEnvironment(profile *EnvironmentProfile) (*EnvironmentProfile, error) {
subnets, err := getAllLocalSubnets()
func detectNativeEnvironment(profile *EnvironmentProfile, probe environmentProbe) (*EnvironmentProfile, error) {
subnets, err := getAllLocalSubnets(probe)
if err != nil {
return addFallbackSubnets(profileWithWarning(profile, fmt.Sprintf("Failed to enumerate interfaces: %v", err)))
}
@ -249,14 +298,14 @@ func detectNativeEnvironment(profile *EnvironmentProfile) (*EnvironmentProfile,
}
// detectDockerEnvironment determines whether Docker uses host or bridge networking.
func detectDockerEnvironment(profile *EnvironmentProfile) (*EnvironmentProfile, error) {
hostMode, hostModeWarnings := isDockerHostMode()
func detectDockerEnvironment(profile *EnvironmentProfile, probe environmentProbe) (*EnvironmentProfile, error) {
hostMode, hostModeWarnings := isDockerHostMode(probe)
if len(hostModeWarnings) > 0 {
profile.Warnings = append(profile.Warnings, hostModeWarnings...)
}
if hostMode {
subnets, err := getAllLocalSubnets()
subnets, err := getAllLocalSubnets(probe)
if err != nil {
return addFallbackSubnets(profileWithWarning(profile, fmt.Sprintf("Docker host mode: failed to enumerate subnets: %v", err)))
}
@ -285,7 +334,7 @@ func detectDockerEnvironment(profile *EnvironmentProfile) (*EnvironmentProfile,
profile.Confidence = 0.85
profile.Metadata["docker_mode"] = "bridge"
containerSubnets, err := getAllLocalSubnets()
containerSubnets, err := getAllLocalSubnets(probe)
if err != nil {
profile.Warnings = append(profile.Warnings, fmt.Sprintf("Docker bridge: failed to enumerate container subnets: %v", err))
} else if len(containerSubnets) > 0 {
@ -301,7 +350,7 @@ func detectDockerEnvironment(profile *EnvironmentProfile) (*EnvironmentProfile,
}
if profile.Policy.ScanGateways {
hostSubnets, confidence, warnings := detectHostNetworkFromContainer()
hostSubnets, confidence, warnings := detectHostNetworkFromContainer(probe)
if len(warnings) > 0 {
profile.Warnings = append(profile.Warnings, warnings...)
}
@ -324,13 +373,13 @@ func detectDockerEnvironment(profile *EnvironmentProfile) (*EnvironmentProfile,
}
// detectLXCEnvironment evaluates privilege level and prepares scanning phases.
func detectLXCEnvironment(profile *EnvironmentProfile) (*EnvironmentProfile, error) {
privileged, warn := isLXCPrivileged()
func detectLXCEnvironment(profile *EnvironmentProfile, probe environmentProbe) (*EnvironmentProfile, error) {
privileged, warn := isLXCPrivileged(probe)
if warn != "" {
profile.Warnings = append(profile.Warnings, warn)
}
containerSubnets, err := getAllLocalSubnets()
containerSubnets, err := getAllLocalSubnets(probe)
if err != nil {
profile.Warnings = append(profile.Warnings, fmt.Sprintf("LXC: failed to enumerate container subnets: %v", err))
}
@ -371,7 +420,7 @@ func detectLXCEnvironment(profile *EnvironmentProfile) (*EnvironmentProfile, err
}
if profile.Policy.ScanGateways {
hostSubnets, confidence, warnings := detectHostNetworkFromContainer()
hostSubnets, confidence, warnings := detectHostNetworkFromContainer(probe)
if len(warnings) > 0 {
profile.Warnings = append(profile.Warnings, warnings...)
}
@ -394,17 +443,17 @@ func detectLXCEnvironment(profile *EnvironmentProfile) (*EnvironmentProfile, err
}
// isDockerHostMode attempts to determine whether Docker is using host networking.
func isDockerHostMode() (bool, []string) {
func isDockerHostMode(probe environmentProbe) (bool, []string) {
var warnings []string
interfaces, err := net.Interfaces()
interfaces, err := probe.Interfaces()
if err != nil {
log.Debug().Err(err).Msg("Failed to enumerate interfaces while detecting Docker mode")
warnings = append(warnings, fmt.Sprintf("Unable to enumerate interfaces: %v", err))
return false, warnings
}
routeCount, routeWarn := countKernelRoutes()
routeCount, routeWarn := countKernelRoutes(probe)
if routeWarn != "" {
warnings = append(warnings, routeWarn)
}
@ -423,8 +472,8 @@ func isDockerHostMode() (bool, []string) {
}
// isLXCPrivileged inspects UID mappings to determine privilege level.
func isLXCPrivileged() (bool, string) {
data, err := os.ReadFile("/proc/self/uid_map")
func isLXCPrivileged(probe environmentProbe) (bool, string) {
data, err := probe.ReadFile("/proc/self/uid_map")
if err != nil {
if errors.Is(err, os.ErrPermission) {
return false, "Unable to read /proc/self/uid_map (permission denied); assuming unprivileged LXC"
@ -456,8 +505,8 @@ func isLXCPrivileged() (bool, string) {
}
// getAllLocalSubnets enumerates non-loopback, UP IPv4 subnets.
func getAllLocalSubnets() ([]net.IPNet, error) {
interfaces, err := net.Interfaces()
func getAllLocalSubnets(probe environmentProbe) ([]net.IPNet, error) {
interfaces, err := probe.Interfaces()
if err != nil {
return nil, fmt.Errorf("failed to list interfaces: %w", err)
}
@ -470,13 +519,12 @@ func getAllLocalSubnets() ([]net.IPNet, error) {
continue
}
addrs, err := iface.Addrs()
if err != nil {
log.Debug().Err(err).Str("interface", iface.Name).Msg("Skipping interface due to address enumeration failure")
if iface.AddrsErr != nil {
log.Debug().Err(iface.AddrsErr).Str("interface", iface.Name).Msg("Skipping interface due to address enumeration failure")
continue
}
for _, addr := range addrs {
for _, addr := range iface.Addrs {
ipNet, ok := addr.(*net.IPNet)
if !ok || ipNet == nil {
continue
@ -512,10 +560,10 @@ func getAllLocalSubnets() ([]net.IPNet, error) {
}
// detectHostNetworkFromContainer infers host LAN subnets from container context.
func detectHostNetworkFromContainer() ([]net.IPNet, float64, []string) {
func detectHostNetworkFromContainer(probe environmentProbe) ([]net.IPNet, float64, []string) {
var warnings []string
gateway, err := getDefaultGateway()
gateway, err := getDefaultGateway(probe)
if err != nil {
warnings = append(warnings, fmt.Sprintf("Could not determine default gateway: %v", err))
return tryCommonSubnets(), 0.3, warnings
@ -553,8 +601,8 @@ func detectHostNetworkFromContainer() ([]net.IPNet, float64, []string) {
}
// getDefaultGateway parses /proc/net/route for the default gateway.
func getDefaultGateway() (net.IP, error) {
data, err := os.ReadFile("/proc/net/route")
func getDefaultGateway(probe environmentProbe) (net.IP, error) {
data, err := probe.ReadFile("/proc/net/route")
if err != nil {
return nil, fmt.Errorf("failed to read /proc/net/route: %w", err)
}
@ -660,8 +708,8 @@ func addFallbackSubnets(profile *EnvironmentProfile) (*EnvironmentProfile, error
}
// countKernelRoutes parses /proc/net/route and returns the number of route entries.
func countKernelRoutes() (int, string) {
data, err := os.ReadFile("/proc/net/route")
func countKernelRoutes(probe environmentProbe) (int, string) {
data, err := probe.ReadFile("/proc/net/route")
if err != nil {
return 0, fmt.Sprintf("Unable to read /proc/net/route: %v", err)
}