From e44a6fdadd6512e75dbd6ca01949367ed851e26b Mon Sep 17 00:00:00 2001 From: rcourtman Date: Wed, 17 Dec 2025 16:08:10 +0000 Subject: [PATCH] test(envdetect): cover environment detection decisions --- .../envdetect/detect_environment_test.go | 287 ++++++++++++++++++ pkg/discovery/envdetect/envdetect.go | 128 +++++--- 2 files changed, 375 insertions(+), 40 deletions(-) create mode 100644 pkg/discovery/envdetect/detect_environment_test.go diff --git a/pkg/discovery/envdetect/detect_environment_test.go b/pkg/discovery/envdetect/detect_environment_test.go new file mode 100644 index 000000000..a562acf75 --- /dev/null +++ b/pkg/discovery/envdetect/detect_environment_test.go @@ -0,0 +1,287 @@ +package envdetect + +import ( + "errors" + "net" + "os" + "os/exec" + "strings" + "testing" + "time" +) + +type fakeFileInfo struct{} + +func (fakeFileInfo) Name() string { return "fake" } +func (fakeFileInfo) Size() int64 { return 0 } +func (fakeFileInfo) Mode() os.FileMode { return 0 } +func (fakeFileInfo) ModTime() (t time.Time) { + return t +} +func (fakeFileInfo) IsDir() bool { return false } +func (fakeFileInfo) Sys() interface{} { return nil } + +type fakeEnvironmentProbe struct { + lookPathPresent map[string]bool + commandOutput map[string][]byte + commandErr map[string]error + fileData map[string][]byte + fileErr map[string]error + statPresent map[string]bool + interfaces []ifaceInfo + interfacesErr error +} + +func (p fakeEnvironmentProbe) LookPath(file string) (string, error) { + if p.lookPathPresent[file] { + return "/usr/bin/" + file, nil + } + return "", &exec.Error{Name: file, Err: exec.ErrNotFound} +} + +func (p fakeEnvironmentProbe) CommandCombinedOutput(name string, args ...string) ([]byte, error) { + key := name + "\x00" + strings.Join(args, "\x00") + if out, ok := p.commandOutput[key]; ok { + return out, p.commandErr[key] + } + return nil, errors.New("unexpected command invocation") +} + +func (p fakeEnvironmentProbe) Stat(name string) (os.FileInfo, error) { + if p.statPresent[name] { + return fakeFileInfo{}, nil + } + return nil, &os.PathError{Op: "stat", Path: name, Err: os.ErrNotExist} +} + +func (p fakeEnvironmentProbe) ReadFile(name string) ([]byte, error) { + if err, ok := p.fileErr[name]; ok && err != nil { + return nil, err + } + if data, ok := p.fileData[name]; ok { + return data, nil + } + return nil, &os.PathError{Op: "open", Path: name, Err: os.ErrNotExist} +} + +func (p fakeEnvironmentProbe) Interfaces() ([]ifaceInfo, error) { + if p.interfacesErr != nil { + return nil, p.interfacesErr + } + return p.interfaces, nil +} + +func mustIPNet(t *testing.T, cidr string) *net.IPNet { + t.Helper() + _, network, err := net.ParseCIDR(cidr) + if err != nil { + t.Fatalf("net.ParseCIDR(%q): %v", cidr, err) + } + return network +} + +func TestDetectEnvironment_Native(t *testing.T) { + probe := fakeEnvironmentProbe{ + interfaces: []ifaceInfo{ + { + Name: "eth0", + Flags: net.FlagUp, + Addrs: []net.Addr{ + &net.IPNet{IP: net.IPv4(192, 168, 1, 10), Mask: net.CIDRMask(24, 32)}, + }, + }, + }, + } + + profile, err := detectEnvironment(probe) + if err != nil { + t.Fatalf("detectEnvironment returned error: %v", err) + } + if profile.Type != Native { + t.Fatalf("Type = %v, want %v", profile.Type, Native) + } + if profile.Metadata["container_detected"] != "false" { + t.Fatalf("container_detected = %q, want false", profile.Metadata["container_detected"]) + } + if len(profile.Phases) != 1 || profile.Phases[0].Name != "local_networks" { + t.Fatalf("Phases = %#v, want single local_networks phase", profile.Phases) + } + if got := profile.Phases[0].Subnets[0].String(); got != mustIPNet(t, "192.168.1.0/24").String() { + t.Fatalf("subnet = %q, want %q", got, "192.168.1.0/24") + } +} + +func TestDetectEnvironment_DockerHostMode(t *testing.T) { + route := strings.Join([]string{ + "Iface\tDestination\tGateway\tFlags\tRefCnt\tUse\tMetric\tMask\tMTU\tWindow\tIRTT", + "lo\t00000000\t00000000\t0003\t0\t0\t0\t00000000\t0\t0\t0", + "eth0\t00000000\t010011AC\t0003\t0\t0\t0\t00000000\t0\t0\t0", + "eth0\t000011AC\t00000000\t0001\t0\t0\t0\t0000FFFF\t0\t0\t0", + "eth1\t00000000\t010011AC\t0003\t0\t0\t0\t00000000\t0\t0\t0", + "eth2\t00000000\t010011AC\t0003\t0\t0\t0\t00000000\t0\t0\t0", + "eth3\t00000000\t010011AC\t0003\t0\t0\t0\t00000000\t0\t0\t0", + }, "\n") + + probe := fakeEnvironmentProbe{ + statPresent: map[string]bool{"/.dockerenv": true}, + fileData: map[string][]byte{"/proc/net/route": []byte(route)}, + interfaces: []ifaceInfo{ + {Name: "lo", Flags: net.FlagUp | net.FlagLoopback}, + { + Name: "eth0", + Flags: net.FlagUp, + Addrs: []net.Addr{ + &net.IPNet{IP: net.IPv4(10, 0, 0, 10), Mask: net.CIDRMask(24, 32)}, + }, + }, + {Name: "eth1", Flags: net.FlagUp}, + {Name: "eth2", Flags: net.FlagUp}, + }, + } + + profile, err := detectEnvironment(probe) + if err != nil { + t.Fatalf("detectEnvironment returned error: %v", err) + } + if profile.Type != DockerHost { + t.Fatalf("Type = %v, want %v", profile.Type, DockerHost) + } + if profile.Metadata["docker_mode"] != "host" { + t.Fatalf("docker_mode = %q, want host", profile.Metadata["docker_mode"]) + } + if len(profile.Phases) != 1 || profile.Phases[0].Name != "host_networks" { + t.Fatalf("Phases = %#v, want single host_networks phase", profile.Phases) + } + if got := profile.Phases[0].Subnets[0].String(); got != mustIPNet(t, "10.0.0.0/24").String() { + t.Fatalf("subnet = %q, want %q", got, "10.0.0.0/24") + } +} + +func TestDetectEnvironment_DockerBridge_InferredHostNetwork(t *testing.T) { + route := strings.Join([]string{ + "Iface\tDestination\tGateway\tFlags\tRefCnt\tUse\tMetric\tMask\tMTU\tWindow\tIRTT", + "eth0\t00000000\t010011AC\t0003\t0\t0\t0\t00000000\t0\t0\t0", + "eth0\t000011AC\t00000000\t0001\t0\t0\t0\t0000FFFF\t0\t0\t0", + }, "\n") + + probe := fakeEnvironmentProbe{ + statPresent: map[string]bool{"/.dockerenv": true}, + fileData: map[string][]byte{"/proc/net/route": []byte(route)}, + interfaces: []ifaceInfo{ + { + Name: "eth0", + Flags: net.FlagUp, + Addrs: []net.Addr{ + &net.IPNet{IP: net.IPv4(172, 17, 0, 2), Mask: net.CIDRMask(16, 32)}, + }, + }, + }, + } + + profile, err := detectEnvironment(probe) + if err != nil { + t.Fatalf("detectEnvironment returned error: %v", err) + } + if profile.Type != DockerBridge { + t.Fatalf("Type = %v, want %v", profile.Type, DockerBridge) + } + + var containerFound, inferredFound bool + for _, phase := range profile.Phases { + switch phase.Name { + case "container_network": + containerFound = true + if got := phase.Subnets[0].String(); got != mustIPNet(t, "172.17.0.0/16").String() { + t.Fatalf("container subnet = %q, want %q", got, "172.17.0.0/16") + } + case "inferred_host_network": + inferredFound = true + if got := phase.Subnets[0].String(); got != mustIPNet(t, "172.17.0.0/24").String() { + t.Fatalf("inferred subnet = %q, want %q", got, "172.17.0.0/24") + } + if phase.Confidence != 0.7 { + t.Fatalf("inferred confidence = %v, want 0.7", phase.Confidence) + } + } + } + + if !containerFound || !inferredFound { + t.Fatalf("expected both container_network and inferred_host_network phases, got %#v", profile.Phases) + } +} + +func TestDetectEnvironment_LXCPrivileged_SystemdDetectVirt(t *testing.T) { + probe := fakeEnvironmentProbe{ + lookPathPresent: map[string]bool{"systemd-detect-virt": true}, + commandOutput: map[string][]byte{ + "systemd-detect-virt\x00--container": []byte("lxc\n"), + }, + fileData: map[string][]byte{ + "/proc/self/uid_map": []byte("0 0 4294967295\n"), + }, + interfaces: []ifaceInfo{ + { + Name: "eth0", + Flags: net.FlagUp, + Addrs: []net.Addr{ + &net.IPNet{IP: net.IPv4(192, 168, 50, 10), Mask: net.CIDRMask(24, 32)}, + }, + }, + }, + } + + profile, err := detectEnvironment(probe) + if err != nil { + t.Fatalf("detectEnvironment returned error: %v", err) + } + if profile.Type != LXCPrivileged { + t.Fatalf("Type = %v, want %v", profile.Type, LXCPrivileged) + } + if profile.Metadata["lxc_privileged"] != "true" { + t.Fatalf("lxc_privileged = %q, want true", profile.Metadata["lxc_privileged"]) + } + if len(profile.Phases) != 1 || profile.Phases[0].Name != "lxc_host_networks" { + t.Fatalf("Phases = %#v, want single lxc_host_networks phase", profile.Phases) + } +} + +func TestGetDefaultGateway_DefaultRouteNotFound(t *testing.T) { + route := strings.Join([]string{ + "Iface\tDestination\tGateway\tFlags\tRefCnt\tUse\tMetric\tMask\tMTU\tWindow\tIRTT", + "eth0\t000011AC\t00000000\t0001\t0\t0\t0\t0000FFFF\t0\t0\t0", + }, "\n") + + probe := fakeEnvironmentProbe{ + fileData: map[string][]byte{"/proc/net/route": []byte(route)}, + } + + _, err := getDefaultGateway(probe) + if err == nil { + t.Fatalf("expected error, got nil") + } + if !strings.Contains(err.Error(), "default gateway not found") { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestCountKernelRoutes_SkipsHeaderAndBlankLines(t *testing.T) { + route := strings.Join([]string{ + "Iface\tDestination\tGateway\tFlags\tRefCnt\tUse\tMetric\tMask\tMTU\tWindow\tIRTT", + "", + "eth0\t00000000\t010011AC\t0003\t0\t0\t0\t00000000\t0\t0\t0", + "eth0\t000011AC\t00000000\t0001\t0\t0\t0\t0000FFFF\t0\t0\t0", + " ", + }, "\n") + + probe := fakeEnvironmentProbe{ + fileData: map[string][]byte{"/proc/net/route": []byte(route)}, + } + + count, warn := countKernelRoutes(probe) + if warn != "" { + t.Fatalf("warn = %q, want empty", warn) + } + if count != 2 { + t.Fatalf("count = %d, want 2", count) + } +} diff --git a/pkg/discovery/envdetect/envdetect.go b/pkg/discovery/envdetect/envdetect.go index aaab3a8c3..d6c2cccbe 100644 --- a/pkg/discovery/envdetect/envdetect.go +++ b/pkg/discovery/envdetect/envdetect.go @@ -15,6 +15,52 @@ import ( "github.com/rs/zerolog/log" ) +type ifaceInfo struct { + Name string + Flags net.Flags + Addrs []net.Addr + AddrsErr error +} + +type environmentProbe interface { + LookPath(file string) (string, error) + CommandCombinedOutput(name string, args ...string) ([]byte, error) + Stat(name string) (os.FileInfo, error) + ReadFile(name string) ([]byte, error) + Interfaces() ([]ifaceInfo, error) +} + +type systemEnvironmentProbe struct{} + +func (systemEnvironmentProbe) LookPath(file string) (string, error) { return exec.LookPath(file) } + +func (systemEnvironmentProbe) CommandCombinedOutput(name string, args ...string) ([]byte, error) { + return exec.Command(name, args...).CombinedOutput() +} + +func (systemEnvironmentProbe) Stat(name string) (os.FileInfo, error) { return os.Stat(name) } +func (systemEnvironmentProbe) ReadFile(name string) ([]byte, error) { return os.ReadFile(name) } + +func (systemEnvironmentProbe) Interfaces() ([]ifaceInfo, error) { + interfaces, err := net.Interfaces() + if err != nil { + return nil, err + } + + out := make([]ifaceInfo, 0, len(interfaces)) + for _, iface := range interfaces { + addrs, addrsErr := iface.Addrs() + out = append(out, ifaceInfo{ + Name: iface.Name, + Flags: iface.Flags, + Addrs: addrs, + AddrsErr: addrsErr, + }) + } + + return out, nil +} + // Environment represents the runtime environment type. type Environment int @@ -95,6 +141,10 @@ type EnvironmentProfile struct { // DetectEnvironment performs environment detection and returns a profile. func DetectEnvironment() (*EnvironmentProfile, error) { + return detectEnvironment(systemEnvironmentProbe{}) +} + +func detectEnvironment(probe environmentProbe) (*EnvironmentProfile, error) { profile := &EnvironmentProfile{ Type: Unknown, Phases: []SubnetPhase{}, @@ -106,7 +156,7 @@ func DetectEnvironment() (*EnvironmentProfile, error) { log.Info().Msg("Detecting runtime environment") - isContainer, containerType := detectContainer() + isContainer, containerType := detectContainer(probe) profile.Metadata["container_detected"] = strconv.FormatBool(isContainer) if containerType != "" { profile.Metadata["container_type"] = containerType @@ -115,11 +165,11 @@ func DetectEnvironment() (*EnvironmentProfile, error) { var err error switch { case !isContainer: - profile, err = detectNativeEnvironment(profile) + profile, err = detectNativeEnvironment(profile, probe) case containerType == "docker": - profile, err = detectDockerEnvironment(profile) + profile, err = detectDockerEnvironment(profile, probe) case containerType == "lxc": - profile, err = detectLXCEnvironment(profile) + profile, err = detectLXCEnvironment(profile, probe) default: profile.Type = Unknown profile.Confidence = 0.3 @@ -152,13 +202,12 @@ func DetectEnvironment() (*EnvironmentProfile, error) { } // detectContainer inspects the host to determine whether we are inside a container. -func detectContainer() (bool, string) { +func detectContainer(probe environmentProbe) (bool, string) { containerType := "" // 1. systemd-detect-virt --container - if _, err := exec.LookPath("systemd-detect-virt"); err == nil { - cmd := exec.Command("systemd-detect-virt", "--container") - output, err := cmd.CombinedOutput() + if _, err := probe.LookPath("systemd-detect-virt"); err == nil { + output, err := probe.CommandCombinedOutput("systemd-detect-virt", "--container") if len(output) > 0 { result := strings.TrimSpace(string(output)) if result != "" && result != "none" { @@ -179,17 +228,17 @@ func detectContainer() (bool, string) { } // 2. Marker files - if _, err := os.Stat("/.dockerenv"); err == nil { + if _, err := probe.Stat("/.dockerenv"); err == nil { log.Debug().Msg("Detected /.dockerenv marker (Docker container)") return true, "docker" } - if _, err := os.Stat("/run/.containerenv"); err == nil { + if _, err := probe.Stat("/run/.containerenv"); err == nil { log.Debug().Msg("Detected /run/.containerenv marker (Podman/OCI container)") return true, "docker" } // 3. /proc/1/cgroup - if data, err := os.ReadFile("/proc/1/cgroup"); err == nil { + if data, err := probe.ReadFile("/proc/1/cgroup"); err == nil { text := string(data) switch { case strings.Contains(text, "docker"), strings.Contains(text, "kubepods"), strings.Contains(text, "containerd"): @@ -204,7 +253,7 @@ func detectContainer() (bool, string) { } // 4. /proc/1/environ - if data, err := os.ReadFile("/proc/1/environ"); err == nil { + if data, err := probe.ReadFile("/proc/1/environ"); err == nil { text := string(data) switch { case strings.Contains(text, "container=lxc"): @@ -223,8 +272,8 @@ func detectContainer() (bool, string) { } // detectNativeEnvironment builds an EnvironmentProfile for native or VM deployments. -func detectNativeEnvironment(profile *EnvironmentProfile) (*EnvironmentProfile, error) { - subnets, err := getAllLocalSubnets() +func detectNativeEnvironment(profile *EnvironmentProfile, probe environmentProbe) (*EnvironmentProfile, error) { + subnets, err := getAllLocalSubnets(probe) if err != nil { return addFallbackSubnets(profileWithWarning(profile, fmt.Sprintf("Failed to enumerate interfaces: %v", err))) } @@ -249,14 +298,14 @@ func detectNativeEnvironment(profile *EnvironmentProfile) (*EnvironmentProfile, } // detectDockerEnvironment determines whether Docker uses host or bridge networking. -func detectDockerEnvironment(profile *EnvironmentProfile) (*EnvironmentProfile, error) { - hostMode, hostModeWarnings := isDockerHostMode() +func detectDockerEnvironment(profile *EnvironmentProfile, probe environmentProbe) (*EnvironmentProfile, error) { + hostMode, hostModeWarnings := isDockerHostMode(probe) if len(hostModeWarnings) > 0 { profile.Warnings = append(profile.Warnings, hostModeWarnings...) } if hostMode { - subnets, err := getAllLocalSubnets() + subnets, err := getAllLocalSubnets(probe) if err != nil { return addFallbackSubnets(profileWithWarning(profile, fmt.Sprintf("Docker host mode: failed to enumerate subnets: %v", err))) } @@ -285,7 +334,7 @@ func detectDockerEnvironment(profile *EnvironmentProfile) (*EnvironmentProfile, profile.Confidence = 0.85 profile.Metadata["docker_mode"] = "bridge" - containerSubnets, err := getAllLocalSubnets() + containerSubnets, err := getAllLocalSubnets(probe) if err != nil { profile.Warnings = append(profile.Warnings, fmt.Sprintf("Docker bridge: failed to enumerate container subnets: %v", err)) } else if len(containerSubnets) > 0 { @@ -301,7 +350,7 @@ func detectDockerEnvironment(profile *EnvironmentProfile) (*EnvironmentProfile, } if profile.Policy.ScanGateways { - hostSubnets, confidence, warnings := detectHostNetworkFromContainer() + hostSubnets, confidence, warnings := detectHostNetworkFromContainer(probe) if len(warnings) > 0 { profile.Warnings = append(profile.Warnings, warnings...) } @@ -324,13 +373,13 @@ func detectDockerEnvironment(profile *EnvironmentProfile) (*EnvironmentProfile, } // detectLXCEnvironment evaluates privilege level and prepares scanning phases. -func detectLXCEnvironment(profile *EnvironmentProfile) (*EnvironmentProfile, error) { - privileged, warn := isLXCPrivileged() +func detectLXCEnvironment(profile *EnvironmentProfile, probe environmentProbe) (*EnvironmentProfile, error) { + privileged, warn := isLXCPrivileged(probe) if warn != "" { profile.Warnings = append(profile.Warnings, warn) } - containerSubnets, err := getAllLocalSubnets() + containerSubnets, err := getAllLocalSubnets(probe) if err != nil { profile.Warnings = append(profile.Warnings, fmt.Sprintf("LXC: failed to enumerate container subnets: %v", err)) } @@ -371,7 +420,7 @@ func detectLXCEnvironment(profile *EnvironmentProfile) (*EnvironmentProfile, err } if profile.Policy.ScanGateways { - hostSubnets, confidence, warnings := detectHostNetworkFromContainer() + hostSubnets, confidence, warnings := detectHostNetworkFromContainer(probe) if len(warnings) > 0 { profile.Warnings = append(profile.Warnings, warnings...) } @@ -394,17 +443,17 @@ func detectLXCEnvironment(profile *EnvironmentProfile) (*EnvironmentProfile, err } // isDockerHostMode attempts to determine whether Docker is using host networking. -func isDockerHostMode() (bool, []string) { +func isDockerHostMode(probe environmentProbe) (bool, []string) { var warnings []string - interfaces, err := net.Interfaces() + interfaces, err := probe.Interfaces() if err != nil { log.Debug().Err(err).Msg("Failed to enumerate interfaces while detecting Docker mode") warnings = append(warnings, fmt.Sprintf("Unable to enumerate interfaces: %v", err)) return false, warnings } - routeCount, routeWarn := countKernelRoutes() + routeCount, routeWarn := countKernelRoutes(probe) if routeWarn != "" { warnings = append(warnings, routeWarn) } @@ -423,8 +472,8 @@ func isDockerHostMode() (bool, []string) { } // isLXCPrivileged inspects UID mappings to determine privilege level. -func isLXCPrivileged() (bool, string) { - data, err := os.ReadFile("/proc/self/uid_map") +func isLXCPrivileged(probe environmentProbe) (bool, string) { + data, err := probe.ReadFile("/proc/self/uid_map") if err != nil { if errors.Is(err, os.ErrPermission) { return false, "Unable to read /proc/self/uid_map (permission denied); assuming unprivileged LXC" @@ -456,8 +505,8 @@ func isLXCPrivileged() (bool, string) { } // getAllLocalSubnets enumerates non-loopback, UP IPv4 subnets. -func getAllLocalSubnets() ([]net.IPNet, error) { - interfaces, err := net.Interfaces() +func getAllLocalSubnets(probe environmentProbe) ([]net.IPNet, error) { + interfaces, err := probe.Interfaces() if err != nil { return nil, fmt.Errorf("failed to list interfaces: %w", err) } @@ -470,13 +519,12 @@ func getAllLocalSubnets() ([]net.IPNet, error) { continue } - addrs, err := iface.Addrs() - if err != nil { - log.Debug().Err(err).Str("interface", iface.Name).Msg("Skipping interface due to address enumeration failure") + if iface.AddrsErr != nil { + log.Debug().Err(iface.AddrsErr).Str("interface", iface.Name).Msg("Skipping interface due to address enumeration failure") continue } - for _, addr := range addrs { + for _, addr := range iface.Addrs { ipNet, ok := addr.(*net.IPNet) if !ok || ipNet == nil { continue @@ -512,10 +560,10 @@ func getAllLocalSubnets() ([]net.IPNet, error) { } // detectHostNetworkFromContainer infers host LAN subnets from container context. -func detectHostNetworkFromContainer() ([]net.IPNet, float64, []string) { +func detectHostNetworkFromContainer(probe environmentProbe) ([]net.IPNet, float64, []string) { var warnings []string - gateway, err := getDefaultGateway() + gateway, err := getDefaultGateway(probe) if err != nil { warnings = append(warnings, fmt.Sprintf("Could not determine default gateway: %v", err)) return tryCommonSubnets(), 0.3, warnings @@ -553,8 +601,8 @@ func detectHostNetworkFromContainer() ([]net.IPNet, float64, []string) { } // getDefaultGateway parses /proc/net/route for the default gateway. -func getDefaultGateway() (net.IP, error) { - data, err := os.ReadFile("/proc/net/route") +func getDefaultGateway(probe environmentProbe) (net.IP, error) { + data, err := probe.ReadFile("/proc/net/route") if err != nil { return nil, fmt.Errorf("failed to read /proc/net/route: %w", err) } @@ -660,8 +708,8 @@ func addFallbackSubnets(profile *EnvironmentProfile) (*EnvironmentProfile, error } // countKernelRoutes parses /proc/net/route and returns the number of route entries. -func countKernelRoutes() (int, string) { - data, err := os.ReadFile("/proc/net/route") +func countKernelRoutes(probe environmentProbe) (int, string) { + data, err := probe.ReadFile("/proc/net/route") if err != nil { return 0, fmt.Sprintf("Unable to read /proc/net/route: %v", err) }