Merge pull request #52 from safing/feature/firewall-resolver-improvements

Firewall and Resolver improvements
This commit is contained in:
Patrick Pacher 2020-05-20 18:07:48 +02:00 committed by GitHub
commit c59f680053
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
70 changed files with 2288 additions and 1549 deletions

107
Gopkg.lock generated
View file

@ -2,12 +2,12 @@
[[projects]]
digest = "1:f82b8ac36058904227087141017bb82f4b0fc58272990a4cdae3e2d6d222644e"
digest = "1:6146fda730c18186631e91e818d995e759e7cbe27644d6871ccd469f6865c686"
name = "github.com/StackExchange/wmi"
packages = ["."]
pruneopts = ""
revision = "5d049714c4a64225c3c79a7cf7d02f7fb5b96338"
version = "1.0.0"
revision = "cbe66965904dbe8a6cd589e2298e5d8b986bd7dd"
version = "1.1.0"
[[projects]]
digest = "1:e010d6b45ee6c721df761eae89961c634ceb55feff166a48d15504729309f267"
@ -18,12 +18,12 @@
version = "v1.1.1"
[[projects]]
digest = "1:3c753679736345f50125ae993e0a2614da126859921ea7faeecda6d217501ce2"
digest = "1:21caed545a1c7ef7a2627bbb45989f689872ff6d5087d49c31340ce74c36de59"
name = "github.com/agext/levenshtein"
packages = ["."]
pruneopts = ""
revision = "0ded9c86537917af2ff89bc9c78de6bd58477894"
version = "v1.2.2"
revision = "52c14c47d03211d8ac1834e94601635e07c5a6ef"
version = "v1.2.3"
[[projects]]
branch = "v2.1"
@ -34,12 +34,12 @@
revision = "d27c04069d0d5dfe11c202dacbf745ae8d1ab181"
[[projects]]
digest = "1:166e24c91c2732657d2f791d3ee3897e7d85ece7cbb62ad991250e6b51fc1d4c"
digest = "1:f384a8b6f89c502229e9013aa4f89ce5b5b56f09f9a4d601d7f1f026d3564fbf"
name = "github.com/coreos/go-iptables"
packages = ["iptables"]
pruneopts = ""
revision = "78b5fff24e6df8886ef8eca9411f683a884349a5"
version = "v0.4.1"
revision = "f901d6c2a4f2a4df092b98c33366dfba1f93d7a0"
version = "v0.4.5"
[[projects]]
digest = "1:0deddd908b6b4b768cfc272c16ee61e7088a60f7fe2f06c547bd3d8e1f8b8e77"
@ -61,12 +61,12 @@
version = "v1.2.4"
[[projects]]
digest = "1:cc1255e2fef3819bfab3540277001e602892dd431ef9ab5499bcdbc425923d64"
digest = "1:f63933986e63230fc32512ed00bc18ea4dbb0f57b5da18561314928fd20c2ff0"
name = "github.com/godbus/dbus"
packages = ["."]
pruneopts = ""
revision = "2ff6f7ffd60f0f2410b3105864bdd12c7894f844"
version = "v5.0.1"
revision = "37bf87eef99d69c4f1d3528bd66e3a87dc201472"
version = "v5.0.3"
[[projects]]
digest = "1:e85e59c4152d8576341daf54f40d96c404c264e04941a4a36b97a0f427eb9e5e"
@ -113,20 +113,20 @@
revision = "2bc1f35cddc0cc527b4bc3dce8578fc2a6c11384"
[[projects]]
digest = "1:0b6694f306890ddbb69c96a16776510bd24e07436fae3f9b0a4e5b650f1e6fb7"
branch = "master"
digest = "1:c140772b00f0c26cf6627aee32f62d9f9d89dffcda648861266c482c36a5344a"
name = "github.com/miekg/dns"
packages = ["."]
pruneopts = ""
revision = "b13675009d59c97f3721247d9efa8914e1866a5b"
version = "v1.1.15"
revision = "b7703d0fa022e159d01efa2de82e6173d5ec04c8"
[[projects]]
digest = "1:3819cd861b7abd7d12dc1ea52ecb998ad1171826a76ecf0aefa09545781091f9"
digest = "1:b962a528cbecf7662bee4d84a600f7a0a6a130368666d7d461757ba4d1341906"
name = "github.com/oschwald/maxminddb-golang"
packages = ["."]
pruneopts = ""
revision = "2905694a1b00c5574f1418a7dbf8a22a7d247559"
version = "v1.3.1"
revision = "6a033e62c03b7dab4c37f7c9eb2ebb3b10e8f13a"
version = "v1.6.0"
[[projects]]
digest = "1:256484dbbcd271f9ecebc6795b2df8cad4c458dd0f5fd82a8c2fa0c29f233411"
@ -145,7 +145,7 @@
version = "v1.2.0"
[[projects]]
digest = "1:8bf42eb2ded52ed2678b0716dbfbf30628765bc12b13222c4d5669ba4c1310e4"
digest = "1:16f319cf21ddf49f27b3a2093d68316840dc25ec5c2a0a431a4a4fc01ea707e2"
name = "github.com/shirou/gopsutil"
packages = [
"cpu",
@ -155,32 +155,24 @@
"process",
]
pruneopts = ""
revision = "4c8b404ee5c53b04b04f34b1744a26bf5d2910de"
version = "v2.19.6"
revision = "a81cf97fce2300934e6c625b9917103346c26ba3"
version = "v2.20.4"
[[projects]]
branch = "master"
digest = "1:99c6a6dab47067c9b898e8c8b13d130c6ab4ffbcc4b7cc6236c2cd0b1e344f5b"
name = "github.com/shirou/w32"
packages = ["."]
pruneopts = ""
revision = "bb4de0191aa41b5507caa14b0650cdbddcd9280b"
[[projects]]
digest = "1:0c63b3c7ad6d825a898f28cb854252a3b29d37700c68a117a977263f5ec94efe"
digest = "1:bff75d4f1a2d2c4b8f4b46ff5ac230b80b5fa49276f615900cba09fe4c97e66e"
name = "github.com/spf13/cobra"
packages = ["."]
pruneopts = ""
revision = "f2b07da1e2c38d5f12845a4f607e2e1018cbb1f5"
version = "v0.0.5"
revision = "a684a6d7f5e37385d954dd3b5a14fc6912c6ab9d"
version = "v1.0.0"
[[projects]]
digest = "1:cbaf13cdbfef0e4734ed8a7504f57fe893d471d62a35b982bf6fb3f036449a66"
digest = "1:688428eeb1ca80d92599eb3254bdf91b51d7e232fead3a73844c1f201a281e51"
name = "github.com/spf13/pflag"
packages = ["."]
pruneopts = ""
revision = "298182f68c66c05229eb03ac171abe6e309ee79a"
version = "v1.0.3"
revision = "2e9d26c8c37aae03e3f9d4e90b7116f5accb7cab"
version = "v1.0.5"
[[projects]]
digest = "1:cc4eb6813da8d08694e557fcafae8fcc24f47f61a0717f952da130ca9a486dfc"
@ -208,18 +200,18 @@
[[projects]]
branch = "master"
digest = "1:086760278d762dbb0e9a26e09b57f04c89178c86467d8d94fae47d64c222f328"
digest = "1:bf61fa9b53be5ce096004599b957e5957b28a5e421b724250aa06ecb7ee6dc57"
name = "golang.org/x/crypto"
packages = [
"ed25519",
"ed25519/internal/edwards25519",
]
pruneopts = ""
revision = "4def268fd1a49955bfb3dda92fe3db4f924f2285"
revision = "4b2356b1ed79e6be3deca3737a3db3d132d2847a"
[[projects]]
branch = "master"
digest = "1:31cd6e3c114e17c5f0c9e8b0bcaa3025ab3c221ce36323c7ce1acaa753d0d0aa"
digest = "1:ea84836e35d7a66c9b8944796295912509c80c921244bc4e098c5417219895f2"
name = "golang.org/x/net"
packages = [
"bpf",
@ -232,7 +224,7 @@
"publicsuffix",
]
pruneopts = ""
revision = "da137c7871d730100384dbcf36e6f8fa493aef5b"
revision = "7e3656a0809f6f95abd88ac65313578f80b00df2"
[[projects]]
branch = "master"
@ -244,9 +236,10 @@
[[projects]]
branch = "master"
digest = "1:2579a16d8afda9c9a475808c13324f5e572852e8927905ffa15bb14e71baba4f"
digest = "1:acb3b56e190190ac9497faf5f0c30c5da4d3e8278d6b7a7042f2aa3332ff7022"
name = "golang.org/x/sys"
packages = [
"internal/unsafeheader",
"unix",
"windows",
"windows/registry",
@ -256,7 +249,7 @@
"windows/svc/mgr",
]
pruneopts = ""
revision = "04f50cda93cbb67f2afa353c52f342100e80e625"
revision = "bc7a7d42d5c30f4d0fe808715c002826ce2c624e"
[[projects]]
digest = "1:740b51a55815493a8d0f2b1e0d0ae48fe48953bf7eaf3fcc4198823bf67768c0"
@ -283,6 +276,38 @@
revision = "342b2e1fbaa52c93f31447ad2c6abc048c63e475"
version = "v0.3.2"
[[projects]]
branch = "master"
digest = "1:3416c611e00178b07c8fc347ba96398e4d6709fe7d3fab17f0b0fa6f933b4bd1"
name = "golang.org/x/tools"
packages = [
"go/ast/astutil",
"go/gcexportdata",
"go/internal/gcimporter",
"go/internal/packagesdriver",
"go/packages",
"go/types/typeutil",
"internal/event",
"internal/event/core",
"internal/event/keys",
"internal/event/label",
"internal/gocommand",
"internal/packagesinternal",
]
pruneopts = ""
revision = "b8469989bc69e50ec6dc4e4513fc3ff9ce48b8af"
[[projects]]
branch = "master"
digest = "1:9d4ac09a835404ae9306c6e1493cf800ecbb0f3f828f4333b3e055de4c962eea"
name = "golang.org/x/xerrors"
packages = [
".",
"internal",
]
pruneopts = ""
revision = "9bdfabe68543c54f90421aeb9a60ef8061b5b544"
[[projects]]
digest = "1:2efc9662a6a1ff28c65c84fc2f9030f13d3afecdb2ecad445f3b0c80e75fc281"
name = "gopkg.in/yaml.v2"

View file

@ -25,3 +25,7 @@
# unused-packages = true
ignored = ["github.com/safing/portbase/*"]
[[constraint]]
name = "github.com/miekg/dns"
branch = "master" # switch back to semver releases when https://github.com/miekg/dns/pull/1110 is released

View file

