diff --git a/Gopkg.lock b/Gopkg.lock index 7d1613bb..437e76c0 100644 --- a/Gopkg.lock +++ b/Gopkg.lock @@ -2,12 +2,12 @@ [[projects]] - digest = "1:f82b8ac36058904227087141017bb82f4b0fc58272990a4cdae3e2d6d222644e" + digest = "1:6146fda730c18186631e91e818d995e759e7cbe27644d6871ccd469f6865c686" name = "github.com/StackExchange/wmi" packages = ["."] pruneopts = "" - revision = "5d049714c4a64225c3c79a7cf7d02f7fb5b96338" - version = "1.0.0" + revision = "cbe66965904dbe8a6cd589e2298e5d8b986bd7dd" + version = "1.1.0" [[projects]] digest = "1:e010d6b45ee6c721df761eae89961c634ceb55feff166a48d15504729309f267" @@ -18,12 +18,12 @@ version = "v1.1.1" [[projects]] - digest = "1:3c753679736345f50125ae993e0a2614da126859921ea7faeecda6d217501ce2" + digest = "1:21caed545a1c7ef7a2627bbb45989f689872ff6d5087d49c31340ce74c36de59" name = "github.com/agext/levenshtein" packages = ["."] pruneopts = "" - revision = "0ded9c86537917af2ff89bc9c78de6bd58477894" - version = "v1.2.2" + revision = "52c14c47d03211d8ac1834e94601635e07c5a6ef" + version = "v1.2.3" [[projects]] branch = "v2.1" @@ -34,12 +34,12 @@ revision = "d27c04069d0d5dfe11c202dacbf745ae8d1ab181" [[projects]] - digest = "1:166e24c91c2732657d2f791d3ee3897e7d85ece7cbb62ad991250e6b51fc1d4c" + digest = "1:f384a8b6f89c502229e9013aa4f89ce5b5b56f09f9a4d601d7f1f026d3564fbf" name = "github.com/coreos/go-iptables" packages = ["iptables"] pruneopts = "" - revision = "78b5fff24e6df8886ef8eca9411f683a884349a5" - version = "v0.4.1" + revision = "f901d6c2a4f2a4df092b98c33366dfba1f93d7a0" + version = "v0.4.5" [[projects]] digest = "1:0deddd908b6b4b768cfc272c16ee61e7088a60f7fe2f06c547bd3d8e1f8b8e77" @@ -61,12 +61,12 @@ version = "v1.2.4" [[projects]] - digest = "1:cc1255e2fef3819bfab3540277001e602892dd431ef9ab5499bcdbc425923d64" + digest = "1:f63933986e63230fc32512ed00bc18ea4dbb0f57b5da18561314928fd20c2ff0" name = "github.com/godbus/dbus" packages = ["."] pruneopts = "" - revision = "2ff6f7ffd60f0f2410b3105864bdd12c7894f844" - version = "v5.0.1" + revision = "37bf87eef99d69c4f1d3528bd66e3a87dc201472" + version = "v5.0.3" [[projects]] digest = "1:e85e59c4152d8576341daf54f40d96c404c264e04941a4a36b97a0f427eb9e5e" @@ -113,20 +113,20 @@ revision = "2bc1f35cddc0cc527b4bc3dce8578fc2a6c11384" [[projects]] - digest = "1:0b6694f306890ddbb69c96a16776510bd24e07436fae3f9b0a4e5b650f1e6fb7" + branch = "master" + digest = "1:c140772b00f0c26cf6627aee32f62d9f9d89dffcda648861266c482c36a5344a" name = "github.com/miekg/dns" packages = ["."] pruneopts = "" - revision = "b13675009d59c97f3721247d9efa8914e1866a5b" - version = "v1.1.15" + revision = "b7703d0fa022e159d01efa2de82e6173d5ec04c8" [[projects]] - digest = "1:3819cd861b7abd7d12dc1ea52ecb998ad1171826a76ecf0aefa09545781091f9" + digest = "1:b962a528cbecf7662bee4d84a600f7a0a6a130368666d7d461757ba4d1341906" name = "github.com/oschwald/maxminddb-golang" packages = ["."] pruneopts = "" - revision = "2905694a1b00c5574f1418a7dbf8a22a7d247559" - version = "v1.3.1" + revision = "6a033e62c03b7dab4c37f7c9eb2ebb3b10e8f13a" + version = "v1.6.0" [[projects]] digest = "1:256484dbbcd271f9ecebc6795b2df8cad4c458dd0f5fd82a8c2fa0c29f233411" @@ -145,7 +145,7 @@ version = "v1.2.0" [[projects]] - digest = "1:8bf42eb2ded52ed2678b0716dbfbf30628765bc12b13222c4d5669ba4c1310e4" + digest = "1:16f319cf21ddf49f27b3a2093d68316840dc25ec5c2a0a431a4a4fc01ea707e2" name = "github.com/shirou/gopsutil" packages = [ "cpu", @@ -155,32 +155,24 @@ "process", ] pruneopts = "" - revision = "4c8b404ee5c53b04b04f34b1744a26bf5d2910de" - version = "v2.19.6" + revision = "a81cf97fce2300934e6c625b9917103346c26ba3" + version = "v2.20.4" [[projects]] - branch = "master" - digest = "1:99c6a6dab47067c9b898e8c8b13d130c6ab4ffbcc4b7cc6236c2cd0b1e344f5b" - name = "github.com/shirou/w32" - packages = ["."] - pruneopts = "" - revision = "bb4de0191aa41b5507caa14b0650cdbddcd9280b" - -[[projects]] - digest = "1:0c63b3c7ad6d825a898f28cb854252a3b29d37700c68a117a977263f5ec94efe" + digest = "1:bff75d4f1a2d2c4b8f4b46ff5ac230b80b5fa49276f615900cba09fe4c97e66e" name = "github.com/spf13/cobra" packages = ["."] pruneopts = "" - revision = "f2b07da1e2c38d5f12845a4f607e2e1018cbb1f5" - version = "v0.0.5" + revision = "a684a6d7f5e37385d954dd3b5a14fc6912c6ab9d" + version = "v1.0.0" [[projects]] - digest = "1:cbaf13cdbfef0e4734ed8a7504f57fe893d471d62a35b982bf6fb3f036449a66" + digest = "1:688428eeb1ca80d92599eb3254bdf91b51d7e232fead3a73844c1f201a281e51" name = "github.com/spf13/pflag" packages = ["."] pruneopts = "" - revision = "298182f68c66c05229eb03ac171abe6e309ee79a" - version = "v1.0.3" + revision = "2e9d26c8c37aae03e3f9d4e90b7116f5accb7cab" + version = "v1.0.5" [[projects]] digest = "1:cc4eb6813da8d08694e557fcafae8fcc24f47f61a0717f952da130ca9a486dfc" @@ -208,18 +200,18 @@ [[projects]] branch = "master" - digest = "1:086760278d762dbb0e9a26e09b57f04c89178c86467d8d94fae47d64c222f328" + digest = "1:bf61fa9b53be5ce096004599b957e5957b28a5e421b724250aa06ecb7ee6dc57" name = "golang.org/x/crypto" packages = [ "ed25519", "ed25519/internal/edwards25519", ] pruneopts = "" - revision = "4def268fd1a49955bfb3dda92fe3db4f924f2285" + revision = "4b2356b1ed79e6be3deca3737a3db3d132d2847a" [[projects]] branch = "master" - digest = "1:31cd6e3c114e17c5f0c9e8b0bcaa3025ab3c221ce36323c7ce1acaa753d0d0aa" + digest = "1:ea84836e35d7a66c9b8944796295912509c80c921244bc4e098c5417219895f2" name = "golang.org/x/net" packages = [ "bpf", @@ -232,7 +224,7 @@ "publicsuffix", ] pruneopts = "" - revision = "da137c7871d730100384dbcf36e6f8fa493aef5b" + revision = "7e3656a0809f6f95abd88ac65313578f80b00df2" [[projects]] branch = "master" @@ -244,9 +236,10 @@ [[projects]] branch = "master" - digest = "1:2579a16d8afda9c9a475808c13324f5e572852e8927905ffa15bb14e71baba4f" + digest = "1:acb3b56e190190ac9497faf5f0c30c5da4d3e8278d6b7a7042f2aa3332ff7022" name = "golang.org/x/sys" packages = [ + "internal/unsafeheader", "unix", "windows", "windows/registry", @@ -256,7 +249,7 @@ "windows/svc/mgr", ] pruneopts = "" - revision = "04f50cda93cbb67f2afa353c52f342100e80e625" + revision = "bc7a7d42d5c30f4d0fe808715c002826ce2c624e" [[projects]] digest = "1:740b51a55815493a8d0f2b1e0d0ae48fe48953bf7eaf3fcc4198823bf67768c0" @@ -283,6 +276,38 @@ revision = "342b2e1fbaa52c93f31447ad2c6abc048c63e475" version = "v0.3.2" +[[projects]] + branch = "master" + digest = "1:3416c611e00178b07c8fc347ba96398e4d6709fe7d3fab17f0b0fa6f933b4bd1" + name = "golang.org/x/tools" + packages = [ + "go/ast/astutil", + "go/gcexportdata", + "go/internal/gcimporter", + "go/internal/packagesdriver", + "go/packages", + "go/types/typeutil", + "internal/event", + "internal/event/core", + "internal/event/keys", + "internal/event/label", + "internal/gocommand", + "internal/packagesinternal", + ] + pruneopts = "" + revision = "b8469989bc69e50ec6dc4e4513fc3ff9ce48b8af" + +[[projects]] + branch = "master" + digest = "1:9d4ac09a835404ae9306c6e1493cf800ecbb0f3f828f4333b3e055de4c962eea" + name = "golang.org/x/xerrors" + packages = [ + ".", + "internal", + ] + pruneopts = "" + revision = "9bdfabe68543c54f90421aeb9a60ef8061b5b544" + [[projects]] digest = "1:2efc9662a6a1ff28c65c84fc2f9030f13d3afecdb2ecad445f3b0c80e75fc281" name = "gopkg.in/yaml.v2" diff --git a/Gopkg.toml b/Gopkg.toml index 764c45b2..f449d026 100644 --- a/Gopkg.toml +++ b/Gopkg.toml @@ -25,3 +25,7 @@ # unused-packages = true ignored = ["github.com/safing/portbase/*"] + +[[constraint]] + name = "github.com/miekg/dns" + branch = "master" # switch back to semver releases when https://github.com/miekg/dns/pull/1110 is released diff --git a/firewall/api.go b/firewall/api.go index b73729c8..b435e7db 100644 --- a/firewall/api.go +++ b/firewall/api.go @@ -60,7 +60,18 @@ func apiAuthenticator(s *http.Server, r *http.Request) (grantAccess bool, err er var procsChecked []string // get process - proc, err := process.GetProcessByEndpoints(r.Context(), remoteIP, remotePort, localIP, localPort, packet.TCP) // switch reverse/local to get remote process + proc, _, err := process.GetProcessByConnection( + r.Context(), + &packet.Info{ + Inbound: false, // outbound as we are looking for the process of the source address + Version: packet.IPv4, + Protocol: packet.TCP, + Src: remoteIP, // source as in the process we are looking for + SrcPort: remotePort, // source as in the process we are looking for + Dst: localIP, + DstPort: localPort, + }, + ) if err != nil { return false, fmt.Errorf("failed to get process: %s", err) } diff --git a/firewall/config.go b/firewall/config.go index 4a6a3abf..19d1b9c4 100644 --- a/firewall/config.go +++ b/firewall/config.go @@ -48,7 +48,7 @@ func registerConfig() error { Order: CfgOptionAskWithSystemNotificationsOrder, OptType: config.OptTypeBool, ExpertiseLevel: config.ExpertiseLevelUser, - ReleaseLevel: config.ReleaseLevelStable, + ReleaseLevel: config.ReleaseLevelExperimental, DefaultValue: true, }) if err != nil { @@ -62,7 +62,7 @@ func registerConfig() error { Order: CfgOptionAskTimeoutOrder, OptType: config.OptTypeInt, ExpertiseLevel: config.ExpertiseLevelUser, - ReleaseLevel: config.ReleaseLevelStable, + ReleaseLevel: config.ReleaseLevelExperimental, DefaultValue: 60, }) if err != nil { diff --git a/firewall/interception.go b/firewall/interception.go index f99c94c0..c87a8c57 100644 --- a/firewall/interception.go +++ b/firewall/interception.go @@ -138,7 +138,7 @@ func handlePacket(pkt packet.Packet) { // pkt.RedirToNameserver() // } - // allow ICMP, IGMP and DHCP + // allow ICMP and DHCP // TODO: actually handle these switch meta.Protocol { case packet.ICMP: @@ -149,10 +149,6 @@ func handlePacket(pkt packet.Packet) { log.Debugf("accepting ICMPv6: %s", pkt) _ = pkt.PermanentAccept() return - case packet.IGMP: - log.Debugf("accepting IGMP: %s", pkt) - _ = pkt.PermanentAccept() - return case packet.UDP: if meta.DstPort == 67 || meta.DstPort == 68 { log.Debugf("accepting DHCP: %s", pkt) @@ -218,6 +214,7 @@ func initialHandler(conn *network.Connection, pkt packet.Packet) { // reroute dns requests to nameserver if conn.Process().Pid != os.Getpid() && pkt.IsOutbound() && pkt.Info().DstPort == 53 && !pkt.Info().Src.Equal(pkt.Info().Dst) { conn.Verdict = network.VerdictRerouteToNameserver + conn.Internal = true conn.StopFirewallHandler() issueVerdict(conn, pkt, 0, true) return @@ -233,7 +230,7 @@ func initialHandler(conn *network.Connection, pkt packet.Packet) { } log.Tracer(pkt.Ctx()).Trace("filter: starting decision process") - DecideOnConnection(conn, pkt) + DecideOnConnection(pkt.Ctx(), conn, pkt) conn.Inspecting = false // TODO: enable inspecting again switch { diff --git a/firewall/interception/windowskext/handler.go b/firewall/interception/windowskext/handler.go index 99ad9b59..97238324 100644 --- a/firewall/interception/windowskext/handler.go +++ b/firewall/interception/windowskext/handler.go @@ -62,7 +62,7 @@ func Handler(packets chan packet.Packet) { } info := new.Info() - info.Direction = packetInfo.direction > 0 + info.Inbound = packetInfo.direction > 0 info.InTunnel = false info.Protocol = packet.IPProtocol(packetInfo.protocol) @@ -76,7 +76,7 @@ func Handler(packets chan packet.Packet) { // IPs if info.Version == packet.IPv4 { // IPv4 - if info.Direction { + if info.Inbound { // Inbound info.Src = convertIPv4(packetInfo.remoteIP) info.Dst = convertIPv4(packetInfo.localIP) @@ -87,7 +87,7 @@ func Handler(packets chan packet.Packet) { } } else { // IPv6 - if info.Direction { + if info.Inbound { // Inbound info.Src = convertIPv6(packetInfo.remoteIP) info.Dst = convertIPv6(packetInfo.localIP) @@ -99,7 +99,7 @@ func Handler(packets chan packet.Packet) { } // Ports - if info.Direction { + if info.Inbound { // Inbound info.SrcPort = packetInfo.remotePort info.DstPort = packetInfo.localPort @@ -113,19 +113,17 @@ func Handler(packets chan packet.Packet) { } } +// convertIPv4 as needed for data from the kernel func convertIPv4(input [4]uint32) net.IP { - return net.IPv4( - uint8(input[0]>>24&0xFF), - uint8(input[0]>>16&0xFF), - uint8(input[0]>>8&0xFF), - uint8(input[0]&0xFF), - ) + addressBuf := make([]byte, 4) + binary.BigEndian.PutUint32(addressBuf, input[0]) + return net.IP(addressBuf) } func convertIPv6(input [4]uint32) net.IP { addressBuf := make([]byte, 16) for i := 0; i < 4; i++ { - binary.LittleEndian.PutUint32(addressBuf[i*4:i*4+4], input[i]) + binary.BigEndian.PutUint32(addressBuf[i*4:i*4+4], input[i]) } return net.IP(addressBuf) } diff --git a/firewall/master.go b/firewall/master.go index a1b30203..7f194960 100644 --- a/firewall/master.go +++ b/firewall/master.go @@ -1,6 +1,7 @@ package firewall import ( + "context" "fmt" "os" "path/filepath" @@ -10,6 +11,7 @@ import ( "github.com/safing/portmaster/network" "github.com/safing/portmaster/network/netutils" "github.com/safing/portmaster/network/packet" + "github.com/safing/portmaster/network/state" "github.com/safing/portmaster/process" "github.com/safing/portmaster/profile" "github.com/safing/portmaster/profile/endpoints" @@ -32,10 +34,10 @@ import ( // DecideOnConnection makes a decision about a connection. // When called, the connection and profile is already locked. -func DecideOnConnection(conn *network.Connection, pkt packet.Packet) { +func DecideOnConnection(ctx context.Context, conn *network.Connection, pkt packet.Packet) { // update profiles and check if communication needs reevaluation if conn.UpdateAndCheck() { - log.Infof("filter: re-evaluating verdict on %s", conn) + log.Tracer(ctx).Infof("filter: re-evaluating verdict on %s", conn) conn.Verdict = network.VerdictUndecided if conn.Entity != nil { @@ -43,7 +45,7 @@ func DecideOnConnection(conn *network.Connection, pkt packet.Packet) { } } - var deciders = []func(*network.Connection, packet.Packet) bool{ + var deciders = []func(context.Context, *network.Connection, packet.Packet) bool{ checkPortmasterConnection, checkSelfCommunication, checkProfileExists, @@ -59,7 +61,7 @@ func DecideOnConnection(conn *network.Connection, pkt packet.Packet) { } for _, decider := range deciders { - if decider(conn, pkt) { + if decider(ctx, conn, pkt) { return } } @@ -70,10 +72,10 @@ func DecideOnConnection(conn *network.Connection, pkt packet.Packet) { // checkPortmasterConnection allows all connection that originate from // portmaster itself. -func checkPortmasterConnection(conn *network.Connection, _ packet.Packet) bool { +func checkPortmasterConnection(ctx context.Context, conn *network.Connection, pkt packet.Packet) bool { // grant self if conn.Process().Pid == os.Getpid() { - log.Infof("filter: granting own connection %s", conn) + log.Tracer(ctx).Infof("filter: granting own connection %s", conn) conn.Verdict = network.VerdictAccept conn.Internal = true return true @@ -83,27 +85,29 @@ func checkPortmasterConnection(conn *network.Connection, _ packet.Packet) bool { } // checkSelfCommunication checks if the process is communicating with itself. -func checkSelfCommunication(conn *network.Connection, pkt packet.Packet) bool { +func checkSelfCommunication(ctx context.Context, conn *network.Connection, pkt packet.Packet) bool { // check if process is communicating with itself if pkt != nil { // TODO: evaluate the case where different IPs in the 127/8 net are used. pktInfo := pkt.Info() if conn.Process().Pid >= 0 && pktInfo.Src.Equal(pktInfo.Dst) { // get PID - otherPid, _, err := process.GetPidByEndpoints( - pktInfo.RemoteIP(), - pktInfo.RemotePort(), - pktInfo.LocalIP(), - pktInfo.LocalPort(), - pktInfo.Protocol, - ) + otherPid, _, err := state.Lookup(&packet.Info{ + Inbound: !pktInfo.Inbound, // we want to know the process on the other end + Version: pktInfo.Version, + Protocol: pktInfo.Protocol, + Src: pktInfo.Src, + SrcPort: pktInfo.SrcPort, + Dst: pktInfo.Dst, + DstPort: pktInfo.DstPort, + }) if err != nil { - log.Warningf("filter: failed to find local peer process PID: %s", err) + log.Tracer(ctx).Warningf("filter: failed to find local peer process PID: %s", err) } else { // get primary process - otherProcess, err := process.GetOrFindPrimaryProcess(pkt.Ctx(), otherPid) + otherProcess, err := process.GetOrFindPrimaryProcess(ctx, otherPid) if err != nil { - log.Warningf("filter: failed to find load local peer process with PID %d: %s", otherPid, err) + log.Tracer(ctx).Warningf("filter: failed to find load local peer process with PID %d: %s", otherPid, err) } else if otherProcess.Pid == conn.Process().Pid { conn.Accept("connection to self") conn.Internal = true @@ -116,7 +120,7 @@ func checkSelfCommunication(conn *network.Connection, pkt packet.Packet) bool { return false } -func checkProfileExists(conn *network.Connection, _ packet.Packet) bool { +func checkProfileExists(_ context.Context, conn *network.Connection, _ packet.Packet) bool { if conn.Process().Profile() == nil { conn.Block("unknown process or profile") return true @@ -124,7 +128,7 @@ func checkProfileExists(conn *network.Connection, _ packet.Packet) bool { return false } -func checkEndpointLists(conn *network.Connection, _ packet.Packet) bool { +func checkEndpointLists(_ context.Context, conn *network.Connection, _ packet.Packet) bool { var result endpoints.EPResult var reason endpoints.Reason @@ -149,12 +153,12 @@ func checkEndpointLists(conn *network.Connection, _ packet.Packet) bool { return false } -func checkConnectionType(conn *network.Connection, _ packet.Packet) bool { +func checkConnectionType(ctx context.Context, conn *network.Connection, _ packet.Packet) bool { p := conn.Process().Profile() // check conn type switch conn.Scope { - case network.IncomingHost, network.IncomingLAN, network.IncomingInternet, network.IncomingInvalid: + case network.IncomingLAN, network.IncomingInternet, network.IncomingInvalid: if p.BlockInbound() { if conn.Scope == network.IncomingHost { conn.Block("inbound connections blocked") @@ -174,7 +178,7 @@ func checkConnectionType(conn *network.Connection, _ packet.Packet) bool { return false } -func checkConnectionScope(conn *network.Connection, _ packet.Packet) bool { +func checkConnectionScope(_ context.Context, conn *network.Connection, _ packet.Packet) bool { p := conn.Process().Profile() // check scopes @@ -213,7 +217,7 @@ func checkConnectionScope(conn *network.Connection, _ packet.Packet) bool { return false } -func checkBypassPrevention(conn *network.Connection, _ packet.Packet) bool { +func checkBypassPrevention(_ context.Context, conn *network.Connection, _ packet.Packet) bool { if conn.Process().Profile().PreventBypassing() { // check for bypass protection result, reason, reasonCtx := PreventBypassing(conn) @@ -230,7 +234,7 @@ func checkBypassPrevention(conn *network.Connection, _ packet.Packet) bool { return false } -func checkFilterLists(conn *network.Connection, _ packet.Packet) bool { +func checkFilterLists(ctx context.Context, conn *network.Connection, pkt packet.Packet) bool { // apply privacy filter lists p := conn.Process().Profile() @@ -242,12 +246,12 @@ func checkFilterLists(conn *network.Connection, _ packet.Packet) bool { case endpoints.NoMatch: // nothing to do default: - log.Debugf("filter: filter lists returned unsupported verdict: %s", result) + log.Tracer(ctx).Debugf("filter: filter lists returned unsupported verdict: %s", result) } return false } -func checkInbound(conn *network.Connection, _ packet.Packet) bool { +func checkInbound(_ context.Context, conn *network.Connection, _ packet.Packet) bool { // implicit default=block for inbound if conn.Inbound { conn.Drop("endpoint is not whitelisted (incoming is always default=block)") @@ -256,7 +260,7 @@ func checkInbound(conn *network.Connection, _ packet.Packet) bool { return false } -func checkDefaultPermit(conn *network.Connection, _ packet.Packet) bool { +func checkDefaultPermit(_ context.Context, conn *network.Connection, _ packet.Packet) bool { // check default action p := conn.Process().Profile() if p.DefaultAction() == profile.DefaultActionPermit { @@ -266,7 +270,7 @@ func checkDefaultPermit(conn *network.Connection, _ packet.Packet) bool { return false } -func checkAutoPermitRelated(conn *network.Connection, _ packet.Packet) bool { +func checkAutoPermitRelated(_ context.Context, conn *network.Connection, _ packet.Packet) bool { p := conn.Process().Profile() if !p.DisableAutoPermit() { related, reason := checkRelation(conn) @@ -278,7 +282,7 @@ func checkAutoPermitRelated(conn *network.Connection, _ packet.Packet) bool { return false } -func checkDefaultAction(conn *network.Connection, pkt packet.Packet) bool { +func checkDefaultAction(_ context.Context, conn *network.Connection, pkt packet.Packet) bool { p := conn.Process().Profile() if p.DefaultAction() == profile.DefaultActionAsk { prompt(conn, pkt) diff --git a/intel/block_reason.go b/intel/block_reason.go index 26bd0a2a..ad140f4f 100644 --- a/intel/block_reason.go +++ b/intel/block_reason.go @@ -71,9 +71,9 @@ func (br ListBlockReason) GetExtraRR(_ *dns.Msg, _ string, _ interface{}) []dns. for _, lm := range br { blockedBy, err := dns.NewRR(fmt.Sprintf( - "%s-blockedBy. 0 IN TXT %q", - strings.TrimRight(lm.Entity, "."), - strings.Join(lm.ActiveLists, ","), + `%s 0 IN TXT "blocked by filter lists %s"`, + lm.Entity, + strings.Join(lm.ActiveLists, ", "), )) if err == nil { rrs = append(rrs, blockedBy) @@ -83,9 +83,9 @@ func (br ListBlockReason) GetExtraRR(_ *dns.Msg, _ string, _ interface{}) []dns. if len(lm.InactiveLists) > 0 { wouldBeBlockedBy, err := dns.NewRR(fmt.Sprintf( - "%s-wouldBeBlockedBy. 0 IN TXT %q", - strings.TrimRight(lm.Entity, "."), - strings.Join(lm.InactiveLists, ","), + `%s 0 IN TXT "would be blocked by filter lists %s"`, + lm.Entity, + strings.Join(lm.InactiveLists, ", "), )) if err == nil { rrs = append(rrs, wouldBeBlockedBy) diff --git a/nameserver/nameserver.go b/nameserver/nameserver.go index c68e62a2..97d51e8f 100644 --- a/nameserver/nameserver.go +++ b/nameserver/nameserver.go @@ -6,6 +6,8 @@ import ( "net" "strings" + "github.com/safing/portmaster/network/packet" + "github.com/safing/portbase/modules/subsystems" "github.com/safing/portbase/log" @@ -22,9 +24,8 @@ import ( ) var ( - module *modules.Module - dnsServer *dns.Server - mtDNSRequest = "dns request" + module *modules.Module + dnsServer *dns.Server listenAddress = "0.0.0.0:53" ipv4Localhost = net.IPv4(127, 0, 0, 1) @@ -61,7 +62,7 @@ func prep() error { func start() error { dnsServer = &dns.Server{Addr: listenAddress, Net: "udp"} - dns.HandleFunc(".", handleRequestAsMicroTask) + dns.HandleFunc(".", handleRequestAsWorker) module.StartServiceWorker("dns resolver", 0, func(ctx context.Context) error { err := dnsServer.ListenAndServe() @@ -95,8 +96,8 @@ func returnServerFailure(w dns.ResponseWriter, query *dns.Msg) { _ = w.WriteMsg(m) } -func handleRequestAsMicroTask(w dns.ResponseWriter, query *dns.Msg) { - err := module.RunMicroTask(&mtDNSRequest, func(ctx context.Context) error { +func handleRequestAsWorker(w dns.ResponseWriter, query *dns.Msg) { + err := module.RunWorker("dns request", func(ctx context.Context) error { return handleRequest(ctx, w, query) }) if err != nil { @@ -168,12 +169,13 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, query *dns.Msg) er // start tracer ctx, tracer := log.AddTracer(ctx) - tracer.Tracef("nameserver: handling new request for %s%s from %s:%d", q.FQDN, q.QType, remoteAddr.IP, remoteAddr.Port) + defer tracer.Submit() + tracer.Tracef("nameserver: handling new request for %s%s from %s:%d, getting connection", q.FQDN, q.QType, remoteAddr.IP, remoteAddr.Port) // TODO: if there are 3 request for the same domain/type in a row, delete all caches of that domain // get connection - conn := network.NewConnectionFromDNSRequest(ctx, q.FQDN, nil, remoteAddr.IP, uint16(remoteAddr.Port)) + conn := network.NewConnectionFromDNSRequest(ctx, q.FQDN, nil, packet.IPv4, remoteAddr.IP, uint16(remoteAddr.Port)) // once we decided on the connection we might need to save it to the database // so we defer that check right now. @@ -191,7 +193,7 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, query *dns.Msg) er return default: - log.Warningf("nameserver: unexpected verdict %s for connection %s, not saving", conn.Verdict, conn) + tracer.Warningf("nameserver: unexpected verdict %s for connection %s, not saving", conn.Verdict, conn) } }() @@ -220,7 +222,7 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, query *dns.Msg) er } // check profile before we even get intel and rr - firewall.DecideOnConnection(conn, nil) + firewall.DecideOnConnection(ctx, conn, nil) switch conn.Verdict { case network.VerdictBlock: @@ -242,7 +244,7 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, query *dns.Msg) er tracer.Infof("nameserver: %s handing over to reason-responder: %s", q.FQDN, conn.Reason) reply := responder.ReplyWithDNS(query, conn.Reason, conn.ReasonContext) if err := w.WriteMsg(reply); err != nil { - log.Warningf("nameserver: failed to return response %s%s to %s: %s", q.FQDN, q.QType, conn.Process(), err) + tracer.Warningf("nameserver: failed to return response %s%s to %s: %s", q.FQDN, q.QType, conn.Process(), err) } else { tracer.Debugf("nameserver: returning response %s%s to %s", q.FQDN, q.QType, conn.Process()) } @@ -269,6 +271,7 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, query *dns.Msg) er return nil } + tracer.Trace("nameserver: deciding on resolved dns") rrCache = firewall.DecideOnResolvedDNS(conn, q, rrCache) if rrCache == nil { sendResponse(w, query, conn.Verdict, conn.Reason, conn.ReasonContext) @@ -283,7 +286,7 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, query *dns.Msg) er m.Extra = rrCache.Extra if err := w.WriteMsg(m); err != nil { - log.Warningf("nameserver: failed to return response %s%s to %s: %s", q.FQDN, q.QType, conn.Process(), err) + tracer.Warningf("nameserver: failed to return response %s%s to %s: %s", q.FQDN, q.QType, conn.Process(), err) } else { tracer.Debugf("nameserver: returning response %s%s to %s", q.FQDN, q.QType, conn.Process()) } diff --git a/nameserver/takeover.go b/nameserver/takeover.go index e55aa46c..ecbea5cf 100644 --- a/nameserver/takeover.go +++ b/nameserver/takeover.go @@ -10,7 +10,7 @@ import ( "github.com/safing/portbase/modules" "github.com/safing/portbase/notifications" "github.com/safing/portmaster/network/packet" - "github.com/safing/portmaster/process" + "github.com/safing/portmaster/network/state" ) var ( @@ -58,7 +58,15 @@ func checkForConflictingService() error { } func takeover(resolverIP net.IP) (int, error) { - pid, _, err := process.GetPidByEndpoints(resolverIP, 53, resolverIP, 65535, packet.UDP) + pid, _, err := state.Lookup(&packet.Info{ + Inbound: true, + Version: 0, // auto-detect + Protocol: packet.UDP, + Src: nil, // do not record direction + SrcPort: 0, // do not record direction + Dst: resolverIP, + DstPort: 53, + }) if err != nil { // there may be nothing listening on :53 return 0, nil diff --git a/netenv/online-status.go b/netenv/online-status.go index b7542b3b..16a6c50d 100644 --- a/netenv/online-status.go +++ b/netenv/online-status.go @@ -201,20 +201,11 @@ func triggerOnlineStatusInvestigation() { func monitorOnlineStatus(ctx context.Context) error { for { - timeout := 5 * time.Minute - /* - if GetOnlineStatus() != StatusOnline { - timeout = time.Second - log.Debugf("checking online status again in %s because current status is %s", timeout, GetOnlineStatus()) - } - */ // wait for trigger select { case <-ctx.Done(): return nil case <-onlineStatusInvestigationTrigger: - - case <-time.After(timeout): } // enable waiting diff --git a/network/clean.go b/network/clean.go index 3b1bbac9..3d3a8c23 100644 --- a/network/clean.go +++ b/network/clean.go @@ -4,6 +4,10 @@ import ( "context" "time" + "github.com/safing/portmaster/network/packet" + + "github.com/safing/portmaster/network/state" + "github.com/safing/portbase/log" "github.com/safing/portmaster/process" ) @@ -22,8 +26,12 @@ func connectionCleaner(ctx context.Context) error { ticker.Stop() return nil case <-ticker.C: + // clean connections and processes activePIDs := cleanConnections() process.CleanProcessStorage(activePIDs) + + // clean udp connection states + state.CleanUDPStates(ctx) } } } @@ -33,13 +41,10 @@ func cleanConnections() (activePIDs map[int]struct{}) { name := "clean connections" // TODO: change to new fn _ = module.RunMediumPriorityMicroTask(&name, func(ctx context.Context) error { - activeIDs := make(map[string]struct{}) - for _, cID := range process.GetActiveConnectionIDs() { - activeIDs[cID] = struct{}{} - } - now := time.Now().Unix() - deleteOlderThan := time.Now().Add(-deleteConnsAfterEndedThreshold).Unix() + now := time.Now().UTC() + nowUnix := now.Unix() + deleteOlderThan := now.Add(-deleteConnsAfterEndedThreshold).Unix() // lock both together because we cannot fully guarantee in which map a connection lands // of course every connection should land in the correct map, but this increases resilience @@ -49,20 +54,27 @@ func cleanConnections() (activePIDs map[int]struct{}) { defer dnsConnsLock.Unlock() // network connections - for key, conn := range conns { + for _, conn := range conns { conn.Lock() // delete inactive connections switch { case conn.Ended == 0: // Step 1: check if still active - _, ok := activeIDs[key] - if ok { - activePIDs[conn.process.Pid] = struct{}{} - } else { + exists := state.Exists(&packet.Info{ + Inbound: false, // src == local + Version: conn.IPVersion, + Protocol: conn.IPProtocol, + Src: conn.LocalIP, + SrcPort: conn.LocalPort, + Dst: conn.Entity.IP, + DstPort: conn.Entity.Port, + }, now) + activePIDs[conn.process.Pid] = struct{}{} + + if !exists { // Step 2: mark end - activePIDs[conn.process.Pid] = struct{}{} - conn.Ended = now + conn.Ended = nowUnix conn.Save() } case conn.Ended < deleteOlderThan: diff --git a/network/connection.go b/network/connection.go index bb5529f1..4c5acbcf 100644 --- a/network/connection.go +++ b/network/connection.go @@ -25,11 +25,19 @@ type Connection struct { //nolint:maligned // TODO: fix alignment record.Base sync.Mutex - ID string - Scope string - Inbound bool - Entity *intel.Entity // needs locking, instance is never shared - process *process.Process + ID string + Scope string + IPVersion packet.IPVersion + Inbound bool + + // local endpoint + IPProtocol packet.IPProtocol + LocalIP net.IP + LocalPort uint16 + process *process.Process + + // remote endpoint + Entity *intel.Entity Verdict Verdict Reason string @@ -55,11 +63,22 @@ type Connection struct { //nolint:maligned // TODO: fix alignment } // NewConnectionFromDNSRequest returns a new connection based on the given dns request. -func NewConnectionFromDNSRequest(ctx context.Context, fqdn string, cnames []string, localIP net.IP, localPort uint16) *Connection { +func NewConnectionFromDNSRequest(ctx context.Context, fqdn string, cnames []string, ipVersion packet.IPVersion, localIP net.IP, localPort uint16) *Connection { // get Process - proc, err := process.GetProcessByEndpoints(ctx, localIP, localPort, dnsAddress, dnsPort, packet.UDP) + proc, _, err := process.GetProcessByConnection( + ctx, + &packet.Info{ + Inbound: false, // outbound as we are looking for the process of the source address + Version: ipVersion, + Protocol: packet.UDP, + Src: localIP, // source as in the process we are looking for + SrcPort: localPort, // source as in the process we are looking for + Dst: nil, // do not record direction + DstPort: 0, // do not record direction + }, + ) if err != nil { - log.Warningf("network: failed to find process of dns request for %s: %s", fqdn, err) + log.Debugf("network: failed to find process of dns request for %s: %s", fqdn, err) proc = process.GetUnidentifiedProcess(ctx) } @@ -80,9 +99,9 @@ func NewConnectionFromDNSRequest(ctx context.Context, fqdn string, cnames []stri // NewConnectionFromFirstPacket returns a new connection based on the given packet. func NewConnectionFromFirstPacket(pkt packet.Packet) *Connection { // get Process - proc, inbound, err := process.GetProcessByPacket(pkt) + proc, inbound, err := process.GetProcessByConnection(pkt.Ctx(), pkt.Info()) if err != nil { - log.Warningf("network: failed to find process of packet %s: %s", pkt, err) + log.Debugf("network: failed to find process of packet %s: %s", pkt, err) proc = process.GetUnidentifiedProcess(pkt.Ctx()) } @@ -147,12 +166,20 @@ func NewConnectionFromFirstPacket(pkt packet.Packet) *Connection { } return &Connection{ - ID: pkt.GetConnectionID(), - Scope: scope, - Inbound: inbound, - Entity: entity, - process: proc, - Started: time.Now().Unix(), + ID: pkt.GetConnectionID(), + Scope: scope, + IPVersion: pkt.Info().Version, + Inbound: inbound, + // local endpoint + IPProtocol: pkt.Info().Protocol, + LocalIP: pkt.Info().LocalIP(), + LocalPort: pkt.Info().LocalPort(), + process: proc, + // remote endpoint + Entity: entity, + // meta + Started: time.Now().Unix(), + profileRevisionCounter: proc.Profile().RevisionCnt(), } } diff --git a/network/database.go b/network/database.go index ee42a5b1..a44a379c 100644 --- a/network/database.go +++ b/network/database.go @@ -5,6 +5,8 @@ import ( "strings" "sync" + "github.com/safing/portmaster/network/state" + "github.com/safing/portbase/database" "github.com/safing/portbase/database/iterator" "github.com/safing/portbase/database/query" @@ -57,6 +59,14 @@ func (s *StorageInterface) Get(key string) (record.Record, error) { return conn, nil } } + case "system": + if len(splitted) >= 2 { + switch splitted[1] { + case "state": + return state.GetInfo(), nil + default: + } + } } return nil, storage.ErrNotFound diff --git a/network/iphelper/get.go b/network/iphelper/get.go new file mode 100644 index 00000000..e92f929c --- /dev/null +++ b/network/iphelper/get.go @@ -0,0 +1,71 @@ +// +build windows + +package iphelper + +import ( + "sync" + + "github.com/safing/portmaster/network/socket" +) + +var ( + ipHelper *IPHelper + + // lock locks access to the whole DLL. + // TODO: It's unproven if we can access the iphlpapi.dll concurrently, especially as we might be encountering various versions of the DLL. In the future, we could possibly investigate and improve performance here. + lock sync.RWMutex +) + +// GetTCP4Table returns the system table for IPv4 TCP activity. +func GetTCP4Table() (connections []*socket.ConnectionInfo, listeners []*socket.BindInfo, err error) { + lock.Lock() + defer lock.Unlock() + + err = checkIPHelper() + if err != nil { + return nil, nil, err + } + + return ipHelper.getTable(IPv4, TCP) +} + +// GetTCP6Table returns the system table for IPv6 TCP activity. +func GetTCP6Table() (connections []*socket.ConnectionInfo, listeners []*socket.BindInfo, err error) { + lock.Lock() + defer lock.Unlock() + + err = checkIPHelper() + if err != nil { + return nil, nil, err + } + + return ipHelper.getTable(IPv6, TCP) +} + +// GetUDP4Table returns the system table for IPv4 UDP activity. +func GetUDP4Table() (binds []*socket.BindInfo, err error) { + lock.Lock() + defer lock.Unlock() + + err = checkIPHelper() + if err != nil { + return nil, err + } + + _, binds, err = ipHelper.getTable(IPv4, UDP) + return +} + +// GetUDP6Table returns the system table for IPv6 UDP activity. +func GetUDP6Table() (binds []*socket.BindInfo, err error) { + lock.Lock() + defer lock.Unlock() + + err = checkIPHelper() + if err != nil { + return nil, err + } + + _, binds, err = ipHelper.getTable(IPv6, UDP) + return +} diff --git a/network/iphelper/iphelper.go b/network/iphelper/iphelper.go new file mode 100644 index 00000000..5498879a --- /dev/null +++ b/network/iphelper/iphelper.go @@ -0,0 +1,63 @@ +// +build windows + +package iphelper + +import ( + "errors" + "fmt" + + "github.com/tevino/abool" + "golang.org/x/sys/windows" +) + +var ( + errInvalid = errors.New("IPHelper not initialzed or broken") +) + +// IPHelper represents a subset of the Windows iphlpapi.dll. +type IPHelper struct { + dll *windows.LazyDLL + + getExtendedTCPTable *windows.LazyProc + getExtendedUDPTable *windows.LazyProc + + valid *abool.AtomicBool +} + +func checkIPHelper() (err error) { + if ipHelper == nil { + ipHelper, err = New() + return err + } + return nil +} + +// New returns a new IPHelper API (with an instance of iphlpapi.dll loaded). +func New() (*IPHelper, error) { + + new := &IPHelper{} + new.valid = abool.NewBool(false) + var err error + + // load dll + new.dll = windows.NewLazySystemDLL("iphlpapi.dll") + err = new.dll.Load() + if err != nil { + return nil, err + } + + // load functions + new.getExtendedTCPTable = new.dll.NewProc("GetExtendedTcpTable") + err = new.getExtendedTCPTable.Find() + if err != nil { + return nil, fmt.Errorf("could find proc GetExtendedTcpTable: %s", err) + } + new.getExtendedUDPTable = new.dll.NewProc("GetExtendedUdpTable") + err = new.getExtendedUDPTable.Find() + if err != nil { + return nil, fmt.Errorf("could find proc GetExtendedUdpTable: %s", err) + } + + new.valid.Set() + return new, nil +} diff --git a/process/iphelper/tables.go b/network/iphelper/tables.go similarity index 69% rename from process/iphelper/tables.go rename to network/iphelper/tables.go index 8ffecfd7..8c1fd8a7 100644 --- a/process/iphelper/tables.go +++ b/network/iphelper/tables.go @@ -3,12 +3,15 @@ package iphelper import ( + "encoding/binary" "errors" "fmt" "net" "sync" "unsafe" + "github.com/safing/portmaster/network/socket" + "golang.org/x/sys/windows" ) @@ -22,19 +25,6 @@ const ( winErrInvalidParameter = uintptr(windows.ERROR_INVALID_PARAMETER) ) -// ConnectionEntry describes a connection state table entry. -type ConnectionEntry struct { - localIP net.IP - remoteIP net.IP - localPort uint16 - remotePort uint16 - pid int -} - -func (entry *ConnectionEntry) String() string { - return fmt.Sprintf("PID=%d %s:%d <> %s:%d", entry.pid, entry.localIP, entry.localPort, entry.remoteIP, entry.remotePort) -} - type iphelperTCPTable struct { // docs: https://msdn.microsoft.com/en-us/library/windows/desktop/aa366921(v=vs.85).aspx numEntries uint32 @@ -148,9 +138,9 @@ func increaseBufSize() int { return bufSize } -// GetTables returns the current connection state table of Windows of the given protocol and IP version. -func (ipHelper *IPHelper) GetTables(protocol uint8, ipVersion uint8) (connections []*ConnectionEntry, listeners []*ConnectionEntry, err error) { //nolint:gocognit,gocycle // TODO - // docs: https://msdn.microsoft.com/en-us/library/windows/desktop/aa365928(v=vs.85).aspx +// getTable returns the current connection state table of Windows of the given protocol and IP version. +func (ipHelper *IPHelper) getTable(ipVersion, protocol uint8) (connections []*socket.ConnectionInfo, binds []*socket.BindInfo, err error) { //nolint:gocognit,gocycle // TODO + // docs: https://docs.microsoft.com/en-us/windows/win32/api/iphlpapi/nf-iphlpapi-getextendedtcptable if !ipHelper.valid.IsSet() { return nil, nil, errInvalid @@ -220,26 +210,27 @@ func (ipHelper *IPHelper) GetTables(protocol uint8, ipVersion uint8) (connection table := tcpTable.table[:tcpTable.numEntries] for _, row := range table { - new := &ConnectionEntry{} - - // PID - new.pid = int(row.owningPid) - - // local - if row.localAddr != 0 { - new.localIP = convertIPv4(row.localAddr) - } - new.localPort = uint16(row.localPort>>8 | row.localPort<<8) - - // remote if row.state == iphelperTCPStateListen { - listeners = append(listeners, new) + binds = append(binds, &socket.BindInfo{ + Local: socket.Address{ + IP: convertIPv4(row.localAddr), + Port: convertPort(row.localPort), + }, + PID: int(row.owningPid), + }) } else { - new.remoteIP = convertIPv4(row.remoteAddr) - new.remotePort = uint16(row.remotePort>>8 | row.remotePort<<8) - connections = append(connections, new) + connections = append(connections, &socket.ConnectionInfo{ + Local: socket.Address{ + IP: convertIPv4(row.localAddr), + Port: convertPort(row.localPort), + }, + Remote: socket.Address{ + IP: convertIPv4(row.remoteAddr), + Port: convertPort(row.remotePort), + }, + PID: int(row.owningPid), + }) } - } case protocol == TCP && ipVersion == IPv6: @@ -248,27 +239,27 @@ func (ipHelper *IPHelper) GetTables(protocol uint8, ipVersion uint8) (connection table := tcpTable.table[:tcpTable.numEntries] for _, row := range table { - new := &ConnectionEntry{} - - // PID - new.pid = int(row.owningPid) - - // local - new.localIP = net.IP(row.localAddr[:]) - new.localPort = uint16(row.localPort>>8 | row.localPort<<8) - - // remote if row.state == iphelperTCPStateListen { - if new.localIP.Equal(net.IPv6zero) { - new.localIP = nil - } - listeners = append(listeners, new) + binds = append(binds, &socket.BindInfo{ + Local: socket.Address{ + IP: net.IP(row.localAddr[:]), + Port: convertPort(row.localPort), + }, + PID: int(row.owningPid), + }) } else { - new.remoteIP = net.IP(row.remoteAddr[:]) - new.remotePort = uint16(row.remotePort>>8 | row.remotePort<<8) - connections = append(connections, new) + connections = append(connections, &socket.ConnectionInfo{ + Local: socket.Address{ + IP: net.IP(row.localAddr[:]), + Port: convertPort(row.localPort), + }, + Remote: socket.Address{ + IP: net.IP(row.remoteAddr[:]), + Port: convertPort(row.remotePort), + }, + PID: int(row.owningPid), + }) } - } case protocol == UDP && ipVersion == IPv4: @@ -277,19 +268,13 @@ func (ipHelper *IPHelper) GetTables(protocol uint8, ipVersion uint8) (connection table := udpTable.table[:udpTable.numEntries] for _, row := range table { - new := &ConnectionEntry{} - - // PID - new.pid = int(row.owningPid) - - // local - new.localPort = uint16(row.localPort>>8 | row.localPort<<8) - if row.localAddr == 0 { - listeners = append(listeners, new) - } else { - new.localIP = convertIPv4(row.localAddr) - connections = append(connections, new) - } + binds = append(binds, &socket.BindInfo{ + Local: socket.Address{ + IP: convertIPv4(row.localAddr), + Port: convertPort(row.localPort), + }, + PID: int(row.owningPid), + }) } case protocol == UDP && ipVersion == IPv6: @@ -298,32 +283,28 @@ func (ipHelper *IPHelper) GetTables(protocol uint8, ipVersion uint8) (connection table := udpTable.table[:udpTable.numEntries] for _, row := range table { - new := &ConnectionEntry{} - - // PID - new.pid = int(row.owningPid) - - // local - new.localIP = net.IP(row.localAddr[:]) - new.localPort = uint16(row.localPort>>8 | row.localPort<<8) - if new.localIP.Equal(net.IPv6zero) { - new.localIP = nil - listeners = append(listeners, new) - } else { - connections = append(connections, new) - } + binds = append(binds, &socket.BindInfo{ + Local: socket.Address{ + IP: net.IP(row.localAddr[:]), + Port: convertPort(row.localPort), + }, + PID: int(row.owningPid), + }) } } - return connections, listeners, nil + return connections, binds, nil } +// convertIPv4 as needed for iphlpapi.dll func convertIPv4(input uint32) net.IP { - return net.IPv4( - uint8(input&0xFF), - uint8(input>>8&0xFF), - uint8(input>>16&0xFF), - uint8(input>>24&0xFF), - ) + addressBuf := make([]byte, 4) + binary.LittleEndian.PutUint32(addressBuf, input) + return net.IP(addressBuf) +} + +// convertPort converts ports received from iphlpapi.dll +func convertPort(input uint32) uint16 { + return uint16(input>>8 | input<<8) } diff --git a/network/iphelper/tables_test.go b/network/iphelper/tables_test.go new file mode 100644 index 00000000..e996219e --- /dev/null +++ b/network/iphelper/tables_test.go @@ -0,0 +1,54 @@ +// +build windows + +package iphelper + +import ( + "fmt" + "testing" +) + +func TestSockets(t *testing.T) { + connections, listeners, err := GetTCP4Table() + if err != nil { + t.Fatal(err) + } + fmt.Println("\nTCP 4 connections:") + for _, connection := range connections { + fmt.Printf("%+v\n", connection) + } + fmt.Println("\nTCP 4 listeners:") + for _, listener := range listeners { + fmt.Printf("%+v\n", listener) + } + + connections, listeners, err = GetTCP6Table() + if err != nil { + t.Fatal(err) + } + fmt.Println("\nTCP 6 connections:") + for _, connection := range connections { + fmt.Printf("%+v\n", connection) + } + fmt.Println("\nTCP 6 listeners:") + for _, listener := range listeners { + fmt.Printf("%+v\n", listener) + } + + binds, err := GetUDP4Table() + if err != nil { + t.Fatal(err) + } + fmt.Println("\nUDP 4 binds:") + for _, bind := range binds { + fmt.Printf("%+v\n", bind) + } + + binds, err = GetUDP6Table() + if err != nil { + t.Fatal(err) + } + fmt.Println("\nUDP 6 binds:") + for _, bind := range binds { + fmt.Printf("%+v\n", bind) + } +} diff --git a/network/module.go b/network/module.go index 70f2fd24..a1ee69f5 100644 --- a/network/module.go +++ b/network/module.go @@ -1,17 +1,12 @@ package network import ( - "net" - "github.com/safing/portbase/modules" ) var ( module *modules.Module - dnsAddress = net.IPv4(127, 0, 0, 1) - dnsPort uint16 = 53 - defaultFirewallHandler FirewallHandler ) diff --git a/network/netutils/cleandns.go b/network/netutils/cleandns.go index f9487664..7df738d0 100644 --- a/network/netutils/cleandns.go +++ b/network/netutils/cleandns.go @@ -2,13 +2,57 @@ package netutils import ( "regexp" + + "github.com/miekg/dns" ) var ( - cleanDomainRegex = regexp.MustCompile(`^((xn--)?[a-z0-9-_]{0,61}[a-z0-9]{1,1}\.)*(xn--)?([a-z0-9-]{1,61}|[a-z0-9-]{1,30}\.[a-z]{2,}\.)$`) + cleanDomainRegex = regexp.MustCompile( + `^` + // match beginning + `(` + // start subdomain group + `(xn--)?` + // idn prefix + `[a-z0-9_-]{1,63}` + // main chunk + `\.` + // ending with a dot + `)*` + // end subdomain group, allow any number of subdomains + `(xn--)?` + // TLD idn prefix + `[a-z0-9_-]{2,63}` + // TLD main chunk with at least two characters + `\.` + // ending with a dot + `$`, // match end + ) ) // IsValidFqdn returns whether the given string is a valid fqdn. func IsValidFqdn(fqdn string) bool { - return cleanDomainRegex.MatchString(fqdn) + // root zone + if fqdn == "." { + return true + } + + // check max length + if len(fqdn) > 256 { + return false + } + + // check with regex + if !cleanDomainRegex.MatchString(fqdn) { + return false + } + + // check with miegk/dns + + // IsFqdn checks if a domain name is fully qualified. + if !dns.IsFqdn(fqdn) { + return false + } + + // IsDomainName checks if s is a valid domain name, it returns the number of + // labels and true, when a domain name is valid. Note that non fully qualified + // domain name is considered valid, in this case the last label is counted in + // the number of labels. When false is returned the number of labels is not + // defined. Also note that this function is extremely liberal; almost any + // string is a valid domain name as the DNS is 8 bit protocol. It checks if each + // label fits in 63 characters and that the entire name will fit into the 255 + // octet wire format limit. + _, ok := dns.IsDomainName(fqdn) + return ok } diff --git a/network/netutils/cleandns_test.go b/network/netutils/cleandns_test.go new file mode 100644 index 00000000..4f0dacb0 --- /dev/null +++ b/network/netutils/cleandns_test.go @@ -0,0 +1,43 @@ +package netutils + +import "testing" + +func testDomainValidity(t *testing.T, domain string, isValid bool) { + if IsValidFqdn(domain) != isValid { + t.Errorf("domain %s failed check: was valid=%v, expected valid=%v", domain, IsValidFqdn(domain), isValid) + } +} + +func TestDNSValidation(t *testing.T) { + // valid + testDomainValidity(t, ".", true) + testDomainValidity(t, "at.", true) + testDomainValidity(t, "orf.at.", true) + testDomainValidity(t, "www.orf.at.", true) + testDomainValidity(t, "a.b.c.d.e.f.g.h.i.j.k.l.m.n.o.p.q.r.s.t.u.v.x.y.z.example.org.", true) + testDomainValidity(t, "a_a.com.", true) + testDomainValidity(t, "a-a.com.", true) + testDomainValidity(t, "a_a.com.", true) + testDomainValidity(t, "a-a.com.", true) + testDomainValidity(t, "xn--a.com.", true) + testDomainValidity(t, "xn--asdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasd.com.", true) + + // maybe valid + testDomainValidity(t, "-.com.", true) + testDomainValidity(t, "_.com.", true) + testDomainValidity(t, "a_.com.", true) + testDomainValidity(t, "a-.com.", true) + testDomainValidity(t, "_a.com.", true) + testDomainValidity(t, "-a.com.", true) + + // invalid + testDomainValidity(t, ".com.", false) + testDomainValidity(t, ".com.", false) + testDomainValidity(t, "xn--asdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdf.com.", false) + testDomainValidity(t, "asdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdf.com.", false) + testDomainValidity(t, "asdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdf.com.", false) + testDomainValidity(t, "asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.as.com.", false) + + // real world examples + testDomainValidity(t, "iuqerfsodp9ifjaposdfjhgosurijfaewrwergwea.com.", true) +} diff --git a/network/packet/packet.go b/network/packet/packet.go index 942dd215..8076ed69 100644 --- a/network/packet/packet.go +++ b/network/packet/packet.go @@ -36,22 +36,22 @@ func (pkt *Base) SetPacketInfo(packetInfo Info) { // SetInbound sets a the packet direction to inbound. This must only used when initializing the packet structure. func (pkt *Base) SetInbound() { - pkt.info.Direction = true + pkt.info.Inbound = true } // SetOutbound sets a the packet direction to outbound. This must only used when initializing the packet structure. func (pkt *Base) SetOutbound() { - pkt.info.Direction = false + pkt.info.Inbound = false } // IsInbound checks if the packet is inbound. func (pkt *Base) IsInbound() bool { - return pkt.info.Direction + return pkt.info.Inbound } // IsOutbound checks if the packet is outbound. func (pkt *Base) IsOutbound() bool { - return !pkt.info.Direction + return !pkt.info.Inbound } // HasPorts checks if the packet has a protocol that uses ports. @@ -80,13 +80,13 @@ func (pkt *Base) GetConnectionID() string { func (pkt *Base) createConnectionID() { if pkt.info.Protocol == TCP || pkt.info.Protocol == UDP { - if pkt.info.Direction { + if pkt.info.Inbound { pkt.connID = fmt.Sprintf("%d-%s-%d-%s-%d", pkt.info.Protocol, pkt.info.Dst, pkt.info.DstPort, pkt.info.Src, pkt.info.SrcPort) } else { pkt.connID = fmt.Sprintf("%d-%s-%d-%s-%d", pkt.info.Protocol, pkt.info.Src, pkt.info.SrcPort, pkt.info.Dst, pkt.info.DstPort) } } else { - if pkt.info.Direction { + if pkt.info.Inbound { pkt.connID = fmt.Sprintf("%d-%s-%s", pkt.info.Protocol, pkt.info.Dst, pkt.info.Src) } else { pkt.connID = fmt.Sprintf("%d-%s-%s", pkt.info.Protocol, pkt.info.Src, pkt.info.Dst) @@ -105,7 +105,7 @@ func (pkt *Base) MatchesAddress(remote bool, protocol IPProtocol, network *net.I if pkt.info.Protocol != protocol { return false } - if pkt.info.Direction != remote { + if pkt.info.Inbound != remote { if !network.Contains(pkt.info.Src) { return false } @@ -131,7 +131,7 @@ func (pkt *Base) MatchesAddress(remote bool, protocol IPProtocol, network *net.I // Remote Src Dst // func (pkt *Base) MatchesIP(endpoint bool, network *net.IPNet) bool { - if pkt.info.Direction != endpoint { + if pkt.info.Inbound != endpoint { if network.Contains(pkt.info.Src) { return true } @@ -152,12 +152,12 @@ func (pkt *Base) String() string { // FmtPacket returns the most important information about the packet as a string func (pkt *Base) FmtPacket() string { if pkt.info.Protocol == TCP || pkt.info.Protocol == UDP { - if pkt.info.Direction { + if pkt.info.Inbound { return fmt.Sprintf("IN %s %s:%d <-> %s:%d", pkt.info.Protocol, pkt.info.Dst, pkt.info.DstPort, pkt.info.Src, pkt.info.SrcPort) } return fmt.Sprintf("OUT %s %s:%d <-> %s:%d", pkt.info.Protocol, pkt.info.Src, pkt.info.SrcPort, pkt.info.Dst, pkt.info.DstPort) } - if pkt.info.Direction { + if pkt.info.Inbound { return fmt.Sprintf("IN %s %s <-> %s", pkt.info.Protocol, pkt.info.Dst, pkt.info.Src) } return fmt.Sprintf("OUT %s %s <-> %s", pkt.info.Protocol, pkt.info.Src, pkt.info.Dst) @@ -170,7 +170,7 @@ func (pkt *Base) FmtProtocol() string { // FmtRemoteIP returns the remote IP address as a string func (pkt *Base) FmtRemoteIP() string { - if pkt.info.Direction { + if pkt.info.Inbound { return pkt.info.Src.String() } return pkt.info.Dst.String() @@ -179,7 +179,7 @@ func (pkt *Base) FmtRemoteIP() string { // FmtRemotePort returns the remote port as a string func (pkt *Base) FmtRemotePort() string { if pkt.info.SrcPort != 0 { - if pkt.info.Direction { + if pkt.info.Inbound { return fmt.Sprintf("%d", pkt.info.SrcPort) } return fmt.Sprintf("%d", pkt.info.DstPort) diff --git a/network/packet/packetinfo.go b/network/packet/packetinfo.go index a98fc8a5..3e68e8af 100644 --- a/network/packet/packetinfo.go +++ b/network/packet/packetinfo.go @@ -6,8 +6,8 @@ import ( // Info holds IP and TCP/UDP header information type Info struct { - Direction bool - InTunnel bool + Inbound bool + InTunnel bool Version IPVersion Protocol IPProtocol @@ -17,7 +17,7 @@ type Info struct { // LocalIP returns the local IP of the packet. func (pi *Info) LocalIP() net.IP { - if pi.Direction { + if pi.Inbound { return pi.Dst } return pi.Src @@ -25,7 +25,7 @@ func (pi *Info) LocalIP() net.IP { // RemoteIP returns the remote IP of the packet. func (pi *Info) RemoteIP() net.IP { - if pi.Direction { + if pi.Inbound { return pi.Src } return pi.Dst @@ -33,7 +33,7 @@ func (pi *Info) RemoteIP() net.IP { // LocalPort returns the local port of the packet. func (pi *Info) LocalPort() uint16 { - if pi.Direction { + if pi.Inbound { return pi.DstPort } return pi.SrcPort @@ -41,7 +41,7 @@ func (pi *Info) LocalPort() uint16 { // RemotePort returns the remote port of the packet. func (pi *Info) RemotePort() uint16 { - if pi.Direction { + if pi.Inbound { return pi.SrcPort } return pi.DstPort diff --git a/process/proc/processfinder.go b/network/proc/findpid.go similarity index 92% rename from process/proc/processfinder.go rename to network/proc/findpid.go index 5e6ed7cc..6808960e 100644 --- a/process/proc/processfinder.go +++ b/network/proc/findpid.go @@ -10,6 +10,8 @@ import ( "sync" "syscall" + "github.com/safing/portmaster/network/socket" + "github.com/safing/portbase/log" ) @@ -18,8 +20,8 @@ var ( pidsByUser = make(map[int][]int) ) -// GetPidOfInode returns the pid of the given uid and socket inode. -func GetPidOfInode(uid, inode int) (int, bool) { //nolint:gocognit // TODO +// FindPID returns the pid of the given uid and socket inode. +func FindPID(uid, inode int) (pid int) { //nolint:gocognit // TODO pidsByUserLock.Lock() defer pidsByUserLock.Unlock() @@ -38,7 +40,7 @@ func GetPidOfInode(uid, inode int) (int, bool) { //nolint:gocognit // TODO var checkedUserPids []int for _, possiblePID := range pids { if findSocketFromPid(possiblePID, inode) { - return possiblePID, true + return possiblePID } checkedUserPids = append(checkedUserPids, possiblePID) } @@ -57,7 +59,7 @@ func GetPidOfInode(uid, inode int) (int, bool) { //nolint:gocognit // TODO // only check if not already checked if sort.SearchInts(checkedUserPids, possiblePID) == len { if findSocketFromPid(possiblePID, inode) { - return possiblePID, true + return possiblePID } } } @@ -71,13 +73,13 @@ func GetPidOfInode(uid, inode int) (int, bool) { //nolint:gocognit // TODO if possibleUID != uid { for _, possiblePID := range pids { if findSocketFromPid(possiblePID, inode) { - return possiblePID, true + return possiblePID } } } } - return unidentifiedProcessID, false + return socket.UnidentifiedProcessID } func findSocketFromPid(pid, inode int) bool { diff --git a/network/proc/tables.go b/network/proc/tables.go new file mode 100644 index 00000000..bf4a3eb0 --- /dev/null +++ b/network/proc/tables.go @@ -0,0 +1,237 @@ +// +build linux + +package proc + +import ( + "bufio" + "encoding/hex" + "fmt" + "net" + "os" + "strconv" + "strings" + "unicode" + + "github.com/safing/portmaster/network/socket" + + "github.com/safing/portbase/log" +) + +/* + +1. find socket inode + - by incoming (listenting sockets) or outgoing (local port + external IP + port) - also local IP? + - /proc/net/{tcp|udp}[6] + +2. get list of processes of uid + +3. find socket inode in process fds + - if not found, refresh map of uid->pids + - if not found, check ALL pids: maybe euid != uid + +4. gather process info + +Cache every step! + +*/ + +// Network Related Constants +const ( + TCP4 uint8 = iota + UDP4 + TCP6 + UDP6 + ICMP4 + ICMP6 + + tcp4ProcFile = "/proc/net/tcp" + tcp6ProcFile = "/proc/net/tcp6" + udp4ProcFile = "/proc/net/udp" + udp6ProcFile = "/proc/net/udp6" + + UnfetchedProcessID = -2 + + tcpListenStateHex = "0A" +) + +// GetTCP4Table returns the system table for IPv4 TCP activity. +func GetTCP4Table() (connections []*socket.ConnectionInfo, listeners []*socket.BindInfo, err error) { + return getTableFromSource(TCP4, tcp4ProcFile) +} + +// GetTCP6Table returns the system table for IPv6 TCP activity. +func GetTCP6Table() (connections []*socket.ConnectionInfo, listeners []*socket.BindInfo, err error) { + return getTableFromSource(TCP6, tcp6ProcFile) +} + +// GetUDP4Table returns the system table for IPv4 UDP activity. +func GetUDP4Table() (binds []*socket.BindInfo, err error) { + _, binds, err = getTableFromSource(UDP4, udp4ProcFile) + return +} + +// GetUDP6Table returns the system table for IPv6 UDP activity. +func GetUDP6Table() (binds []*socket.BindInfo, err error) { + _, binds, err = getTableFromSource(UDP6, udp6ProcFile) + return +} + +const ( + // hint: we split fields by multiple delimiters, see procDelimiter + fieldIndexLocalIP = 1 + fieldIndexLocalPort = 2 + fieldIndexRemoteIP = 3 + fieldIndexRemotePort = 4 + fieldIndexUID = 11 + fieldIndexInode = 13 +) + +func getTableFromSource(stack uint8, procFile string) (connections []*socket.ConnectionInfo, binds []*socket.BindInfo, err error) { + + var ipConverter func(string) net.IP + switch stack { + case TCP4, UDP4: + ipConverter = convertIPv4 + case TCP6, UDP6: + ipConverter = convertIPv6 + default: + return nil, nil, fmt.Errorf("unsupported table stack: %d", stack) + } + + // open file + socketData, err := os.Open(procFile) + if err != nil { + return nil, nil, err + } + defer socketData.Close() + + // file scanner + scanner := bufio.NewScanner(socketData) + scanner.Split(bufio.ScanLines) + + // parse + scanner.Scan() // skip first row + for scanner.Scan() { + fields := strings.FieldsFunc(scanner.Text(), procDelimiter) + if len(fields) < 14 { + // log.Tracef("process: too short: %s", fields) + continue + } + + localIP := ipConverter(fields[fieldIndexLocalIP]) + if localIP == nil { + continue + } + + localPort, err := strconv.ParseUint(fields[fieldIndexLocalPort], 16, 16) + if err != nil { + log.Warningf("process: could not parse port: %s", err) + continue + } + + uid, err := strconv.ParseInt(fields[fieldIndexUID], 10, 32) + // log.Tracef("uid: %s", fields[fieldIndexUID]) + if err != nil { + log.Warningf("process: could not parse uid %s: %s", fields[11], err) + continue + } + + inode, err := strconv.ParseInt(fields[fieldIndexInode], 10, 32) + // log.Tracef("inode: %s", fields[fieldIndexInode]) + if err != nil { + log.Warningf("process: could not parse inode %s: %s", fields[13], err) + continue + } + + switch stack { + case UDP4, UDP6: + + binds = append(binds, &socket.BindInfo{ + Local: socket.Address{ + IP: localIP, + Port: uint16(localPort), + }, + PID: UnfetchedProcessID, + UID: int(uid), + Inode: int(inode), + }) + + case TCP4, TCP6: + + if fields[5] == tcpListenStateHex { + // listener + + binds = append(binds, &socket.BindInfo{ + Local: socket.Address{ + IP: localIP, + Port: uint16(localPort), + }, + PID: UnfetchedProcessID, + UID: int(uid), + Inode: int(inode), + }) + } else { + // connection + + remoteIP := ipConverter(fields[fieldIndexRemoteIP]) + if remoteIP == nil { + continue + } + + remotePort, err := strconv.ParseUint(fields[fieldIndexRemotePort], 16, 16) + if err != nil { + log.Warningf("process: could not parse port: %s", err) + continue + } + + connections = append(connections, &socket.ConnectionInfo{ + Local: socket.Address{ + IP: localIP, + Port: uint16(localPort), + }, + Remote: socket.Address{ + IP: remoteIP, + Port: uint16(remotePort), + }, + PID: UnfetchedProcessID, + UID: int(uid), + Inode: int(inode), + }) + } + } + } + + return connections, binds, nil +} + +func procDelimiter(c rune) bool { + return unicode.IsSpace(c) || c == ':' +} + +func convertIPv4(data string) net.IP { + decoded, err := hex.DecodeString(data) + if err != nil { + log.Warningf("process: could not parse IPv4 %s: %s", data, err) + return nil + } + if len(decoded) != 4 { + log.Warningf("process: decoded IPv4 %s has wrong length", decoded) + return nil + } + ip := net.IPv4(decoded[3], decoded[2], decoded[1], decoded[0]) + return ip +} + +func convertIPv6(data string) net.IP { + decoded, err := hex.DecodeString(data) + if err != nil { + log.Warningf("process: could not parse IPv6 %s: %s", data, err) + return nil + } + if len(decoded) != 16 { + log.Warningf("process: decoded IPv6 %s has wrong length", decoded) + return nil + } + ip := net.IP(decoded) + return ip +} diff --git a/network/proc/tables_test.go b/network/proc/tables_test.go new file mode 100644 index 00000000..eed12ce8 --- /dev/null +++ b/network/proc/tables_test.go @@ -0,0 +1,60 @@ +// +build linux + +package proc + +import ( + "fmt" + "testing" +) + +func TestSockets(t *testing.T) { + connections, listeners, err := GetTCP4Table() + if err != nil { + t.Fatal(err) + } + fmt.Println("\nTCP 4 connections:") + for _, connection := range connections { + pid := FindPID(connection.UID, connection.Inode) + fmt.Printf("%d: %+v\n", pid, connection) + } + fmt.Println("\nTCP 4 listeners:") + for _, listener := range listeners { + pid := FindPID(listener.UID, listener.Inode) + fmt.Printf("%d: %+v\n", pid, listener) + } + + connections, listeners, err = GetTCP6Table() + if err != nil { + t.Fatal(err) + } + fmt.Println("\nTCP 6 connections:") + for _, connection := range connections { + pid := FindPID(connection.UID, connection.Inode) + fmt.Printf("%d: %+v\n", pid, connection) + } + fmt.Println("\nTCP 6 listeners:") + for _, listener := range listeners { + pid := FindPID(listener.UID, listener.Inode) + fmt.Printf("%d: %+v\n", pid, listener) + } + + binds, err := GetUDP4Table() + if err != nil { + t.Fatal(err) + } + fmt.Println("\nUDP 4 binds:") + for _, bind := range binds { + pid := FindPID(bind.UID, bind.Inode) + fmt.Printf("%d: %+v\n", pid, bind) + } + + binds, err = GetUDP6Table() + if err != nil { + t.Fatal(err) + } + fmt.Println("\nUDP 6 binds:") + for _, bind := range binds { + pid := FindPID(bind.UID, bind.Inode) + fmt.Printf("%d: %+v\n", pid, bind) + } +} diff --git a/network/socket/socket.go b/network/socket/socket.go new file mode 100644 index 00000000..e8dfe1d9 --- /dev/null +++ b/network/socket/socket.go @@ -0,0 +1,31 @@ +package socket + +import "net" + +const ( + // UnidentifiedProcessID is originally defined in the process pkg, but duplicated here because of import loops. + UnidentifiedProcessID = -1 +) + +// ConnectionInfo holds socket information returned by the system. +type ConnectionInfo struct { + Local Address + Remote Address + PID int + UID int + Inode int +} + +// BindInfo holds socket information returned by the system. +type BindInfo struct { + Local Address + PID int + UID int + Inode int +} + +// Address is an IP + Port pair. +type Address struct { + IP net.IP + Port uint16 +} diff --git a/network/state/exists.go b/network/state/exists.go new file mode 100644 index 00000000..7b308608 --- /dev/null +++ b/network/state/exists.go @@ -0,0 +1,101 @@ +package state + +import ( + "time" + + "github.com/safing/portmaster/network/packet" + "github.com/safing/portmaster/network/socket" +) + +const ( + // UDPConnectionTTL defines the duration after which unseen UDP connections are regarded as ended. + UDPConnectionTTL = 10 * time.Minute +) + +// Exists checks if the given connection is present in the system state tables. +func Exists(pktInfo *packet.Info, now time.Time) (exists bool) { + + // TODO: create lookup maps before running a flurry of Exists() checks. + + switch { + case pktInfo.Version == packet.IPv4 && pktInfo.Protocol == packet.TCP: + tcp4Lock.Lock() + defer tcp4Lock.Unlock() + return existsTCP(tcp4Connections, pktInfo) + + case pktInfo.Version == packet.IPv6 && pktInfo.Protocol == packet.TCP: + tcp6Lock.Lock() + defer tcp6Lock.Unlock() + return existsTCP(tcp6Connections, pktInfo) + + case pktInfo.Version == packet.IPv4 && pktInfo.Protocol == packet.UDP: + udp4Lock.Lock() + defer udp4Lock.Unlock() + return existsUDP(udp4Binds, udp4States, pktInfo, now) + + case pktInfo.Version == packet.IPv6 && pktInfo.Protocol == packet.UDP: + udp6Lock.Lock() + defer udp6Lock.Unlock() + return existsUDP(udp6Binds, udp6States, pktInfo, now) + + default: + return false + } +} + +func existsTCP(connections []*socket.ConnectionInfo, pktInfo *packet.Info) (exists bool) { + localIP := pktInfo.LocalIP() + localPort := pktInfo.LocalPort() + remoteIP := pktInfo.RemoteIP() + remotePort := pktInfo.RemotePort() + + // search connections + for _, socketInfo := range connections { + if localPort == socketInfo.Local.Port && + remotePort == socketInfo.Remote.Port && + remoteIP.Equal(socketInfo.Remote.IP) && + localIP.Equal(socketInfo.Local.IP) { + return true + } + } + + return false +} + +func existsUDP( + binds []*socket.BindInfo, + udpStates map[string]map[string]*udpState, + pktInfo *packet.Info, + now time.Time, +) (exists bool) { + + localIP := pktInfo.LocalIP() + localPort := pktInfo.LocalPort() + remoteIP := pktInfo.RemoteIP() + remotePort := pktInfo.RemotePort() + + connThreshhold := now.Add(-UDPConnectionTTL) + + // search binds + for _, socketInfo := range binds { + if localPort == socketInfo.Local.Port && + (socketInfo.Local.IP[0] == 0 || localIP.Equal(socketInfo.Local.IP)) { + + udpConnState, ok := getUDPConnState(socketInfo, udpStates, socket.Address{ + IP: remoteIP, + Port: remotePort, + }) + switch { + case !ok: + return false + case udpConnState.lastSeen.After(connThreshhold): + return true + default: + return false + } + + } + } + + return false +} diff --git a/network/state/info.go b/network/state/info.go new file mode 100644 index 00000000..5d4b0d4d --- /dev/null +++ b/network/state/info.go @@ -0,0 +1,52 @@ +package state + +import ( + "sync" + + "github.com/safing/portbase/database/record" + + "github.com/safing/portmaster/network/socket" +) + +// Info holds network state information as provided by the system. +type Info struct { + record.Base + sync.Mutex + + TCP4Connections []*socket.ConnectionInfo + TCP4Listeners []*socket.BindInfo + TCP6Connections []*socket.ConnectionInfo + TCP6Listeners []*socket.BindInfo + UDP4Binds []*socket.BindInfo + UDP6Binds []*socket.BindInfo +} + +// GetInfo returns all system state tables. The returned data must not be modified. +func GetInfo() *Info { + info := &Info{} + + tcp4Lock.Lock() + updateTCP4Tables() + info.TCP4Connections = tcp4Connections + info.TCP4Listeners = tcp4Listeners + tcp4Lock.Unlock() + + tcp6Lock.Lock() + updateTCP6Tables() + info.TCP6Connections = tcp6Connections + info.TCP6Listeners = tcp6Listeners + tcp6Lock.Unlock() + + udp4Lock.Lock() + updateUDP4Table() + info.UDP4Binds = udp4Binds + udp4Lock.Unlock() + + udp6Lock.Lock() + updateUDP6Table() + info.UDP6Binds = udp6Binds + udp6Lock.Unlock() + + info.UpdateMeta() + return info +} diff --git a/network/state/lookup.go b/network/state/lookup.go new file mode 100644 index 00000000..5aadf7fa --- /dev/null +++ b/network/state/lookup.go @@ -0,0 +1,171 @@ +package state + +import ( + "errors" + "sync" + "time" + + "github.com/safing/portmaster/network/netutils" + "github.com/safing/portmaster/network/packet" + "github.com/safing/portmaster/network/socket" +) + +// - TCP +// - Outbound: Match listeners (in!), then connections (out!) +// - Inbound: Match listeners (in!), then connections (out!) +// - Clean via connections +// - UDP +// - Any connection: match specific local address or zero IP +// - In or out: save direction of first packet: +// - map[]map[]{direction, lastSeen} +// - only clean if is removed by OS +// - limit to 256 entries? +// - clean after 72hrs? +// - switch direction to outbound if outbound packet is seen? +// - IP: Unidentified Process + +// Errors +var ( + ErrConnectionNotFound = errors.New("could not find connection in system state tables") + ErrPIDNotFound = errors.New("could not find pid for socket inode") +) + +var ( + tcp4Lock sync.Mutex + tcp6Lock sync.Mutex + udp4Lock sync.Mutex + udp6Lock sync.Mutex + + baseWaitTime = 3 * time.Millisecond +) + +// Lookup looks for the given connection in the system state tables and returns the PID of the associated process and whether the connection is inbound. +func Lookup(pktInfo *packet.Info) (pid int, inbound bool, err error) { + // auto-detect version + if pktInfo.Version == 0 { + if ip := pktInfo.LocalIP().To4(); ip != nil { + pktInfo.Version = packet.IPv4 + } else { + pktInfo.Version = packet.IPv6 + } + } + + switch { + case pktInfo.Version == packet.IPv4 && pktInfo.Protocol == packet.TCP: + tcp4Lock.Lock() + defer tcp4Lock.Unlock() + return searchTCP(tcp4Connections, tcp4Listeners, updateTCP4Tables, pktInfo) + + case pktInfo.Version == packet.IPv6 && pktInfo.Protocol == packet.TCP: + tcp6Lock.Lock() + defer tcp6Lock.Unlock() + return searchTCP(tcp6Connections, tcp6Listeners, updateTCP6Tables, pktInfo) + + case pktInfo.Version == packet.IPv4 && pktInfo.Protocol == packet.UDP: + udp4Lock.Lock() + defer udp4Lock.Unlock() + return searchUDP(udp4Binds, udp4States, updateUDP4Table, pktInfo) + + case pktInfo.Version == packet.IPv6 && pktInfo.Protocol == packet.UDP: + udp6Lock.Lock() + defer udp6Lock.Unlock() + return searchUDP(udp6Binds, udp6States, updateUDP6Table, pktInfo) + + default: + return socket.UnidentifiedProcessID, false, errors.New("unsupported protocol for finding process") + } +} + +func searchTCP( + connections []*socket.ConnectionInfo, + listeners []*socket.BindInfo, + updateTables func() ([]*socket.ConnectionInfo, []*socket.BindInfo), + pktInfo *packet.Info, +) ( + pid int, + inbound bool, + err error, +) { + + localIP := pktInfo.LocalIP() + localPort := pktInfo.LocalPort() + + // search until we find something + for i := 0; i < 7; i++ { + // always search listeners first + for _, socketInfo := range listeners { + if localPort == socketInfo.Local.Port && + (socketInfo.Local.IP[0] == 0 || localIP.Equal(socketInfo.Local.IP)) { + return checkBindPID(socketInfo, true) + } + } + + // search connections + for _, socketInfo := range connections { + if localPort == socketInfo.Local.Port && + localIP.Equal(socketInfo.Local.IP) { + return checkConnectionPID(socketInfo, false) + } + } + + // we found nothing, we could have been too fast, give the kernel some time to think + // back off timer: with 3ms baseWaitTime: 3, 6, 9, 12, 15, 18, 21ms - 84ms in total + time.Sleep(time.Duration(i+1) * baseWaitTime) + + // refetch lists + connections, listeners = updateTables() + } + + return socket.UnidentifiedProcessID, false, ErrConnectionNotFound +} + +func searchUDP( + binds []*socket.BindInfo, + udpStates map[string]map[string]*udpState, + updateTable func() []*socket.BindInfo, + pktInfo *packet.Info, +) ( + pid int, + inbound bool, + err error, +) { + + localIP := pktInfo.LocalIP() + localPort := pktInfo.LocalPort() + + isInboundMulticast := pktInfo.Inbound && netutils.ClassifyIP(localIP) == netutils.LocalMulticast + // TODO: Currently broadcast/multicast scopes are not checked, so we might + // attribute an incoming broadcast/multicast packet to the wrong process if + // there are multiple processes listening on the same local port, but + // binding to different addresses. This highly unusual for clients. + + // search until we find something + for i := 0; i < 5; i++ { + // search binds + for _, socketInfo := range binds { + if localPort == socketInfo.Local.Port && + (socketInfo.Local.IP[0] == 0 || // zero IP + isInboundMulticast || // inbound broadcast, multicast + localIP.Equal(socketInfo.Local.IP)) { + + // do not check direction if remoteIP/Port is not given + if pktInfo.RemotePort() == 0 { + return checkBindPID(socketInfo, pktInfo.Inbound) + } + + // get direction and return + connInbound := getUDPDirection(socketInfo, udpStates, pktInfo) + return checkBindPID(socketInfo, connInbound) + } + } + + // we found nothing, we could have been too fast, give the kernel some time to think + // back off timer: with 3ms baseWaitTime: 3, 6, 9, 12, 15, 18, 21ms - 84ms in total + time.Sleep(time.Duration(i+1) * baseWaitTime) + + // refetch lists + binds = updateTable() + } + + return socket.UnidentifiedProcessID, pktInfo.Inbound, ErrConnectionNotFound +} diff --git a/network/state/system_linux.go b/network/state/system_linux.go new file mode 100644 index 00000000..b902c58c --- /dev/null +++ b/network/state/system_linux.go @@ -0,0 +1,27 @@ +package state + +import ( + "github.com/safing/portmaster/network/proc" + "github.com/safing/portmaster/network/socket" +) + +var ( + getTCP4Table = proc.GetTCP4Table + getTCP6Table = proc.GetTCP6Table + getUDP4Table = proc.GetUDP4Table + getUDP6Table = proc.GetUDP6Table +) + +func checkConnectionPID(socketInfo *socket.ConnectionInfo, connInbound bool) (pid int, inbound bool, err error) { + if socketInfo.PID == proc.UnfetchedProcessID { + socketInfo.PID = proc.FindPID(socketInfo.UID, socketInfo.Inode) + } + return socketInfo.PID, connInbound, nil +} + +func checkBindPID(socketInfo *socket.BindInfo, connInbound bool) (pid int, inbound bool, err error) { + if socketInfo.PID == proc.UnfetchedProcessID { + socketInfo.PID = proc.FindPID(socketInfo.UID, socketInfo.Inode) + } + return socketInfo.PID, connInbound, nil +} diff --git a/network/state/system_windows.go b/network/state/system_windows.go new file mode 100644 index 00000000..a03ea5f6 --- /dev/null +++ b/network/state/system_windows.go @@ -0,0 +1,21 @@ +package state + +import ( + "github.com/safing/portmaster/network/iphelper" + "github.com/safing/portmaster/network/socket" +) + +var ( + getTCP4Table = iphelper.GetTCP4Table + getTCP6Table = iphelper.GetTCP6Table + getUDP4Table = iphelper.GetUDP4Table + getUDP6Table = iphelper.GetUDP6Table +) + +func checkConnectionPID(socketInfo *socket.ConnectionInfo, connInbound bool) (pid int, inbound bool, err error) { + return socketInfo.PID, connInbound, nil +} + +func checkBindPID(socketInfo *socket.BindInfo, connInbound bool) (pid int, inbound bool, err error) { + return socketInfo.PID, connInbound, nil +} diff --git a/network/state/tables.go b/network/state/tables.go new file mode 100644 index 00000000..2f236cc6 --- /dev/null +++ b/network/state/tables.go @@ -0,0 +1,68 @@ +package state + +import ( + "github.com/safing/portbase/log" + "github.com/safing/portmaster/network/socket" +) + +var ( + tcp4Connections []*socket.ConnectionInfo + tcp4Listeners []*socket.BindInfo + + tcp6Connections []*socket.ConnectionInfo + tcp6Listeners []*socket.BindInfo + + udp4Binds []*socket.BindInfo + + udp6Binds []*socket.BindInfo +) + +func updateTCP4Tables() (connections []*socket.ConnectionInfo, listeners []*socket.BindInfo) { + var err error + connections, listeners, err = getTCP4Table() + if err != nil { + log.Warningf("state: failed to get TCP4 socket table: %s", err) + return + } + + tcp4Connections = connections + tcp4Listeners = listeners + return +} + +func updateTCP6Tables() (connections []*socket.ConnectionInfo, listeners []*socket.BindInfo) { + var err error + connections, listeners, err = getTCP6Table() + if err != nil { + log.Warningf("state: failed to get TCP6 socket table: %s", err) + return + } + + tcp6Connections = connections + tcp6Listeners = listeners + return +} + +func updateUDP4Table() (binds []*socket.BindInfo) { + var err error + binds, err = getUDP4Table() + if err != nil { + log.Warningf("state: failed to get UDP4 socket table: %s", err) + return + } + + udp4Binds = binds + return +} + +func updateUDP6Table() (binds []*socket.BindInfo) { + var err error + binds, err = getUDP6Table() + if err != nil { + log.Warningf("state: failed to get UDP6 socket table: %s", err) + return + } + + udp6Binds = binds + return +} diff --git a/network/state/udp.go b/network/state/udp.go new file mode 100644 index 00000000..f49b1d04 --- /dev/null +++ b/network/state/udp.go @@ -0,0 +1,125 @@ +package state + +import ( + "context" + "time" + + "github.com/safing/portmaster/network/packet" + "github.com/safing/portmaster/network/socket" +) + +type udpState struct { + inbound bool + lastSeen time.Time +} + +const ( + // UDPConnStateTTL is the maximum time a udp connection state is held. + UDPConnStateTTL = 72 * time.Hour + + // UDPConnStateShortenedTTL is a shortened maximum time a udp connection state is held, if there more entries than defined by AggressiveCleaningThreshold. + UDPConnStateShortenedTTL = 3 * time.Hour + + // AggressiveCleaningThreshold defines the soft limit of udp connection state held per udp socket. + AggressiveCleaningThreshold = 256 +) + +var ( + udp4States = make(map[string]map[string]*udpState) // locked with udp4Lock + udp6States = make(map[string]map[string]*udpState) // locked with udp6Lock +) + +func getUDPConnState(socketInfo *socket.BindInfo, udpStates map[string]map[string]*udpState, remoteAddress socket.Address) (udpConnState *udpState, ok bool) { + bindMap, ok := udpStates[makeUDPStateKey(socketInfo.Local)] + if ok { + udpConnState, ok = bindMap[makeUDPStateKey(remoteAddress)] + return + } + + return nil, false +} + +func getUDPDirection(socketInfo *socket.BindInfo, udpStates map[string]map[string]*udpState, pktInfo *packet.Info) (connDirection bool) { + localKey := makeUDPStateKey(socketInfo.Local) + + bindMap, ok := udpStates[localKey] + if !ok { + bindMap = make(map[string]*udpState) + udpStates[localKey] = bindMap + } + + remoteKey := makeUDPStateKey(socket.Address{ + IP: pktInfo.RemoteIP(), + Port: pktInfo.RemotePort(), + }) + udpConnState, ok := bindMap[remoteKey] + if !ok { + bindMap[remoteKey] = &udpState{ + inbound: pktInfo.Inbound, + lastSeen: time.Now().UTC(), + } + return pktInfo.Inbound + } + + udpConnState.lastSeen = time.Now().UTC() + return udpConnState.inbound +} + +// CleanUDPStates cleans the udp connection states which save connection directions. +func CleanUDPStates(_ context.Context) { + now := time.Now().UTC() + + udp4Lock.Lock() + updateUDP4Table() + cleanStates(udp4Binds, udp4States, now) + udp4Lock.Unlock() + + udp6Lock.Lock() + updateUDP6Table() + cleanStates(udp6Binds, udp6States, now) + udp6Lock.Unlock() +} + +func cleanStates( + binds []*socket.BindInfo, + udpStates map[string]map[string]*udpState, + now time.Time, +) { + // compute thresholds + threshold := now.Add(-UDPConnStateTTL) + shortThreshhold := now.Add(-UDPConnStateShortenedTTL) + + // make lookup map of all active keys + bindKeys := make(map[string]struct{}) + for _, socketInfo := range binds { + bindKeys[makeUDPStateKey(socketInfo.Local)] = struct{}{} + } + + // clean the udp state storage + for localKey, bindMap := range udpStates { + if _, active := bindKeys[localKey]; active { + // clean old entries + for remoteKey, udpConnState := range bindMap { + if udpConnState.lastSeen.Before(threshold) { + delete(bindMap, remoteKey) + } + } + // if there are too many clean more aggressively + if len(bindMap) > AggressiveCleaningThreshold { + for remoteKey, udpConnState := range bindMap { + if udpConnState.lastSeen.Before(shortThreshhold) { + delete(bindMap, remoteKey) + } + } + } + } else { + // delete the whole thing + delete(udpStates, localKey) + } + } +} + +func makeUDPStateKey(address socket.Address) string { + // This could potentially go wrong, but as all IPs are created by the same source, everything should be fine. + return string(address.IP) + string(address.Port) +} diff --git a/pmctl/logs.go b/pmctl/logs.go index 7ee1576d..03d05678 100644 --- a/pmctl/logs.go +++ b/pmctl/logs.go @@ -74,7 +74,7 @@ func finalizeLogFile(logFile *os.File, logFilePath string) { func initControlLogFile() *os.File { // check logging dir - logFileBasePath := filepath.Join(logsRoot.Path, "fstree", "control") + logFileBasePath := filepath.Join(logsRoot.Path, "control") err := logsRoot.EnsureAbsPath(logFileBasePath) if err != nil { log.Printf("failed to check/create log file folder %s: %s\n", logFileBasePath, err) @@ -93,7 +93,7 @@ func logControlError(cErr error) { } // check logging dir - logFileBasePath := filepath.Join(logsRoot.Path, "fstree", "control") + logFileBasePath := filepath.Join(logsRoot.Path, "control") err := logsRoot.EnsureAbsPath(logFileBasePath) if err != nil { log.Printf("failed to check/create log file folder %s: %s\n", logFileBasePath, err) @@ -114,7 +114,7 @@ func logControlError(cErr error) { //nolint:deadcode,unused // TODO func logControlStack() { // check logging dir - logFileBasePath := filepath.Join(logsRoot.Path, "fstree", "control") + logFileBasePath := filepath.Join(logsRoot.Path, "control") err := logsRoot.EnsureAbsPath(logFileBasePath) if err != nil { log.Printf("failed to check/create log file folder %s: %s\n", logFileBasePath, err) diff --git a/pmctl/run.go b/pmctl/run.go index 53133697..432bfa3d 100644 --- a/pmctl/run.go +++ b/pmctl/run.go @@ -227,7 +227,7 @@ func execute(opts *Options, args []string) (cont bool, err error) { // log files var logFile, errorFile *os.File - logFileBasePath := filepath.Join(logsRoot.Path, "fstree", opts.ShortIdentifier) + logFileBasePath := filepath.Join(logsRoot.Path, opts.ShortIdentifier) err = logsRoot.EnsureAbsPath(logFileBasePath) if err != nil { log.Printf("failed to check/create log file dir %s: %s\n", logFileBasePath, err) diff --git a/process/find.go b/process/find.go index 30f93f2d..50070949 100644 --- a/process/find.go +++ b/process/find.go @@ -3,7 +3,8 @@ package process import ( "context" "errors" - "net" + + "github.com/safing/portmaster/network/state" "github.com/safing/portbase/log" "github.com/safing/portmaster/network/packet" @@ -11,131 +12,28 @@ import ( // Errors var ( - ErrConnectionNotFound = errors.New("could not find connection in system state tables") - ErrProcessNotFound = errors.New("could not find process in system state tables") + ErrProcessNotFound = errors.New("could not find process in system state tables") ) -// GetPidByPacket returns the pid of the owner of the packet. -func GetPidByPacket(pkt packet.Packet) (pid int, direction bool, err error) { - - var localIP net.IP - var localPort uint16 - var remoteIP net.IP - var remotePort uint16 - if pkt.IsInbound() { - localIP = pkt.Info().Dst - remoteIP = pkt.Info().Src - } else { - localIP = pkt.Info().Src - remoteIP = pkt.Info().Dst - } - if pkt.HasPorts() { - if pkt.IsInbound() { - localPort = pkt.Info().DstPort - remotePort = pkt.Info().SrcPort - } else { - localPort = pkt.Info().SrcPort - remotePort = pkt.Info().DstPort - } - } - - switch { - case pkt.Info().Protocol == packet.TCP && pkt.Info().Version == packet.IPv4: - return getTCP4PacketInfo(localIP, localPort, remoteIP, remotePort, pkt.IsInbound()) - case pkt.Info().Protocol == packet.UDP && pkt.Info().Version == packet.IPv4: - return getUDP4PacketInfo(localIP, localPort, remoteIP, remotePort, pkt.IsInbound()) - case pkt.Info().Protocol == packet.TCP && pkt.Info().Version == packet.IPv6: - return getTCP6PacketInfo(localIP, localPort, remoteIP, remotePort, pkt.IsInbound()) - case pkt.Info().Protocol == packet.UDP && pkt.Info().Version == packet.IPv6: - return getUDP6PacketInfo(localIP, localPort, remoteIP, remotePort, pkt.IsInbound()) - default: - return UnidentifiedProcessID, false, errors.New("unsupported protocol for finding process") - } - -} - -// GetProcessByPacket returns the process that owns the given packet. -func GetProcessByPacket(pkt packet.Packet) (process *Process, direction bool, err error) { - if !enableProcessDetection() { - log.Tracer(pkt.Ctx()).Tracef("process: process detection disabled") - return GetUnidentifiedProcess(pkt.Ctx()), pkt.Info().Direction, nil - } - - log.Tracer(pkt.Ctx()).Tracef("process: getting process and profile by packet") - - var pid int - pid, direction, err = GetPidByPacket(pkt) - if err != nil { - log.Tracer(pkt.Ctx()).Errorf("process: failed to find PID of connection: %s", err) - return nil, direction, err - } - if pid < 0 { - log.Tracer(pkt.Ctx()).Errorf("process: %s", ErrConnectionNotFound.Error()) - return nil, direction, ErrConnectionNotFound - } - - process, err = GetOrFindPrimaryProcess(pkt.Ctx(), pid) - if err != nil { - log.Tracer(pkt.Ctx()).Errorf("process: failed to find (primary) process with PID: %s", err) - return nil, direction, err - } - - err = process.GetProfile(pkt.Ctx()) - if err != nil { - log.Tracer(pkt.Ctx()).Errorf("process: failed to get profile for process %s: %s", process, err) - } - - return process, direction, nil - -} - -// GetPidByEndpoints returns the pid of the owner of the described link. -func GetPidByEndpoints(localIP net.IP, localPort uint16, remoteIP net.IP, remotePort uint16, protocol packet.IPProtocol) (pid int, direction bool, err error) { - - ipVersion := packet.IPv4 - if v4 := localIP.To4(); v4 == nil { - ipVersion = packet.IPv6 - } - - switch { - case protocol == packet.TCP && ipVersion == packet.IPv4: - return getTCP4PacketInfo(localIP, localPort, remoteIP, remotePort, false) - case protocol == packet.UDP && ipVersion == packet.IPv4: - return getUDP4PacketInfo(localIP, localPort, remoteIP, remotePort, false) - case protocol == packet.TCP && ipVersion == packet.IPv6: - return getTCP6PacketInfo(localIP, localPort, remoteIP, remotePort, false) - case protocol == packet.UDP && ipVersion == packet.IPv6: - return getUDP6PacketInfo(localIP, localPort, remoteIP, remotePort, false) - default: - return UnidentifiedProcessID, false, errors.New("unsupported protocol for finding process") - } - -} - -// GetProcessByEndpoints returns the process that owns the described link. -func GetProcessByEndpoints(ctx context.Context, localIP net.IP, localPort uint16, remoteIP net.IP, remotePort uint16, protocol packet.IPProtocol) (process *Process, err error) { +// GetProcessByConnection returns the process that owns the described connection. +func GetProcessByConnection(ctx context.Context, pktInfo *packet.Info) (process *Process, connInbound bool, err error) { if !enableProcessDetection() { log.Tracer(ctx).Tracef("process: process detection disabled") - return GetUnidentifiedProcess(ctx), nil + return GetUnidentifiedProcess(ctx), pktInfo.Inbound, nil } - log.Tracer(ctx).Tracef("process: getting process and profile by endpoints") - + log.Tracer(ctx).Tracef("process: getting pid from system network state") var pid int - pid, _, err = GetPidByEndpoints(localIP, localPort, remoteIP, remotePort, protocol) + pid, connInbound, err = state.Lookup(pktInfo) if err != nil { - log.Tracer(ctx).Errorf("process: failed to find PID of connection: %s", err) - return nil, err - } - if pid < 0 { - log.Tracer(ctx).Errorf("process: %s", ErrConnectionNotFound.Error()) - return nil, ErrConnectionNotFound + log.Tracer(ctx).Debugf("process: failed to find PID of connection: %s", err) + return nil, connInbound, err } process, err = GetOrFindPrimaryProcess(ctx, pid) if err != nil { - log.Tracer(ctx).Errorf("process: failed to find (primary) process with PID: %s", err) - return nil, err + log.Tracer(ctx).Debugf("process: failed to find (primary) process with PID: %s", err) + return nil, connInbound, err } err = process.GetProfile(ctx) @@ -143,10 +41,5 @@ func GetProcessByEndpoints(ctx context.Context, localIP net.IP, localPort uint16 log.Tracer(ctx).Errorf("process: failed to get profile for process %s: %s", process, err) } - return process, nil -} - -// GetActiveConnectionIDs returns a list of all active connection IDs. -func GetActiveConnectionIDs() []string { - return getActiveConnectionIDs() + return process, connInbound, nil } diff --git a/process/getpid_linux.go b/process/getpid_linux.go deleted file mode 100644 index 1788f3e9..00000000 --- a/process/getpid_linux.go +++ /dev/null @@ -1,13 +0,0 @@ -package process - -import ( - "github.com/safing/portmaster/process/proc" -) - -var ( - getTCP4PacketInfo = proc.GetTCP4PacketInfo - getTCP6PacketInfo = proc.GetTCP6PacketInfo - getUDP4PacketInfo = proc.GetUDP4PacketInfo - getUDP6PacketInfo = proc.GetUDP6PacketInfo - getActiveConnectionIDs = proc.GetActiveConnectionIDs -) diff --git a/process/getpid_windows.go b/process/getpid_windows.go deleted file mode 100644 index 98b200ea..00000000 --- a/process/getpid_windows.go +++ /dev/null @@ -1,13 +0,0 @@ -package process - -import ( - "github.com/safing/portmaster/process/iphelper" -) - -var ( - getTCP4PacketInfo = iphelper.GetTCP4PacketInfo - getTCP6PacketInfo = iphelper.GetTCP6PacketInfo - getUDP4PacketInfo = iphelper.GetUDP4PacketInfo - getUDP6PacketInfo = iphelper.GetUDP6PacketInfo - getActiveConnectionIDs = iphelper.GetActiveConnectionIDs -) diff --git a/process/iphelper/get.go b/process/iphelper/get.go deleted file mode 100644 index 6487ea06..00000000 --- a/process/iphelper/get.go +++ /dev/null @@ -1,260 +0,0 @@ -// +build windows - -package iphelper - -import ( - "fmt" - "net" - "sync" - "time" -) - -const ( - unidentifiedProcessID = -1 -) - -var ( - tcp4Connections []*ConnectionEntry - tcp4Listeners []*ConnectionEntry - tcp6Connections []*ConnectionEntry - tcp6Listeners []*ConnectionEntry - - udp4Connections []*ConnectionEntry - udp4Listeners []*ConnectionEntry - udp6Connections []*ConnectionEntry - udp6Listeners []*ConnectionEntry - - ipHelper *IPHelper - lock sync.RWMutex - - waitTime = 15 * time.Millisecond -) - -func checkIPHelper() (err error) { - if ipHelper == nil { - ipHelper, err = New() - return err - } - return nil -} - -// GetTCP4PacketInfo returns the pid of the given IPv4/TCP connection. -func GetTCP4PacketInfo(localIP net.IP, localPort uint16, remoteIP net.IP, remotePort uint16, pktDirection bool) (pid int, direction bool, err error) { - - // search - pid, _ = search(tcp4Connections, tcp4Listeners, localIP, remoteIP, localPort, remotePort, pktDirection) - if pid >= 0 { - return pid, pktDirection, nil - } - - for i := 0; i < 3; i++ { - // give kernel some time, then try again - // log.Tracef("process: giving kernel some time to think") - - // if unable to find, refresh - lock.Lock() - err = checkIPHelper() - if err == nil { - tcp4Connections, tcp4Listeners, err = ipHelper.GetTables(TCP, IPv4) - } - lock.Unlock() - if err != nil { - return unidentifiedProcessID, pktDirection, err - } - - // search - pid, _ = search(tcp4Connections, tcp4Listeners, localIP, remoteIP, localPort, remotePort, pktDirection) - if pid >= 0 { - return pid, pktDirection, nil - } - - time.Sleep(waitTime) - } - - return unidentifiedProcessID, pktDirection, nil -} - -// GetTCP6PacketInfo returns the pid of the given IPv6/TCP connection. -func GetTCP6PacketInfo(localIP net.IP, localPort uint16, remoteIP net.IP, remotePort uint16, pktDirection bool) (pid int, direction bool, err error) { - - // search - pid, _ = search(tcp6Connections, tcp6Listeners, localIP, remoteIP, localPort, remotePort, pktDirection) - if pid >= 0 { - return pid, pktDirection, nil - } - - for i := 0; i < 3; i++ { - // give kernel some time, then try again - // log.Tracef("process: giving kernel some time to think") - - // if unable to find, refresh - lock.Lock() - err = checkIPHelper() - if err == nil { - tcp6Connections, tcp6Listeners, err = ipHelper.GetTables(TCP, IPv6) - } - lock.Unlock() - if err != nil { - return unidentifiedProcessID, pktDirection, err - } - - // search - pid, _ = search(tcp6Connections, tcp6Listeners, localIP, remoteIP, localPort, remotePort, pktDirection) - if pid >= 0 { - return pid, pktDirection, nil - } - - time.Sleep(waitTime) - } - - return unidentifiedProcessID, pktDirection, nil -} - -// GetUDP4PacketInfo returns the pid of the given IPv4/UDP connection. -func GetUDP4PacketInfo(localIP net.IP, localPort uint16, remoteIP net.IP, remotePort uint16, pktDirection bool) (pid int, direction bool, err error) { - - // search - pid, _ = search(udp4Connections, udp4Listeners, localIP, remoteIP, localPort, remotePort, pktDirection) - if pid >= 0 { - return pid, pktDirection, nil - } - - for i := 0; i < 3; i++ { - // give kernel some time, then try again - // log.Tracef("process: giving kernel some time to think") - - // if unable to find, refresh - lock.Lock() - err = checkIPHelper() - if err == nil { - udp4Connections, udp4Listeners, err = ipHelper.GetTables(UDP, IPv4) - } - lock.Unlock() - if err != nil { - return unidentifiedProcessID, pktDirection, err - } - - // search - pid, _ = search(udp4Connections, udp4Listeners, localIP, remoteIP, localPort, remotePort, pktDirection) - if pid >= 0 { - return pid, pktDirection, nil - } - - time.Sleep(waitTime) - } - - return unidentifiedProcessID, pktDirection, nil -} - -// GetUDP6PacketInfo returns the pid of the given IPv6/UDP connection. -func GetUDP6PacketInfo(localIP net.IP, localPort uint16, remoteIP net.IP, remotePort uint16, pktDirection bool) (pid int, direction bool, err error) { - - // search - pid, _ = search(udp6Connections, udp6Listeners, localIP, remoteIP, localPort, remotePort, pktDirection) - if pid >= 0 { - return pid, pktDirection, nil - } - - for i := 0; i < 3; i++ { - // give kernel some time, then try again - // log.Tracef("process: giving kernel some time to think") - - // if unable to find, refresh - lock.Lock() - err = checkIPHelper() - if err == nil { - udp6Connections, udp6Listeners, err = ipHelper.GetTables(UDP, IPv6) - } - lock.Unlock() - if err != nil { - return unidentifiedProcessID, pktDirection, err - } - - // search - pid, _ = search(udp6Connections, udp6Listeners, localIP, remoteIP, localPort, remotePort, pktDirection) - if pid >= 0 { - return pid, pktDirection, nil - } - - time.Sleep(waitTime) - } - - return unidentifiedProcessID, pktDirection, nil -} - -func search(connections, listeners []*ConnectionEntry, localIP, remoteIP net.IP, localPort, remotePort uint16, pktDirection bool) (pid int, direction bool) { //nolint:unparam // TODO: use direction, it may not be used because results caused problems, investigate. - lock.RLock() - defer lock.RUnlock() - - if pktDirection { - // inbound - pid = searchListeners(listeners, localIP, localPort) - if pid >= 0 { - return pid, true - } - pid = searchConnections(connections, localIP, remoteIP, localPort, remotePort) - if pid >= 0 { - return pid, false - } - } else { - // outbound - pid = searchConnections(connections, localIP, remoteIP, localPort, remotePort) - if pid >= 0 { - return pid, false - } - pid = searchListeners(listeners, localIP, localPort) - if pid >= 0 { - return pid, true - } - } - - return unidentifiedProcessID, pktDirection -} - -func searchConnections(list []*ConnectionEntry, localIP, remoteIP net.IP, localPort, remotePort uint16) (pid int) { - - for _, entry := range list { - if localPort == entry.localPort && - remotePort == entry.remotePort && - remoteIP.Equal(entry.remoteIP) && - localIP.Equal(entry.localIP) { - return entry.pid - } - } - - return unidentifiedProcessID -} - -func searchListeners(list []*ConnectionEntry, localIP net.IP, localPort uint16) (pid int) { - - for _, entry := range list { - if localPort == entry.localPort && - (entry.localIP == nil || // nil IP means zero IP, see tables.go - localIP.Equal(entry.localIP)) { - return entry.pid - } - } - - return unidentifiedProcessID -} - -// GetActiveConnectionIDs returns all currently active connection IDs. -func GetActiveConnectionIDs() (connections []string) { - lock.Lock() - defer lock.Unlock() - - for _, entry := range tcp4Connections { - connections = append(connections, fmt.Sprintf("%d-%s-%d-%s-%d", TCP, entry.localIP, entry.localPort, entry.remoteIP, entry.remotePort)) - } - for _, entry := range tcp6Connections { - connections = append(connections, fmt.Sprintf("%d-%s-%d-%s-%d", TCP, entry.localIP, entry.localPort, entry.remoteIP, entry.remotePort)) - } - for _, entry := range udp4Connections { - connections = append(connections, fmt.Sprintf("%d-%s-%d-%s-%d", UDP, entry.localIP, entry.localPort, entry.remoteIP, entry.remotePort)) - } - for _, entry := range udp6Connections { - connections = append(connections, fmt.Sprintf("%d-%s-%d-%s-%d", UDP, entry.localIP, entry.localPort, entry.remoteIP, entry.remotePort)) - } - - return -} diff --git a/process/iphelper/iphelper.go b/process/iphelper/iphelper.go deleted file mode 100644 index a7c259da..00000000 --- a/process/iphelper/iphelper.go +++ /dev/null @@ -1,79 +0,0 @@ -// +build windows - -package iphelper - -import ( - "errors" - "fmt" - - "github.com/tevino/abool" - "golang.org/x/sys/windows" -) - -var ( - errInvalid = errors.New("IPHelper not initialzed or broken") -) - -// IPHelper represents a subset of the Windows iphlpapi.dll. -type IPHelper struct { - dll *windows.LazyDLL - - getExtendedTCPTable *windows.LazyProc - getExtendedUDPTable *windows.LazyProc - // getOwnerModuleFromTcpEntry *windows.LazyProc - // getOwnerModuleFromTcp6Entry *windows.LazyProc - // getOwnerModuleFromUdpEntry *windows.LazyProc - // getOwnerModuleFromUdp6Entry *windows.LazyProc - - valid *abool.AtomicBool -} - -// New returns a new IPHelper API (with an instance of iphlpapi.dll loaded). -func New() (*IPHelper, error) { - - new := &IPHelper{} - new.valid = abool.NewBool(false) - var err error - - // load dll - new.dll = windows.NewLazySystemDLL("iphlpapi.dll") - err = new.dll.Load() - if err != nil { - return nil, err - } - - // load functions - new.getExtendedTCPTable = new.dll.NewProc("GetExtendedTcpTable") - err = new.getExtendedTCPTable.Find() - if err != nil { - return nil, fmt.Errorf("could find proc GetExtendedTcpTable: %s", err) - } - new.getExtendedUDPTable = new.dll.NewProc("GetExtendedUdpTable") - err = new.getExtendedUDPTable.Find() - if err != nil { - return nil, fmt.Errorf("could find proc GetExtendedUdpTable: %s", err) - } - // new.getOwnerModuleFromTcpEntry = new.dll.NewProc("GetOwnerModuleFromTcpEntry") - // err = new.getOwnerModuleFromTcpEntry.Find() - // if err != nil { - // return nil, fmt.Errorf("could find proc GetOwnerModuleFromTcpEntry: %s", err) - // } - // new.getOwnerModuleFromTcp6Entry = new.dll.NewProc("GetOwnerModuleFromTcp6Entry") - // err = new.getOwnerModuleFromTcp6Entry.Find() - // if err != nil { - // return nil, fmt.Errorf("could find proc GetOwnerModuleFromTcp6Entry: %s", err) - // } - // new.getOwnerModuleFromUdpEntry = new.dll.NewProc("GetOwnerModuleFromUdpEntry") - // err = new.getOwnerModuleFromUdpEntry.Find() - // if err != nil { - // return nil, fmt.Errorf("could find proc GetOwnerModuleFromUdpEntry: %s", err) - // } - // new.getOwnerModuleFromUdp6Entry = new.dll.NewProc("GetOwnerModuleFromUdp6Entry") - // err = new.getOwnerModuleFromUdp6Entry.Find() - // if err != nil { - // return nil, fmt.Errorf("could find proc GetOwnerModuleFromUdp6Entry: %s", err) - // } - - new.valid.Set() - return new, nil -} diff --git a/process/iphelper/test/main.go b/process/iphelper/test/main.go deleted file mode 100644 index 5234fbb4..00000000 --- a/process/iphelper/test/main.go +++ /dev/null @@ -1,62 +0,0 @@ -// +build windows - -package main - -import ( - "fmt" - - "github.com/safing/portmaster/process/iphelper" -) - -func main() { - iph, err := iphelper.New() - if err != nil { - panic(err) - } - - fmt.Printf("TCP4\n") - conns, lConns, err := iph.GetTables(iphelper.TCP, iphelper.IPv4) - if err != nil { - panic(err) - } - fmt.Printf("Connections:\n") - for _, conn := range conns { - fmt.Printf("%s\n", conn) - } - fmt.Printf("Listeners:\n") - for _, conn := range lConns { - fmt.Printf("%s\n", conn) - } - - fmt.Printf("\nTCP6\n") - conns, lConns, err = iph.GetTables(iphelper.TCP, iphelper.IPv6) - if err != nil { - panic(err) - } - fmt.Printf("Connections:\n") - for _, conn := range conns { - fmt.Printf("%s\n", conn) - } - fmt.Printf("Listeners:\n") - for _, conn := range lConns { - fmt.Printf("%s\n", conn) - } - - fmt.Printf("\nUDP4\n") - _, lConns, err = iph.GetTables(iphelper.UDP, iphelper.IPv4) - if err != nil { - panic(err) - } - for _, conn := range lConns { - fmt.Printf("%s\n", conn) - } - - fmt.Printf("\nUDP6\n") - _, lConns, err = iph.GetTables(iphelper.UDP, iphelper.IPv6) - if err != nil { - panic(err) - } - for _, conn := range lConns { - fmt.Printf("%s\n", conn) - } -} diff --git a/process/proc/gather.go b/process/proc/gather.go deleted file mode 100644 index 1413b3c9..00000000 --- a/process/proc/gather.go +++ /dev/null @@ -1,83 +0,0 @@ -// +build linux - -package proc - -import ( - "net" - "time" -) - -// PID querying return codes -const ( - Success uint8 = iota - NoSocket - NoProcess -) - -var ( - waitTime = 15 * time.Millisecond -) - -// GetPidOfConnection returns the PID of the given connection. -func GetPidOfConnection(localIP net.IP, localPort uint16, protocol uint8) (pid int, status uint8) { - uid, inode, ok := getConnectionSocket(localIP, localPort, protocol) - if !ok { - uid, inode, ok = getListeningSocket(localIP, localPort, protocol) - for i := 0; i < 3 && !ok; i++ { - // give kernel some time, then try again - // log.Tracef("process: giving kernel some time to think") - time.Sleep(waitTime) - uid, inode, ok = getConnectionSocket(localIP, localPort, protocol) - if !ok { - uid, inode, ok = getListeningSocket(localIP, localPort, protocol) - } - } - if !ok { - return unidentifiedProcessID, NoSocket - } - } - - pid, ok = GetPidOfInode(uid, inode) - for i := 0; i < 3 && !ok; i++ { - // give kernel some time, then try again - // log.Tracef("process: giving kernel some time to think") - time.Sleep(waitTime) - pid, ok = GetPidOfInode(uid, inode) - } - if !ok { - return unidentifiedProcessID, NoProcess - } - - return -} - -// GetPidOfIncomingConnection returns the PID of the given incoming connection. -func GetPidOfIncomingConnection(localIP net.IP, localPort uint16, protocol uint8) (pid int, status uint8) { - uid, inode, ok := getListeningSocket(localIP, localPort, protocol) - if !ok { - // for TCP4 and UDP4, also try TCP6 and UDP6, as linux sometimes treats them as a single dual socket, and shows the IPv6 version. - switch protocol { - case TCP4: - uid, inode, ok = getListeningSocket(localIP, localPort, TCP6) - case UDP4: - uid, inode, ok = getListeningSocket(localIP, localPort, UDP6) - } - - if !ok { - return unidentifiedProcessID, NoSocket - } - } - - pid, ok = GetPidOfInode(uid, inode) - for i := 0; i < 3 && !ok; i++ { - // give kernel some time, then try again - // log.Tracef("process: giving kernel some time to think") - time.Sleep(waitTime) - pid, ok = GetPidOfInode(uid, inode) - } - if !ok { - return unidentifiedProcessID, NoProcess - } - - return -} diff --git a/process/proc/get.go b/process/proc/get.go deleted file mode 100644 index 52974b3e..00000000 --- a/process/proc/get.go +++ /dev/null @@ -1,66 +0,0 @@ -// +build linux - -package proc - -import ( - "errors" - "net" -) - -const ( - unidentifiedProcessID = -1 -) - -// GetTCP4PacketInfo searches the network state tables for a TCP4 connection -func GetTCP4PacketInfo(localIP net.IP, localPort uint16, remoteIP net.IP, remotePort uint16, pktDirection bool) (pid int, direction bool, err error) { - return search(TCP4, localIP, localPort, pktDirection) -} - -// GetTCP6PacketInfo searches the network state tables for a TCP6 connection -func GetTCP6PacketInfo(localIP net.IP, localPort uint16, remoteIP net.IP, remotePort uint16, pktDirection bool) (pid int, direction bool, err error) { - return search(TCP6, localIP, localPort, pktDirection) -} - -// GetUDP4PacketInfo searches the network state tables for a UDP4 connection -func GetUDP4PacketInfo(localIP net.IP, localPort uint16, remoteIP net.IP, remotePort uint16, pktDirection bool) (pid int, direction bool, err error) { - return search(UDP4, localIP, localPort, pktDirection) -} - -// GetUDP6PacketInfo searches the network state tables for a UDP6 connection -func GetUDP6PacketInfo(localIP net.IP, localPort uint16, remoteIP net.IP, remotePort uint16, pktDirection bool) (pid int, direction bool, err error) { - return search(UDP6, localIP, localPort, pktDirection) -} - -func search(protocol uint8, localIP net.IP, localPort uint16, pktDirection bool) (pid int, direction bool, err error) { - - var status uint8 - if pktDirection { - pid, status = GetPidOfIncomingConnection(localIP, localPort, protocol) - if pid >= 0 { - return pid, true, nil - } - // pid, status = GetPidOfConnection(localIP, localPort, protocol) - // if pid >= 0 { - // return pid, false, nil - // } - } else { - pid, status = GetPidOfConnection(localIP, localPort, protocol) - if pid >= 0 { - return pid, false, nil - } - // pid, status = GetPidOfIncomingConnection(localIP, localPort, protocol) - // if pid >= 0 { - // return pid, true, nil - // } - } - - switch status { - case NoSocket: - return unidentifiedProcessID, direction, errors.New("could not find socket") - case NoProcess: - return unidentifiedProcessID, direction, errors.New("could not find PID") - default: - return unidentifiedProcessID, direction, nil - } - -} diff --git a/process/proc/processfinder_test.go b/process/proc/processfinder_test.go deleted file mode 100644 index 16d3d181..00000000 --- a/process/proc/processfinder_test.go +++ /dev/null @@ -1,18 +0,0 @@ -// +build linux - -package proc - -import ( - "log" - "testing" -) - -func TestProcessFinder(t *testing.T) { - - updatePids() - log.Printf("pidsByUser: %v", pidsByUser) - - pid, _ := GetPidOfInode(1000, 112588) - log.Printf("pid: %d", pid) - -} diff --git a/process/proc/sockets.go b/process/proc/sockets.go deleted file mode 100644 index bcdd91d4..00000000 --- a/process/proc/sockets.go +++ /dev/null @@ -1,370 +0,0 @@ -// +build linux - -package proc - -import ( - "bufio" - "encoding/hex" - "fmt" - "net" - "os" - "strconv" - "strings" - "sync" - "unicode" - - "github.com/safing/portbase/log" -) - -/* - -1. find socket inode - - by incoming (listenting sockets) or outgoing (local port + external IP + port) - also local IP? - - /proc/net/{tcp|udp}[6] - -2. get list of processes of uid - -3. find socket inode in process fds - - if not found, refresh map of uid->pids - - if not found, check ALL pids: maybe euid != uid - -4. gather process info - -Cache every step! - -*/ - -// Network Related Constants -const ( - TCP4 uint8 = iota - UDP4 - TCP6 - UDP6 - ICMP4 - ICMP6 - - TCP4Data = "/proc/net/tcp" - UDP4Data = "/proc/net/udp" - TCP6Data = "/proc/net/tcp6" - UDP6Data = "/proc/net/udp6" - ICMP4Data = "/proc/net/icmp" - ICMP6Data = "/proc/net/icmp6" -) - -var ( - // connectionSocketsLock sync.Mutex - // connectionTCP4 = make(map[string][]int) - // connectionUDP4 = make(map[string][]int) - // connectionTCP6 = make(map[string][]int) - // connectionUDP6 = make(map[string][]int) - - listeningSocketsLock sync.Mutex - addressListeningTCP4 = make(map[string][]int) - addressListeningUDP4 = make(map[string][]int) - addressListeningTCP6 = make(map[string][]int) - addressListeningUDP6 = make(map[string][]int) - globalListeningTCP4 = make(map[uint16][]int) - globalListeningUDP4 = make(map[uint16][]int) - globalListeningTCP6 = make(map[uint16][]int) - globalListeningUDP6 = make(map[uint16][]int) -) - -func getConnectionSocket(localIP net.IP, localPort uint16, protocol uint8) (int, int, bool) { - // listeningSocketsLock.Lock() - // defer listeningSocketsLock.Unlock() - - var procFile string - var localIPHex string - switch protocol { - case TCP4: - procFile = TCP4Data - localIPBytes := []byte(localIP.To4()) - localIPHex = strings.ToUpper(hex.EncodeToString([]byte{localIPBytes[3], localIPBytes[2], localIPBytes[1], localIPBytes[0]})) - case UDP4: - procFile = UDP4Data - localIPBytes := []byte(localIP.To4()) - localIPHex = strings.ToUpper(hex.EncodeToString([]byte{localIPBytes[3], localIPBytes[2], localIPBytes[1], localIPBytes[0]})) - case TCP6: - procFile = TCP6Data - localIPHex = hex.EncodeToString([]byte(localIP)) - case UDP6: - procFile = UDP6Data - localIPHex = hex.EncodeToString([]byte(localIP)) - } - - localPortHex := fmt.Sprintf("%04X", localPort) - - // log.Tracef("process/proc: searching for PID of: %s:%d (%s:%s)", localIP, localPort, localIPHex, localPortHex) - - // open file - socketData, err := os.Open(procFile) - if err != nil { - log.Warningf("process/proc: could not read %s: %s", procFile, err) - return unidentifiedProcessID, unidentifiedProcessID, false - } - defer socketData.Close() - - // file scanner - scanner := bufio.NewScanner(socketData) - scanner.Split(bufio.ScanLines) - - // parse - scanner.Scan() // skip first line - for scanner.Scan() { - line := strings.FieldsFunc(scanner.Text(), procDelimiter) - // log.Tracef("line: %s", line) - if len(line) < 14 { - // log.Tracef("process: too short: %s", line) - continue - } - - if line[1] != localIPHex { - continue - } - if line[2] != localPortHex { - continue - } - - ok := true - - uid, err := strconv.ParseInt(line[11], 10, 32) - if err != nil { - log.Warningf("process: could not parse uid %s: %s", line[11], err) - uid = -1 - ok = false - } - - inode, err := strconv.ParseInt(line[13], 10, 32) - if err != nil { - log.Warningf("process: could not parse inode %s: %s", line[13], err) - inode = -1 - ok = false - } - - // log.Tracef("process/proc: identified process of %s:%d: socket=%d uid=%d", localIP, localPort, int(inode), int(uid)) - return int(uid), int(inode), ok - - } - - return unidentifiedProcessID, unidentifiedProcessID, false - -} - -func getListeningSocket(localIP net.IP, localPort uint16, protocol uint8) (uid, inode int, ok bool) { - listeningSocketsLock.Lock() - defer listeningSocketsLock.Unlock() - - var addressListening map[string][]int - var globalListening map[uint16][]int - switch protocol { - case TCP4: - addressListening = addressListeningTCP4 - globalListening = globalListeningTCP4 - case UDP4: - addressListening = addressListeningUDP4 - globalListening = globalListeningUDP4 - case TCP6: - addressListening = addressListeningTCP6 - globalListening = globalListeningTCP6 - case UDP6: - addressListening = addressListeningUDP6 - globalListening = globalListeningUDP6 - } - - data, ok := addressListening[fmt.Sprintf("%s:%d", localIP, localPort)] - if !ok { - data, ok = globalListening[localPort] - } - if ok { - return data[0], data[1], true - } - updateListeners(protocol) - data, ok = addressListening[fmt.Sprintf("%s:%d", localIP, localPort)] - if !ok { - data, ok = globalListening[localPort] - } - if ok { - return data[0], data[1], true - } - - return unidentifiedProcessID, unidentifiedProcessID, false -} - -func procDelimiter(c rune) bool { - return unicode.IsSpace(c) || c == ':' -} - -func convertIPv4(data string) net.IP { - decoded, err := hex.DecodeString(data) - if err != nil { - log.Warningf("process: could not parse IPv4 %s: %s", data, err) - return nil - } - if len(decoded) != 4 { - log.Warningf("process: decoded IPv4 %s has wrong length", decoded) - return nil - } - ip := net.IPv4(decoded[3], decoded[2], decoded[1], decoded[0]) - return ip -} - -func convertIPv6(data string) net.IP { - decoded, err := hex.DecodeString(data) - if err != nil { - log.Warningf("process: could not parse IPv6 %s: %s", data, err) - return nil - } - if len(decoded) != 16 { - log.Warningf("process: decoded IPv6 %s has wrong length", decoded) - return nil - } - ip := net.IP(decoded) - return ip -} - -func updateListeners(protocol uint8) { - switch protocol { - case TCP4: - addressListeningTCP4, globalListeningTCP4 = getListenerMaps(TCP4Data, "00000000", "0A", convertIPv4) - case UDP4: - addressListeningUDP4, globalListeningUDP4 = getListenerMaps(UDP4Data, "00000000", "07", convertIPv4) - case TCP6: - addressListeningTCP6, globalListeningTCP6 = getListenerMaps(TCP6Data, "00000000000000000000000000000000", "0A", convertIPv6) - case UDP6: - addressListeningUDP6, globalListeningUDP6 = getListenerMaps(UDP6Data, "00000000000000000000000000000000", "07", convertIPv6) - } -} - -func getListenerMaps(procFile, zeroIP, socketStatusListening string, ipConverter func(string) net.IP) (map[string][]int, map[uint16][]int) { - addressListening := make(map[string][]int) - globalListening := make(map[uint16][]int) - - // open file - socketData, err := os.Open(procFile) - if err != nil { - log.Warningf("process: could not read %s: %s", procFile, err) - return addressListening, globalListening - } - defer socketData.Close() - - // file scanner - scanner := bufio.NewScanner(socketData) - scanner.Split(bufio.ScanLines) - - // parse - scanner.Scan() // skip first line - for scanner.Scan() { - line := strings.FieldsFunc(scanner.Text(), procDelimiter) - if len(line) < 14 { - // log.Tracef("process: too short: %s", line) - continue - } - if line[5] != socketStatusListening { - // skip if not listening - // log.Tracef("process: not listening %s: %s", line, line[5]) - continue - } - - port, err := strconv.ParseUint(line[2], 16, 16) - // log.Tracef("port: %s", line[2]) - if err != nil { - log.Warningf("process: could not parse port %s: %s", line[2], err) - continue - } - - uid, err := strconv.ParseInt(line[11], 10, 32) - // log.Tracef("uid: %s", line[11]) - if err != nil { - log.Warningf("process: could not parse uid %s: %s", line[11], err) - continue - } - - inode, err := strconv.ParseInt(line[13], 10, 32) - // log.Tracef("inode: %s", line[13]) - if err != nil { - log.Warningf("process: could not parse inode %s: %s", line[13], err) - continue - } - - if line[1] == zeroIP { - globalListening[uint16(port)] = []int{int(uid), int(inode)} - } else { - address := ipConverter(line[1]) - if address != nil { - addressListening[fmt.Sprintf("%s:%d", address, port)] = []int{int(uid), int(inode)} - } - } - - } - - return addressListening, globalListening -} - -// GetActiveConnectionIDs returns all connection IDs that are still marked as active by the OS. -func GetActiveConnectionIDs() []string { - var connections []string - - connections = append(connections, getConnectionIDsFromSource(TCP4Data, 6, convertIPv4)...) - connections = append(connections, getConnectionIDsFromSource(UDP4Data, 17, convertIPv4)...) - connections = append(connections, getConnectionIDsFromSource(TCP6Data, 6, convertIPv6)...) - connections = append(connections, getConnectionIDsFromSource(UDP6Data, 17, convertIPv6)...) - - return connections -} - -func getConnectionIDsFromSource(source string, protocol uint16, ipConverter func(string) net.IP) []string { - var connections []string - - // open file - socketData, err := os.Open(source) - if err != nil { - log.Warningf("process: could not read %s: %s", source, err) - return connections - } - defer socketData.Close() - - // file scanner - scanner := bufio.NewScanner(socketData) - scanner.Split(bufio.ScanLines) - - // parse - scanner.Scan() // skip first line - for scanner.Scan() { - line := strings.FieldsFunc(scanner.Text(), procDelimiter) - if len(line) < 14 { - // log.Tracef("process: too short: %s", line) - continue - } - - // skip listeners and closed connections - if line[5] == "0A" || line[5] == "07" { - continue - } - - localIP := ipConverter(line[1]) - if localIP == nil { - continue - } - - localPort, err := strconv.ParseUint(line[2], 16, 16) - if err != nil { - log.Warningf("process: could not parse port: %s", err) - continue - } - - remoteIP := ipConverter(line[3]) - if remoteIP == nil { - continue - } - - remotePort, err := strconv.ParseUint(line[4], 16, 16) - if err != nil { - log.Warningf("process: could not parse port: %s", err) - continue - } - - connections = append(connections, fmt.Sprintf("%d-%s-%d-%s-%d", protocol, localIP, localPort, remoteIP, remotePort)) - } - - return connections -} diff --git a/process/proc/sockets_test.go b/process/proc/sockets_test.go deleted file mode 100644 index 44e8fd34..00000000 --- a/process/proc/sockets_test.go +++ /dev/null @@ -1,40 +0,0 @@ -// +build linux - -package proc - -import ( - "net" - "testing" -) - -func TestSockets(t *testing.T) { - - updateListeners(TCP4) - updateListeners(UDP4) - updateListeners(TCP6) - updateListeners(UDP6) - t.Logf("addressListeningTCP4: %v", addressListeningTCP4) - t.Logf("globalListeningTCP4: %v", globalListeningTCP4) - t.Logf("addressListeningUDP4: %v", addressListeningUDP4) - t.Logf("globalListeningUDP4: %v", globalListeningUDP4) - t.Logf("addressListeningTCP6: %v", addressListeningTCP6) - t.Logf("globalListeningTCP6: %v", globalListeningTCP6) - t.Logf("addressListeningUDP6: %v", addressListeningUDP6) - t.Logf("globalListeningUDP6: %v", globalListeningUDP6) - - getListeningSocket(net.IPv4zero, 53, TCP4) - getListeningSocket(net.IPv4zero, 53, UDP4) - getListeningSocket(net.IPv6zero, 53, TCP6) - getListeningSocket(net.IPv6zero, 53, UDP6) - - // spotify: 192.168.0.102:5353 192.121.140.65:80 - localIP := net.IPv4(192, 168, 127, 10) - uid, inode, ok := getConnectionSocket(localIP, 46634, TCP4) - t.Logf("getConnectionSocket: %d %d %v", uid, inode, ok) - - activeConnectionIDs := GetActiveConnectionIDs() - for _, connID := range activeConnectionIDs { - t.Logf("active: %s", connID) - } - -} diff --git a/process/process.go b/process/process.go index 8ef1ad73..16d04cb3 100644 --- a/process/process.go +++ b/process/process.go @@ -49,7 +49,8 @@ type Process struct { FirstSeen int64 LastSeen int64 - Virtual bool // This process is either merged into another process or is not needed. + Virtual bool // This process is either merged into another process or is not needed. + Error string // Cache errors } // Profile returns the assigned layered profile. @@ -94,6 +95,7 @@ func GetOrFindPrimaryProcess(ctx context.Context, pid int) (*Process, error) { parentProcess, err := loadProcess(ctx, process.ParentPid) if err != nil { log.Tracer(ctx).Tracef("process: could not get parent of %d: %d: %s", process.Pid, process.ParentPid, err) + saveFailedProcess(process.ParentPid, err.Error()) return process, nil } @@ -226,13 +228,7 @@ func loadProcess(ctx context.Context, pid int) (*Process, error) { pInfo, err := processInfo.NewProcess(int32(pid)) if err != nil { - // TODO: remove this workaround as soon as NewProcess really returns an error on windows when the process does not exist - // Issue: https://github.com/shirou/gopsutil/issues/729 - _, err = pInfo.Name() - if err != nil { - // process does not exists - return nil, err - } + return nil, err } // UID @@ -375,3 +371,14 @@ func loadProcess(ctx context.Context, pid int) (*Process, error) { new.Save() return new, nil } + +func saveFailedProcess(pid int, err string) { + failed := &Process{ + Pid: pid, + FirstSeen: time.Now().Unix(), + Virtual: true, // not needed + Error: err, + } + + failed.Save() +} diff --git a/profile/config.go b/profile/config.go index 4464a513..494bdc7a 100644 --- a/profile/config.go +++ b/profile/config.go @@ -94,6 +94,7 @@ func registerConfiguration() error { Description: `The default filter action when nothing else permits or blocks a connection.`, Order: cfgOptionDefaultActionOrder, OptType: config.OptTypeString, + ReleaseLevel: config.ReleaseLevelExperimental, DefaultValue: "permit", ExternalOptType: "string list", ValidationRegex: "^(permit|ask|block)$", @@ -121,17 +122,12 @@ func registerConfiguration() error { cfgOptionDisableAutoPermit = config.Concurrent.GetAsInt(CfgOptionDisableAutoPermitKey, int64(status.SecurityLevelsAll)) cfgIntOptions[CfgOptionDisableAutoPermitKey] = cfgOptionDisableAutoPermit - // Endpoint Filter List - err = config.Register(&config.Option{ - Name: "Endpoint Filter List", - Key: CfgOptionEndpointsKey, - Description: "Filter outgoing connections by matching the destination endpoint. Network Scope restrictions still apply.", - Help: `Format: + filterListHelp := `Format: Permission: "+": permit "-": block Host Matching: - IP, CIDR, Country Code, ASN, Filterlist, "*" for any + IP, CIDR, Country Code, ASN, Filterlist, Network Scope, "*" for any Domains: "example.com": exact match ".example.com": exact match + subdomains @@ -144,11 +140,20 @@ func registerConfiguration() error { Examples: + .example.com */HTTP - .example.com - + 192.168.0.1/24 + + 192.168.0.1 + + 192.168.1.1/24 + + Localhost,LAN + - AS123456789 - L:MAL - - AS0 + AT - - *`, + - *` + + // Endpoint Filter List + err = config.Register(&config.Option{ + Name: "Endpoint Filter List", + Key: CfgOptionEndpointsKey, + Description: "Filter outgoing connections by matching the destination endpoint. Network Scope restrictions still apply.", + Help: filterListHelp, Order: cfgOptionEndpointsOrder, OptType: config.OptTypeStringArray, DefaultValue: []string{}, @@ -163,35 +168,13 @@ Examples: // Service Endpoint Filter List err = config.Register(&config.Option{ - Name: "Service Endpoint Filter List", - Key: CfgOptionServiceEndpointsKey, - Description: "Filter incoming connections by matching the source endpoint. Network Scope restrictions and the inbound permission still apply. Also not that the implicit default action of this list is to always block.", - Help: `Format: - Permission: - "+": permit - "-": block - Host Matching: - IP, CIDR, Country Code, ASN, Filterlist, "*" for any - Domains: - "example.com": exact match - ".example.com": exact match + subdomains - "*xample.com": prefix wildcard - "example.*": suffix wildcard - "*example*": prefix and suffix wildcard - Protocol and Port Matching (optional): - / - -Examples: - + .example.com */HTTP - - .example.com - + 192.168.0.1/24 - - L:MAL - - AS0 - + AT - - *`, + Name: "Service Endpoint Filter List", + Key: CfgOptionServiceEndpointsKey, + Description: "Filter incoming connections by matching the source endpoint. Network Scope restrictions and the inbound permission still apply. Also not that the implicit default action of this list is to always block.", + Help: filterListHelp, Order: cfgOptionServiceEndpointsOrder, OptType: config.OptTypeStringArray, - DefaultValue: []string{}, + DefaultValue: []string{"+ Localhost"}, ExternalOptType: "endpoint list", ValidationRegex: `^(\+|\-) [A-z0-9\.:\-*/]+( [A-z0-9/]+)?$`, }) @@ -313,7 +296,7 @@ Examples: Order: cfgOptionBlockP2POrder, OptType: config.OptTypeInt, ExternalOptType: "security level", - DefaultValue: status.SecurityLevelsAll, + DefaultValue: status.SecurityLevelExtreme, ValidationRegex: "^(4|6|7)$", }) if err != nil { @@ -326,7 +309,7 @@ Examples: err = config.Register(&config.Option{ Name: "Block Inbound Connections", Key: CfgOptionBlockInboundKey, - Description: "Connections initiated towards your device. This will usually only be the case if you are running a network service or are using peer to peer software.", + Description: "Connections initiated towards your device from the LAN or Internet. This will usually only be the case if you are running a network service or are using peer to peer software.", Order: cfgOptionBlockInboundOrder, OptType: config.OptTypeInt, ExternalOptType: "security level", diff --git a/profile/endpoints/endpoint-domain.go b/profile/endpoints/endpoint-domain.go index fbd0dcf9..3cc3450f 100644 --- a/profile/endpoints/endpoint-domain.go +++ b/profile/endpoints/endpoint-domain.go @@ -31,7 +31,7 @@ type EndpointDomain struct { } func (ep *EndpointDomain) check(entity *intel.Entity, domain string) (EPResult, Reason) { - result, reason := ep.match(ep, entity, ep.Domain, "domain matches") + result, reason := ep.match(ep, entity, ep.OriginalValue, "domain matches") switch ep.MatchType { case domainMatchTypeExact: diff --git a/profile/endpoints/endpoint-scopes.go b/profile/endpoints/endpoint-scopes.go new file mode 100644 index 00000000..ea22126d --- /dev/null +++ b/profile/endpoints/endpoint-scopes.go @@ -0,0 +1,106 @@ +package endpoints + +import ( + "strings" + + "github.com/safing/portmaster/network/netutils" + + "github.com/safing/portmaster/intel" +) + +const ( + scopeLocalhost = 1 + scopeLocalhostName = "Localhost" + scopeLocalhostMatcher = "localhost" + + scopeLAN = 2 + scopeLANName = "LAN" + scopeLANMatcher = "lan" + + scopeInternet = 4 + scopeInternetName = "Internet" + scopeInternetMatcher = "internet" +) + +// EndpointScope matches network scopes. +type EndpointScope struct { + EndpointBase + + scopes uint8 +} + +// Matches checks whether the given entity matches this endpoint definition. +func (ep *EndpointScope) Matches(entity *intel.Entity) (EPResult, Reason) { + if entity.IP == nil { + return Undeterminable, nil + } + + classification := netutils.ClassifyIP(entity.IP) + var scope uint8 + switch classification { + case netutils.HostLocal: + scope = scopeLocalhost + case netutils.LinkLocal: + scope = scopeLAN + case netutils.SiteLocal: + scope = scopeLAN + case netutils.Global: + scope = scopeInternet + case netutils.LocalMulticast: + scope = scopeLAN + case netutils.GlobalMulticast: + scope = scopeInternet + } + + if ep.scopes&scope > 0 { + return ep.match(ep, entity, ep.Scopes(), "scope matches") + } + return NoMatch, nil +} + +// Scopes returns the string representation of all scopes. +func (ep *EndpointScope) Scopes() string { + // single scope + switch ep.scopes { + case scopeLocalhost: + return scopeLocalhostName + case scopeLAN: + return scopeLANName + case scopeInternet: + return scopeInternetName + } + + // multiple scopes + var s []string + if ep.scopes&scopeLocalhost > 0 { + s = append(s, scopeLocalhostName) + } + if ep.scopes&scopeLAN > 0 { + s = append(s, scopeLANName) + } + if ep.scopes&scopeInternet > 0 { + s = append(s, scopeInternetName) + } + return strings.Join(s, ",") +} + +func (ep *EndpointScope) String() string { + return ep.renderPPP(ep.Scopes()) +} + +func parseTypeScope(fields []string) (Endpoint, error) { + ep := &EndpointScope{} + for _, val := range strings.Split(strings.ToLower(fields[1]), ",") { + switch val { + case scopeLocalhostMatcher: + ep.scopes ^= scopeLocalhost + case scopeLANMatcher: + ep.scopes ^= scopeLAN + case scopeInternetMatcher: + ep.scopes ^= scopeInternet + default: + return nil, nil + } + } + return ep.parsePPP(ep, fields) +} diff --git a/profile/endpoints/endpoint.go b/profile/endpoints/endpoint.go index 76847ac7..2e0a4e85 100644 --- a/profile/endpoints/endpoint.go +++ b/profile/endpoints/endpoint.go @@ -201,7 +201,7 @@ func invalidDefinitionError(fields []string, msg string) error { return fmt.Errorf(`invalid endpoint definition: "%s" - %s`, strings.Join(fields, " "), msg) } -func parseEndpoint(value string) (endpoint Endpoint, err error) { +func parseEndpoint(value string) (endpoint Endpoint, err error) { //nolint:gocognit fields := strings.Fields(value) if len(fields) < 2 { return nil, fmt.Errorf(`invalid endpoint definition: "%s"`, value) @@ -231,6 +231,10 @@ func parseEndpoint(value string) (endpoint Endpoint, err error) { if endpoint, err = parseTypeASN(fields); endpoint != nil || err != nil { return } + // scopes + if endpoint, err = parseTypeScope(fields); endpoint != nil || err != nil { + return + } // lists if endpoint, err = parseTypeList(fields); endpoint != nil || err != nil { return diff --git a/profile/endpoints/endpoint_test.go b/profile/endpoints/endpoint_test.go index d8aabee8..21ef057e 100644 --- a/profile/endpoints/endpoint_test.go +++ b/profile/endpoints/endpoint_test.go @@ -43,6 +43,12 @@ func TestEndpointParsing(t *testing.T) { testParsing(t, "+ AS1234") testParsing(t, "+ AS12345") + // network scope + testParsing(t, "+ Localhost") + testParsing(t, "+ LAN") + testParsing(t, "+ Internet") + testParsing(t, "+ Localhost,LAN,Internet") + // protocol and ports testParsing(t, "+ * TCP/1-1024") testParsing(t, "+ * */DNS") diff --git a/profile/endpoints/endpoints_test.go b/profile/endpoints/endpoints_test.go index 7a275e3e..85164a4d 100644 --- a/profile/endpoints/endpoints_test.go +++ b/profile/endpoints/endpoints_test.go @@ -342,7 +342,26 @@ func TestEndpointMatching(t *testing.T) { IP: net.ParseIP("151.101.1.164"), // nytimes.com }).Init(), NoMatch) + // Scope + + ep, err = parseEndpoint("+ Localhost,LAN") + if err != nil { + t.Fatal(err) + } + + testEndpointMatch(t, ep, (&intel.Entity{ + IP: net.ParseIP("192.168.0.1"), + }).Init(), Permitted) + testEndpointMatch(t, ep, (&intel.Entity{ + IP: net.ParseIP("151.101.1.164"), // nytimes.com + }).Init(), NoMatch) + // Lists + + _, err = parseEndpoint("+ L:A,B,C") + if err != nil { + t.Fatal(err) + } // TODO: write test for lists matcher } diff --git a/profile/profile-layered.go b/profile/profile-layered.go index 45311662..ab0335a2 100644 --- a/profile/profile-layered.go +++ b/profile/profile-layered.go @@ -126,6 +126,18 @@ func (lp *LayeredProfile) getValidityFlag() *abool.AtomicBool { return lp.validityFlag } +// RevisionCnt returns the current profile revision counter. +func (lp *LayeredProfile) RevisionCnt() (revisionCounter uint64) { + if lp == nil { + return 0 + } + + lp.lock.Lock() + defer lp.lock.Unlock() + + return lp.revisionCounter +} + // Update checks for updated profiles and replaces any outdated profiles. func (lp *LayeredProfile) Update() (revisionCounter uint64) { lp.lock.Lock() diff --git a/resolver/block_detection.go b/resolver/block-detection.go similarity index 100% rename from resolver/block_detection.go rename to resolver/block-detection.go diff --git a/resolver/clients.go b/resolver/clients.go index 6d1ad4b2..e3456759 100644 --- a/resolver/clients.go +++ b/resolver/clients.go @@ -1,6 +1,7 @@ package resolver import ( + "context" "crypto/tls" "net" "sync" @@ -9,6 +10,13 @@ import ( "github.com/miekg/dns" ) +const ( + defaultClientTTL = 5 * time.Minute + defaultRequestTimeout = 3 * time.Second // dns query + defaultConnectTimeout = 2 * time.Second // tcp/tls + connectionEOLGracePeriod = 7 * time.Second +) + var ( localAddrFactory func(network string) net.Addr ) @@ -27,21 +35,54 @@ func getLocalAddr(network string) net.Addr { return nil } -type clientManager struct { - dnsClient *dns.Client - factory func() *dns.Client +type dnsClientManager struct { + lock sync.Mutex - lock sync.Mutex - refreshAfter time.Time - ttl time.Duration // force refresh of connection to reduce traceability + // set by creator + serverAddress string + ttl time.Duration // force refresh of connection to reduce traceability + factory func() *dns.Client + + // internal + pool sync.Pool } -func newDNSClientManager(_ *Resolver) *clientManager { - return &clientManager{ - ttl: 0, // new client for every request, as we need to randomize the port +type dnsClient struct { + mgr *dnsClientManager + client *dns.Client + conn *dns.Conn + useUntil time.Time +} + +// getConn returns the *dns.Conn and if it's new. This function may only be called between clientManager.getDNSClient() and dnsClient.done(). +func (dc *dnsClient) getConn() (c *dns.Conn, new bool, err error) { + if dc.conn == nil { + dc.conn, err = dc.client.Dial(dc.mgr.serverAddress) + if err != nil { + return nil, false, err + } + return dc.conn, true, nil + } + return dc.conn, false, nil +} + +func (dc *dnsClient) addToPool() { + dc.mgr.pool.Put(dc) +} + +func (dc *dnsClient) destroy() { + if dc.conn != nil { + _ = dc.conn.Close() + } +} + +func newDNSClientManager(resolver *Resolver) *dnsClientManager { + return &dnsClientManager{ + serverAddress: resolver.ServerAddress, + ttl: 0, // new client for every request, as we need to randomize the port factory: func() *dns.Client { return &dns.Client{ - Timeout: 5 * time.Second, + Timeout: defaultRequestTimeout, Dialer: &net.Dialer{ LocalAddr: getLocalAddr("udp"), }, @@ -50,25 +91,28 @@ func newDNSClientManager(_ *Resolver) *clientManager { } } -func newTCPClientManager(_ *Resolver) *clientManager { - return &clientManager{ - ttl: 0, // TODO: build a custom client that can reuse connections to some degree (performance / privacy tradeoff) +func newTCPClientManager(resolver *Resolver) *dnsClientManager { + return &dnsClientManager{ + serverAddress: resolver.ServerAddress, + ttl: defaultClientTTL, factory: func() *dns.Client { return &dns.Client{ Net: "tcp", - Timeout: 5 * time.Second, + Timeout: defaultRequestTimeout, Dialer: &net.Dialer{ LocalAddr: getLocalAddr("tcp"), - KeepAlive: 15 * time.Second, + Timeout: defaultConnectTimeout, + KeepAlive: defaultClientTTL, }, } }, } } -func newTLSClientManager(resolver *Resolver) *clientManager { - return &clientManager{ - ttl: 0, // TODO: build a custom client that can reuse connections to some degree (performance / privacy tradeoff) +func newTLSClientManager(resolver *Resolver) *dnsClientManager { + return &dnsClientManager{ + serverAddress: resolver.ServerAddress, + ttl: defaultClientTTL, factory: func() *dns.Client { return &dns.Client{ Net: "tcp-tls", @@ -77,24 +121,68 @@ func newTLSClientManager(resolver *Resolver) *clientManager { ServerName: resolver.VerifyDomain, // TODO: use portbase rng }, - Timeout: 5 * time.Second, + Timeout: defaultRequestTimeout, Dialer: &net.Dialer{ LocalAddr: getLocalAddr("tcp"), - KeepAlive: 15 * time.Second, + Timeout: defaultConnectTimeout, + KeepAlive: defaultClientTTL, }, } }, } } -func (cm *clientManager) getDNSClient() *dns.Client { +func (cm *dnsClientManager) getDNSClient() *dnsClient { cm.lock.Lock() defer cm.lock.Unlock() - if cm.dnsClient == nil || cm.ttl == 0 || time.Now().After(cm.refreshAfter) { - cm.dnsClient = cm.factory() - cm.refreshAfter = time.Now().Add(cm.ttl) + // return new immediately if a new client should be used for every request + if cm.ttl == 0 { + return &dnsClient{ + mgr: cm, + client: cm.factory(), + } } - return cm.dnsClient + // get cached client from pool + now := time.Now().UTC() + +poolLoop: + for { + dc, ok := cm.pool.Get().(*dnsClient) + switch { + case !ok || dc == nil: // cache empty (probably, pool may always return nil!) + break poolLoop // create new + case now.After(dc.useUntil): + continue // get next + default: + return dc + } + } + + // no available in pool, create new + newClient := &dnsClient{ + mgr: cm, + client: cm.factory(), + useUntil: now.Add(cm.ttl), + } + newClient.startCleaner() + + return newClient +} + +// startCleaner waits for EOL of the client and then removes it from the pool. +func (dc *dnsClient) startCleaner() { + // While a single worker to clean all connections may be slightly more performant, this approach focuses on least as possible locking and is simpler, thus less error prone. + module.StartWorker("dns client cleanup", func(ctx context.Context) error { + select { + case <-time.After(dc.mgr.ttl + connectionEOLGracePeriod): + // destroy + case <-ctx.Done(): + // give a short time before kill for graceful request completion + time.Sleep(100 * time.Millisecond) + } + dc.destroy() + return nil + }) } diff --git a/resolver/main.go b/resolver/main.go index 9c71f5db..05d20fe3 100644 --- a/resolver/main.go +++ b/resolver/main.go @@ -2,6 +2,7 @@ package resolver import ( "context" + "strings" "time" "github.com/safing/portbase/log" @@ -30,6 +31,7 @@ func start() error { // load resolvers from config and environment loadResolvers() + // reload after network change err := module.RegisterEventHook( "netenv", "network changed", @@ -44,6 +46,27 @@ func start() error { return err } + // reload after config change + prevNameservers := strings.Join(configuredNameServers(), " ") + err = module.RegisterEventHook( + "config", + "config change", + "update nameservers", + func(_ context.Context, _ interface{}) error { + newNameservers := strings.Join(configuredNameServers(), " ") + if newNameservers != prevNameservers { + prevNameservers = newNameservers + + loadResolvers() + log.Debug("resolver: reloaded nameservers due to config change") + } + return nil + }, + ) + if err != nil { + return err + } + module.StartServiceWorker( "mdns handler", 5*time.Second, diff --git a/resolver/mdns.go b/resolver/mdns.go index a8ba1ee5..5fe889dd 100644 --- a/resolver/mdns.go +++ b/resolver/mdns.go @@ -9,6 +9,8 @@ import ( "sync" "time" + "github.com/safing/portmaster/network/netutils" + "github.com/miekg/dns" "github.com/safing/portbase/log" @@ -29,10 +31,11 @@ var ( questionsLock sync.Mutex mDNSResolver = &Resolver{ - Server: ServerSourceMDNS, - ServerType: ServerTypeDNS, - Source: ServerSourceMDNS, - Conn: &mDNSResolverConn{}, + Server: ServerSourceMDNS, + ServerType: ServerTypeDNS, + ServerIPScope: netutils.SiteLocal, + Source: ServerSourceMDNS, + Conn: &mDNSResolverConn{}, } ) @@ -42,10 +45,10 @@ func (mrc *mDNSResolverConn) Query(ctx context.Context, q *Query) (*RRCache, err return queryMulticastDNS(ctx, q) } -func (mrc *mDNSResolverConn) MarkFailed() {} +func (mrc *mDNSResolverConn) ReportFailure() {} -func (mrc *mDNSResolverConn) LastFail() time.Time { - return time.Time{} +func (mrc *mDNSResolverConn) IsFailing() bool { + return false } type savedQuestion struct { @@ -189,15 +192,21 @@ func handleMDNSMessages(ctx context.Context, messages chan *dns.Msg) error { // get entry from database if saveFullRequest { + // get from database rrCache, err = GetRRCache(question.Name, dns.Type(question.Qtype)) + // if we have no cached entry, or it has been updated less more than two seconds ago, or if it expired: + // create new and do not append if err != nil || rrCache.updated < time.Now().Add(-2*time.Second).Unix() || rrCache.TTL < time.Now().Unix() { rrCache = &RRCache{ - Domain: question.Name, - Question: dns.Type(question.Qtype), + Domain: question.Name, + Question: dns.Type(question.Qtype), + Server: mDNSResolver.Server, + ServerScope: mDNSResolver.ServerIPScope, } } } + // add all entries to RRCache for _, entry := range message.Answer { if strings.HasSuffix(entry.Header().Name, ".local.") || domainInScope(entry.Header().Name, localReverseScopes) { if saveFullRequest { @@ -289,9 +298,11 @@ func handleMDNSMessages(ctx context.Context, messages chan *dns.Msg) error { continue } rrCache = &RRCache{ - Domain: v.Header().Name, - Question: dns.Type(v.Header().Class), - Answer: []dns.RR{v}, + Domain: v.Header().Name, + Question: dns.Type(v.Header().Class), + Answer: []dns.RR{v}, + Server: mDNSResolver.Server, + ServerScope: mDNSResolver.ServerIPScope, } rrCache.Clean(60) err := rrCache.Save() diff --git a/resolver/namerecord.go b/resolver/namerecord.go index d94beaaa..1a594e8f 100644 --- a/resolver/namerecord.go +++ b/resolver/namerecord.go @@ -12,7 +12,7 @@ import ( var ( recordDatabase = database.NewInterface(&database.Options{ AlwaysSetRelativateExpiry: 2592000, // 30 days - CacheSize: 128, + CacheSize: 256, }) ) diff --git a/resolver/namerecord_test.go b/resolver/namerecord_test.go new file mode 100644 index 00000000..f0e21a37 --- /dev/null +++ b/resolver/namerecord_test.go @@ -0,0 +1,27 @@ +package resolver + +import "testing" + +func TestNameRecordStorage(t *testing.T) { + testDomain := "Mk35mMqOWEHXSMk11MYcbjLOjTE8PQvDiAVUxf4BvwtgR.example.com." + testQuestion := "A" + + testNameRecord := &NameRecord{ + Domain: testDomain, + Question: testQuestion, + } + + err := testNameRecord.Save() + if err != nil { + t.Fatal(err) + } + + r, err := GetNameRecord(testDomain, testQuestion) + if err != nil { + t.Fatal(err) + } + + if r.Domain != testDomain || r.Question != testQuestion { + t.Fatal("mismatch") + } +} diff --git a/resolver/pooling_test.go b/resolver/pooling_test.go new file mode 100644 index 00000000..3c03c14c --- /dev/null +++ b/resolver/pooling_test.go @@ -0,0 +1,189 @@ +package resolver + +import ( + "sync" + "sync/atomic" + "testing" + + "github.com/miekg/dns" +) + +var ( + domainFeed = make(chan string) +) + +func testQuery(t *testing.T, wg *sync.WaitGroup, newCnt *uint32, brc *BasicResolverConn, q *Query) { + dnsClient := brc.clientManager.getDNSClient() + + // create query + dnsQuery := new(dns.Msg) + dnsQuery.SetQuestion(q.FQDN, uint16(q.QType)) + + // get connection + conn, new, err := dnsClient.getConn() + if err != nil { + t.Fatalf("failed to connect: %s", err) //nolint:staticcheck + } + if new { + atomic.AddUint32(newCnt, 1) + } + + // query server + reply, ttl, err := dnsClient.client.ExchangeWithConn(dnsQuery, conn) + if err != nil { + t.Fatal(err) //nolint:staticcheck + } + if reply == nil { + t.Fatalf("resolved %s, but reply was empty!", q.FQDN) //nolint:staticcheck + } + + t.Logf("resolved %s [new resolver = %v] in %s", q.FQDN, new, ttl) + dnsClient.addToPool() + wg.Done() +} + +func TestClientPooling(t *testing.T) { + // skip if short - this test depends on the Internet and might fail randomly + if testing.Short() { + t.Skip() + } + + go feedDomains() + + // create separate resolver for this test + resolver, _, err := createResolver("dot://9.9.9.9:853?verify=dns.quad9.net&name=Quad9&blockedif=empty", "config") + if err != nil { + t.Fatal(err) + } + brc := resolver.Conn.(*BasicResolverConn) + + wg := &sync.WaitGroup{} + var newCnt uint32 + for i := 0; i < 10; i++ { + wg.Add(10) + for i := 0; i < 10; i++ { + go testQuery(t, wg, &newCnt, brc, &Query{ //nolint:staticcheck + FQDN: <-domainFeed, + QType: dns.Type(dns.TypeA), + }) + } + wg.Wait() + if newCnt > uint32(10+i) { + t.Fatalf("unexpected pool size: %d (limit is %d)", newCnt, 10+i) + } + } +} + +func feedDomains() { + for { + for _, domain := range poolingTestDomains { + domainFeed <- domain + } + } +} + +// Data + +var ( + poolingTestDomains = []string{ + "facebook.com.", + "google.com.", + "youtube.com.", + "twitter.com.", + "instagram.com.", + "linkedin.com.", + "microsoft.com.", + "apple.com.", + "wikipedia.org.", + "plus.google.com.", + "en.wikipedia.org.", + "googletagmanager.com.", + "youtu.be.", + "adobe.com.", + "vimeo.com.", + "pinterest.com.", + "itunes.apple.com.", + "play.google.com.", + "maps.google.com.", + "goo.gl.", + "wordpress.com.", + "blogspot.com.", + "bit.ly.", + "github.com.", + "player.vimeo.com.", + "amazon.com.", + "wordpress.org.", + "docs.google.com.", + "yahoo.com.", + "mozilla.org.", + "tumblr.com.", + "godaddy.com.", + "flickr.com.", + "parked-content.godaddy.com.", + "drive.google.com.", + "support.google.com.", + "apache.org.", + "gravatar.com.", + "europa.eu.", + "qq.com.", + "w3.org.", + "nytimes.com.", + "reddit.com.", + "macromedia.com.", + "get.adobe.com.", + "soundcloud.com.", + "sourceforge.net.", + "sites.google.com.", + "nih.gov.", + "amazonaws.com.", + "t.co.", + "support.microsoft.com.", + "forbes.com.", + "theguardian.com.", + "cnn.com.", + "github.io.", + "bbc.co.uk.", + "dropbox.com.", + "whatsapp.com.", + "medium.com.", + "creativecommons.org.", + "www.ncbi.nlm.nih.gov.", + "httpd.apache.org.", + "archive.org.", + "ec.europa.eu.", + "php.net.", + "apps.apple.com.", + "weebly.com.", + "support.apple.com.", + "weibo.com.", + "wixsite.com.", + "issuu.com.", + "who.int.", + "paypal.com.", + "m.facebook.com.", + "oracle.com.", + "msn.com.", + "gnu.org.", + "tinyurl.com.", + "reuters.com.", + "l.facebook.com.", + "cloudflare.com.", + "wsj.com.", + "washingtonpost.com.", + "domainmarket.com.", + "imdb.com.", + "bbc.com.", + "bing.com.", + "accounts.google.com.", + "vk.com.", + "api.whatsapp.com.", + "opera.com.", + "cdc.gov.", + "slideshare.net.", + "wpa.qq.com.", + "harvard.edu.", + "mit.edu.", + "code.google.com.", + "wikimedia.org.", + } +) diff --git a/resolver/resolve.go b/resolver/resolve.go index f13d07c2..e5c406ae 100644 --- a/resolver/resolve.go +++ b/resolver/resolve.go @@ -14,8 +14,6 @@ import ( ) var ( - mtAsyncResolve = "async resolve" - // basic errors // ErrNotFound is a basic error that will match all "not found" errors @@ -114,6 +112,7 @@ func Resolve(ctx context.Context, q *Query) (rrCache *RRCache, err error) { rrCache.MixAnswers() return rrCache, nil } + log.Tracer(ctx).Debugf("resolver: waited for another %s%s query, but cache missed!", q.FQDN, q.QType) // if cache is still empty or non-compliant, go ahead and just query } else { // we are the first! @@ -132,14 +131,14 @@ func checkCache(ctx context.Context, q *Query) *RRCache { if err != nil { if err != database.ErrNotFound { log.Tracer(ctx).Warningf("resolver: getting RRCache %s%s from database failed: %s", q.FQDN, q.QType.String(), err) - log.Warningf("resolver: getting RRCache %s%s from database failed: %s", q.FQDN, q.QType.String(), err) } return nil } // get resolver that rrCache was resolved with - resolver := getResolverByIDWithLocking(rrCache.Server) + resolver := getActiveResolverByIDWithLocking(rrCache.Server) if resolver == nil { + log.Tracer(ctx).Debugf("resolver: ignoring RRCache %s%s because source server %s has been removed", q.FQDN, q.QType.String(), rrCache.Server) return nil } @@ -159,12 +158,13 @@ func checkCache(ctx context.Context, q *Query) *RRCache { log.Tracer(ctx).Trace("resolver: serving from cache, requesting new") // resolve async - module.StartMediumPriorityMicroTask(&mtAsyncResolve, func(ctx context.Context) error { + module.StartWorker("resolve async", func(ctx context.Context) error { _, _ = resolveAndCache(ctx, q) return nil }) } + log.Tracer(ctx).Tracef("resolver: using cached RR (expires in %s)", time.Until(time.Unix(rrCache.TTL, 0))) return rrCache } @@ -218,11 +218,6 @@ func resolveAndCache(ctx context.Context, q *Query) (rrCache *RRCache, err error return nil, ErrNoCompliance } - // prep - lastFailBoundary := time.Now().Add( - -time.Duration(nameserverRetryRate()) * time.Second, - ) - // start resolving var i int @@ -231,7 +226,7 @@ resolveLoop: for i = 0; i < 2; i++ { for _, resolver := range resolvers { // check if resolver failed recently (on first run) - if i == 0 && resolver.Conn.LastFail().After(lastFailBoundary) { + if i == 0 && resolver.Conn.IsFailing() { log.Tracer(ctx).Tracef("resolver: skipping resolver %s, because it failed recently", resolver) continue } diff --git a/resolver/resolver.go b/resolver/resolver.go index 65155fab..8921e2db 100644 --- a/resolver/resolver.go +++ b/resolver/resolver.go @@ -2,6 +2,7 @@ package resolver import ( "context" + "errors" "net" "sync" "time" @@ -23,6 +24,11 @@ const ( ServerSourceMDNS = "mdns" ) +var ( + // FailThreshold is amount of errors a resolvers must experience in order to be regarded as failed. + FailThreshold = 5 +) + // Resolver holds information about an active resolver. type Resolver struct { // Server config url (and ID) @@ -83,8 +89,8 @@ func (resolver *Resolver) String() string { // ResolverConn is an interface to implement different types of query backends. type ResolverConn interface { //nolint:go-lint // TODO Query(ctx context.Context, q *Query) (*RRCache, error) - MarkFailed() - LastFail() time.Time + ReportFailure() + IsFailing() bool } // BasicResolverConn implements ResolverConn for standard dns clients. @@ -92,12 +98,14 @@ type BasicResolverConn struct { sync.Mutex // for lastFail resolver *Resolver - clientManager *clientManager - lastFail time.Time + clientManager *dnsClientManager + + lastFail time.Time + fails int } -// MarkFailed marks the resolver as failed. -func (brc *BasicResolverConn) MarkFailed() { +// ReportFailure reports that an error occurred with this resolver. +func (brc *BasicResolverConn) ReportFailure() { if !netenv.Online() { // don't mark failed if we are offline return @@ -105,14 +113,26 @@ func (brc *BasicResolverConn) MarkFailed() { brc.Lock() defer brc.Unlock() - brc.lastFail = time.Now() + now := time.Now().UTC() + failDuration := time.Duration(nameserverRetryRate()) * time.Second + + // reset fail counter if currently not failing + if now.Add(-failDuration).After(brc.lastFail) { + brc.fails = 0 + } + + // update + brc.lastFail = now + brc.fails++ } -// LastFail returns the internal lastfail value while locking the Resolver. -func (brc *BasicResolverConn) LastFail() time.Time { +// IsFailing returns if this resolver is currently failing. +func (brc *BasicResolverConn) IsFailing() bool { brc.Lock() defer brc.Unlock() - return brc.lastFail + + failDuration := time.Duration(nameserverRetryRate()) * time.Second + return brc.fails >= FailThreshold && time.Now().UTC().Add(-failDuration).Before(brc.lastFail) } // Query executes the given query against the resolver. @@ -126,35 +146,78 @@ func (brc *BasicResolverConn) Query(ctx context.Context, q *Query) (*RRCache, er // start var reply *dns.Msg + var ttl time.Duration var err error - for i := 0; i < 3; i++ { + var conn *dns.Conn + var new bool + var tries int - // log query time - // qStart := time.Now() - reply, _, err = brc.clientManager.getDNSClient().Exchange(dnsQuery, resolver.ServerAddress) - // log.Tracef("resolver: query to %s took %s", resolver.Server, time.Now().Sub(qStart)) + for ; tries < 3; tries++ { - // error handling + // first get connection + dc := brc.clientManager.getDNSClient() + conn, new, err = dc.getConn() if err != nil { - log.Tracer(ctx).Tracef("resolver: query to %s encountered error: %s", resolver.Server, err) + log.Tracer(ctx).Tracef("resolver: failed to connect to %s: %s", resolver.Server, err) + // remove client from pool + dc.destroy() + // report that resolver had an error + brc.ReportFailure() + // hint network environment at failed connection + netenv.ReportFailedConnection() // TODO: handle special cases // 1. connect: network is unreachable // 2. timeout - // hint network environment at failed connection - netenv.ReportFailedConnection() + // try again + continue + } + if new { + log.Tracer(ctx).Tracef("resolver: created new connection to %s (%s)", resolver.Name, resolver.ServerAddress) + } else { + log.Tracer(ctx).Tracef("resolver: reusing connection to %s (%s)", resolver.Name, resolver.ServerAddress) + } + + // query server + reply, ttl, err = dc.client.ExchangeWithConn(dnsQuery, conn) + log.Tracer(ctx).Tracef("resolver: query took %s", ttl) + + // error handling + if err != nil { + log.Tracer(ctx).Tracef("resolver: query to %s encountered error: %s", resolver.Server, err) + + // remove client from pool + dc.destroy() // temporary error if nerr, ok := err.(net.Error); ok && nerr.Timeout() { log.Tracer(ctx).Tracef("resolver: retrying to resolve %s%s with %s, error is temporary", q.FQDN, q.QType, resolver.Server) + // try again continue } + // report failed if dns (nothing happens at getConn()) + if resolver.ServerType == ServerTypeDNS { + // report that resolver had an error + brc.ReportFailure() + // hint network environment at failed connection + netenv.ReportFailedConnection() + } + // permanent error break + } else if reply == nil { + // remove client from pool + dc.destroy() + + log.Errorf("resolver: successful query for %s%s to %s, but reply was nil", q.FQDN, q.QType, resolver.Server) + return nil, errors.New("internal error") } + // make client available (again) + dc.addToPool() + if resolver.IsBlockedUpstream(reply) { return nil, &BlockedUpstreamError{resolver.GetName()} } @@ -166,12 +229,15 @@ func (brc *BasicResolverConn) Query(ctx context.Context, q *Query) (*RRCache, er if err != nil { return nil, err // TODO: mark as failed + } else if reply == nil { + log.Errorf("resolver: queried %s for %s%s (%d tries), but reply was nil", q.FQDN, q.QType, resolver.GetName(), tries+1) + return nil, errors.New("internal error") } // hint network environment at successful connection netenv.ReportSuccessfulConnection() - new := &RRCache{ + newRecord := &RRCache{ Domain: q.FQDN, Question: q.QType, Answer: reply.Answer, @@ -182,5 +248,5 @@ func (brc *BasicResolverConn) Query(ctx context.Context, q *Query) (*RRCache, er } // TODO: check if reply.Answer is valid - return new, nil + return newRecord, nil } diff --git a/resolver/resolvers.go b/resolver/resolvers.go index 00ad0d0e..5d65fed3 100644 --- a/resolver/resolvers.go +++ b/resolver/resolvers.go @@ -25,7 +25,7 @@ var ( globalResolvers []*Resolver // all (global) resolvers localResolvers []*Resolver // all resolvers that are in site-local or link-local IP ranges localScopes []*Scope // list of scopes with a list of local resolvers that can resolve the scope - allResolvers map[string]*Resolver // lookup map of all resolvers + activeResolvers map[string]*Resolver // lookup map of all resolvers resolversLock sync.RWMutex dupReqMap = make(map[string]*sync.WaitGroup) @@ -41,11 +41,11 @@ func indexOfScope(domain string, list []*Scope) int { return -1 } -func getResolverByIDWithLocking(server string) *Resolver { - resolversLock.Lock() - defer resolversLock.Unlock() +func getActiveResolverByIDWithLocking(server string) *Resolver { + resolversLock.RLock() + defer resolversLock.RUnlock() - resolver, ok := allResolvers[server] + resolver, ok := activeResolvers[server] if ok { return resolver } @@ -62,7 +62,7 @@ func formatIPAndPort(ip net.IP, port uint16) string { return address } -func clientManagerFactory(serverType string) func(*Resolver) *clientManager { +func clientManagerFactory(serverType string) func(*Resolver) *dnsClientManager { switch serverType { case ServerTypeDNS: return newDNSClientManager @@ -152,8 +152,8 @@ func configureSearchDomains(resolver *Resolver, searches []string) { } } -func getConfiguredResolvers() (resolvers []*Resolver) { - for _, server := range configuredNameServers() { +func getConfiguredResolvers(list []string) (resolvers []*Resolver) { + for _, server := range list { resolver, skip, err := createResolver(server, "config") if err != nil { // TODO(ppacher): module error @@ -199,19 +199,40 @@ func loadResolvers() { defer resolversLock.Unlock() newResolvers := append( - getConfiguredResolvers(), + getConfiguredResolvers(configuredNameServers()), getSystemResolvers()..., ) - // save resolvers - globalResolvers = newResolvers - if len(globalResolvers) == 0 { - log.Criticalf("resolver: no (valid) dns servers found in configuration and system") - // TODO(module error) + if len(newResolvers) == 0 { + msg := "no (valid) dns servers found in (user) configuration or system, falling back to defaults" + log.Warningf("resolver: %s", msg) + module.Warning("no-valid-user-resolvers", msg) + + // load defaults directly, overriding config system + newResolvers = getConfiguredResolvers(defaultNameServers) + if len(newResolvers) == 0 { + msg = "no (valid) dns servers found in configuration or system" + log.Criticalf("resolver: %s", msg) + module.Error("no-valid-default-resolvers", msg) + return + } } + // save resolvers + globalResolvers = newResolvers + + // assing resolvers to scopes setLocalAndScopeResolvers(globalResolvers) + // set active resolvers (for cache validation) + // reset + activeResolvers = make(map[string]*Resolver) + // add + for _, resolver := range newResolvers { + activeResolvers[resolver.Server] = resolver + } + activeResolvers[mDNSResolver.Server] = mDNSResolver + // log global resolvers if len(globalResolvers) > 0 { log.Trace("resolver: loaded global resolvers:") diff --git a/resolver/reverse.go b/resolver/reverse.go index 0487cf44..c236818b 100644 --- a/resolver/reverse.go +++ b/resolver/reverse.go @@ -65,12 +65,12 @@ func ResolveIPAndValidate(ctx context.Context, ip string, securityLevel uint8) ( for _, rr := range rrCache.Answer { switch v := rr.(type) { case *dns.A: - log.Infof("A: %s", v.A.String()) + // log.Debugf("A: %s", v.A.String()) if ip == v.A.String() { return ptrName, nil } case *dns.AAAA: - log.Infof("AAAA: %s", v.AAAA.String()) + // log.Debugf("AAAA: %s", v.AAAA.String()) if ip == v.AAAA.String() { return ptrName, nil } diff --git a/resolver/rrcache_test.go b/resolver/rrcache_test.go new file mode 100644 index 00000000..8aaa3094 --- /dev/null +++ b/resolver/rrcache_test.go @@ -0,0 +1,41 @@ +package resolver + +import ( + "testing" + + "github.com/miekg/dns" +) + +func TestCaching(t *testing.T) { + testDomain := "Mk35mMqOWEHXSMk11MYcbjLOjTE8PQvDiAVUxf4BvwtgR.example.com." + testQuestion := "A" + + testNameRecord := &NameRecord{ + Domain: testDomain, + Question: testQuestion, + } + + err := testNameRecord.Save() + if err != nil { + t.Fatal(err) + } + + rrCache, err := GetRRCache(testDomain, dns.Type(dns.TypeA)) + if err != nil { + t.Fatal(err) + } + + err = rrCache.Save() + if err != nil { + t.Fatal(err) + } + + rrCache2, err := GetRRCache(testDomain, dns.Type(dns.TypeA)) + if err != nil { + t.Fatal(err) + } + + if rrCache2.Domain != rrCache.Domain { + t.Fatal("something very is wrong") + } +} diff --git a/ui/module.go b/ui/module.go index 76a130ef..8fdfed9d 100644 --- a/ui/module.go +++ b/ui/module.go @@ -3,6 +3,8 @@ package ui import ( "context" + "github.com/safing/portbase/dataroot" + resources "github.com/cookieo9/resources-go" "github.com/safing/portbase/log" "github.com/safing/portbase/modules" @@ -27,6 +29,19 @@ func prep() error { } func start() error { + // Create a dummy directory to which processes change their working directory + // to. Currently this includes the App and the Notifier. The aim is protect + // all other directories and increase compatibility should any process want + // to read or write something to the current working directory. This can also + // be useful in the future to dump data to for debugging. The permission used + // may seem dangerous, but proper permission on the parent directory provide + // (some) protection. + // Processes must _never_ read from this directory. + err := dataroot.Root().ChildDir("exec", 0777).Ensure() + if err != nil { + log.Warningf("ui: failed to create safe exec dir: %s", err) + } + return module.RegisterEventHook("ui", eventReload, "reload assets", reloadUI) } diff --git a/updates/main.go b/updates/main.go index ea15e075..84aa132e 100644 --- a/updates/main.go +++ b/updates/main.go @@ -152,7 +152,7 @@ func start() error { err = registry.LoadIndexes() if err != nil { - return err + log.Warningf("updates: failed to load indexes: %s", err) } err = registry.ScanStorage("") @@ -235,8 +235,7 @@ func checkForUpdates(ctx context.Context) (err error) { }() if err = registry.UpdateIndexes(); err != nil { - err = fmt.Errorf("failed to update indexes: %w", err) - return + log.Warningf("updates: failed to update indexes: %s", err) } err = registry.DownloadUpdates(ctx) diff --git a/updates/upgrader.go b/updates/upgrader.go index 1d754f7a..47f8ea36 100644 --- a/updates/upgrader.go +++ b/updates/upgrader.go @@ -113,15 +113,9 @@ func upgradePortmasterControl() error { return nil } - // check if registry tmp dir is ok - err := registry.TmpDir().Ensure() - if err != nil { - return fmt.Errorf("failed to prep updates tmp dir: %s", err) - } - // update portmaster-control in data root rootControlPath := filepath.Join(filepath.Dir(registry.StorageDir().Path), filename) - err = upgradeFile(rootControlPath, pmCtrlUpdate) + err := upgradeFile(rootControlPath, pmCtrlUpdate) if err != nil { return err } @@ -130,11 +124,11 @@ func upgradePortmasterControl() error { // upgrade parent process, if it's portmaster-control parent, err := processInfo.NewProcess(int32(os.Getppid())) if err != nil { - return fmt.Errorf("could not get parent process for upgrade checks: %s", err) + return fmt.Errorf("could not get parent process for upgrade checks: %w", err) } parentName, err := parent.Name() if err != nil { - return fmt.Errorf("could not get parent process name for upgrade checks: %s", err) + return fmt.Errorf("could not get parent process name for upgrade checks: %w", err) } if parentName != filename { log.Tracef("updates: parent process does not seem to be portmaster-control, name is %s", parentName) @@ -142,7 +136,7 @@ func upgradePortmasterControl() error { } parentPath, err := parent.Exe() if err != nil { - return fmt.Errorf("could not get parent process path for upgrade: %s", err) + return fmt.Errorf("could not get parent process path for upgrade: %w", err) } err = upgradeFile(parentPath, pmCtrlUpdate) if err != nil { @@ -190,7 +184,7 @@ func upgradeFile(fileToUpgrade string, file *updater.File) error { // ensure tmp dir is here err = registry.TmpDir().Ensure() if err != nil { - return fmt.Errorf("unable to check updates tmp dir for moving file that needs upgrade: %s", err) + return fmt.Errorf("could not prepare tmp directory for moving file that needs upgrade: %w", err) } // maybe we're on windows and it's in use, try moving @@ -204,17 +198,17 @@ func upgradeFile(fileToUpgrade string, file *updater.File) error { ), )) if err != nil { - return fmt.Errorf("unable to move file that needs upgrade: %s", err) + return fmt.Errorf("unable to move file that needs upgrade: %w", err) } } } // copy upgrade - err = copyFile(file.Path(), fileToUpgrade) + err = CopyFile(file.Path(), fileToUpgrade) if err != nil { // try again time.Sleep(1 * time.Second) - err = copyFile(file.Path(), fileToUpgrade) + err = CopyFile(file.Path(), fileToUpgrade) if err != nil { return err } @@ -224,23 +218,31 @@ func upgradeFile(fileToUpgrade string, file *updater.File) error { if !onWindows { info, err := os.Stat(fileToUpgrade) if err != nil { - return fmt.Errorf("failed to get file info on %s: %s", fileToUpgrade, err) + return fmt.Errorf("failed to get file info on %s: %w", fileToUpgrade, err) } if info.Mode() != 0755 { err := os.Chmod(fileToUpgrade, 0755) if err != nil { - return fmt.Errorf("failed to set permissions on %s: %s", fileToUpgrade, err) + return fmt.Errorf("failed to set permissions on %s: %w", fileToUpgrade, err) } } } return nil } -func copyFile(srcPath, dstPath string) (err error) { +// CopyFile atomically copies a file using the update registry's tmp dir. +func CopyFile(srcPath, dstPath string) (err error) { + + // check tmp dir + err = registry.TmpDir().Ensure() + if err != nil { + return fmt.Errorf("could not prepare tmp directory for copying file: %w", err) + } + // open file for writing atomicDstFile, err := renameio.TempFile(registry.TmpDir().Path, dstPath) if err != nil { - return fmt.Errorf("could not create temp file for atomic copy: %s", err) + return fmt.Errorf("could not create temp file for atomic copy: %w", err) } defer atomicDstFile.Cleanup() //nolint:errcheck // ignore error for now, tmp dir will be cleaned later again anyway @@ -260,7 +262,7 @@ func copyFile(srcPath, dstPath string) (err error) { // finalize file err = atomicDstFile.CloseAtomicallyReplace() if err != nil { - return fmt.Errorf("updates: failed to finalize copy to file %s: %s", dstPath, err) + return fmt.Errorf("updates: failed to finalize copy to file %s: %w", dstPath, err) } return nil