mirror of
https://github.com/safing/portmaster
synced 2025-09-01 01:59:11 +00:00
Merge pull request #52 from safing/feature/firewall-resolver-improvements
Firewall and Resolver improvements
This commit is contained in:
commit
c59f680053
70 changed files with 2288 additions and 1549 deletions
107
Gopkg.lock
generated
107
Gopkg.lock
generated
|
@ -2,12 +2,12 @@
|
|||
|
||||
|
||||
[[projects]]
|
||||
digest = "1:f82b8ac36058904227087141017bb82f4b0fc58272990a4cdae3e2d6d222644e"
|
||||
digest = "1:6146fda730c18186631e91e818d995e759e7cbe27644d6871ccd469f6865c686"
|
||||
name = "github.com/StackExchange/wmi"
|
||||
packages = ["."]
|
||||
pruneopts = ""
|
||||
revision = "5d049714c4a64225c3c79a7cf7d02f7fb5b96338"
|
||||
version = "1.0.0"
|
||||
revision = "cbe66965904dbe8a6cd589e2298e5d8b986bd7dd"
|
||||
version = "1.1.0"
|
||||
|
||||
[[projects]]
|
||||
digest = "1:e010d6b45ee6c721df761eae89961c634ceb55feff166a48d15504729309f267"
|
||||
|
@ -18,12 +18,12 @@
|
|||
version = "v1.1.1"
|
||||
|
||||
[[projects]]
|
||||
digest = "1:3c753679736345f50125ae993e0a2614da126859921ea7faeecda6d217501ce2"
|
||||
digest = "1:21caed545a1c7ef7a2627bbb45989f689872ff6d5087d49c31340ce74c36de59"
|
||||
name = "github.com/agext/levenshtein"
|
||||
packages = ["."]
|
||||
pruneopts = ""
|
||||
revision = "0ded9c86537917af2ff89bc9c78de6bd58477894"
|
||||
version = "v1.2.2"
|
||||
revision = "52c14c47d03211d8ac1834e94601635e07c5a6ef"
|
||||
version = "v1.2.3"
|
||||
|
||||
[[projects]]
|
||||
branch = "v2.1"
|
||||
|
@ -34,12 +34,12 @@
|
|||
revision = "d27c04069d0d5dfe11c202dacbf745ae8d1ab181"
|
||||
|
||||
[[projects]]
|
||||
digest = "1:166e24c91c2732657d2f791d3ee3897e7d85ece7cbb62ad991250e6b51fc1d4c"
|
||||
digest = "1:f384a8b6f89c502229e9013aa4f89ce5b5b56f09f9a4d601d7f1f026d3564fbf"
|
||||
name = "github.com/coreos/go-iptables"
|
||||
packages = ["iptables"]
|
||||
pruneopts = ""
|
||||
revision = "78b5fff24e6df8886ef8eca9411f683a884349a5"
|
||||
version = "v0.4.1"
|
||||
revision = "f901d6c2a4f2a4df092b98c33366dfba1f93d7a0"
|
||||
version = "v0.4.5"
|
||||
|
||||
[[projects]]
|
||||
digest = "1:0deddd908b6b4b768cfc272c16ee61e7088a60f7fe2f06c547bd3d8e1f8b8e77"
|
||||
|
@ -61,12 +61,12 @@
|
|||
version = "v1.2.4"
|
||||
|
||||
[[projects]]
|
||||
digest = "1:cc1255e2fef3819bfab3540277001e602892dd431ef9ab5499bcdbc425923d64"
|
||||
digest = "1:f63933986e63230fc32512ed00bc18ea4dbb0f57b5da18561314928fd20c2ff0"
|
||||
name = "github.com/godbus/dbus"
|
||||
packages = ["."]
|
||||
pruneopts = ""
|
||||
revision = "2ff6f7ffd60f0f2410b3105864bdd12c7894f844"
|
||||
version = "v5.0.1"
|
||||
revision = "37bf87eef99d69c4f1d3528bd66e3a87dc201472"
|
||||
version = "v5.0.3"
|
||||
|
||||
[[projects]]
|
||||
digest = "1:e85e59c4152d8576341daf54f40d96c404c264e04941a4a36b97a0f427eb9e5e"
|
||||
|
@ -113,20 +113,20 @@
|
|||
revision = "2bc1f35cddc0cc527b4bc3dce8578fc2a6c11384"
|
||||
|
||||
[[projects]]
|
||||
digest = "1:0b6694f306890ddbb69c96a16776510bd24e07436fae3f9b0a4e5b650f1e6fb7"
|
||||
branch = "master"
|
||||
digest = "1:c140772b00f0c26cf6627aee32f62d9f9d89dffcda648861266c482c36a5344a"
|
||||
name = "github.com/miekg/dns"
|
||||
packages = ["."]
|
||||
pruneopts = ""
|
||||
revision = "b13675009d59c97f3721247d9efa8914e1866a5b"
|
||||
version = "v1.1.15"
|
||||
revision = "b7703d0fa022e159d01efa2de82e6173d5ec04c8"
|
||||
|
||||
[[projects]]
|
||||
digest = "1:3819cd861b7abd7d12dc1ea52ecb998ad1171826a76ecf0aefa09545781091f9"
|
||||
digest = "1:b962a528cbecf7662bee4d84a600f7a0a6a130368666d7d461757ba4d1341906"
|
||||
name = "github.com/oschwald/maxminddb-golang"
|
||||
packages = ["."]
|
||||
pruneopts = ""
|
||||
revision = "2905694a1b00c5574f1418a7dbf8a22a7d247559"
|
||||
version = "v1.3.1"
|
||||
revision = "6a033e62c03b7dab4c37f7c9eb2ebb3b10e8f13a"
|
||||
version = "v1.6.0"
|
||||
|
||||
[[projects]]
|
||||
digest = "1:256484dbbcd271f9ecebc6795b2df8cad4c458dd0f5fd82a8c2fa0c29f233411"
|
||||
|
@ -145,7 +145,7 @@
|
|||
version = "v1.2.0"
|
||||
|
||||
[[projects]]
|
||||
digest = "1:8bf42eb2ded52ed2678b0716dbfbf30628765bc12b13222c4d5669ba4c1310e4"
|
||||
digest = "1:16f319cf21ddf49f27b3a2093d68316840dc25ec5c2a0a431a4a4fc01ea707e2"
|
||||
name = "github.com/shirou/gopsutil"
|
||||
packages = [
|
||||
"cpu",
|
||||
|
@ -155,32 +155,24 @@
|
|||
"process",
|
||||
]
|
||||
pruneopts = ""
|
||||
revision = "4c8b404ee5c53b04b04f34b1744a26bf5d2910de"
|
||||
version = "v2.19.6"
|
||||
revision = "a81cf97fce2300934e6c625b9917103346c26ba3"
|
||||
version = "v2.20.4"
|
||||
|
||||
[[projects]]
|
||||
branch = "master"
|
||||
digest = "1:99c6a6dab47067c9b898e8c8b13d130c6ab4ffbcc4b7cc6236c2cd0b1e344f5b"
|
||||
name = "github.com/shirou/w32"
|
||||
packages = ["."]
|
||||
pruneopts = ""
|
||||
revision = "bb4de0191aa41b5507caa14b0650cdbddcd9280b"
|
||||
|
||||
[[projects]]
|
||||
digest = "1:0c63b3c7ad6d825a898f28cb854252a3b29d37700c68a117a977263f5ec94efe"
|
||||
digest = "1:bff75d4f1a2d2c4b8f4b46ff5ac230b80b5fa49276f615900cba09fe4c97e66e"
|
||||
name = "github.com/spf13/cobra"
|
||||
packages = ["."]
|
||||
pruneopts = ""
|
||||
revision = "f2b07da1e2c38d5f12845a4f607e2e1018cbb1f5"
|
||||
version = "v0.0.5"
|
||||
revision = "a684a6d7f5e37385d954dd3b5a14fc6912c6ab9d"
|
||||
version = "v1.0.0"
|
||||
|
||||
[[projects]]
|
||||
digest = "1:cbaf13cdbfef0e4734ed8a7504f57fe893d471d62a35b982bf6fb3f036449a66"
|
||||
digest = "1:688428eeb1ca80d92599eb3254bdf91b51d7e232fead3a73844c1f201a281e51"
|
||||
name = "github.com/spf13/pflag"
|
||||
packages = ["."]
|
||||
pruneopts = ""
|
||||
revision = "298182f68c66c05229eb03ac171abe6e309ee79a"
|
||||
version = "v1.0.3"
|
||||
revision = "2e9d26c8c37aae03e3f9d4e90b7116f5accb7cab"
|
||||
version = "v1.0.5"
|
||||
|
||||
[[projects]]
|
||||
digest = "1:cc4eb6813da8d08694e557fcafae8fcc24f47f61a0717f952da130ca9a486dfc"
|
||||
|
@ -208,18 +200,18 @@
|
|||
|
||||
[[projects]]
|
||||
branch = "master"
|
||||
digest = "1:086760278d762dbb0e9a26e09b57f04c89178c86467d8d94fae47d64c222f328"
|
||||
digest = "1:bf61fa9b53be5ce096004599b957e5957b28a5e421b724250aa06ecb7ee6dc57"
|
||||
name = "golang.org/x/crypto"
|
||||
packages = [
|
||||
"ed25519",
|
||||
"ed25519/internal/edwards25519",
|
||||
]
|
||||
pruneopts = ""
|
||||
revision = "4def268fd1a49955bfb3dda92fe3db4f924f2285"
|
||||
revision = "4b2356b1ed79e6be3deca3737a3db3d132d2847a"
|
||||
|
||||
[[projects]]
|
||||
branch = "master"
|
||||
digest = "1:31cd6e3c114e17c5f0c9e8b0bcaa3025ab3c221ce36323c7ce1acaa753d0d0aa"
|
||||
digest = "1:ea84836e35d7a66c9b8944796295912509c80c921244bc4e098c5417219895f2"
|
||||
name = "golang.org/x/net"
|
||||
packages = [
|
||||
"bpf",
|
||||
|
@ -232,7 +224,7 @@
|
|||
"publicsuffix",
|
||||
]
|
||||
pruneopts = ""
|
||||
revision = "da137c7871d730100384dbcf36e6f8fa493aef5b"
|
||||
revision = "7e3656a0809f6f95abd88ac65313578f80b00df2"
|
||||
|
||||
[[projects]]
|
||||
branch = "master"
|
||||
|
@ -244,9 +236,10 @@
|
|||
|
||||
[[projects]]
|
||||
branch = "master"
|
||||
digest = "1:2579a16d8afda9c9a475808c13324f5e572852e8927905ffa15bb14e71baba4f"
|
||||
digest = "1:acb3b56e190190ac9497faf5f0c30c5da4d3e8278d6b7a7042f2aa3332ff7022"
|
||||
name = "golang.org/x/sys"
|
||||
packages = [
|
||||
"internal/unsafeheader",
|
||||
"unix",
|
||||
"windows",
|
||||
"windows/registry",
|
||||
|
@ -256,7 +249,7 @@
|
|||
"windows/svc/mgr",
|
||||
]
|
||||
pruneopts = ""
|
||||
revision = "04f50cda93cbb67f2afa353c52f342100e80e625"
|
||||
revision = "bc7a7d42d5c30f4d0fe808715c002826ce2c624e"
|
||||
|
||||
[[projects]]
|
||||
digest = "1:740b51a55815493a8d0f2b1e0d0ae48fe48953bf7eaf3fcc4198823bf67768c0"
|
||||
|
@ -283,6 +276,38 @@
|
|||
revision = "342b2e1fbaa52c93f31447ad2c6abc048c63e475"
|
||||
version = "v0.3.2"
|
||||
|
||||
[[projects]]
|
||||
branch = "master"
|
||||
digest = "1:3416c611e00178b07c8fc347ba96398e4d6709fe7d3fab17f0b0fa6f933b4bd1"
|
||||
name = "golang.org/x/tools"
|
||||
packages = [
|
||||
"go/ast/astutil",
|
||||
"go/gcexportdata",
|
||||
"go/internal/gcimporter",
|
||||
"go/internal/packagesdriver",
|
||||
"go/packages",
|
||||
"go/types/typeutil",
|
||||
"internal/event",
|
||||
"internal/event/core",
|
||||
"internal/event/keys",
|
||||
"internal/event/label",
|
||||
"internal/gocommand",
|
||||
"internal/packagesinternal",
|
||||
]
|
||||
pruneopts = ""
|
||||
revision = "b8469989bc69e50ec6dc4e4513fc3ff9ce48b8af"
|
||||
|
||||
[[projects]]
|
||||
branch = "master"
|
||||
digest = "1:9d4ac09a835404ae9306c6e1493cf800ecbb0f3f828f4333b3e055de4c962eea"
|
||||
name = "golang.org/x/xerrors"
|
||||
packages = [
|
||||
".",
|
||||
"internal",
|
||||
]
|
||||
pruneopts = ""
|
||||
revision = "9bdfabe68543c54f90421aeb9a60ef8061b5b544"
|
||||
|
||||
[[projects]]
|
||||
digest = "1:2efc9662a6a1ff28c65c84fc2f9030f13d3afecdb2ecad445f3b0c80e75fc281"
|
||||
name = "gopkg.in/yaml.v2"
|
||||
|
|
|
@ -25,3 +25,7 @@
|
|||
# unused-packages = true
|
||||
|
||||
ignored = ["github.com/safing/portbase/*"]
|
||||
|
||||
[[constraint]]
|
||||
name = "github.com/miekg/dns"
|
||||
branch = "master" # switch back to semver releases when https://github.com/miekg/dns/pull/1110 is released
|
||||
|
|
|
@ -60,7 +60,18 @@ func apiAuthenticator(s *http.Server, r *http.Request) (grantAccess bool, err er
|
|||
var procsChecked []string
|
||||
|
||||
// get process
|
||||
proc, err := process.GetProcessByEndpoints(r.Context(), remoteIP, remotePort, localIP, localPort, packet.TCP) // switch reverse/local to get remote process
|
||||
proc, _, err := process.GetProcessByConnection(
|
||||
r.Context(),
|
||||
&packet.Info{
|
||||
Inbound: false, // outbound as we are looking for the process of the source address
|
||||
Version: packet.IPv4,
|
||||
Protocol: packet.TCP,
|
||||
Src: remoteIP, // source as in the process we are looking for
|
||||
SrcPort: remotePort, // source as in the process we are looking for
|
||||
Dst: localIP,
|
||||
DstPort: localPort,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to get process: %s", err)
|
||||
}
|
||||
|
|
|
@ -48,7 +48,7 @@ func registerConfig() error {
|
|||
Order: CfgOptionAskWithSystemNotificationsOrder,
|
||||
OptType: config.OptTypeBool,
|
||||
ExpertiseLevel: config.ExpertiseLevelUser,
|
||||
ReleaseLevel: config.ReleaseLevelStable,
|
||||
ReleaseLevel: config.ReleaseLevelExperimental,
|
||||
DefaultValue: true,
|
||||
})
|
||||
if err != nil {
|
||||
|
@ -62,7 +62,7 @@ func registerConfig() error {
|
|||
Order: CfgOptionAskTimeoutOrder,
|
||||
OptType: config.OptTypeInt,
|
||||
ExpertiseLevel: config.ExpertiseLevelUser,
|
||||
ReleaseLevel: config.ReleaseLevelStable,
|
||||
ReleaseLevel: config.ReleaseLevelExperimental,
|
||||
DefaultValue: 60,
|
||||
})
|
||||
if err != nil {
|
||||
|
|
|
@ -138,7 +138,7 @@ func handlePacket(pkt packet.Packet) {
|
|||
// pkt.RedirToNameserver()
|
||||
// }
|
||||
|
||||
// allow ICMP, IGMP and DHCP
|
||||
// allow ICMP and DHCP
|
||||
// TODO: actually handle these
|
||||
switch meta.Protocol {
|
||||
case packet.ICMP:
|
||||
|
@ -149,10 +149,6 @@ func handlePacket(pkt packet.Packet) {
|
|||
log.Debugf("accepting ICMPv6: %s", pkt)
|
||||
_ = pkt.PermanentAccept()
|
||||
return
|
||||
case packet.IGMP:
|
||||
log.Debugf("accepting IGMP: %s", pkt)
|
||||
_ = pkt.PermanentAccept()
|
||||
return
|
||||
case packet.UDP:
|
||||
if meta.DstPort == 67 || meta.DstPort == 68 {
|
||||
log.Debugf("accepting DHCP: %s", pkt)
|
||||
|
@ -218,6 +214,7 @@ func initialHandler(conn *network.Connection, pkt packet.Packet) {
|
|||
// reroute dns requests to nameserver
|
||||
if conn.Process().Pid != os.Getpid() && pkt.IsOutbound() && pkt.Info().DstPort == 53 && !pkt.Info().Src.Equal(pkt.Info().Dst) {
|
||||
conn.Verdict = network.VerdictRerouteToNameserver
|
||||
conn.Internal = true
|
||||
conn.StopFirewallHandler()
|
||||
issueVerdict(conn, pkt, 0, true)
|
||||
return
|
||||
|
@ -233,7 +230,7 @@ func initialHandler(conn *network.Connection, pkt packet.Packet) {
|
|||
}
|
||||
|
||||
log.Tracer(pkt.Ctx()).Trace("filter: starting decision process")
|
||||
DecideOnConnection(conn, pkt)
|
||||
DecideOnConnection(pkt.Ctx(), conn, pkt)
|
||||
conn.Inspecting = false // TODO: enable inspecting again
|
||||
|
||||
switch {
|
||||
|
|
|
@ -62,7 +62,7 @@ func Handler(packets chan packet.Packet) {
|
|||
}
|
||||
|
||||
info := new.Info()
|
||||
info.Direction = packetInfo.direction > 0
|
||||
info.Inbound = packetInfo.direction > 0
|
||||
info.InTunnel = false
|
||||
info.Protocol = packet.IPProtocol(packetInfo.protocol)
|
||||
|
||||
|
@ -76,7 +76,7 @@ func Handler(packets chan packet.Packet) {
|
|||
// IPs
|
||||
if info.Version == packet.IPv4 {
|
||||
// IPv4
|
||||
if info.Direction {
|
||||
if info.Inbound {
|
||||
// Inbound
|
||||
info.Src = convertIPv4(packetInfo.remoteIP)
|
||||
info.Dst = convertIPv4(packetInfo.localIP)
|
||||
|
@ -87,7 +87,7 @@ func Handler(packets chan packet.Packet) {
|
|||
}
|
||||
} else {
|
||||
// IPv6
|
||||
if info.Direction {
|
||||
if info.Inbound {
|
||||
// Inbound
|
||||
info.Src = convertIPv6(packetInfo.remoteIP)
|
||||
info.Dst = convertIPv6(packetInfo.localIP)
|
||||
|
@ -99,7 +99,7 @@ func Handler(packets chan packet.Packet) {
|
|||
}
|
||||
|
||||
// Ports
|
||||
if info.Direction {
|
||||
if info.Inbound {
|
||||
// Inbound
|
||||
info.SrcPort = packetInfo.remotePort
|
||||
info.DstPort = packetInfo.localPort
|
||||
|
@ -113,19 +113,17 @@ func Handler(packets chan packet.Packet) {
|
|||
}
|
||||
}
|
||||
|
||||
// convertIPv4 as needed for data from the kernel
|
||||
func convertIPv4(input [4]uint32) net.IP {
|
||||
return net.IPv4(
|
||||
uint8(input[0]>>24&0xFF),
|
||||
uint8(input[0]>>16&0xFF),
|
||||
uint8(input[0]>>8&0xFF),
|
||||
uint8(input[0]&0xFF),
|
||||
)
|
||||
addressBuf := make([]byte, 4)
|
||||
binary.BigEndian.PutUint32(addressBuf, input[0])
|
||||
return net.IP(addressBuf)
|
||||
}
|
||||
|
||||
func convertIPv6(input [4]uint32) net.IP {
|
||||
addressBuf := make([]byte, 16)
|
||||
for i := 0; i < 4; i++ {
|
||||
binary.LittleEndian.PutUint32(addressBuf[i*4:i*4+4], input[i])
|
||||
binary.BigEndian.PutUint32(addressBuf[i*4:i*4+4], input[i])
|
||||
}
|
||||
return net.IP(addressBuf)
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package firewall
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
@ -10,6 +11,7 @@ import (
|
|||
"github.com/safing/portmaster/network"
|
||||
"github.com/safing/portmaster/network/netutils"
|
||||
"github.com/safing/portmaster/network/packet"
|
||||
"github.com/safing/portmaster/network/state"
|
||||
"github.com/safing/portmaster/process"
|
||||
"github.com/safing/portmaster/profile"
|
||||
"github.com/safing/portmaster/profile/endpoints"
|
||||
|
@ -32,10 +34,10 @@ import (
|
|||
|
||||
// DecideOnConnection makes a decision about a connection.
|
||||
// When called, the connection and profile is already locked.
|
||||
func DecideOnConnection(conn *network.Connection, pkt packet.Packet) {
|
||||
func DecideOnConnection(ctx context.Context, conn *network.Connection, pkt packet.Packet) {
|
||||
// update profiles and check if communication needs reevaluation
|
||||
if conn.UpdateAndCheck() {
|
||||
log.Infof("filter: re-evaluating verdict on %s", conn)
|
||||
log.Tracer(ctx).Infof("filter: re-evaluating verdict on %s", conn)
|
||||
conn.Verdict = network.VerdictUndecided
|
||||
|
||||
if conn.Entity != nil {
|
||||
|
@ -43,7 +45,7 @@ func DecideOnConnection(conn *network.Connection, pkt packet.Packet) {
|
|||
}
|
||||
}
|
||||
|
||||
var deciders = []func(*network.Connection, packet.Packet) bool{
|
||||
var deciders = []func(context.Context, *network.Connection, packet.Packet) bool{
|
||||
checkPortmasterConnection,
|
||||
checkSelfCommunication,
|
||||
checkProfileExists,
|
||||
|
@ -59,7 +61,7 @@ func DecideOnConnection(conn *network.Connection, pkt packet.Packet) {
|
|||
}
|
||||
|
||||
for _, decider := range deciders {
|
||||
if decider(conn, pkt) {
|
||||
if decider(ctx, conn, pkt) {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
@ -70,10 +72,10 @@ func DecideOnConnection(conn *network.Connection, pkt packet.Packet) {
|
|||
|
||||
// checkPortmasterConnection allows all connection that originate from
|
||||
// portmaster itself.
|
||||
func checkPortmasterConnection(conn *network.Connection, _ packet.Packet) bool {
|
||||
func checkPortmasterConnection(ctx context.Context, conn *network.Connection, pkt packet.Packet) bool {
|
||||
// grant self
|
||||
if conn.Process().Pid == os.Getpid() {
|
||||
log.Infof("filter: granting own connection %s", conn)
|
||||
log.Tracer(ctx).Infof("filter: granting own connection %s", conn)
|
||||
conn.Verdict = network.VerdictAccept
|
||||
conn.Internal = true
|
||||
return true
|
||||
|
@ -83,27 +85,29 @@ func checkPortmasterConnection(conn *network.Connection, _ packet.Packet) bool {
|
|||
}
|
||||
|
||||
// checkSelfCommunication checks if the process is communicating with itself.
|
||||
func checkSelfCommunication(conn *network.Connection, pkt packet.Packet) bool {
|
||||
func checkSelfCommunication(ctx context.Context, conn *network.Connection, pkt packet.Packet) bool {
|
||||
// check if process is communicating with itself
|
||||
if pkt != nil {
|
||||
// TODO: evaluate the case where different IPs in the 127/8 net are used.
|
||||
pktInfo := pkt.Info()
|
||||
if conn.Process().Pid >= 0 && pktInfo.Src.Equal(pktInfo.Dst) {
|
||||
// get PID
|
||||
otherPid, _, err := process.GetPidByEndpoints(
|
||||
pktInfo.RemoteIP(),
|
||||
pktInfo.RemotePort(),
|
||||
pktInfo.LocalIP(),
|
||||
pktInfo.LocalPort(),
|
||||
pktInfo.Protocol,
|
||||
)
|
||||
otherPid, _, err := state.Lookup(&packet.Info{
|
||||
Inbound: !pktInfo.Inbound, // we want to know the process on the other end
|
||||
Version: pktInfo.Version,
|
||||
Protocol: pktInfo.Protocol,
|
||||
Src: pktInfo.Src,
|
||||
SrcPort: pktInfo.SrcPort,
|
||||
Dst: pktInfo.Dst,
|
||||
DstPort: pktInfo.DstPort,
|
||||
})
|
||||
if err != nil {
|
||||
log.Warningf("filter: failed to find local peer process PID: %s", err)
|
||||
log.Tracer(ctx).Warningf("filter: failed to find local peer process PID: %s", err)
|
||||
} else {
|
||||
// get primary process
|
||||
otherProcess, err := process.GetOrFindPrimaryProcess(pkt.Ctx(), otherPid)
|
||||
otherProcess, err := process.GetOrFindPrimaryProcess(ctx, otherPid)
|
||||
if err != nil {
|
||||
log.Warningf("filter: failed to find load local peer process with PID %d: %s", otherPid, err)
|
||||
log.Tracer(ctx).Warningf("filter: failed to find load local peer process with PID %d: %s", otherPid, err)
|
||||
} else if otherProcess.Pid == conn.Process().Pid {
|
||||
conn.Accept("connection to self")
|
||||
conn.Internal = true
|
||||
|
@ -116,7 +120,7 @@ func checkSelfCommunication(conn *network.Connection, pkt packet.Packet) bool {
|
|||
return false
|
||||
}
|
||||
|
||||
func checkProfileExists(conn *network.Connection, _ packet.Packet) bool {
|
||||
func checkProfileExists(_ context.Context, conn *network.Connection, _ packet.Packet) bool {
|
||||
if conn.Process().Profile() == nil {
|
||||
conn.Block("unknown process or profile")
|
||||
return true
|
||||
|
@ -124,7 +128,7 @@ func checkProfileExists(conn *network.Connection, _ packet.Packet) bool {
|
|||
return false
|
||||
}
|
||||
|
||||
func checkEndpointLists(conn *network.Connection, _ packet.Packet) bool {
|
||||
func checkEndpointLists(_ context.Context, conn *network.Connection, _ packet.Packet) bool {
|
||||
var result endpoints.EPResult
|
||||
var reason endpoints.Reason
|
||||
|
||||
|
@ -149,12 +153,12 @@ func checkEndpointLists(conn *network.Connection, _ packet.Packet) bool {
|
|||
return false
|
||||
}
|
||||
|
||||
func checkConnectionType(conn *network.Connection, _ packet.Packet) bool {
|
||||
func checkConnectionType(ctx context.Context, conn *network.Connection, _ packet.Packet) bool {
|
||||
p := conn.Process().Profile()
|
||||
|
||||
// check conn type
|
||||
switch conn.Scope {
|
||||
case network.IncomingHost, network.IncomingLAN, network.IncomingInternet, network.IncomingInvalid:
|
||||
case network.IncomingLAN, network.IncomingInternet, network.IncomingInvalid:
|
||||
if p.BlockInbound() {
|
||||
if conn.Scope == network.IncomingHost {
|
||||
conn.Block("inbound connections blocked")
|
||||
|
@ -174,7 +178,7 @@ func checkConnectionType(conn *network.Connection, _ packet.Packet) bool {
|
|||
return false
|
||||
}
|
||||
|
||||
func checkConnectionScope(conn *network.Connection, _ packet.Packet) bool {
|
||||
func checkConnectionScope(_ context.Context, conn *network.Connection, _ packet.Packet) bool {
|
||||
p := conn.Process().Profile()
|
||||
|
||||
// check scopes
|
||||
|
@ -213,7 +217,7 @@ func checkConnectionScope(conn *network.Connection, _ packet.Packet) bool {
|
|||
return false
|
||||
}
|
||||
|
||||
func checkBypassPrevention(conn *network.Connection, _ packet.Packet) bool {
|
||||
func checkBypassPrevention(_ context.Context, conn *network.Connection, _ packet.Packet) bool {
|
||||
if conn.Process().Profile().PreventBypassing() {
|
||||
// check for bypass protection
|
||||
result, reason, reasonCtx := PreventBypassing(conn)
|
||||
|
@ -230,7 +234,7 @@ func checkBypassPrevention(conn *network.Connection, _ packet.Packet) bool {
|
|||
return false
|
||||
}
|
||||
|
||||
func checkFilterLists(conn *network.Connection, _ packet.Packet) bool {
|
||||
func checkFilterLists(ctx context.Context, conn *network.Connection, pkt packet.Packet) bool {
|
||||
// apply privacy filter lists
|
||||
p := conn.Process().Profile()
|
||||
|
||||
|
@ -242,12 +246,12 @@ func checkFilterLists(conn *network.Connection, _ packet.Packet) bool {
|
|||
case endpoints.NoMatch:
|
||||
// nothing to do
|
||||
default:
|
||||
log.Debugf("filter: filter lists returned unsupported verdict: %s", result)
|
||||
log.Tracer(ctx).Debugf("filter: filter lists returned unsupported verdict: %s", result)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func checkInbound(conn *network.Connection, _ packet.Packet) bool {
|
||||
func checkInbound(_ context.Context, conn *network.Connection, _ packet.Packet) bool {
|
||||
// implicit default=block for inbound
|
||||
if conn.Inbound {
|
||||
conn.Drop("endpoint is not whitelisted (incoming is always default=block)")
|
||||
|
@ -256,7 +260,7 @@ func checkInbound(conn *network.Connection, _ packet.Packet) bool {
|
|||
return false
|
||||
}
|
||||
|
||||
func checkDefaultPermit(conn *network.Connection, _ packet.Packet) bool {
|
||||
func checkDefaultPermit(_ context.Context, conn *network.Connection, _ packet.Packet) bool {
|
||||
// check default action
|
||||
p := conn.Process().Profile()
|
||||
if p.DefaultAction() == profile.DefaultActionPermit {
|
||||
|
@ -266,7 +270,7 @@ func checkDefaultPermit(conn *network.Connection, _ packet.Packet) bool {
|
|||
return false
|
||||
}
|
||||
|
||||
func checkAutoPermitRelated(conn *network.Connection, _ packet.Packet) bool {
|
||||
func checkAutoPermitRelated(_ context.Context, conn *network.Connection, _ packet.Packet) bool {
|
||||
p := conn.Process().Profile()
|
||||
if !p.DisableAutoPermit() {
|
||||
related, reason := checkRelation(conn)
|
||||
|
@ -278,7 +282,7 @@ func checkAutoPermitRelated(conn *network.Connection, _ packet.Packet) bool {
|
|||
return false
|
||||
}
|
||||
|
||||
func checkDefaultAction(conn *network.Connection, pkt packet.Packet) bool {
|
||||
func checkDefaultAction(_ context.Context, conn *network.Connection, pkt packet.Packet) bool {
|
||||
p := conn.Process().Profile()
|
||||
if p.DefaultAction() == profile.DefaultActionAsk {
|
||||
prompt(conn, pkt)
|
||||
|
|
|
@ -71,9 +71,9 @@ func (br ListBlockReason) GetExtraRR(_ *dns.Msg, _ string, _ interface{}) []dns.
|
|||
|
||||
for _, lm := range br {
|
||||
blockedBy, err := dns.NewRR(fmt.Sprintf(
|
||||
"%s-blockedBy. 0 IN TXT %q",
|
||||
strings.TrimRight(lm.Entity, "."),
|
||||
strings.Join(lm.ActiveLists, ","),
|
||||
`%s 0 IN TXT "blocked by filter lists %s"`,
|
||||
lm.Entity,
|
||||
strings.Join(lm.ActiveLists, ", "),
|
||||
))
|
||||
if err == nil {
|
||||
rrs = append(rrs, blockedBy)
|
||||
|
@ -83,9 +83,9 @@ func (br ListBlockReason) GetExtraRR(_ *dns.Msg, _ string, _ interface{}) []dns.
|
|||
|
||||
if len(lm.InactiveLists) > 0 {
|
||||
wouldBeBlockedBy, err := dns.NewRR(fmt.Sprintf(
|
||||
"%s-wouldBeBlockedBy. 0 IN TXT %q",
|
||||
strings.TrimRight(lm.Entity, "."),
|
||||
strings.Join(lm.InactiveLists, ","),
|
||||
`%s 0 IN TXT "would be blocked by filter lists %s"`,
|
||||
lm.Entity,
|
||||
strings.Join(lm.InactiveLists, ", "),
|
||||
))
|
||||
if err == nil {
|
||||
rrs = append(rrs, wouldBeBlockedBy)
|
||||
|
|
|
@ -6,6 +6,8 @@ import (
|
|||
"net"
|
||||
"strings"
|
||||
|
||||
"github.com/safing/portmaster/network/packet"
|
||||
|
||||
"github.com/safing/portbase/modules/subsystems"
|
||||
|
||||
"github.com/safing/portbase/log"
|
||||
|
@ -22,9 +24,8 @@ import (
|
|||
)
|
||||
|
||||
var (
|
||||
module *modules.Module
|
||||
dnsServer *dns.Server
|
||||
mtDNSRequest = "dns request"
|
||||
module *modules.Module
|
||||
dnsServer *dns.Server
|
||||
|
||||
listenAddress = "0.0.0.0:53"
|
||||
ipv4Localhost = net.IPv4(127, 0, 0, 1)
|
||||
|
@ -61,7 +62,7 @@ func prep() error {
|
|||
|
||||
func start() error {
|
||||
dnsServer = &dns.Server{Addr: listenAddress, Net: "udp"}
|
||||
dns.HandleFunc(".", handleRequestAsMicroTask)
|
||||
dns.HandleFunc(".", handleRequestAsWorker)
|
||||
|
||||
module.StartServiceWorker("dns resolver", 0, func(ctx context.Context) error {
|
||||
err := dnsServer.ListenAndServe()
|
||||
|
@ -95,8 +96,8 @@ func returnServerFailure(w dns.ResponseWriter, query *dns.Msg) {
|
|||
_ = w.WriteMsg(m)
|
||||
}
|
||||
|
||||
func handleRequestAsMicroTask(w dns.ResponseWriter, query *dns.Msg) {
|
||||
err := module.RunMicroTask(&mtDNSRequest, func(ctx context.Context) error {
|
||||
func handleRequestAsWorker(w dns.ResponseWriter, query *dns.Msg) {
|
||||
err := module.RunWorker("dns request", func(ctx context.Context) error {
|
||||
return handleRequest(ctx, w, query)
|
||||
})
|
||||
if err != nil {
|
||||
|
@ -168,12 +169,13 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, query *dns.Msg) er
|
|||
|
||||
// start tracer
|
||||
ctx, tracer := log.AddTracer(ctx)
|
||||
tracer.Tracef("nameserver: handling new request for %s%s from %s:%d", q.FQDN, q.QType, remoteAddr.IP, remoteAddr.Port)
|
||||
defer tracer.Submit()
|
||||
tracer.Tracef("nameserver: handling new request for %s%s from %s:%d, getting connection", q.FQDN, q.QType, remoteAddr.IP, remoteAddr.Port)
|
||||
|
||||
// TODO: if there are 3 request for the same domain/type in a row, delete all caches of that domain
|
||||
|
||||
// get connection
|
||||
conn := network.NewConnectionFromDNSRequest(ctx, q.FQDN, nil, remoteAddr.IP, uint16(remoteAddr.Port))
|
||||
conn := network.NewConnectionFromDNSRequest(ctx, q.FQDN, nil, packet.IPv4, remoteAddr.IP, uint16(remoteAddr.Port))
|
||||
|
||||
// once we decided on the connection we might need to save it to the database
|
||||
// so we defer that check right now.
|
||||
|
@ -191,7 +193,7 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, query *dns.Msg) er
|
|||
return
|
||||
|
||||
default:
|
||||
log.Warningf("nameserver: unexpected verdict %s for connection %s, not saving", conn.Verdict, conn)
|
||||
tracer.Warningf("nameserver: unexpected verdict %s for connection %s, not saving", conn.Verdict, conn)
|
||||
}
|
||||
}()
|
||||
|
||||
|
@ -220,7 +222,7 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, query *dns.Msg) er
|
|||
}
|
||||
|
||||
// check profile before we even get intel and rr
|
||||
firewall.DecideOnConnection(conn, nil)
|
||||
firewall.DecideOnConnection(ctx, conn, nil)
|
||||
|
||||
switch conn.Verdict {
|
||||
case network.VerdictBlock:
|
||||
|
@ -242,7 +244,7 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, query *dns.Msg) er
|
|||
tracer.Infof("nameserver: %s handing over to reason-responder: %s", q.FQDN, conn.Reason)
|
||||
reply := responder.ReplyWithDNS(query, conn.Reason, conn.ReasonContext)
|
||||
if err := w.WriteMsg(reply); err != nil {
|
||||
log.Warningf("nameserver: failed to return response %s%s to %s: %s", q.FQDN, q.QType, conn.Process(), err)
|
||||
tracer.Warningf("nameserver: failed to return response %s%s to %s: %s", q.FQDN, q.QType, conn.Process(), err)
|
||||
} else {
|
||||
tracer.Debugf("nameserver: returning response %s%s to %s", q.FQDN, q.QType, conn.Process())
|
||||
}
|
||||
|
@ -269,6 +271,7 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, query *dns.Msg) er
|
|||
return nil
|
||||
}
|
||||
|
||||
tracer.Trace("nameserver: deciding on resolved dns")
|
||||
rrCache = firewall.DecideOnResolvedDNS(conn, q, rrCache)
|
||||
if rrCache == nil {
|
||||
sendResponse(w, query, conn.Verdict, conn.Reason, conn.ReasonContext)
|
||||
|
@ -283,7 +286,7 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, query *dns.Msg) er
|
|||
m.Extra = rrCache.Extra
|
||||
|
||||
if err := w.WriteMsg(m); err != nil {
|
||||
log.Warningf("nameserver: failed to return response %s%s to %s: %s", q.FQDN, q.QType, conn.Process(), err)
|
||||
tracer.Warningf("nameserver: failed to return response %s%s to %s: %s", q.FQDN, q.QType, conn.Process(), err)
|
||||
} else {
|
||||
tracer.Debugf("nameserver: returning response %s%s to %s", q.FQDN, q.QType, conn.Process())
|
||||
}
|
||||
|
|
|
@ -10,7 +10,7 @@ import (
|
|||
"github.com/safing/portbase/modules"
|
||||
"github.com/safing/portbase/notifications"
|
||||
"github.com/safing/portmaster/network/packet"
|
||||
"github.com/safing/portmaster/process"
|
||||
"github.com/safing/portmaster/network/state"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -58,7 +58,15 @@ func checkForConflictingService() error {
|
|||
}
|
||||
|
||||
func takeover(resolverIP net.IP) (int, error) {
|
||||
pid, _, err := process.GetPidByEndpoints(resolverIP, 53, resolverIP, 65535, packet.UDP)
|
||||
pid, _, err := state.Lookup(&packet.Info{
|
||||
Inbound: true,
|
||||
Version: 0, // auto-detect
|
||||
Protocol: packet.UDP,
|
||||
Src: nil, // do not record direction
|
||||
SrcPort: 0, // do not record direction
|
||||
Dst: resolverIP,
|
||||
DstPort: 53,
|
||||
})
|
||||
if err != nil {
|
||||
// there may be nothing listening on :53
|
||||
return 0, nil
|
||||
|
|
|
@ -201,20 +201,11 @@ func triggerOnlineStatusInvestigation() {
|
|||
|
||||
func monitorOnlineStatus(ctx context.Context) error {
|
||||
for {
|
||||
timeout := 5 * time.Minute
|
||||
/*
|
||||
if GetOnlineStatus() != StatusOnline {
|
||||
timeout = time.Second
|
||||
log.Debugf("checking online status again in %s because current status is %s", timeout, GetOnlineStatus())
|
||||
}
|
||||
*/
|
||||
// wait for trigger
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
case <-onlineStatusInvestigationTrigger:
|
||||
|
||||
case <-time.After(timeout):
|
||||
}
|
||||
|
||||
// enable waiting
|
||||
|
|
|
@ -4,6 +4,10 @@ import (
|
|||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/safing/portmaster/network/packet"
|
||||
|
||||
"github.com/safing/portmaster/network/state"
|
||||
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portmaster/process"
|
||||
)
|
||||
|
@ -22,8 +26,12 @@ func connectionCleaner(ctx context.Context) error {
|
|||
ticker.Stop()
|
||||
return nil
|
||||
case <-ticker.C:
|
||||
// clean connections and processes
|
||||
activePIDs := cleanConnections()
|
||||
process.CleanProcessStorage(activePIDs)
|
||||
|
||||
// clean udp connection states
|
||||
state.CleanUDPStates(ctx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -33,13 +41,10 @@ func cleanConnections() (activePIDs map[int]struct{}) {
|
|||
|
||||
name := "clean connections" // TODO: change to new fn
|
||||
_ = module.RunMediumPriorityMicroTask(&name, func(ctx context.Context) error {
|
||||
activeIDs := make(map[string]struct{})
|
||||
for _, cID := range process.GetActiveConnectionIDs() {
|
||||
activeIDs[cID] = struct{}{}
|
||||
}
|
||||
|
||||
now := time.Now().Unix()
|
||||
deleteOlderThan := time.Now().Add(-deleteConnsAfterEndedThreshold).Unix()
|
||||
now := time.Now().UTC()
|
||||
nowUnix := now.Unix()
|
||||
deleteOlderThan := now.Add(-deleteConnsAfterEndedThreshold).Unix()
|
||||
|
||||
// lock both together because we cannot fully guarantee in which map a connection lands
|
||||
// of course every connection should land in the correct map, but this increases resilience
|
||||
|
@ -49,20 +54,27 @@ func cleanConnections() (activePIDs map[int]struct{}) {
|
|||
defer dnsConnsLock.Unlock()
|
||||
|
||||
// network connections
|
||||
for key, conn := range conns {
|
||||
for _, conn := range conns {
|
||||
conn.Lock()
|
||||
|
||||
// delete inactive connections
|
||||
switch {
|
||||
case conn.Ended == 0:
|
||||
// Step 1: check if still active
|
||||
_, ok := activeIDs[key]
|
||||
if ok {
|
||||
activePIDs[conn.process.Pid] = struct{}{}
|
||||
} else {
|
||||
exists := state.Exists(&packet.Info{
|
||||
Inbound: false, // src == local
|
||||
Version: conn.IPVersion,
|
||||
Protocol: conn.IPProtocol,
|
||||
Src: conn.LocalIP,
|
||||
SrcPort: conn.LocalPort,
|
||||
Dst: conn.Entity.IP,
|
||||
DstPort: conn.Entity.Port,
|
||||
}, now)
|
||||
activePIDs[conn.process.Pid] = struct{}{}
|
||||
|
||||
if !exists {
|
||||
// Step 2: mark end
|
||||
activePIDs[conn.process.Pid] = struct{}{}
|
||||
conn.Ended = now
|
||||
conn.Ended = nowUnix
|
||||
conn.Save()
|
||||
}
|
||||
case conn.Ended < deleteOlderThan:
|
||||
|
|
|
@ -25,11 +25,19 @@ type Connection struct { //nolint:maligned // TODO: fix alignment
|
|||
record.Base
|
||||
sync.Mutex
|
||||
|
||||
ID string
|
||||
Scope string
|
||||
Inbound bool
|
||||
Entity *intel.Entity // needs locking, instance is never shared
|
||||
process *process.Process
|
||||
ID string
|
||||
Scope string
|
||||
IPVersion packet.IPVersion
|
||||
Inbound bool
|
||||
|
||||
// local endpoint
|
||||
IPProtocol packet.IPProtocol
|
||||
LocalIP net.IP
|
||||
LocalPort uint16
|
||||
process *process.Process
|
||||
|
||||
// remote endpoint
|
||||
Entity *intel.Entity
|
||||
|
||||
Verdict Verdict
|
||||
Reason string
|
||||
|
@ -55,11 +63,22 @@ type Connection struct { //nolint:maligned // TODO: fix alignment
|
|||
}
|
||||
|
||||
// NewConnectionFromDNSRequest returns a new connection based on the given dns request.
|
||||
func NewConnectionFromDNSRequest(ctx context.Context, fqdn string, cnames []string, localIP net.IP, localPort uint16) *Connection {
|
||||
func NewConnectionFromDNSRequest(ctx context.Context, fqdn string, cnames []string, ipVersion packet.IPVersion, localIP net.IP, localPort uint16) *Connection {
|
||||
// get Process
|
||||
proc, err := process.GetProcessByEndpoints(ctx, localIP, localPort, dnsAddress, dnsPort, packet.UDP)
|
||||
proc, _, err := process.GetProcessByConnection(
|
||||
ctx,
|
||||
&packet.Info{
|
||||
Inbound: false, // outbound as we are looking for the process of the source address
|
||||
Version: ipVersion,
|
||||
Protocol: packet.UDP,
|
||||
Src: localIP, // source as in the process we are looking for
|
||||
SrcPort: localPort, // source as in the process we are looking for
|
||||
Dst: nil, // do not record direction
|
||||
DstPort: 0, // do not record direction
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
log.Warningf("network: failed to find process of dns request for %s: %s", fqdn, err)
|
||||
log.Debugf("network: failed to find process of dns request for %s: %s", fqdn, err)
|
||||
proc = process.GetUnidentifiedProcess(ctx)
|
||||
}
|
||||
|
||||
|
@ -80,9 +99,9 @@ func NewConnectionFromDNSRequest(ctx context.Context, fqdn string, cnames []stri
|
|||
// NewConnectionFromFirstPacket returns a new connection based on the given packet.
|
||||
func NewConnectionFromFirstPacket(pkt packet.Packet) *Connection {
|
||||
// get Process
|
||||
proc, inbound, err := process.GetProcessByPacket(pkt)
|
||||
proc, inbound, err := process.GetProcessByConnection(pkt.Ctx(), pkt.Info())
|
||||
if err != nil {
|
||||
log.Warningf("network: failed to find process of packet %s: %s", pkt, err)
|
||||
log.Debugf("network: failed to find process of packet %s: %s", pkt, err)
|
||||
proc = process.GetUnidentifiedProcess(pkt.Ctx())
|
||||
}
|
||||
|
||||
|
@ -147,12 +166,20 @@ func NewConnectionFromFirstPacket(pkt packet.Packet) *Connection {
|
|||
}
|
||||
|
||||
return &Connection{
|
||||
ID: pkt.GetConnectionID(),
|
||||
Scope: scope,
|
||||
Inbound: inbound,
|
||||
Entity: entity,
|
||||
process: proc,
|
||||
Started: time.Now().Unix(),
|
||||
ID: pkt.GetConnectionID(),
|
||||
Scope: scope,
|
||||
IPVersion: pkt.Info().Version,
|
||||
Inbound: inbound,
|
||||
// local endpoint
|
||||
IPProtocol: pkt.Info().Protocol,
|
||||
LocalIP: pkt.Info().LocalIP(),
|
||||
LocalPort: pkt.Info().LocalPort(),
|
||||
process: proc,
|
||||
// remote endpoint
|
||||
Entity: entity,
|
||||
// meta
|
||||
Started: time.Now().Unix(),
|
||||
profileRevisionCounter: proc.Profile().RevisionCnt(),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -5,6 +5,8 @@ import (
|
|||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/safing/portmaster/network/state"
|
||||
|
||||
"github.com/safing/portbase/database"
|
||||
"github.com/safing/portbase/database/iterator"
|
||||
"github.com/safing/portbase/database/query"
|
||||
|
@ -57,6 +59,14 @@ func (s *StorageInterface) Get(key string) (record.Record, error) {
|
|||
return conn, nil
|
||||
}
|
||||
}
|
||||
case "system":
|
||||
if len(splitted) >= 2 {
|
||||
switch splitted[1] {
|
||||
case "state":
|
||||
return state.GetInfo(), nil
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil, storage.ErrNotFound
|
||||
|
|
71
network/iphelper/get.go
Normal file
71
network/iphelper/get.go
Normal file
|
@ -0,0 +1,71 @@
|
|||
// +build windows
|
||||
|
||||
package iphelper
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/safing/portmaster/network/socket"
|
||||
)
|
||||
|
||||
var (
|
||||
ipHelper *IPHelper
|
||||
|
||||
// lock locks access to the whole DLL.
|
||||
// TODO: It's unproven if we can access the iphlpapi.dll concurrently, especially as we might be encountering various versions of the DLL. In the future, we could possibly investigate and improve performance here.
|
||||
lock sync.RWMutex
|
||||
)
|
||||
|
||||
// GetTCP4Table returns the system table for IPv4 TCP activity.
|
||||
func GetTCP4Table() (connections []*socket.ConnectionInfo, listeners []*socket.BindInfo, err error) {
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
|
||||
err = checkIPHelper()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return ipHelper.getTable(IPv4, TCP)
|
||||
}
|
||||
|
||||
// GetTCP6Table returns the system table for IPv6 TCP activity.
|
||||
func GetTCP6Table() (connections []*socket.ConnectionInfo, listeners []*socket.BindInfo, err error) {
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
|
||||
err = checkIPHelper()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return ipHelper.getTable(IPv6, TCP)
|
||||
}
|
||||
|
||||
// GetUDP4Table returns the system table for IPv4 UDP activity.
|
||||
func GetUDP4Table() (binds []*socket.BindInfo, err error) {
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
|
||||
err = checkIPHelper()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, binds, err = ipHelper.getTable(IPv4, UDP)
|
||||
return
|
||||
}
|
||||
|
||||
// GetUDP6Table returns the system table for IPv6 UDP activity.
|
||||
func GetUDP6Table() (binds []*socket.BindInfo, err error) {
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
|
||||
err = checkIPHelper()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, binds, err = ipHelper.getTable(IPv6, UDP)
|
||||
return
|
||||
}
|
63
network/iphelper/iphelper.go
Normal file
63
network/iphelper/iphelper.go
Normal file
|
@ -0,0 +1,63 @@
|
|||
// +build windows
|
||||
|
||||
package iphelper
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/tevino/abool"
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
var (
|
||||
errInvalid = errors.New("IPHelper not initialzed or broken")
|
||||
)
|
||||
|
||||
// IPHelper represents a subset of the Windows iphlpapi.dll.
|
||||
type IPHelper struct {
|
||||
dll *windows.LazyDLL
|
||||
|
||||
getExtendedTCPTable *windows.LazyProc
|
||||
getExtendedUDPTable *windows.LazyProc
|
||||
|
||||
valid *abool.AtomicBool
|
||||
}
|
||||
|
||||
func checkIPHelper() (err error) {
|
||||
if ipHelper == nil {
|
||||
ipHelper, err = New()
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// New returns a new IPHelper API (with an instance of iphlpapi.dll loaded).
|
||||
func New() (*IPHelper, error) {
|
||||
|
||||
new := &IPHelper{}
|
||||
new.valid = abool.NewBool(false)
|
||||
var err error
|
||||
|
||||
// load dll
|
||||
new.dll = windows.NewLazySystemDLL("iphlpapi.dll")
|
||||
err = new.dll.Load()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// load functions
|
||||
new.getExtendedTCPTable = new.dll.NewProc("GetExtendedTcpTable")
|
||||
err = new.getExtendedTCPTable.Find()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could find proc GetExtendedTcpTable: %s", err)
|
||||
}
|
||||
new.getExtendedUDPTable = new.dll.NewProc("GetExtendedUdpTable")
|
||||
err = new.getExtendedUDPTable.Find()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could find proc GetExtendedUdpTable: %s", err)
|
||||
}
|
||||
|
||||
new.valid.Set()
|
||||
return new, nil
|
||||
}
|
|
@ -3,12 +3,15 @@
|
|||
package iphelper
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"unsafe"
|
||||
|
||||
"github.com/safing/portmaster/network/socket"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
|
@ -22,19 +25,6 @@ const (
|
|||
winErrInvalidParameter = uintptr(windows.ERROR_INVALID_PARAMETER)
|
||||
)
|
||||
|
||||
// ConnectionEntry describes a connection state table entry.
|
||||
type ConnectionEntry struct {
|
||||
localIP net.IP
|
||||
remoteIP net.IP
|
||||
localPort uint16
|
||||
remotePort uint16
|
||||
pid int
|
||||
}
|
||||
|
||||
func (entry *ConnectionEntry) String() string {
|
||||
return fmt.Sprintf("PID=%d %s:%d <> %s:%d", entry.pid, entry.localIP, entry.localPort, entry.remoteIP, entry.remotePort)
|
||||
}
|
||||
|
||||
type iphelperTCPTable struct {
|
||||
// docs: https://msdn.microsoft.com/en-us/library/windows/desktop/aa366921(v=vs.85).aspx
|
||||
numEntries uint32
|
||||
|
@ -148,9 +138,9 @@ func increaseBufSize() int {
|
|||
return bufSize
|
||||
}
|
||||
|
||||
// GetTables returns the current connection state table of Windows of the given protocol and IP version.
|
||||
func (ipHelper *IPHelper) GetTables(protocol uint8, ipVersion uint8) (connections []*ConnectionEntry, listeners []*ConnectionEntry, err error) { //nolint:gocognit,gocycle // TODO
|
||||
// docs: https://msdn.microsoft.com/en-us/library/windows/desktop/aa365928(v=vs.85).aspx
|
||||
// getTable returns the current connection state table of Windows of the given protocol and IP version.
|
||||
func (ipHelper *IPHelper) getTable(ipVersion, protocol uint8) (connections []*socket.ConnectionInfo, binds []*socket.BindInfo, err error) { //nolint:gocognit,gocycle // TODO
|
||||
// docs: https://docs.microsoft.com/en-us/windows/win32/api/iphlpapi/nf-iphlpapi-getextendedtcptable
|
||||
|
||||
if !ipHelper.valid.IsSet() {
|
||||
return nil, nil, errInvalid
|
||||
|
@ -220,26 +210,27 @@ func (ipHelper *IPHelper) GetTables(protocol uint8, ipVersion uint8) (connection
|
|||
table := tcpTable.table[:tcpTable.numEntries]
|
||||
|
||||
for _, row := range table {
|
||||
new := &ConnectionEntry{}
|
||||
|
||||
// PID
|
||||
new.pid = int(row.owningPid)
|
||||
|
||||
// local
|
||||
if row.localAddr != 0 {
|
||||
new.localIP = convertIPv4(row.localAddr)
|
||||
}
|
||||
new.localPort = uint16(row.localPort>>8 | row.localPort<<8)
|
||||
|
||||
// remote
|
||||
if row.state == iphelperTCPStateListen {
|
||||
listeners = append(listeners, new)
|
||||
binds = append(binds, &socket.BindInfo{
|
||||
Local: socket.Address{
|
||||
IP: convertIPv4(row.localAddr),
|
||||
Port: convertPort(row.localPort),
|
||||
},
|
||||
PID: int(row.owningPid),
|
||||
})
|
||||
} else {
|
||||
new.remoteIP = convertIPv4(row.remoteAddr)
|
||||
new.remotePort = uint16(row.remotePort>>8 | row.remotePort<<8)
|
||||
connections = append(connections, new)
|
||||
connections = append(connections, &socket.ConnectionInfo{
|
||||
Local: socket.Address{
|
||||
IP: convertIPv4(row.localAddr),
|
||||
Port: convertPort(row.localPort),
|
||||
},
|
||||
Remote: socket.Address{
|
||||
IP: convertIPv4(row.remoteAddr),
|
||||
Port: convertPort(row.remotePort),
|
||||
},
|
||||
PID: int(row.owningPid),
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
case protocol == TCP && ipVersion == IPv6:
|
||||
|
@ -248,27 +239,27 @@ func (ipHelper *IPHelper) GetTables(protocol uint8, ipVersion uint8) (connection
|
|||
table := tcpTable.table[:tcpTable.numEntries]
|
||||
|
||||
for _, row := range table {
|
||||
new := &ConnectionEntry{}
|
||||
|
||||
// PID
|
||||
new.pid = int(row.owningPid)
|
||||
|
||||
// local
|
||||
new.localIP = net.IP(row.localAddr[:])
|
||||
new.localPort = uint16(row.localPort>>8 | row.localPort<<8)
|
||||
|
||||
// remote
|
||||
if row.state == iphelperTCPStateListen {
|
||||
if new.localIP.Equal(net.IPv6zero) {
|
||||
new.localIP = nil
|
||||
}
|
||||
listeners = append(listeners, new)
|
||||
binds = append(binds, &socket.BindInfo{
|
||||
Local: socket.Address{
|
||||
IP: net.IP(row.localAddr[:]),
|
||||
Port: convertPort(row.localPort),
|
||||
},
|
||||
PID: int(row.owningPid),
|
||||
})
|
||||
} else {
|
||||
new.remoteIP = net.IP(row.remoteAddr[:])
|
||||
new.remotePort = uint16(row.remotePort>>8 | row.remotePort<<8)
|
||||
connections = append(connections, new)
|
||||
connections = append(connections, &socket.ConnectionInfo{
|
||||
Local: socket.Address{
|
||||
IP: net.IP(row.localAddr[:]),
|
||||
Port: convertPort(row.localPort),
|
||||
},
|
||||
Remote: socket.Address{
|
||||
IP: net.IP(row.remoteAddr[:]),
|
||||
Port: convertPort(row.remotePort),
|
||||
},
|
||||
PID: int(row.owningPid),
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
case protocol == UDP && ipVersion == IPv4:
|
||||
|
@ -277,19 +268,13 @@ func (ipHelper *IPHelper) GetTables(protocol uint8, ipVersion uint8) (connection
|
|||
table := udpTable.table[:udpTable.numEntries]
|
||||
|
||||
for _, row := range table {
|
||||
new := &ConnectionEntry{}
|
||||
|
||||
// PID
|
||||
new.pid = int(row.owningPid)
|
||||
|
||||
// local
|
||||
new.localPort = uint16(row.localPort>>8 | row.localPort<<8)
|
||||
if row.localAddr == 0 {
|
||||
listeners = append(listeners, new)
|
||||
} else {
|
||||
new.localIP = convertIPv4(row.localAddr)
|
||||
connections = append(connections, new)
|
||||
}
|
||||
binds = append(binds, &socket.BindInfo{
|
||||
Local: socket.Address{
|
||||
IP: convertIPv4(row.localAddr),
|
||||
Port: convertPort(row.localPort),
|
||||
},
|
||||
PID: int(row.owningPid),
|
||||
})
|
||||
}
|
||||
|
||||
case protocol == UDP && ipVersion == IPv6:
|
||||
|
@ -298,32 +283,28 @@ func (ipHelper *IPHelper) GetTables(protocol uint8, ipVersion uint8) (connection
|
|||
table := udpTable.table[:udpTable.numEntries]
|
||||
|
||||
for _, row := range table {
|
||||
new := &ConnectionEntry{}
|
||||
|
||||
// PID
|
||||
new.pid = int(row.owningPid)
|
||||
|
||||
// local
|
||||
new.localIP = net.IP(row.localAddr[:])
|
||||
new.localPort = uint16(row.localPort>>8 | row.localPort<<8)
|
||||
if new.localIP.Equal(net.IPv6zero) {
|
||||
new.localIP = nil
|
||||
listeners = append(listeners, new)
|
||||
} else {
|
||||
connections = append(connections, new)
|
||||
}
|
||||
binds = append(binds, &socket.BindInfo{
|
||||
Local: socket.Address{
|
||||
IP: net.IP(row.localAddr[:]),
|
||||
Port: convertPort(row.localPort),
|
||||
},
|
||||
PID: int(row.owningPid),
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return connections, listeners, nil
|
||||
return connections, binds, nil
|
||||
}
|
||||
|
||||
// convertIPv4 as needed for iphlpapi.dll
|
||||
func convertIPv4(input uint32) net.IP {
|
||||
return net.IPv4(
|
||||
uint8(input&0xFF),
|
||||
uint8(input>>8&0xFF),
|
||||
uint8(input>>16&0xFF),
|
||||
uint8(input>>24&0xFF),
|
||||
)
|
||||
addressBuf := make([]byte, 4)
|
||||
binary.LittleEndian.PutUint32(addressBuf, input)
|
||||
return net.IP(addressBuf)
|
||||
}
|
||||
|
||||
// convertPort converts ports received from iphlpapi.dll
|
||||
func convertPort(input uint32) uint16 {
|
||||
return uint16(input>>8 | input<<8)
|
||||
}
|
54
network/iphelper/tables_test.go
Normal file
54
network/iphelper/tables_test.go
Normal file
|
@ -0,0 +1,54 @@
|
|||
// +build windows
|
||||
|
||||
package iphelper
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSockets(t *testing.T) {
|
||||
connections, listeners, err := GetTCP4Table()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
fmt.Println("\nTCP 4 connections:")
|
||||
for _, connection := range connections {
|
||||
fmt.Printf("%+v\n", connection)
|
||||
}
|
||||
fmt.Println("\nTCP 4 listeners:")
|
||||
for _, listener := range listeners {
|
||||
fmt.Printf("%+v\n", listener)
|
||||
}
|
||||
|
||||
connections, listeners, err = GetTCP6Table()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
fmt.Println("\nTCP 6 connections:")
|
||||
for _, connection := range connections {
|
||||
fmt.Printf("%+v\n", connection)
|
||||
}
|
||||
fmt.Println("\nTCP 6 listeners:")
|
||||
for _, listener := range listeners {
|
||||
fmt.Printf("%+v\n", listener)
|
||||
}
|
||||
|
||||
binds, err := GetUDP4Table()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
fmt.Println("\nUDP 4 binds:")
|
||||
for _, bind := range binds {
|
||||
fmt.Printf("%+v\n", bind)
|
||||
}
|
||||
|
||||
binds, err = GetUDP6Table()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
fmt.Println("\nUDP 6 binds:")
|
||||
for _, bind := range binds {
|
||||
fmt.Printf("%+v\n", bind)
|
||||
}
|
||||
}
|
|
@ -1,17 +1,12 @@
|
|||
package network
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"github.com/safing/portbase/modules"
|
||||
)
|
||||
|
||||
var (
|
||||
module *modules.Module
|
||||
|
||||
dnsAddress = net.IPv4(127, 0, 0, 1)
|
||||
dnsPort uint16 = 53
|
||||
|
||||
defaultFirewallHandler FirewallHandler
|
||||
)
|
||||
|
||||
|
|
|
@ -2,13 +2,57 @@ package netutils
|
|||
|
||||
import (
|
||||
"regexp"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
var (
|
||||
cleanDomainRegex = regexp.MustCompile(`^((xn--)?[a-z0-9-_]{0,61}[a-z0-9]{1,1}\.)*(xn--)?([a-z0-9-]{1,61}|[a-z0-9-]{1,30}\.[a-z]{2,}\.)$`)
|
||||
cleanDomainRegex = regexp.MustCompile(
|
||||
`^` + // match beginning
|
||||
`(` + // start subdomain group
|
||||
`(xn--)?` + // idn prefix
|
||||
`[a-z0-9_-]{1,63}` + // main chunk
|
||||
`\.` + // ending with a dot
|
||||
`)*` + // end subdomain group, allow any number of subdomains
|
||||
`(xn--)?` + // TLD idn prefix
|
||||
`[a-z0-9_-]{2,63}` + // TLD main chunk with at least two characters
|
||||
`\.` + // ending with a dot
|
||||
`$`, // match end
|
||||
)
|
||||
)
|
||||
|
||||
// IsValidFqdn returns whether the given string is a valid fqdn.
|
||||
func IsValidFqdn(fqdn string) bool {
|
||||
return cleanDomainRegex.MatchString(fqdn)
|
||||
// root zone
|
||||
if fqdn == "." {
|
||||
return true
|
||||
}
|
||||
|
||||
// check max length
|
||||
if len(fqdn) > 256 {
|
||||
return false
|
||||
}
|
||||
|
||||
// check with regex
|
||||
if !cleanDomainRegex.MatchString(fqdn) {
|
||||
return false
|
||||
}
|
||||
|
||||
// check with miegk/dns
|
||||
|
||||
// IsFqdn checks if a domain name is fully qualified.
|
||||
if !dns.IsFqdn(fqdn) {
|
||||
return false
|
||||
}
|
||||
|
||||
// IsDomainName checks if s is a valid domain name, it returns the number of
|
||||
// labels and true, when a domain name is valid. Note that non fully qualified
|
||||
// domain name is considered valid, in this case the last label is counted in
|
||||
// the number of labels. When false is returned the number of labels is not
|
||||
// defined. Also note that this function is extremely liberal; almost any
|
||||
// string is a valid domain name as the DNS is 8 bit protocol. It checks if each
|
||||
// label fits in 63 characters and that the entire name will fit into the 255
|
||||
// octet wire format limit.
|
||||
_, ok := dns.IsDomainName(fqdn)
|
||||
return ok
|
||||
}
|
||||
|
|
43
network/netutils/cleandns_test.go
Normal file
43
network/netutils/cleandns_test.go
Normal file
|
@ -0,0 +1,43 @@
|
|||
package netutils
|
||||
|
||||
import "testing"
|
||||
|
||||
func testDomainValidity(t *testing.T, domain string, isValid bool) {
|
||||
if IsValidFqdn(domain) != isValid {
|
||||
t.Errorf("domain %s failed check: was valid=%v, expected valid=%v", domain, IsValidFqdn(domain), isValid)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSValidation(t *testing.T) {
|
||||
// valid
|
||||
testDomainValidity(t, ".", true)
|
||||
testDomainValidity(t, "at.", true)
|
||||
testDomainValidity(t, "orf.at.", true)
|
||||
testDomainValidity(t, "www.orf.at.", true)
|
||||
testDomainValidity(t, "a.b.c.d.e.f.g.h.i.j.k.l.m.n.o.p.q.r.s.t.u.v.x.y.z.example.org.", true)
|
||||
testDomainValidity(t, "a_a.com.", true)
|
||||
testDomainValidity(t, "a-a.com.", true)
|
||||
testDomainValidity(t, "a_a.com.", true)
|
||||
testDomainValidity(t, "a-a.com.", true)
|
||||
testDomainValidity(t, "xn--a.com.", true)
|
||||
testDomainValidity(t, "xn--asdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasd.com.", true)
|
||||
|
||||
// maybe valid
|
||||
testDomainValidity(t, "-.com.", true)
|
||||
testDomainValidity(t, "_.com.", true)
|
||||
testDomainValidity(t, "a_.com.", true)
|
||||
testDomainValidity(t, "a-.com.", true)
|
||||
testDomainValidity(t, "_a.com.", true)
|
||||
testDomainValidity(t, "-a.com.", true)
|
||||
|
||||
// invalid
|
||||
testDomainValidity(t, ".com.", false)
|
||||
testDomainValidity(t, ".com.", false)
|
||||
testDomainValidity(t, "xn--asdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdf.com.", false)
|
||||
testDomainValidity(t, "asdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdf.com.", false)
|
||||
testDomainValidity(t, "asdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdf.com.", false)
|
||||
testDomainValidity(t, "asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.as.com.", false)
|
||||
|
||||
// real world examples
|
||||
testDomainValidity(t, "iuqerfsodp9ifjaposdfjhgosurijfaewrwergwea.com.", true)
|
||||
}
|
|
@ -36,22 +36,22 @@ func (pkt *Base) SetPacketInfo(packetInfo Info) {
|
|||
|
||||
// SetInbound sets a the packet direction to inbound. This must only used when initializing the packet structure.
|
||||
func (pkt *Base) SetInbound() {
|
||||
pkt.info.Direction = true
|
||||
pkt.info.Inbound = true
|
||||
}
|
||||
|
||||
// SetOutbound sets a the packet direction to outbound. This must only used when initializing the packet structure.
|
||||
func (pkt *Base) SetOutbound() {
|
||||
pkt.info.Direction = false
|
||||
pkt.info.Inbound = false
|
||||
}
|
||||
|
||||
// IsInbound checks if the packet is inbound.
|
||||
func (pkt *Base) IsInbound() bool {
|
||||
return pkt.info.Direction
|
||||
return pkt.info.Inbound
|
||||
}
|
||||
|
||||
// IsOutbound checks if the packet is outbound.
|
||||
func (pkt *Base) IsOutbound() bool {
|
||||
return !pkt.info.Direction
|
||||
return !pkt.info.Inbound
|
||||
}
|
||||
|
||||
// HasPorts checks if the packet has a protocol that uses ports.
|
||||
|
@ -80,13 +80,13 @@ func (pkt *Base) GetConnectionID() string {
|
|||
|
||||
func (pkt *Base) createConnectionID() {
|
||||
if pkt.info.Protocol == TCP || pkt.info.Protocol == UDP {
|
||||
if pkt.info.Direction {
|
||||
if pkt.info.Inbound {
|
||||
pkt.connID = fmt.Sprintf("%d-%s-%d-%s-%d", pkt.info.Protocol, pkt.info.Dst, pkt.info.DstPort, pkt.info.Src, pkt.info.SrcPort)
|
||||
} else {
|
||||
pkt.connID = fmt.Sprintf("%d-%s-%d-%s-%d", pkt.info.Protocol, pkt.info.Src, pkt.info.SrcPort, pkt.info.Dst, pkt.info.DstPort)
|
||||
}
|
||||
} else {
|
||||
if pkt.info.Direction {
|
||||
if pkt.info.Inbound {
|
||||
pkt.connID = fmt.Sprintf("%d-%s-%s", pkt.info.Protocol, pkt.info.Dst, pkt.info.Src)
|
||||
} else {
|
||||
pkt.connID = fmt.Sprintf("%d-%s-%s", pkt.info.Protocol, pkt.info.Src, pkt.info.Dst)
|
||||
|
@ -105,7 +105,7 @@ func (pkt *Base) MatchesAddress(remote bool, protocol IPProtocol, network *net.I
|
|||
if pkt.info.Protocol != protocol {
|
||||
return false
|
||||
}
|
||||
if pkt.info.Direction != remote {
|
||||
if pkt.info.Inbound != remote {
|
||||
if !network.Contains(pkt.info.Src) {
|
||||
return false
|
||||
}
|
||||
|
@ -131,7 +131,7 @@ func (pkt *Base) MatchesAddress(remote bool, protocol IPProtocol, network *net.I
|
|||
// Remote Src Dst
|
||||
//
|
||||
func (pkt *Base) MatchesIP(endpoint bool, network *net.IPNet) bool {
|
||||
if pkt.info.Direction != endpoint {
|
||||
if pkt.info.Inbound != endpoint {
|
||||
if network.Contains(pkt.info.Src) {
|
||||
return true
|
||||
}
|
||||
|
@ -152,12 +152,12 @@ func (pkt *Base) String() string {
|
|||
// FmtPacket returns the most important information about the packet as a string
|
||||
func (pkt *Base) FmtPacket() string {
|
||||
if pkt.info.Protocol == TCP || pkt.info.Protocol == UDP {
|
||||
if pkt.info.Direction {
|
||||
if pkt.info.Inbound {
|
||||
return fmt.Sprintf("IN %s %s:%d <-> %s:%d", pkt.info.Protocol, pkt.info.Dst, pkt.info.DstPort, pkt.info.Src, pkt.info.SrcPort)
|
||||
}
|
||||
return fmt.Sprintf("OUT %s %s:%d <-> %s:%d", pkt.info.Protocol, pkt.info.Src, pkt.info.SrcPort, pkt.info.Dst, pkt.info.DstPort)
|
||||
}
|
||||
if pkt.info.Direction {
|
||||
if pkt.info.Inbound {
|
||||
return fmt.Sprintf("IN %s %s <-> %s", pkt.info.Protocol, pkt.info.Dst, pkt.info.Src)
|
||||
}
|
||||
return fmt.Sprintf("OUT %s %s <-> %s", pkt.info.Protocol, pkt.info.Src, pkt.info.Dst)
|
||||
|
@ -170,7 +170,7 @@ func (pkt *Base) FmtProtocol() string {
|
|||
|
||||
// FmtRemoteIP returns the remote IP address as a string
|
||||
func (pkt *Base) FmtRemoteIP() string {
|
||||
if pkt.info.Direction {
|
||||
if pkt.info.Inbound {
|
||||
return pkt.info.Src.String()
|
||||
}
|
||||
return pkt.info.Dst.String()
|
||||
|
@ -179,7 +179,7 @@ func (pkt *Base) FmtRemoteIP() string {
|
|||
// FmtRemotePort returns the remote port as a string
|
||||
func (pkt *Base) FmtRemotePort() string {
|
||||
if pkt.info.SrcPort != 0 {
|
||||
if pkt.info.Direction {
|
||||
if pkt.info.Inbound {
|
||||
return fmt.Sprintf("%d", pkt.info.SrcPort)
|
||||
}
|
||||
return fmt.Sprintf("%d", pkt.info.DstPort)
|
||||
|
|
|
@ -6,8 +6,8 @@ import (
|
|||
|
||||
// Info holds IP and TCP/UDP header information
|
||||
type Info struct {
|
||||
Direction bool
|
||||
InTunnel bool
|
||||
Inbound bool
|
||||
InTunnel bool
|
||||
|
||||
Version IPVersion
|
||||
Protocol IPProtocol
|
||||
|
@ -17,7 +17,7 @@ type Info struct {
|
|||
|
||||
// LocalIP returns the local IP of the packet.
|
||||
func (pi *Info) LocalIP() net.IP {
|
||||
if pi.Direction {
|
||||
if pi.Inbound {
|
||||
return pi.Dst
|
||||
}
|
||||
return pi.Src
|
||||
|
@ -25,7 +25,7 @@ func (pi *Info) LocalIP() net.IP {
|
|||
|
||||
// RemoteIP returns the remote IP of the packet.
|
||||
func (pi *Info) RemoteIP() net.IP {
|
||||
if pi.Direction {
|
||||
if pi.Inbound {
|
||||
return pi.Src
|
||||
}
|
||||
return pi.Dst
|
||||
|
@ -33,7 +33,7 @@ func (pi *Info) RemoteIP() net.IP {
|
|||
|
||||
// LocalPort returns the local port of the packet.
|
||||
func (pi *Info) LocalPort() uint16 {
|
||||
if pi.Direction {
|
||||
if pi.Inbound {
|
||||
return pi.DstPort
|
||||
}
|
||||
return pi.SrcPort
|
||||
|
@ -41,7 +41,7 @@ func (pi *Info) LocalPort() uint16 {
|
|||
|
||||
// RemotePort returns the remote port of the packet.
|
||||
func (pi *Info) RemotePort() uint16 {
|
||||
if pi.Direction {
|
||||
if pi.Inbound {
|
||||
return pi.SrcPort
|
||||
}
|
||||
return pi.DstPort
|
||||
|
|
|
@ -10,6 +10,8 @@ import (
|
|||
"sync"
|
||||
"syscall"
|
||||
|
||||
"github.com/safing/portmaster/network/socket"
|
||||
|
||||
"github.com/safing/portbase/log"
|
||||
)
|
||||
|
||||
|
@ -18,8 +20,8 @@ var (
|
|||
pidsByUser = make(map[int][]int)
|
||||
)
|
||||
|
||||
// GetPidOfInode returns the pid of the given uid and socket inode.
|
||||
func GetPidOfInode(uid, inode int) (int, bool) { //nolint:gocognit // TODO
|
||||
// FindPID returns the pid of the given uid and socket inode.
|
||||
func FindPID(uid, inode int) (pid int) { //nolint:gocognit // TODO
|
||||
pidsByUserLock.Lock()
|
||||
defer pidsByUserLock.Unlock()
|
||||
|
||||
|
@ -38,7 +40,7 @@ func GetPidOfInode(uid, inode int) (int, bool) { //nolint:gocognit // TODO
|
|||
var checkedUserPids []int
|
||||
for _, possiblePID := range pids {
|
||||
if findSocketFromPid(possiblePID, inode) {
|
||||
return possiblePID, true
|
||||
return possiblePID
|
||||
}
|
||||
checkedUserPids = append(checkedUserPids, possiblePID)
|
||||
}
|
||||
|
@ -57,7 +59,7 @@ func GetPidOfInode(uid, inode int) (int, bool) { //nolint:gocognit // TODO
|
|||
// only check if not already checked
|
||||
if sort.SearchInts(checkedUserPids, possiblePID) == len {
|
||||
if findSocketFromPid(possiblePID, inode) {
|
||||
return possiblePID, true
|
||||
return possiblePID
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -71,13 +73,13 @@ func GetPidOfInode(uid, inode int) (int, bool) { //nolint:gocognit // TODO
|
|||
if possibleUID != uid {
|
||||
for _, possiblePID := range pids {
|
||||
if findSocketFromPid(possiblePID, inode) {
|
||||
return possiblePID, true
|
||||
return possiblePID
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return unidentifiedProcessID, false
|
||||
return socket.UnidentifiedProcessID
|
||||
}
|
||||
|
||||
func findSocketFromPid(pid, inode int) bool {
|
237
network/proc/tables.go
Normal file
237
network/proc/tables.go
Normal file
|
@ -0,0 +1,237 @@
|
|||
// +build linux
|
||||
|
||||
package proc
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"github.com/safing/portmaster/network/socket"
|
||||
|
||||
"github.com/safing/portbase/log"
|
||||
)
|
||||
|
||||
/*
|
||||
|
||||
1. find socket inode
|
||||
- by incoming (listenting sockets) or outgoing (local port + external IP + port) - also local IP?
|
||||
- /proc/net/{tcp|udp}[6]
|
||||
|
||||
2. get list of processes of uid
|
||||
|
||||
3. find socket inode in process fds
|
||||
- if not found, refresh map of uid->pids
|
||||
- if not found, check ALL pids: maybe euid != uid
|
||||
|
||||
4. gather process info
|
||||
|
||||
Cache every step!
|
||||
|
||||
*/
|
||||
|
||||
// Network Related Constants
|
||||
const (
|
||||
TCP4 uint8 = iota
|
||||
UDP4
|
||||
TCP6
|
||||
UDP6
|
||||
ICMP4
|
||||
ICMP6
|
||||
|
||||
tcp4ProcFile = "/proc/net/tcp"
|
||||
tcp6ProcFile = "/proc/net/tcp6"
|
||||
udp4ProcFile = "/proc/net/udp"
|
||||
udp6ProcFile = "/proc/net/udp6"
|
||||
|
||||
UnfetchedProcessID = -2
|
||||
|
||||
tcpListenStateHex = "0A"
|
||||
)
|
||||
|
||||
// GetTCP4Table returns the system table for IPv4 TCP activity.
|
||||
func GetTCP4Table() (connections []*socket.ConnectionInfo, listeners []*socket.BindInfo, err error) {
|
||||
return getTableFromSource(TCP4, tcp4ProcFile)
|
||||
}
|
||||
|
||||
// GetTCP6Table returns the system table for IPv6 TCP activity.
|
||||
func GetTCP6Table() (connections []*socket.ConnectionInfo, listeners []*socket.BindInfo, err error) {
|
||||
return getTableFromSource(TCP6, tcp6ProcFile)
|
||||
}
|
||||
|
||||
// GetUDP4Table returns the system table for IPv4 UDP activity.
|
||||
func GetUDP4Table() (binds []*socket.BindInfo, err error) {
|
||||
_, binds, err = getTableFromSource(UDP4, udp4ProcFile)
|
||||
return
|
||||
}
|
||||
|
||||
// GetUDP6Table returns the system table for IPv6 UDP activity.
|
||||
func GetUDP6Table() (binds []*socket.BindInfo, err error) {
|
||||
_, binds, err = getTableFromSource(UDP6, udp6ProcFile)
|
||||
return
|
||||
}
|
||||
|
||||
const (
|
||||
// hint: we split fields by multiple delimiters, see procDelimiter
|
||||
fieldIndexLocalIP = 1
|
||||
fieldIndexLocalPort = 2
|
||||
fieldIndexRemoteIP = 3
|
||||
fieldIndexRemotePort = 4
|
||||
fieldIndexUID = 11
|
||||
fieldIndexInode = 13
|
||||
)
|
||||
|
||||
func getTableFromSource(stack uint8, procFile string) (connections []*socket.ConnectionInfo, binds []*socket.BindInfo, err error) {
|
||||
|
||||
var ipConverter func(string) net.IP
|
||||
switch stack {
|
||||
case TCP4, UDP4:
|
||||
ipConverter = convertIPv4
|
||||
case TCP6, UDP6:
|
||||
ipConverter = convertIPv6
|
||||
default:
|
||||
return nil, nil, fmt.Errorf("unsupported table stack: %d", stack)
|
||||
}
|
||||
|
||||
// open file
|
||||
socketData, err := os.Open(procFile)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
defer socketData.Close()
|
||||
|
||||
// file scanner
|
||||
scanner := bufio.NewScanner(socketData)
|
||||
scanner.Split(bufio.ScanLines)
|
||||
|
||||
// parse
|
||||
scanner.Scan() // skip first row
|
||||
for scanner.Scan() {
|
||||
fields := strings.FieldsFunc(scanner.Text(), procDelimiter)
|
||||
if len(fields) < 14 {
|
||||
// log.Tracef("process: too short: %s", fields)
|
||||
continue
|
||||
}
|
||||
|
||||
localIP := ipConverter(fields[fieldIndexLocalIP])
|
||||
if localIP == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
localPort, err := strconv.ParseUint(fields[fieldIndexLocalPort], 16, 16)
|
||||
if err != nil {
|
||||
log.Warningf("process: could not parse port: %s", err)
|
||||
continue
|
||||
}
|
||||
|
||||
uid, err := strconv.ParseInt(fields[fieldIndexUID], 10, 32)
|
||||
// log.Tracef("uid: %s", fields[fieldIndexUID])
|
||||
if err != nil {
|
||||
log.Warningf("process: could not parse uid %s: %s", fields[11], err)
|
||||
continue
|
||||
}
|
||||
|
||||
inode, err := strconv.ParseInt(fields[fieldIndexInode], 10, 32)
|
||||
// log.Tracef("inode: %s", fields[fieldIndexInode])
|
||||
if err != nil {
|
||||
log.Warningf("process: could not parse inode %s: %s", fields[13], err)
|
||||
continue
|
||||
}
|
||||
|
||||
switch stack {
|
||||
case UDP4, UDP6:
|
||||
|
||||
binds = append(binds, &socket.BindInfo{
|
||||
Local: socket.Address{
|
||||
IP: localIP,
|
||||
Port: uint16(localPort),
|
||||
},
|
||||
PID: UnfetchedProcessID,
|
||||
UID: int(uid),
|
||||
Inode: int(inode),
|
||||
})
|
||||
|
||||
case TCP4, TCP6:
|
||||
|
||||
if fields[5] == tcpListenStateHex {
|
||||
// listener
|
||||
|
||||
binds = append(binds, &socket.BindInfo{
|
||||
Local: socket.Address{
|
||||
IP: localIP,
|
||||
Port: uint16(localPort),
|
||||
},
|
||||
PID: UnfetchedProcessID,
|
||||
UID: int(uid),
|
||||
Inode: int(inode),
|
||||
})
|
||||
} else {
|
||||
// connection
|
||||
|
||||
remoteIP := ipConverter(fields[fieldIndexRemoteIP])
|
||||
if remoteIP == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
remotePort, err := strconv.ParseUint(fields[fieldIndexRemotePort], 16, 16)
|
||||
if err != nil {
|
||||
log.Warningf("process: could not parse port: %s", err)
|
||||
continue
|
||||
}
|
||||
|
||||
connections = append(connections, &socket.ConnectionInfo{
|
||||
Local: socket.Address{
|
||||
IP: localIP,
|
||||
Port: uint16(localPort),
|
||||
},
|
||||
Remote: socket.Address{
|
||||
IP: remoteIP,
|
||||
Port: uint16(remotePort),
|
||||
},
|
||||
PID: UnfetchedProcessID,
|
||||
UID: int(uid),
|
||||
Inode: int(inode),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return connections, binds, nil
|
||||
}
|
||||
|
||||
func procDelimiter(c rune) bool {
|
||||
return unicode.IsSpace(c) || c == ':'
|
||||
}
|
||||
|
||||
func convertIPv4(data string) net.IP {
|
||||
decoded, err := hex.DecodeString(data)
|
||||
if err != nil {
|
||||
log.Warningf("process: could not parse IPv4 %s: %s", data, err)
|
||||
return nil
|
||||
}
|
||||
if len(decoded) != 4 {
|
||||
log.Warningf("process: decoded IPv4 %s has wrong length", decoded)
|
||||
return nil
|
||||
}
|
||||
ip := net.IPv4(decoded[3], decoded[2], decoded[1], decoded[0])
|
||||
return ip
|
||||
}
|
||||
|
||||
func convertIPv6(data string) net.IP {
|
||||
decoded, err := hex.DecodeString(data)
|
||||
if err != nil {
|
||||
log.Warningf("process: could not parse IPv6 %s: %s", data, err)
|
||||
return nil
|
||||
}
|
||||
if len(decoded) != 16 {
|
||||
log.Warningf("process: decoded IPv6 %s has wrong length", decoded)
|
||||
return nil
|
||||
}
|
||||
ip := net.IP(decoded)
|
||||
return ip
|
||||
}
|
60
network/proc/tables_test.go
Normal file
60
network/proc/tables_test.go
Normal file
|
@ -0,0 +1,60 @@
|
|||
// +build linux
|
||||
|
||||
package proc
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSockets(t *testing.T) {
|
||||
connections, listeners, err := GetTCP4Table()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
fmt.Println("\nTCP 4 connections:")
|
||||
for _, connection := range connections {
|
||||
pid := FindPID(connection.UID, connection.Inode)
|
||||
fmt.Printf("%d: %+v\n", pid, connection)
|
||||
}
|
||||
fmt.Println("\nTCP 4 listeners:")
|
||||
for _, listener := range listeners {
|
||||
pid := FindPID(listener.UID, listener.Inode)
|
||||
fmt.Printf("%d: %+v\n", pid, listener)
|
||||
}
|
||||
|
||||
connections, listeners, err = GetTCP6Table()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
fmt.Println("\nTCP 6 connections:")
|
||||
for _, connection := range connections {
|
||||
pid := FindPID(connection.UID, connection.Inode)
|
||||
fmt.Printf("%d: %+v\n", pid, connection)
|
||||
}
|
||||
fmt.Println("\nTCP 6 listeners:")
|
||||
for _, listener := range listeners {
|
||||
pid := FindPID(listener.UID, listener.Inode)
|
||||
fmt.Printf("%d: %+v\n", pid, listener)
|
||||
}
|
||||
|
||||
binds, err := GetUDP4Table()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
fmt.Println("\nUDP 4 binds:")
|
||||
for _, bind := range binds {
|
||||
pid := FindPID(bind.UID, bind.Inode)
|
||||
fmt.Printf("%d: %+v\n", pid, bind)
|
||||
}
|
||||
|
||||
binds, err = GetUDP6Table()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
fmt.Println("\nUDP 6 binds:")
|
||||
for _, bind := range binds {
|
||||
pid := FindPID(bind.UID, bind.Inode)
|
||||
fmt.Printf("%d: %+v\n", pid, bind)
|
||||
}
|
||||
}
|
31
network/socket/socket.go
Normal file
31
network/socket/socket.go
Normal file
|
@ -0,0 +1,31 @@
|
|||
package socket
|
||||
|
||||
import "net"
|
||||
|
||||
const (
|
||||
// UnidentifiedProcessID is originally defined in the process pkg, but duplicated here because of import loops.
|
||||
UnidentifiedProcessID = -1
|
||||
)
|
||||
|
||||
// ConnectionInfo holds socket information returned by the system.
|
||||
type ConnectionInfo struct {
|
||||
Local Address
|
||||
Remote Address
|
||||
PID int
|
||||
UID int
|
||||
Inode int
|
||||
}
|
||||
|
||||
// BindInfo holds socket information returned by the system.
|
||||
type BindInfo struct {
|
||||
Local Address
|
||||
PID int
|
||||
UID int
|
||||
Inode int
|
||||
}
|
||||
|
||||
// Address is an IP + Port pair.
|
||||
type Address struct {
|
||||
IP net.IP
|
||||
Port uint16
|
||||
}
|
101
network/state/exists.go
Normal file
101
network/state/exists.go
Normal file
|
@ -0,0 +1,101 @@
|
|||
package state
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/safing/portmaster/network/packet"
|
||||
"github.com/safing/portmaster/network/socket"
|
||||
)
|
||||
|
||||
const (
|
||||
// UDPConnectionTTL defines the duration after which unseen UDP connections are regarded as ended.
|
||||
UDPConnectionTTL = 10 * time.Minute
|
||||
)
|
||||
|
||||
// Exists checks if the given connection is present in the system state tables.
|
||||
func Exists(pktInfo *packet.Info, now time.Time) (exists bool) {
|
||||
|
||||
// TODO: create lookup maps before running a flurry of Exists() checks.
|
||||
|
||||
switch {
|
||||
case pktInfo.Version == packet.IPv4 && pktInfo.Protocol == packet.TCP:
|
||||
tcp4Lock.Lock()
|
||||
defer tcp4Lock.Unlock()
|
||||
return existsTCP(tcp4Connections, pktInfo)
|
||||
|
||||
case pktInfo.Version == packet.IPv6 && pktInfo.Protocol == packet.TCP:
|
||||
tcp6Lock.Lock()
|
||||
defer tcp6Lock.Unlock()
|
||||
return existsTCP(tcp6Connections, pktInfo)
|
||||
|
||||
case pktInfo.Version == packet.IPv4 && pktInfo.Protocol == packet.UDP:
|
||||
udp4Lock.Lock()
|
||||
defer udp4Lock.Unlock()
|
||||
return existsUDP(udp4Binds, udp4States, pktInfo, now)
|
||||
|
||||
case pktInfo.Version == packet.IPv6 && pktInfo.Protocol == packet.UDP:
|
||||
udp6Lock.Lock()
|
||||
defer udp6Lock.Unlock()
|
||||
return existsUDP(udp6Binds, udp6States, pktInfo, now)
|
||||
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func existsTCP(connections []*socket.ConnectionInfo, pktInfo *packet.Info) (exists bool) {
|
||||
localIP := pktInfo.LocalIP()
|
||||
localPort := pktInfo.LocalPort()
|
||||
remoteIP := pktInfo.RemoteIP()
|
||||
remotePort := pktInfo.RemotePort()
|
||||
|
||||
// search connections
|
||||
for _, socketInfo := range connections {
|
||||
if localPort == socketInfo.Local.Port &&
|
||||
remotePort == socketInfo.Remote.Port &&
|
||||
remoteIP.Equal(socketInfo.Remote.IP) &&
|
||||
localIP.Equal(socketInfo.Local.IP) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func existsUDP(
|
||||
binds []*socket.BindInfo,
|
||||
udpStates map[string]map[string]*udpState,
|
||||
pktInfo *packet.Info,
|
||||
now time.Time,
|
||||
) (exists bool) {
|
||||
|
||||
localIP := pktInfo.LocalIP()
|
||||
localPort := pktInfo.LocalPort()
|
||||
remoteIP := pktInfo.RemoteIP()
|
||||
remotePort := pktInfo.RemotePort()
|
||||
|
||||
connThreshhold := now.Add(-UDPConnectionTTL)
|
||||
|
||||
// search binds
|
||||
for _, socketInfo := range binds {
|
||||
if localPort == socketInfo.Local.Port &&
|
||||
(socketInfo.Local.IP[0] == 0 || localIP.Equal(socketInfo.Local.IP)) {
|
||||
|
||||
udpConnState, ok := getUDPConnState(socketInfo, udpStates, socket.Address{
|
||||
IP: remoteIP,
|
||||
Port: remotePort,
|
||||
})
|
||||
switch {
|
||||
case !ok:
|
||||
return false
|
||||
case udpConnState.lastSeen.After(connThreshhold):
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
52
network/state/info.go
Normal file
52
network/state/info.go
Normal file
|
@ -0,0 +1,52 @@
|
|||
package state
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/safing/portbase/database/record"
|
||||
|
||||
"github.com/safing/portmaster/network/socket"
|
||||
)
|
||||
|
||||
// Info holds network state information as provided by the system.
|
||||
type Info struct {
|
||||
record.Base
|
||||
sync.Mutex
|
||||
|
||||
TCP4Connections []*socket.ConnectionInfo
|
||||
TCP4Listeners []*socket.BindInfo
|
||||
TCP6Connections []*socket.ConnectionInfo
|
||||
TCP6Listeners []*socket.BindInfo
|
||||
UDP4Binds []*socket.BindInfo
|
||||
UDP6Binds []*socket.BindInfo
|
||||
}
|
||||
|
||||
// GetInfo returns all system state tables. The returned data must not be modified.
|
||||
func GetInfo() *Info {
|
||||
info := &Info{}
|
||||
|
||||
tcp4Lock.Lock()
|
||||
updateTCP4Tables()
|
||||
info.TCP4Connections = tcp4Connections
|
||||
info.TCP4Listeners = tcp4Listeners
|
||||
tcp4Lock.Unlock()
|
||||
|
||||
tcp6Lock.Lock()
|
||||
updateTCP6Tables()
|
||||
info.TCP6Connections = tcp6Connections
|
||||
info.TCP6Listeners = tcp6Listeners
|
||||
tcp6Lock.Unlock()
|
||||
|
||||
udp4Lock.Lock()
|
||||
updateUDP4Table()
|
||||
info.UDP4Binds = udp4Binds
|
||||
udp4Lock.Unlock()
|
||||
|
||||
udp6Lock.Lock()
|
||||
updateUDP6Table()
|
||||
info.UDP6Binds = udp6Binds
|
||||
udp6Lock.Unlock()
|
||||
|
||||
info.UpdateMeta()
|
||||
return info
|
||||
}
|
171
network/state/lookup.go
Normal file
171
network/state/lookup.go
Normal file
|
@ -0,0 +1,171 @@
|
|||
package state
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/safing/portmaster/network/netutils"
|
||||
"github.com/safing/portmaster/network/packet"
|
||||
"github.com/safing/portmaster/network/socket"
|
||||
)
|
||||
|
||||
// - TCP
|
||||
// - Outbound: Match listeners (in!), then connections (out!)
|
||||
// - Inbound: Match listeners (in!), then connections (out!)
|
||||
// - Clean via connections
|
||||
// - UDP
|
||||
// - Any connection: match specific local address or zero IP
|
||||
// - In or out: save direction of first packet:
|
||||
// - map[<local udp bind ip+port>]map[<remote ip+port>]{direction, lastSeen}
|
||||
// - only clean if <local udp bind ip+port> is removed by OS
|
||||
// - limit <remote ip+port> to 256 entries?
|
||||
// - clean <remote ip+port> after 72hrs?
|
||||
// - switch direction to outbound if outbound packet is seen?
|
||||
// - IP: Unidentified Process
|
||||
|
||||
// Errors
|
||||
var (
|
||||
ErrConnectionNotFound = errors.New("could not find connection in system state tables")
|
||||
ErrPIDNotFound = errors.New("could not find pid for socket inode")
|
||||
)
|
||||
|
||||
var (
|
||||
tcp4Lock sync.Mutex
|
||||
tcp6Lock sync.Mutex
|
||||
udp4Lock sync.Mutex
|
||||
udp6Lock sync.Mutex
|
||||
|
||||
baseWaitTime = 3 * time.Millisecond
|
||||
)
|
||||
|
||||
// Lookup looks for the given connection in the system state tables and returns the PID of the associated process and whether the connection is inbound.
|
||||
func Lookup(pktInfo *packet.Info) (pid int, inbound bool, err error) {
|
||||
// auto-detect version
|
||||
if pktInfo.Version == 0 {
|
||||
if ip := pktInfo.LocalIP().To4(); ip != nil {
|
||||
pktInfo.Version = packet.IPv4
|
||||
} else {
|
||||
pktInfo.Version = packet.IPv6
|
||||
}
|
||||
}
|
||||
|
||||
switch {
|
||||
case pktInfo.Version == packet.IPv4 && pktInfo.Protocol == packet.TCP:
|
||||
tcp4Lock.Lock()
|
||||
defer tcp4Lock.Unlock()
|
||||
return searchTCP(tcp4Connections, tcp4Listeners, updateTCP4Tables, pktInfo)
|
||||
|
||||
case pktInfo.Version == packet.IPv6 && pktInfo.Protocol == packet.TCP:
|
||||
tcp6Lock.Lock()
|
||||
defer tcp6Lock.Unlock()
|
||||
return searchTCP(tcp6Connections, tcp6Listeners, updateTCP6Tables, pktInfo)
|
||||
|
||||
case pktInfo.Version == packet.IPv4 && pktInfo.Protocol == packet.UDP:
|
||||
udp4Lock.Lock()
|
||||
defer udp4Lock.Unlock()
|
||||
return searchUDP(udp4Binds, udp4States, updateUDP4Table, pktInfo)
|
||||
|
||||
case pktInfo.Version == packet.IPv6 && pktInfo.Protocol == packet.UDP:
|
||||
udp6Lock.Lock()
|
||||
defer udp6Lock.Unlock()
|
||||
return searchUDP(udp6Binds, udp6States, updateUDP6Table, pktInfo)
|
||||
|
||||
default:
|
||||
return socket.UnidentifiedProcessID, false, errors.New("unsupported protocol for finding process")
|
||||
}
|
||||
}
|
||||
|
||||
func searchTCP(
|
||||
connections []*socket.ConnectionInfo,
|
||||
listeners []*socket.BindInfo,
|
||||
updateTables func() ([]*socket.ConnectionInfo, []*socket.BindInfo),
|
||||
pktInfo *packet.Info,
|
||||
) (
|
||||
pid int,
|
||||
inbound bool,
|
||||
err error,
|
||||
) {
|
||||
|
||||
localIP := pktInfo.LocalIP()
|
||||
localPort := pktInfo.LocalPort()
|
||||
|
||||
// search until we find something
|
||||
for i := 0; i < 7; i++ {
|
||||
// always search listeners first
|
||||
for _, socketInfo := range listeners {
|
||||
if localPort == socketInfo.Local.Port &&
|
||||
(socketInfo.Local.IP[0] == 0 || localIP.Equal(socketInfo.Local.IP)) {
|
||||
return checkBindPID(socketInfo, true)
|
||||
}
|
||||
}
|
||||
|
||||
// search connections
|
||||
for _, socketInfo := range connections {
|
||||
if localPort == socketInfo.Local.Port &&
|
||||
localIP.Equal(socketInfo.Local.IP) {
|
||||
return checkConnectionPID(socketInfo, false)
|
||||
}
|
||||
}
|
||||
|
||||
// we found nothing, we could have been too fast, give the kernel some time to think
|
||||
// back off timer: with 3ms baseWaitTime: 3, 6, 9, 12, 15, 18, 21ms - 84ms in total
|
||||
time.Sleep(time.Duration(i+1) * baseWaitTime)
|
||||
|
||||
// refetch lists
|
||||
connections, listeners = updateTables()
|
||||
}
|
||||
|
||||
return socket.UnidentifiedProcessID, false, ErrConnectionNotFound
|
||||
}
|
||||
|
||||
func searchUDP(
|
||||
binds []*socket.BindInfo,
|
||||
udpStates map[string]map[string]*udpState,
|
||||
updateTable func() []*socket.BindInfo,
|
||||
pktInfo *packet.Info,
|
||||
) (
|
||||
pid int,
|
||||
inbound bool,
|
||||
err error,
|
||||
) {
|
||||
|
||||
localIP := pktInfo.LocalIP()
|
||||
localPort := pktInfo.LocalPort()
|
||||
|
||||
isInboundMulticast := pktInfo.Inbound && netutils.ClassifyIP(localIP) == netutils.LocalMulticast
|
||||
// TODO: Currently broadcast/multicast scopes are not checked, so we might
|
||||
// attribute an incoming broadcast/multicast packet to the wrong process if
|
||||
// there are multiple processes listening on the same local port, but
|
||||
// binding to different addresses. This highly unusual for clients.
|
||||
|
||||
// search until we find something
|
||||
for i := 0; i < 5; i++ {
|
||||
// search binds
|
||||
for _, socketInfo := range binds {
|
||||
if localPort == socketInfo.Local.Port &&
|
||||
(socketInfo.Local.IP[0] == 0 || // zero IP
|
||||
isInboundMulticast || // inbound broadcast, multicast
|
||||
localIP.Equal(socketInfo.Local.IP)) {
|
||||
|
||||
// do not check direction if remoteIP/Port is not given
|
||||
if pktInfo.RemotePort() == 0 {
|
||||
return checkBindPID(socketInfo, pktInfo.Inbound)
|
||||
}
|
||||
|
||||
// get direction and return
|
||||
connInbound := getUDPDirection(socketInfo, udpStates, pktInfo)
|
||||
return checkBindPID(socketInfo, connInbound)
|
||||
}
|
||||
}
|
||||
|
||||
// we found nothing, we could have been too fast, give the kernel some time to think
|
||||
// back off timer: with 3ms baseWaitTime: 3, 6, 9, 12, 15, 18, 21ms - 84ms in total
|
||||
time.Sleep(time.Duration(i+1) * baseWaitTime)
|
||||
|
||||
// refetch lists
|
||||
binds = updateTable()
|
||||
}
|
||||
|
||||
return socket.UnidentifiedProcessID, pktInfo.Inbound, ErrConnectionNotFound
|
||||
}
|
27
network/state/system_linux.go
Normal file
27
network/state/system_linux.go
Normal file
|
@ -0,0 +1,27 @@
|
|||
package state
|
||||
|
||||
import (
|
||||
"github.com/safing/portmaster/network/proc"
|
||||
"github.com/safing/portmaster/network/socket"
|
||||
)
|
||||
|
||||
var (
|
||||
getTCP4Table = proc.GetTCP4Table
|
||||
getTCP6Table = proc.GetTCP6Table
|
||||
getUDP4Table = proc.GetUDP4Table
|
||||
getUDP6Table = proc.GetUDP6Table
|
||||
)
|
||||
|
||||
func checkConnectionPID(socketInfo *socket.ConnectionInfo, connInbound bool) (pid int, inbound bool, err error) {
|
||||
if socketInfo.PID == proc.UnfetchedProcessID {
|
||||
socketInfo.PID = proc.FindPID(socketInfo.UID, socketInfo.Inode)
|
||||
}
|
||||
return socketInfo.PID, connInbound, nil
|
||||
}
|
||||
|
||||
func checkBindPID(socketInfo *socket.BindInfo, connInbound bool) (pid int, inbound bool, err error) {
|
||||
if socketInfo.PID == proc.UnfetchedProcessID {
|
||||
socketInfo.PID = proc.FindPID(socketInfo.UID, socketInfo.Inode)
|
||||
}
|
||||
return socketInfo.PID, connInbound, nil
|
||||
}
|
21
network/state/system_windows.go
Normal file
21
network/state/system_windows.go
Normal file
|
@ -0,0 +1,21 @@
|
|||
package state
|
||||
|
||||
import (
|
||||
"github.com/safing/portmaster/network/iphelper"
|
||||
"github.com/safing/portmaster/network/socket"
|
||||
)
|
||||
|
||||
var (
|
||||
getTCP4Table = iphelper.GetTCP4Table
|
||||
getTCP6Table = iphelper.GetTCP6Table
|
||||
getUDP4Table = iphelper.GetUDP4Table
|
||||
getUDP6Table = iphelper.GetUDP6Table
|
||||
)
|
||||
|
||||
func checkConnectionPID(socketInfo *socket.ConnectionInfo, connInbound bool) (pid int, inbound bool, err error) {
|
||||
return socketInfo.PID, connInbound, nil
|
||||
}
|
||||
|
||||
func checkBindPID(socketInfo *socket.BindInfo, connInbound bool) (pid int, inbound bool, err error) {
|
||||
return socketInfo.PID, connInbound, nil
|
||||
}
|
68
network/state/tables.go
Normal file
68
network/state/tables.go
Normal file
|
@ -0,0 +1,68 @@
|
|||
package state
|
||||
|
||||
import (
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portmaster/network/socket"
|
||||
)
|
||||
|
||||
var (
|
||||
tcp4Connections []*socket.ConnectionInfo
|
||||
tcp4Listeners []*socket.BindInfo
|
||||
|
||||
tcp6Connections []*socket.ConnectionInfo
|
||||
tcp6Listeners []*socket.BindInfo
|
||||
|
||||
udp4Binds []*socket.BindInfo
|
||||
|
||||
udp6Binds []*socket.BindInfo
|
||||
)
|
||||
|
||||
func updateTCP4Tables() (connections []*socket.ConnectionInfo, listeners []*socket.BindInfo) {
|
||||
var err error
|
||||
connections, listeners, err = getTCP4Table()
|
||||
if err != nil {
|
||||
log.Warningf("state: failed to get TCP4 socket table: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
tcp4Connections = connections
|
||||
tcp4Listeners = listeners
|
||||
return
|
||||
}
|
||||
|
||||
func updateTCP6Tables() (connections []*socket.ConnectionInfo, listeners []*socket.BindInfo) {
|
||||
var err error
|
||||
connections, listeners, err = getTCP6Table()
|
||||
if err != nil {
|
||||
log.Warningf("state: failed to get TCP6 socket table: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
tcp6Connections = connections
|
||||
tcp6Listeners = listeners
|
||||
return
|
||||
}
|
||||
|
||||
func updateUDP4Table() (binds []*socket.BindInfo) {
|
||||
var err error
|
||||
binds, err = getUDP4Table()
|
||||
if err != nil {
|
||||
log.Warningf("state: failed to get UDP4 socket table: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
udp4Binds = binds
|
||||
return
|
||||
}
|
||||
|
||||
func updateUDP6Table() (binds []*socket.BindInfo) {
|
||||
var err error
|
||||
binds, err = getUDP6Table()
|
||||
if err != nil {
|
||||
log.Warningf("state: failed to get UDP6 socket table: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
udp6Binds = binds
|
||||
return
|
||||
}
|
125
network/state/udp.go
Normal file
125
network/state/udp.go
Normal file
|
@ -0,0 +1,125 @@
|
|||
package state
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/safing/portmaster/network/packet"
|
||||
"github.com/safing/portmaster/network/socket"
|
||||
)
|
||||
|
||||
type udpState struct {
|
||||
inbound bool
|
||||
lastSeen time.Time
|
||||
}
|
||||
|
||||
const (
|
||||
// UDPConnStateTTL is the maximum time a udp connection state is held.
|
||||
UDPConnStateTTL = 72 * time.Hour
|
||||
|
||||
// UDPConnStateShortenedTTL is a shortened maximum time a udp connection state is held, if there more entries than defined by AggressiveCleaningThreshold.
|
||||
UDPConnStateShortenedTTL = 3 * time.Hour
|
||||
|
||||
// AggressiveCleaningThreshold defines the soft limit of udp connection state held per udp socket.
|
||||
AggressiveCleaningThreshold = 256
|
||||
)
|
||||
|
||||
var (
|
||||
udp4States = make(map[string]map[string]*udpState) // locked with udp4Lock
|
||||
udp6States = make(map[string]map[string]*udpState) // locked with udp6Lock
|
||||
)
|
||||
|
||||
func getUDPConnState(socketInfo *socket.BindInfo, udpStates map[string]map[string]*udpState, remoteAddress socket.Address) (udpConnState *udpState, ok bool) {
|
||||
bindMap, ok := udpStates[makeUDPStateKey(socketInfo.Local)]
|
||||
if ok {
|
||||
udpConnState, ok = bindMap[makeUDPStateKey(remoteAddress)]
|
||||
return
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func getUDPDirection(socketInfo *socket.BindInfo, udpStates map[string]map[string]*udpState, pktInfo *packet.Info) (connDirection bool) {
|
||||
localKey := makeUDPStateKey(socketInfo.Local)
|
||||
|
||||
bindMap, ok := udpStates[localKey]
|
||||
if !ok {
|
||||
bindMap = make(map[string]*udpState)
|
||||
udpStates[localKey] = bindMap
|
||||
}
|
||||
|
||||
remoteKey := makeUDPStateKey(socket.Address{
|
||||
IP: pktInfo.RemoteIP(),
|
||||
Port: pktInfo.RemotePort(),
|
||||
})
|
||||
udpConnState, ok := bindMap[remoteKey]
|
||||
if !ok {
|
||||
bindMap[remoteKey] = &udpState{
|
||||
inbound: pktInfo.Inbound,
|
||||
lastSeen: time.Now().UTC(),
|
||||
}
|
||||
return pktInfo.Inbound
|
||||
}
|
||||
|
||||
udpConnState.lastSeen = time.Now().UTC()
|
||||
return udpConnState.inbound
|
||||
}
|
||||
|
||||
// CleanUDPStates cleans the udp connection states which save connection directions.
|
||||
func CleanUDPStates(_ context.Context) {
|
||||
now := time.Now().UTC()
|
||||
|
||||
udp4Lock.Lock()
|
||||
updateUDP4Table()
|
||||
cleanStates(udp4Binds, udp4States, now)
|
||||
udp4Lock.Unlock()
|
||||
|
||||
udp6Lock.Lock()
|
||||
updateUDP6Table()
|
||||
cleanStates(udp6Binds, udp6States, now)
|
||||
udp6Lock.Unlock()
|
||||
}
|
||||
|
||||
func cleanStates(
|
||||
binds []*socket.BindInfo,
|
||||
udpStates map[string]map[string]*udpState,
|
||||
now time.Time,
|
||||
) {
|
||||
// compute thresholds
|
||||
threshold := now.Add(-UDPConnStateTTL)
|
||||
shortThreshhold := now.Add(-UDPConnStateShortenedTTL)
|
||||
|
||||
// make lookup map of all active keys
|
||||
bindKeys := make(map[string]struct{})
|
||||
for _, socketInfo := range binds {
|
||||
bindKeys[makeUDPStateKey(socketInfo.Local)] = struct{}{}
|
||||
}
|
||||
|
||||
// clean the udp state storage
|
||||
for localKey, bindMap := range udpStates {
|
||||
if _, active := bindKeys[localKey]; active {
|
||||
// clean old entries
|
||||
for remoteKey, udpConnState := range bindMap {
|
||||
if udpConnState.lastSeen.Before(threshold) {
|
||||
delete(bindMap, remoteKey)
|
||||
}
|
||||
}
|
||||
// if there are too many clean more aggressively
|
||||
if len(bindMap) > AggressiveCleaningThreshold {
|
||||
for remoteKey, udpConnState := range bindMap {
|
||||
if udpConnState.lastSeen.Before(shortThreshhold) {
|
||||
delete(bindMap, remoteKey)
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// delete the whole thing
|
||||
delete(udpStates, localKey)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func makeUDPStateKey(address socket.Address) string {
|
||||
// This could potentially go wrong, but as all IPs are created by the same source, everything should be fine.
|
||||
return string(address.IP) + string(address.Port)
|
||||
}
|
|
@ -74,7 +74,7 @@ func finalizeLogFile(logFile *os.File, logFilePath string) {
|
|||
|
||||
func initControlLogFile() *os.File {
|
||||
// check logging dir
|
||||
logFileBasePath := filepath.Join(logsRoot.Path, "fstree", "control")
|
||||
logFileBasePath := filepath.Join(logsRoot.Path, "control")
|
||||
err := logsRoot.EnsureAbsPath(logFileBasePath)
|
||||
if err != nil {
|
||||
log.Printf("failed to check/create log file folder %s: %s\n", logFileBasePath, err)
|
||||
|
@ -93,7 +93,7 @@ func logControlError(cErr error) {
|
|||
}
|
||||
|
||||
// check logging dir
|
||||
logFileBasePath := filepath.Join(logsRoot.Path, "fstree", "control")
|
||||
logFileBasePath := filepath.Join(logsRoot.Path, "control")
|
||||
err := logsRoot.EnsureAbsPath(logFileBasePath)
|
||||
if err != nil {
|
||||
log.Printf("failed to check/create log file folder %s: %s\n", logFileBasePath, err)
|
||||
|
@ -114,7 +114,7 @@ func logControlError(cErr error) {
|
|||
//nolint:deadcode,unused // TODO
|
||||
func logControlStack() {
|
||||
// check logging dir
|
||||
logFileBasePath := filepath.Join(logsRoot.Path, "fstree", "control")
|
||||
logFileBasePath := filepath.Join(logsRoot.Path, "control")
|
||||
err := logsRoot.EnsureAbsPath(logFileBasePath)
|
||||
if err != nil {
|
||||
log.Printf("failed to check/create log file folder %s: %s\n", logFileBasePath, err)
|
||||
|
|
|
@ -227,7 +227,7 @@ func execute(opts *Options, args []string) (cont bool, err error) {
|
|||
|
||||
// log files
|
||||
var logFile, errorFile *os.File
|
||||
logFileBasePath := filepath.Join(logsRoot.Path, "fstree", opts.ShortIdentifier)
|
||||
logFileBasePath := filepath.Join(logsRoot.Path, opts.ShortIdentifier)
|
||||
err = logsRoot.EnsureAbsPath(logFileBasePath)
|
||||
if err != nil {
|
||||
log.Printf("failed to check/create log file dir %s: %s\n", logFileBasePath, err)
|
||||
|
|
133
process/find.go
133
process/find.go
|
@ -3,7 +3,8 @@ package process
|
|||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
|
||||
"github.com/safing/portmaster/network/state"
|
||||
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portmaster/network/packet"
|
||||
|
@ -11,131 +12,28 @@ import (
|
|||
|
||||
// Errors
|
||||
var (
|
||||
ErrConnectionNotFound = errors.New("could not find connection in system state tables")
|
||||
ErrProcessNotFound = errors.New("could not find process in system state tables")
|
||||
ErrProcessNotFound = errors.New("could not find process in system state tables")
|
||||
)
|
||||
|
||||
// GetPidByPacket returns the pid of the owner of the packet.
|
||||
func GetPidByPacket(pkt packet.Packet) (pid int, direction bool, err error) {
|
||||
|
||||
var localIP net.IP
|
||||
var localPort uint16
|
||||
var remoteIP net.IP
|
||||
var remotePort uint16
|
||||
if pkt.IsInbound() {
|
||||
localIP = pkt.Info().Dst
|
||||
remoteIP = pkt.Info().Src
|
||||
} else {
|
||||
localIP = pkt.Info().Src
|
||||
remoteIP = pkt.Info().Dst
|
||||
}
|
||||
if pkt.HasPorts() {
|
||||
if pkt.IsInbound() {
|
||||
localPort = pkt.Info().DstPort
|
||||
remotePort = pkt.Info().SrcPort
|
||||
} else {
|
||||
localPort = pkt.Info().SrcPort
|
||||
remotePort = pkt.Info().DstPort
|
||||
}
|
||||
}
|
||||
|
||||
switch {
|
||||
case pkt.Info().Protocol == packet.TCP && pkt.Info().Version == packet.IPv4:
|
||||
return getTCP4PacketInfo(localIP, localPort, remoteIP, remotePort, pkt.IsInbound())
|
||||
case pkt.Info().Protocol == packet.UDP && pkt.Info().Version == packet.IPv4:
|
||||
return getUDP4PacketInfo(localIP, localPort, remoteIP, remotePort, pkt.IsInbound())
|
||||
case pkt.Info().Protocol == packet.TCP && pkt.Info().Version == packet.IPv6:
|
||||
return getTCP6PacketInfo(localIP, localPort, remoteIP, remotePort, pkt.IsInbound())
|
||||
case pkt.Info().Protocol == packet.UDP && pkt.Info().Version == packet.IPv6:
|
||||
return getUDP6PacketInfo(localIP, localPort, remoteIP, remotePort, pkt.IsInbound())
|
||||
default:
|
||||
return UnidentifiedProcessID, false, errors.New("unsupported protocol for finding process")
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// GetProcessByPacket returns the process that owns the given packet.
|
||||
func GetProcessByPacket(pkt packet.Packet) (process *Process, direction bool, err error) {
|
||||
if !enableProcessDetection() {
|
||||
log.Tracer(pkt.Ctx()).Tracef("process: process detection disabled")
|
||||
return GetUnidentifiedProcess(pkt.Ctx()), pkt.Info().Direction, nil
|
||||
}
|
||||
|
||||
log.Tracer(pkt.Ctx()).Tracef("process: getting process and profile by packet")
|
||||
|
||||
var pid int
|
||||
pid, direction, err = GetPidByPacket(pkt)
|
||||
if err != nil {
|
||||
log.Tracer(pkt.Ctx()).Errorf("process: failed to find PID of connection: %s", err)
|
||||
return nil, direction, err
|
||||
}
|
||||
if pid < 0 {
|
||||
log.Tracer(pkt.Ctx()).Errorf("process: %s", ErrConnectionNotFound.Error())
|
||||
return nil, direction, ErrConnectionNotFound
|
||||
}
|
||||
|
||||
process, err = GetOrFindPrimaryProcess(pkt.Ctx(), pid)
|
||||
if err != nil {
|
||||
log.Tracer(pkt.Ctx()).Errorf("process: failed to find (primary) process with PID: %s", err)
|
||||
return nil, direction, err
|
||||
}
|
||||
|
||||
err = process.GetProfile(pkt.Ctx())
|
||||
if err != nil {
|
||||
log.Tracer(pkt.Ctx()).Errorf("process: failed to get profile for process %s: %s", process, err)
|
||||
}
|
||||
|
||||
return process, direction, nil
|
||||
|
||||
}
|
||||
|
||||
// GetPidByEndpoints returns the pid of the owner of the described link.
|
||||
func GetPidByEndpoints(localIP net.IP, localPort uint16, remoteIP net.IP, remotePort uint16, protocol packet.IPProtocol) (pid int, direction bool, err error) {
|
||||
|
||||
ipVersion := packet.IPv4
|
||||
if v4 := localIP.To4(); v4 == nil {
|
||||
ipVersion = packet.IPv6
|
||||
}
|
||||
|
||||
switch {
|
||||
case protocol == packet.TCP && ipVersion == packet.IPv4:
|
||||
return getTCP4PacketInfo(localIP, localPort, remoteIP, remotePort, false)
|
||||
case protocol == packet.UDP && ipVersion == packet.IPv4:
|
||||
return getUDP4PacketInfo(localIP, localPort, remoteIP, remotePort, false)
|
||||
case protocol == packet.TCP && ipVersion == packet.IPv6:
|
||||
return getTCP6PacketInfo(localIP, localPort, remoteIP, remotePort, false)
|
||||
case protocol == packet.UDP && ipVersion == packet.IPv6:
|
||||
return getUDP6PacketInfo(localIP, localPort, remoteIP, remotePort, false)
|
||||
default:
|
||||
return UnidentifiedProcessID, false, errors.New("unsupported protocol for finding process")
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// GetProcessByEndpoints returns the process that owns the described link.
|
||||
func GetProcessByEndpoints(ctx context.Context, localIP net.IP, localPort uint16, remoteIP net.IP, remotePort uint16, protocol packet.IPProtocol) (process *Process, err error) {
|
||||
// GetProcessByConnection returns the process that owns the described connection.
|
||||
func GetProcessByConnection(ctx context.Context, pktInfo *packet.Info) (process *Process, connInbound bool, err error) {
|
||||
if !enableProcessDetection() {
|
||||
log.Tracer(ctx).Tracef("process: process detection disabled")
|
||||
return GetUnidentifiedProcess(ctx), nil
|
||||
return GetUnidentifiedProcess(ctx), pktInfo.Inbound, nil
|
||||
}
|
||||
|
||||
log.Tracer(ctx).Tracef("process: getting process and profile by endpoints")
|
||||
|
||||
log.Tracer(ctx).Tracef("process: getting pid from system network state")
|
||||
var pid int
|
||||
pid, _, err = GetPidByEndpoints(localIP, localPort, remoteIP, remotePort, protocol)
|
||||
pid, connInbound, err = state.Lookup(pktInfo)
|
||||
if err != nil {
|
||||
log.Tracer(ctx).Errorf("process: failed to find PID of connection: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
if pid < 0 {
|
||||
log.Tracer(ctx).Errorf("process: %s", ErrConnectionNotFound.Error())
|
||||
return nil, ErrConnectionNotFound
|
||||
log.Tracer(ctx).Debugf("process: failed to find PID of connection: %s", err)
|
||||
return nil, connInbound, err
|
||||
}
|
||||
|
||||
process, err = GetOrFindPrimaryProcess(ctx, pid)
|
||||
if err != nil {
|
||||
log.Tracer(ctx).Errorf("process: failed to find (primary) process with PID: %s", err)
|
||||
return nil, err
|
||||
log.Tracer(ctx).Debugf("process: failed to find (primary) process with PID: %s", err)
|
||||
return nil, connInbound, err
|
||||
}
|
||||
|
||||
err = process.GetProfile(ctx)
|
||||
|
@ -143,10 +41,5 @@ func GetProcessByEndpoints(ctx context.Context, localIP net.IP, localPort uint16
|
|||
log.Tracer(ctx).Errorf("process: failed to get profile for process %s: %s", process, err)
|
||||
}
|
||||
|
||||
return process, nil
|
||||
}
|
||||
|
||||
// GetActiveConnectionIDs returns a list of all active connection IDs.
|
||||
func GetActiveConnectionIDs() []string {
|
||||
return getActiveConnectionIDs()
|
||||
return process, connInbound, nil
|
||||
}
|
||||
|
|
|
@ -1,13 +0,0 @@
|
|||
package process
|
||||
|
||||
import (
|
||||
"github.com/safing/portmaster/process/proc"
|
||||
)
|
||||
|
||||
var (
|
||||
getTCP4PacketInfo = proc.GetTCP4PacketInfo
|
||||
getTCP6PacketInfo = proc.GetTCP6PacketInfo
|
||||
getUDP4PacketInfo = proc.GetUDP4PacketInfo
|
||||
getUDP6PacketInfo = proc.GetUDP6PacketInfo
|
||||
getActiveConnectionIDs = proc.GetActiveConnectionIDs
|
||||
)
|
|
@ -1,13 +0,0 @@
|
|||
package process
|
||||
|
||||
import (
|
||||
"github.com/safing/portmaster/process/iphelper"
|
||||
)
|
||||
|
||||
var (
|
||||
getTCP4PacketInfo = iphelper.GetTCP4PacketInfo
|
||||
getTCP6PacketInfo = iphelper.GetTCP6PacketInfo
|
||||
getUDP4PacketInfo = iphelper.GetUDP4PacketInfo
|
||||
getUDP6PacketInfo = iphelper.GetUDP6PacketInfo
|
||||
getActiveConnectionIDs = iphelper.GetActiveConnectionIDs
|
||||
)
|
|
@ -1,260 +0,0 @@
|
|||
// +build windows
|
||||
|
||||
package iphelper
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
unidentifiedProcessID = -1
|
||||
)
|
||||
|
||||
var (
|
||||
tcp4Connections []*ConnectionEntry
|
||||
tcp4Listeners []*ConnectionEntry
|
||||
tcp6Connections []*ConnectionEntry
|
||||
tcp6Listeners []*ConnectionEntry
|
||||
|
||||
udp4Connections []*ConnectionEntry
|
||||
udp4Listeners []*ConnectionEntry
|
||||
udp6Connections []*ConnectionEntry
|
||||
udp6Listeners []*ConnectionEntry
|
||||
|
||||
ipHelper *IPHelper
|
||||
lock sync.RWMutex
|
||||
|
||||
waitTime = 15 * time.Millisecond
|
||||
)
|
||||
|
||||
func checkIPHelper() (err error) {
|
||||
if ipHelper == nil {
|
||||
ipHelper, err = New()
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetTCP4PacketInfo returns the pid of the given IPv4/TCP connection.
|
||||
func GetTCP4PacketInfo(localIP net.IP, localPort uint16, remoteIP net.IP, remotePort uint16, pktDirection bool) (pid int, direction bool, err error) {
|
||||
|
||||
// search
|
||||
pid, _ = search(tcp4Connections, tcp4Listeners, localIP, remoteIP, localPort, remotePort, pktDirection)
|
||||
if pid >= 0 {
|
||||
return pid, pktDirection, nil
|
||||
}
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
// give kernel some time, then try again
|
||||
// log.Tracef("process: giving kernel some time to think")
|
||||
|
||||
// if unable to find, refresh
|
||||
lock.Lock()
|
||||
err = checkIPHelper()
|
||||
if err == nil {
|
||||
tcp4Connections, tcp4Listeners, err = ipHelper.GetTables(TCP, IPv4)
|
||||
}
|
||||
lock.Unlock()
|
||||
if err != nil {
|
||||
return unidentifiedProcessID, pktDirection, err
|
||||
}
|
||||
|
||||
// search
|
||||
pid, _ = search(tcp4Connections, tcp4Listeners, localIP, remoteIP, localPort, remotePort, pktDirection)
|
||||
if pid >= 0 {
|
||||
return pid, pktDirection, nil
|
||||
}
|
||||
|
||||
time.Sleep(waitTime)
|
||||
}
|
||||
|
||||
return unidentifiedProcessID, pktDirection, nil
|
||||
}
|
||||
|
||||
// GetTCP6PacketInfo returns the pid of the given IPv6/TCP connection.
|
||||
func GetTCP6PacketInfo(localIP net.IP, localPort uint16, remoteIP net.IP, remotePort uint16, pktDirection bool) (pid int, direction bool, err error) {
|
||||
|
||||
// search
|
||||
pid, _ = search(tcp6Connections, tcp6Listeners, localIP, remoteIP, localPort, remotePort, pktDirection)
|
||||
if pid >= 0 {
|
||||
return pid, pktDirection, nil
|
||||
}
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
// give kernel some time, then try again
|
||||
// log.Tracef("process: giving kernel some time to think")
|
||||
|
||||
// if unable to find, refresh
|
||||
lock.Lock()
|
||||
err = checkIPHelper()
|
||||
if err == nil {
|
||||
tcp6Connections, tcp6Listeners, err = ipHelper.GetTables(TCP, IPv6)
|
||||
}
|
||||
lock.Unlock()
|
||||
if err != nil {
|
||||
return unidentifiedProcessID, pktDirection, err
|
||||
}
|
||||
|
||||
// search
|
||||
pid, _ = search(tcp6Connections, tcp6Listeners, localIP, remoteIP, localPort, remotePort, pktDirection)
|
||||
if pid >= 0 {
|
||||
return pid, pktDirection, nil
|
||||
}
|
||||
|
||||
time.Sleep(waitTime)
|
||||
}
|
||||
|
||||
return unidentifiedProcessID, pktDirection, nil
|
||||
}
|
||||
|
||||
// GetUDP4PacketInfo returns the pid of the given IPv4/UDP connection.
|
||||
func GetUDP4PacketInfo(localIP net.IP, localPort uint16, remoteIP net.IP, remotePort uint16, pktDirection bool) (pid int, direction bool, err error) {
|
||||
|
||||
// search
|
||||
pid, _ = search(udp4Connections, udp4Listeners, localIP, remoteIP, localPort, remotePort, pktDirection)
|
||||
if pid >= 0 {
|
||||
return pid, pktDirection, nil
|
||||
}
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
// give kernel some time, then try again
|
||||
// log.Tracef("process: giving kernel some time to think")
|
||||
|
||||
// if unable to find, refresh
|
||||
lock.Lock()
|
||||
err = checkIPHelper()
|
||||
if err == nil {
|
||||
udp4Connections, udp4Listeners, err = ipHelper.GetTables(UDP, IPv4)
|
||||
}
|
||||
lock.Unlock()
|
||||
if err != nil {
|
||||
return unidentifiedProcessID, pktDirection, err
|
||||
}
|
||||
|
||||
// search
|
||||
pid, _ = search(udp4Connections, udp4Listeners, localIP, remoteIP, localPort, remotePort, pktDirection)
|
||||
if pid >= 0 {
|
||||
return pid, pktDirection, nil
|
||||
}
|
||||
|
||||
time.Sleep(waitTime)
|
||||
}
|
||||
|
||||
return unidentifiedProcessID, pktDirection, nil
|
||||
}
|
||||
|
||||
// GetUDP6PacketInfo returns the pid of the given IPv6/UDP connection.
|
||||
func GetUDP6PacketInfo(localIP net.IP, localPort uint16, remoteIP net.IP, remotePort uint16, pktDirection bool) (pid int, direction bool, err error) {
|
||||
|
||||
// search
|
||||
pid, _ = search(udp6Connections, udp6Listeners, localIP, remoteIP, localPort, remotePort, pktDirection)
|
||||
if pid >= 0 {
|
||||
return pid, pktDirection, nil
|
||||
}
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
// give kernel some time, then try again
|
||||
// log.Tracef("process: giving kernel some time to think")
|
||||
|
||||
// if unable to find, refresh
|
||||
lock.Lock()
|
||||
err = checkIPHelper()
|
||||
if err == nil {
|
||||
udp6Connections, udp6Listeners, err = ipHelper.GetTables(UDP, IPv6)
|
||||
}
|
||||
lock.Unlock()
|
||||
if err != nil {
|
||||
return unidentifiedProcessID, pktDirection, err
|
||||
}
|
||||
|
||||
// search
|
||||
pid, _ = search(udp6Connections, udp6Listeners, localIP, remoteIP, localPort, remotePort, pktDirection)
|
||||
if pid >= 0 {
|
||||
return pid, pktDirection, nil
|
||||
}
|
||||
|
||||
time.Sleep(waitTime)
|
||||
}
|
||||
|
||||
return unidentifiedProcessID, pktDirection, nil
|
||||
}
|
||||
|
||||
func search(connections, listeners []*ConnectionEntry, localIP, remoteIP net.IP, localPort, remotePort uint16, pktDirection bool) (pid int, direction bool) { //nolint:unparam // TODO: use direction, it may not be used because results caused problems, investigate.
|
||||
lock.RLock()
|
||||
defer lock.RUnlock()
|
||||
|
||||
if pktDirection {
|
||||
// inbound
|
||||
pid = searchListeners(listeners, localIP, localPort)
|
||||
if pid >= 0 {
|
||||
return pid, true
|
||||
}
|
||||
pid = searchConnections(connections, localIP, remoteIP, localPort, remotePort)
|
||||
if pid >= 0 {
|
||||
return pid, false
|
||||
}
|
||||
} else {
|
||||
// outbound
|
||||
pid = searchConnections(connections, localIP, remoteIP, localPort, remotePort)
|
||||
if pid >= 0 {
|
||||
return pid, false
|
||||
}
|
||||
pid = searchListeners(listeners, localIP, localPort)
|
||||
if pid >= 0 {
|
||||
return pid, true
|
||||
}
|
||||
}
|
||||
|
||||
return unidentifiedProcessID, pktDirection
|
||||
}
|
||||
|
||||
func searchConnections(list []*ConnectionEntry, localIP, remoteIP net.IP, localPort, remotePort uint16) (pid int) {
|
||||
|
||||
for _, entry := range list {
|
||||
if localPort == entry.localPort &&
|
||||
remotePort == entry.remotePort &&
|
||||
remoteIP.Equal(entry.remoteIP) &&
|
||||
localIP.Equal(entry.localIP) {
|
||||
return entry.pid
|
||||
}
|
||||
}
|
||||
|
||||
return unidentifiedProcessID
|
||||
}
|
||||
|
||||
func searchListeners(list []*ConnectionEntry, localIP net.IP, localPort uint16) (pid int) {
|
||||
|
||||
for _, entry := range list {
|
||||
if localPort == entry.localPort &&
|
||||
(entry.localIP == nil || // nil IP means zero IP, see tables.go
|
||||
localIP.Equal(entry.localIP)) {
|
||||
return entry.pid
|
||||
}
|
||||
}
|
||||
|
||||
return unidentifiedProcessID
|
||||
}
|
||||
|
||||
// GetActiveConnectionIDs returns all currently active connection IDs.
|
||||
func GetActiveConnectionIDs() (connections []string) {
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
|
||||
for _, entry := range tcp4Connections {
|
||||
connections = append(connections, fmt.Sprintf("%d-%s-%d-%s-%d", TCP, entry.localIP, entry.localPort, entry.remoteIP, entry.remotePort))
|
||||
}
|
||||
for _, entry := range tcp6Connections {
|
||||
connections = append(connections, fmt.Sprintf("%d-%s-%d-%s-%d", TCP, entry.localIP, entry.localPort, entry.remoteIP, entry.remotePort))
|
||||
}
|
||||
for _, entry := range udp4Connections {
|
||||
connections = append(connections, fmt.Sprintf("%d-%s-%d-%s-%d", UDP, entry.localIP, entry.localPort, entry.remoteIP, entry.remotePort))
|
||||
}
|
||||
for _, entry := range udp6Connections {
|
||||
connections = append(connections, fmt.Sprintf("%d-%s-%d-%s-%d", UDP, entry.localIP, entry.localPort, entry.remoteIP, entry.remotePort))
|
||||
}
|
||||
|
||||
return
|
||||
}
|
|
@ -1,79 +0,0 @@
|
|||
// +build windows
|
||||
|
||||
package iphelper
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/tevino/abool"
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
var (
|
||||
errInvalid = errors.New("IPHelper not initialzed or broken")
|
||||
)
|
||||
|
||||
// IPHelper represents a subset of the Windows iphlpapi.dll.
|
||||
type IPHelper struct {
|
||||
dll *windows.LazyDLL
|
||||
|
||||
getExtendedTCPTable *windows.LazyProc
|
||||
getExtendedUDPTable *windows.LazyProc
|
||||
// getOwnerModuleFromTcpEntry *windows.LazyProc
|
||||
// getOwnerModuleFromTcp6Entry *windows.LazyProc
|
||||
// getOwnerModuleFromUdpEntry *windows.LazyProc
|
||||
// getOwnerModuleFromUdp6Entry *windows.LazyProc
|
||||
|
||||
valid *abool.AtomicBool
|
||||
}
|
||||
|
||||
// New returns a new IPHelper API (with an instance of iphlpapi.dll loaded).
|
||||
func New() (*IPHelper, error) {
|
||||
|
||||
new := &IPHelper{}
|
||||
new.valid = abool.NewBool(false)
|
||||
var err error
|
||||
|
||||
// load dll
|
||||
new.dll = windows.NewLazySystemDLL("iphlpapi.dll")
|
||||
err = new.dll.Load()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// load functions
|
||||
new.getExtendedTCPTable = new.dll.NewProc("GetExtendedTcpTable")
|
||||
err = new.getExtendedTCPTable.Find()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could find proc GetExtendedTcpTable: %s", err)
|
||||
}
|
||||
new.getExtendedUDPTable = new.dll.NewProc("GetExtendedUdpTable")
|
||||
err = new.getExtendedUDPTable.Find()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could find proc GetExtendedUdpTable: %s", err)
|
||||
}
|
||||
// new.getOwnerModuleFromTcpEntry = new.dll.NewProc("GetOwnerModuleFromTcpEntry")
|
||||
// err = new.getOwnerModuleFromTcpEntry.Find()
|
||||
// if err != nil {
|
||||
// return nil, fmt.Errorf("could find proc GetOwnerModuleFromTcpEntry: %s", err)
|
||||
// }
|
||||
// new.getOwnerModuleFromTcp6Entry = new.dll.NewProc("GetOwnerModuleFromTcp6Entry")
|
||||
// err = new.getOwnerModuleFromTcp6Entry.Find()
|
||||
// if err != nil {
|
||||
// return nil, fmt.Errorf("could find proc GetOwnerModuleFromTcp6Entry: %s", err)
|
||||
// }
|
||||
// new.getOwnerModuleFromUdpEntry = new.dll.NewProc("GetOwnerModuleFromUdpEntry")
|
||||
// err = new.getOwnerModuleFromUdpEntry.Find()
|
||||
// if err != nil {
|
||||
// return nil, fmt.Errorf("could find proc GetOwnerModuleFromUdpEntry: %s", err)
|
||||
// }
|
||||
// new.getOwnerModuleFromUdp6Entry = new.dll.NewProc("GetOwnerModuleFromUdp6Entry")
|
||||
// err = new.getOwnerModuleFromUdp6Entry.Find()
|
||||
// if err != nil {
|
||||
// return nil, fmt.Errorf("could find proc GetOwnerModuleFromUdp6Entry: %s", err)
|
||||
// }
|
||||
|
||||
new.valid.Set()
|
||||
return new, nil
|
||||
}
|
|
@ -1,62 +0,0 @@
|
|||
// +build windows
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/safing/portmaster/process/iphelper"
|
||||
)
|
||||
|
||||
func main() {
|
||||
iph, err := iphelper.New()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
fmt.Printf("TCP4\n")
|
||||
conns, lConns, err := iph.GetTables(iphelper.TCP, iphelper.IPv4)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
fmt.Printf("Connections:\n")
|
||||
for _, conn := range conns {
|
||||
fmt.Printf("%s\n", conn)
|
||||
}
|
||||
fmt.Printf("Listeners:\n")
|
||||
for _, conn := range lConns {
|
||||
fmt.Printf("%s\n", conn)
|
||||
}
|
||||
|
||||
fmt.Printf("\nTCP6\n")
|
||||
conns, lConns, err = iph.GetTables(iphelper.TCP, iphelper.IPv6)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
fmt.Printf("Connections:\n")
|
||||
for _, conn := range conns {
|
||||
fmt.Printf("%s\n", conn)
|
||||
}
|
||||
fmt.Printf("Listeners:\n")
|
||||
for _, conn := range lConns {
|
||||
fmt.Printf("%s\n", conn)
|
||||
}
|
||||
|
||||
fmt.Printf("\nUDP4\n")
|
||||
_, lConns, err = iph.GetTables(iphelper.UDP, iphelper.IPv4)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
for _, conn := range lConns {
|
||||
fmt.Printf("%s\n", conn)
|
||||
}
|
||||
|
||||
fmt.Printf("\nUDP6\n")
|
||||
_, lConns, err = iph.GetTables(iphelper.UDP, iphelper.IPv6)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
for _, conn := range lConns {
|
||||
fmt.Printf("%s\n", conn)
|
||||
}
|
||||
}
|
|
@ -1,83 +0,0 @@
|
|||
// +build linux
|
||||
|
||||
package proc
|
||||
|
||||
import (
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
// PID querying return codes
|
||||
const (
|
||||
Success uint8 = iota
|
||||
NoSocket
|
||||
NoProcess
|
||||
)
|
||||
|
||||
var (
|
||||
waitTime = 15 * time.Millisecond
|
||||
)
|
||||
|
||||
// GetPidOfConnection returns the PID of the given connection.
|
||||
func GetPidOfConnection(localIP net.IP, localPort uint16, protocol uint8) (pid int, status uint8) {
|
||||
uid, inode, ok := getConnectionSocket(localIP, localPort, protocol)
|
||||
if !ok {
|
||||
uid, inode, ok = getListeningSocket(localIP, localPort, protocol)
|
||||
for i := 0; i < 3 && !ok; i++ {
|
||||
// give kernel some time, then try again
|
||||
// log.Tracef("process: giving kernel some time to think")
|
||||
time.Sleep(waitTime)
|
||||
uid, inode, ok = getConnectionSocket(localIP, localPort, protocol)
|
||||
if !ok {
|
||||
uid, inode, ok = getListeningSocket(localIP, localPort, protocol)
|
||||
}
|
||||
}
|
||||
if !ok {
|
||||
return unidentifiedProcessID, NoSocket
|
||||
}
|
||||
}
|
||||
|
||||
pid, ok = GetPidOfInode(uid, inode)
|
||||
for i := 0; i < 3 && !ok; i++ {
|
||||
// give kernel some time, then try again
|
||||
// log.Tracef("process: giving kernel some time to think")
|
||||
time.Sleep(waitTime)
|
||||
pid, ok = GetPidOfInode(uid, inode)
|
||||
}
|
||||
if !ok {
|
||||
return unidentifiedProcessID, NoProcess
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// GetPidOfIncomingConnection returns the PID of the given incoming connection.
|
||||
func GetPidOfIncomingConnection(localIP net.IP, localPort uint16, protocol uint8) (pid int, status uint8) {
|
||||
uid, inode, ok := getListeningSocket(localIP, localPort, protocol)
|
||||
if !ok {
|
||||
// for TCP4 and UDP4, also try TCP6 and UDP6, as linux sometimes treats them as a single dual socket, and shows the IPv6 version.
|
||||
switch protocol {
|
||||
case TCP4:
|
||||
uid, inode, ok = getListeningSocket(localIP, localPort, TCP6)
|
||||
case UDP4:
|
||||
uid, inode, ok = getListeningSocket(localIP, localPort, UDP6)
|
||||
}
|
||||
|
||||
if !ok {
|
||||
return unidentifiedProcessID, NoSocket
|
||||
}
|
||||
}
|
||||
|
||||
pid, ok = GetPidOfInode(uid, inode)
|
||||
for i := 0; i < 3 && !ok; i++ {
|
||||
// give kernel some time, then try again
|
||||
// log.Tracef("process: giving kernel some time to think")
|
||||
time.Sleep(waitTime)
|
||||
pid, ok = GetPidOfInode(uid, inode)
|
||||
}
|
||||
if !ok {
|
||||
return unidentifiedProcessID, NoProcess
|
||||
}
|
||||
|
||||
return
|
||||
}
|
|
@ -1,66 +0,0 @@
|
|||
// +build linux
|
||||
|
||||
package proc
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
)
|
||||
|
||||
const (
|
||||
unidentifiedProcessID = -1
|
||||
)
|
||||
|
||||
// GetTCP4PacketInfo searches the network state tables for a TCP4 connection
|
||||
func GetTCP4PacketInfo(localIP net.IP, localPort uint16, remoteIP net.IP, remotePort uint16, pktDirection bool) (pid int, direction bool, err error) {
|
||||
return search(TCP4, localIP, localPort, pktDirection)
|
||||
}
|
||||
|
||||
// GetTCP6PacketInfo searches the network state tables for a TCP6 connection
|
||||
func GetTCP6PacketInfo(localIP net.IP, localPort uint16, remoteIP net.IP, remotePort uint16, pktDirection bool) (pid int, direction bool, err error) {
|
||||
return search(TCP6, localIP, localPort, pktDirection)
|
||||
}
|
||||
|
||||
// GetUDP4PacketInfo searches the network state tables for a UDP4 connection
|
||||
func GetUDP4PacketInfo(localIP net.IP, localPort uint16, remoteIP net.IP, remotePort uint16, pktDirection bool) (pid int, direction bool, err error) {
|
||||
return search(UDP4, localIP, localPort, pktDirection)
|
||||
}
|
||||
|
||||
// GetUDP6PacketInfo searches the network state tables for a UDP6 connection
|
||||
func GetUDP6PacketInfo(localIP net.IP, localPort uint16, remoteIP net.IP, remotePort uint16, pktDirection bool) (pid int, direction bool, err error) {
|
||||
return search(UDP6, localIP, localPort, pktDirection)
|
||||
}
|
||||
|
||||
func search(protocol uint8, localIP net.IP, localPort uint16, pktDirection bool) (pid int, direction bool, err error) {
|
||||
|
||||
var status uint8
|
||||
if pktDirection {
|
||||
pid, status = GetPidOfIncomingConnection(localIP, localPort, protocol)
|
||||
if pid >= 0 {
|
||||
return pid, true, nil
|
||||
}
|
||||
// pid, status = GetPidOfConnection(localIP, localPort, protocol)
|
||||
// if pid >= 0 {
|
||||
// return pid, false, nil
|
||||
// }
|
||||
} else {
|
||||
pid, status = GetPidOfConnection(localIP, localPort, protocol)
|
||||
if pid >= 0 {
|
||||
return pid, false, nil
|
||||
}
|
||||
// pid, status = GetPidOfIncomingConnection(localIP, localPort, protocol)
|
||||
// if pid >= 0 {
|
||||
// return pid, true, nil
|
||||
// }
|
||||
}
|
||||
|
||||
switch status {
|
||||
case NoSocket:
|
||||
return unidentifiedProcessID, direction, errors.New("could not find socket")
|
||||
case NoProcess:
|
||||
return unidentifiedProcessID, direction, errors.New("could not find PID")
|
||||
default:
|
||||
return unidentifiedProcessID, direction, nil
|
||||
}
|
||||
|
||||
}
|
|
@ -1,18 +0,0 @@
|
|||
// +build linux
|
||||
|
||||
package proc
|
||||
|
||||
import (
|
||||
"log"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestProcessFinder(t *testing.T) {
|
||||
|
||||
updatePids()
|
||||
log.Printf("pidsByUser: %v", pidsByUser)
|
||||
|
||||
pid, _ := GetPidOfInode(1000, 112588)
|
||||
log.Printf("pid: %d", pid)
|
||||
|
||||
}
|
|
@ -1,370 +0,0 @@
|
|||
// +build linux
|
||||
|
||||
package proc
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"unicode"
|
||||
|
||||
"github.com/safing/portbase/log"
|
||||
)
|
||||
|
||||
/*
|
||||
|
||||
1. find socket inode
|
||||
- by incoming (listenting sockets) or outgoing (local port + external IP + port) - also local IP?
|
||||
- /proc/net/{tcp|udp}[6]
|
||||
|
||||
2. get list of processes of uid
|
||||
|
||||
3. find socket inode in process fds
|
||||
- if not found, refresh map of uid->pids
|
||||
- if not found, check ALL pids: maybe euid != uid
|
||||
|
||||
4. gather process info
|
||||
|
||||
Cache every step!
|
||||
|
||||
*/
|
||||
|
||||
// Network Related Constants
|
||||
const (
|
||||
TCP4 uint8 = iota
|
||||
UDP4
|
||||
TCP6
|
||||
UDP6
|
||||
ICMP4
|
||||
ICMP6
|
||||
|
||||
TCP4Data = "/proc/net/tcp"
|
||||
UDP4Data = "/proc/net/udp"
|
||||
TCP6Data = "/proc/net/tcp6"
|
||||
UDP6Data = "/proc/net/udp6"
|
||||
ICMP4Data = "/proc/net/icmp"
|
||||
ICMP6Data = "/proc/net/icmp6"
|
||||
)
|
||||
|
||||
var (
|
||||
// connectionSocketsLock sync.Mutex
|
||||
// connectionTCP4 = make(map[string][]int)
|
||||
// connectionUDP4 = make(map[string][]int)
|
||||
// connectionTCP6 = make(map[string][]int)
|
||||
// connectionUDP6 = make(map[string][]int)
|
||||
|
||||
listeningSocketsLock sync.Mutex
|
||||
addressListeningTCP4 = make(map[string][]int)
|
||||
addressListeningUDP4 = make(map[string][]int)
|
||||
addressListeningTCP6 = make(map[string][]int)
|
||||
addressListeningUDP6 = make(map[string][]int)
|
||||
globalListeningTCP4 = make(map[uint16][]int)
|
||||
globalListeningUDP4 = make(map[uint16][]int)
|
||||
globalListeningTCP6 = make(map[uint16][]int)
|
||||
globalListeningUDP6 = make(map[uint16][]int)
|
||||
)
|
||||
|
||||
func getConnectionSocket(localIP net.IP, localPort uint16, protocol uint8) (int, int, bool) {
|
||||
// listeningSocketsLock.Lock()
|
||||
// defer listeningSocketsLock.Unlock()
|
||||
|
||||
var procFile string
|
||||
var localIPHex string
|
||||
switch protocol {
|
||||
case TCP4:
|
||||
procFile = TCP4Data
|
||||
localIPBytes := []byte(localIP.To4())
|
||||
localIPHex = strings.ToUpper(hex.EncodeToString([]byte{localIPBytes[3], localIPBytes[2], localIPBytes[1], localIPBytes[0]}))
|
||||
case UDP4:
|
||||
procFile = UDP4Data
|
||||
localIPBytes := []byte(localIP.To4())
|
||||
localIPHex = strings.ToUpper(hex.EncodeToString([]byte{localIPBytes[3], localIPBytes[2], localIPBytes[1], localIPBytes[0]}))
|
||||
case TCP6:
|
||||
procFile = TCP6Data
|
||||
localIPHex = hex.EncodeToString([]byte(localIP))
|
||||
case UDP6:
|
||||
procFile = UDP6Data
|
||||
localIPHex = hex.EncodeToString([]byte(localIP))
|
||||
}
|
||||
|
||||
localPortHex := fmt.Sprintf("%04X", localPort)
|
||||
|
||||
// log.Tracef("process/proc: searching for PID of: %s:%d (%s:%s)", localIP, localPort, localIPHex, localPortHex)
|
||||
|
||||
// open file
|
||||
socketData, err := os.Open(procFile)
|
||||
if err != nil {
|
||||
log.Warningf("process/proc: could not read %s: %s", procFile, err)
|
||||
return unidentifiedProcessID, unidentifiedProcessID, false
|
||||
}
|
||||
defer socketData.Close()
|
||||
|
||||
// file scanner
|
||||
scanner := bufio.NewScanner(socketData)
|
||||
scanner.Split(bufio.ScanLines)
|
||||
|
||||
// parse
|
||||
scanner.Scan() // skip first line
|
||||
for scanner.Scan() {
|
||||
line := strings.FieldsFunc(scanner.Text(), procDelimiter)
|
||||
// log.Tracef("line: %s", line)
|
||||
if len(line) < 14 {
|
||||
// log.Tracef("process: too short: %s", line)
|
||||
continue
|
||||
}
|
||||
|
||||
if line[1] != localIPHex {
|
||||
continue
|
||||
}
|
||||
if line[2] != localPortHex {
|
||||
continue
|
||||
}
|
||||
|
||||
ok := true
|
||||
|
||||
uid, err := strconv.ParseInt(line[11], 10, 32)
|
||||
if err != nil {
|
||||
log.Warningf("process: could not parse uid %s: %s", line[11], err)
|
||||
uid = -1
|
||||
ok = false
|
||||
}
|
||||
|
||||
inode, err := strconv.ParseInt(line[13], 10, 32)
|
||||
if err != nil {
|
||||
log.Warningf("process: could not parse inode %s: %s", line[13], err)
|
||||
inode = -1
|
||||
ok = false
|
||||
}
|
||||
|
||||
// log.Tracef("process/proc: identified process of %s:%d: socket=%d uid=%d", localIP, localPort, int(inode), int(uid))
|
||||
return int(uid), int(inode), ok
|
||||
|
||||
}
|
||||
|
||||
return unidentifiedProcessID, unidentifiedProcessID, false
|
||||
|
||||
}
|
||||
|
||||
func getListeningSocket(localIP net.IP, localPort uint16, protocol uint8) (uid, inode int, ok bool) {
|
||||
listeningSocketsLock.Lock()
|
||||
defer listeningSocketsLock.Unlock()
|
||||
|
||||
var addressListening map[string][]int
|
||||
var globalListening map[uint16][]int
|
||||
switch protocol {
|
||||
case TCP4:
|
||||
addressListening = addressListeningTCP4
|
||||
globalListening = globalListeningTCP4
|
||||
case UDP4:
|
||||
addressListening = addressListeningUDP4
|
||||
globalListening = globalListeningUDP4
|
||||
case TCP6:
|
||||
addressListening = addressListeningTCP6
|
||||
globalListening = globalListeningTCP6
|
||||
case UDP6:
|
||||
addressListening = addressListeningUDP6
|
||||
globalListening = globalListeningUDP6
|
||||
}
|
||||
|
||||
data, ok := addressListening[fmt.Sprintf("%s:%d", localIP, localPort)]
|
||||
if !ok {
|
||||
data, ok = globalListening[localPort]
|
||||
}
|
||||
if ok {
|
||||
return data[0], data[1], true
|
||||
}
|
||||
updateListeners(protocol)
|
||||
data, ok = addressListening[fmt.Sprintf("%s:%d", localIP, localPort)]
|
||||
if !ok {
|
||||
data, ok = globalListening[localPort]
|
||||
}
|
||||
if ok {
|
||||
return data[0], data[1], true
|
||||
}
|
||||
|
||||
return unidentifiedProcessID, unidentifiedProcessID, false
|
||||
}
|
||||
|
||||
func procDelimiter(c rune) bool {
|
||||
return unicode.IsSpace(c) || c == ':'
|
||||
}
|
||||
|
||||
func convertIPv4(data string) net.IP {
|
||||
decoded, err := hex.DecodeString(data)
|
||||
if err != nil {
|
||||
log.Warningf("process: could not parse IPv4 %s: %s", data, err)
|
||||
return nil
|
||||
}
|
||||
if len(decoded) != 4 {
|
||||
log.Warningf("process: decoded IPv4 %s has wrong length", decoded)
|
||||
return nil
|
||||
}
|
||||
ip := net.IPv4(decoded[3], decoded[2], decoded[1], decoded[0])
|
||||
return ip
|
||||
}
|
||||
|
||||
func convertIPv6(data string) net.IP {
|
||||
decoded, err := hex.DecodeString(data)
|
||||
if err != nil {
|
||||
log.Warningf("process: could not parse IPv6 %s: %s", data, err)
|
||||
return nil
|
||||
}
|
||||
if len(decoded) != 16 {
|
||||
log.Warningf("process: decoded IPv6 %s has wrong length", decoded)
|
||||
return nil
|
||||
}
|
||||
ip := net.IP(decoded)
|
||||
return ip
|
||||
}
|
||||
|
||||
func updateListeners(protocol uint8) {
|
||||
switch protocol {
|
||||
case TCP4:
|
||||
addressListeningTCP4, globalListeningTCP4 = getListenerMaps(TCP4Data, "00000000", "0A", convertIPv4)
|
||||
case UDP4:
|
||||
addressListeningUDP4, globalListeningUDP4 = getListenerMaps(UDP4Data, "00000000", "07", convertIPv4)
|
||||
case TCP6:
|
||||
addressListeningTCP6, globalListeningTCP6 = getListenerMaps(TCP6Data, "00000000000000000000000000000000", "0A", convertIPv6)
|
||||
case UDP6:
|
||||
addressListeningUDP6, globalListeningUDP6 = getListenerMaps(UDP6Data, "00000000000000000000000000000000", "07", convertIPv6)
|
||||
}
|
||||
}
|
||||
|
||||
func getListenerMaps(procFile, zeroIP, socketStatusListening string, ipConverter func(string) net.IP) (map[string][]int, map[uint16][]int) {
|
||||
addressListening := make(map[string][]int)
|
||||
globalListening := make(map[uint16][]int)
|
||||
|
||||
// open file
|
||||
socketData, err := os.Open(procFile)
|
||||
if err != nil {
|
||||
log.Warningf("process: could not read %s: %s", procFile, err)
|
||||
return addressListening, globalListening
|
||||
}
|
||||
defer socketData.Close()
|
||||
|
||||
// file scanner
|
||||
scanner := bufio.NewScanner(socketData)
|
||||
scanner.Split(bufio.ScanLines)
|
||||
|
||||
// parse
|
||||
scanner.Scan() // skip first line
|
||||
for scanner.Scan() {
|
||||
line := strings.FieldsFunc(scanner.Text(), procDelimiter)
|
||||
if len(line) < 14 {
|
||||
// log.Tracef("process: too short: %s", line)
|
||||
continue
|
||||
}
|
||||
if line[5] != socketStatusListening {
|
||||
// skip if not listening
|
||||
// log.Tracef("process: not listening %s: %s", line, line[5])
|
||||
continue
|
||||
}
|
||||
|
||||
port, err := strconv.ParseUint(line[2], 16, 16)
|
||||
// log.Tracef("port: %s", line[2])
|
||||
if err != nil {
|
||||
log.Warningf("process: could not parse port %s: %s", line[2], err)
|
||||
continue
|
||||
}
|
||||
|
||||
uid, err := strconv.ParseInt(line[11], 10, 32)
|
||||
// log.Tracef("uid: %s", line[11])
|
||||
if err != nil {
|
||||
log.Warningf("process: could not parse uid %s: %s", line[11], err)
|
||||
continue
|
||||
}
|
||||
|
||||
inode, err := strconv.ParseInt(line[13], 10, 32)
|
||||
// log.Tracef("inode: %s", line[13])
|
||||
if err != nil {
|
||||
log.Warningf("process: could not parse inode %s: %s", line[13], err)
|
||||
continue
|
||||
}
|
||||
|
||||
if line[1] == zeroIP {
|
||||
globalListening[uint16(port)] = []int{int(uid), int(inode)}
|
||||
} else {
|
||||
address := ipConverter(line[1])
|
||||
if address != nil {
|
||||
addressListening[fmt.Sprintf("%s:%d", address, port)] = []int{int(uid), int(inode)}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return addressListening, globalListening
|
||||
}
|
||||
|
||||
// GetActiveConnectionIDs returns all connection IDs that are still marked as active by the OS.
|
||||
func GetActiveConnectionIDs() []string {
|
||||
var connections []string
|
||||
|
||||
connections = append(connections, getConnectionIDsFromSource(TCP4Data, 6, convertIPv4)...)
|
||||
connections = append(connections, getConnectionIDsFromSource(UDP4Data, 17, convertIPv4)...)
|
||||
connections = append(connections, getConnectionIDsFromSource(TCP6Data, 6, convertIPv6)...)
|
||||
connections = append(connections, getConnectionIDsFromSource(UDP6Data, 17, convertIPv6)...)
|
||||
|
||||
return connections
|
||||
}
|
||||
|
||||
func getConnectionIDsFromSource(source string, protocol uint16, ipConverter func(string) net.IP) []string {
|
||||
var connections []string
|
||||
|
||||
// open file
|
||||
socketData, err := os.Open(source)
|
||||
if err != nil {
|
||||
log.Warningf("process: could not read %s: %s", source, err)
|
||||
return connections
|
||||
}
|
||||
defer socketData.Close()
|
||||
|
||||
// file scanner
|
||||
scanner := bufio.NewScanner(socketData)
|
||||
scanner.Split(bufio.ScanLines)
|
||||
|
||||
// parse
|
||||
scanner.Scan() // skip first line
|
||||
for scanner.Scan() {
|
||||
line := strings.FieldsFunc(scanner.Text(), procDelimiter)
|
||||
if len(line) < 14 {
|
||||
// log.Tracef("process: too short: %s", line)
|
||||
continue
|
||||
}
|
||||
|
||||
// skip listeners and closed connections
|
||||
if line[5] == "0A" || line[5] == "07" {
|
||||
continue
|
||||
}
|
||||
|
||||
localIP := ipConverter(line[1])
|
||||
if localIP == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
localPort, err := strconv.ParseUint(line[2], 16, 16)
|
||||
if err != nil {
|
||||
log.Warningf("process: could not parse port: %s", err)
|
||||
continue
|
||||
}
|
||||
|
||||
remoteIP := ipConverter(line[3])
|
||||
if remoteIP == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
remotePort, err := strconv.ParseUint(line[4], 16, 16)
|
||||
if err != nil {
|
||||
log.Warningf("process: could not parse port: %s", err)
|
||||
continue
|
||||
}
|
||||
|
||||
connections = append(connections, fmt.Sprintf("%d-%s-%d-%s-%d", protocol, localIP, localPort, remoteIP, remotePort))
|
||||
}
|
||||
|
||||
return connections
|
||||
}
|
|
@ -1,40 +0,0 @@
|
|||
// +build linux
|
||||
|
||||
package proc
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSockets(t *testing.T) {
|
||||
|
||||
updateListeners(TCP4)
|
||||
updateListeners(UDP4)
|
||||
updateListeners(TCP6)
|
||||
updateListeners(UDP6)
|
||||
t.Logf("addressListeningTCP4: %v", addressListeningTCP4)
|
||||
t.Logf("globalListeningTCP4: %v", globalListeningTCP4)
|
||||
t.Logf("addressListeningUDP4: %v", addressListeningUDP4)
|
||||
t.Logf("globalListeningUDP4: %v", globalListeningUDP4)
|
||||
t.Logf("addressListeningTCP6: %v", addressListeningTCP6)
|
||||
t.Logf("globalListeningTCP6: %v", globalListeningTCP6)
|
||||
t.Logf("addressListeningUDP6: %v", addressListeningUDP6)
|
||||
t.Logf("globalListeningUDP6: %v", globalListeningUDP6)
|
||||
|
||||
getListeningSocket(net.IPv4zero, 53, TCP4)
|
||||
getListeningSocket(net.IPv4zero, 53, UDP4)
|
||||
getListeningSocket(net.IPv6zero, 53, TCP6)
|
||||
getListeningSocket(net.IPv6zero, 53, UDP6)
|
||||
|
||||
// spotify: 192.168.0.102:5353 192.121.140.65:80
|
||||
localIP := net.IPv4(192, 168, 127, 10)
|
||||
uid, inode, ok := getConnectionSocket(localIP, 46634, TCP4)
|
||||
t.Logf("getConnectionSocket: %d %d %v", uid, inode, ok)
|
||||
|
||||
activeConnectionIDs := GetActiveConnectionIDs()
|
||||
for _, connID := range activeConnectionIDs {
|
||||
t.Logf("active: %s", connID)
|
||||
}
|
||||
|
||||
}
|
|
@ -49,7 +49,8 @@ type Process struct {
|
|||
FirstSeen int64
|
||||
LastSeen int64
|
||||
|
||||
Virtual bool // This process is either merged into another process or is not needed.
|
||||
Virtual bool // This process is either merged into another process or is not needed.
|
||||
Error string // Cache errors
|
||||
}
|
||||
|
||||
// Profile returns the assigned layered profile.
|
||||
|
@ -94,6 +95,7 @@ func GetOrFindPrimaryProcess(ctx context.Context, pid int) (*Process, error) {
|
|||
parentProcess, err := loadProcess(ctx, process.ParentPid)
|
||||
if err != nil {
|
||||
log.Tracer(ctx).Tracef("process: could not get parent of %d: %d: %s", process.Pid, process.ParentPid, err)
|
||||
saveFailedProcess(process.ParentPid, err.Error())
|
||||
return process, nil
|
||||
}
|
||||
|
||||
|
@ -226,13 +228,7 @@ func loadProcess(ctx context.Context, pid int) (*Process, error) {
|
|||
|
||||
pInfo, err := processInfo.NewProcess(int32(pid))
|
||||
if err != nil {
|
||||
// TODO: remove this workaround as soon as NewProcess really returns an error on windows when the process does not exist
|
||||
// Issue: https://github.com/shirou/gopsutil/issues/729
|
||||
_, err = pInfo.Name()
|
||||
if err != nil {
|
||||
// process does not exists
|
||||
return nil, err
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// UID
|
||||
|
@ -375,3 +371,14 @@ func loadProcess(ctx context.Context, pid int) (*Process, error) {
|
|||
new.Save()
|
||||
return new, nil
|
||||
}
|
||||
|
||||
func saveFailedProcess(pid int, err string) {
|
||||
failed := &Process{
|
||||
Pid: pid,
|
||||
FirstSeen: time.Now().Unix(),
|
||||
Virtual: true, // not needed
|
||||
Error: err,
|
||||
}
|
||||
|
||||
failed.Save()
|
||||
}
|
||||
|
|
|
@ -94,6 +94,7 @@ func registerConfiguration() error {
|
|||
Description: `The default filter action when nothing else permits or blocks a connection.`,
|
||||
Order: cfgOptionDefaultActionOrder,
|
||||
OptType: config.OptTypeString,
|
||||
ReleaseLevel: config.ReleaseLevelExperimental,
|
||||
DefaultValue: "permit",
|
||||
ExternalOptType: "string list",
|
||||
ValidationRegex: "^(permit|ask|block)$",
|
||||
|
@ -121,17 +122,12 @@ func registerConfiguration() error {
|
|||
cfgOptionDisableAutoPermit = config.Concurrent.GetAsInt(CfgOptionDisableAutoPermitKey, int64(status.SecurityLevelsAll))
|
||||
cfgIntOptions[CfgOptionDisableAutoPermitKey] = cfgOptionDisableAutoPermit
|
||||
|
||||
// Endpoint Filter List
|
||||
err = config.Register(&config.Option{
|
||||
Name: "Endpoint Filter List",
|
||||
Key: CfgOptionEndpointsKey,
|
||||
Description: "Filter outgoing connections by matching the destination endpoint. Network Scope restrictions still apply.",
|
||||
Help: `Format:
|
||||
filterListHelp := `Format:
|
||||
Permission:
|
||||
"+": permit
|
||||
"-": block
|
||||
Host Matching:
|
||||
IP, CIDR, Country Code, ASN, Filterlist, "*" for any
|
||||
IP, CIDR, Country Code, ASN, Filterlist, Network Scope, "*" for any
|
||||
Domains:
|
||||
"example.com": exact match
|
||||
".example.com": exact match + subdomains
|
||||
|
@ -144,11 +140,20 @@ func registerConfiguration() error {
|
|||
Examples:
|
||||
+ .example.com */HTTP
|
||||
- .example.com
|
||||
+ 192.168.0.1/24
|
||||
+ 192.168.0.1
|
||||
+ 192.168.1.1/24
|
||||
+ Localhost,LAN
|
||||
- AS123456789
|
||||
- L:MAL
|
||||
- AS0
|
||||
+ AT
|
||||
- *`,
|
||||
- *`
|
||||
|
||||
// Endpoint Filter List
|
||||
err = config.Register(&config.Option{
|
||||
Name: "Endpoint Filter List",
|
||||
Key: CfgOptionEndpointsKey,
|
||||
Description: "Filter outgoing connections by matching the destination endpoint. Network Scope restrictions still apply.",
|
||||
Help: filterListHelp,
|
||||
Order: cfgOptionEndpointsOrder,
|
||||
OptType: config.OptTypeStringArray,
|
||||
DefaultValue: []string{},
|
||||
|
@ -163,35 +168,13 @@ Examples:
|
|||
|
||||
// Service Endpoint Filter List
|
||||
err = config.Register(&config.Option{
|
||||
Name: "Service Endpoint Filter List",
|
||||
Key: CfgOptionServiceEndpointsKey,
|
||||
Description: "Filter incoming connections by matching the source endpoint. Network Scope restrictions and the inbound permission still apply. Also not that the implicit default action of this list is to always block.",
|
||||
Help: `Format:
|
||||
Permission:
|
||||
"+": permit
|
||||
"-": block
|
||||
Host Matching:
|
||||
IP, CIDR, Country Code, ASN, Filterlist, "*" for any
|
||||
Domains:
|
||||
"example.com": exact match
|
||||
".example.com": exact match + subdomains
|
||||
"*xample.com": prefix wildcard
|
||||
"example.*": suffix wildcard
|
||||
"*example*": prefix and suffix wildcard
|
||||
Protocol and Port Matching (optional):
|
||||
<protocol>/<port>
|
||||
|
||||
Examples:
|
||||
+ .example.com */HTTP
|
||||
- .example.com
|
||||
+ 192.168.0.1/24
|
||||
- L:MAL
|
||||
- AS0
|
||||
+ AT
|
||||
- *`,
|
||||
Name: "Service Endpoint Filter List",
|
||||
Key: CfgOptionServiceEndpointsKey,
|
||||
Description: "Filter incoming connections by matching the source endpoint. Network Scope restrictions and the inbound permission still apply. Also not that the implicit default action of this list is to always block.",
|
||||
Help: filterListHelp,
|
||||
Order: cfgOptionServiceEndpointsOrder,
|
||||
OptType: config.OptTypeStringArray,
|
||||
DefaultValue: []string{},
|
||||
DefaultValue: []string{"+ Localhost"},
|
||||
ExternalOptType: "endpoint list",
|
||||
ValidationRegex: `^(\+|\-) [A-z0-9\.:\-*/]+( [A-z0-9/]+)?$`,
|
||||
})
|
||||
|
@ -313,7 +296,7 @@ Examples:
|
|||
Order: cfgOptionBlockP2POrder,
|
||||
OptType: config.OptTypeInt,
|
||||
ExternalOptType: "security level",
|
||||
DefaultValue: status.SecurityLevelsAll,
|
||||
DefaultValue: status.SecurityLevelExtreme,
|
||||
ValidationRegex: "^(4|6|7)$",
|
||||
})
|
||||
if err != nil {
|
||||
|
@ -326,7 +309,7 @@ Examples:
|
|||
err = config.Register(&config.Option{
|
||||
Name: "Block Inbound Connections",
|
||||
Key: CfgOptionBlockInboundKey,
|
||||
Description: "Connections initiated towards your device. This will usually only be the case if you are running a network service or are using peer to peer software.",
|
||||
Description: "Connections initiated towards your device from the LAN or Internet. This will usually only be the case if you are running a network service or are using peer to peer software.",
|
||||
Order: cfgOptionBlockInboundOrder,
|
||||
OptType: config.OptTypeInt,
|
||||
ExternalOptType: "security level",
|
||||
|
|
|
@ -31,7 +31,7 @@ type EndpointDomain struct {
|
|||
}
|
||||
|
||||
func (ep *EndpointDomain) check(entity *intel.Entity, domain string) (EPResult, Reason) {
|
||||
result, reason := ep.match(ep, entity, ep.Domain, "domain matches")
|
||||
result, reason := ep.match(ep, entity, ep.OriginalValue, "domain matches")
|
||||
|
||||
switch ep.MatchType {
|
||||
case domainMatchTypeExact:
|
||||
|
|
106
profile/endpoints/endpoint-scopes.go
Normal file
106
profile/endpoints/endpoint-scopes.go
Normal file
|
@ -0,0 +1,106 @@
|
|||
package endpoints
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/safing/portmaster/network/netutils"
|
||||
|
||||
"github.com/safing/portmaster/intel"
|
||||
)
|
||||
|
||||
const (
|
||||
scopeLocalhost = 1
|
||||
scopeLocalhostName = "Localhost"
|
||||
scopeLocalhostMatcher = "localhost"
|
||||
|
||||
scopeLAN = 2
|
||||
scopeLANName = "LAN"
|
||||
scopeLANMatcher = "lan"
|
||||
|
||||
scopeInternet = 4
|
||||
scopeInternetName = "Internet"
|
||||
scopeInternetMatcher = "internet"
|
||||
)
|
||||
|
||||
// EndpointScope matches network scopes.
|
||||
type EndpointScope struct {
|
||||
EndpointBase
|
||||
|
||||
scopes uint8
|
||||
}
|
||||
|
||||
// Matches checks whether the given entity matches this endpoint definition.
|
||||
func (ep *EndpointScope) Matches(entity *intel.Entity) (EPResult, Reason) {
|
||||
if entity.IP == nil {
|
||||
return Undeterminable, nil
|
||||
}
|
||||
|
||||
classification := netutils.ClassifyIP(entity.IP)
|
||||
var scope uint8
|
||||
switch classification {
|
||||
case netutils.HostLocal:
|
||||
scope = scopeLocalhost
|
||||
case netutils.LinkLocal:
|
||||
scope = scopeLAN
|
||||
case netutils.SiteLocal:
|
||||
scope = scopeLAN
|
||||
case netutils.Global:
|
||||
scope = scopeInternet
|
||||
case netutils.LocalMulticast:
|
||||
scope = scopeLAN
|
||||
case netutils.GlobalMulticast:
|
||||
scope = scopeInternet
|
||||
}
|
||||
|
||||
if ep.scopes&scope > 0 {
|
||||
return ep.match(ep, entity, ep.Scopes(), "scope matches")
|
||||
}
|
||||
return NoMatch, nil
|
||||
}
|
||||
|
||||
// Scopes returns the string representation of all scopes.
|
||||
func (ep *EndpointScope) Scopes() string {
|
||||
// single scope
|
||||
switch ep.scopes {
|
||||
case scopeLocalhost:
|
||||
return scopeLocalhostName
|
||||
case scopeLAN:
|
||||
return scopeLANName
|
||||
case scopeInternet:
|
||||
return scopeInternetName
|
||||
}
|
||||
|
||||
// multiple scopes
|
||||
var s []string
|
||||
if ep.scopes&scopeLocalhost > 0 {
|
||||
s = append(s, scopeLocalhostName)
|
||||
}
|
||||
if ep.scopes&scopeLAN > 0 {
|
||||
s = append(s, scopeLANName)
|
||||
}
|
||||
if ep.scopes&scopeInternet > 0 {
|
||||
s = append(s, scopeInternetName)
|
||||
}
|
||||
return strings.Join(s, ",")
|
||||
}
|
||||
|
||||
func (ep *EndpointScope) String() string {
|
||||
return ep.renderPPP(ep.Scopes())
|
||||
}
|
||||
|
||||
func parseTypeScope(fields []string) (Endpoint, error) {
|
||||
ep := &EndpointScope{}
|
||||
for _, val := range strings.Split(strings.ToLower(fields[1]), ",") {
|
||||
switch val {
|
||||
case scopeLocalhostMatcher:
|
||||
ep.scopes ^= scopeLocalhost
|
||||
case scopeLANMatcher:
|
||||
ep.scopes ^= scopeLAN
|
||||
case scopeInternetMatcher:
|
||||
ep.scopes ^= scopeInternet
|
||||
default:
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
return ep.parsePPP(ep, fields)
|
||||
}
|
|
@ -201,7 +201,7 @@ func invalidDefinitionError(fields []string, msg string) error {
|
|||
return fmt.Errorf(`invalid endpoint definition: "%s" - %s`, strings.Join(fields, " "), msg)
|
||||
}
|
||||
|
||||
func parseEndpoint(value string) (endpoint Endpoint, err error) {
|
||||
func parseEndpoint(value string) (endpoint Endpoint, err error) { //nolint:gocognit
|
||||
fields := strings.Fields(value)
|
||||
if len(fields) < 2 {
|
||||
return nil, fmt.Errorf(`invalid endpoint definition: "%s"`, value)
|
||||
|
@ -231,6 +231,10 @@ func parseEndpoint(value string) (endpoint Endpoint, err error) {
|
|||
if endpoint, err = parseTypeASN(fields); endpoint != nil || err != nil {
|
||||
return
|
||||
}
|
||||
// scopes
|
||||
if endpoint, err = parseTypeScope(fields); endpoint != nil || err != nil {
|
||||
return
|
||||
}
|
||||
// lists
|
||||
if endpoint, err = parseTypeList(fields); endpoint != nil || err != nil {
|
||||
return
|
||||
|
|
|
@ -43,6 +43,12 @@ func TestEndpointParsing(t *testing.T) {
|
|||
testParsing(t, "+ AS1234")
|
||||
testParsing(t, "+ AS12345")
|
||||
|
||||
// network scope
|
||||
testParsing(t, "+ Localhost")
|
||||
testParsing(t, "+ LAN")
|
||||
testParsing(t, "+ Internet")
|
||||
testParsing(t, "+ Localhost,LAN,Internet")
|
||||
|
||||
// protocol and ports
|
||||
testParsing(t, "+ * TCP/1-1024")
|
||||
testParsing(t, "+ * */DNS")
|
||||
|
|
|
@ -342,7 +342,26 @@ func TestEndpointMatching(t *testing.T) {
|
|||
IP: net.ParseIP("151.101.1.164"), // nytimes.com
|
||||
}).Init(), NoMatch)
|
||||
|
||||
// Scope
|
||||
|
||||
ep, err = parseEndpoint("+ Localhost,LAN")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
testEndpointMatch(t, ep, (&intel.Entity{
|
||||
IP: net.ParseIP("192.168.0.1"),
|
||||
}).Init(), Permitted)
|
||||
testEndpointMatch(t, ep, (&intel.Entity{
|
||||
IP: net.ParseIP("151.101.1.164"), // nytimes.com
|
||||
}).Init(), NoMatch)
|
||||
|
||||
// Lists
|
||||
|
||||
_, err = parseEndpoint("+ L:A,B,C")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// TODO: write test for lists matcher
|
||||
|
||||
}
|
||||
|
|
|
@ -126,6 +126,18 @@ func (lp *LayeredProfile) getValidityFlag() *abool.AtomicBool {
|
|||
return lp.validityFlag
|
||||
}
|
||||
|
||||
// RevisionCnt returns the current profile revision counter.
|
||||
func (lp *LayeredProfile) RevisionCnt() (revisionCounter uint64) {
|
||||
if lp == nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
lp.lock.Lock()
|
||||
defer lp.lock.Unlock()
|
||||
|
||||
return lp.revisionCounter
|
||||
}
|
||||
|
||||
// Update checks for updated profiles and replaces any outdated profiles.
|
||||
func (lp *LayeredProfile) Update() (revisionCounter uint64) {
|
||||
lp.lock.Lock()
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package resolver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"sync"
|
||||
|
@ -9,6 +10,13 @@ import (
|
|||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultClientTTL = 5 * time.Minute
|
||||
defaultRequestTimeout = 3 * time.Second // dns query
|
||||
defaultConnectTimeout = 2 * time.Second // tcp/tls
|
||||
connectionEOLGracePeriod = 7 * time.Second
|
||||
)
|
||||
|
||||
var (
|
||||
localAddrFactory func(network string) net.Addr
|
||||
)
|
||||
|
@ -27,21 +35,54 @@ func getLocalAddr(network string) net.Addr {
|
|||
return nil
|
||||
}
|
||||
|
||||
type clientManager struct {
|
||||
dnsClient *dns.Client
|
||||
factory func() *dns.Client
|
||||
type dnsClientManager struct {
|
||||
lock sync.Mutex
|
||||
|
||||
lock sync.Mutex
|
||||
refreshAfter time.Time
|
||||
ttl time.Duration // force refresh of connection to reduce traceability
|
||||
// set by creator
|
||||
serverAddress string
|
||||
ttl time.Duration // force refresh of connection to reduce traceability
|
||||
factory func() *dns.Client
|
||||
|
||||
// internal
|
||||
pool sync.Pool
|
||||
}
|
||||
|
||||
func newDNSClientManager(_ *Resolver) *clientManager {
|
||||
return &clientManager{
|
||||
ttl: 0, // new client for every request, as we need to randomize the port
|
||||
type dnsClient struct {
|
||||
mgr *dnsClientManager
|
||||
client *dns.Client
|
||||
conn *dns.Conn
|
||||
useUntil time.Time
|
||||
}
|
||||
|
||||
// getConn returns the *dns.Conn and if it's new. This function may only be called between clientManager.getDNSClient() and dnsClient.done().
|
||||
func (dc *dnsClient) getConn() (c *dns.Conn, new bool, err error) {
|
||||
if dc.conn == nil {
|
||||
dc.conn, err = dc.client.Dial(dc.mgr.serverAddress)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
return dc.conn, true, nil
|
||||
}
|
||||
return dc.conn, false, nil
|
||||
}
|
||||
|
||||
func (dc *dnsClient) addToPool() {
|
||||
dc.mgr.pool.Put(dc)
|
||||
}
|
||||
|
||||
func (dc *dnsClient) destroy() {
|
||||
if dc.conn != nil {
|
||||
_ = dc.conn.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func newDNSClientManager(resolver *Resolver) *dnsClientManager {
|
||||
return &dnsClientManager{
|
||||
serverAddress: resolver.ServerAddress,
|
||||
ttl: 0, // new client for every request, as we need to randomize the port
|
||||
factory: func() *dns.Client {
|
||||
return &dns.Client{
|
||||
Timeout: 5 * time.Second,
|
||||
Timeout: defaultRequestTimeout,
|
||||
Dialer: &net.Dialer{
|
||||
LocalAddr: getLocalAddr("udp"),
|
||||
},
|
||||
|
@ -50,25 +91,28 @@ func newDNSClientManager(_ *Resolver) *clientManager {
|
|||
}
|
||||
}
|
||||
|
||||
func newTCPClientManager(_ *Resolver) *clientManager {
|
||||
return &clientManager{
|
||||
ttl: 0, // TODO: build a custom client that can reuse connections to some degree (performance / privacy tradeoff)
|
||||
func newTCPClientManager(resolver *Resolver) *dnsClientManager {
|
||||
return &dnsClientManager{
|
||||
serverAddress: resolver.ServerAddress,
|
||||
ttl: defaultClientTTL,
|
||||
factory: func() *dns.Client {
|
||||
return &dns.Client{
|
||||
Net: "tcp",
|
||||
Timeout: 5 * time.Second,
|
||||
Timeout: defaultRequestTimeout,
|
||||
Dialer: &net.Dialer{
|
||||
LocalAddr: getLocalAddr("tcp"),
|
||||
KeepAlive: 15 * time.Second,
|
||||
Timeout: defaultConnectTimeout,
|
||||
KeepAlive: defaultClientTTL,
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func newTLSClientManager(resolver *Resolver) *clientManager {
|
||||
return &clientManager{
|
||||
ttl: 0, // TODO: build a custom client that can reuse connections to some degree (performance / privacy tradeoff)
|
||||
func newTLSClientManager(resolver *Resolver) *dnsClientManager {
|
||||
return &dnsClientManager{
|
||||
serverAddress: resolver.ServerAddress,
|
||||
ttl: defaultClientTTL,
|
||||
factory: func() *dns.Client {
|
||||
return &dns.Client{
|
||||
Net: "tcp-tls",
|
||||
|
@ -77,24 +121,68 @@ func newTLSClientManager(resolver *Resolver) *clientManager {
|
|||
ServerName: resolver.VerifyDomain,
|
||||
// TODO: use portbase rng
|
||||
},
|
||||
Timeout: 5 * time.Second,
|
||||
Timeout: defaultRequestTimeout,
|
||||
Dialer: &net.Dialer{
|
||||
LocalAddr: getLocalAddr("tcp"),
|
||||
KeepAlive: 15 * time.Second,
|
||||
Timeout: defaultConnectTimeout,
|
||||
KeepAlive: defaultClientTTL,
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (cm *clientManager) getDNSClient() *dns.Client {
|
||||
func (cm *dnsClientManager) getDNSClient() *dnsClient {
|
||||
cm.lock.Lock()
|
||||
defer cm.lock.Unlock()
|
||||
|
||||
if cm.dnsClient == nil || cm.ttl == 0 || time.Now().After(cm.refreshAfter) {
|
||||
cm.dnsClient = cm.factory()
|
||||
cm.refreshAfter = time.Now().Add(cm.ttl)
|
||||
// return new immediately if a new client should be used for every request
|
||||
if cm.ttl == 0 {
|
||||
return &dnsClient{
|
||||
mgr: cm,
|
||||
client: cm.factory(),
|
||||
}
|
||||
}
|
||||
|
||||
return cm.dnsClient
|
||||
// get cached client from pool
|
||||
now := time.Now().UTC()
|
||||
|
||||
poolLoop:
|
||||
for {
|
||||
dc, ok := cm.pool.Get().(*dnsClient)
|
||||
switch {
|
||||
case !ok || dc == nil: // cache empty (probably, pool may always return nil!)
|
||||
break poolLoop // create new
|
||||
case now.After(dc.useUntil):
|
||||
continue // get next
|
||||
default:
|
||||
return dc
|
||||
}
|
||||
}
|
||||
|
||||
// no available in pool, create new
|
||||
newClient := &dnsClient{
|
||||
mgr: cm,
|
||||
client: cm.factory(),
|
||||
useUntil: now.Add(cm.ttl),
|
||||
}
|
||||
newClient.startCleaner()
|
||||
|
||||
return newClient
|
||||
}
|
||||
|
||||
// startCleaner waits for EOL of the client and then removes it from the pool.
|
||||
func (dc *dnsClient) startCleaner() {
|
||||
// While a single worker to clean all connections may be slightly more performant, this approach focuses on least as possible locking and is simpler, thus less error prone.
|
||||
module.StartWorker("dns client cleanup", func(ctx context.Context) error {
|
||||
select {
|
||||
case <-time.After(dc.mgr.ttl + connectionEOLGracePeriod):
|
||||
// destroy
|
||||
case <-ctx.Done():
|
||||
// give a short time before kill for graceful request completion
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
dc.destroy()
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
|
|
@ -2,6 +2,7 @@ package resolver
|
|||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/safing/portbase/log"
|
||||
|
@ -30,6 +31,7 @@ func start() error {
|
|||
// load resolvers from config and environment
|
||||
loadResolvers()
|
||||
|
||||
// reload after network change
|
||||
err := module.RegisterEventHook(
|
||||
"netenv",
|
||||
"network changed",
|
||||
|
@ -44,6 +46,27 @@ func start() error {
|
|||
return err
|
||||
}
|
||||
|
||||
// reload after config change
|
||||
prevNameservers := strings.Join(configuredNameServers(), " ")
|
||||
err = module.RegisterEventHook(
|
||||
"config",
|
||||
"config change",
|
||||
"update nameservers",
|
||||
func(_ context.Context, _ interface{}) error {
|
||||
newNameservers := strings.Join(configuredNameServers(), " ")
|
||||
if newNameservers != prevNameservers {
|
||||
prevNameservers = newNameservers
|
||||
|
||||
loadResolvers()
|
||||
log.Debug("resolver: reloaded nameservers due to config change")
|
||||
}
|
||||
return nil
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
module.StartServiceWorker(
|
||||
"mdns handler",
|
||||
5*time.Second,
|
||||
|
|
|
@ -9,6 +9,8 @@ import (
|
|||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/safing/portmaster/network/netutils"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
|
||||
"github.com/safing/portbase/log"
|
||||
|
@ -29,10 +31,11 @@ var (
|
|||
questionsLock sync.Mutex
|
||||
|
||||
mDNSResolver = &Resolver{
|
||||
Server: ServerSourceMDNS,
|
||||
ServerType: ServerTypeDNS,
|
||||
Source: ServerSourceMDNS,
|
||||
Conn: &mDNSResolverConn{},
|
||||
Server: ServerSourceMDNS,
|
||||
ServerType: ServerTypeDNS,
|
||||
ServerIPScope: netutils.SiteLocal,
|
||||
Source: ServerSourceMDNS,
|
||||
Conn: &mDNSResolverConn{},
|
||||
}
|
||||
)
|
||||
|
||||
|
@ -42,10 +45,10 @@ func (mrc *mDNSResolverConn) Query(ctx context.Context, q *Query) (*RRCache, err
|
|||
return queryMulticastDNS(ctx, q)
|
||||
}
|
||||
|
||||
func (mrc *mDNSResolverConn) MarkFailed() {}
|
||||
func (mrc *mDNSResolverConn) ReportFailure() {}
|
||||
|
||||
func (mrc *mDNSResolverConn) LastFail() time.Time {
|
||||
return time.Time{}
|
||||
func (mrc *mDNSResolverConn) IsFailing() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
type savedQuestion struct {
|
||||
|
@ -189,15 +192,21 @@ func handleMDNSMessages(ctx context.Context, messages chan *dns.Msg) error {
|
|||
|
||||
// get entry from database
|
||||
if saveFullRequest {
|
||||
// get from database
|
||||
rrCache, err = GetRRCache(question.Name, dns.Type(question.Qtype))
|
||||
// if we have no cached entry, or it has been updated less more than two seconds ago, or if it expired:
|
||||
// create new and do not append
|
||||
if err != nil || rrCache.updated < time.Now().Add(-2*time.Second).Unix() || rrCache.TTL < time.Now().Unix() {
|
||||
rrCache = &RRCache{
|
||||
Domain: question.Name,
|
||||
Question: dns.Type(question.Qtype),
|
||||
Domain: question.Name,
|
||||
Question: dns.Type(question.Qtype),
|
||||
Server: mDNSResolver.Server,
|
||||
ServerScope: mDNSResolver.ServerIPScope,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// add all entries to RRCache
|
||||
for _, entry := range message.Answer {
|
||||
if strings.HasSuffix(entry.Header().Name, ".local.") || domainInScope(entry.Header().Name, localReverseScopes) {
|
||||
if saveFullRequest {
|
||||
|
@ -289,9 +298,11 @@ func handleMDNSMessages(ctx context.Context, messages chan *dns.Msg) error {
|
|||
continue
|
||||
}
|
||||
rrCache = &RRCache{
|
||||
Domain: v.Header().Name,
|
||||
Question: dns.Type(v.Header().Class),
|
||||
Answer: []dns.RR{v},
|
||||
Domain: v.Header().Name,
|
||||
Question: dns.Type(v.Header().Class),
|
||||
Answer: []dns.RR{v},
|
||||
Server: mDNSResolver.Server,
|
||||
ServerScope: mDNSResolver.ServerIPScope,
|
||||
}
|
||||
rrCache.Clean(60)
|
||||
err := rrCache.Save()
|
||||
|
|
|
@ -12,7 +12,7 @@ import (
|
|||
var (
|
||||
recordDatabase = database.NewInterface(&database.Options{
|
||||
AlwaysSetRelativateExpiry: 2592000, // 30 days
|
||||
CacheSize: 128,
|
||||
CacheSize: 256,
|
||||
})
|
||||
)
|
||||
|
||||
|
|
27
resolver/namerecord_test.go
Normal file
27
resolver/namerecord_test.go
Normal file
|
@ -0,0 +1,27 @@
|
|||
package resolver
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestNameRecordStorage(t *testing.T) {
|
||||
testDomain := "Mk35mMqOWEHXSMk11MYcbjLOjTE8PQvDiAVUxf4BvwtgR.example.com."
|
||||
testQuestion := "A"
|
||||
|
||||
testNameRecord := &NameRecord{
|
||||
Domain: testDomain,
|
||||
Question: testQuestion,
|
||||
}
|
||||
|
||||
err := testNameRecord.Save()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
r, err := GetNameRecord(testDomain, testQuestion)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if r.Domain != testDomain || r.Question != testQuestion {
|
||||
t.Fatal("mismatch")
|
||||
}
|
||||
}
|
189
resolver/pooling_test.go
Normal file
189
resolver/pooling_test.go
Normal file
|
@ -0,0 +1,189 @@
|
|||
package resolver
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
var (
|
||||
domainFeed = make(chan string)
|
||||
)
|
||||
|
||||
func testQuery(t *testing.T, wg *sync.WaitGroup, newCnt *uint32, brc *BasicResolverConn, q *Query) {
|
||||
dnsClient := brc.clientManager.getDNSClient()
|
||||
|
||||
// create query
|
||||
dnsQuery := new(dns.Msg)
|
||||
dnsQuery.SetQuestion(q.FQDN, uint16(q.QType))
|
||||
|
||||
// get connection
|
||||
conn, new, err := dnsClient.getConn()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to connect: %s", err) //nolint:staticcheck
|
||||
}
|
||||
if new {
|
||||
atomic.AddUint32(newCnt, 1)
|
||||
}
|
||||
|
||||
// query server
|
||||
reply, ttl, err := dnsClient.client.ExchangeWithConn(dnsQuery, conn)
|
||||
if err != nil {
|
||||
t.Fatal(err) //nolint:staticcheck
|
||||
}
|
||||
if reply == nil {
|
||||
t.Fatalf("resolved %s, but reply was empty!", q.FQDN) //nolint:staticcheck
|
||||
}
|
||||
|
||||
t.Logf("resolved %s [new resolver = %v] in %s", q.FQDN, new, ttl)
|
||||
dnsClient.addToPool()
|
||||
wg.Done()
|
||||
}
|
||||
|
||||
func TestClientPooling(t *testing.T) {
|
||||
// skip if short - this test depends on the Internet and might fail randomly
|
||||
if testing.Short() {
|
||||
t.Skip()
|
||||
}
|
||||
|
||||
go feedDomains()
|
||||
|
||||
// create separate resolver for this test
|
||||
resolver, _, err := createResolver("dot://9.9.9.9:853?verify=dns.quad9.net&name=Quad9&blockedif=empty", "config")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
brc := resolver.Conn.(*BasicResolverConn)
|
||||
|
||||
wg := &sync.WaitGroup{}
|
||||
var newCnt uint32
|
||||
for i := 0; i < 10; i++ {
|
||||
wg.Add(10)
|
||||
for i := 0; i < 10; i++ {
|
||||
go testQuery(t, wg, &newCnt, brc, &Query{ //nolint:staticcheck
|
||||
FQDN: <-domainFeed,
|
||||
QType: dns.Type(dns.TypeA),
|
||||
})
|
||||
}
|
||||
wg.Wait()
|
||||
if newCnt > uint32(10+i) {
|
||||
t.Fatalf("unexpected pool size: %d (limit is %d)", newCnt, 10+i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func feedDomains() {
|
||||
for {
|
||||
for _, domain := range poolingTestDomains {
|
||||
domainFeed <- domain
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Data
|
||||
|
||||
var (
|
||||
poolingTestDomains = []string{
|
||||
"facebook.com.",
|
||||
"google.com.",
|
||||
"youtube.com.",
|
||||
"twitter.com.",
|
||||
"instagram.com.",
|
||||
"linkedin.com.",
|
||||
"microsoft.com.",
|
||||
"apple.com.",
|
||||
"wikipedia.org.",
|
||||
"plus.google.com.",
|
||||
"en.wikipedia.org.",
|
||||
"googletagmanager.com.",
|
||||
"youtu.be.",
|
||||
"adobe.com.",
|
||||
"vimeo.com.",
|
||||
"pinterest.com.",
|
||||
"itunes.apple.com.",
|
||||
"play.google.com.",
|
||||
"maps.google.com.",
|
||||
"goo.gl.",
|
||||
"wordpress.com.",
|
||||
"blogspot.com.",
|
||||
"bit.ly.",
|
||||
"github.com.",
|
||||
"player.vimeo.com.",
|
||||
"amazon.com.",
|
||||
"wordpress.org.",
|
||||
"docs.google.com.",
|
||||
"yahoo.com.",
|
||||
"mozilla.org.",
|
||||
"tumblr.com.",
|
||||
"godaddy.com.",
|
||||
"flickr.com.",
|
||||
"parked-content.godaddy.com.",
|
||||
"drive.google.com.",
|
||||
"support.google.com.",
|
||||
"apache.org.",
|
||||
"gravatar.com.",
|
||||
"europa.eu.",
|
||||
"qq.com.",
|
||||
"w3.org.",
|
||||
"nytimes.com.",
|
||||
"reddit.com.",
|
||||
"macromedia.com.",
|
||||
"get.adobe.com.",
|
||||
"soundcloud.com.",
|
||||
"sourceforge.net.",
|
||||
"sites.google.com.",
|
||||
"nih.gov.",
|
||||
"amazonaws.com.",
|
||||
"t.co.",
|
||||
"support.microsoft.com.",
|
||||
"forbes.com.",
|
||||
"theguardian.com.",
|
||||
"cnn.com.",
|
||||
"github.io.",
|
||||
"bbc.co.uk.",
|
||||
"dropbox.com.",
|
||||
"whatsapp.com.",
|
||||
"medium.com.",
|
||||
"creativecommons.org.",
|
||||
"www.ncbi.nlm.nih.gov.",
|
||||
"httpd.apache.org.",
|
||||
"archive.org.",
|
||||
"ec.europa.eu.",
|
||||
"php.net.",
|
||||
"apps.apple.com.",
|
||||
"weebly.com.",
|
||||
"support.apple.com.",
|
||||
"weibo.com.",
|
||||
"wixsite.com.",
|
||||
"issuu.com.",
|
||||
"who.int.",
|
||||
"paypal.com.",
|
||||
"m.facebook.com.",
|
||||
"oracle.com.",
|
||||
"msn.com.",
|
||||
"gnu.org.",
|
||||
"tinyurl.com.",
|
||||
"reuters.com.",
|
||||
"l.facebook.com.",
|
||||
"cloudflare.com.",
|
||||
"wsj.com.",
|
||||
"washingtonpost.com.",
|
||||
"domainmarket.com.",
|
||||
"imdb.com.",
|
||||
"bbc.com.",
|
||||
"bing.com.",
|
||||
"accounts.google.com.",
|
||||
"vk.com.",
|
||||
"api.whatsapp.com.",
|
||||
"opera.com.",
|
||||
"cdc.gov.",
|
||||
"slideshare.net.",
|
||||
"wpa.qq.com.",
|
||||
"harvard.edu.",
|
||||
"mit.edu.",
|
||||
"code.google.com.",
|
||||
"wikimedia.org.",
|
||||
}
|
||||
)
|
|
@ -14,8 +14,6 @@ import (
|
|||
)
|
||||
|
||||
var (
|
||||
mtAsyncResolve = "async resolve"
|
||||
|
||||
// basic errors
|
||||
|
||||
// ErrNotFound is a basic error that will match all "not found" errors
|
||||
|
@ -114,6 +112,7 @@ func Resolve(ctx context.Context, q *Query) (rrCache *RRCache, err error) {
|
|||
rrCache.MixAnswers()
|
||||
return rrCache, nil
|
||||
}
|
||||
log.Tracer(ctx).Debugf("resolver: waited for another %s%s query, but cache missed!", q.FQDN, q.QType)
|
||||
// if cache is still empty or non-compliant, go ahead and just query
|
||||
} else {
|
||||
// we are the first!
|
||||
|
@ -132,14 +131,14 @@ func checkCache(ctx context.Context, q *Query) *RRCache {
|
|||
if err != nil {
|
||||
if err != database.ErrNotFound {
|
||||
log.Tracer(ctx).Warningf("resolver: getting RRCache %s%s from database failed: %s", q.FQDN, q.QType.String(), err)
|
||||
log.Warningf("resolver: getting RRCache %s%s from database failed: %s", q.FQDN, q.QType.String(), err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// get resolver that rrCache was resolved with
|
||||
resolver := getResolverByIDWithLocking(rrCache.Server)
|
||||
resolver := getActiveResolverByIDWithLocking(rrCache.Server)
|
||||
if resolver == nil {
|
||||
log.Tracer(ctx).Debugf("resolver: ignoring RRCache %s%s because source server %s has been removed", q.FQDN, q.QType.String(), rrCache.Server)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -159,12 +158,13 @@ func checkCache(ctx context.Context, q *Query) *RRCache {
|
|||
log.Tracer(ctx).Trace("resolver: serving from cache, requesting new")
|
||||
|
||||
// resolve async
|
||||
module.StartMediumPriorityMicroTask(&mtAsyncResolve, func(ctx context.Context) error {
|
||||
module.StartWorker("resolve async", func(ctx context.Context) error {
|
||||
_, _ = resolveAndCache(ctx, q)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
log.Tracer(ctx).Tracef("resolver: using cached RR (expires in %s)", time.Until(time.Unix(rrCache.TTL, 0)))
|
||||
return rrCache
|
||||
}
|
||||
|
||||
|
@ -218,11 +218,6 @@ func resolveAndCache(ctx context.Context, q *Query) (rrCache *RRCache, err error
|
|||
return nil, ErrNoCompliance
|
||||
}
|
||||
|
||||
// prep
|
||||
lastFailBoundary := time.Now().Add(
|
||||
-time.Duration(nameserverRetryRate()) * time.Second,
|
||||
)
|
||||
|
||||
// start resolving
|
||||
|
||||
var i int
|
||||
|
@ -231,7 +226,7 @@ resolveLoop:
|
|||
for i = 0; i < 2; i++ {
|
||||
for _, resolver := range resolvers {
|
||||
// check if resolver failed recently (on first run)
|
||||
if i == 0 && resolver.Conn.LastFail().After(lastFailBoundary) {
|
||||
if i == 0 && resolver.Conn.IsFailing() {
|
||||
log.Tracer(ctx).Tracef("resolver: skipping resolver %s, because it failed recently", resolver)
|
||||
continue
|
||||
}
|
||||
|
|
|
@ -2,6 +2,7 @@ package resolver
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
@ -23,6 +24,11 @@ const (
|
|||
ServerSourceMDNS = "mdns"
|
||||
)
|
||||
|
||||
var (
|
||||
// FailThreshold is amount of errors a resolvers must experience in order to be regarded as failed.
|
||||
FailThreshold = 5
|
||||
)
|
||||
|
||||
// Resolver holds information about an active resolver.
|
||||
type Resolver struct {
|
||||
// Server config url (and ID)
|
||||
|
@ -83,8 +89,8 @@ func (resolver *Resolver) String() string {
|
|||
// ResolverConn is an interface to implement different types of query backends.
|
||||
type ResolverConn interface { //nolint:go-lint // TODO
|
||||
Query(ctx context.Context, q *Query) (*RRCache, error)
|
||||
MarkFailed()
|
||||
LastFail() time.Time
|
||||
ReportFailure()
|
||||
IsFailing() bool
|
||||
}
|
||||
|
||||
// BasicResolverConn implements ResolverConn for standard dns clients.
|
||||
|
@ -92,12 +98,14 @@ type BasicResolverConn struct {
|
|||
sync.Mutex // for lastFail
|
||||
|
||||
resolver *Resolver
|
||||
clientManager *clientManager
|
||||
lastFail time.Time
|
||||
clientManager *dnsClientManager
|
||||
|
||||
lastFail time.Time
|
||||
fails int
|
||||
}
|
||||
|
||||
// MarkFailed marks the resolver as failed.
|
||||
func (brc *BasicResolverConn) MarkFailed() {
|
||||
// ReportFailure reports that an error occurred with this resolver.
|
||||
func (brc *BasicResolverConn) ReportFailure() {
|
||||
if !netenv.Online() {
|
||||
// don't mark failed if we are offline
|
||||
return
|
||||
|
@ -105,14 +113,26 @@ func (brc *BasicResolverConn) MarkFailed() {
|
|||
|
||||
brc.Lock()
|
||||
defer brc.Unlock()
|
||||
brc.lastFail = time.Now()
|
||||
now := time.Now().UTC()
|
||||
failDuration := time.Duration(nameserverRetryRate()) * time.Second
|
||||
|
||||
// reset fail counter if currently not failing
|
||||
if now.Add(-failDuration).After(brc.lastFail) {
|
||||
brc.fails = 0
|
||||
}
|
||||
|
||||
// update
|
||||
brc.lastFail = now
|
||||
brc.fails++
|
||||
}
|
||||
|
||||
// LastFail returns the internal lastfail value while locking the Resolver.
|
||||
func (brc *BasicResolverConn) LastFail() time.Time {
|
||||
// IsFailing returns if this resolver is currently failing.
|
||||
func (brc *BasicResolverConn) IsFailing() bool {
|
||||
brc.Lock()
|
||||
defer brc.Unlock()
|
||||
return brc.lastFail
|
||||
|
||||
failDuration := time.Duration(nameserverRetryRate()) * time.Second
|
||||
return brc.fails >= FailThreshold && time.Now().UTC().Add(-failDuration).Before(brc.lastFail)
|
||||
}
|
||||
|
||||
// Query executes the given query against the resolver.
|
||||
|
@ -126,35 +146,78 @@ func (brc *BasicResolverConn) Query(ctx context.Context, q *Query) (*RRCache, er
|
|||
|
||||
// start
|
||||
var reply *dns.Msg
|
||||
var ttl time.Duration
|
||||
var err error
|
||||
for i := 0; i < 3; i++ {
|
||||
var conn *dns.Conn
|
||||
var new bool
|
||||
var tries int
|
||||
|
||||
// log query time
|
||||
// qStart := time.Now()
|
||||
reply, _, err = brc.clientManager.getDNSClient().Exchange(dnsQuery, resolver.ServerAddress)
|
||||
// log.Tracef("resolver: query to %s took %s", resolver.Server, time.Now().Sub(qStart))
|
||||
for ; tries < 3; tries++ {
|
||||
|
||||
// error handling
|
||||
// first get connection
|
||||
dc := brc.clientManager.getDNSClient()
|
||||
conn, new, err = dc.getConn()
|
||||
if err != nil {
|
||||
log.Tracer(ctx).Tracef("resolver: query to %s encountered error: %s", resolver.Server, err)
|
||||
log.Tracer(ctx).Tracef("resolver: failed to connect to %s: %s", resolver.Server, err)
|
||||
// remove client from pool
|
||||
dc.destroy()
|
||||
// report that resolver had an error
|
||||
brc.ReportFailure()
|
||||
// hint network environment at failed connection
|
||||
netenv.ReportFailedConnection()
|
||||
|
||||
// TODO: handle special cases
|
||||
// 1. connect: network is unreachable
|
||||
// 2. timeout
|
||||
|
||||
// hint network environment at failed connection
|
||||
netenv.ReportFailedConnection()
|
||||
// try again
|
||||
continue
|
||||
}
|
||||
if new {
|
||||
log.Tracer(ctx).Tracef("resolver: created new connection to %s (%s)", resolver.Name, resolver.ServerAddress)
|
||||
} else {
|
||||
log.Tracer(ctx).Tracef("resolver: reusing connection to %s (%s)", resolver.Name, resolver.ServerAddress)
|
||||
}
|
||||
|
||||
// query server
|
||||
reply, ttl, err = dc.client.ExchangeWithConn(dnsQuery, conn)
|
||||
log.Tracer(ctx).Tracef("resolver: query took %s", ttl)
|
||||
|
||||
// error handling
|
||||
if err != nil {
|
||||
log.Tracer(ctx).Tracef("resolver: query to %s encountered error: %s", resolver.Server, err)
|
||||
|
||||
// remove client from pool
|
||||
dc.destroy()
|
||||
|
||||
// temporary error
|
||||
if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
|
||||
log.Tracer(ctx).Tracef("resolver: retrying to resolve %s%s with %s, error is temporary", q.FQDN, q.QType, resolver.Server)
|
||||
// try again
|
||||
continue
|
||||
}
|
||||
|
||||
// report failed if dns (nothing happens at getConn())
|
||||
if resolver.ServerType == ServerTypeDNS {
|
||||
// report that resolver had an error
|
||||
brc.ReportFailure()
|
||||
// hint network environment at failed connection
|
||||
netenv.ReportFailedConnection()
|
||||
}
|
||||
|
||||
// permanent error
|
||||
break
|
||||
} else if reply == nil {
|
||||
// remove client from pool
|
||||
dc.destroy()
|
||||
|
||||
log.Errorf("resolver: successful query for %s%s to %s, but reply was nil", q.FQDN, q.QType, resolver.Server)
|
||||
return nil, errors.New("internal error")
|
||||
}
|
||||
|
||||
// make client available (again)
|
||||
dc.addToPool()
|
||||
|
||||
if resolver.IsBlockedUpstream(reply) {
|
||||
return nil, &BlockedUpstreamError{resolver.GetName()}
|
||||
}
|
||||
|
@ -166,12 +229,15 @@ func (brc *BasicResolverConn) Query(ctx context.Context, q *Query) (*RRCache, er
|
|||
if err != nil {
|
||||
return nil, err
|
||||
// TODO: mark as failed
|
||||
} else if reply == nil {
|
||||
log.Errorf("resolver: queried %s for %s%s (%d tries), but reply was nil", q.FQDN, q.QType, resolver.GetName(), tries+1)
|
||||
return nil, errors.New("internal error")
|
||||
}
|
||||
|
||||
// hint network environment at successful connection
|
||||
netenv.ReportSuccessfulConnection()
|
||||
|
||||
new := &RRCache{
|
||||
newRecord := &RRCache{
|
||||
Domain: q.FQDN,
|
||||
Question: q.QType,
|
||||
Answer: reply.Answer,
|
||||
|
@ -182,5 +248,5 @@ func (brc *BasicResolverConn) Query(ctx context.Context, q *Query) (*RRCache, er
|
|||
}
|
||||
|
||||
// TODO: check if reply.Answer is valid
|
||||
return new, nil
|
||||
return newRecord, nil
|
||||
}
|
||||
|
|
|
@ -25,7 +25,7 @@ var (
|
|||
globalResolvers []*Resolver // all (global) resolvers
|
||||
localResolvers []*Resolver // all resolvers that are in site-local or link-local IP ranges
|
||||
localScopes []*Scope // list of scopes with a list of local resolvers that can resolve the scope
|
||||
allResolvers map[string]*Resolver // lookup map of all resolvers
|
||||
activeResolvers map[string]*Resolver // lookup map of all resolvers
|
||||
resolversLock sync.RWMutex
|
||||
|
||||
dupReqMap = make(map[string]*sync.WaitGroup)
|
||||
|
@ -41,11 +41,11 @@ func indexOfScope(domain string, list []*Scope) int {
|
|||
return -1
|
||||
}
|
||||
|
||||
func getResolverByIDWithLocking(server string) *Resolver {
|
||||
resolversLock.Lock()
|
||||
defer resolversLock.Unlock()
|
||||
func getActiveResolverByIDWithLocking(server string) *Resolver {
|
||||
resolversLock.RLock()
|
||||
defer resolversLock.RUnlock()
|
||||
|
||||
resolver, ok := allResolvers[server]
|
||||
resolver, ok := activeResolvers[server]
|
||||
if ok {
|
||||
return resolver
|
||||
}
|
||||
|
@ -62,7 +62,7 @@ func formatIPAndPort(ip net.IP, port uint16) string {
|
|||
return address
|
||||
}
|
||||
|
||||
func clientManagerFactory(serverType string) func(*Resolver) *clientManager {
|
||||
func clientManagerFactory(serverType string) func(*Resolver) *dnsClientManager {
|
||||
switch serverType {
|
||||
case ServerTypeDNS:
|
||||
return newDNSClientManager
|
||||
|
@ -152,8 +152,8 @@ func configureSearchDomains(resolver *Resolver, searches []string) {
|
|||
}
|
||||
}
|
||||
|
||||
func getConfiguredResolvers() (resolvers []*Resolver) {
|
||||
for _, server := range configuredNameServers() {
|
||||
func getConfiguredResolvers(list []string) (resolvers []*Resolver) {
|
||||
for _, server := range list {
|
||||
resolver, skip, err := createResolver(server, "config")
|
||||
if err != nil {
|
||||
// TODO(ppacher): module error
|
||||
|
@ -199,19 +199,40 @@ func loadResolvers() {
|
|||
defer resolversLock.Unlock()
|
||||
|
||||
newResolvers := append(
|
||||
getConfiguredResolvers(),
|
||||
getConfiguredResolvers(configuredNameServers()),
|
||||
getSystemResolvers()...,
|
||||
)
|
||||
|
||||
// save resolvers
|
||||
globalResolvers = newResolvers
|
||||
if len(globalResolvers) == 0 {
|
||||
log.Criticalf("resolver: no (valid) dns servers found in configuration and system")
|
||||
// TODO(module error)
|
||||
if len(newResolvers) == 0 {
|
||||
msg := "no (valid) dns servers found in (user) configuration or system, falling back to defaults"
|
||||
log.Warningf("resolver: %s", msg)
|
||||
module.Warning("no-valid-user-resolvers", msg)
|
||||
|
||||
// load defaults directly, overriding config system
|
||||
newResolvers = getConfiguredResolvers(defaultNameServers)
|
||||
if len(newResolvers) == 0 {
|
||||
msg = "no (valid) dns servers found in configuration or system"
|
||||
log.Criticalf("resolver: %s", msg)
|
||||
module.Error("no-valid-default-resolvers", msg)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// save resolvers
|
||||
globalResolvers = newResolvers
|
||||
|
||||
// assing resolvers to scopes
|
||||
setLocalAndScopeResolvers(globalResolvers)
|
||||
|
||||
// set active resolvers (for cache validation)
|
||||
// reset
|
||||
activeResolvers = make(map[string]*Resolver)
|
||||
// add
|
||||
for _, resolver := range newResolvers {
|
||||
activeResolvers[resolver.Server] = resolver
|
||||
}
|
||||
activeResolvers[mDNSResolver.Server] = mDNSResolver
|
||||
|
||||
// log global resolvers
|
||||
if len(globalResolvers) > 0 {
|
||||
log.Trace("resolver: loaded global resolvers:")
|
||||
|
|
|
@ -65,12 +65,12 @@ func ResolveIPAndValidate(ctx context.Context, ip string, securityLevel uint8) (
|
|||
for _, rr := range rrCache.Answer {
|
||||
switch v := rr.(type) {
|
||||
case *dns.A:
|
||||
log.Infof("A: %s", v.A.String())
|
||||
// log.Debugf("A: %s", v.A.String())
|
||||
if ip == v.A.String() {
|
||||
return ptrName, nil
|
||||
}
|
||||
case *dns.AAAA:
|
||||
log.Infof("AAAA: %s", v.AAAA.String())
|
||||
// log.Debugf("AAAA: %s", v.AAAA.String())
|
||||
if ip == v.AAAA.String() {
|
||||
return ptrName, nil
|
||||
}
|
||||
|
|
41
resolver/rrcache_test.go
Normal file
41
resolver/rrcache_test.go
Normal file
|
@ -0,0 +1,41 @@
|
|||
package resolver
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
func TestCaching(t *testing.T) {
|
||||
testDomain := "Mk35mMqOWEHXSMk11MYcbjLOjTE8PQvDiAVUxf4BvwtgR.example.com."
|
||||
testQuestion := "A"
|
||||
|
||||
testNameRecord := &NameRecord{
|
||||
Domain: testDomain,
|
||||
Question: testQuestion,
|
||||
}
|
||||
|
||||
err := testNameRecord.Save()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
rrCache, err := GetRRCache(testDomain, dns.Type(dns.TypeA))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = rrCache.Save()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
rrCache2, err := GetRRCache(testDomain, dns.Type(dns.TypeA))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if rrCache2.Domain != rrCache.Domain {
|
||||
t.Fatal("something very is wrong")
|
||||
}
|
||||
}
|
15
ui/module.go
15
ui/module.go
|
@ -3,6 +3,8 @@ package ui
|
|||
import (
|
||||
"context"
|
||||
|
||||
"github.com/safing/portbase/dataroot"
|
||||
|
||||
resources "github.com/cookieo9/resources-go"
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portbase/modules"
|
||||
|
@ -27,6 +29,19 @@ func prep() error {
|
|||
}
|
||||
|
||||
func start() error {
|
||||
// Create a dummy directory to which processes change their working directory
|
||||
// to. Currently this includes the App and the Notifier. The aim is protect
|
||||
// all other directories and increase compatibility should any process want
|
||||
// to read or write something to the current working directory. This can also
|
||||
// be useful in the future to dump data to for debugging. The permission used
|
||||
// may seem dangerous, but proper permission on the parent directory provide
|
||||
// (some) protection.
|
||||
// Processes must _never_ read from this directory.
|
||||
err := dataroot.Root().ChildDir("exec", 0777).Ensure()
|
||||
if err != nil {
|
||||
log.Warningf("ui: failed to create safe exec dir: %s", err)
|
||||
}
|
||||
|
||||
return module.RegisterEventHook("ui", eventReload, "reload assets", reloadUI)
|
||||
}
|
||||
|
||||
|
|
|
@ -152,7 +152,7 @@ func start() error {
|
|||
|
||||
err = registry.LoadIndexes()
|
||||
if err != nil {
|
||||
return err
|
||||
log.Warningf("updates: failed to load indexes: %s", err)
|
||||
}
|
||||
|
||||
err = registry.ScanStorage("")
|
||||
|
@ -235,8 +235,7 @@ func checkForUpdates(ctx context.Context) (err error) {
|
|||
}()
|
||||
|
||||
if err = registry.UpdateIndexes(); err != nil {
|
||||
err = fmt.Errorf("failed to update indexes: %w", err)
|
||||
return
|
||||
log.Warningf("updates: failed to update indexes: %s", err)
|
||||
}
|
||||
|
||||
err = registry.DownloadUpdates(ctx)
|
||||
|
|
|
@ -113,15 +113,9 @@ func upgradePortmasterControl() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// check if registry tmp dir is ok
|
||||
err := registry.TmpDir().Ensure()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to prep updates tmp dir: %s", err)
|
||||
}
|
||||
|
||||
// update portmaster-control in data root
|
||||
rootControlPath := filepath.Join(filepath.Dir(registry.StorageDir().Path), filename)
|
||||
err = upgradeFile(rootControlPath, pmCtrlUpdate)
|
||||
err := upgradeFile(rootControlPath, pmCtrlUpdate)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -130,11 +124,11 @@ func upgradePortmasterControl() error {
|
|||
// upgrade parent process, if it's portmaster-control
|
||||
parent, err := processInfo.NewProcess(int32(os.Getppid()))
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not get parent process for upgrade checks: %s", err)
|
||||
return fmt.Errorf("could not get parent process for upgrade checks: %w", err)
|
||||
}
|
||||
parentName, err := parent.Name()
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not get parent process name for upgrade checks: %s", err)
|
||||
return fmt.Errorf("could not get parent process name for upgrade checks: %w", err)
|
||||
}
|
||||
if parentName != filename {
|
||||
log.Tracef("updates: parent process does not seem to be portmaster-control, name is %s", parentName)
|
||||
|
@ -142,7 +136,7 @@ func upgradePortmasterControl() error {
|
|||
}
|
||||
parentPath, err := parent.Exe()
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not get parent process path for upgrade: %s", err)
|
||||
return fmt.Errorf("could not get parent process path for upgrade: %w", err)
|
||||
}
|
||||
err = upgradeFile(parentPath, pmCtrlUpdate)
|
||||
if err != nil {
|
||||
|
@ -190,7 +184,7 @@ func upgradeFile(fileToUpgrade string, file *updater.File) error {
|
|||
// ensure tmp dir is here
|
||||
err = registry.TmpDir().Ensure()
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to check updates tmp dir for moving file that needs upgrade: %s", err)
|
||||
return fmt.Errorf("could not prepare tmp directory for moving file that needs upgrade: %w", err)
|
||||
}
|
||||
|
||||
// maybe we're on windows and it's in use, try moving
|
||||
|
@ -204,17 +198,17 @@ func upgradeFile(fileToUpgrade string, file *updater.File) error {
|
|||
),
|
||||
))
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to move file that needs upgrade: %s", err)
|
||||
return fmt.Errorf("unable to move file that needs upgrade: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// copy upgrade
|
||||
err = copyFile(file.Path(), fileToUpgrade)
|
||||
err = CopyFile(file.Path(), fileToUpgrade)
|
||||
if err != nil {
|
||||
// try again
|
||||
time.Sleep(1 * time.Second)
|
||||
err = copyFile(file.Path(), fileToUpgrade)
|
||||
err = CopyFile(file.Path(), fileToUpgrade)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -224,23 +218,31 @@ func upgradeFile(fileToUpgrade string, file *updater.File) error {
|
|||
if !onWindows {
|
||||
info, err := os.Stat(fileToUpgrade)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get file info on %s: %s", fileToUpgrade, err)
|
||||
return fmt.Errorf("failed to get file info on %s: %w", fileToUpgrade, err)
|
||||
}
|
||||
if info.Mode() != 0755 {
|
||||
err := os.Chmod(fileToUpgrade, 0755)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set permissions on %s: %s", fileToUpgrade, err)
|
||||
return fmt.Errorf("failed to set permissions on %s: %w", fileToUpgrade, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func copyFile(srcPath, dstPath string) (err error) {
|
||||
// CopyFile atomically copies a file using the update registry's tmp dir.
|
||||
func CopyFile(srcPath, dstPath string) (err error) {
|
||||
|
||||
// check tmp dir
|
||||
err = registry.TmpDir().Ensure()
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not prepare tmp directory for copying file: %w", err)
|
||||
}
|
||||
|
||||
// open file for writing
|
||||
atomicDstFile, err := renameio.TempFile(registry.TmpDir().Path, dstPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not create temp file for atomic copy: %s", err)
|
||||
return fmt.Errorf("could not create temp file for atomic copy: %w", err)
|
||||
}
|
||||
defer atomicDstFile.Cleanup() //nolint:errcheck // ignore error for now, tmp dir will be cleaned later again anyway
|
||||
|
||||
|
@ -260,7 +262,7 @@ func copyFile(srcPath, dstPath string) (err error) {
|
|||
// finalize file
|
||||
err = atomicDstFile.CloseAtomicallyReplace()
|
||||
if err != nil {
|
||||
return fmt.Errorf("updates: failed to finalize copy to file %s: %s", dstPath, err)
|
||||
return fmt.Errorf("updates: failed to finalize copy to file %s: %w", dstPath, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
|
Loading…
Add table
Reference in a new issue