@ -60,7 +60,18 @@ func apiAuthenticator(s *http.Server, r *http.Request) (grantAccess bool, err er
var procsChecked []string
// get process
proc, err := process.GetProcessByEndpoints(r.Context(), remoteIP, remotePort, localIP, localPort, packet.TCP) // switch reverse/local to get remote process
proc, _, err := process.GetProcessByConnection(
r.Context(),
&packet.Info{
Inbound: false, // outbound as we are looking for the process of the source address
Version: packet.IPv4,
Protocol: packet.TCP,
Src: remoteIP, // source as in the process we are looking for
SrcPort: remotePort, // source as in the process we are looking for
Dst: localIP,
DstPort: localPort,
},
)
if err != nil {
return false, fmt.Errorf("failed to get process: %s", err)
}

View file

@ -48,7 +48,7 @@ func registerConfig() error {
Order: CfgOptionAskWithSystemNotificationsOrder,
OptType: config.OptTypeBool,
ExpertiseLevel: config.ExpertiseLevelUser,
ReleaseLevel: config.ReleaseLevelStable,
ReleaseLevel: config.ReleaseLevelExperimental,
DefaultValue: true,
})
if err != nil {
@ -62,7 +62,7 @@ func registerConfig() error {
Order: CfgOptionAskTimeoutOrder,
OptType: config.OptTypeInt,
ExpertiseLevel: config.ExpertiseLevelUser,
ReleaseLevel: config.ReleaseLevelStable,
ReleaseLevel: config.ReleaseLevelExperimental,
DefaultValue: 60,
})
if err != nil {

View file

@ -138,7 +138,7 @@ func handlePacket(pkt packet.Packet) {
// pkt.RedirToNameserver()
// }
// allow ICMP, IGMP and DHCP
// allow ICMP and DHCP
// TODO: actually handle these
switch meta.Protocol {
case packet.ICMP:
@ -149,10 +149,6 @@ func handlePacket(pkt packet.Packet) {
log.Debugf("accepting ICMPv6: %s", pkt)
_ = pkt.PermanentAccept()
return
case packet.IGMP:
log.Debugf("accepting IGMP: %s", pkt)
_ = pkt.PermanentAccept()
return
case packet.UDP:
if meta.DstPort == 67 || meta.DstPort == 68 {
log.Debugf("accepting DHCP: %s", pkt)
@ -218,6 +214,7 @@ func initialHandler(conn *network.Connection, pkt packet.Packet) {
// reroute dns requests to nameserver
if conn.Process().Pid != os.Getpid() && pkt.IsOutbound() && pkt.Info().DstPort == 53 && !pkt.Info().Src.Equal(pkt.Info().Dst) {
conn.Verdict = network.VerdictRerouteToNameserver
conn.Internal = true
conn.StopFirewallHandler()
issueVerdict(conn, pkt, 0, true)
return
@ -233,7 +230,7 @@ func initialHandler(conn *network.Connection, pkt packet.Packet) {
}
log.Tracer(pkt.Ctx()).Trace("filter: starting decision process")
DecideOnConnection(conn, pkt)
DecideOnConnection(pkt.Ctx(), conn, pkt)
conn.Inspecting = false // TODO: enable inspecting again
switch {

View file

@ -62,7 +62,7 @@ func Handler(packets chan packet.Packet) {
}
info := new.Info()
info.Direction = packetInfo.direction > 0
info.Inbound = packetInfo.direction > 0
info.InTunnel = false
info.Protocol = packet.IPProtocol(packetInfo.protocol)
@ -76,7 +76,7 @@ func Handler(packets chan packet.Packet) {
// IPs
if info.Version == packet.IPv4 {
// IPv4
if info.Direction {
if info.Inbound {
// Inbound
info.Src = convertIPv4(packetInfo.remoteIP)
info.Dst = convertIPv4(packetInfo.localIP)
@ -87,7 +87,7 @@ func Handler(packets chan packet.Packet) {
}
} else {
// IPv6
if info.Direction {
if info.Inbound {
// Inbound
info.Src = convertIPv6(packetInfo.remoteIP)
info.Dst = convertIPv6(packetInfo.localIP)
@ -99,7 +99,7 @@ func Handler(packets chan packet.Packet) {
}
// Ports
if info.Direction {
if info.Inbound {
// Inbound
info.SrcPort = packetInfo.remotePort
info.DstPort = packetInfo.localPort
@ -113,19 +113,17 @@ func Handler(packets chan packet.Packet) {
}
}
// convertIPv4 as needed for data from the kernel
func convertIPv4(input [4]uint32) net.IP {
return net.IPv4(
uint8(input[0]>>24&0xFF),
uint8(input[0]>>16&0xFF),
uint8(input[0]>>8&0xFF),
uint8(input[0]&0xFF),
)
addressBuf := make([]byte, 4)
binary.BigEndian.PutUint32(addressBuf, input[0])
return net.IP(addressBuf)
}
func convertIPv6(input [4]uint32) net.IP {
addressBuf := make([]byte, 16)
for i := 0; i < 4; i++ {
binary.LittleEndian.PutUint32(addressBuf[i*4:i*4+4], input[i])
binary.BigEndian.PutUint32(addressBuf[i*4:i*4+4], input[i])
}
return net.IP(addressBuf)
}

View file

@ -1,6 +1,7 @@
package firewall
import (
"context"
"fmt"
"os"
"path/filepath"
@ -10,6 +11,7 @@ import (
"github.com/safing/portmaster/network"
"github.com/safing/portmaster/network/netutils"
"github.com/safing/portmaster/network/packet"
"github.com/safing/portmaster/network/state"
"github.com/safing/portmaster/process"
"github.com/safing/portmaster/profile"
"github.com/safing/portmaster/profile/endpoints"
@ -32,10 +34,10 @@ import (
// DecideOnConnection makes a decision about a connection.
// When called, the connection and profile is already locked.
func DecideOnConnection(conn *network.Connection, pkt packet.Packet) {
func DecideOnConnection(ctx context.Context, conn *network.Connection, pkt packet.Packet) {
// update profiles and check if communication needs reevaluation
if conn.UpdateAndCheck() {
log.Infof("filter: re-evaluating verdict on %s", conn)
log.Tracer(ctx).Infof("filter: re-evaluating verdict on %s", conn)
conn.Verdict = network.VerdictUndecided
if conn.Entity != nil {
@ -43,7 +45,7 @@ func DecideOnConnection(conn *network.Connection, pkt packet.Packet) {
}
}
var deciders = []func(*network.Connection, packet.Packet) bool{
var deciders = []func(context.Context, *network.Connection, packet.Packet) bool{
checkPortmasterConnection,
checkSelfCommunication,
checkProfileExists,
@ -59,7 +61,7 @@ func DecideOnConnection(conn *network.Connection, pkt packet.Packet) {
}
for _, decider := range deciders {
if decider(conn, pkt) {
if decider(ctx, conn, pkt) {
return
}
}
@ -70,10 +72,10 @@ func DecideOnConnection(conn *network.Connection, pkt packet.Packet) {
// checkPortmasterConnection allows all connection that originate from
// portmaster itself.
func checkPortmasterConnection(conn *network.Connection, _ packet.Packet) bool {
func checkPortmasterConnection(ctx context.Context, conn *network.Connection, pkt packet.Packet) bool {
// grant self
if conn.Process().Pid == os.Getpid() {
log.Infof("filter: granting own connection %s", conn)
log.Tracer(ctx).Infof("filter: granting own connection %s", conn)
conn.Verdict = network.VerdictAccept
conn.Internal = true
return true
@ -83,27 +85,29 @@ func checkPortmasterConnection(conn *network.Connection, _ packet.Packet) bool {
}
// checkSelfCommunication checks if the process is communicating with itself.
func checkSelfCommunication(conn *network.Connection, pkt packet.Packet) bool {
func checkSelfCommunication(ctx context.Context, conn *network.Connection, pkt packet.Packet) bool {
// check if process is communicating with itself
if pkt != nil {
// TODO: evaluate the case where different IPs in the 127/8 net are used.
pktInfo := pkt.Info()
if conn.Process().Pid >= 0 && pktInfo.Src.Equal(pktInfo.Dst) {
// get PID
otherPid, _, err := process.GetPidByEndpoints(
pktInfo.RemoteIP(),
pktInfo.RemotePort(),
pktInfo.LocalIP(),
pktInfo.LocalPort(),
pktInfo.Protocol,
)
otherPid, _, err := state.Lookup(&packet.Info{
Inbound: !pktInfo.Inbound, // we want to know the process on the other end
Version: pktInfo.Version,
Protocol: pktInfo.Protocol,
Src: pktInfo.Src,
SrcPort: pktInfo.SrcPort,
Dst: pktInfo.Dst,
DstPort: pktInfo.DstPort,
})
if err != nil {
log.Warningf("filter: failed to find local peer process PID: %s", err)
log.Tracer(ctx).Warningf("filter: failed to find local peer process PID: %s", err)
} else {
// get primary process
otherProcess, err := process.GetOrFindPrimaryProcess(pkt.Ctx(), otherPid)
otherProcess, err := process.GetOrFindPrimaryProcess(ctx, otherPid)
if err != nil {
log.Warningf("filter: failed to find load local peer process with PID %d: %s", otherPid, err)
log.Tracer(ctx).Warningf("filter: failed to find load local peer process with PID %d: %s", otherPid, err)
} else if otherProcess.Pid == conn.Process().Pid {
conn.Accept("connection to self")
conn.Internal = true
@ -116,7 +120,7 @@ func checkSelfCommunication(conn *network.Connection, pkt packet.Packet) bool {
return false
}
func checkProfileExists(conn *network.Connection, _ packet.Packet) bool {
func checkProfileExists(_ context.Context, conn *network.Connection, _ packet.Packet) bool {
if conn.Process().Profile() == nil {
conn.Block("unknown process or profile")
return true
@ -124,7 +128,7 @@ func checkProfileExists(conn *network.Connection, _ packet.Packet) bool {
return false
}
func checkEndpointLists(conn *network.Connection, _ packet.Packet) bool {
func checkEndpointLists(_ context.Context, conn *network.Connection, _ packet.Packet) bool {
var result endpoints.EPResult
var reason endpoints.Reason
@ -149,12 +153,12 @@ func checkEndpointLists(conn *network.Connection, _ packet.Packet) bool {
return false
}
func checkConnectionType(conn *network.Connection, _ packet.Packet) bool {
func checkConnectionType(ctx context.Context, conn *network.Connection, _ packet.Packet) bool {
p := conn.Process().Profile()
// check conn type
switch conn.Scope {
case network.IncomingHost, network.IncomingLAN, network.IncomingInternet, network.IncomingInvalid:
case network.IncomingLAN, network.IncomingInternet, network.IncomingInvalid:
if p.BlockInbound() {
if conn.Scope == network.IncomingHost {
conn.Block("inbound connections blocked")
@ -174,7 +178,7 @@ func checkConnectionType(conn *network.Connection, _ packet.Packet) bool {
return false
}
func checkConnectionScope(conn *network.Connection, _ packet.Packet) bool {
func checkConnectionScope(_ context.Context, conn *network.Connection, _ packet.Packet) bool {
p := conn.Process().Profile()
// check scopes
@ -213,7 +217,7 @@ func checkConnectionScope(conn *network.Connection, _ packet.Packet) bool {
return false
}
func checkBypassPrevention(conn *network.Connection, _ packet.Packet) bool {
func checkBypassPrevention(_ context.Context, conn *network.Connection, _ packet.Packet) bool {
if conn.Process().Profile().PreventBypassing() {
// check for bypass protection
result, reason, reasonCtx := PreventBypassing(conn)
@ -230,7 +234,7 @@ func checkBypassPrevention(conn *network.Connection, _ packet.Packet) bool {
return false
}
func checkFilterLists(conn *network.Connection, _ packet.Packet) bool {
func checkFilterLists(ctx context.Context, conn *network.Connection, pkt packet.Packet) bool {
// apply privacy filter lists
p := conn.Process().Profile()
@ -242,12 +246,12 @@ func checkFilterLists(conn *network.Connection, _ packet.Packet) bool {
case endpoints.NoMatch:
// nothing to do
default:
log.Debugf("filter: filter lists returned unsupported verdict: %s", result)
log.Tracer(ctx).Debugf("filter: filter lists returned unsupported verdict: %s", result)
}
return false
}
func checkInbound(conn *network.Connection, _ packet.Packet) bool {
func checkInbound(_ context.Context, conn *network.Connection, _ packet.Packet) bool {
// implicit default=block for inbound
if conn.Inbound {
conn.Drop("endpoint is not whitelisted (incoming is always default=block)")
@ -256,7 +260,7 @@ func checkInbound(conn *network.Connection, _ packet.Packet) bool {
return false
}
func checkDefaultPermit(conn *network.Connection, _ packet.Packet) bool {
func checkDefaultPermit(_ context.Context, conn *network.Connection, _ packet.Packet) bool {
// check default action
p := conn.Process().Profile()
if p.DefaultAction() == profile.DefaultActionPermit {
@ -266,7 +270,7 @@ func checkDefaultPermit(conn *network.Connection, _ packet.Packet) bool {
return false
}
func checkAutoPermitRelated(conn *network.Connection, _ packet.Packet) bool {
func checkAutoPermitRelated(_ context.Context, conn *network.Connection, _ packet.Packet) bool {
p := conn.Process().Profile()
if !p.DisableAutoPermit() {
related, reason := checkRelation(conn)
@ -278,7 +282,7 @@ func checkAutoPermitRelated(conn *network.Connection, _ packet.Packet) bool {
return false
}
func checkDefaultAction(conn *network.Connection, pkt packet.Packet) bool {
func checkDefaultAction(_ context.Context, conn *network.Connection, pkt packet.Packet) bool {
p := conn.Process().Profile()
if p.DefaultAction() == profile.DefaultActionAsk {
prompt(conn, pkt)

View file

@ -71,9 +71,9 @@ func (br ListBlockReason) GetExtraRR(_ *dns.Msg, _ string, _ interface{}) []dns.
for _, lm := range br {
blockedBy, err := dns.NewRR(fmt.Sprintf(
"%s-blockedBy. 0 IN TXT %q",
strings.TrimRight(lm.Entity, "."),
strings.Join(lm.ActiveLists, ","),
`%s 0 IN TXT "blocked by filter lists %s"`,
lm.Entity,
strings.Join(lm.ActiveLists, ", "),
))
if err == nil {
rrs = append(rrs, blockedBy)
@ -83,9 +83,9 @@ func (br ListBlockReason) GetExtraRR(_ *dns.Msg, _ string, _ interface{}) []dns.
if len(lm.InactiveLists) > 0 {
wouldBeBlockedBy, err := dns.NewRR(fmt.Sprintf(
"%s-wouldBeBlockedBy. 0 IN TXT %q",
strings.TrimRight(lm.Entity, "."),
strings.Join(lm.InactiveLists, ","),
`%s 0 IN TXT "would be blocked by filter lists %s"`,
lm.Entity,
strings.Join(lm.InactiveLists, ", "),
))
if err == nil {
rrs = append(rrs, wouldBeBlockedBy)

View file

@ -6,6 +6,8 @@ import (
"net"
"strings"
"github.com/safing/portmaster/network/packet"
"github.com/safing/portbase/modules/subsystems"
"github.com/safing/portbase/log"
@ -22,9 +24,8 @@ import (
)
var (
module *modules.Module
dnsServer *dns.Server
mtDNSRequest = "dns request"
module *modules.Module
dnsServer *dns.Server
listenAddress = "0.0.0.0:53"
ipv4Localhost = net.IPv4(127, 0, 0, 1)
@ -61,7 +62,7 @@ func prep() error {
func start() error {
dnsServer = &dns.Server{Addr: listenAddress, Net: "udp"}
dns.HandleFunc(".", handleRequestAsMicroTask)
dns.HandleFunc(".", handleRequestAsWorker)
module.StartServiceWorker("dns resolver", 0, func(ctx context.Context) error {
err := dnsServer.ListenAndServe()
@ -95,8 +96,8 @@ func returnServerFailure(w dns.ResponseWriter, query *dns.Msg) {
_ = w.WriteMsg(m)
}
func handleRequestAsMicroTask(w dns.ResponseWriter, query *dns.Msg) {
err := module.RunMicroTask(&mtDNSRequest, func(ctx context.Context) error {
func handleRequestAsWorker(w dns.ResponseWriter, query *dns.Msg) {
err := module.RunWorker("dns request", func(ctx context.Context) error {
return handleRequest(ctx, w, query)
})
if err != nil {
@ -168,12 +169,13 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, query *dns.Msg) er
// start tracer
ctx, tracer := log.AddTracer(ctx)
tracer.Tracef("nameserver: handling new request for %s%s from %s:%d", q.FQDN, q.QType, remoteAddr.IP, remoteAddr.Port)
defer tracer.Submit()
tracer.Tracef("nameserver: handling new request for %s%s from %s:%d, getting connection", q.FQDN, q.QType, remoteAddr.IP, remoteAddr.Port)
// TODO: if there are 3 request for the same domain/type in a row, delete all caches of that domain
// get connection
conn := network.NewConnectionFromDNSRequest(ctx, q.FQDN, nil, remoteAddr.IP, uint16(remoteAddr.Port))
conn := network.NewConnectionFromDNSRequest(ctx, q.FQDN, nil, packet.IPv4, remoteAddr.IP, uint16(remoteAddr.Port))
// once we decided on the connection we might need to save it to the database
// so we defer that check right now.
@ -191,7 +193,7 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, query *dns.Msg) er
return
default:
log.Warningf("nameserver: unexpected verdict %s for connection %s, not saving", conn.Verdict, conn)
tracer.Warningf("nameserver: unexpected verdict %s for connection %s, not saving", conn.Verdict, conn)
}
}()
@ -220,7 +222,7 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, query *dns.Msg) er
}
// check profile before we even get intel and rr
firewall.DecideOnConnection(conn, nil)
firewall.DecideOnConnection(ctx, conn, nil)
switch conn.Verdict {
case network.VerdictBlock:
@ -242,7 +244,7 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, query *dns.Msg) er
tracer.Infof("nameserver: %s handing over to reason-responder: %s", q.FQDN, conn.Reason)
reply := responder.ReplyWithDNS(query, conn.Reason, conn.ReasonContext)
if err := w.WriteMsg(reply); err != nil {
log.Warningf("nameserver: failed to return response %s%s to %s: %s", q.FQDN, q.QType, conn.Process(), err)
tracer.Warningf("nameserver: failed to return response %s%s to %s: %s", q.FQDN, q.QType, conn.Process(), err)
} else {
tracer.Debugf("nameserver: returning response %s%s to %s", q.FQDN, q.QType, conn.Process())
}
@ -269,6 +271,7 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, query *dns.Msg) er
return nil
}
tracer.Trace("nameserver: deciding on resolved dns")
rrCache = firewall.DecideOnResolvedDNS(conn, q, rrCache)
if rrCache == nil {
sendResponse(w, query, conn.Verdict, conn.Reason, conn.ReasonContext)
@ -283,7 +286,7 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, query *dns.Msg) er
m.Extra = rrCache.Extra
if err := w.WriteMsg(m); err != nil {
log.Warningf("nameserver: failed to return response %s%s to %s: %s", q.FQDN, q.QType, conn.Process(), err)
tracer.Warningf("nameserver: failed to return response %s%s to %s: %s", q.FQDN, q.QType, conn.Process(), err)
} else {
tracer.Debugf("nameserver: returning response %s%s to %s", q.FQDN, q.QType, conn.Process())
}

View file

@ -10,7 +10,7 @@ import (
"github.com/safing/portbase/modules"
"github.com/safing/portbase/notifications"
"github.com/safing/portmaster/network/packet"
"github.com/safing/portmaster/process"
"github.com/safing/portmaster/network/state"
)
var (
@ -58,7 +58,15 @@ func checkForConflictingService() error {
}
func takeover(resolverIP net.IP) (int, error) {
pid, _, err := process.GetPidByEndpoints(resolverIP, 53, resolverIP, 65535, packet.UDP)
pid, _, err := state.Lookup(&packet.Info{
Inbound: true,
Version: 0, // auto-detect
Protocol: packet.UDP,
Src: nil, // do not record direction
SrcPort: 0, // do not record direction
Dst: resolverIP,
DstPort: 53,
})
if err != nil {
// there may be nothing listening on :53
return 0, nil

View file

@ -201,20 +201,11 @@ func triggerOnlineStatusInvestigation() {
func monitorOnlineStatus(ctx context.Context) error {
for {
timeout := 5 * time.Minute
/*
if GetOnlineStatus() != StatusOnline {
timeout = time.Second
log.Debugf("checking online status again in %s because current status is %s", timeout, GetOnlineStatus())
}
*/
// wait for trigger
select {
case <-ctx.Done():
return nil
case <-onlineStatusInvestigationTrigger:
case <-time.After(timeout):
}
// enable waiting

View file

@ -4,6 +4,10 @@ import (
"context"
"time"
"github.com/safing/portmaster/network/packet"
"github.com/safing/portmaster/network/state"
"github.com/safing/portbase/log"
"github.com/safing/portmaster/process"
)
@ -22,8 +26,12 @@ func connectionCleaner(ctx context.Context) error {
ticker.Stop()
return nil
case <-ticker.C:
// clean connections and processes
activePIDs := cleanConnections()
process.CleanProcessStorage(activePIDs)
// clean udp connection states
state.CleanUDPStates(ctx)
}
}
}
@ -33,13 +41,10 @@ func cleanConnections() (activePIDs map[int]struct{}) {
name := "clean connections" // TODO: change to new fn
_ = module.RunMediumPriorityMicroTask(&name, func(ctx context.Context) error {
activeIDs := make(map[string]struct{})
for _, cID := range process.GetActiveConnectionIDs() {
activeIDs[cID] = struct{}{}
}
now := time.Now().Unix()
deleteOlderThan := time.Now().Add(-deleteConnsAfterEndedThreshold).Unix()
now := time.Now().UTC()
nowUnix := now.Unix()
deleteOlderThan := now.Add(-deleteConnsAfterEndedThreshold).Unix()
// lock both together because we cannot fully guarantee in which map a connection lands
// of course every connection should land in the correct map, but this increases resilience
@ -49,20 +54,27 @@ func cleanConnections() (activePIDs map[int]struct{}) {
defer dnsConnsLock.Unlock()
// network connections
for key, conn := range conns {
for _, conn := range conns {
conn.Lock()
// delete inactive connections
switch {
case conn.Ended == 0:
// Step 1: check if still active
_, ok := activeIDs[key]
if ok {
activePIDs[conn.process.Pid] = struct{}{}
} else {
exists := state.Exists(&packet.Info{
Inbound: false, // src == local
Version: conn.IPVersion,
Protocol: conn.IPProtocol,
Src: conn.LocalIP,
SrcPort: conn.LocalPort,
Dst: conn.Entity.IP,
DstPort: conn.Entity.Port,
}, now)
activePIDs[conn.process.Pid] = struct{}{}
if !exists {
// Step 2: mark end
activePIDs[conn.process.Pid] = struct{}{}
conn.Ended = now
conn.Ended = nowUnix
conn.Save()
}
case conn.Ended < deleteOlderThan:

View file

@ -25,11 +25,19 @@ type Connection struct { //nolint:maligned // TODO: fix alignment
record.Base
sync.Mutex
ID string
Scope string
Inbound bool
Entity *intel.Entity // needs locking, instance is never shared
process *process.Process
ID string
Scope string
IPVersion packet.IPVersion
Inbound bool
// local endpoint
IPProtocol packet.IPProtocol
LocalIP net.IP
LocalPort uint16
process *process.Process
// remote endpoint
Entity *intel.Entity
Verdict Verdict
Reason string
@ -55,11 +63,22 @@ type Connection struct { //nolint:maligned // TODO: fix alignment
}
// NewConnectionFromDNSRequest returns a new connection based on the given dns request.
func NewConnectionFromDNSRequest(ctx context.Context, fqdn string, cnames []string, localIP net.IP, localPort uint16) *Connection {
func NewConnectionFromDNSRequest(ctx context.Context, fqdn string, cnames []string, ipVersion packet.IPVersion, localIP net.IP, localPort uint16) *Connection {
// get Process
proc, err := process.GetProcessByEndpoints(ctx, localIP, localPort, dnsAddress, dnsPort, packet.UDP)
proc, _, err := process.GetProcessByConnection(
ctx,
&packet.Info{
Inbound: false, // outbound as we are looking for the process of the source address
Version: ipVersion,
Protocol: packet.UDP,
Src: localIP, // source as in the process we are looking for
SrcPort: localPort, // source as in the process we are looking for
Dst: nil, // do not record direction
DstPort: 0, // do not record direction
},
)
if err != nil {
log.Warningf("network: failed to find process of dns request for %s: %s", fqdn, err)
log.Debugf("network: failed to find process of dns request for %s: %s", fqdn, err)
proc = process.GetUnidentifiedProcess(ctx)
}
@ -80,9 +99,9 @@ func NewConnectionFromDNSRequest(ctx context.Context, fqdn string, cnames []stri
// NewConnectionFromFirstPacket returns a new connection based on the given packet.
func NewConnectionFromFirstPacket(pkt packet.Packet) *Connection {
// get Process
proc, inbound, err := process.GetProcessByPacket(pkt)
proc, inbound, err := process.GetProcessByConnection(pkt.Ctx(), pkt.Info())
if err != nil {
log.Warningf("network: failed to find process of packet %s: %s", pkt, err)
log.Debugf("network: failed to find process of packet %s: %s", pkt, err)
proc = process.GetUnidentifiedProcess(pkt.Ctx())
}
@ -147,12 +166,20 @@ func NewConnectionFromFirstPacket(pkt packet.Packet) *Connection {
}
return &Connection{
ID: pkt.GetConnectionID(),
Scope: scope,
Inbound: inbound,
Entity: entity,
process: proc,
Started: time.Now().Unix(),
ID: pkt.GetConnectionID(),
Scope: scope,
IPVersion: pkt.Info().Version,
Inbound: inbound,
// local endpoint
IPProtocol: pkt.Info().Protocol,
LocalIP: pkt.Info().LocalIP(),
LocalPort: pkt.Info().LocalPort(),
process: proc,
// remote endpoint
Entity: entity,
// meta
Started: time.Now().Unix(),
profileRevisionCounter: proc.Profile().RevisionCnt(),
}
}

View file

@ -5,6 +5,8 @@ import (
"strings"
"sync"
"github.com/safing/portmaster/network/state"
"github.com/safing/portbase/database"
"github.com/safing/portbase/database/iterator"
"github.com/safing/portbase/database/query"
@ -57,6 +59,14 @@ func (s *StorageInterface) Get(key string) (record.Record, error) {
return conn, nil
}
}
case "system":
if len(splitted) >= 2 {
switch splitted[1] {
case "state":
return state.GetInfo(), nil
default:
}
}
}
return nil, storage.ErrNotFound

71
network/iphelper/get.go Normal file
View file

@ -0,0 +1,71 @@
// +build windows
package iphelper
import (
"sync"
"github.com/safing/portmaster/network/socket"
)
var (
ipHelper *IPHelper
// lock locks access to the whole DLL.
// TODO: It's unproven if we can access the iphlpapi.dll concurrently, especially as we might be encountering various versions of the DLL. In the future, we could possibly investigate and improve performance here.
lock sync.RWMutex
)
// GetTCP4Table returns the system table for IPv4 TCP activity.
func GetTCP4Table() (connections []*socket.ConnectionInfo, listeners []*socket.BindInfo, err error) {
lock.Lock()
defer lock.Unlock()
err = checkIPHelper()
if err != nil {
return nil, nil, err
}
return ipHelper.getTable(IPv4, TCP)
}
// GetTCP6Table returns the system table for IPv6 TCP activity.
func GetTCP6Table() (connections []*socket.ConnectionInfo, listeners []*socket.BindInfo, err error) {
lock.Lock()
defer lock.Unlock()
err = checkIPHelper()
if err != nil {
return nil, nil, err
}
return ipHelper.getTable(IPv6, TCP)
}
// GetUDP4Table returns the system table for IPv4 UDP activity.
func GetUDP4Table() (binds []*socket.BindInfo, err error) {
lock.Lock()
defer lock.Unlock()
err = checkIPHelper()
if err != nil {
return nil, err
}
_, binds, err = ipHelper.getTable(IPv4, UDP)
return
}
// GetUDP6Table returns the system table for IPv6 UDP activity.
func GetUDP6Table() (binds []*socket.BindInfo, err error) {
lock.Lock()
defer lock.Unlock()
err = checkIPHelper()
if err != nil {
return nil, err
}
_, binds, err = ipHelper.getTable(IPv6, UDP)
return
}

View file

@ -0,0 +1,63 @@
// +build windows
package iphelper
import (
"errors"
"fmt"
"github.com/tevino/abool"
"golang.org/x/sys/windows"
)
var (
errInvalid = errors.New("IPHelper not initialzed or broken")
)
// IPHelper represents a subset of the Windows iphlpapi.dll.
type IPHelper struct {
dll *windows.LazyDLL
getExtendedTCPTable *windows.LazyProc
getExtendedUDPTable *windows.LazyProc
valid *abool.AtomicBool
}
func checkIPHelper() (err error) {
if ipHelper == nil {
ipHelper, err = New()
return err
}
return nil
}
// New returns a new IPHelper API (with an instance of iphlpapi.dll loaded).
func New() (*IPHelper, error) {
new := &IPHelper{}
new.valid = abool.NewBool(false)
var err error
// load dll
new.dll = windows.NewLazySystemDLL("iphlpapi.dll")
err = new.dll.Load()
if err != nil {
return nil, err
}
// load functions
new.getExtendedTCPTable = new.dll.NewProc("GetExtendedTcpTable")
err = new.getExtendedTCPTable.Find()
if err != nil {
return nil, fmt.Errorf("could find proc GetExtendedTcpTable: %s", err)
}
new.getExtendedUDPTable = new.dll.NewProc("GetExtendedUdpTable")
err = new.getExtendedUDPTable.Find()
if err != nil {
return nil, fmt.Errorf("could find proc GetExtendedUdpTable: %s", err)
}
new.valid.Set()
return new, nil
}

View file

@ -3,12 +3,15 @@
package iphelper
import (
"encoding/binary"
"errors"
"fmt"
"net"
"sync"
"unsafe"
"github.com/safing/portmaster/network/socket"
"golang.org/x/sys/windows"
)
@ -22,19 +25,6 @@ const (
winErrInvalidParameter = uintptr(windows.ERROR_INVALID_PARAMETER)
)
// ConnectionEntry describes a connection state table entry.
type ConnectionEntry struct {
localIP net.IP
remoteIP net.IP
localPort uint16
remotePort uint16
pid int
}
func (entry *ConnectionEntry) String() string {
return fmt.Sprintf("PID=%d %s:%d <> %s:%d", entry.pid, entry.localIP, entry.localPort, entry.remoteIP, entry.remotePort)
}
type iphelperTCPTable struct {
// docs: https://msdn.microsoft.com/en-us/library/windows/desktop/aa366921(v=vs.85).aspx
numEntries uint32
@ -148,9 +138,9 @@ func increaseBufSize() int {
return bufSize
}
// GetTables returns the current connection state table of Windows of the given protocol and IP version.
func (ipHelper *IPHelper) GetTables(protocol uint8, ipVersion uint8) (connections []*ConnectionEntry, listeners []*ConnectionEntry, err error) { //nolint:gocognit,gocycle // TODO
// docs: https://msdn.microsoft.com/en-us/library/windows/desktop/aa365928(v=vs.85).aspx
// getTable returns the current connection state table of Windows of the given protocol and IP version.
func (ipHelper *IPHelper) getTable(ipVersion, protocol uint8) (connections []*socket.ConnectionInfo, binds []*socket.BindInfo, err error) { //nolint:gocognit,gocycle // TODO
// docs: https://docs.microsoft.com/en-us/windows/win32/api/iphlpapi/nf-iphlpapi-getextendedtcptable
if !ipHelper.valid.IsSet() {
return nil, nil, errInvalid
@ -220,26 +210,27 @@ func (ipHelper *IPHelper) GetTables(protocol uint8, ipVersion uint8) (connection
table := tcpTable.table[:tcpTable.numEntries]
for _, row := range table {
new := &ConnectionEntry{}
// PID
new.pid = int(row.owningPid)
// local
if row.localAddr != 0 {
new.localIP = convertIPv4(row.localAddr)
}
new.localPort = uint16(row.localPort>>8 | row.localPort<<8)
// remote
if row.state == iphelperTCPStateListen {
listeners = append(listeners, new)
binds = append(binds, &socket.BindInfo{
Local: socket.Address{
IP: convertIPv4(row.localAddr),
Port: convertPort(row.localPort),
},
PID: int(row.owningPid),
})
} else {
new.remoteIP = convertIPv4(row.remoteAddr)
new.remotePort = uint16(row.remotePort>>8 | row.remotePort<<8)
connections = append(connections, new)
connections = append(connections, &socket.ConnectionInfo{
Local: socket.Address{
IP: convertIPv4(row.localAddr),
Port: convertPort(row.localPort),
},
Remote: socket.Address{
IP: convertIPv4(row.remoteAddr),
Port: convertPort(row.remotePort),
},
PID: int(row.owningPid),
})
}
}
case protocol == TCP && ipVersion == IPv6:
@ -248,27 +239,27 @@ func (ipHelper *IPHelper) GetTables(protocol uint8, ipVersion uint8) (connection
table := tcpTable.table[:tcpTable.numEntries]
for _, row := range table {
new := &ConnectionEntry{}
// PID
new.pid = int(row.owningPid)
// local
new.localIP = net.IP(row.localAddr[:])
new.localPort = uint16(row.localPort>>8 | row.localPort<<8)
// remote
if row.state == iphelperTCPStateListen {
if new.localIP.Equal(net.IPv6zero) {
new.localIP = nil
}
listeners = append(listeners, new)
binds = append(binds, &socket.BindInfo{
Local: socket.Address{
IP: net.IP(row.localAddr[:]),
Port: convertPort(row.localPort),
},
PID: int(row.owningPid),
})
} else {
new.remoteIP = net.IP(row.remoteAddr[:])
new.remotePort = uint16(row.remotePort>>8 | row.remotePort<<8)
connections = append(connections, new)
connections = append(connections, &socket.ConnectionInfo{
Local: socket.Address{
IP: net.IP(row.localAddr[:]),
Port: convertPort(row.localPort),
},
Remote: socket.Address{
IP: net.IP(row.remoteAddr[:]),
Port: convertPort(row.remotePort),
},
PID: int(row.owningPid),
})
}
}
case protocol == UDP && ipVersion == IPv4:
@ -277,19 +268,13 @@ func (ipHelper *IPHelper) GetTables(protocol uint8, ipVersion uint8) (connection
table := udpTable.table[:udpTable.numEntries]
for _, row := range table {
new := &ConnectionEntry{}
// PID
new.pid = int(row.owningPid)
// local
new.localPort = uint16(row.localPort>>8 | row.localPort<<8)
if row.localAddr == 0 {
listeners = append(listeners, new)
} else {
new.localIP = convertIPv4(row.localAddr)
connections = append(connections, new)
}
binds = append(binds, &socket.BindInfo{
Local: socket.Address{
IP: convertIPv4(row.localAddr),
Port: convertPort(row.localPort),
},
PID: int(row.owningPid),
})
}
case protocol == UDP && ipVersion == IPv6:
@ -298,32 +283,28 @@ func (ipHelper *IPHelper) GetTables(protocol uint8, ipVersion uint8) (connection
table := udpTable.table[:udpTable.numEntries]
for _, row := range table {
new := &ConnectionEntry{}
// PID
new.pid = int(row.owningPid)
// local
new.localIP = net.IP(row.localAddr[:])
new.localPort = uint16(row.localPort>>8 | row.localPort<<8)
if new.localIP.Equal(net.IPv6zero) {
new.localIP = nil
listeners = append(listeners, new)
} else {
connections = append(connections, new)
}
binds = append(binds, &socket.BindInfo{
Local: socket.Address{
IP: net.IP(row.localAddr[:]),
Port: convertPort(row.localPort),
},
PID: int(row.owningPid),
})
}
}
return connections, listeners, nil
return connections, binds, nil
}
// convertIPv4 as needed for iphlpapi.dll
func convertIPv4(input uint32) net.IP {
return net.IPv4(
uint8(input&0xFF),
uint8(input>>8&0xFF),
uint8(input>>16&0xFF),
uint8(input>>24&0xFF),
)
addressBuf := make([]byte, 4)
binary.LittleEndian.PutUint32(addressBuf, input)
return net.IP(addressBuf)
}
// convertPort converts ports received from iphlpapi.dll
func convertPort(input uint32) uint16 {
return uint16(input>>8 | input<<8)
}

View file

@ -0,0 +1,54 @@
// +build windows
package iphelper
import (
"fmt"
"testing"
)
func TestSockets(t *testing.T) {
connections, listeners, err := GetTCP4Table()
if err != nil {
t.Fatal(err)
}
fmt.Println("\nTCP 4 connections:")
for _, connection := range connections {
fmt.Printf("%+v\n", connection)
}
fmt.Println("\nTCP 4 listeners:")
for _, listener := range listeners {
fmt.Printf("%+v\n", listener)
}
connections, listeners, err = GetTCP6Table()
if err != nil {
t.Fatal(err)
}
fmt.Println("\nTCP 6 connections:")
for _, connection := range connections {
fmt.Printf("%+v\n", connection)
}
fmt.Println("\nTCP 6 listeners:")
for _, listener := range listeners {
fmt.Printf("%+v\n", listener)
}
binds, err := GetUDP4Table()
if err != nil {
t.Fatal(err)
}
fmt.Println("\nUDP 4 binds:")
for _, bind := range binds {
fmt.Printf("%+v\n", bind)
}
binds, err = GetUDP6Table()
if err != nil {
t.Fatal(err)
}
fmt.Println("\nUDP 6 binds:")
for _, bind := range binds {
fmt.Printf("%+v\n", bind)
}
}

View file

@ -1,17 +1,12 @@
package network
import (
"net"
"github.com/safing/portbase/modules"
)
var (
module *modules.Module
dnsAddress = net.IPv4(127, 0, 0, 1)
dnsPort uint16 = 53
defaultFirewallHandler FirewallHandler
)

View file

@ -2,13 +2,57 @@ package netutils
import (
"regexp"
"github.com/miekg/dns"
)
var (
cleanDomainRegex = regexp.MustCompile(`^((xn--)?[a-z0-9-_]{0,61}[a-z0-9]{1,1}\.)*(xn--)?([a-z0-9-]{1,61}|[a-z0-9-]{1,30}\.[a-z]{2,}\.)$`)
cleanDomainRegex = regexp.MustCompile(
`^` + // match beginning
`(` + // start subdomain group
`(xn--)?` + // idn prefix
`[a-z0-9_-]{1,63}` + // main chunk
`\.` + // ending with a dot
`)*` + // end subdomain group, allow any number of subdomains
`(xn--)?` + // TLD idn prefix
`[a-z0-9_-]{2,63}` + // TLD main chunk with at least two characters
`\.` + // ending with a dot
`$`, // match end
)
)
// IsValidFqdn returns whether the given string is a valid fqdn.
func IsValidFqdn(fqdn string) bool {
return cleanDomainRegex.MatchString(fqdn)
// root zone
if fqdn == "." {
return true
}
// check max length
if len(fqdn) > 256 {
return false
}
// check with regex
if !cleanDomainRegex.MatchString(fqdn) {
return false
}
// check with miegk/dns
// IsFqdn checks if a domain name is fully qualified.
if !dns.IsFqdn(fqdn) {
return false
}
// IsDomainName checks if s is a valid domain name, it returns the number of
// labels and true, when a domain name is valid. Note that non fully qualified
// domain name is considered valid, in this case the last label is counted in
// the number of labels. When false is returned the number of labels is not
// defined. Also note that this function is extremely liberal; almost any
// string is a valid domain name as the DNS is 8 bit protocol. It checks if each
// label fits in 63 characters and that the entire name will fit into the 255
// octet wire format limit.
_, ok := dns.IsDomainName(fqdn)
return ok
}

View file

@ -0,0 +1,43 @@
package netutils
import "testing"
func testDomainValidity(t *testing.T, domain string, isValid bool) {
if IsValidFqdn(domain) != isValid {
t.Errorf("domain %s failed check: was valid=%v, expected valid=%v", domain, IsValidFqdn(domain), isValid)
}
}
func TestDNSValidation(t *testing.T) {
// valid
testDomainValidity(t, ".", true)
testDomainValidity(t, "at.", true)
testDomainValidity(t, "orf.at.", true)
testDomainValidity(t, "www.orf.at.", true)
testDomainValidity(t, "a.b.c.d.e.f.g.h.i.j.k.l.m.n.o.p.q.r.s.t.u.v.x.y.z.example.org.", true)
testDomainValidity(t, "a_a.com.", true)
testDomainValidity(t, "a-a.com.", true)
testDomainValidity(t, "a_a.com.", true)
testDomainValidity(t, "a-a.com.", true)
testDomainValidity(t, "xn--a.com.", true)
testDomainValidity(t, "xn--asdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasd.com.", true)
// maybe valid
testDomainValidity(t, "-.com.", true)
testDomainValidity(t, "_.com.", true)
testDomainValidity(t, "a_.com.", true)
testDomainValidity(t, "a-.com.", true)
testDomainValidity(t, "_a.com.", true)
testDomainValidity(t, "-a.com.", true)
// invalid
testDomainValidity(t, ".com.", false)
testDomainValidity(t, ".com.", false)
testDomainValidity(t, "xn--asdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdf.com.", false)
testDomainValidity(t, "asdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdf.com.", false)
testDomainValidity(t, "asdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdf.com.", false)
testDomainValidity(t, "asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.as.com.", false)
// real world examples
testDomainValidity(t, "iuqerfsodp9ifjaposdfjhgosurijfaewrwergwea.com.", true)
}

View file

@ -36,22 +36,22 @@ func (pkt *Base) SetPacketInfo(packetInfo Info) {
// SetInbound sets a the packet direction to inbound. This must only used when initializing the packet structure.
func (pkt *Base) SetInbound() {
pkt.info.Direction = true
pkt.info.Inbound = true
}
// SetOutbound sets a the packet direction to outbound. This must only used when initializing the packet structure.
func (pkt *Base) SetOutbound() {
pkt.info.Direction = false
pkt.info.Inbound = false
}
// IsInbound checks if the packet is inbound.
func (pkt *Base) IsInbound() bool {
return pkt.info.Direction
return pkt.info.Inbound
}
// IsOutbound checks if the packet is outbound.
func (pkt *Base) IsOutbound() bool {
return !pkt.info.Direction
return !pkt.info.Inbound
}
// HasPorts checks if the packet has a protocol that uses ports.
@ -80,13 +80,13 @@ func (pkt *Base) GetConnectionID() string {
func (pkt *Base) createConnectionID() {
if pkt.info.Protocol == TCP || pkt.info.Protocol == UDP {
if pkt.info.Direction {
if pkt.info.Inbound {
pkt.connID = fmt.Sprintf("%d-%s-%d-%s-%d", pkt.info.Protocol, pkt.info.Dst, pkt.info.DstPort, pkt.info.Src, pkt.info.SrcPort)
} else {
pkt.connID = fmt.Sprintf("%d-%s-%d-%s-%d", pkt.info.Protocol, pkt.info.Src, pkt.info.SrcPort, pkt.info.Dst, pkt.info.DstPort)
}
} else {
if pkt.info.Direction {
if pkt.info.Inbound {
pkt.connID = fmt.Sprintf("%d-%s-%s", pkt.info.Protocol, pkt.info.Dst, pkt.info.Src)
} else {
pkt.connID = fmt.Sprintf("%d-%s-%s", pkt.info.Protocol, pkt.info.Src, pkt.info.Dst)
@ -105,7 +105,7 @@ func (pkt *Base) MatchesAddress(remote bool, protocol IPProtocol, network *net.I
if pkt.info.Protocol != protocol {
return false
}
if pkt.info.Direction != remote {
if pkt.info.Inbound != remote {
if !network.Contains(pkt.info.Src) {
return false
}
@ -131,7 +131,7 @@ func (pkt *Base) MatchesAddress(remote bool, protocol IPProtocol, network *net.I
// Remote Src Dst
//
func (pkt *Base) MatchesIP(endpoint bool, network *net.IPNet) bool {
if pkt.info.Direction != endpoint {
if pkt.info.Inbound != endpoint {
if network.Contains(pkt.info.Src) {
return true
}
@ -152,12 +152,12 @@ func (pkt *Base) String() string {
// FmtPacket returns the most important information about the packet as a string
func (pkt *Base) FmtPacket() string {
if pkt.info.Protocol == TCP || pkt.info.Protocol == UDP {
if pkt.info.Direction {
if pkt.info.Inbound {
return fmt.Sprintf("IN %s %s:%d <-> %s:%d", pkt.info.Protocol, pkt.info.Dst, pkt.info.DstPort, pkt.info.Src, pkt.info.SrcPort)
}
return fmt.Sprintf("OUT %s %s:%d <-> %s:%d", pkt.info.Protocol, pkt.info.Src, pkt.info.SrcPort, pkt.info.Dst, pkt.info.DstPort)
}
if pkt.info.Direction {
if pkt.info.Inbound {
return fmt.Sprintf("IN %s %s <-> %s", pkt.info.Protocol, pkt.info.Dst, pkt.info.Src)
}
return fmt.Sprintf("OUT %s %s <-> %s", pkt.info.Protocol, pkt.info.Src, pkt.info.Dst)
@ -170,7 +170,7 @@ func (pkt *Base) FmtProtocol() string {
// FmtRemoteIP returns the remote IP address as a string
func (pkt *Base) FmtRemoteIP() string {
if pkt.info.Direction {
if pkt.info.Inbound {
return pkt.info.Src.String()
}
return pkt.info.Dst.String()
@ -179,7 +179,7 @@ func (pkt *Base) FmtRemoteIP() string {
// FmtRemotePort returns the remote port as a string
func (pkt *Base) FmtRemotePort() string {
if pkt.info.SrcPort != 0 {
if pkt.info.Direction {
if pkt.info.Inbound {
return fmt.Sprintf("%d", pkt.info.SrcPort)
}
return fmt.Sprintf("%d", pkt.info.DstPort)

View file

@ -6,8 +6,8 @@ import (
// Info holds IP and TCP/UDP header information
type Info struct {
Direction bool
InTunnel bool
Inbound bool
InTunnel bool
Version IPVersion
Protocol IPProtocol
@ -17,7 +17,7 @@ type Info struct {
// LocalIP returns the local IP of the packet.
func (pi *Info) LocalIP() net.IP {
if pi.Direction {
if pi.Inbound {
return pi.Dst
}
return pi.Src
@ -25,7 +25,7 @@ func (pi *Info) LocalIP() net.IP {
// RemoteIP returns the remote IP of the packet.
func (pi *Info) RemoteIP() net.IP {
if pi.Direction {
if pi.Inbound {
return pi.Src
}
return pi.Dst
@ -33,7 +33,7 @@ func (pi *Info) RemoteIP() net.IP {
// LocalPort returns the local port of the packet.
func (pi *Info) LocalPort() uint16 {
if pi.Direction {
if pi.Inbound {
return pi.DstPort
}
return pi.SrcPort
@ -41,7 +41,7 @@ func (pi *Info) LocalPort() uint16 {
// RemotePort returns the remote port of the packet.
func (pi *Info) RemotePort() uint16 {
if pi.Direction {
if pi.Inbound {
return pi.SrcPort
}
return pi.DstPort

View file

@ -10,6 +10,8 @@ import (
"sync"
"syscall"
"github.com/safing/portmaster/network/socket"
"github.com/safing/portbase/log"
)
@ -18,8 +20,8 @@ var (
pidsByUser = make(map[int][]int)
)
// GetPidOfInode returns the pid of the given uid and socket inode.
func GetPidOfInode(uid, inode int) (int, bool) { //nolint:gocognit // TODO
// FindPID returns the pid of the given uid and socket inode.
func FindPID(uid, inode int) (pid int) { //nolint:gocognit // TODO
pidsByUserLock.Lock()
defer pidsByUserLock.Unlock()
@ -38,7 +40,7 @@ func GetPidOfInode(uid, inode int) (int, bool) { //nolint:gocognit // TODO
var checkedUserPids []int
for _, possiblePID := range pids {
if findSocketFromPid(possiblePID, inode) {
return possiblePID, true
return possiblePID
}
checkedUserPids = append(checkedUserPids, possiblePID)
}
@ -57,7 +59,7 @@ func GetPidOfInode(uid, inode int) (int, bool) { //nolint:gocognit // TODO
// only check if not already checked
if sort.SearchInts(checkedUserPids, possiblePID) == len {
if findSocketFromPid(possiblePID, inode) {
return possiblePID, true
return possiblePID
}
}
}
@ -71,13 +73,13 @@ func GetPidOfInode(uid, inode int) (int, bool) { //nolint:gocognit // TODO
if possibleUID != uid {
for _, possiblePID := range pids {
if findSocketFromPid(possiblePID, inode) {
return possiblePID, true
return possiblePID
}
}
}
}
return unidentifiedProcessID, false
return socket.UnidentifiedProcessID
}
func findSocketFromPid(pid, inode int) bool {

237
network/proc/tables.go Normal file
View file

@ -0,0 +1,237 @@
// +build linux
package proc
import (
"bufio"
"encoding/hex"
"fmt"
"net"
"os"
"strconv"
"strings"
"unicode"
"github.com/safing/portmaster/network/socket"
"github.com/safing/portbase/log"
)
/*
1. find socket inode
- by incoming (listenting sockets) or outgoing (local port + external IP + port) - also local IP?
- /proc/net/{tcp|udp}[6]
2. get list of processes of uid
3. find socket inode in process fds
- if not found, refresh map of uid->pids
- if not found, check ALL pids: maybe euid != uid
4. gather process info
Cache every step!
*/
// Network Related Constants
const (
TCP4 uint8 = iota
UDP4
TCP6
UDP6
ICMP4
ICMP6
tcp4ProcFile = "/proc/net/tcp"
tcp6ProcFile = "/proc/net/tcp6"
udp4ProcFile = "/proc/net/udp"
udp6ProcFile = "/proc/net/udp6"
UnfetchedProcessID = -2
tcpListenStateHex = "0A"
)
// GetTCP4Table returns the system table for IPv4 TCP activity.
func GetTCP4Table() (connections []*socket.ConnectionInfo, listeners []*socket.BindInfo, err error) {
return getTableFromSource(TCP4, tcp4ProcFile)
}
// GetTCP6Table returns the system table for IPv6 TCP activity.
func GetTCP6Table() (connections []*socket.ConnectionInfo, listeners []*socket.BindInfo, err error) {
return getTableFromSource(TCP6, tcp6ProcFile)
}
// GetUDP4Table returns the system table for IPv4 UDP activity.
func GetUDP4Table() (binds []*socket.BindInfo, err error) {
_, binds, err = getTableFromSource(UDP4, udp4ProcFile)
return
}
// GetUDP6Table returns the system table for IPv6 UDP activity.
func GetUDP6Table() (binds []*socket.BindInfo, err error) {
_, binds, err = getTableFromSource(UDP6, udp6ProcFile)
return
}
const (
// hint: we split fields by multiple delimiters, see procDelimiter
fieldIndexLocalIP = 1
fieldIndexLocalPort = 2
fieldIndexRemoteIP = 3
fieldIndexRemotePort = 4
fieldIndexUID = 11
fieldIndexInode = 13
)
func getTableFromSource(stack uint8, procFile string) (connections []*socket.ConnectionInfo, binds []*socket.BindInfo, err error) {
var ipConverter func(string) net.IP
switch stack {
case TCP4, UDP4:
ipConverter = convertIPv4
case TCP6, UDP6:
ipConverter = convertIPv6
default:
return nil, nil, fmt.Errorf("unsupported table stack: %d", stack)
}
// open file
socketData, err := os.Open(procFile)
if err != nil {
return nil, nil, err
}
defer socketData.Close()
// file scanner
scanner := bufio.NewScanner(socketData)
scanner.Split(bufio.ScanLines)
// parse
scanner.Scan() // skip first row
for scanner.Scan() {
fields := strings.FieldsFunc(scanner.Text(), procDelimiter)
if len(fields) < 14 {
// log.Tracef("process: too short: %s", fields)
continue
}
localIP := ipConverter(fields[fieldIndexLocalIP])
if localIP == nil {
continue
}
localPort, err := strconv.ParseUint(fields[fieldIndexLocalPort], 16, 16)
if err != nil {
log.Warningf("process: could not parse port: %s", err)
continue
}
uid, err := strconv.ParseInt(fields[fieldIndexUID], 10, 32)
// log.Tracef("uid: %s", fields[fieldIndexUID])
if err != nil {
log.Warningf("process: could not parse uid %s: %s", fields[11], err)
continue
}
inode, err := strconv.ParseInt(fields[fieldIndexInode], 10, 32)
// log.Tracef("inode: %s", fields[fieldIndexInode])
if err != nil {
log.Warningf("process: could not parse inode %s: %s", fields[13], err)
continue
}
switch stack {
case UDP4, UDP6:
binds = append(binds, &socket.BindInfo{
Local: socket.Address{
IP: localIP,
Port: uint16(localPort),
},
PID: UnfetchedProcessID,
UID: int(uid),
Inode: int(inode),
})
case TCP4, TCP6:
if fields[5] == tcpListenStateHex {
// listener
binds = append(binds, &socket.BindInfo{
Local: socket.Address{
IP: localIP,
Port: uint16(localPort),
},
PID: UnfetchedProcessID,
UID: int(uid),
Inode: int(inode),
})
} else {
// connection
remoteIP := ipConverter(fields[fieldIndexRemoteIP])
if remoteIP == nil {
continue
}
remotePort, err := strconv.ParseUint(fields[fieldIndexRemotePort], 16, 16)
if err != nil {
log.Warningf("process: could not parse port: %s", err)
continue
}
connections = append(connections, &socket.ConnectionInfo{
Local: socket.Address{
IP: localIP,
Port: uint16(localPort),
},
Remote: socket.Address{
IP: remoteIP,
Port: uint16(remotePort),
},
PID: UnfetchedProcessID,
UID: int(uid),
Inode: int(inode),
})
}
}
}
return connections, binds, nil
}
func procDelimiter(c rune) bool {
return unicode.IsSpace(c) || c == ':'
}
func convertIPv4(data string) net.IP {
decoded, err := hex.DecodeString(data)
if err != nil {
log.Warningf("process: could not parse IPv4 %s: %s", data, err)
return nil
}
if len(decoded) != 4 {
log.Warningf("process: decoded IPv4 %s has wrong length", decoded)
return nil
}
ip := net.IPv4(decoded[3], decoded[2], decoded[1], decoded[0])
return ip
}
func convertIPv6(data string) net.IP {
decoded, err := hex.DecodeString(data)
if err != nil {
log.Warningf("process: could not parse IPv6 %s: %s", data, err)
return nil
}
if len(decoded) != 16 {
log.Warningf("process: decoded IPv6 %s has wrong length", decoded)
return nil
}
ip := net.IP(decoded)
return ip
}

View file

@ -0,0 +1,60 @@
// +build linux
package proc
import (
"fmt"
"testing"
)
func TestSockets(t *testing.T) {
connections, listeners, err := GetTCP4Table()
if err != nil {
t.Fatal(err)
}
fmt.Println("\nTCP 4 connections:")
for _, connection := range connections {
pid := FindPID(connection.UID, connection.Inode)
fmt.Printf("%d: %+v\n", pid, connection)
}
fmt.Println("\nTCP 4 listeners:")
for _, listener := range listeners {
pid := FindPID(listener.UID, listener.Inode)
fmt.Printf("%d: %+v\n", pid, listener)
}
connections, listeners, err = GetTCP6Table()
if err != nil {
t.Fatal(err)
}
fmt.Println("\nTCP 6 connections:")
for _, connection := range connections {
pid := FindPID(connection.UID, connection.Inode)
fmt.Printf("%d: %+v\n", pid, connection)
}
fmt.Println("\nTCP 6 listeners:")
for _, listener := range listeners {
pid := FindPID(listener.UID, listener.Inode)
fmt.Printf("%d: %+v\n", pid, listener)
}
binds, err := GetUDP4Table()
if err != nil {
t.Fatal(err)
}
fmt.Println("\nUDP 4 binds:")
for _, bind := range binds {
pid := FindPID(bind.UID, bind.Inode)
fmt.Printf("%d: %+v\n", pid, bind)
}
binds, err = GetUDP6Table()
if err != nil {
t.Fatal(err)
}
fmt.Println("\nUDP 6 binds:")
for _, bind := range binds {
pid := FindPID(bind.UID, bind.Inode)
fmt.Printf("%d: %+v\n", pid, bind)
}
}

31
network/socket/socket.go Normal file
View file

@ -0,0 +1,31 @@
package socket
import "net"
const (
// UnidentifiedProcessID is originally defined in the process pkg, but duplicated here because of import loops.
UnidentifiedProcessID = -1
)
// ConnectionInfo holds socket information returned by the system.
type ConnectionInfo struct {
Local Address
Remote Address
PID int
UID int
Inode int
}
// BindInfo holds socket information returned by the system.
type BindInfo struct {
Local Address
PID int
UID int
Inode int
}
// Address is an IP + Port pair.
type Address struct {
IP net.IP
Port uint16
}

101
network/state/exists.go Normal file
View file

@ -0,0 +1,101 @@
package state
import (
"time"
"github.com/safing/portmaster/network/packet"
"github.com/safing/portmaster/network/socket"
)
const (
// UDPConnectionTTL defines the duration after which unseen UDP connections are regarded as ended.
UDPConnectionTTL = 10 * time.Minute
)
// Exists checks if the given connection is present in the system state tables.
func Exists(pktInfo *packet.Info, now time.Time) (exists bool) {
// TODO: create lookup maps before running a flurry of Exists() checks.
switch {
case pktInfo.Version == packet.IPv4 && pktInfo.Protocol == packet.TCP:
tcp4Lock.Lock()
defer tcp4Lock.Unlock()
return existsTCP(tcp4Connections, pktInfo)
case pktInfo.Version == packet.IPv6 && pktInfo.Protocol == packet.TCP:
tcp6Lock.Lock()
defer tcp6Lock.Unlock()
return existsTCP(tcp6Connections, pktInfo)
case pktInfo.Version == packet.IPv4 && pktInfo.Protocol == packet.UDP:
udp4Lock.Lock()
defer udp4Lock.Unlock()
return existsUDP(udp4Binds, udp4States, pktInfo, now)
case pktInfo.Version == packet.IPv6 && pktInfo.Protocol == packet.UDP:
udp6Lock.Lock()
defer udp6Lock.Unlock()
return existsUDP(udp6Binds, udp6States, pktInfo, now)
default:
return false
}
}
func existsTCP(connections []*socket.ConnectionInfo, pktInfo *packet.Info) (exists bool) {
localIP := pktInfo.LocalIP()
localPort := pktInfo.LocalPort()
remoteIP := pktInfo.RemoteIP()
remotePort := pktInfo.RemotePort()
// search connections
for _, socketInfo := range connections {
if localPort == socketInfo.Local.Port &&
remotePort == socketInfo.Remote.Port &&
remoteIP.Equal(socketInfo.Remote.IP) &&
localIP.Equal(socketInfo.Local.IP) {
return true
}
}
return false
}
func existsUDP(
binds []*socket.BindInfo,
udpStates map[string]map[string]*udpState,
pktInfo *packet.Info,
now time.Time,
) (exists bool) {
localIP := pktInfo.LocalIP()
localPort := pktInfo.LocalPort()
remoteIP := pktInfo.RemoteIP()
remotePort := pktInfo.RemotePort()
connThreshhold := now.Add(-UDPConnectionTTL)
// search binds
for _, socketInfo := range binds {
if localPort == socketInfo.Local.Port &&
(socketInfo.Local.IP[0] == 0 || localIP.Equal(socketInfo.Local.IP)) {
udpConnState, ok := getUDPConnState(socketInfo, udpStates, socket.Address{
IP: remoteIP,
Port: remotePort,
})
switch {
case !ok:
return false
case udpConnState.lastSeen.After(connThreshhold):
return true
default:
return false
}
}
}
return false
}

52
network/state/info.go Normal file
View file

@ -0,0 +1,52 @@
package state
import (
"sync"
"github.com/safing/portbase/database/record"
"github.com/safing/portmaster/network/socket"
)
// Info holds network state information as provided by the system.
type Info struct {
record.Base
sync.Mutex
TCP4Connections []*socket.ConnectionInfo
TCP4Listeners []*socket.BindInfo
TCP6Connections []*socket.ConnectionInfo
TCP6Listeners []*socket.BindInfo
UDP4Binds []*socket.BindInfo
UDP6Binds []*socket.BindInfo
}
// GetInfo returns all system state tables. The returned data must not be modified.
func GetInfo() *Info {
info := &Info{}
tcp4Lock.Lock()
updateTCP4Tables()
info.TCP4Connections = tcp4Connections
info.TCP4Listeners = tcp4Listeners
tcp4Lock.Unlock()
tcp6Lock.Lock()
updateTCP6Tables()
info.TCP6Connections = tcp6Connections
info.TCP6Listeners = tcp6Listeners
tcp6Lock.Unlock()
udp4Lock.Lock()
updateUDP4Table()
info.UDP4Binds = udp4Binds
udp4Lock.Unlock()
udp6Lock.Lock()
updateUDP6Table()
info.UDP6Binds = udp6Binds
udp6Lock.Unlock()
info.UpdateMeta()
return info
}

171
network/state/lookup.go Normal file
View file

@ -0,0 +1,171 @@
package state
import (
"errors"
"sync"
"time"
"github.com/safing/portmaster/network/netutils"
"github.com/safing/portmaster/network/packet"
"github.com/safing/portmaster/network/socket"
)
// - TCP
// - Outbound: Match listeners (in!), then connections (out!)
// - Inbound: Match listeners (in!), then connections (out!)
// - Clean via connections
// - UDP
// - Any connection: match specific local address or zero IP
// - In or out: save direction of first packet:
// - map[<local udp bind ip+port>]map[<remote ip+port>]{direction, lastSeen}
// - only clean if <local udp bind ip+port> is removed by OS
// - limit <remote ip+port> to 256 entries?
// - clean <remote ip+port> after 72hrs?
// - switch direction to outbound if outbound packet is seen?
// - IP: Unidentified Process
// Errors
var (
ErrConnectionNotFound = errors.New("could not find connection in system state tables")
ErrPIDNotFound = errors.New("could not find pid for socket inode")
)
var (
tcp4Lock sync.Mutex
tcp6Lock sync.Mutex
udp4Lock sync.Mutex
udp6Lock sync.Mutex
baseWaitTime = 3 * time.Millisecond
)
// Lookup looks for the given connection in the system state tables and returns the PID of the associated process and whether the connection is inbound.
func Lookup(pktInfo *packet.Info) (pid int, inbound bool, err error) {
// auto-detect version
if pktInfo.Version == 0 {
if ip := pktInfo.LocalIP().To4(); ip != nil {
pktInfo.Version = packet.IPv4
} else {
pktInfo.Version = packet.IPv6
}
}
switch {
case pktInfo.Version == packet.IPv4 && pktInfo.Protocol == packet.TCP:
tcp4Lock.Lock()
defer tcp4Lock.Unlock()
return searchTCP(tcp4Connections, tcp4Listeners, updateTCP4Tables, pktInfo)
case pktInfo.Version == packet.IPv6 && pktInfo.Protocol == packet.TCP:
tcp6Lock.Lock()
defer tcp6Lock.Unlock()
return searchTCP(tcp6Connections, tcp6Listeners, updateTCP6Tables, pktInfo)
case pktInfo.Version == packet.IPv4 && pktInfo.Protocol == packet.UDP:
udp4Lock.Lock()
defer udp4Lock.Unlock()
return searchUDP(udp4Binds, udp4States, updateUDP4Table, pktInfo)
case pktInfo.Version == packet.IPv6 && pktInfo.Protocol == packet.UDP:
udp6Lock.Lock()
defer udp6Lock.Unlock()
return searchUDP(udp6Binds, udp6States, updateUDP6Table, pktInfo)
default:
return socket.UnidentifiedProcessID, false, errors.New("unsupported protocol for finding process")
}
}
func searchTCP(
connections []*socket.ConnectionInfo,
listeners []*socket.BindInfo,
updateTables func() ([]*socket.ConnectionInfo, []*socket.BindInfo),
pktInfo *packet.Info,
) (
pid int,
inbound bool,
err error,
) {
localIP := pktInfo.LocalIP()
localPort := pktInfo.LocalPort()
// search until we find something
for i := 0; i < 7; i++ {
// always search listeners first
for _, socketInfo := range listeners {
if localPort == socketInfo.Local.Port &&
(socketInfo.Local.IP[0] == 0 || localIP.Equal(socketInfo.Local.IP)) {
return checkBindPID(socketInfo, true)
}
}
// search connections
for _, socketInfo := range connections {
if localPort == socketInfo.Local.Port &&
localIP.Equal(socketInfo.Local.IP) {
return checkConnectionPID(socketInfo, false)
}
}
// we found nothing, we could have been too fast, give the kernel some time to think
// back off timer: with 3ms baseWaitTime: 3, 6, 9, 12, 15, 18, 21ms - 84ms in total
time.Sleep(time.Duration(i+1) * baseWaitTime)
// refetch lists
connections, listeners = updateTables()
}
return socket.UnidentifiedProcessID, false, ErrConnectionNotFound
}
func searchUDP(
binds []*socket.BindInfo,
udpStates map[string]map[string]*udpState,
updateTable func() []*socket.BindInfo,
pktInfo *packet.Info,
) (
pid int,
inbound bool,
err error,
) {
localIP := pktInfo.LocalIP()
localPort := pktInfo.LocalPort()
isInboundMulticast := pktInfo.Inbound && netutils.ClassifyIP(localIP) == netutils.LocalMulticast
// TODO: Currently broadcast/multicast scopes are not checked, so we might
// attribute an incoming broadcast/multicast packet to the wrong process if
// there are multiple processes listening on the same local port, but
// binding to different addresses. This highly unusual for clients.
// search until we find something
for i := 0; i < 5; i++ {
// search binds
for _, socketInfo := range binds {
if localPort == socketInfo.Local.Port &&
(socketInfo.Local.IP[0] == 0 || // zero IP
isInboundMulticast || // inbound broadcast, multicast
localIP.Equal(socketInfo.Local.IP)) {
// do not check direction if remoteIP/Port is not given
if pktInfo.RemotePort() == 0 {
return checkBindPID(socketInfo, pktInfo.Inbound)
}
// get direction and return
connInbound := getUDPDirection(socketInfo, udpStates, pktInfo)
return checkBindPID(socketInfo, connInbound)
}
}
// we found nothing, we could have been too fast, give the kernel some time to think
// back off timer: with 3ms baseWaitTime: 3, 6, 9, 12, 15, 18, 21ms - 84ms in total
time.Sleep(time.Duration(i+1) * baseWaitTime)
// refetch lists
binds = updateTable()
}
return socket.UnidentifiedProcessID, pktInfo.Inbound, ErrConnectionNotFound
}

View file

@ -0,0 +1,27 @@
package state
import (
"github.com/safing/portmaster/network/proc"
"github.com/safing/portmaster/network/socket"
)
var (
getTCP4Table = proc.GetTCP4Table
getTCP6Table = proc.GetTCP6Table
getUDP4Table = proc.GetUDP4Table
getUDP6Table = proc.GetUDP6Table
)
func checkConnectionPID(socketInfo *socket.ConnectionInfo, connInbound bool) (pid int, inbound bool, err error) {
if socketInfo.PID == proc.UnfetchedProcessID {
socketInfo.PID = proc.FindPID(socketInfo.UID, socketInfo.Inode)
}
return socketInfo.PID, connInbound, nil
}
func checkBindPID(socketInfo *socket.BindInfo, connInbound bool) (pid int, inbound bool, err error) {
if socketInfo.PID == proc.UnfetchedProcessID {
socketInfo.PID = proc.FindPID(socketInfo.UID, socketInfo.Inode)
}
return socketInfo.PID, connInbound, nil
}

View file

@ -0,0 +1,21 @@
package state
import (
"github.com/safing/portmaster/network/iphelper"
"github.com/safing/portmaster/network/socket"
)
var (
getTCP4Table = iphelper.GetTCP4Table
getTCP6Table = iphelper.GetTCP6Table
getUDP4Table = iphelper.GetUDP4Table
getUDP6Table = iphelper.GetUDP6Table
)
func checkConnectionPID(socketInfo *socket.ConnectionInfo, connInbound bool) (pid int, inbound bool, err error) {
return socketInfo.PID, connInbound, nil
}
func checkBindPID(socketInfo *socket.BindInfo, connInbound bool) (pid int, inbound bool, err error) {
return socketInfo.PID, connInbound, nil
}

68
network/state/tables.go Normal file
View file

@ -0,0 +1,68 @@
package state
import (
"github.com/safing/portbase/log"
"github.com/safing/portmaster/network/socket"
)
var (
tcp4Connections []*socket.ConnectionInfo
tcp4Listeners []*socket.BindInfo
tcp6Connections []*socket.ConnectionInfo
tcp6Listeners []*socket.BindInfo
udp4Binds []*socket.BindInfo
udp6Binds []*socket.BindInfo
)
func updateTCP4Tables() (connections []*socket.ConnectionInfo, listeners []*socket.BindInfo) {
var err error
connections, listeners, err = getTCP4Table()
if err != nil {
log.Warningf("state: failed to get TCP4 socket table: %s", err)
return
}
tcp4Connections = connections
tcp4Listeners = listeners
return
}
func updateTCP6Tables() (connections []*socket.ConnectionInfo, listeners []*socket.BindInfo) {
var err error
connections, listeners, err = getTCP6Table()
if err != nil {
log.Warningf("state: failed to get TCP6 socket table: %s", err)
return
}
tcp6Connections = connections
tcp6Listeners = listeners
return
}
func updateUDP4Table() (binds []*socket.BindInfo) {
var err error
binds, err = getUDP4Table()
if err != nil {
log.Warningf("state: failed to get UDP4 socket table: %s", err)
return
}
udp4Binds = binds
return
}
func updateUDP6Table() (binds []*socket.BindInfo) {
var err error
binds, err = getUDP6Table()
if err != nil {
log.Warningf("state: failed to get UDP6 socket table: %s", err)
return
}
udp6Binds = binds
return
}

125
network/state/udp.go Normal file
View file

@ -0,0 +1,125 @@
package state
import (
"context"
"time"
"github.com/safing/portmaster/network/packet"
"github.com/safing/portmaster/network/socket"
)
type udpState struct {
inbound bool
lastSeen time.Time
}
const (
// UDPConnStateTTL is the maximum time a udp connection state is held.
UDPConnStateTTL = 72 * time.Hour
// UDPConnStateShortenedTTL is a shortened maximum time a udp connection state is held, if there more entries than defined by AggressiveCleaningThreshold.
UDPConnStateShortenedTTL = 3 * time.Hour
// AggressiveCleaningThreshold defines the soft limit of udp connection state held per udp socket.
AggressiveCleaningThreshold = 256
)
var (
udp4States = make(map[string]map[string]*udpState) // locked with udp4Lock
udp6States = make(map[string]map[string]*udpState) // locked with udp6Lock
)
func getUDPConnState(socketInfo *socket.BindInfo, udpStates map[string]map[string]*udpState, remoteAddress socket.Address) (udpConnState *udpState, ok bool) {
bindMap, ok := udpStates[makeUDPStateKey(socketInfo.Local)]
if ok {
udpConnState, ok = bindMap[makeUDPStateKey(remoteAddress)]
return
}
return nil, false
}
func getUDPDirection(socketInfo *socket.BindInfo, udpStates map[string]map[string]*udpState, pktInfo *packet.Info) (connDirection bool) {
localKey := makeUDPStateKey(socketInfo.Local)
bindMap, ok := udpStates[localKey]
if !ok {
bindMap = make(map[string]*udpState)
udpStates[localKey] = bindMap
}
remoteKey := makeUDPStateKey(socket.Address{
IP: pktInfo.RemoteIP(),
Port: pktInfo.RemotePort(),
})
udpConnState, ok := bindMap[remoteKey]
if !ok {
bindMap[remoteKey] = &udpState{
inbound: pktInfo.Inbound,
lastSeen: time.Now().UTC(),
}
return pktInfo.Inbound
}
udpConnState.lastSeen = time.Now().UTC()
return udpConnState.inbound
}
// CleanUDPStates cleans the udp connection states which save connection directions.
func CleanUDPStates(_ context.Context) {
now := time.Now().UTC()
udp4Lock.Lock()
updateUDP4Table()
cleanStates(udp4Binds, udp4States, now)
udp4Lock.Unlock()
udp6Lock.Lock()
updateUDP6Table()
cleanStates(udp6Binds, udp6States, now)
udp6Lock.Unlock()
}
func cleanStates(
binds []*socket.BindInfo,
udpStates map[string]map[string]*udpState,
now time.Time,
) {
// compute thresholds
threshold := now.Add(-UDPConnStateTTL)
shortThreshhold := now.Add(-UDPConnStateShortenedTTL)
// make lookup map of all active keys
bindKeys := make(map[string]struct{})
for _, socketInfo := range binds {
bindKeys[makeUDPStateKey(socketInfo.Local)] = struct{}{}
}
// clean the udp state storage
for localKey, bindMap := range udpStates {
if _, active := bindKeys[localKey]; active {
// clean old entries
for remoteKey, udpConnState := range bindMap {
if udpConnState.lastSeen.Before(threshold) {
delete(bindMap, remoteKey)
}
}
// if there are too many clean more aggressively
if len(bindMap) > AggressiveCleaningThreshold {
for remoteKey, udpConnState := range bindMap {
if udpConnState.lastSeen.Before(shortThreshhold) {
delete(bindMap, remoteKey)
}
}
}
} else {
// delete the whole thing
delete(udpStates, localKey)
}
}
}
func makeUDPStateKey(address socket.Address) string {
// This could potentially go wrong, but as all IPs are created by the same source, everything should be fine.
return string(address.IP) + string(address.Port)
}

View file

@ -74,7 +74,7 @@ func finalizeLogFile(logFile *os.File, logFilePath string) {
func initControlLogFile() *os.File {
// check logging dir
logFileBasePath := filepath.Join(logsRoot.Path, "fstree", "control")
logFileBasePath := filepath.Join(logsRoot.Path, "control")
err := logsRoot.EnsureAbsPath(logFileBasePath)
if err != nil {
log.Printf("failed to check/create log file folder %s: %s\n", logFileBasePath, err)
@ -93,7 +93,7 @@ func logControlError(cErr error) {
}
// check logging dir
logFileBasePath := filepath.Join(logsRoot.Path, "fstree", "control")
logFileBasePath := filepath.Join(logsRoot.Path, "control")
err := logsRoot.EnsureAbsPath(logFileBasePath)
if err != nil {
log.Printf("failed to check/create log file folder %s: %s\n", logFileBasePath, err)
@ -114,7 +114,7 @@ func logControlError(cErr error) {
//nolint:deadcode,unused // TODO
func logControlStack() {
// check logging dir
logFileBasePath := filepath.Join(logsRoot.Path, "fstree", "control")
logFileBasePath := filepath.Join(logsRoot.Path, "control")
err := logsRoot.EnsureAbsPath(logFileBasePath)
if err != nil {
log.Printf("failed to check/create log file folder %s: %s\n", logFileBasePath, err)

View file

@ -227,7 +227,7 @@ func execute(opts *Options, args []string) (cont bool, err error) {
// log files
var logFile, errorFile *os.File
logFileBasePath := filepath.Join(logsRoot.Path, "fstree", opts.ShortIdentifier)
logFileBasePath := filepath.Join(logsRoot.Path, opts.ShortIdentifier)
err = logsRoot.EnsureAbsPath(logFileBasePath)
if err != nil {
log.Printf("failed to check/create log file dir %s: %s\n", logFileBasePath, err)

View file

@ -3,7 +3,8 @@ package process
import (
"context"
"errors"
"net"
"github.com/safing/portmaster/network/state"
"github.com/safing/portbase/log"
"github.com/safing/portmaster/network/packet"
@ -11,131 +12,28 @@ import (
// Errors
var (
ErrConnectionNotFound = errors.New("could not find connection in system state tables")
ErrProcessNotFound = errors.New("could not find process in system state tables")
ErrProcessNotFound = errors.New("could not find process in system state tables")
)
// GetPidByPacket returns the pid of the owner of the packet.
func GetPidByPacket(pkt packet.Packet) (pid int, direction bool, err error) {
var localIP net.IP
var localPort uint16
var remoteIP net.IP
var remotePort uint16
if pkt.IsInbound() {
localIP = pkt.Info().Dst
remoteIP = pkt.Info().Src
} else {
localIP = pkt.Info().Src
remoteIP = pkt.Info().Dst
}
if pkt.HasPorts() {
if pkt.IsInbound() {
localPort = pkt.Info().DstPort
remotePort = pkt.Info().SrcPort
} else {
localPort = pkt.Info().SrcPort
remotePort = pkt.Info().DstPort
}
}
switch {
case pkt.Info().Protocol == packet.TCP && pkt.Info().Version == packet.IPv4:
return getTCP4PacketInfo(localIP, localPort, remoteIP, remotePort, pkt.IsInbound())
case pkt.Info().Protocol == packet.UDP && pkt.Info().Version == packet.IPv4:
return getUDP4PacketInfo(localIP, localPort, remoteIP, remotePort, pkt.IsInbound())
case pkt.Info().Protocol == packet.TCP && pkt.Info().Version == packet.IPv6:
return getTCP6PacketInfo(localIP, localPort, remoteIP, remotePort, pkt.IsInbound())
case pkt.Info().Protocol == packet.UDP && pkt.Info().Version == packet.IPv6:
return getUDP6PacketInfo(localIP, localPort, remoteIP, remotePort, pkt.IsInbound())
default:
return UnidentifiedProcessID, false, errors.New("unsupported protocol for finding process")
}
}
// GetProcessByPacket returns the process that owns the given packet.
func GetProcessByPacket(pkt packet.Packet) (process *Process, direction bool, err error) {
if !enableProcessDetection() {
log.Tracer(pkt.Ctx()).Tracef("process: process detection disabled")
return GetUnidentifiedProcess(pkt.Ctx()), pkt.Info().Direction, nil
}
log.Tracer(pkt.Ctx()).Tracef("process: getting process and profile by packet")
var pid int
pid, direction, err = GetPidByPacket(pkt)
if err != nil {
log.Tracer(pkt.Ctx()).Errorf("process: failed to find PID of connection: %s", err)
return nil, direction, err
}
if pid < 0 {
log.Tracer(pkt.Ctx()).Errorf("process: %s", ErrConnectionNotFound.Error())
return nil, direction, ErrConnectionNotFound
}
process, err = GetOrFindPrimaryProcess(pkt.Ctx(), pid)
if err != nil {
log.Tracer(pkt.Ctx()).Errorf("process: failed to find (primary) process with PID: %s", err)
return nil, direction, err
}
err = process.GetProfile(pkt.Ctx())
if err != nil {
log.Tracer(pkt.Ctx()).Errorf("process: failed to get profile for process %s: %s", process, err)
}
return process, direction, nil
}
// GetPidByEndpoints returns the pid of the owner of the described link.
func GetPidByEndpoints(localIP net.IP, localPort uint16, remoteIP net.IP, remotePort uint16, protocol packet.IPProtocol) (pid int, direction bool, err error) {
ipVersion := packet.IPv4
if v4 := localIP.To4(); v4 == nil {
ipVersion = packet.IPv6
}
switch {
case protocol == packet.TCP && ipVersion == packet.IPv4:
return getTCP4PacketInfo(localIP, localPort, remoteIP, remotePort, false)
case protocol == packet.UDP && ipVersion == packet.IPv4:
return getUDP4PacketInfo(localIP, localPort, remoteIP, remotePort, false)
case protocol == packet.TCP && ipVersion == packet.IPv6:
return getTCP6PacketInfo(localIP, localPort, remoteIP, remotePort, false)
case protocol == packet.UDP && ipVersion == packet.IPv6:
return getUDP6PacketInfo(localIP, localPort, remoteIP, remotePort, false)
default:
return UnidentifiedProcessID, false, errors.New("unsupported protocol for finding process")
}
}
// GetProcessByEndpoints returns the process that owns the described link.
func GetProcessByEndpoints(ctx context.Context, localIP net.IP, localPort uint16, remoteIP net.IP, remotePort uint16, protocol packet.IPProtocol) (process *Process, err error) {
// GetProcessByConnection returns the process that owns the described connection.
func GetProcessByConnection(ctx context.Context, pktInfo *packet.Info) (process *Process, connInbound bool, err error) {
if !enableProcessDetection() {
log.Tracer(ctx).Tracef("process: process detection disabled")
return GetUnidentifiedProcess(ctx), nil
return GetUnidentifiedProcess(ctx), pktInfo.Inbound, nil
}
log.Tracer(ctx).Tracef("process: getting process and profile by endpoints")
log.Tracer(ctx).Tracef("process: getting pid from system network state")
var pid int
pid, _, err = GetPidByEndpoints(localIP, localPort, remoteIP, remotePort, protocol)
pid, connInbound, err = state.Lookup(pktInfo)
if err != nil {
log.Tracer(ctx).Errorf("process: failed to find PID of connection: %s", err)
return nil, err
}
if pid < 0 {
log.Tracer(ctx).Errorf("process: %s", ErrConnectionNotFound.Error())
return nil, ErrConnectionNotFound
log.Tracer(ctx).Debugf("process: failed to find PID of connection: %s", err)
return nil, connInbound, err
}
process, err = GetOrFindPrimaryProcess(ctx, pid)
if err != nil {
log.Tracer(ctx).Errorf("process: failed to find (primary) process with PID: %s", err)
return nil, err
log.Tracer(ctx).Debugf("process: failed to find (primary) process with PID: %s", err)
return nil, connInbound, err
}
err = process.GetProfile(ctx)
@ -143,10 +41,5 @@ func GetProcessByEndpoints(ctx context.Context, localIP net.IP, localPort uint16
log.Tracer(ctx).Errorf("process: failed to get profile for process %s: %s", process, err)
}
return process, nil
}
// GetActiveConnectionIDs returns a list of all active connection IDs.
func GetActiveConnectionIDs() []string {
return getActiveConnectionIDs()
return process, connInbound, nil
}

View file

@ -1,13 +0,0 @@
package process
import (
"github.com/safing/portmaster/process/proc"
)
var (
getTCP4PacketInfo = proc.GetTCP4PacketInfo
getTCP6PacketInfo = proc.GetTCP6PacketInfo
getUDP4PacketInfo = proc.GetUDP4PacketInfo
getUDP6PacketInfo = proc.GetUDP6PacketInfo
getActiveConnectionIDs = proc.GetActiveConnectionIDs
)

View file

@ -1,13 +0,0 @@
package process
import (
"github.com/safing/portmaster/process/iphelper"
)
var (
getTCP4PacketInfo = iphelper.GetTCP4PacketInfo
getTCP6PacketInfo = iphelper.GetTCP6PacketInfo
getUDP4PacketInfo = iphelper.GetUDP4PacketInfo
getUDP6PacketInfo = iphelper.GetUDP6PacketInfo
getActiveConnectionIDs = iphelper.GetActiveConnectionIDs
)

View file

@ -1,260 +0,0 @@
// +build windows
package iphelper
import (
"fmt"
"net"
"sync"
"time"
)
const (
unidentifiedProcessID = -1
)
var (
tcp4Connections []*ConnectionEntry
tcp4Listeners []*ConnectionEntry
tcp6Connections []*ConnectionEntry
tcp6Listeners []*ConnectionEntry
udp4Connections []*ConnectionEntry
udp4Listeners []*ConnectionEntry
udp6Connections []*ConnectionEntry
udp6Listeners []*ConnectionEntry
ipHelper *IPHelper
lock sync.RWMutex
waitTime = 15 * time.Millisecond
)
func checkIPHelper() (err error) {
if ipHelper == nil {
ipHelper, err = New()
return err
}
return nil
}
// GetTCP4PacketInfo returns the pid of the given IPv4/TCP connection.
func GetTCP4PacketInfo(localIP net.IP, localPort uint16, remoteIP net.IP, remotePort uint16, pktDirection bool) (pid int, direction bool, err error) {
// search
pid, _ = search(tcp4Connections, tcp4Listeners, localIP, remoteIP, localPort, remotePort, pktDirection)
if pid >= 0 {
return pid, pktDirection, nil
}
for i := 0; i < 3; i++ {
// give kernel some time, then try again
// log.Tracef("process: giving kernel some time to think")
// if unable to find, refresh
lock.Lock()
err = checkIPHelper()
if err == nil {
tcp4Connections, tcp4Listeners, err = ipHelper.GetTables(TCP, IPv4)
}
lock.Unlock()
if err != nil {
return unidentifiedProcessID, pktDirection, err
}
// search
pid, _ = search(tcp4Connections, tcp4Listeners, localIP, remoteIP, localPort, remotePort, pktDirection)
if pid >= 0 {
return pid, pktDirection, nil
}
time.Sleep(waitTime)
}
return unidentifiedProcessID, pktDirection, nil
}
// GetTCP6PacketInfo returns the pid of the given IPv6/TCP connection.
func GetTCP6PacketInfo(localIP net.IP, localPort uint16, remoteIP net.IP, remotePort uint16, pktDirection bool) (pid int, direction bool, err error) {
// search
pid, _ = search(tcp6Connections, tcp6Listeners, localIP, remoteIP, localPort, remotePort, pktDirection)
if pid >= 0 {
return pid, pktDirection, nil
}
for i := 0; i < 3; i++ {
// give kernel some time, then try again
// log.Tracef("process: giving kernel some time to think")
// if unable to find, refresh
lock.Lock()
err = checkIPHelper()
if err == nil {
tcp6Connections, tcp6Listeners, err = ipHelper.GetTables(TCP, IPv6)
}
lock.Unlock()
if err != nil {
return unidentifiedProcessID, pktDirection, err
}
// search
pid, _ = search(tcp6Connections, tcp6Listeners, localIP, remoteIP, localPort, remotePort, pktDirection)
if pid >= 0 {
return pid, pktDirection, nil
}
time.Sleep(waitTime)
}
return unidentifiedProcessID, pktDirection, nil
}
// GetUDP4PacketInfo returns the pid of the given IPv4/UDP connection.
func GetUDP4PacketInfo(localIP net.IP, localPort uint16, remoteIP net.IP, remotePort uint16, pktDirection bool) (pid int, direction bool, err error) {
// search
pid, _ = search(udp4Connections, udp4Listeners, localIP, remoteIP, localPort, remotePort, pktDirection)
if pid >= 0 {
return pid, pktDirection, nil
}
for i := 0; i < 3; i++ {
// give kernel some time, then try again
// log.Tracef("process: giving kernel some time to think")
// if unable to find, refresh
lock.Lock()
err = checkIPHelper()
if err == nil {
udp4Connections, udp4Listeners, err = ipHelper.GetTables(UDP, IPv4)
}
lock.Unlock()
if err != nil {
return unidentifiedProcessID, pktDirection, err
}
// search
pid, _ = search(udp4Connections, udp4Listeners, localIP, remoteIP, localPort, remotePort, pktDirection)
if pid >= 0 {
return pid, pktDirection, nil
}
time.Sleep(waitTime)
}
return unidentifiedProcessID, pktDirection, nil
}
// GetUDP6PacketInfo returns the pid of the given IPv6/UDP connection.
func GetUDP6PacketInfo(localIP net.IP, localPort uint16, remoteIP net.IP, remotePort uint16, pktDirection bool) (pid int, direction bool, err error) {
// search
pid, _ = search(udp6Connections, udp6Listeners, localIP, remoteIP, localPort, remotePort, pktDirection)
if pid >= 0 {
return pid, pktDirection, nil
}
for i := 0; i < 3; i++ {
// give kernel some time, then try again
// log.Tracef("process: giving kernel some time to think")
// if unable to find, refresh
lock.Lock()
err = checkIPHelper()
if err == nil {
udp6Connections, udp6Listeners, err = ipHelper.GetTables(UDP, IPv6)
}
lock.Unlock()
if err != nil {
return unidentifiedProcessID, pktDirection, err
}
// search
pid, _ = search(udp6Connections, udp6Listeners, localIP, remoteIP, localPort, remotePort, pktDirection)
if pid >= 0 {
return pid, pktDirection, nil
}
time.Sleep(waitTime)
}
return unidentifiedProcessID, pktDirection, nil
}
func search(connections, listeners []*ConnectionEntry, localIP, remoteIP net.IP, localPort, remotePort uint16, pktDirection bool) (pid int, direction bool) { //nolint:unparam // TODO: use direction, it may not be used because results caused problems, investigate.
lock.RLock()
defer lock.RUnlock()
if pktDirection {
// inbound
pid = searchListeners(listeners, localIP, localPort)
if pid >= 0 {
return pid, true
}
pid = searchConnections(connections, localIP, remoteIP, localPort, remotePort)
if pid >= 0 {
return pid, false
}
} else {
// outbound
pid = searchConnections(connections, localIP, remoteIP, localPort, remotePort)
if pid >= 0 {
return pid, false
}
pid = searchListeners(listeners, localIP, localPort)
if pid >= 0 {
return pid, true
}
}
return unidentifiedProcessID, pktDirection
}
func searchConnections(list []*ConnectionEntry, localIP, remoteIP net.IP, localPort, remotePort uint16) (pid int) {
for _, entry := range list {
if localPort == entry.localPort &&
remotePort == entry.remotePort &&
remoteIP.Equal(entry.remoteIP) &&
localIP.Equal(entry.localIP) {
return entry.pid
}
}
return unidentifiedProcessID
}
func searchListeners(list []*ConnectionEntry, localIP net.IP, localPort uint16) (pid int) {
for _, entry := range list {
if localPort == entry.localPort &&
(entry.localIP == nil || // nil IP means zero IP, see tables.go
localIP.Equal(entry.localIP)) {
return entry.pid
}
}
return unidentifiedProcessID
}
// GetActiveConnectionIDs returns all currently active connection IDs.
func GetActiveConnectionIDs() (connections []string) {
lock.Lock()
defer lock.Unlock()
for _, entry := range tcp4Connections {
connections = append(connections, fmt.Sprintf("%d-%s-%d-%s-%d", TCP, entry.localIP, entry.localPort, entry.remoteIP, entry.remotePort))
}
for _, entry := range tcp6Connections {
connections = append(connections, fmt.Sprintf("%d-%s-%d-%s-%d", TCP, entry.localIP, entry.localPort, entry.remoteIP, entry.remotePort))
}
for _, entry := range udp4Connections {
connections = append(connections, fmt.Sprintf("%d-%s-%d-%s-%d", UDP, entry.localIP, entry.localPort, entry.remoteIP, entry.remotePort))
}
for _, entry := range udp6Connections {
connections = append(connections, fmt.Sprintf("%d-%s-%d-%s-%d", UDP, entry.localIP, entry.localPort, entry.remoteIP, entry.remotePort))
}
return
}

View file

@ -1,79 +0,0 @@
// +build windows
package iphelper
import (
"errors"
"fmt"
"github.com/tevino/abool"
"golang.org/x/sys/windows"
)
var (
errInvalid = errors.New("IPHelper not initialzed or broken")
)
// IPHelper represents a subset of the Windows iphlpapi.dll.
type IPHelper struct {
dll *windows.LazyDLL
getExtendedTCPTable *windows.LazyProc
getExtendedUDPTable *windows.LazyProc
// getOwnerModuleFromTcpEntry *windows.LazyProc
// getOwnerModuleFromTcp6Entry *windows.LazyProc
// getOwnerModuleFromUdpEntry *windows.LazyProc
// getOwnerModuleFromUdp6Entry *windows.LazyProc
valid *abool.AtomicBool
}
// New returns a new IPHelper API (with an instance of iphlpapi.dll loaded).
func New() (*IPHelper, error) {
new := &IPHelper{}
new.valid = abool.NewBool(false)
var err error
// load dll
new.dll = windows.NewLazySystemDLL("iphlpapi.dll")
err = new.dll.Load()
if err != nil {
return nil, err
}
// load functions
new.getExtendedTCPTable = new.dll.NewProc("GetExtendedTcpTable")
err = new.getExtendedTCPTable.Find()
if err != nil {
return nil, fmt.Errorf("could find proc GetExtendedTcpTable: %s", err)
}
new.getExtendedUDPTable = new.dll.NewProc("GetExtendedUdpTable")
err = new.getExtendedUDPTable.Find()
if err != nil {
return nil, fmt.Errorf("could find proc GetExtendedUdpTable: %s", err)
}
// new.getOwnerModuleFromTcpEntry = new.dll.NewProc("GetOwnerModuleFromTcpEntry")
// err = new.getOwnerModuleFromTcpEntry.Find()
// if err != nil {
// return nil, fmt.Errorf("could find proc GetOwnerModuleFromTcpEntry: %s", err)
// }
// new.getOwnerModuleFromTcp6Entry = new.dll.NewProc("GetOwnerModuleFromTcp6Entry")
// err = new.getOwnerModuleFromTcp6Entry.Find()
// if err != nil {
// return nil, fmt.Errorf("could find proc GetOwnerModuleFromTcp6Entry: %s", err)
// }
// new.getOwnerModuleFromUdpEntry = new.dll.NewProc("GetOwnerModuleFromUdpEntry")
// err = new.getOwnerModuleFromUdpEntry.Find()
// if err != nil {
// return nil, fmt.Errorf("could find proc GetOwnerModuleFromUdpEntry: %s", err)
// }
// new.getOwnerModuleFromUdp6Entry = new.dll.NewProc("GetOwnerModuleFromUdp6Entry")
// err = new.getOwnerModuleFromUdp6Entry.Find()
// if err != nil {
// return nil, fmt.Errorf("could find proc GetOwnerModuleFromUdp6Entry: %s", err)
// }
new.valid.Set()
return new, nil
}

View file

@ -1,62 +0,0 @@
// +build windows
package main
import (
"fmt"
"github.com/safing/portmaster/process/iphelper"
)
func main() {
iph, err := iphelper.New()
if err != nil {
panic(err)
}
fmt.Printf("TCP4\n")
conns, lConns, err := iph.GetTables(iphelper.TCP, iphelper.IPv4)
if err != nil {
panic(err)
}
fmt.Printf("Connections:\n")
for _, conn := range conns {
fmt.Printf("%s\n", conn)
}
fmt.Printf("Listeners:\n")
for _, conn := range lConns {
fmt.Printf("%s\n", conn)
}
fmt.Printf("\nTCP6\n")
conns, lConns, err = iph.GetTables(iphelper.TCP, iphelper.IPv6)
if err != nil {
panic(err)
}
fmt.Printf("Connections:\n")
for _, conn := range conns {
fmt.Printf("%s\n", conn)
}
fmt.Printf("Listeners:\n")
for _, conn := range lConns {
fmt.Printf("%s\n", conn)
}
fmt.Printf("\nUDP4\n")
_, lConns, err = iph.GetTables(iphelper.UDP, iphelper.IPv4)
if err != nil {
panic(err)
}
for _, conn := range lConns {
fmt.Printf("%s\n", conn)
}
fmt.Printf("\nUDP6\n")
_, lConns, err = iph.GetTables(iphelper.UDP, iphelper.IPv6)
if err != nil {
panic(err)
}
for _, conn := range lConns {
fmt.Printf("%s\n", conn)
}
}

View file

@ -1,83 +0,0 @@
// +build linux
package proc
import (
"net"
"time"
)
// PID querying return codes
const (
Success uint8 = iota
NoSocket
NoProcess
)
var (
waitTime = 15 * time.Millisecond
)
// GetPidOfConnection returns the PID of the given connection.
func GetPidOfConnection(localIP net.IP, localPort uint16, protocol uint8) (pid int, status uint8) {
uid, inode, ok := getConnectionSocket(localIP, localPort, protocol)
if !ok {
uid, inode, ok = getListeningSocket(localIP, localPort, protocol)
for i := 0; i < 3 && !ok; i++ {
// give kernel some time, then try again
// log.Tracef("process: giving kernel some time to think")
time.Sleep(waitTime)
uid, inode, ok = getConnectionSocket(localIP, localPort, protocol)
if !ok {
uid, inode, ok = getListeningSocket(localIP, localPort, protocol)
}
}
if !ok {
return unidentifiedProcessID, NoSocket
}
}
pid, ok = GetPidOfInode(uid, inode)
for i := 0; i < 3 && !ok; i++ {
// give kernel some time, then try again
// log.Tracef("process: giving kernel some time to think")
time.Sleep(waitTime)
pid, ok = GetPidOfInode(uid, inode)
}
if !ok {
return unidentifiedProcessID, NoProcess
}
return
}
// GetPidOfIncomingConnection returns the PID of the given incoming connection.
func GetPidOfIncomingConnection(localIP net.IP, localPort uint16, protocol uint8) (pid int, status uint8) {
uid, inode, ok := getListeningSocket(localIP, localPort, protocol)
if !ok {
// for TCP4 and UDP4, also try TCP6 and UDP6, as linux sometimes treats them as a single dual socket, and shows the IPv6 version.
switch protocol {
case TCP4:
uid, inode, ok = getListeningSocket(localIP, localPort, TCP6)
case UDP4:
uid, inode, ok = getListeningSocket(localIP, localPort, UDP6)
}
if !ok {
return unidentifiedProcessID, NoSocket
}
}
pid, ok = GetPidOfInode(uid, inode)
for i := 0; i < 3 && !ok; i++ {
// give kernel some time, then try again
// log.Tracef("process: giving kernel some time to think")
time.Sleep(waitTime)
pid, ok = GetPidOfInode(uid, inode)
}
if !ok {
return unidentifiedProcessID, NoProcess
}
return
}

View file

@ -1,66 +0,0 @@
// +build linux
package proc
import (
"errors"
"net"
)
const (
unidentifiedProcessID = -1
)
// GetTCP4PacketInfo searches the network state tables for a TCP4 connection
func GetTCP4PacketInfo(localIP net.IP, localPort uint16, remoteIP net.IP, remotePort uint16, pktDirection bool) (pid int, direction bool, err error) {
return search(TCP4, localIP, localPort, pktDirection)
}
// GetTCP6PacketInfo searches the network state tables for a TCP6 connection
func GetTCP6PacketInfo(localIP net.IP, localPort uint16, remoteIP net.IP, remotePort uint16, pktDirection bool) (pid int, direction bool, err error) {
return search(TCP6, localIP, localPort, pktDirection)
}
// GetUDP4PacketInfo searches the network state tables for a UDP4 connection
func GetUDP4PacketInfo(localIP net.IP, localPort uint16, remoteIP net.IP, remotePort uint16, pktDirection bool) (pid int, direction bool, err error) {
return search(UDP4, localIP, localPort, pktDirection)
}
// GetUDP6PacketInfo searches the network state tables for a UDP6 connection
func GetUDP6PacketInfo(localIP net.IP, localPort uint16, remoteIP net.IP, remotePort uint16, pktDirection bool) (pid int, direction bool, err error) {
return search(UDP6, localIP, localPort, pktDirection)
}
func search(protocol uint8, localIP net.IP, localPort uint16, pktDirection bool) (pid int, direction bool, err error) {
var status uint8
if pktDirection {
pid, status = GetPidOfIncomingConnection(localIP, localPort, protocol)
if pid >= 0 {
return pid, true, nil
}
// pid, status = GetPidOfConnection(localIP, localPort, protocol)
// if pid >= 0 {
// return pid, false, nil
// }
} else {
pid, status = GetPidOfConnection(localIP, localPort, protocol)
if pid >= 0 {
return pid, false, nil
}
// pid, status = GetPidOfIncomingConnection(localIP, localPort, protocol)
// if pid >= 0 {
// return pid, true, nil
// }
}
switch status {
case NoSocket:
return unidentifiedProcessID, direction, errors.New("could not find socket")
case NoProcess:
return unidentifiedProcessID, direction, errors.New("could not find PID")
default:
return unidentifiedProcessID, direction, nil
}
}

View file

@ -1,18 +0,0 @@
// +build linux
package proc
import (
"log"
"testing"
)
func TestProcessFinder(t *testing.T) {
updatePids()
log.Printf("pidsByUser: %v", pidsByUser)
pid, _ := GetPidOfInode(1000, 112588)
log.Printf("pid: %d", pid)
}

View file

@ -1,370 +0,0 @@
// +build linux
package proc
import (
"bufio"
"encoding/hex"
"fmt"
"net"
"os"
"strconv"
"strings"
"sync"
"unicode"
"github.com/safing/portbase/log"
)
/*
1. find socket inode
- by incoming (listenting sockets) or outgoing (local port + external IP + port) - also local IP?
- /proc/net/{tcp|udp}[6]
2. get list of processes of uid
3. find socket inode in process fds
- if not found, refresh map of uid->pids
- if not found, check ALL pids: maybe euid != uid
4. gather process info
Cache every step!
*/
// Network Related Constants
const (
TCP4 uint8 = iota
UDP4
TCP6
UDP6
ICMP4
ICMP6
TCP4Data = "/proc/net/tcp"
UDP4Data = "/proc/net/udp"
TCP6Data = "/proc/net/tcp6"
UDP6Data = "/proc/net/udp6"
ICMP4Data = "/proc/net/icmp"
ICMP6Data = "/proc/net/icmp6"
)
var (
// connectionSocketsLock sync.Mutex
// connectionTCP4 = make(map[string][]int)
// connectionUDP4 = make(map[string][]int)
// connectionTCP6 = make(map[string][]int)
// connectionUDP6 = make(map[string][]int)
listeningSocketsLock sync.Mutex
addressListeningTCP4 = make(map[string][]int)
addressListeningUDP4 = make(map[string][]int)
addressListeningTCP6 = make(map[string][]int)
addressListeningUDP6 = make(map[string][]int)
globalListeningTCP4 = make(map[uint16][]int)
globalListeningUDP4 = make(map[uint16][]int)
globalListeningTCP6 = make(map[uint16][]int)
globalListeningUDP6 = make(map[uint16][]int)
)
func getConnectionSocket(localIP net.IP, localPort uint16, protocol uint8) (int, int, bool) {
// listeningSocketsLock.Lock()
// defer listeningSocketsLock.Unlock()
var procFile string
var localIPHex string
switch protocol {
case TCP4:
procFile = TCP4Data
localIPBytes := []byte(localIP.To4())
localIPHex = strings.ToUpper(hex.EncodeToString([]byte{localIPBytes[3], localIPBytes[2], localIPBytes[1], localIPBytes[0]}))
case UDP4:
procFile = UDP4Data
localIPBytes := []byte(localIP.To4())
localIPHex = strings.ToUpper(hex.EncodeToString([]byte{localIPBytes[3], localIPBytes[2], localIPBytes[1], localIPBytes[0]}))
case TCP6:
procFile = TCP6Data
localIPHex = hex.EncodeToString([]byte(localIP))
case UDP6:
procFile = UDP6Data
localIPHex = hex.EncodeToString([]byte(localIP))
}
localPortHex := fmt.Sprintf("%04X", localPort)
// log.Tracef("process/proc: searching for PID of: %s:%d (%s:%s)", localIP, localPort, localIPHex, localPortHex)
// open file
socketData, err := os.Open(procFile)
if err != nil {
log.Warningf("process/proc: could not read %s: %s", procFile, err)
return unidentifiedProcessID, unidentifiedProcessID, false
}
defer socketData.Close()
// file scanner
scanner := bufio.NewScanner(socketData)
scanner.Split(bufio.ScanLines)
// parse
scanner.Scan() // skip first line
for scanner.Scan() {
line := strings.FieldsFunc(scanner.Text(), procDelimiter)
// log.Tracef("line: %s", line)
if len(line) < 14 {
// log.Tracef("process: too short: %s", line)
continue
}
if line[1] != localIPHex {
continue
}
if line[2] != localPortHex {
continue
}
ok := true
uid, err := strconv.ParseInt(line[11], 10, 32)
if err != nil {
log.Warningf("process: could not parse uid %s: %s", line[11], err)
uid = -1
ok = false
}
inode, err := strconv.ParseInt(line[13], 10, 32)
if err != nil {
log.Warningf("process: could not parse inode %s: %s", line[13], err)
inode = -1
ok = false
}
// log.Tracef("process/proc: identified process of %s:%d: socket=%d uid=%d", localIP, localPort, int(inode), int(uid))
return int(uid), int(inode), ok
}
return unidentifiedProcessID, unidentifiedProcessID, false
}
func getListeningSocket(localIP net.IP, localPort uint16, protocol uint8) (uid, inode int, ok bool) {
listeningSocketsLock.Lock()
defer listeningSocketsLock.Unlock()
var addressListening map[string][]int
var globalListening map[uint16][]int
switch protocol {
case TCP4:
addressListening = addressListeningTCP4
globalListening = globalListeningTCP4
case UDP4:
addressListening = addressListeningUDP4
globalListening = globalListeningUDP4
case TCP6:
addressListening = addressListeningTCP6
globalListening = globalListeningTCP6
case UDP6:
addressListening = addressListeningUDP6
globalListening = globalListeningUDP6
}
data, ok := addressListening[fmt.Sprintf("%s:%d", localIP, localPort)]
if !ok {
data, ok = globalListening[localPort]
}
if ok {
return data[0], data[1], true
}
updateListeners(protocol)
data, ok = addressListening[fmt.Sprintf("%s:%d", localIP, localPort)]
if !ok {
data, ok = globalListening[localPort]
}
if ok {
return data[0], data[1], true
}
return unidentifiedProcessID, unidentifiedProcessID, false
}
func procDelimiter(c rune) bool {
return unicode.IsSpace(c) || c == ':'
}
func convertIPv4(data string) net.IP {
decoded, err := hex.DecodeString(data)
if err != nil {
log.Warningf("process: could not parse IPv4 %s: %s", data, err)
return nil
}
if len(decoded) != 4 {
log.Warningf("process: decoded IPv4 %s has wrong length", decoded)
return nil
}
ip := net.IPv4(decoded[3], decoded[2], decoded[1], decoded[0])
return ip
}
func convertIPv6(data string) net.IP {
decoded, err := hex.DecodeString(data)
if err != nil {
log.Warningf("process: could not parse IPv6 %s: %s", data, err)
return nil
}
if len(decoded) != 16 {
log.Warningf("process: decoded IPv6 %s has wrong length", decoded)
return nil
}
ip := net.IP(decoded)
return ip
}
func updateListeners(protocol uint8) {
switch protocol {
case TCP4:
addressListeningTCP4, globalListeningTCP4 = getListenerMaps(TCP4Data, "00000000", "0A", convertIPv4)
case UDP4:
addressListeningUDP4, globalListeningUDP4 = getListenerMaps(UDP4Data, "00000000", "07", convertIPv4)
case TCP6:
addressListeningTCP6, globalListeningTCP6 = getListenerMaps(TCP6Data, "00000000000000000000000000000000", "0A", convertIPv6)
case UDP6:
addressListeningUDP6, globalListeningUDP6 = getListenerMaps(UDP6Data, "00000000000000000000000000000000", "07", convertIPv6)
}
}
func getListenerMaps(procFile, zeroIP, socketStatusListening string, ipConverter func(string) net.IP) (map[string][]int, map[uint16][]int) {
addressListening := make(map[string][]int)
globalListening := make(map[uint16][]int)
// open file
socketData, err := os.Open(procFile)
if err != nil {
log.Warningf("process: could not read %s: %s", procFile, err)
return addressListening, globalListening
}
defer socketData.Close()
// file scanner
scanner := bufio.NewScanner(socketData)
scanner.Split(bufio.ScanLines)
// parse
scanner.Scan() // skip first line
for scanner.Scan() {
line := strings.FieldsFunc(scanner.Text(), procDelimiter)
if len(line) < 14 {
// log.Tracef("process: too short: %s", line)
continue
}
if line[5] != socketStatusListening {
// skip if not listening
// log.Tracef("process: not listening %s: %s", line, line[5])
continue
}
port, err := strconv.ParseUint(line[2], 16, 16)
// log.Tracef("port: %s", line[2])
if err != nil {
log.Warningf("process: could not parse port %s: %s", line[2], err)
continue
}
uid, err := strconv.ParseInt(line[11], 10, 32)
// log.Tracef("uid: %s", line[11])
if err != nil {
log.Warningf("process: could not parse uid %s: %s", line[11], err)
continue
}
inode, err := strconv.ParseInt(line[13], 10, 32)
// log.Tracef("inode: %s", line[13])
if err != nil {
log.Warningf("process: could not parse inode %s: %s", line[13], err)
continue
}
if line[1] == zeroIP {
globalListening[uint16(port)] = []int{int(uid), int(inode)}
} else {
address := ipConverter(line[1])
if address != nil {
addressListening[fmt.Sprintf("%s:%d", address, port)] = []int{int(uid), int(inode)}
}
}
}
return addressListening, globalListening
}
// GetActiveConnectionIDs returns all connection IDs that are still marked as active by the OS.
func GetActiveConnectionIDs() []string {
var connections []string
connections = append(connections, getConnectionIDsFromSource(TCP4Data, 6, convertIPv4)...)
connections = append(connections, getConnectionIDsFromSource(UDP4Data, 17, convertIPv4)...)
connections = append(connections, getConnectionIDsFromSource(TCP6Data, 6, convertIPv6)...)
connections = append(connections, getConnectionIDsFromSource(UDP6Data, 17, convertIPv6)...)
return connections
}
func getConnectionIDsFromSource(source string, protocol uint16, ipConverter func(string) net.IP) []string {
var connections []string
// open file
socketData, err := os.Open(source)
if err != nil {
log.Warningf("process: could not read %s: %s", source, err)
return connections
}
defer socketData.Close()
// file scanner
scanner := bufio.NewScanner(socketData)
scanner.Split(bufio.ScanLines)
// parse
scanner.Scan() // skip first line
for scanner.Scan() {
line := strings.FieldsFunc(scanner.Text(), procDelimiter)
if len(line) < 14 {
// log.Tracef("process: too short: %s", line)
continue
}
// skip listeners and closed connections
if line[5] == "0A" || line[5] == "07" {
continue
}
localIP := ipConverter(line[1])
if localIP == nil {
continue
}
localPort, err := strconv.ParseUint(line[2], 16, 16)
if err != nil {
log.Warningf("process: could not parse port: %s", err)
continue
}
remoteIP := ipConverter(line[3])
if remoteIP == nil {
continue
}
remotePort, err := strconv.ParseUint(line[4], 16, 16)
if err != nil {
log.Warningf("process: could not parse port: %s", err)
continue
}
connections = append(connections, fmt.Sprintf("%d-%s-%d-%s-%d", protocol, localIP, localPort, remoteIP, remotePort))
}
return connections
}

View file

@ -1,40 +0,0 @@
// +build linux
package proc
import (
"net"
"testing"
)
func TestSockets(t *testing.T) {
updateListeners(TCP4)
updateListeners(UDP4)
updateListeners(TCP6)
updateListeners(UDP6)
t.Logf("addressListeningTCP4: %v", addressListeningTCP4)
t.Logf("globalListeningTCP4: %v", globalListeningTCP4)
t.Logf("addressListeningUDP4: %v", addressListeningUDP4)
t.Logf("globalListeningUDP4: %v", globalListeningUDP4)
t.Logf("addressListeningTCP6: %v", addressListeningTCP6)
t.Logf("globalListeningTCP6: %v", globalListeningTCP6)
t.Logf("addressListeningUDP6: %v", addressListeningUDP6)
t.Logf("globalListeningUDP6: %v", globalListeningUDP6)
getListeningSocket(net.IPv4zero, 53, TCP4)
getListeningSocket(net.IPv4zero, 53, UDP4)
getListeningSocket(net.IPv6zero, 53, TCP6)
getListeningSocket(net.IPv6zero, 53, UDP6)
// spotify: 192.168.0.102:5353 192.121.140.65:80
localIP := net.IPv4(192, 168, 127, 10)
uid, inode, ok := getConnectionSocket(localIP, 46634, TCP4)
t.Logf("getConnectionSocket: %d %d %v", uid, inode, ok)
activeConnectionIDs := GetActiveConnectionIDs()
for _, connID := range activeConnectionIDs {
t.Logf("active: %s", connID)
}
}

View file

@ -49,7 +49,8 @@ type Process struct {
FirstSeen int64
LastSeen int64
Virtual bool // This process is either merged into another process or is not needed.
Virtual bool // This process is either merged into another process or is not needed.
Error string // Cache errors
}
// Profile returns the assigned layered profile.
@ -94,6 +95,7 @@ func GetOrFindPrimaryProcess(ctx context.Context, pid int) (*Process, error) {
parentProcess, err := loadProcess(ctx, process.ParentPid)
if err != nil {
log.Tracer(ctx).Tracef("process: could not get parent of %d: %d: %s", process.Pid, process.ParentPid, err)
saveFailedProcess(process.ParentPid, err.Error())
return process, nil
}
@ -226,13 +228,7 @@ func loadProcess(ctx context.Context, pid int) (*Process, error) {
pInfo, err := processInfo.NewProcess(int32(pid))
if err != nil {
// TODO: remove this workaround as soon as NewProcess really returns an error on windows when the process does not exist
// Issue: https://github.com/shirou/gopsutil/issues/729
_, err = pInfo.Name()
if err != nil {
// process does not exists
return nil, err
}
return nil, err
}
// UID
@ -375,3 +371,14 @@ func loadProcess(ctx context.Context, pid int) (*Process, error) {
new.Save()
return new, nil
}
func saveFailedProcess(pid int, err string) {
failed := &Process{
Pid: pid,
FirstSeen: time.Now().Unix(),
Virtual: true, // not needed
Error: err,
}
failed.Save()
}

View file

@ -94,6 +94,7 @@ func registerConfiguration() error {
Description: `The default filter action when nothing else permits or blocks a connection.`,
Order: cfgOptionDefaultActionOrder,
OptType: config.OptTypeString,
ReleaseLevel: config.ReleaseLevelExperimental,
DefaultValue: "permit",
ExternalOptType: "string list",
ValidationRegex: "^(permit|ask|block)$",
@ -121,17 +122,12 @@ func registerConfiguration() error {
cfgOptionDisableAutoPermit = config.Concurrent.GetAsInt(CfgOptionDisableAutoPermitKey, int64(status.SecurityLevelsAll))
cfgIntOptions[CfgOptionDisableAutoPermitKey] = cfgOptionDisableAutoPermit
// Endpoint Filter List
err = config.Register(&config.Option{
Name: "Endpoint Filter List",
Key: CfgOptionEndpointsKey,
Description: "Filter outgoing connections by matching the destination endpoint. Network Scope restrictions still apply.",
Help: `Format:
filterListHelp := `Format:
Permission:
"+": permit
"-": block
Host Matching:
IP, CIDR, Country Code, ASN, Filterlist, "*" for any
IP, CIDR, Country Code, ASN, Filterlist, Network Scope, "*" for any
Domains:
"example.com": exact match
".example.com": exact match + subdomains
@ -144,11 +140,20 @@ func registerConfiguration() error {
Examples:
+ .example.com */HTTP
- .example.com
+ 192.168.0.1/24
+ 192.168.0.1
+ 192.168.1.1/24
+ Localhost,LAN
- AS123456789
- L:MAL
- AS0
+ AT
- *`,
- *`
// Endpoint Filter List
err = config.Register(&config.Option{
Name: "Endpoint Filter List",
Key: CfgOptionEndpointsKey,
Description: "Filter outgoing connections by matching the destination endpoint. Network Scope restrictions still apply.",
Help: filterListHelp,
Order: cfgOptionEndpointsOrder,
OptType: config.OptTypeStringArray,
DefaultValue: []string{},
@ -163,35 +168,13 @@ Examples:
// Service Endpoint Filter List
err = config.Register(&config.Option{
Name: "Service Endpoint Filter List",
Key: CfgOptionServiceEndpointsKey,
Description: "Filter incoming connections by matching the source endpoint. Network Scope restrictions and the inbound permission still apply. Also not that the implicit default action of this list is to always block.",
Help: `Format:
Permission:
"+": permit
"-": block
Host Matching:
IP, CIDR, Country Code, ASN, Filterlist, "*" for any
Domains:
"example.com": exact match
".example.com": exact match + subdomains
"*xample.com": prefix wildcard
"example.*": suffix wildcard
"*example*": prefix and suffix wildcard
Protocol and Port Matching (optional):
<protocol>/<port>
Examples:
+ .example.com */HTTP
- .example.com
+ 192.168.0.1/24
- L:MAL
- AS0
+ AT
- *`,
Name: "Service Endpoint Filter List",
Key: CfgOptionServiceEndpointsKey,
Description: "Filter incoming connections by matching the source endpoint. Network Scope restrictions and the inbound permission still apply. Also not that the implicit default action of this list is to always block.",
Help: filterListHelp,
Order: cfgOptionServiceEndpointsOrder,
OptType: config.OptTypeStringArray,
DefaultValue: []string{},
DefaultValue: []string{"+ Localhost"},
ExternalOptType: "endpoint list",
ValidationRegex: `^(\+|\-) [A-z0-9\.:\-*/]+( [A-z0-9/]+)?$`,
})
@ -313,7 +296,7 @@ Examples:
Order: cfgOptionBlockP2POrder,
OptType: config.OptTypeInt,
ExternalOptType: "security level",
DefaultValue: status.SecurityLevelsAll,
DefaultValue: status.SecurityLevelExtreme,
ValidationRegex: "^(4|6|7)$",
})
if err != nil {
@ -326,7 +309,7 @@ Examples:
err = config.Register(&config.Option{
Name: "Block Inbound Connections",
Key: CfgOptionBlockInboundKey,
Description: "Connections initiated towards your device. This will usually only be the case if you are running a network service or are using peer to peer software.",
Description: "Connections initiated towards your device from the LAN or Internet. This will usually only be the case if you are running a network service or are using peer to peer software.",
Order: cfgOptionBlockInboundOrder,
OptType: config.OptTypeInt,
ExternalOptType: "security level",

View file

@ -31,7 +31,7 @@ type EndpointDomain struct {
}
func (ep *EndpointDomain) check(entity *intel.Entity, domain string) (EPResult, Reason) {
result, reason := ep.match(ep, entity, ep.Domain, "domain matches")
result, reason := ep.match(ep, entity, ep.OriginalValue, "domain matches")
switch ep.MatchType {
case domainMatchTypeExact:

View file

@ -0,0 +1,106 @@
package endpoints
import (
"strings"
"github.com/safing/portmaster/network/netutils"
"github.com/safing/portmaster/intel"
)
const (
scopeLocalhost = 1
scopeLocalhostName = "Localhost"
scopeLocalhostMatcher = "localhost"
scopeLAN = 2
scopeLANName = "LAN"
scopeLANMatcher = "lan"
scopeInternet = 4
scopeInternetName = "Internet"
scopeInternetMatcher = "internet"
)
// EndpointScope matches network scopes.
type EndpointScope struct {
EndpointBase
scopes uint8
}
// Matches checks whether the given entity matches this endpoint definition.
func (ep *EndpointScope) Matches(entity *intel.Entity) (EPResult, Reason) {
if entity.IP == nil {
return Undeterminable, nil
}
classification := netutils.ClassifyIP(entity.IP)
var scope uint8
switch classification {
case netutils.HostLocal:
scope = scopeLocalhost
case netutils.LinkLocal:
scope = scopeLAN
case netutils.SiteLocal:
scope = scopeLAN
case netutils.Global:
scope = scopeInternet
case netutils.LocalMulticast:
scope = scopeLAN
case netutils.GlobalMulticast:
scope = scopeInternet
}
if ep.scopes&scope > 0 {
return ep.match(ep, entity, ep.Scopes(), "scope matches")
}
return NoMatch, nil
}
// Scopes returns the string representation of all scopes.
func (ep *EndpointScope) Scopes() string {
// single scope
switch ep.scopes {
case scopeLocalhost:
return scopeLocalhostName
case scopeLAN:
return scopeLANName
case scopeInternet:
return scopeInternetName
}
// multiple scopes
var s []string
if ep.scopes&scopeLocalhost > 0 {
s = append(s, scopeLocalhostName)
}
if ep.scopes&scopeLAN > 0 {
s = append(s, scopeLANName)
}
if ep.scopes&scopeInternet > 0 {
s = append(s, scopeInternetName)
}
return strings.Join(s, ",")
}
func (ep *EndpointScope) String() string {
return ep.renderPPP(ep.Scopes())
}
func parseTypeScope(fields []string) (Endpoint, error) {
ep := &EndpointScope{}
for _, val := range strings.Split(strings.ToLower(fields[1]), ",") {
switch val {
case scopeLocalhostMatcher:
ep.scopes ^= scopeLocalhost
case scopeLANMatcher:
ep.scopes ^= scopeLAN
case scopeInternetMatcher:
ep.scopes ^= scopeInternet
default:
return nil, nil
}
}
return ep.parsePPP(ep, fields)
}

View file

@ -201,7 +201,7 @@ func invalidDefinitionError(fields []string, msg string) error {
return fmt.Errorf(`invalid endpoint definition: "%s" - %s`, strings.Join(fields, " "), msg)
}
func parseEndpoint(value string) (endpoint Endpoint, err error) {
func parseEndpoint(value string) (endpoint Endpoint, err error) { //nolint:gocognit
fields := strings.Fields(value)
if len(fields) < 2 {
return nil, fmt.Errorf(`invalid endpoint definition: "%s"`, value)
@ -231,6 +231,10 @@ func parseEndpoint(value string) (endpoint Endpoint, err error) {
if endpoint, err = parseTypeASN(fields); endpoint != nil || err != nil {
return
}
// scopes
if endpoint, err = parseTypeScope(fields); endpoint != nil || err != nil {
return
}
// lists
if endpoint, err = parseTypeList(fields); endpoint != nil || err != nil {
return

View file

@ -43,6 +43,12 @@ func TestEndpointParsing(t *testing.T) {
testParsing(t, "+ AS1234")
testParsing(t, "+ AS12345")
// network scope
testParsing(t, "+ Localhost")
testParsing(t, "+ LAN")
testParsing(t, "+ Internet")
testParsing(t, "+ Localhost,LAN,Internet")
// protocol and ports
testParsing(t, "+ * TCP/1-1024")
testParsing(t, "+ * */DNS")

View file

@ -342,7 +342,26 @@ func TestEndpointMatching(t *testing.T) {
IP: net.ParseIP("151.101.1.164"), // nytimes.com
}).Init(), NoMatch)
// Scope
ep, err = parseEndpoint("+ Localhost,LAN")
if err != nil {
t.Fatal(err)
}
testEndpointMatch(t, ep, (&intel.Entity{
IP: net.ParseIP("192.168.0.1"),
}).Init(), Permitted)
testEndpointMatch(t, ep, (&intel.Entity{
IP: net.ParseIP("151.101.1.164"), // nytimes.com
}).Init(), NoMatch)
// Lists
_, err = parseEndpoint("+ L:A,B,C")
if err != nil {
t.Fatal(err)
}
// TODO: write test for lists matcher
}

View file

@ -126,6 +126,18 @@ func (lp *LayeredProfile) getValidityFlag() *abool.AtomicBool {
return lp.validityFlag
}
// RevisionCnt returns the current profile revision counter.
func (lp *LayeredProfile) RevisionCnt() (revisionCounter uint64) {
if lp == nil {
return 0
}
lp.lock.Lock()
defer lp.lock.Unlock()
return lp.revisionCounter
}
// Update checks for updated profiles and replaces any outdated profiles.
func (lp *LayeredProfile) Update() (revisionCounter uint64) {
lp.lock.Lock()

View file

@ -1,6 +1,7 @@
package resolver
import (
"context"
"crypto/tls"
"net"
"sync"
@ -9,6 +10,13 @@ import (
"github.com/miekg/dns"
)
const (
defaultClientTTL = 5 * time.Minute
defaultRequestTimeout = 3 * time.Second // dns query
defaultConnectTimeout = 2 * time.Second // tcp/tls
connectionEOLGracePeriod = 7 * time.Second
)
var (
localAddrFactory func(network string) net.Addr
)
@ -27,21 +35,54 @@ func getLocalAddr(network string) net.Addr {
return nil
}
type clientManager struct {
dnsClient *dns.Client
factory func() *dns.Client
type dnsClientManager struct {
lock sync.Mutex
lock sync.Mutex
refreshAfter time.Time
ttl time.Duration // force refresh of connection to reduce traceability
// set by creator
serverAddress string
ttl time.Duration // force refresh of connection to reduce traceability
factory func() *dns.Client
// internal
pool sync.Pool
}
func newDNSClientManager(_ *Resolver) *clientManager {
return &clientManager{
ttl: 0, // new client for every request, as we need to randomize the port
type dnsClient struct {
mgr *dnsClientManager
client *dns.Client
conn *dns.Conn
useUntil time.Time
}
// getConn returns the *dns.Conn and if it's new. This function may only be called between clientManager.getDNSClient() and dnsClient.done().
func (dc *dnsClient) getConn() (c *dns.Conn, new bool, err error) {
if dc.conn == nil {
dc.conn, err = dc.client.Dial(dc.mgr.serverAddress)
if err != nil {
return nil, false, err
}
return dc.conn, true, nil
}
return dc.conn, false, nil
}
func (dc *dnsClient) addToPool() {
dc.mgr.pool.Put(dc)
}
func (dc *dnsClient) destroy() {
if dc.conn != nil {
_ = dc.conn.Close()
}
}
func newDNSClientManager(resolver *Resolver) *dnsClientManager {
return &dnsClientManager{
serverAddress: resolver.ServerAddress,
ttl: 0, // new client for every request, as we need to randomize the port
factory: func() *dns.Client {
return &dns.Client{
Timeout: 5 * time.Second,
Timeout: defaultRequestTimeout,
Dialer: &net.Dialer{
LocalAddr: getLocalAddr("udp"),
},
@ -50,25 +91,28 @@ func newDNSClientManager(_ *Resolver) *clientManager {
}
}
func newTCPClientManager(_ *Resolver) *clientManager {
return &clientManager{
ttl: 0, // TODO: build a custom client that can reuse connections to some degree (performance / privacy tradeoff)
func newTCPClientManager(resolver *Resolver) *dnsClientManager {
return &dnsClientManager{
serverAddress: resolver.ServerAddress,
ttl: defaultClientTTL,
factory: func() *dns.Client {
return &dns.Client{
Net: "tcp",
Timeout: 5 * time.Second,
Timeout: defaultRequestTimeout,
Dialer: &net.Dialer{
LocalAddr: getLocalAddr("tcp"),
KeepAlive: 15 * time.Second,
Timeout: defaultConnectTimeout,
KeepAlive: defaultClientTTL,
},
}
},
}
}
func newTLSClientManager(resolver *Resolver) *clientManager {
return &clientManager{
ttl: 0, // TODO: build a custom client that can reuse connections to some degree (performance / privacy tradeoff)
func newTLSClientManager(resolver *Resolver) *dnsClientManager {
return &dnsClientManager{
serverAddress: resolver.ServerAddress,
ttl: defaultClientTTL,
factory: func() *dns.Client {
return &dns.Client{
Net: "tcp-tls",
@ -77,24 +121,68 @@ func newTLSClientManager(resolver *Resolver) *clientManager {
ServerName: resolver.VerifyDomain,
// TODO: use portbase rng
},
Timeout: 5 * time.Second,
Timeout: defaultRequestTimeout,
Dialer: &net.Dialer{
LocalAddr: getLocalAddr("tcp"),
KeepAlive: 15 * time.Second,
Timeout: defaultConnectTimeout,
KeepAlive: defaultClientTTL,
},
}
},
}
}
func (cm *clientManager) getDNSClient() *dns.Client {
func (cm *dnsClientManager) getDNSClient() *dnsClient {
cm.lock.Lock()
defer cm.lock.Unlock()
if cm.dnsClient == nil || cm.ttl == 0 || time.Now().After(cm.refreshAfter) {
cm.dnsClient = cm.factory()
cm.refreshAfter = time.Now().Add(cm.ttl)
// return new immediately if a new client should be used for every request
if cm.ttl == 0 {
return &dnsClient{
mgr: cm,
client: cm.factory(),
}
}
return cm.dnsClient
// get cached client from pool
now := time.Now().UTC()
poolLoop:
for {
dc, ok := cm.pool.Get().(*dnsClient)
switch {
case !ok || dc == nil: // cache empty (probably, pool may always return nil!)
break poolLoop // create new
case now.After(dc.useUntil):
continue // get next
default:
return dc
}
}
// no available in pool, create new
newClient := &dnsClient{
mgr: cm,
client: cm.factory(),
useUntil: now.Add(cm.ttl),
}
newClient.startCleaner()
return newClient
}
// startCleaner waits for EOL of the client and then removes it from the pool.
func (dc *dnsClient) startCleaner() {
// While a single worker to clean all connections may be slightly more performant, this approach focuses on least as possible locking and is simpler, thus less error prone.
module.StartWorker("dns client cleanup", func(ctx context.Context) error {
select {
case <-time.After(dc.mgr.ttl + connectionEOLGracePeriod):
// destroy
case <-ctx.Done():
// give a short time before kill for graceful request completion
time.Sleep(100 * time.Millisecond)
}
dc.destroy()
return nil
})
}

View file

@ -2,6 +2,7 @@ package resolver
import (
"context"
"strings"
"time"
"github.com/safing/portbase/log"
@ -30,6 +31,7 @@ func start() error {
// load resolvers from config and environment
loadResolvers()
// reload after network change
err := module.RegisterEventHook(
"netenv",
"network changed",
@ -44,6 +46,27 @@ func start() error {
return err
}
// reload after config change
prevNameservers := strings.Join(configuredNameServers(), " ")
err = module.RegisterEventHook(
"config",
"config change",
"update nameservers",
func(_ context.Context, _ interface{}) error {
newNameservers := strings.Join(configuredNameServers(), " ")
if newNameservers != prevNameservers {
prevNameservers = newNameservers
loadResolvers()
log.Debug("resolver: reloaded nameservers due to config change")
}
return nil
},
)
if err != nil {
return err
}
module.StartServiceWorker(
"mdns handler",
5*time.Second,

View file

@ -9,6 +9,8 @@ import (
"sync"
"time"
"github.com/safing/portmaster/network/netutils"
"github.com/miekg/dns"
"github.com/safing/portbase/log"
@ -29,10 +31,11 @@ var (
questionsLock sync.Mutex
mDNSResolver = &Resolver{
Server: ServerSourceMDNS,
ServerType: ServerTypeDNS,
Source: ServerSourceMDNS,
Conn: &mDNSResolverConn{},
Server: ServerSourceMDNS,
ServerType: ServerTypeDNS,
ServerIPScope: netutils.SiteLocal,
Source: ServerSourceMDNS,
Conn: &mDNSResolverConn{},
}
)
@ -42,10 +45,10 @@ func (mrc *mDNSResolverConn) Query(ctx context.Context, q *Query) (*RRCache, err
return queryMulticastDNS(ctx, q)
}
func (mrc *mDNSResolverConn) MarkFailed() {}
func (mrc *mDNSResolverConn) ReportFailure() {}
func (mrc *mDNSResolverConn) LastFail() time.Time {
return time.Time{}
func (mrc *mDNSResolverConn) IsFailing() bool {
return false
}
type savedQuestion struct {
@ -189,15 +192,21 @@ func handleMDNSMessages(ctx context.Context, messages chan *dns.Msg) error {
// get entry from database
if saveFullRequest {
// get from database
rrCache, err = GetRRCache(question.Name, dns.Type(question.Qtype))
// if we have no cached entry, or it has been updated less more than two seconds ago, or if it expired:
// create new and do not append
if err != nil || rrCache.updated < time.Now().Add(-2*time.Second).Unix() || rrCache.TTL < time.Now().Unix() {
rrCache = &RRCache{
Domain: question.Name,
Question: dns.Type(question.Qtype),
Domain: question.Name,
Question: dns.Type(question.Qtype),
Server: mDNSResolver.Server,
ServerScope: mDNSResolver.ServerIPScope,
}
}
}
// add all entries to RRCache
for _, entry := range message.Answer {
if strings.HasSuffix(entry.Header().Name, ".local.") || domainInScope(entry.Header().Name, localReverseScopes) {
if saveFullRequest {
@ -289,9 +298,11 @@ func handleMDNSMessages(ctx context.Context, messages chan *dns.Msg) error {
continue
}
rrCache = &RRCache{
Domain: v.Header().Name,
Question: dns.Type(v.Header().Class),
Answer: []dns.RR{v},
Domain: v.Header().Name,
Question: dns.Type(v.Header().Class),
Answer: []dns.RR{v},
Server: mDNSResolver.Server,
ServerScope: mDNSResolver.ServerIPScope,
}
rrCache.Clean(60)
err := rrCache.Save()

View file

@ -12,7 +12,7 @@ import (
var (
recordDatabase = database.NewInterface(&database.Options{
AlwaysSetRelativateExpiry: 2592000, // 30 days
CacheSize: 128,
CacheSize: 256,
})
)

View file

@ -0,0 +1,27 @@
package resolver
import "testing"
func TestNameRecordStorage(t *testing.T) {
testDomain := "Mk35mMqOWEHXSMk11MYcbjLOjTE8PQvDiAVUxf4BvwtgR.example.com."
testQuestion := "A"
testNameRecord := &NameRecord{
Domain: testDomain,
Question: testQuestion,
}
err := testNameRecord.Save()
if err != nil {
t.Fatal(err)
}
r, err := GetNameRecord(testDomain, testQuestion)
if err != nil {
t.Fatal(err)
}
if r.Domain != testDomain || r.Question != testQuestion {
t.Fatal("mismatch")
}
}

189
resolver/pooling_test.go Normal file
View file

@ -0,0 +1,189 @@
package resolver
import (
"sync"
"sync/atomic"
"testing"
"github.com/miekg/dns"
)
var (
domainFeed = make(chan string)
)
func testQuery(t *testing.T, wg *sync.WaitGroup, newCnt *uint32, brc *BasicResolverConn, q *Query) {
dnsClient := brc.clientManager.getDNSClient()
// create query
dnsQuery := new(dns.Msg)
dnsQuery.SetQuestion(q.FQDN, uint16(q.QType))
// get connection
conn, new, err := dnsClient.getConn()
if err != nil {
t.Fatalf("failed to connect: %s", err) //nolint:staticcheck
}
if new {
atomic.AddUint32(newCnt, 1)
}
// query server
reply, ttl, err := dnsClient.client.ExchangeWithConn(dnsQuery, conn)
if err != nil {
t.Fatal(err) //nolint:staticcheck
}
if reply == nil {
t.Fatalf("resolved %s, but reply was empty!", q.FQDN) //nolint:staticcheck
}
t.Logf("resolved %s [new resolver = %v] in %s", q.FQDN, new, ttl)
dnsClient.addToPool()
wg.Done()
}
func TestClientPooling(t *testing.T) {
// skip if short - this test depends on the Internet and might fail randomly
if testing.Short() {
t.Skip()
}
go feedDomains()
// create separate resolver for this test
resolver, _, err := createResolver("dot://9.9.9.9:853?verify=dns.quad9.net&name=Quad9&blockedif=empty", "config")
if err != nil {
t.Fatal(err)
}
brc := resolver.Conn.(*BasicResolverConn)
wg := &sync.WaitGroup{}
var newCnt uint32
for i := 0; i < 10; i++ {
wg.Add(10)
for i := 0; i < 10; i++ {
go testQuery(t, wg, &newCnt, brc, &Query{ //nolint:staticcheck
FQDN: <-domainFeed,
QType: dns.Type(dns.TypeA),
})
}
wg.Wait()
if newCnt > uint32(10+i) {
t.Fatalf("unexpected pool size: %d (limit is %d)", newCnt, 10+i)
}
}
}
func feedDomains() {
for {
for _, domain := range poolingTestDomains {
domainFeed <- domain
}
}
}
// Data
var (
poolingTestDomains = []string{
"facebook.com.",
"google.com.",
"youtube.com.",
"twitter.com.",
"instagram.com.",
"linkedin.com.",
"microsoft.com.",
"apple.com.",
"wikipedia.org.",
"plus.google.com.",
"en.wikipedia.org.",
"googletagmanager.com.",
"youtu.be.",
"adobe.com.",
"vimeo.com.",
"pinterest.com.",
"itunes.apple.com.",
"play.google.com.",
"maps.google.com.",
"goo.gl.",
"wordpress.com.",
"blogspot.com.",
"bit.ly.",
"github.com.",
"player.vimeo.com.",
"amazon.com.",
"wordpress.org.",
"docs.google.com.",
"yahoo.com.",
"mozilla.org.",
"tumblr.com.",
"godaddy.com.",
"flickr.com.",
"parked-content.godaddy.com.",
"drive.google.com.",
"support.google.com.",
"apache.org.",
"gravatar.com.",
"europa.eu.",
"qq.com.",
"w3.org.",
"nytimes.com.",
"reddit.com.",
"macromedia.com.",
"get.adobe.com.",
"soundcloud.com.",
"sourceforge.net.",
"sites.google.com.",
"nih.gov.",
"amazonaws.com.",
"t.co.",
"support.microsoft.com.",
"forbes.com.",
"theguardian.com.",
"cnn.com.",
"github.io.",
"bbc.co.uk.",
"dropbox.com.",
"whatsapp.com.",
"medium.com.",
"creativecommons.org.",
"www.ncbi.nlm.nih.gov.",
"httpd.apache.org.",
"archive.org.",
"ec.europa.eu.",
"php.net.",
"apps.apple.com.",
"weebly.com.",
"support.apple.com.",
"weibo.com.",
"wixsite.com.",
"issuu.com.",
"who.int.",
"paypal.com.",
"m.facebook.com.",
"oracle.com.",
"msn.com.",
"gnu.org.",
"tinyurl.com.",
"reuters.com.",
"l.facebook.com.",
"cloudflare.com.",
"wsj.com.",
"washingtonpost.com.",
"domainmarket.com.",
"imdb.com.",
"bbc.com.",
"bing.com.",
"accounts.google.com.",
"vk.com.",
"api.whatsapp.com.",
"opera.com.",
"cdc.gov.",
"slideshare.net.",
"wpa.qq.com.",
"harvard.edu.",
"mit.edu.",
"code.google.com.",
"wikimedia.org.",
}
)

View file

@ -14,8 +14,6 @@ import (
)
var (
mtAsyncResolve = "async resolve"
// basic errors
// ErrNotFound is a basic error that will match all "not found" errors
@ -114,6 +112,7 @@ func Resolve(ctx context.Context, q *Query) (rrCache *RRCache, err error) {
rrCache.MixAnswers()
return rrCache, nil
}
log.Tracer(ctx).Debugf("resolver: waited for another %s%s query, but cache missed!", q.FQDN, q.QType)
// if cache is still empty or non-compliant, go ahead and just query
} else {
// we are the first!
@ -132,14 +131,14 @@ func checkCache(ctx context.Context, q *Query) *RRCache {
if err != nil {
if err != database.ErrNotFound {
log.Tracer(ctx).Warningf("resolver: getting RRCache %s%s from database failed: %s", q.FQDN, q.QType.String(), err)
log.Warningf("resolver: getting RRCache %s%s from database failed: %s", q.FQDN, q.QType.String(), err)
}
return nil
}
// get resolver that rrCache was resolved with
resolver := getResolverByIDWithLocking(rrCache.Server)
resolver := getActiveResolverByIDWithLocking(rrCache.Server)
if resolver == nil {
log.Tracer(ctx).Debugf("resolver: ignoring RRCache %s%s because source server %s has been removed", q.FQDN, q.QType.String(), rrCache.Server)
return nil
}
@ -159,12 +158,13 @@ func checkCache(ctx context.Context, q *Query) *RRCache {
log.Tracer(ctx).Trace("resolver: serving from cache, requesting new")
// resolve async
module.StartMediumPriorityMicroTask(&mtAsyncResolve, func(ctx context.Context) error {
module.StartWorker("resolve async", func(ctx context.Context) error {
_, _ = resolveAndCache(ctx, q)
return nil
})
}
log.Tracer(ctx).Tracef("resolver: using cached RR (expires in %s)", time.Until(time.Unix(rrCache.TTL, 0)))
return rrCache
}
@ -218,11 +218,6 @@ func resolveAndCache(ctx context.Context, q *Query) (rrCache *RRCache, err error
return nil, ErrNoCompliance
}
// prep
lastFailBoundary := time.Now().Add(
-time.Duration(nameserverRetryRate()) * time.Second,
)
// start resolving
var i int
@ -231,7 +226,7 @@ resolveLoop:
for i = 0; i < 2; i++ {
for _, resolver := range resolvers {
// check if resolver failed recently (on first run)
if i == 0 && resolver.Conn.LastFail().After(lastFailBoundary) {
if i == 0 && resolver.Conn.IsFailing() {
log.Tracer(ctx).Tracef("resolver: skipping resolver %s, because it failed recently", resolver)
continue
}

View file

@ -2,6 +2,7 @@ package resolver
import (
"context"
"errors"
"net"
"sync"
"time"
@ -23,6 +24,11 @@ const (
ServerSourceMDNS = "mdns"
)
var (
// FailThreshold is amount of errors a resolvers must experience in order to be regarded as failed.
FailThreshold = 5
)
// Resolver holds information about an active resolver.
type Resolver struct {
// Server config url (and ID)
@ -83,8 +89,8 @@ func (resolver *Resolver) String() string {
// ResolverConn is an interface to implement different types of query backends.
type ResolverConn interface { //nolint:go-lint // TODO
Query(ctx context.Context, q *Query) (*RRCache, error)
MarkFailed()
LastFail() time.Time
ReportFailure()
IsFailing() bool
}
// BasicResolverConn implements ResolverConn for standard dns clients.
@ -92,12 +98,14 @@ type BasicResolverConn struct {
sync.Mutex // for lastFail
resolver *Resolver
clientManager *clientManager
lastFail time.Time
clientManager *dnsClientManager
lastFail time.Time
fails int
}
// MarkFailed marks the resolver as failed.
func (brc *BasicResolverConn) MarkFailed() {
// ReportFailure reports that an error occurred with this resolver.
func (brc *BasicResolverConn) ReportFailure() {
if !netenv.Online() {
// don't mark failed if we are offline
return
@ -105,14 +113,26 @@ func (brc *BasicResolverConn) MarkFailed() {
brc.Lock()
defer brc.Unlock()
brc.lastFail = time.Now()
now := time.Now().UTC()
failDuration := time.Duration(nameserverRetryRate()) * time.Second
// reset fail counter if currently not failing
if now.Add(-failDuration).After(brc.lastFail) {
brc.fails = 0
}
// update
brc.lastFail = now
brc.fails++
}
// LastFail returns the internal lastfail value while locking the Resolver.
func (brc *BasicResolverConn) LastFail() time.Time {
// IsFailing returns if this resolver is currently failing.
func (brc *BasicResolverConn) IsFailing() bool {
brc.Lock()
defer brc.Unlock()
return brc.lastFail
failDuration := time.Duration(nameserverRetryRate()) * time.Second
return brc.fails >= FailThreshold && time.Now().UTC().Add(-failDuration).Before(brc.lastFail)
}
// Query executes the given query against the resolver.
@ -126,35 +146,78 @@ func (brc *BasicResolverConn) Query(ctx context.Context, q *Query) (*RRCache, er
// start
var reply *dns.Msg
var ttl time.Duration
var err error
for i := 0; i < 3; i++ {
var conn *dns.Conn
var new bool
var tries int
// log query time
// qStart := time.Now()
reply, _, err = brc.clientManager.getDNSClient().Exchange(dnsQuery, resolver.ServerAddress)
// log.Tracef("resolver: query to %s took %s", resolver.Server, time.Now().Sub(qStart))
for ; tries < 3; tries++ {
// error handling
// first get connection
dc := brc.clientManager.getDNSClient()
conn, new, err = dc.getConn()
if err != nil {
log.Tracer(ctx).Tracef("resolver: query to %s encountered error: %s", resolver.Server, err)
log.Tracer(ctx).Tracef("resolver: failed to connect to %s: %s", resolver.Server, err)
// remove client from pool
dc.destroy()
// report that resolver had an error
brc.ReportFailure()
// hint network environment at failed connection
netenv.ReportFailedConnection()
// TODO: handle special cases
// 1. connect: network is unreachable
// 2. timeout
// hint network environment at failed connection
netenv.ReportFailedConnection()
// try again
continue
}
if new {
log.Tracer(ctx).Tracef("resolver: created new connection to %s (%s)", resolver.Name, resolver.ServerAddress)
} else {
log.Tracer(ctx).Tracef("resolver: reusing connection to %s (%s)", resolver.Name, resolver.ServerAddress)
}
// query server
reply, ttl, err = dc.client.ExchangeWithConn(dnsQuery, conn)
log.Tracer(ctx).Tracef("resolver: query took %s", ttl)
// error handling
if err != nil {
log.Tracer(ctx).Tracef("resolver: query to %s encountered error: %s", resolver.Server, err)
// remove client from pool
dc.destroy()
// temporary error
if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
log.Tracer(ctx).Tracef("resolver: retrying to resolve %s%s with %s, error is temporary", q.FQDN, q.QType, resolver.Server)
// try again
continue
}
// report failed if dns (nothing happens at getConn())
if resolver.ServerType == ServerTypeDNS {
// report that resolver had an error
brc.ReportFailure()
// hint network environment at failed connection
netenv.ReportFailedConnection()
}
// permanent error
break
} else if reply == nil {
// remove client from pool
dc.destroy()
log.Errorf("resolver: successful query for %s%s to %s, but reply was nil", q.FQDN, q.QType, resolver.Server)
return nil, errors.New("internal error")
}
// make client available (again)
dc.addToPool()
if resolver.IsBlockedUpstream(reply) {
return nil, &BlockedUpstreamError{resolver.GetName()}
}
@ -166,12 +229,15 @@ func (brc *BasicResolverConn) Query(ctx context.Context, q *Query) (*RRCache, er
if err != nil {
return nil, err
// TODO: mark as failed
} else if reply == nil {
log.Errorf("resolver: queried %s for %s%s (%d tries), but reply was nil", q.FQDN, q.QType, resolver.GetName(), tries+1)
return nil, errors.New("internal error")
}
// hint network environment at successful connection
netenv.ReportSuccessfulConnection()
new := &RRCache{
newRecord := &RRCache{
Domain: q.FQDN,
Question: q.QType,
Answer: reply.Answer,
@ -182,5 +248,5 @@ func (brc *BasicResolverConn) Query(ctx context.Context, q *Query) (*RRCache, er
}
// TODO: check if reply.Answer is valid
return new, nil
return newRecord, nil
}

View file

@ -25,7 +25,7 @@ var (
globalResolvers []*Resolver // all (global) resolvers
localResolvers []*Resolver // all resolvers that are in site-local or link-local IP ranges
localScopes []*Scope // list of scopes with a list of local resolvers that can resolve the scope
allResolvers map[string]*Resolver // lookup map of all resolvers
activeResolvers map[string]*Resolver // lookup map of all resolvers
resolversLock sync.RWMutex
dupReqMap = make(map[string]*sync.WaitGroup)
@ -41,11 +41,11 @@ func indexOfScope(domain string, list []*Scope) int {
return -1
}
func getResolverByIDWithLocking(server string) *Resolver {
resolversLock.Lock()
defer resolversLock.Unlock()
func getActiveResolverByIDWithLocking(server string) *Resolver {
resolversLock.RLock()
defer resolversLock.RUnlock()
resolver, ok := allResolvers[server]
resolver, ok := activeResolvers[server]
if ok {
return resolver
}
@ -62,7 +62,7 @@ func formatIPAndPort(ip net.IP, port uint16) string {
return address
}
func clientManagerFactory(serverType string) func(*Resolver) *clientManager {
func clientManagerFactory(serverType string) func(*Resolver) *dnsClientManager {
switch serverType {
case ServerTypeDNS:
return newDNSClientManager
@ -152,8 +152,8 @@ func configureSearchDomains(resolver *Resolver, searches []string) {
}
}
func getConfiguredResolvers() (resolvers []*Resolver) {
for _, server := range configuredNameServers() {
func getConfiguredResolvers(list []string) (resolvers []*Resolver) {
for _, server := range list {
resolver, skip, err := createResolver(server, "config")
if err != nil {
// TODO(ppacher): module error
@ -199,19 +199,40 @@ func loadResolvers() {
defer resolversLock.Unlock()
newResolvers := append(
getConfiguredResolvers(),
getConfiguredResolvers(configuredNameServers()),
getSystemResolvers()...,
)
// save resolvers
globalResolvers = newResolvers
if len(globalResolvers) == 0 {
log.Criticalf("resolver: no (valid) dns servers found in configuration and system")
// TODO(module error)
if len(newResolvers) == 0 {
msg := "no (valid) dns servers found in (user) configuration or system, falling back to defaults"
log.Warningf("resolver: %s", msg)
module.Warning("no-valid-user-resolvers", msg)
// load defaults directly, overriding config system
newResolvers = getConfiguredResolvers(defaultNameServers)
if len(newResolvers) == 0 {
msg = "no (valid) dns servers found in configuration or system"
log.Criticalf("resolver: %s", msg)
module.Error("no-valid-default-resolvers", msg)
return
}
}
// save resolvers
globalResolvers = newResolvers
// assing resolvers to scopes
setLocalAndScopeResolvers(globalResolvers)
// set active resolvers (for cache validation)
// reset
activeResolvers = make(map[string]*Resolver)
// add
for _, resolver := range newResolvers {
activeResolvers[resolver.Server] = resolver
}
activeResolvers[mDNSResolver.Server] = mDNSResolver
// log global resolvers
if len(globalResolvers) > 0 {
log.Trace("resolver: loaded global resolvers:")

View file

@ -65,12 +65,12 @@ func ResolveIPAndValidate(ctx context.Context, ip string, securityLevel uint8) (
for _, rr := range rrCache.Answer {
switch v := rr.(type) {
case *dns.A:
log.Infof("A: %s", v.A.String())
// log.Debugf("A: %s", v.A.String())
if ip == v.A.String() {
return ptrName, nil
}
case *dns.AAAA:
log.Infof("AAAA: %s", v.AAAA.String())
// log.Debugf("AAAA: %s", v.AAAA.String())
if ip == v.AAAA.String() {
return ptrName, nil
}

41
resolver/rrcache_test.go Normal file
View file

@ -0,0 +1,41 @@
package resolver
import (
"testing"
"github.com/miekg/dns"
)
func TestCaching(t *testing.T) {
testDomain := "Mk35mMqOWEHXSMk11MYcbjLOjTE8PQvDiAVUxf4BvwtgR.example.com."
testQuestion := "A"
testNameRecord := &NameRecord{
Domain: testDomain,
Question: testQuestion,
}
err := testNameRecord.Save()
if err != nil {
t.Fatal(err)
}
rrCache, err := GetRRCache(testDomain, dns.Type(dns.TypeA))
if err != nil {
t.Fatal(err)
}
err = rrCache.Save()
if err != nil {
t.Fatal(err)
}
rrCache2, err := GetRRCache(testDomain, dns.Type(dns.TypeA))
if err != nil {
t.Fatal(err)
}
if rrCache2.Domain != rrCache.Domain {
t.Fatal("something very is wrong")
}
}

View file

@ -3,6 +3,8 @@ package ui
import (
"context"
"github.com/safing/portbase/dataroot"
resources "github.com/cookieo9/resources-go"
"github.com/safing/portbase/log"
"github.com/safing/portbase/modules"
@ -27,6 +29,19 @@ func prep() error {
}
func start() error {
// Create a dummy directory to which processes change their working directory
// to. Currently this includes the App and the Notifier. The aim is protect
// all other directories and increase compatibility should any process want
// to read or write something to the current working directory. This can also
// be useful in the future to dump data to for debugging. The permission used
// may seem dangerous, but proper permission on the parent directory provide
// (some) protection.
// Processes must _never_ read from this directory.
err := dataroot.Root().ChildDir("exec", 0777).Ensure()
if err != nil {
log.Warningf("ui: failed to create safe exec dir: %s", err)
}
return module.RegisterEventHook("ui", eventReload, "reload assets", reloadUI)
}

View file

@ -152,7 +152,7 @@ func start() error {
err = registry.LoadIndexes()
if err != nil {
return err
log.Warningf("updates: failed to load indexes: %s", err)
}
err = registry.ScanStorage("")
@ -235,8 +235,7 @@ func checkForUpdates(ctx context.Context) (err error) {
}()
if err = registry.UpdateIndexes(); err != nil {
err = fmt.Errorf("failed to update indexes: %w", err)
return
log.Warningf("updates: failed to update indexes: %s", err)
}
err = registry.DownloadUpdates(ctx)

View file

@ -113,15 +113,9 @@ func upgradePortmasterControl() error {
return nil
}
// check if registry tmp dir is ok
err := registry.TmpDir().Ensure()
if err != nil {
return fmt.Errorf("failed to prep updates tmp dir: %s", err)
}
// update portmaster-control in data root
rootControlPath := filepath.Join(filepath.Dir(registry.StorageDir().Path), filename)
err = upgradeFile(rootControlPath, pmCtrlUpdate)
err := upgradeFile(rootControlPath, pmCtrlUpdate)
if err != nil {
return err
}
@ -130,11 +124,11 @@ func upgradePortmasterControl() error {
// upgrade parent process, if it's portmaster-control
parent, err := processInfo.NewProcess(int32(os.Getppid()))
if err != nil {
return fmt.Errorf("could not get parent process for upgrade checks: %s", err)
return fmt.Errorf("could not get parent process for upgrade checks: %w", err)
}
parentName, err := parent.Name()
if err != nil {
return fmt.Errorf("could not get parent process name for upgrade checks: %s", err)
return fmt.Errorf("could not get parent process name for upgrade checks: %w", err)
}
if parentName != filename {
log.Tracef("updates: parent process does not seem to be portmaster-control, name is %s", parentName)
@ -142,7 +136,7 @@ func upgradePortmasterControl() error {
}
parentPath, err := parent.Exe()
if err != nil {
return fmt.Errorf("could not get parent process path for upgrade: %s", err)
return fmt.Errorf("could not get parent process path for upgrade: %w", err)
}
err = upgradeFile(parentPath, pmCtrlUpdate)
if err != nil {
@ -190,7 +184,7 @@ func upgradeFile(fileToUpgrade string, file *updater.File) error {
// ensure tmp dir is here
err = registry.TmpDir().Ensure()
if err != nil {
return fmt.Errorf("unable to check updates tmp dir for moving file that needs upgrade: %s", err)
return fmt.Errorf("could not prepare tmp directory for moving file that needs upgrade: %w", err)
}
// maybe we're on windows and it's in use, try moving
@ -204,17 +198,17 @@ func upgradeFile(fileToUpgrade string, file *updater.File) error {
),
))
if err != nil {
return fmt.Errorf("unable to move file that needs upgrade: %s", err)
return fmt.Errorf("unable to move file that needs upgrade: %w", err)
}
}
}
// copy upgrade
err = copyFile(file.Path(), fileToUpgrade)
err = CopyFile(file.Path(), fileToUpgrade)
if err != nil {
// try again
time.Sleep(1 * time.Second)
err = copyFile(file.Path(), fileToUpgrade)
err = CopyFile(file.Path(), fileToUpgrade)
if err != nil {
return err
}
@ -224,23 +218,31 @@ func upgradeFile(fileToUpgrade string, file *updater.File) error {
if !onWindows {
info, err := os.Stat(fileToUpgrade)
if err != nil {
return fmt.Errorf("failed to get file info on %s: %s", fileToUpgrade, err)
return fmt.Errorf("failed to get file info on %s: %w", fileToUpgrade, err)
}
if info.Mode() != 0755 {
err := os.Chmod(fileToUpgrade, 0755)
if err != nil {
return fmt.Errorf("failed to set permissions on %s: %s", fileToUpgrade, err)
return fmt.Errorf("failed to set permissions on %s: %w", fileToUpgrade, err)
}
}
}
return nil
}
func copyFile(srcPath, dstPath string) (err error) {
// CopyFile atomically copies a file using the update registry's tmp dir.
func CopyFile(srcPath, dstPath string) (err error) {
// check tmp dir
err = registry.TmpDir().Ensure()
if err != nil {
return fmt.Errorf("could not prepare tmp directory for copying file: %w", err)
}
// open file for writing
atomicDstFile, err := renameio.TempFile(registry.TmpDir().Path, dstPath)
if err != nil {
return fmt.Errorf("could not create temp file for atomic copy: %s", err)
return fmt.Errorf("could not create temp file for atomic copy: %w", err)
}
defer atomicDstFile.Cleanup() //nolint:errcheck // ignore error for now, tmp dir will be cleaned later again anyway
@ -260,7 +262,7 @@ func copyFile(srcPath, dstPath string) (err error) {
// finalize file
err = atomicDstFile.CloseAtomicallyReplace()
if err != nil {
return fmt.Errorf("updates: failed to finalize copy to file %s: %s", dstPath, err)
return fmt.Errorf("updates: failed to finalize copy to file %s: %w", dstPath, err)
}
return nil