Merge from develop, update README.md

This commit is contained in:
Daniel 2019-01-24 15:40:18 +01:00
commit 3ba081404b
143 changed files with 6882 additions and 3125 deletions

4
.gitignore vendored Normal file
View file

@ -0,0 +1,4 @@
dnsonly
main
*.exe

View file

@ -1,3 +1,41 @@
# Portmaster # Portmaster
The Portmaster is currently being revamped. You can check out the latest changes in the `develop` branch. The Portmaster enables you to protect your data on your device. You are back in charge of your outgoing connections: you choose what data you share and what data stays private.
## Current Status
The Portmaster is currently in alpha. Expect dragons.
Supported platforms:
- linux_amd64
- windows_amd64 (_soon_)
- darwin_amd64 (_later_)
## Usage
Just download the portmaster from the releases page.
./portmaster -db=/opt/pm_db
# this will add some rules to iptables for traffic interception via nfqueue (and will clean up afterwards!)
# then start the ui
./portmaster -db=/opt/pm_db -ui
# missing files will be automatically download when first needed
## Documentation
Documentation _in progress_ can be found here: [http://docs.safing.io/](http://docs.safing.io/)
## Dependencies
#### Linux
- libnetfilter_queue
- debian/ubuntu: `sudo apt-get install libnetfilter-queue1`
- fedora: `sudo yum install libnetfilter_queue`
- arch: `sudo pacman -S libnetfilter_queue`
- [Network Manager](https://wiki.gnome.org/Projects/NetworkManager) (_optional_)
#### Windows
- Windows 7 (with update KB3033929) or up
- [KB3033929](https://docs.microsoft.com/en-us/security-updates/SecurityAdvisories/2015/3033929) (a 2015 security update) is required for correctly verifying the driver signature
- Windows Server 2016 systems must have secure boot disabled. (_clarification needed_)

52
build Executable file
View file

@ -0,0 +1,52 @@
#!/bin/bash
# get build data
if [[ "$BUILD_COMMIT" == "" ]]; then
BUILD_COMMIT=$(git describe --all --long --abbrev=99 --dirty 2>/dev/null)
fi
if [[ "$BUILD_USER" == "" ]]; then
BUILD_USER=$(id -un)
fi
if [[ "$BUILD_HOST" == "" ]]; then
BUILD_HOST=$(hostname -f)
fi
if [[ "$BUILD_DATE" == "" ]]; then
BUILD_DATE=$(date +%d.%m.%Y)
fi
if [[ "$BUILD_SOURCE" == "" ]]; then
BUILD_SOURCE=$(git remote -v | grep origin | cut -f2 | cut -d" " -f1 | head -n 1)
fi
if [[ "$BUILD_SOURCE" == "" ]]; then
BUILD_SOURCE=$(git remote -v | cut -f2 | cut -d" " -f1 | head -n 1)
fi
BUILD_BUILDOPTIONS=$(echo $* | sed "s/ /§/g")
# check
if [[ "$BUILD_COMMIT" == "" ]]; then
echo "could not automatically determine BUILD_COMMIT, please supply manually as environment variable."
exit 1
fi
if [[ "$BUILD_USER" == "" ]]; then
echo "could not automatically determine BUILD_USER, please supply manually as environment variable."
exit 1
fi
if [[ "$BUILD_HOST" == "" ]]; then
echo "could not automatically determine BUILD_HOST, please supply manually as environment variable."
exit 1
fi
if [[ "$BUILD_DATE" == "" ]]; then
echo "could not automatically determine BUILD_DATE, please supply manually as environment variable."
exit 1
fi
if [[ "$BUILD_SOURCE" == "" ]]; then
echo "could not automatically determine BUILD_SOURCE, please supply manually as environment variable."
exit 1
fi
echo "Please notice, that this build script includes metadata into the build."
echo "This information is useful for debugging and license compliance."
echo "Run the compiled binary with the -version flag to see the information included."
# build
BUILD_PATH="github.com/Safing/portbase/info"
go build -ldflags "-X ${BUILD_PATH}.commit=${BUILD_COMMIT} -X ${BUILD_PATH}.buildOptions=${BUILD_BUILDOPTIONS} -X ${BUILD_PATH}.buildUser=${BUILD_USER} -X ${BUILD_PATH}.buildHost=${BUILD_HOST} -X ${BUILD_PATH}.buildDate=${BUILD_DATE} -X ${BUILD_PATH}.buildSource=${BUILD_SOURCE}" $*

55
dnsonly.go Normal file
View file

@ -0,0 +1,55 @@
package main
import (
"fmt"
"os"
"os/signal"
"syscall"
"github.com/Safing/portbase/info"
"github.com/Safing/portbase/log"
"github.com/Safing/portbase/modules"
// include packages here
_ "github.com/Safing/portmaster/nameserver/only"
)
func main() {
// Set Info
info.Set("Portmaster (DNS only)", "0.2.0")
// Start
err := modules.Start()
if err != nil {
if err == modules.ErrCleanExit {
os.Exit(0)
} else {
err = modules.Shutdown()
if err != nil {
log.Shutdown()
}
os.Exit(1)
}
}
// Shutdown
// catch interrupt for clean shutdown
signalCh := make(chan os.Signal)
signal.Notify(
signalCh,
os.Interrupt,
syscall.SIGHUP,
syscall.SIGINT,
syscall.SIGTERM,
syscall.SIGQUIT,
)
select {
case <-signalCh:
fmt.Println(" <INTERRUPT>")
log.Warning("main: program was interrupted, shutting down.")
modules.Shutdown()
case <-modules.ShuttingDown():
}
}

26
firewall/config.go Normal file
View file

@ -0,0 +1,26 @@
package firewall
import (
"github.com/Safing/portbase/config"
)
var (
permanentVerdicts config.BoolOption
)
func registerConfig() error {
err := config.Register(&config.Option{
Name: "Permanent Verdicts",
Key: "firewall/permanentVerdicts",
Description: "With permanent verdicts, control of a connection is fully handed back to the OS after the initial decision. This brings a great performance increase, but makes it impossible to change the decision of a link later on.",
ExpertiseLevel: config.ExpertiseLevelExpert,
OptType: config.OptTypeBool,
DefaultValue: true,
})
if err != nil {
return err
}
permanentVerdicts = config.Concurrent.GetAsBool("firewall/permanentVerdicts", true)
return nil
}

View file

@ -1,28 +1,22 @@
// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file.
package firewall package firewall
import ( import (
"fmt"
"net" "net"
"os" "os"
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/Safing/safing-core/configuration" "github.com/Safing/portbase/log"
"github.com/Safing/safing-core/firewall/inspection" "github.com/Safing/portbase/modules"
"github.com/Safing/safing-core/firewall/interception" "github.com/Safing/portmaster/firewall/inspection"
"github.com/Safing/safing-core/log" "github.com/Safing/portmaster/firewall/interception"
"github.com/Safing/safing-core/modules" "github.com/Safing/portmaster/network"
"github.com/Safing/safing-core/network" "github.com/Safing/portmaster/network/packet"
"github.com/Safing/safing-core/network/packet" "github.com/Safing/portmaster/process"
"github.com/Safing/safing-core/port17/entry"
"github.com/Safing/safing-core/port17/mode"
"github.com/Safing/safing-core/portmaster"
"github.com/Safing/safing-core/process"
) )
var ( var (
firewallModule *modules.Module
// localNet net.IPNet // localNet net.IPNet
localhost net.IP localhost net.IP
dnsServer net.IPNet dnsServer net.IPNet
@ -30,8 +24,6 @@ var (
packetsBlocked *uint64 packetsBlocked *uint64
packetsDropped *uint64 packetsDropped *uint64
config = configuration.Get()
localNet4 *net.IPNet localNet4 *net.IPNet
// Yes, this would normally be 127.0.0.0/8 // Yes, this would normally be 127.0.0.0/8
// TODO: figure out any side effects // TODO: figure out any side effects
@ -46,23 +38,30 @@ var (
) )
func init() { func init() {
modules.Register("firewall", prep, start, stop, "global", "network", "nameserver", "profile")
}
var err error func prep() (err error) {
err = registerConfig()
if err != nil {
return err
}
_, localNet4, err = net.ParseCIDR("127.0.0.0/24") _, localNet4, err = net.ParseCIDR("127.0.0.0/24")
// Yes, this would normally be 127.0.0.0/8 // Yes, this would normally be 127.0.0.0/8
// TODO: figure out any side effects // TODO: figure out any side effects
if err != nil { if err != nil {
log.Criticalf("firewall: failed to parse cidr 127.0.0.0/24: %s", err) return fmt.Errorf("firewall: failed to parse cidr 127.0.0.0/24: %s", err)
} }
_, tunnelNet4, err = net.ParseCIDR("127.17.0.0/16") _, tunnelNet4, err = net.ParseCIDR("127.17.0.0/16")
if err != nil { if err != nil {
log.Criticalf("firewall: failed to parse cidr 127.17.0.0/16: %s", err) return fmt.Errorf("firewall: failed to parse cidr 127.17.0.0/16: %s", err)
} }
_, tunnelNet6, err = net.ParseCIDR("fd17::/64") _, tunnelNet6, err = net.ParseCIDR("fd17::/64")
if err != nil { if err != nil {
log.Criticalf("firewall: failed to parse cidr fd17::/64: %s", err) return fmt.Errorf("firewall: failed to parse cidr fd17::/64: %s", err)
} }
var pA uint64 var pA uint64
@ -71,20 +70,22 @@ func init() {
packetsBlocked = &pB packetsBlocked = &pB
var pD uint64 var pD uint64
packetsDropped = &pD packetsDropped = &pD
return nil
} }
func Start() { func start() error {
firewallModule = modules.Register("Firewall", 128)
defer firewallModule.StopComplete()
// start interceptor
go interception.Start()
go statLogger() go statLogger()
go run()
// go run()
// go run()
// go run()
// go run() return interception.Start()
// go run() }
// go run()
run() func stop() error {
return interception.Stop()
} }
func handlePacket(pkt packet.Packet) { func handlePacket(pkt packet.Packet) {
@ -111,12 +112,6 @@ func handlePacket(pkt packet.Packet) {
return return
} }
// allow anything that goes to a tunnel entrypoint
if pkt.IsOutbound() && (pkt.GetIPHeader().Dst.Equal(tunnelEntry4) || pkt.GetIPHeader().Dst.Equal(tunnelEntry6)) {
pkt.PermanentAccept()
return
}
// log.Debugf("firewall: pkt %s has ID %s", pkt, pkt.GetConnectionID()) // log.Debugf("firewall: pkt %s has ID %s", pkt, pkt.GetConnectionID())
// use this to time how long it takes process packet // use this to time how long it takes process packet
@ -124,16 +119,16 @@ func handlePacket(pkt packet.Packet) {
// defer log.Tracef("firewall: took %s to process packet %s", time.Now().Sub(timed).String(), pkt) // defer log.Tracef("firewall: took %s to process packet %s", time.Now().Sub(timed).String(), pkt)
// check if packet is destined for tunnel // check if packet is destined for tunnel
switch pkt.IPVersion() { // switch pkt.IPVersion() {
case packet.IPv4: // case packet.IPv4:
if portmaster.TunnelNet4 != nil && portmaster.TunnelNet4.Contains(pkt.GetIPHeader().Dst) { // if TunnelNet4 != nil && TunnelNet4.Contains(pkt.GetIPHeader().Dst) {
tunnelHandler(pkt) // tunnelHandler(pkt)
} // }
case packet.IPv6: // case packet.IPv6:
if portmaster.TunnelNet6 != nil && portmaster.TunnelNet6.Contains(pkt.GetIPHeader().Dst) { // if TunnelNet6 != nil && TunnelNet6.Contains(pkt.GetIPHeader().Dst) {
tunnelHandler(pkt) // tunnelHandler(pkt)
} // }
} // }
// associate packet to link and handle // associate packet to link and handle
link, created := network.GetOrCreateLinkByPacket(pkt) link, created := network.GetOrCreateLinkByPacket(pkt)
@ -146,7 +141,7 @@ func handlePacket(pkt packet.Packet) {
link.HandlePacket(pkt) link.HandlePacket(pkt)
return return
} }
verdict(pkt, link.Verdict) verdict(pkt, link.GetVerdict())
} }
@ -157,57 +152,68 @@ func initialHandler(pkt packet.Packet, link *network.Link) {
if err != nil { if err != nil {
if err != process.ErrConnectionNotFound { if err != process.ErrConnectionNotFound {
log.Warningf("firewall: could not find process of packet (dropping link %s): %s", pkt.String(), err) log.Warningf("firewall: could not find process of packet (dropping link %s): %s", pkt.String(), err)
link.Deny(fmt.Sprintf("could not find process or it does not exist (unsolicited packet): %s", err))
} else {
log.Warningf("firewall: internal error finding process of packet (dropping link %s): %s", pkt.String(), err)
link.Deny(fmt.Sprintf("internal error finding process: %s", err))
} }
link.UpdateVerdict(network.DROP)
verdict(pkt, network.DROP) if pkt.IsInbound() {
network.UnknownIncomingConnection.AddLink(link)
} else {
network.UnknownDirectConnection.AddLink(link)
}
verdict(pkt, link.GetVerdict())
link.StopFirewallHandler()
return return
} }
// add new Link to Connection (and save both)
connection.AddLink(link)
// reroute dns requests to nameserver // reroute dns requests to nameserver
if connection.Process().Pid != os.Getpid() && pkt.IsOutbound() && pkt.GetTCPUDPHeader() != nil && !pkt.GetIPHeader().Dst.Equal(localhost) && pkt.GetTCPUDPHeader().DstPort == 53 { if connection.Process().Pid != os.Getpid() && pkt.IsOutbound() && pkt.GetTCPUDPHeader() != nil && !pkt.GetIPHeader().Dst.Equal(localhost) && pkt.GetTCPUDPHeader().DstPort == 53 {
pkt.RerouteToNameserver() link.RerouteToNameserver()
verdict(pkt, link.GetVerdict())
link.StopFirewallHandler()
return return
} }
// persist connection
connection.CreateInProcessNamespace()
// add new Link to Connection
connection.AddLink(link, pkt)
// make a decision if not made already // make a decision if not made already
if connection.Verdict == network.UNDECIDED { if connection.GetVerdict() == network.UNDECIDED {
portmaster.DecideOnConnection(connection, pkt) DecideOnConnection(connection, pkt)
} }
if connection.Verdict != network.CANTSAY { if connection.GetVerdict() == network.ACCEPT {
link.UpdateVerdict(connection.Verdict) DecideOnLink(connection, link, pkt)
} else { } else {
portmaster.DecideOnLink(connection, link, pkt) link.UpdateVerdict(connection.GetVerdict())
} }
// log decision // log decision
logInitialVerdict(link) logInitialVerdict(link)
// TODO: link this to real status // TODO: link this to real status
port17Active := mode.Client() // port17Active := mode.Client()
switch { switch {
case port17Active && link.Inspect: // case port17Active && link.Inspect:
// tunnel link, but also inspect (after reroute) // // tunnel link, but also inspect (after reroute)
link.Tunneled = true // link.Tunneled = true
link.SetFirewallHandler(inspectThenVerdict) // link.SetFirewallHandler(inspectThenVerdict)
verdict(pkt, link.Verdict) // verdict(pkt, link.GetVerdict())
case port17Active: // case port17Active:
// tunnel link, don't inspect // // tunnel link, don't inspect
link.Tunneled = true // link.Tunneled = true
link.StopFirewallHandler() // link.StopFirewallHandler()
permanentVerdict(pkt, network.ACCEPT) // permanentVerdict(pkt, network.ACCEPT)
case link.Inspect: case link.Inspect:
link.SetFirewallHandler(inspectThenVerdict) link.SetFirewallHandler(inspectThenVerdict)
inspectThenVerdict(pkt, link) inspectThenVerdict(pkt, link)
default: default:
link.StopFirewallHandler() link.StopFirewallHandler()
verdict(pkt, link.Verdict) verdict(pkt, link.GetVerdict())
} }
} }
@ -216,10 +222,11 @@ func inspectThenVerdict(pkt packet.Packet, link *network.Link) {
pktVerdict, continueInspection := inspection.RunInspectors(pkt, link) pktVerdict, continueInspection := inspection.RunInspectors(pkt, link)
if continueInspection { if continueInspection {
// do not allow to circumvent link decision: e.g. to ACCEPT packets from a DROP-ed link // do not allow to circumvent link decision: e.g. to ACCEPT packets from a DROP-ed link
if pktVerdict > link.Verdict { linkVerdict := link.GetVerdict()
if pktVerdict > linkVerdict {
verdict(pkt, pktVerdict) verdict(pkt, pktVerdict)
} else { } else {
verdict(pkt, link.Verdict) verdict(pkt, linkVerdict)
} }
return return
} }
@ -227,13 +234,11 @@ func inspectThenVerdict(pkt packet.Packet, link *network.Link) {
// we are done with inspecting // we are done with inspecting
link.StopFirewallHandler() link.StopFirewallHandler()
config.Changed() link.Lock()
config.RLock() defer link.Unlock()
link.VerdictPermanent = config.PermanentVerdicts link.VerdictPermanent = permanentVerdicts()
config.RUnlock()
if link.VerdictPermanent { if link.VerdictPermanent {
link.Save() go link.Save()
permanentVerdict(pkt, link.Verdict) permanentVerdict(pkt, link.Verdict)
} else { } else {
verdict(pkt, link.Verdict) verdict(pkt, link.Verdict)
@ -254,6 +259,12 @@ func permanentVerdict(pkt packet.Packet, action network.Verdict) {
atomic.AddUint64(packetsDropped, 1) atomic.AddUint64(packetsDropped, 1)
pkt.PermanentDrop() pkt.PermanentDrop()
return return
case network.RerouteToNameserver:
pkt.RerouteToNameserver()
return
case network.RerouteToTunnel:
pkt.RerouteToTunnel()
return
} }
pkt.Drop() pkt.Drop()
} }
@ -272,36 +283,46 @@ func verdict(pkt packet.Packet, action network.Verdict) {
atomic.AddUint64(packetsDropped, 1) atomic.AddUint64(packetsDropped, 1)
pkt.Drop() pkt.Drop()
return return
case network.RerouteToNameserver:
pkt.RerouteToNameserver()
return
case network.RerouteToTunnel:
pkt.RerouteToTunnel()
return
} }
pkt.Drop() pkt.Drop()
} }
func tunnelHandler(pkt packet.Packet) { // func tunnelHandler(pkt packet.Packet) {
tunnelInfo := portmaster.GetTunnelInfo(pkt.GetIPHeader().Dst) // tunnelInfo := GetTunnelInfo(pkt.GetIPHeader().Dst)
if tunnelInfo == nil { // if tunnelInfo == nil {
pkt.Block() // pkt.Block()
return // return
} // }
//
entry.CreateTunnel(pkt, tunnelInfo.Domain, tunnelInfo.RRCache.ExportAllARecords()) // entry.CreateTunnel(pkt, tunnelInfo.Domain, tunnelInfo.RRCache.ExportAllARecords())
log.Tracef("firewall: rerouting %s to tunnel entry point", pkt) // log.Tracef("firewall: rerouting %s to tunnel entry point", pkt)
pkt.RerouteToTunnel() // pkt.RerouteToTunnel()
return // return
} // }
func logInitialVerdict(link *network.Link) { func logInitialVerdict(link *network.Link) {
// switch link.Verdict { // switch link.GetVerdict() {
// case network.ACCEPT: // case network.ACCEPT:
// log.Infof("firewall: accepting new link: %s", link.String()) // log.Infof("firewall: accepting new link: %s", link.String())
// case network.BLOCK: // case network.BLOCK:
// log.Infof("firewall: blocking new link: %s", link.String()) // log.Infof("firewall: blocking new link: %s", link.String())
// case network.DROP: // case network.DROP:
// log.Infof("firewall: dropping new link: %s", link.String()) // log.Infof("firewall: dropping new link: %s", link.String())
// case network.RerouteToNameserver:
// log.Infof("firewall: rerouting new link to nameserver: %s", link.String())
// case network.RerouteToTunnel:
// log.Infof("firewall: rerouting new link to tunnel: %s", link.String())
// } // }
} }
func logChangedVerdict(link *network.Link) { func logChangedVerdict(link *network.Link) {
// switch link.Verdict { // switch link.GetVerdict() {
// case network.ACCEPT: // case network.ACCEPT:
// log.Infof("firewall: change! - now accepting link: %s", link.String()) // log.Infof("firewall: change! - now accepting link: %s", link.String())
// case network.BLOCK: // case network.BLOCK:
@ -312,25 +333,26 @@ func logChangedVerdict(link *network.Link) {
} }
func run() { func run() {
packetProcessingLoop:
for { for {
select { select {
case <-firewallModule.Stop: case <-modules.ShuttingDown():
break packetProcessingLoop return
case pkt := <-interception.Packets: case pkt := <-interception.Packets:
handlePacket(pkt) handlePacket(pkt)
} }
} }
} }
func statLogger() { func statLogger() {
for { for {
time.Sleep(10 * time.Second) select {
case <-modules.ShuttingDown():
return
case <-time.After(10 * time.Second):
log.Tracef("firewall: packets accepted %d, blocked %d, dropped %d", atomic.LoadUint64(packetsAccepted), atomic.LoadUint64(packetsBlocked), atomic.LoadUint64(packetsDropped)) log.Tracef("firewall: packets accepted %d, blocked %d, dropped %d", atomic.LoadUint64(packetsAccepted), atomic.LoadUint64(packetsBlocked), atomic.LoadUint64(packetsDropped))
atomic.StoreUint64(packetsAccepted, 0) atomic.StoreUint64(packetsAccepted, 0)
atomic.StoreUint64(packetsBlocked, 0) atomic.StoreUint64(packetsBlocked, 0)
atomic.StoreUint64(packetsDropped, 0) atomic.StoreUint64(packetsDropped, 0)
} }
}
} }

View file

@ -3,9 +3,10 @@
package inspection package inspection
import ( import (
"github.com/Safing/safing-core/network"
"github.com/Safing/safing-core/network/packet"
"sync" "sync"
"github.com/Safing/portmaster/network"
"github.com/Safing/portmaster/network/packet"
) )
const ( const (
@ -40,24 +41,28 @@ func RunInspectors(pkt packet.Packet, link *network.Link) (network.Verdict, bool
// inspectorsLock.Lock() // inspectorsLock.Lock()
// defer inspectorsLock.Unlock() // defer inspectorsLock.Unlock()
if link.ActiveInspectors == nil { activeInspectors := link.GetActiveInspectors()
link.ActiveInspectors = make([]bool, len(inspectors), len(inspectors)) if activeInspectors == nil {
activeInspectors = make([]bool, len(inspectors), len(inspectors))
link.SetActiveInspectors(activeInspectors)
} }
if link.InspectorData == nil { inspectorData := link.GetInspectorData()
link.InspectorData = make(map[uint8]interface{}) if inspectorData == nil {
inspectorData = make(map[uint8]interface{})
link.SetInspectorData(inspectorData)
} }
continueInspection := false continueInspection := false
verdict := network.UNDECIDED verdict := network.UNDECIDED
for key, skip := range link.ActiveInspectors { for key, skip := range activeInspectors {
if skip { if skip {
continue continue
} }
if link.Verdict > inspectVerdicts[key] { if link.Verdict > inspectVerdicts[key] {
link.ActiveInspectors[key] = true activeInspectors[key] = true
continue continue
} }
@ -78,16 +83,16 @@ func RunInspectors(pkt packet.Packet, link *network.Link) (network.Verdict, bool
continueInspection = true continueInspection = true
case BLOCK_LINK: case BLOCK_LINK:
link.UpdateVerdict(network.BLOCK) link.UpdateVerdict(network.BLOCK)
link.ActiveInspectors[key] = true activeInspectors[key] = true
if verdict < network.BLOCK { if verdict < network.BLOCK {
verdict = network.BLOCK verdict = network.BLOCK
} }
case DROP_LINK: case DROP_LINK:
link.UpdateVerdict(network.DROP) link.UpdateVerdict(network.DROP)
link.ActiveInspectors[key] = true activeInspectors[key] = true
verdict = network.DROP verdict = network.DROP
case STOP_INSPECTING: case STOP_INSPECTING:
link.ActiveInspectors[key] = true activeInspectors[key] = true
} }
} }

View file

@ -1,5 +1,3 @@
// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file.
package tls package tls
var ( var (

View file

@ -1,5 +1,3 @@
// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file.
package tls package tls
import ( import (
@ -12,14 +10,13 @@ import (
"github.com/google/gopacket/layers" "github.com/google/gopacket/layers"
"github.com/google/gopacket/tcpassembly" "github.com/google/gopacket/tcpassembly"
"github.com/Safing/safing-core/configuration" "github.com/Safing/portbase/log"
"github.com/Safing/safing-core/crypto/verify" "github.com/Safing/portmaster/firewall/inspection"
"github.com/Safing/safing-core/firewall/inspection" "github.com/Safing/portmaster/firewall/inspection/tls/tlslib"
"github.com/Safing/safing-core/firewall/inspection/tls/tlslib" "github.com/Safing/portmaster/firewall/inspection/tls/verify"
"github.com/Safing/safing-core/log" "github.com/Safing/portmaster/network"
"github.com/Safing/safing-core/network" "github.com/Safing/portmaster/network/netutils"
"github.com/Safing/safing-core/network/netutils" "github.com/Safing/portmaster/network/packet"
"github.com/Safing/safing-core/network/packet"
) )
// TODO: // TODO:
@ -31,8 +28,6 @@ var (
tlsInspectorIndex int tlsInspectorIndex int
assemblerManager *netutils.SimpleStreamAssemblerManager assemblerManager *netutils.SimpleStreamAssemblerManager
assembler *tcpassembly.Assembler assembler *tcpassembly.Assembler
config = configuration.Get()
) )
const ( const (

View file

@ -6,7 +6,7 @@ import (
"fmt" "fmt"
"testing" "testing"
"github.com/Safing/safing-core/firewall/inspection/tls/tlslib" "github.com/Safing/portmaster/firewall/inspection/tls/tlslib"
) )
var clientHelloSample = []byte{ var clientHelloSample = []byte{

View file

@ -1,5 +1,3 @@
// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file.
package verify package verify
import ( import (
@ -14,15 +12,15 @@ import (
"strings" "strings"
"github.com/cloudflare/cfssl/crypto/pkcs7" "github.com/cloudflare/cfssl/crypto/pkcs7"
datastore "github.com/ipfs/go-datastore"
"github.com/Safing/safing-core/crypto/hash" "github.com/Safing/portbase/crypto/hash"
"github.com/Safing/safing-core/database" "github.com/Safing/portbase/database"
"github.com/Safing/portbase/database/record"
) )
// Cert saves a certificate. // Cert saves a certificate.
type Cert struct { type Cert struct {
database.Base record.Record
cert *x509.Certificate cert *x509.Certificate
Raw []byte Raw []byte
@ -120,7 +118,7 @@ func (m *Cert) CreateRevokedCert(caID string, serialNumber *big.Int) error {
} }
// CreateInNamespace saves Cert with the provided name in the provided namespace. // CreateInNamespace saves Cert with the provided name in the provided namespace.
func (m *Cert) CreateInNamespace(namespace *datastore.Key, name string) error { func (m *Cert) CreateInNamespace(namespace string, name string) error {
return m.CreateObject(namespace, name, m) return m.CreateObject(namespace, name, m)
} }
@ -140,7 +138,7 @@ func GetCertWithSPKI(spki []byte) (*Cert, error) {
} }
// GetCertFromNamespace gets Cert with the provided name from the provided namespace. // GetCertFromNamespace gets Cert with the provided name from the provided namespace.
func GetCertFromNamespace(namespace *datastore.Key, name string) (*Cert, error) { func GetCertFromNamespace(namespace string, name string) (*Cert, error) {
object, err := database.GetAndEnsureModel(namespace, name, certModel) object, err := database.GetAndEnsureModel(namespace, name, certModel)
if err != nil { if err != nil {
return nil, err return nil, err

View file

@ -1,5 +1,3 @@
// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file.
package verify package verify
import ( import (

View file

@ -1,5 +1,3 @@
// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file.
package verify package verify
import ( import (
@ -14,16 +12,15 @@ import (
"sync" "sync"
"time" "time"
datastore "github.com/ipfs/go-datastore" "github.com/Safing/portbase/crypto/hash"
"github.com/Safing/portbase/database"
"github.com/Safing/safing-core/crypto/hash" "github.com/Safing/portbase/database/record"
"github.com/Safing/safing-core/database" "github.com/Safing/portbase/log"
"github.com/Safing/safing-core/log"
) )
// CARevocationInfo saves Information on revokation of Certificates of a Certificate Authority. // CARevocationInfo saves Information on revokation of Certificates of a Certificate Authority.
type CARevocationInfo struct { type CARevocationInfo struct {
database.Base record.Record
CRLDistributionPoints []string CRLDistributionPoints []string
OCSPServers []string OCSPServers []string
@ -39,23 +36,17 @@ type CARevocationInfo struct {
} }
var ( var (
caRevocationInfoModel *CARevocationInfo // only use this as parameter for database.EnsureModel-like functions
dupCrlReqMap = make(map[string]*sync.Mutex) dupCrlReqMap = make(map[string]*sync.Mutex)
dupCrlReqLock sync.Mutex dupCrlReqLock sync.Mutex
) )
func init() {
database.RegisterModel(caRevocationInfoModel, func() database.Model { return new(CARevocationInfo) })
}
// Create saves CARevocationInfo with the provided name in the default namespace. // Create saves CARevocationInfo with the provided name in the default namespace.
func (m *CARevocationInfo) Create(name string) error { func (m *CARevocationInfo) Create(name string) error {
return m.CreateObject(&database.CARevocationInfoCache, name, m) return m.CreateObject(&database.CARevocationInfoCache, name, m)
} }
// CreateInNamespace saves CARevocationInfo with the provided name in the provided namespace. // CreateInNamespace saves CARevocationInfo with the provided name in the provided namespace.
func (m *CARevocationInfo) CreateInNamespace(namespace *datastore.Key, name string) error { func (m *CARevocationInfo) CreateInNamespace(namespace string, name string) error {
return m.CreateObject(namespace, name, m) return m.CreateObject(namespace, name, m)
} }
@ -78,7 +69,7 @@ func GetCARevocationInfo(name string) (*CARevocationInfo, error) {
} }
// GetCARevocationInfoFromNamespace fetches CARevocationInfo with the provided name from the provided namespace. // GetCARevocationInfoFromNamespace fetches CARevocationInfo with the provided name from the provided namespace.
func GetCARevocationInfoFromNamespace(namespace *datastore.Key, name string) (*CARevocationInfo, error) { func GetCARevocationInfoFromNamespace(namespace string, name string) (*CARevocationInfo, error) {
object, err := database.GetAndEnsureModel(namespace, name, caRevocationInfoModel) object, err := database.GetAndEnsureModel(namespace, name, caRevocationInfoModel)
if err != nil { if err != nil {
return nil, err return nil, err

View file

@ -1,5 +1,3 @@
// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file.
package verify package verify
import ( import (
@ -16,8 +14,8 @@ import (
"golang.org/x/crypto/ocsp" "golang.org/x/crypto/ocsp"
"github.com/Safing/safing-core/crypto/hash" "github.com/Safing/portbase/crypto/hash"
"github.com/Safing/safing-core/log" "github.com/Safing/portbase/log"
) )
var ( var (

View file

@ -1,5 +1,3 @@
// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file.
package verify package verify
import ( import (
@ -8,9 +6,8 @@ import (
"fmt" "fmt"
"time" "time"
"github.com/Safing/safing-core/configuration" "github.com/Safing/portbase/crypto/hash"
"github.com/Safing/safing-core/crypto/hash" "github.com/Safing/portbase/database"
"github.com/Safing/safing-core/database"
) )
// useful references: // useful references:
@ -24,10 +21,6 @@ import (
// RE: https://www.grc.com/revocation/crlsets.htm // RE: https://www.grc.com/revocation/crlsets.htm
// RE: RE: https://www.imperialviolet.org/2014/04/29/revocationagain.html // RE: RE: https://www.imperialviolet.org/2014/04/29/revocationagain.html
var (
config = configuration.Get()
)
// FullCheckBytes does a full certificate check, certificates are provided as raw bytes. // FullCheckBytes does a full certificate check, certificates are provided as raw bytes.
// It parses the raw certificates and calls FullCheck. // It parses the raw certificates and calls FullCheck.
func FullCheckBytes(name string, certBytes [][]byte) (bool, error) { func FullCheckBytes(name string, certBytes [][]byte) (bool, error) {

View file

@ -2,10 +2,19 @@
package interception package interception
import "github.com/Safing/safing-core/network/packet" import "github.com/Safing/portmaster/network/packet"
var Packets chan packet.Packet var (
// Packets channel for feeding the firewall.
func init() {
Packets = make(chan packet.Packet, 1000) Packets = make(chan packet.Packet, 1000)
)
// Start starts the interception.
func Start() error {
return StartNfqueueInterception()
}
// Stop starts the interception.
func Stop() error {
return StopNfqueueInterception()
} }

View file

@ -1,31 +1,31 @@
// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file.
package interception package interception
import ( import (
"github.com/Safing/safing-core/firewall/interception/windivert" "fmt"
"github.com/Safing/safing-core/log"
"github.com/Safing/safing-core/modules" "github.com/Safing/portmaster/firewall/interception/windivert"
"github.com/Safing/safing-core/network/packet" "github.com/Safing/portmaster/network/packet"
) )
var Packets chan packet.Packet var Packets chan packet.Packet
func init() { func init() {
// Packets channel for feeding the firewall.
Packets = make(chan packet.Packet, 1000) Packets = make(chan packet.Packet, 1000)
} }
func Start() { // Start starts the interception.
func Start() error {
windivertModule := modules.Register("Firewall:Interception:WinDivert", 192)
wd, err := windivert.New("/WinDivert.dll", "") wd, err := windivert.New("/WinDivert.dll", "")
if err != nil { if err != nil {
log.Criticalf("firewall/interception: could not init windivert: %s", err) return fmt.Errorf("firewall/interception: could not init windivert: %s", err)
} else {
wd.Packets(Packets)
} }
<-windivertModule.Stop return wd.Packets(Packets)
windivertModule.StopComplete() }
// Stop starts the interception.
func Stop() error {
return nil
} }

View file

@ -2,45 +2,48 @@
package nfqueue package nfqueue
import ( // suspended for now
"github.com/Safing/safing-core/network/packet"
"sync"
)
type multiQueue struct { // import (
qs []*nfQueue // "sync"
} //
// "github.com/Safing/portmaster/network/packet"
func NewMultiQueue(min, max uint16) (mq *multiQueue) { // )
mq = &multiQueue{make([]*nfQueue, 0, max-min)} //
for i := min; i < max; i++ { // type multiQueue struct {
mq.qs = append(mq.qs, NewNFQueue(i)) // qs []*NFQueue
} // }
return mq //
} // func NewMultiQueue(min, max uint16) (mq *multiQueue) {
// mq = &multiQueue{make([]*NFQueue, 0, max-min)}
func (mq *multiQueue) Process() <-chan packet.Packet { // for i := min; i < max; i++ {
var ( // mq.qs = append(mq.qs, NewNFQueue(i))
wg sync.WaitGroup // }
out = make(chan packet.Packet, len(mq.qs)) // return mq
) // }
for _, q := range mq.qs { //
wg.Add(1) // func (mq *multiQueue) Process() <-chan packet.Packet {
go func(ch <-chan packet.Packet) { // var (
for pkt := range ch { // wg sync.WaitGroup
out <- pkt // out = make(chan packet.Packet, len(mq.qs))
} // )
wg.Done() // for _, q := range mq.qs {
}(q.Process()) // wg.Add(1)
} // go func(ch <-chan packet.Packet) {
go func() { // for pkt := range ch {
wg.Wait() // out <- pkt
close(out) // }
}() // wg.Done()
return out // }(q.Process())
} // }
func (mq *multiQueue) Destroy() { // go func() {
for _, q := range mq.qs { // wg.Wait()
q.Destroy() // close(out)
} // }()
} // return out
// }
// func (mq *multiQueue) Destroy() {
// for _, q := range mq.qs {
// q.Destroy()
// }
// }

View file

@ -17,17 +17,19 @@ import (
"syscall" "syscall"
"time" "time"
"unsafe" "unsafe"
"errors"
"fmt"
"github.com/Safing/safing-core/network/packet" "github.com/Safing/portmaster/network/packet"
) )
var queues map[uint16]*nfQueue var queues map[uint16]*NFQueue
func init() { func init() {
queues = make(map[uint16]*nfQueue) queues = make(map[uint16]*NFQueue)
} }
type nfQueue struct { type NFQueue struct {
DefaultVerdict uint32 DefaultVerdict uint32
Timeout time.Duration Timeout time.Duration
qid uint16 qid uint16
@ -38,83 +40,77 @@ type nfQueue struct {
fd int fd int
lk sync.Mutex lk sync.Mutex
pktch chan packet.Packet Packets chan packet.Packet
} }
func NewNFQueue(qid uint16) (nfq *nfQueue) { func NewNFQueue(qid uint16) (nfq *NFQueue, err error) {
if os.Geteuid() != 0 { if os.Geteuid() != 0 {
panic("Must be ran by root.") return nil, errors.New("must be root to intercept packets")
} }
nfq = &nfQueue{DefaultVerdict: NFQ_ACCEPT, Timeout: 100 * time.Millisecond, qid: qid, qidptr: &qid} nfq = &NFQueue{DefaultVerdict: NFQ_ACCEPT, Timeout: 100 * time.Millisecond, qid: qid, qidptr: &qid}
queues[nfq.qid] = nfq queues[nfq.qid] = nfq
return nfq
}
/* err = nfq.init()
This returns a channel that will recieve packets, if err != nil {
the user then must call pkt.Accept() or pkt.Drop() return nil, err
*/
func (this *nfQueue) Process() <-chan packet.Packet {
if this.h != nil {
return this.pktch
} }
this.init()
go func() { go func() {
runtime.LockOSThread() runtime.LockOSThread()
C.loop_for_packets(this.h) C.loop_for_packets(nfq.h)
}() }()
return this.pktch return nfq, nil
} }
func (this *nfQueue) init() { func (this *NFQueue) init() error {
var err error var err error
if this.h, err = C.nfq_open(); err != nil || this.h == nil { if this.h, err = C.nfq_open(); err != nil || this.h == nil {
panic(err) fmt.Errorf("could not open nfqueue: %s", err)
} }
//if this.qh, err = C.nfq_create_queue(this.h, qid, C.get_cb(), unsafe.Pointer(nfq)); err != nil || this.qh == nil { //if this.qh, err = C.nfq_create_queue(this.h, qid, C.get_cb(), unsafe.Pointer(nfq)); err != nil || this.qh == nil {
this.pktch = make(chan packet.Packet, 1) this.Packets = make(chan packet.Packet, 1)
if C.nfq_unbind_pf(this.h, C.AF_INET) < 0 { if C.nfq_unbind_pf(this.h, C.AF_INET) < 0 {
this.Destroy() this.Destroy()
panic("nfq_unbind_pf(AF_INET) failed, are you running root?.") return errors.New("nfq_unbind_pf(AF_INET) failed, are you root?")
} }
if C.nfq_unbind_pf(this.h, C.AF_INET6) < 0 { if C.nfq_unbind_pf(this.h, C.AF_INET6) < 0 {
this.Destroy() this.Destroy()
panic("nfq_unbind_pf(AF_INET6) failed.") return errors.New("nfq_unbind_pf(AF_INET6) failed")
} }
if C.nfq_bind_pf(this.h, C.AF_INET) < 0 { if C.nfq_bind_pf(this.h, C.AF_INET) < 0 {
this.Destroy() this.Destroy()
panic("nfq_bind_pf(AF_INET) failed.") return errors.New("nfq_bind_pf(AF_INET) failed")
} }
if C.nfq_bind_pf(this.h, C.AF_INET6) < 0 { if C.nfq_bind_pf(this.h, C.AF_INET6) < 0 {
this.Destroy() this.Destroy()
panic("nfq_bind_pf(AF_INET6) failed.") return errors.New("nfq_bind_pf(AF_INET6) failed")
} }
if this.qh, err = C.create_queue(this.h, C.uint16_t(this.qid)); err != nil || this.qh == nil { if this.qh, err = C.create_queue(this.h, C.uint16_t(this.qid)); err != nil || this.qh == nil {
C.nfq_close(this.h) C.nfq_close(this.h)
panic(err) return fmt.Errorf("could not create queue: %s", err)
} }
this.fd = int(C.nfq_fd(this.h)) this.fd = int(C.nfq_fd(this.h))
if C.nfq_set_mode(this.qh, C.NFQNL_COPY_PACKET, 0xffff) < 0 { if C.nfq_set_mode(this.qh, C.NFQNL_COPY_PACKET, 0xffff) < 0 {
this.Destroy() this.Destroy()
panic("nfq_set_mode(NFQNL_COPY_PACKET) failed.") return errors.New("nfq_set_mode(NFQNL_COPY_PACKET) failed")
} }
if C.nfq_set_queue_maxlen(this.qh, 1024*8) < 0 { if C.nfq_set_queue_maxlen(this.qh, 1024*8) < 0 {
this.Destroy() this.Destroy()
panic("nfq_set_queue_maxlen(1024 * 8) failed.") return errors.New("nfq_set_queue_maxlen(1024 * 8) failed")
} }
return nil
} }
func (this *nfQueue) Destroy() { func (this *NFQueue) Destroy() {
this.lk.Lock() this.lk.Lock()
defer this.lk.Unlock() defer this.lk.Unlock()
@ -131,12 +127,12 @@ func (this *nfQueue) Destroy() {
} }
// TODO: don't close, we're exiting anyway // TODO: don't close, we're exiting anyway
// if this.pktch != nil { // if this.Packets != nil {
// close(this.pktch) // close(this.Packets)
// } // }
} }
func (this *nfQueue) Valid() bool { func (this *NFQueue) Valid() bool {
return this.h != nil && this.qh != nil return this.h != nil && this.qh != nil
} }
@ -148,7 +144,7 @@ func go_nfq_callback(id uint32, hwproto uint16, hook uint8, mark *uint32,
qidptr := (*uint16)(data) qidptr := (*uint16)(data)
qid := uint16(*qidptr) qid := uint16(*qidptr)
// nfq := (*nfQueue)(nfqptr) // nfq := (*NFQueue)(nfqptr)
new_version := version new_version := version
ipver := packet.IPVersion(new_version) ipver := packet.IPVersion(new_version)
ipsz := C.int(ipver.ByteSize()) ipsz := C.int(ipver.ByteSize())
@ -187,7 +183,7 @@ func go_nfq_callback(id uint32, hwproto uint16, hook uint8, mark *uint32,
// fmt.Printf("%s queuing packet\n", time.Now().Format("060102 15:04:05.000")) // fmt.Printf("%s queuing packet\n", time.Now().Format("060102 15:04:05.000"))
// BUG: "panic: send on closed channel" when shutting down // BUG: "panic: send on closed channel" when shutting down
queues[qid].pktch <- &pkt queues[qid].Packets <- &pkt
select { select {
case v = <-pkt.verdict: case v = <-pkt.verdict:

View file

@ -5,7 +5,7 @@ package nfqueue
import ( import (
"fmt" "fmt"
"github.com/Safing/safing-core/network/packet" "github.com/Safing/portmaster/network/packet"
) )
var ( var (

View file

@ -1,31 +1,33 @@
// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file.
// +build linux
package interception package interception
import ( import (
"fmt"
"sort" "sort"
"strings" "strings"
"github.com/coreos/go-iptables/iptables" "github.com/coreos/go-iptables/iptables"
"github.com/Safing/safing-core/firewall/interception/nfqueue" "github.com/Safing/portmaster/firewall/interception/nfqueue"
"github.com/Safing/safing-core/log"
"github.com/Safing/safing-core/modules"
) )
// iptables -A OUTPUT -p icmp -j", "NFQUEUE", "--queue-num", "1", "--queue-bypass // iptables -A OUTPUT -p icmp -j", "NFQUEUE", "--queue-num", "1", "--queue-bypass
var nfqueueModule *modules.Module var (
v4chains []string
v4rules []string
v4once []string
var v4chains []string v6chains []string
var v4rules []string v6rules []string
var v4once []string v6once []string
var v6chains []string out4Queue *nfqueue.NFQueue
var v6rules []string in4Queue *nfqueue.NFQueue
var v6once []string out6Queue *nfqueue.NFQueue
in6Queue *nfqueue.NFQueue
shutdownSignal = make(chan struct{})
)
func init() { func init() {
@ -100,8 +102,8 @@ func init() {
} }
// Reverse because we'd like to insert in a loop // Reverse because we'd like to insert in a loop
sort.Reverse(sort.StringSlice(v4once)) _ = sort.Reverse(sort.StringSlice(v4once)) // silence vet (sort is used just like in the docs)
sort.Reverse(sort.StringSlice(v6once)) _ = sort.Reverse(sort.StringSlice(v6once)) // silence vet (sort is used just like in the docs)
} }
@ -127,9 +129,10 @@ func activateNfqueueFirewall() error {
} }
} }
var ok bool
for _, rule := range v4once { for _, rule := range v4once {
splittedRule := strings.Split(rule, " ") splittedRule := strings.Split(rule, " ")
ok, err := ip4tables.Exists(splittedRule[0], splittedRule[1], splittedRule[2:]...) ok, err = ip4tables.Exists(splittedRule[0], splittedRule[1], splittedRule[2:]...)
if err != nil { if err != nil {
return err return err
} }
@ -183,9 +186,10 @@ func deactivateNfqueueFirewall() error {
return err return err
} }
var ok bool
for _, rule := range v4once { for _, rule := range v4once {
splittedRule := strings.Split(rule, " ") splittedRule := strings.Split(rule, " ")
ok, err := ip4tables.Exists(splittedRule[0], splittedRule[1], splittedRule[2:]...) ok, err = ip4tables.Exists(splittedRule[0], splittedRule[1], splittedRule[2:]...)
if err != nil { if err != nil {
return err return err
} }
@ -198,10 +202,10 @@ func deactivateNfqueueFirewall() error {
for _, chain := range v4chains { for _, chain := range v4chains {
splittedRule := strings.Split(chain, " ") splittedRule := strings.Split(chain, " ")
if err := ip4tables.ClearChain(splittedRule[0], splittedRule[1]); err != nil { if err = ip4tables.ClearChain(splittedRule[0], splittedRule[1]); err != nil {
return err return err
} }
if err := ip4tables.DeleteChain(splittedRule[0], splittedRule[1]); err != nil { if err = ip4tables.DeleteChain(splittedRule[0], splittedRule[1]); err != nil {
return err return err
} }
} }
@ -238,70 +242,84 @@ func deactivateNfqueueFirewall() error {
return nil return nil
} }
func Start() { // StartNfqueueInterception starts the nfqueue interception.
func StartNfqueueInterception() (err error) {
nfqueueModule = modules.Register("Firewall:Interception:Nfqueue", 192) err = activateNfqueueFirewall()
if err != nil {
if err := activateNfqueueFirewall(); err != nil { Stop()
log.Criticalf("could not activate firewall for nfqueue: %q", err) return fmt.Errorf("could not initialize nfqueue: %s", err)
} }
out4Queue := nfqueue.NewNFQueue(17040) out4Queue, err = nfqueue.NewNFQueue(17040)
in4Queue := nfqueue.NewNFQueue(17140) if err != nil {
out6Queue := nfqueue.NewNFQueue(17060) Stop()
in6Queue := nfqueue.NewNFQueue(17160) return fmt.Errorf("interception: failed to create nfqueue(IPv4, in): %s", err)
}
in4Queue, err = nfqueue.NewNFQueue(17140)
if err != nil {
Stop()
return fmt.Errorf("interception: failed to create nfqueue(IPv4, in): %s", err)
}
out6Queue, err = nfqueue.NewNFQueue(17060)
if err != nil {
Stop()
return fmt.Errorf("interception: failed to create nfqueue(IPv4, in): %s", err)
}
in6Queue, err = nfqueue.NewNFQueue(17160)
if err != nil {
Stop()
return fmt.Errorf("interception: failed to create nfqueue(IPv4, in): %s", err)
}
out4Channel := out4Queue.Process() go handleInterception()
// if err != nil { return nil
// log.Criticalf("could not open nfqueue out4") }
// } else {
defer out4Queue.Destroy()
// }
in4Channel := in4Queue.Process()
// if err != nil {
// log.Criticalf("could not open nfqueue in4")
// } else {
defer in4Queue.Destroy()
// }
out6Channel := out6Queue.Process()
// if err != nil {
// log.Criticalf("could not open nfqueue out6")
// } else {
defer out6Queue.Destroy()
// }
in6Channel := in6Queue.Process()
// if err != nil {
// log.Criticalf("could not open nfqueue in6")
// } else {
defer in6Queue.Destroy()
// }
packetInterceptionLoop: // StopNfqueueInterception stops the nfqueue interception.
func StopNfqueueInterception() error {
defer close(shutdownSignal)
if out4Queue != nil {
out4Queue.Destroy()
}
if in4Queue != nil {
in4Queue.Destroy()
}
if out6Queue != nil {
out6Queue.Destroy()
}
if in6Queue != nil {
in6Queue.Destroy()
}
err := deactivateNfqueueFirewall()
if err != nil {
return fmt.Errorf("interception: error while deactivating nfqueue: %s", err)
}
return nil
}
func handleInterception() {
for { for {
select { select {
case <-nfqueueModule.Stop: case <-shutdownSignal:
break packetInterceptionLoop return
case pkt := <-out4Channel: case pkt := <-out4Queue.Packets:
pkt.SetOutbound() pkt.SetOutbound()
Packets <- pkt Packets <- pkt
case pkt := <-in4Channel: case pkt := <-in4Queue.Packets:
pkt.SetInbound() pkt.SetInbound()
Packets <- pkt Packets <- pkt
case pkt := <-out6Channel: case pkt := <-out6Queue.Packets:
pkt.SetOutbound() pkt.SetOutbound()
Packets <- pkt Packets <- pkt
case pkt := <-in6Channel: case pkt := <-in6Queue.Packets:
pkt.SetInbound() pkt.SetInbound()
Packets <- pkt Packets <- pkt
} }
} }
if err := deactivateNfqueueFirewall(); err != nil {
log.Criticalf("could not deactivate firewall for nfqueue: %q", err)
}
nfqueueModule.StopComplete()
} }
func stringInSlice(slice []string, value string) bool { func stringInSlice(slice []string, value string) bool {

340
firewall/master.go Normal file
View file

@ -0,0 +1,340 @@
package firewall
import (
"fmt"
"os"
"strings"
"github.com/Safing/portbase/log"
"github.com/Safing/portmaster/intel"
"github.com/Safing/portmaster/network"
"github.com/Safing/portmaster/network/packet"
"github.com/Safing/portmaster/profile"
"github.com/Safing/portmaster/status"
"github.com/agext/levenshtein"
)
// Call order:
//
// 1. DecideOnConnectionBeforeIntel (if connecting to domain)
// is called when a DNS query is made, before the query is resolved
// 2. DecideOnConnectionAfterIntel (if connecting to domain)
// is called when a DNS query is made, after the query is resolved
// 3. DecideOnConnection
// is called when the first packet of the first link of the connection arrives
// 4. DecideOnLink
// is called when when the first packet of a link arrives only if connection has verdict UNDECIDED or CANTSAY
// DecideOnConnectionBeforeIntel makes a decision about a connection before the dns query is resolved and intel is gathered.
func DecideOnConnectionBeforeIntel(connection *network.Connection, fqdn string) {
// check:
// Profile.DomainWhitelist
// Profile.Flags
// - process specific: System, Admin, User
// - network specific: Internet, LocalNet
// grant self
if connection.Process().Pid == os.Getpid() {
log.Infof("firewall: granting own connection %s", connection)
connection.Accept("")
return
}
// check if there is a profile
profileSet := connection.Process().ProfileSet()
if profileSet == nil {
log.Errorf("firewall: denying connection %s, no profile set", connection)
connection.Deny("no profile set")
return
}
profileSet.Update(status.CurrentSecurityLevel())
// check for any network access
if !profileSet.CheckFlag(profile.Internet) && !profileSet.CheckFlag(profile.LAN) {
log.Infof("firewall: denying connection %s, accessing Internet or LAN not allowed", connection)
connection.Deny("accessing Internet or LAN not allowed")
return
}
// check domain list
permitted, reason, ok := profileSet.CheckEndpoint(fqdn, 0, 0, false)
if ok {
if permitted {
log.Infof("firewall: accepting connection %s, endpoint is whitelisted: %s", connection, reason)
connection.Accept(fmt.Sprintf("endpoint is whitelisted: %s", reason))
} else {
log.Infof("firewall: denying connection %s, endpoint is blacklisted", connection)
connection.Deny("endpoint is blacklisted")
}
return
}
switch profileSet.GetProfileMode() {
case profile.Whitelist:
log.Infof("firewall: denying connection %s, domain is not whitelisted", connection)
connection.Deny("domain is not whitelisted")
case profile.Prompt:
// check Related flag
// TODO: improve this!
if profileSet.CheckFlag(profile.Related) {
matched := false
pathElements := strings.Split(connection.Process().Path, "/") // FIXME: path seperator
// only look at the last two path segments
if len(pathElements) > 2 {
pathElements = pathElements[len(pathElements)-2:]
}
domainElements := strings.Split(fqdn, ".")
var domainElement string
var processElement string
matchLoop:
for _, domainElement = range domainElements {
for _, pathElement := range pathElements {
if levenshtein.Match(domainElement, pathElement, nil) > 0.5 {
matched = true
processElement = pathElement
break matchLoop
}
}
if levenshtein.Match(domainElement, profileSet.UserProfile().Name, nil) > 0.5 {
matched = true
processElement = profileSet.UserProfile().Name
break matchLoop
}
if levenshtein.Match(domainElement, connection.Process().Name, nil) > 0.5 {
matched = true
processElement = connection.Process().Name
break matchLoop
}
if levenshtein.Match(domainElement, connection.Process().ExecName, nil) > 0.5 {
matched = true
processElement = connection.Process().ExecName
break matchLoop
}
}
if matched {
log.Infof("firewall: accepting connection %s, match to domain was found: %s ~= %s", connection, domainElement, processElement)
connection.Accept("domain is related to process")
}
}
if connection.GetVerdict() != network.ACCEPT {
// TODO
log.Infof("firewall: accepting connection %s, domain permitted (prompting is not yet implemented)", connection)
connection.Accept("domain permitted (prompting is not yet implemented)")
}
case profile.Blacklist:
log.Infof("firewall: accepting connection %s, domain is not blacklisted", connection)
connection.Accept("domain is not blacklisted")
}
}
// DecideOnConnectionAfterIntel makes a decision about a connection after the dns query is resolved and intel is gathered.
func DecideOnConnectionAfterIntel(connection *network.Connection, fqdn string, rrCache *intel.RRCache) *intel.RRCache {
// grant self
if connection.Process().Pid == os.Getpid() {
log.Infof("firewall: granting own connection %s", connection)
connection.Accept("")
return rrCache
}
// check if there is a profile
profileSet := connection.Process().ProfileSet()
if profileSet == nil {
log.Errorf("firewall: denying connection %s, no profile set", connection)
connection.Deny("no profile")
return rrCache
}
profileSet.Update(status.CurrentSecurityLevel())
// TODO: Stamp integration
// TODO: Gate17 integration
// tunnelInfo, err := AssignTunnelIP(fqdn)
rrCache.Duplicate().FilterEntries(profileSet.CheckFlag(profile.Internet), profileSet.CheckFlag(profile.LAN), false)
if len(rrCache.Answer) == 0 {
if profileSet.CheckFlag(profile.Internet) {
connection.Deny("server is located in the LAN, but LAN access is not permitted")
} else {
connection.Deny("server is located in the Internet, but Internet access is not permitted")
}
}
return rrCache
}
// DeciceOnConnection makes a decision about a connection with its first packet.
func DecideOnConnection(connection *network.Connection, pkt packet.Packet) {
// grant self
if connection.Process().Pid == os.Getpid() {
log.Infof("firewall: granting own connection %s", connection)
connection.Accept("")
return
}
// check if there is a profile
profileSet := connection.Process().ProfileSet()
if profileSet == nil {
log.Errorf("firewall: denying connection %s, no profile set", connection)
connection.Deny("no profile")
return
}
profileSet.Update(status.CurrentSecurityLevel())
// check connection type
switch connection.Domain {
case network.IncomingHost, network.IncomingLAN, network.IncomingInternet, network.IncomingInvalid:
if !profileSet.CheckFlag(profile.Service) {
log.Infof("firewall: denying connection %s, not a service", connection)
if connection.Domain == network.IncomingHost {
connection.Block("not a service")
} else {
connection.Drop("not a service")
}
return
}
case network.PeerLAN, network.PeerInternet, network.PeerInvalid: // Important: PeerHost is and should be missing!
if !profileSet.CheckFlag(profile.PeerToPeer) {
log.Infof("firewall: denying connection %s, peer to peer connections (to an IP) not allowed", connection)
connection.Deny("peer to peer connections (to an IP) not allowed")
return
}
default:
}
// check network scope
switch connection.Domain {
case network.IncomingHost:
if !profileSet.CheckFlag(profile.Localhost) {
log.Infof("firewall: denying connection %s, serving localhost not allowed", connection)
connection.Block("serving localhost not allowed")
return
}
case network.IncomingLAN:
if !profileSet.CheckFlag(profile.LAN) {
log.Infof("firewall: denying connection %s, serving LAN not allowed", connection)
connection.Deny("serving LAN not allowed")
return
}
case network.IncomingInternet:
if !profileSet.CheckFlag(profile.Internet) {
log.Infof("firewall: denying connection %s, serving Internet not allowed", connection)
connection.Deny("serving Internet not allowed")
return
}
case network.IncomingInvalid:
log.Infof("firewall: denying connection %s, invalid IP address", connection)
connection.Drop("invalid IP address")
return
case network.PeerHost:
if !profileSet.CheckFlag(profile.Localhost) {
log.Infof("firewall: denying connection %s, accessing localhost not allowed", connection)
connection.Block("accessing localhost not allowed")
return
}
case network.PeerLAN:
if !profileSet.CheckFlag(profile.LAN) {
log.Infof("firewall: denying connection %s, accessing the LAN not allowed", connection)
connection.Deny("accessing the LAN not allowed")
return
}
case network.PeerInternet:
if !profileSet.CheckFlag(profile.Internet) {
log.Infof("firewall: denying connection %s, accessing the Internet not allowed", connection)
connection.Deny("accessing the Internet not allowed")
return
}
case network.PeerInvalid:
log.Infof("firewall: denying connection %s, invalid IP address", connection)
connection.Deny("invalid IP address")
return
}
log.Infof("firewall: accepting connection %s", connection)
connection.Accept("")
}
// DecideOnLink makes a decision about a link with the first packet.
func DecideOnLink(connection *network.Connection, link *network.Link, pkt packet.Packet) {
// check:
// Profile.Flags
// - network specific: Internet, LocalNet
// Profile.ConnectPorts
// Profile.ListenPorts
// grant self
if connection.Process().Pid == os.Getpid() {
log.Infof("firewall: granting own link %s", connection)
connection.Accept("")
return
}
// check if there is a profile
profileSet := connection.Process().ProfileSet()
if profileSet == nil {
log.Infof("firewall: no profile, denying %s", link)
link.Block("no profile")
return
}
profileSet.Update(status.CurrentSecurityLevel())
// get host
var domainOrIP string
switch {
case strings.HasSuffix(connection.Domain, "."):
domainOrIP = connection.Domain
case connection.Direction:
domainOrIP = pkt.GetIPHeader().Src.String()
default:
domainOrIP = pkt.GetIPHeader().Dst.String()
}
// get protocol / destination port
protocol := pkt.GetIPHeader().Protocol
var dstPort uint16
tcpUDPHeader := pkt.GetTCPUDPHeader()
if tcpUDPHeader != nil {
dstPort = tcpUDPHeader.DstPort
}
// check endpoints list
permitted, reason, ok := profileSet.CheckEndpoint(domainOrIP, uint8(protocol), dstPort, connection.Direction)
if ok {
if permitted {
log.Infof("firewall: accepting link %s, endpoint is whitelisted: %s", link, reason)
link.Accept(fmt.Sprintf("port whitelisted: %s", reason))
} else {
log.Infof("firewall: denying link %s: port %d is blacklisted", link, dstPort)
link.Deny("port blacklisted")
}
return
}
switch profileSet.GetProfileMode() {
case profile.Whitelist:
log.Infof("firewall: denying link %s: endpoint %d is not whitelisted", link, dstPort)
link.Deny("endpoint is not whitelisted")
return
case profile.Prompt:
log.Infof("firewall: accepting link %s: endpoint %d is blacklisted", link, dstPort)
link.Accept("endpoint permitted (prompting is not yet implemented)")
return
case profile.Blacklist:
log.Infof("firewall: accepting link %s: endpoint %d is not blacklisted", link, dstPort)
link.Accept("endpoint is not blacklisted")
return
}
log.Infof("firewall: accepting link %s", link)
link.Accept("")
}

View file

@ -1,4 +1,4 @@
package portmaster package firewall
import ( import (
"errors" "errors"
@ -7,8 +7,8 @@ import (
"sync" "sync"
"time" "time"
"github.com/Safing/safing-core/crypto/random" "github.com/Safing/portbase/crypto/random"
"github.com/Safing/safing-core/intel" "github.com/Safing/portmaster/intel"
"github.com/miekg/dns" "github.com/miekg/dns"
) )

49
global/databases.go Normal file
View file

@ -0,0 +1,49 @@
package global
import (
"github.com/Safing/portbase/database"
"github.com/Safing/portbase/modules"
// module dependencies
_ "github.com/Safing/portbase/database/dbmodule"
_ "github.com/Safing/portbase/database/storage/badger"
_ "github.com/Safing/portmaster/status"
)
func init() {
modules.Register("global", nil, start, nil, "database", "status")
}
func start() error {
_, err := database.Register(&database.Database{
Name: "core",
Description: "Holds core data, such as settings and profiles",
StorageType: "badger",
PrimaryAPI: "",
})
if err != nil {
return err
}
_, err = database.Register(&database.Database{
Name: "cache",
Description: "Cached data, such as Intelligence and DNS Records",
StorageType: "badger",
PrimaryAPI: "",
})
if err != nil {
return err
}
// _, err = database.Register(&database.Database{
// Name: "history",
// Description: "Historic event data",
// StorageType: "badger",
// PrimaryAPI: "",
// })
// if err != nil {
// return err
// }
return nil
}

94
intel/clients.go Normal file
View file

@ -0,0 +1,94 @@
package intel
import (
"crypto/tls"
"sync"
"time"
"github.com/miekg/dns"
)
type clientManager struct {
dnsClient *dns.Client
factory func() *dns.Client
lock sync.Mutex
refreshAfter time.Time
ttl time.Duration // force refresh of connection to reduce traceability
}
// ref: https://godoc.org/github.com/miekg/dns#Client
func newDNSClientManager(resolver *Resolver) *clientManager {
return &clientManager{
ttl: -1 * time.Minute,
factory: func() *dns.Client {
return &dns.Client{
Timeout: 5 * time.Second,
}
},
}
}
func newTCPClientManager(resolver *Resolver) *clientManager {
return &clientManager{
ttl: -15 * time.Minute,
factory: func() *dns.Client {
return &dns.Client{
Net: "tcp",
Timeout: 5 * time.Second,
}
},
}
}
func newTLSClientManager(resolver *Resolver) *clientManager {
return &clientManager{
ttl: -15 * time.Minute,
factory: func() *dns.Client {
return &dns.Client{
Net: "tcp-tls",
TLSConfig: &tls.Config{
MinVersion: tls.VersionTLS12,
ServerName: resolver.VerifyDomain,
// TODO: use custom random
// Rand: io.Reader,
},
Timeout: 5 * time.Second,
}
},
}
}
func newHTTPSClientManager(resolver *Resolver) *clientManager {
return &clientManager{
ttl: -15 * time.Minute,
factory: func() *dns.Client {
new := &dns.Client{
Net: "https",
TLSConfig: &tls.Config{
MinVersion: tls.VersionTLS12,
// TODO: use custom random
// Rand: io.Reader,
},
Timeout: 5 * time.Second,
}
if resolver.VerifyDomain != "" {
new.TLSConfig.ServerName = resolver.VerifyDomain
}
return new
},
}
}
func (cm *clientManager) getDNSClient() *dns.Client {
cm.lock.Lock()
defer cm.lock.Unlock()
if cm.dnsClient == nil || time.Now().After(cm.refreshAfter) {
cm.dnsClient = cm.factory()
cm.refreshAfter = time.Now().Add(cm.ttl)
}
return cm.dnsClient
}

100
intel/config.go Normal file
View file

@ -0,0 +1,100 @@
package intel
import (
"github.com/Safing/portbase/config"
"github.com/Safing/portmaster/status"
)
var (
configuredNameServers config.StringArrayOption
defaultNameServers = []string{
"tls|1.1.1.1:853|cloudflare-dns.com", // Cloudflare
"tls|1.0.0.1:853|cloudflare-dns.com", // Cloudflare
"tls|9.9.9.9:853|dns.quad9.net", // Quad9
// "https|cloudflare-dns.com/dns-query", // HTTPS still experimental
"dns|1.1.1.1:53", // Cloudflare
"dns|1.0.0.1:53", // Cloudflare
"dns|9.9.9.9:53", // Quad9
}
nameserverRetryRate config.IntOption
doNotUseMulticastDNS status.SecurityLevelOption
doNotUseAssignedNameservers status.SecurityLevelOption
doNotResolveSpecialDomains status.SecurityLevelOption
)
func prep() error {
err := config.Register(&config.Option{
Name: "Nameservers (DNS)",
Key: "intel/nameservers",
Description: "Nameserver to use for resolving DNS requests.",
ExpertiseLevel: config.ExpertiseLevelExpert,
OptType: config.OptTypeStringArray,
DefaultValue: defaultNameServers,
ValidationRegex: "^(dns|tcp|tls|https)$",
})
if err != nil {
return err
}
configuredNameServers = config.Concurrent.GetAsStringArray("intel/nameservers", defaultNameServers)
err = config.Register(&config.Option{
Name: "Nameserver Retry Rate",
Key: "intel/nameserverRetryRate",
Description: "Rate at which to retry failed nameservers, in seconds.",
ExpertiseLevel: config.ExpertiseLevelExpert,
OptType: config.OptTypeInt,
DefaultValue: 600,
})
if err != nil {
return err
}
nameserverRetryRate = config.Concurrent.GetAsInt("intel/nameserverRetryRate", 0)
err = config.Register(&config.Option{
Name: "Do not use Multicast DNS",
Key: "intel/doNotUseMulticastDNS",
Description: "Multicast DNS queries other devices in the local network",
ExpertiseLevel: config.ExpertiseLevelExpert,
OptType: config.OptTypeInt,
ExternalOptType: "security level",
DefaultValue: 3,
ValidationRegex: "^(1|2|3)$",
})
if err != nil {
return err
}
doNotUseMulticastDNS = status.ConfigIsActiveConcurrent("intel/doNotUseMulticastDNS")
err = config.Register(&config.Option{
Name: "Do not use assigned Nameservers",
Key: "intel/doNotUseAssignedNameservers",
Description: "that were acquired by the network (dhcp) or system",
ExpertiseLevel: config.ExpertiseLevelExpert,
OptType: config.OptTypeInt,
ExternalOptType: "security level",
DefaultValue: 3,
ValidationRegex: "^(1|2|3)$",
})
if err != nil {
return err
}
doNotUseAssignedNameservers = status.ConfigIsActiveConcurrent("intel/doNotUseAssignedNameservers")
err = config.Register(&config.Option{
Name: "Do not resolve special domains",
Key: "intel/doNotResolveSpecialDomains",
Description: "Do not resolve special (top level) domains: example, example.com, example.net, example.org, invalid, test, onion. (RFC6761, RFC7686)",
ExpertiseLevel: config.ExpertiseLevelExpert,
OptType: config.OptTypeInt,
ExternalOptType: "security level",
DefaultValue: 3,
ValidationRegex: "^(1|2|3)$",
})
if err != nil {
return err
}
doNotResolveSpecialDomains = status.ConfigIsActiveConcurrent("intel/doNotResolveSpecialDomains")
return nil
}

View file

@ -1,62 +0,0 @@
// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file.
package intel
import (
"github.com/Safing/safing-core/database"
datastore "github.com/ipfs/go-datastore"
)
// EntityClassification holds classification information about an internet entity.
type EntityClassification struct {
lists []byte
}
// Intel holds intelligence data for a domain.
type Intel struct {
database.Base
Domain string
DomainOwner string
CertOwner string
Classification *EntityClassification
}
var intelModel *Intel // only use this as parameter for database.EnsureModel-like functions
func init() {
database.RegisterModel(intelModel, func() database.Model { return new(Intel) })
}
// Create saves the Intel with the provided name in the default namespace.
func (m *Intel) Create(name string) error {
return m.CreateObject(&database.IntelCache, name, m)
}
// CreateInNamespace saves the Intel with the provided name in the provided namespace.
func (m *Intel) CreateInNamespace(namespace *datastore.Key, name string) error {
return m.CreateObject(namespace, name, m)
}
// Save saves the Intel.
func (m *Intel) Save() error {
return m.SaveObject(m)
}
// getIntel fetches the Intel with the provided name in the default namespace.
func getIntel(name string) (*Intel, error) {
return getIntelFromNamespace(&database.IntelCache, name)
}
// getIntelFromNamespace fetches the Intel with the provided name in the provided namespace.
func getIntelFromNamespace(namespace *datastore.Key, name string) (*Intel, error) {
object, err := database.GetAndEnsureModel(namespace, name, intelModel)
if err != nil {
return nil, err
}
model, ok := object.(*Intel)
if !ok {
return nil, database.NewMismatchError(object, intelModel)
}
return model, nil
}

View file

@ -1,218 +0,0 @@
// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file.
package intel
import (
"fmt"
"net"
"time"
"github.com/Safing/safing-core/database"
datastore "github.com/ipfs/go-datastore"
"github.com/miekg/dns"
)
// RRCache is used to cache DNS data
type RRCache struct {
Answer []dns.RR
Ns []dns.RR
Extra []dns.RR
Expires int64
Modified int64
servedFromCache bool
requestingNew bool
}
func (m *RRCache) Clean(minExpires uint32) {
var lowestTTL uint32 = 0xFFFFFFFF
var header *dns.RR_Header
// set TTLs to 17
// TODO: double append? is there something more elegant?
for _, rr := range append(m.Answer, append(m.Ns, m.Extra...)...) {
header = rr.Header()
if lowestTTL > header.Ttl {
lowestTTL = header.Ttl
}
header.Ttl = 17
}
// TTL must be at least minExpires
if lowestTTL < minExpires {
lowestTTL = minExpires
}
m.Expires = time.Now().Unix() + int64(lowestTTL)
m.Modified = time.Now().Unix()
}
func (m *RRCache) ExportAllARecords() (ips []net.IP) {
for _, rr := range m.Answer {
if rr.Header().Class == dns.ClassINET && rr.Header().Rrtype == dns.TypeA {
aRecord, ok := rr.(*dns.A)
if ok {
ips = append(ips, aRecord.A)
}
} else if rr.Header().Class == dns.ClassINET && rr.Header().Rrtype == dns.TypeAAAA {
aRecord, ok := rr.(*dns.AAAA)
if ok {
ips = append(ips, aRecord.AAAA)
}
}
}
return
}
func (m *RRCache) ToRRSave() *RRSave {
var s RRSave
s.Expires = m.Expires
s.Modified = m.Modified
for _, entry := range m.Answer {
s.Answer = append(s.Answer, entry.String())
}
for _, entry := range m.Ns {
s.Ns = append(s.Ns, entry.String())
}
for _, entry := range m.Extra {
s.Extra = append(s.Extra, entry.String())
}
return &s
}
func (m *RRCache) Create(name string) error {
s := m.ToRRSave()
return s.CreateObject(&database.DNSCache, name, s)
}
func (m *RRCache) CreateWithType(name string, qtype dns.Type) error {
s := m.ToRRSave()
return s.Create(fmt.Sprintf("%s%s", name, qtype.String()))
}
func (m *RRCache) Save() error {
s := m.ToRRSave()
return s.SaveObject(s)
}
func GetRRCache(domain string, qtype dns.Type) (*RRCache, error) {
return GetRRCacheFromNamespace(&database.DNSCache, domain, qtype)
}
func GetRRCacheFromNamespace(namespace *datastore.Key, domain string, qtype dns.Type) (*RRCache, error) {
var m RRCache
rrSave, err := GetRRSaveFromNamespace(namespace, domain, qtype)
if err != nil {
return nil, err
}
m.Expires = rrSave.Expires
m.Modified = rrSave.Modified
for _, entry := range rrSave.Answer {
rr, err := dns.NewRR(entry)
if err == nil {
m.Answer = append(m.Answer, rr)
}
}
for _, entry := range rrSave.Ns {
rr, err := dns.NewRR(entry)
if err == nil {
m.Ns = append(m.Ns, rr)
}
}
for _, entry := range rrSave.Extra {
rr, err := dns.NewRR(entry)
if err == nil {
m.Extra = append(m.Extra, rr)
}
}
m.servedFromCache = true
return &m, nil
}
// ServedFromCache marks the RRCache as served from cache.
func (m *RRCache) ServedFromCache() bool {
return m.servedFromCache
}
// RequestingNew informs that it has expired and new RRs are being fetched.
func (m *RRCache) RequestingNew() bool {
return m.requestingNew
}
// Flags formats ServedFromCache and RequestingNew to a condensed, flag-like format.
func (m *RRCache) Flags() string {
switch {
case m.servedFromCache && m.requestingNew:
return " [CR]"
case m.servedFromCache:
return " [C]"
case m.requestingNew:
return " [R]" // theoretically impossible, but let's leave it here, just in case
default:
return ""
}
}
// IsNXDomain returnes whether the result is nxdomain.
func (m *RRCache) IsNXDomain() bool {
return len(m.Answer) == 0
}
// RRSave is helper struct to RRCache to better save data to the database.
type RRSave struct {
database.Base
Answer []string
Ns []string
Extra []string
Expires int64
Modified int64
}
var rrSaveModel *RRSave // only use this as parameter for database.EnsureModel-like functions
func init() {
database.RegisterModel(rrSaveModel, func() database.Model { return new(RRSave) })
}
// Create saves RRSave with the provided name in the default namespace.
func (m *RRSave) Create(name string) error {
return m.CreateObject(&database.DNSCache, name, m)
}
// CreateWithType saves RRSave with the provided name and type in the default namespace.
func (m *RRSave) CreateWithType(name string, qtype dns.Type) error {
return m.Create(fmt.Sprintf("%s%s", name, qtype.String()))
}
// CreateInNamespace saves RRSave with the provided name in the provided namespace.
func (m *RRSave) CreateInNamespace(namespace *datastore.Key, name string) error {
return m.CreateObject(namespace, name, m)
}
// Save saves RRSave.
func (m *RRSave) Save() error {
return m.SaveObject(m)
}
// GetRRSave fetches RRSave with the provided name in the default namespace.
func GetRRSave(name string, qtype dns.Type) (*RRSave, error) {
return GetRRSaveFromNamespace(&database.DNSCache, name, qtype)
}
// GetRRSaveFromNamespace fetches RRSave with the provided name in the provided namespace.
func GetRRSaveFromNamespace(namespace *datastore.Key, name string, qtype dns.Type) (*RRSave, error) {
object, err := database.GetAndEnsureModel(namespace, fmt.Sprintf("%s%s", name, qtype.String()), rrSaveModel)
if err != nil {
return nil, err
}
model, ok := object.(*RRSave)
if !ok {
return nil, database.NewMismatchError(object, rrSaveModel)
}
return model, nil
}

View file

@ -1,48 +0,0 @@
// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file.
package intel
import (
"github.com/Safing/safing-core/log"
"sync"
"github.com/miekg/dns"
)
var (
dfMap = make(map[string]string)
dfMapLock sync.RWMutex
)
func checkDomainFronting(hidden string, qtype dns.Type, securityLevel int8) (*RRCache, bool) {
dfMapLock.RLock()
front, ok := dfMap[hidden]
dfMapLock.RUnlock()
if !ok {
return nil, false
}
log.Tracef("intel: applying domain fronting %s -> %s", hidden, front)
// get domain name
rrCache := resolveAndCache(front, qtype, securityLevel)
if rrCache == nil {
return nil, true
}
// replace domain name
var header *dns.RR_Header
for _, rr := range rrCache.Answer {
header = rr.Header()
if header.Name == front {
header.Name = hidden
}
}
// save under front
rrCache.CreateWithType(hidden, qtype)
return rrCache, true
}
func addDomainFronting(hidden string, front string) {
dfMapLock.Lock()
dfMap[hidden] = front
dfMapLock.Unlock()
return
}

View file

@ -3,44 +3,66 @@
package intel package intel
import ( import (
"github.com/Safing/safing-core/database" "fmt"
"github.com/Safing/safing-core/modules" "sync"
"github.com/miekg/dns" "github.com/Safing/portbase/database"
"github.com/Safing/portbase/database/record"
) )
var ( var (
intelModule *modules.Module intelDatabase = database.NewInterface(&database.Options{
AlwaysSetRelativateExpiry: 2592000, // 30 days
})
) )
func init() { // Intel holds intelligence data for a domain.
intelModule = modules.Register("Intel", 128) type Intel struct {
go Start() record.Base
sync.Mutex
Domain string
} }
// GetIntel returns an Intel object of the given domain. The returned Intel object MUST not be modified. func makeIntelKey(domain string) string {
func GetIntel(domain string) *Intel { return fmt.Sprintf("cache:intel/domain/%s", domain)
fqdn := dns.Fqdn(domain) }
intel, err := getIntel(fqdn)
// GetIntelFromDB gets an Intel record from the database.
func GetIntelFromDB(domain string) (*Intel, error) {
key := makeIntelKey(domain)
r, err := intelDatabase.Get(key)
if err != nil { if err != nil {
if err == database.ErrNotFound { return nil, err
intel = &Intel{Domain: fqdn}
intel.Create(fqdn)
} else {
return nil
} }
// unwrap
if r.IsWrapped() {
// only allocate a new struct, if we need it
new := &Intel{}
err = record.Unwrap(r, new)
if err != nil {
return nil, err
} }
return intel return new, nil
}
// or adjust type
new, ok := r.(*Intel)
if !ok {
return nil, fmt.Errorf("record not of type *Intel, but %T", r)
}
return new, nil
} }
func GetIntelAndRRs(domain string, qtype dns.Type, securityLevel int8) (intel *Intel, rrs *RRCache) { // Save saves the Intel record to the database.
intel = GetIntel(domain) func (intel *Intel) Save() error {
rrs = Resolve(domain, qtype, securityLevel) intel.SetKey(makeIntelKey(intel.Domain))
return return intelDatabase.PutNew(intel)
} }
func Start() { // GetIntel fetches intelligence data for the given domain.
// mocking until intel has its own goroutines func GetIntel(domain string) (*Intel, error) {
defer intelModule.StopComplete() return &Intel{Domain: domain}, nil
<-intelModule.Stop
} }

View file

@ -1,61 +1,91 @@
// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file.
package intel package intel
import ( import (
"fmt"
"strings" "strings"
"sync"
"github.com/Safing/safing-core/database" "github.com/Safing/portbase/database"
"github.com/Safing/portbase/database/record"
"github.com/Safing/portbase/utils"
)
datastore "github.com/ipfs/go-datastore" var (
ipInfoDatabase = database.NewInterface(&database.Options{
AlwaysSetRelativateExpiry: 86400, // 24 hours
})
) )
// IPInfo represents various information about an IP. // IPInfo represents various information about an IP.
type IPInfo struct { type IPInfo struct {
database.Base record.Base
sync.Mutex
IP string
Domains []string Domains []string
} }
var ipInfoModel *IPInfo // only use this as parameter for database.EnsureModel-like functions func makeIPInfoKey(ip string) string {
return fmt.Sprintf("cache:intel/ipInfo/%s", ip)
func init() {
database.RegisterModel(ipInfoModel, func() database.Model { return new(IPInfo) })
} }
// Create saves the IPInfo with the provided name in the default namespace. // GetIPInfo gets an IPInfo record from the database.
func (m *IPInfo) Create(name string) error { func GetIPInfo(ip string) (*IPInfo, error) {
return m.CreateObject(&database.IPInfoCache, name, m) key := makeIPInfoKey(ip)
}
// CreateInNamespace saves the IPInfo with the provided name in the provided namespace. r, err := ipInfoDatabase.Get(key)
func (m *IPInfo) CreateInNamespace(namespace *datastore.Key, name string) error {
return m.CreateObject(namespace, name, m)
}
// Save saves the IPInfo.
func (m *IPInfo) Save() error {
return m.SaveObject(m)
}
// GetIPInfo fetches the IPInfo with the provided name in the default namespace.
func GetIPInfo(name string) (*IPInfo, error) {
return GetIPInfoFromNamespace(&database.IPInfoCache, name)
}
// GetIPInfoFromNamespace fetches the IPInfo with the provided name in the provided namespace.
func GetIPInfoFromNamespace(namespace *datastore.Key, name string) (*IPInfo, error) {
object, err := database.GetAndEnsureModel(namespace, name, ipInfoModel)
if err != nil { if err != nil {
return nil, err return nil, err
} }
model, ok := object.(*IPInfo)
if !ok { // unwrap
return nil, database.NewMismatchError(object, ipInfoModel) if r.IsWrapped() {
// only allocate a new struct, if we need it
new := &IPInfo{}
err = record.Unwrap(r, new)
if err != nil {
return nil, err
} }
return model, nil return new, nil
}
// or adjust type
new, ok := r.(*IPInfo)
if !ok {
return nil, fmt.Errorf("record not of type *IPInfo, but %T", r)
}
return new, nil
}
// AddDomain adds a domain to the list and reports back if it was added, or was already present.
func (ipi *IPInfo) AddDomain(domain string) (added bool) {
ipi.Lock()
defer ipi.Unlock()
if !utils.StringInSlice(ipi.Domains, domain) {
ipi.Domains = append([]string{domain}, ipi.Domains...)
return true
}
return false
}
// Save saves the IPInfo record to the database.
func (ipi *IPInfo) Save() error {
ipi.Lock()
if !ipi.KeyIsSet() {
ipi.SetKey(makeIPInfoKey(ipi.IP))
}
ipi.Unlock()
return ipInfoDatabase.Put(ipi)
} }
// FmtDomains returns a string consisting of the domains that have seen to use this IP, joined by " or " // FmtDomains returns a string consisting of the domains that have seen to use this IP, joined by " or "
func (m *IPInfo) FmtDomains() string { func (ipi *IPInfo) FmtDomains() string {
return strings.Join(m.Domains, " or ") return strings.Join(ipi.Domains, " or ")
}
// FmtDomains returns a string consisting of the domains that have seen to use this IP, joined by " or "
func (ipi *IPInfo) String() string {
ipi.Lock()
defer ipi.Unlock()
return fmt.Sprintf("<IPInfo[%s] %s: %s", ipi.Key(), ipi.IP, ipi.FmtDomains())
} }

25
intel/ipinfo_test.go Normal file
View file

@ -0,0 +1,25 @@
package intel
import "testing"
func testDomains(t *testing.T, ipi *IPInfo, expectedDomains string) {
if ipi.FmtDomains() != expectedDomains {
t.Errorf("unexpected domains '%s', expected '%s'", ipi.FmtDomains(), expectedDomains)
}
}
func TestIPInfo(t *testing.T) {
ipi := &IPInfo{
IP: "1.2.3.4",
Domains: []string{"example.com.", "sub.example.com."},
}
testDomains(t, ipi, "example.com. or sub.example.com.")
ipi.AddDomain("added.example.com.")
testDomains(t, ipi, "added.example.com. or example.com. or sub.example.com.")
ipi.AddDomain("sub.example.com.")
testDomains(t, ipi, "added.example.com. or example.com. or sub.example.com.")
ipi.AddDomain("added.example.com.")
testDomains(t, ipi, "added.example.com. or example.com. or sub.example.com.")
}

35
intel/main.go Normal file
View file

@ -0,0 +1,35 @@
package intel
import (
"github.com/miekg/dns"
"github.com/Safing/portbase/log"
"github.com/Safing/portbase/modules"
// module dependencies
_ "github.com/Safing/portmaster/global"
)
func init() {
modules.Register("intel", prep, start, nil, "global")
}
func start() error {
// load resolvers from config and environment
loadResolvers(false)
go listenToMDNS()
return nil
}
// GetIntelAndRRs returns intel and DNS resource records for the given domain.
func GetIntelAndRRs(domain string, qtype dns.Type, securityLevel uint8) (intel *Intel, rrs *RRCache) {
intel, err := GetIntel(domain)
if err != nil {
log.Errorf("intel: failed to get intel: %s", err)
intel = nil
}
rrs = Resolve(domain, qtype, securityLevel)
return
}

38
intel/main_test.go Normal file
View file

@ -0,0 +1,38 @@
package intel
import (
"os"
"testing"
"github.com/Safing/portbase/database/dbmodule"
"github.com/Safing/portbase/log"
"github.com/Safing/portbase/modules"
)
func TestMain(m *testing.M) {
// setup
testDir := os.TempDir()
dbmodule.SetDatabaseLocation(testDir)
err := modules.Start()
if err != nil {
if err == modules.ErrCleanExit {
os.Exit(0)
} else {
err = modules.Shutdown()
if err != nil {
log.Shutdown()
}
os.Exit(1)
}
}
// run tests
rv := m.Run()
// teardown
modules.Shutdown()
os.RemoveAll(testDir)
// exit with test run return value
os.Exit(rv)
}

View file

@ -6,14 +6,16 @@ import (
"errors" "errors"
"fmt" "fmt"
"net" "net"
"github.com/Safing/safing-core/log"
"strings" "strings"
"sync" "sync"
"time" "time"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/Safing/portbase/log"
) )
// DNS Classes
const ( const (
DNSClassMulticast = dns.ClassINET | 1<<15 DNSClassMulticast = dns.ClassINET | 1<<15
) )
@ -33,10 +35,6 @@ type savedQuestion struct {
expires int64 expires int64
} }
func init() {
go listenToMDNS()
}
func indexOfRR(entry *dns.RR_Header, list *[]dns.RR) int { func indexOfRR(entry *dns.RR_Header, list *[]dns.RR) int {
for k, v := range *list { for k, v := range *list {
if entry.Name == v.Header().Name && entry.Rrtype == v.Header().Rrtype { if entry.Name == v.Header().Name && entry.Rrtype == v.Header().Rrtype {
@ -89,7 +87,7 @@ func listenToMDNS() {
var question *dns.Question var question *dns.Question
var saveFullRequest bool var saveFullRequest bool
scavengedRecords := make(map[string]*dns.RR) scavengedRecords := make(map[string]dns.RR)
var rrCache *RRCache var rrCache *RRCache
// save every received response // save every received response
@ -114,7 +112,7 @@ func listenToMDNS() {
continue continue
} }
// continue if no question // get question, some servers do not reply with question
if len(message.Question) == 0 { if len(message.Question) == 0 {
questionsLock.Lock() questionsLock.Lock()
savedQ, ok := questions[message.MsgHdr.Id] savedQ, ok := questions[message.MsgHdr.Id]
@ -138,8 +136,11 @@ func listenToMDNS() {
// get entry from database // get entry from database
if saveFullRequest { if saveFullRequest {
rrCache, err = GetRRCache(question.Name, dns.Type(question.Qtype)) rrCache, err = GetRRCache(question.Name, dns.Type(question.Qtype))
if err != nil || rrCache.Modified < time.Now().Add(-2*time.Second).Unix() || rrCache.Expires < time.Now().Unix() { if err != nil || rrCache.updated < time.Now().Add(-2*time.Second).Unix() || rrCache.TTL < time.Now().Unix() {
rrCache = &RRCache{} rrCache = &RRCache{
Domain: question.Name,
Question: dns.Type(question.Qtype),
}
} }
} }
@ -155,12 +156,12 @@ func listenToMDNS() {
} }
switch entry.(type) { switch entry.(type) {
case *dns.A: case *dns.A:
scavengedRecords[fmt.Sprintf("%sA", entry.Header().Name)] = &entry scavengedRecords[fmt.Sprintf("%sA", entry.Header().Name)] = entry
case *dns.AAAA: case *dns.AAAA:
scavengedRecords[fmt.Sprintf("%sAAAA", entry.Header().Name)] = &entry scavengedRecords[fmt.Sprintf("%sAAAA", entry.Header().Name)] = entry
case *dns.PTR: case *dns.PTR:
if !strings.HasPrefix(entry.Header().Name, "_") { if !strings.HasPrefix(entry.Header().Name, "_") {
scavengedRecords[fmt.Sprintf("%sPTR", entry.Header().Name)] = &entry scavengedRecords[fmt.Sprintf("%sPTR", entry.Header().Name)] = entry
} }
} }
} }
@ -177,17 +178,16 @@ func listenToMDNS() {
} }
switch entry.(type) { switch entry.(type) {
case *dns.A: case *dns.A:
scavengedRecords[fmt.Sprintf("%sA", entry.Header().Name)] = &entry scavengedRecords[fmt.Sprintf("%s_A", entry.Header().Name)] = entry
case *dns.AAAA: case *dns.AAAA:
scavengedRecords[fmt.Sprintf("%sAAAA", entry.Header().Name)] = &entry scavengedRecords[fmt.Sprintf("%s_AAAA", entry.Header().Name)] = entry
case *dns.PTR: case *dns.PTR:
if !strings.HasPrefix(entry.Header().Name, "_") { if !strings.HasPrefix(entry.Header().Name, "_") {
scavengedRecords[fmt.Sprintf("%sPTR", entry.Header().Name)] = &entry scavengedRecords[fmt.Sprintf("%s_PTR", entry.Header().Name)] = entry
} }
} }
} }
} }
// TODO: scan Extra for A and AAAA records and save them seperately
for _, entry := range message.Extra { for _, entry := range message.Extra {
if strings.HasSuffix(entry.Header().Name, ".local.") || domainInScopes(entry.Header().Name, localReverseScopes) { if strings.HasSuffix(entry.Header().Name, ".local.") || domainInScopes(entry.Header().Name, localReverseScopes) {
if saveFullRequest { if saveFullRequest {
@ -200,34 +200,35 @@ func listenToMDNS() {
} }
switch entry.(type) { switch entry.(type) {
case *dns.A: case *dns.A:
scavengedRecords[fmt.Sprintf("%sA", entry.Header().Name)] = &entry scavengedRecords[fmt.Sprintf("%sA", entry.Header().Name)] = entry
case *dns.AAAA: case *dns.AAAA:
scavengedRecords[fmt.Sprintf("%sAAAA", entry.Header().Name)] = &entry scavengedRecords[fmt.Sprintf("%sAAAA", entry.Header().Name)] = entry
case *dns.PTR: case *dns.PTR:
if !strings.HasPrefix(entry.Header().Name, "_") { if !strings.HasPrefix(entry.Header().Name, "_") {
scavengedRecords[fmt.Sprintf("%sPTR", entry.Header().Name)] = &entry scavengedRecords[fmt.Sprintf("%sPTR", entry.Header().Name)] = entry
} }
} }
} }
} }
var questionID string
if saveFullRequest { if saveFullRequest {
rrCache.Clean(60) rrCache.Clean(60)
rrCache.CreateWithType(question.Name, dns.Type(question.Qtype)) rrCache.Save()
// log.Tracef("intel: mdns saved full reply to %s%s", question.Name, dns.Type(question.Qtype).String()) questionID = fmt.Sprintf("%s%s", question.Name, dns.Type(question.Qtype).String())
} }
for k, v := range scavengedRecords { for k, v := range scavengedRecords {
if saveFullRequest { if saveFullRequest && k == questionID {
if k == fmt.Sprintf("%s%s", question.Name, dns.Type(question.Qtype).String()) {
continue continue
} }
}
rrCache = &RRCache{ rrCache = &RRCache{
Answer: []dns.RR{*v}, Domain: v.Header().Name,
Question: dns.Type(v.Header().Class),
Answer: []dns.RR{v},
} }
rrCache.Clean(60) rrCache.Clean(60)
rrCache.Create(k) rrCache.Save()
// log.Tracef("intel: mdns scavenged %s", k) // log.Tracef("intel: mdns scavenged %s", k)
} }
@ -261,7 +262,7 @@ func listenForDNSPackets(conn *net.UDPConn, messages chan *dns.Msg) {
} }
} }
func queryMulticastDNS(resolver *Resolver, fqdn string, qtype dns.Type) (*RRCache, error) { func queryMulticastDNS(fqdn string, qtype dns.Type) (*RRCache, error) {
q := new(dns.Msg) q := new(dns.Msg)
q.SetQuestion(fqdn, uint16(qtype)) q.SetQuestion(fqdn, uint16(qtype))
// request unicast response // request unicast response

73
intel/namerecord.go Normal file
View file

@ -0,0 +1,73 @@
package intel
import (
"errors"
"fmt"
"sync"
"github.com/Safing/portbase/database"
"github.com/Safing/portbase/database/record"
)
var (
recordDatabase = database.NewInterface(&database.Options{
AlwaysSetRelativateExpiry: 2592000, // 30 days
CacheSize: 128,
})
)
// NameRecord is helper struct to RRCache to better save data to the database.
type NameRecord struct {
record.Base
sync.Mutex
Domain string
Question string
Answer []string
Ns []string
Extra []string
TTL int64
Filtered bool
}
func makeNameRecordKey(domain string, question string) string {
return fmt.Sprintf("cache:intel/nameRecord/%s%s", domain, question)
}
// GetNameRecord gets a NameRecord from the database.
func GetNameRecord(domain string, question string) (*NameRecord, error) {
key := makeNameRecordKey(domain, question)
r, err := recordDatabase.Get(key)
if err != nil {
return nil, err
}
// unwrap
if r.IsWrapped() {
// only allocate a new struct, if we need it
new := &NameRecord{}
err = record.Unwrap(r, new)
if err != nil {
return nil, err
}
return new, nil
}
// or adjust type
new, ok := r.(*NameRecord)
if !ok {
return nil, fmt.Errorf("record not of type *NameRecord, but %T", r)
}
return new, nil
}
// Save saves the NameRecord to the database.
func (rec *NameRecord) Save() error {
if rec.Domain == "" || rec.Question == "" {
return errors.New("could not save NameRecord, missing Domain and/or Question")
}
rec.SetKey(makeNameRecordKey(rec.Domain, rec.Question))
return recordDatabase.PutNew(rec)
}

View file

@ -3,30 +3,20 @@
package intel package intel
import ( import (
"crypto/tls"
"encoding/json"
"errors"
"fmt" "fmt"
"io/ioutil"
"math/rand" "math/rand"
"net" "net"
"net/http"
"net/url"
"sort"
"strconv"
"strings" "strings"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/tevino/abool"
"github.com/Safing/safing-core/configuration" "github.com/Safing/portbase/database"
"github.com/Safing/safing-core/database" "github.com/Safing/portbase/log"
"github.com/Safing/safing-core/log" "github.com/Safing/portmaster/network/netutils"
"github.com/Safing/safing-core/network/environment" "github.com/Safing/portmaster/status"
"github.com/Safing/safing-core/network/netutils"
) )
// TODO: make resolver interface for http package // TODO: make resolver interface for http package
@ -79,322 +69,14 @@ import (
// global -> local scopes, global // global -> local scopes, global
// special -> local scopes, local // special -> local scopes, local
type Resolver struct {
// static
Server string
ServerAddress string
IP *net.IP
Port uint16
Resolve func(resolver *Resolver, fqdn string, qtype dns.Type) (*RRCache, error)
Search *[]string
AllowedSecurityLevel int8
SkipFqdnBeforeInit string
HTTPClient *http.Client
Source string
// atomic
Initialized *abool.AtomicBool
InitLock sync.Mutex
LastFail *int64
Expires *int64
// must be locked
LockReason sync.Mutex
FailReason string
// TODO: add:
// Expiration (for server got from DHCP / ICMPv6)
// bootstrapping (first query is already sent, wait for it to either succeed or fail - think about http bootstrapping here!)
// expanded server info: type, server address, server port, options - so we do not have to parse this every time!
}
func (r *Resolver) String() string {
return r.Server
}
func (r *Resolver) Address() string {
return urlFormatAddress(r.IP, r.Port)
}
type Scope struct {
Domain string
Resolvers []*Resolver
}
var (
config = configuration.Get()
globalResolvers []*Resolver // all resolvers
localResolvers []*Resolver // all resolvers that are in site-local or link-local IP ranges
localScopes []Scope // list of scopes with a list of local resolvers that can resolve the scope
mDNSResolver *Resolver // holds a reference to the mDNS resolver
resolversLock sync.RWMutex
env = environment.NewInterface()
dupReqMap = make(map[string]*sync.Mutex)
dupReqLock sync.Mutex
)
func init() {
loadResolvers(false)
}
func indexOfResolver(server string, list []*Resolver) int {
for k, v := range list {
if v.Server == server {
return k
}
}
return -1
}
func indexOfScope(domain string, list *[]Scope) int {
for k, v := range *list {
if v.Domain == domain {
return k
}
}
return -1
}
func parseAddress(server string) (*net.IP, uint16, error) {
delimiter := strings.LastIndex(server, ":")
if delimiter < 0 {
return nil, 0, errors.New("port missing")
}
ip := net.ParseIP(strings.Trim(server[:delimiter], "[]"))
if ip == nil {
return nil, 0, errors.New("invalid IP address")
}
port, err := strconv.Atoi(server[delimiter+1:])
if err != nil || port < 1 || port > 65536 {
return nil, 0, errors.New("invalid port")
}
return &ip, uint16(port), nil
}
func urlFormatAddress(ip *net.IP, port uint16) string {
var address string
if ipv4 := ip.To4(); ipv4 != nil {
address = fmt.Sprintf("%s:%d", ipv4.String(), port)
} else {
address = fmt.Sprintf("[%s]:%d", ip.String(), port)
}
return address
}
func loadResolvers(resetResolvers bool) {
// TODO: what happens when a lot of processes want to reload at once? we do not need to run this multiple times in a short time frame.
resolversLock.Lock()
defer resolversLock.Unlock()
var newResolvers []*Resolver
configuredServersLoop:
for _, server := range config.DNSServers {
key := indexOfResolver(server, newResolvers)
if key >= 0 {
continue configuredServersLoop
}
key = indexOfResolver(server, globalResolvers)
if resetResolvers || key == -1 {
parts := strings.Split(server, "|")
if len(parts) < 2 {
log.Warningf("intel: invalid DNS server in config: %s (invalid format)", server)
continue configuredServersLoop
}
var lastFail int64
new := &Resolver{
Server: server,
ServerAddress: parts[1],
LastFail: &lastFail,
Source: "config",
Initialized: abool.NewBool(false),
}
ip, port, err := parseAddress(parts[1])
if err != nil {
new.IP = ip
new.Port = port
}
switch {
case strings.HasPrefix(server, "DNS|"):
new.Resolve = queryDNS
new.AllowedSecurityLevel = configuration.SecurityLevelFortress
case strings.HasPrefix(server, "DoH|"):
new.Resolve = queryDNSoverHTTPS
new.AllowedSecurityLevel = configuration.SecurityLevelFortress
new.SkipFqdnBeforeInit = dns.Fqdn(strings.Split(parts[1], ":")[0])
tls := &tls.Config{
// TODO: use custom random
// Rand: io.Reader,
}
tr := &http.Transport{
MaxIdleConnsPerHost: 100,
TLSClientConfig: tls,
// TODO: use custom resolver as of Go1.9
}
if len(parts) == 3 && strings.HasPrefix(parts[2], "df:") {
// activate domain fronting
tls.ServerName = parts[2][3:]
addDomainFronting(new.SkipFqdnBeforeInit, dns.Fqdn(tls.ServerName))
new.SkipFqdnBeforeInit = dns.Fqdn(tls.ServerName)
}
new.HTTPClient = &http.Client{Transport: tr}
default:
log.Warningf("intel: invalid DNS server in config: %s (not starting with a valid identifier)", server)
continue configuredServersLoop
}
newResolvers = append(newResolvers, new)
} else {
newResolvers = append(newResolvers, globalResolvers[key])
}
}
// add local resolvers
assignedNameservers := environment.Nameservers()
assignedServersLoop:
for _, nameserver := range assignedNameservers {
server := fmt.Sprintf("DNS|%s", urlFormatAddress(&nameserver.IP, 53))
key := indexOfResolver(server, newResolvers)
if key >= 0 {
continue assignedServersLoop
}
key = indexOfResolver(server, globalResolvers)
if resetResolvers || key == -1 {
var lastFail int64
new := &Resolver{
Server: server,
ServerAddress: urlFormatAddress(&nameserver.IP, 53),
IP: &nameserver.IP,
Port: 53,
LastFail: &lastFail,
Resolve: queryDNS,
AllowedSecurityLevel: configuration.SecurityLevelFortress,
Initialized: abool.NewBool(false),
Source: "dhcp",
}
if netutils.IPIsLocal(nameserver.IP) && len(nameserver.Search) > 0 {
// only allow searches for local resolvers
var newSearch []string
for _, value := range nameserver.Search {
newSearch = append(newSearch, fmt.Sprintf(".%s.", strings.Trim(value, ".")))
}
new.Search = &newSearch
}
newResolvers = append(newResolvers, new)
} else {
newResolvers = append(newResolvers, globalResolvers[key])
}
}
// save resolvers
globalResolvers = newResolvers
if len(globalResolvers) == 0 {
log.Criticalf("intel: no (valid) dns servers found in configuration and system")
}
// make list with local resolvers
localResolvers = make([]*Resolver, 0)
for _, resolver := range globalResolvers {
if resolver.IP != nil && netutils.IPIsLocal(*resolver.IP) {
localResolvers = append(localResolvers, resolver)
}
}
// add resolvers to every scope the cover
localScopes = make([]Scope, 0)
for _, resolver := range globalResolvers {
if resolver.Search != nil {
// add resolver to custom searches
for _, search := range *resolver.Search {
if search == "." {
continue
}
key := indexOfScope(search, &localScopes)
if key == -1 {
localScopes = append(localScopes, Scope{
Domain: search,
Resolvers: []*Resolver{resolver},
})
} else {
localScopes[key].Resolvers = append(localScopes[key].Resolvers, resolver)
}
}
}
}
// init mdns resolver
if mDNSResolver == nil {
cannotFail := int64(-1)
mDNSResolver = &Resolver{
Server: "mDNS",
Resolve: queryMulticastDNS,
AllowedSecurityLevel: config.DoNotUseMDNS.Level(),
Initialized: abool.NewBool(false),
Source: "static",
LastFail: &cannotFail,
}
}
// sort scopes by length
sort.Slice(localScopes,
func(i, j int) bool {
return len(localScopes[i].Domain) > len(localScopes[j].Domain)
},
)
log.Trace("intel: loaded global resolvers:")
for _, resolver := range globalResolvers {
log.Tracef("intel: %s", resolver.Server)
}
log.Trace("intel: loaded local resolvers:")
for _, resolver := range localResolvers {
log.Tracef("intel: %s", resolver.Server)
}
log.Trace("intel: loaded scopes:")
for _, scope := range localScopes {
var scopeServers []string
for _, resolver := range scope.Resolvers {
scopeServers = append(scopeServers, resolver.Server)
}
log.Tracef("intel: %s: %s", scope.Domain, strings.Join(scopeServers, ", "))
}
}
// Resolve resolves the given query for a domain and type and returns a RRCache object or nil, if the query failed. // Resolve resolves the given query for a domain and type and returns a RRCache object or nil, if the query failed.
func Resolve(fqdn string, qtype dns.Type, securityLevel int8) *RRCache { func Resolve(fqdn string, qtype dns.Type, securityLevel uint8) *RRCache {
fqdn = dns.Fqdn(fqdn) fqdn = dns.Fqdn(fqdn)
// use this to time how long it takes resolve this domain // use this to time how long it takes resolve this domain
// timed := time.Now() // timed := time.Now()
// defer log.Tracef("intel: took %s to get resolve %s%s", time.Now().Sub(timed).String(), fqdn, qtype.String()) // defer log.Tracef("intel: took %s to get resolve %s%s", time.Now().Sub(timed).String(), fqdn, qtype.String())
// handle request for localhost
if fqdn == "localhost." {
var rr dns.RR
var err error
switch uint16(qtype) {
case dns.TypeA:
rr, err = dns.NewRR("localhost. 3600 IN A 127.0.0.1")
case dns.TypeAAAA:
rr, err = dns.NewRR("localhost. 3600 IN AAAA ::1")
default:
return nil
}
if err != nil {
return nil
}
return &RRCache{
Answer: []dns.RR{rr},
}
}
// check cache // check cache
rrCache, err := GetRRCache(fqdn, qtype) rrCache, err := GetRRCache(fqdn, qtype)
if err != nil { if err != nil {
@ -406,7 +88,8 @@ func Resolve(fqdn string, qtype dns.Type, securityLevel int8) *RRCache {
return resolveAndCache(fqdn, qtype, securityLevel) return resolveAndCache(fqdn, qtype, securityLevel)
} }
if rrCache.Expires <= time.Now().Unix() { if rrCache.TTL <= time.Now().Unix() {
log.Tracef("intel: serving cache, requesting new. TTL=%d, now=%d", rrCache.TTL, time.Now().Unix())
rrCache.requestingNew = true rrCache.requestingNew = true
go resolveAndCache(fqdn, qtype, securityLevel) go resolveAndCache(fqdn, qtype, securityLevel)
} }
@ -420,17 +103,9 @@ func Resolve(fqdn string, qtype dns.Type, securityLevel int8) *RRCache {
return rrCache return rrCache
} }
func resolveAndCache(fqdn string, qtype dns.Type, securityLevel int8) *RRCache { func resolveAndCache(fqdn string, qtype dns.Type, securityLevel uint8) (rrCache *RRCache) {
// log.Tracef("intel: resolving %s%s", fqdn, qtype.String()) // log.Tracef("intel: resolving %s%s", fqdn, qtype.String())
rrCache, ok := checkDomainFronting(fqdn, qtype, securityLevel)
if ok {
if rrCache == nil {
return nil
}
return rrCache
}
// dedup requests // dedup requests
dupKey := fmt.Sprintf("%s%s", fqdn, qtype.String()) dupKey := fmt.Sprintf("%s%s", fqdn, qtype.String())
dupReqLock.Lock() dupReqLock.Lock()
@ -456,7 +131,7 @@ func resolveAndCache(fqdn string, qtype dns.Type, securityLevel int8) *RRCache {
} }
defer func() { defer func() {
dupReqLock.Lock() dupReqLock.Lock()
delete(dupReqMap, fqdn) delete(dupReqMap, dupKey)
dupReqLock.Unlock() dupReqLock.Unlock()
mutex.Unlock() mutex.Unlock()
}() }()
@ -469,29 +144,29 @@ func resolveAndCache(fqdn string, qtype dns.Type, securityLevel int8) *RRCache {
// persist to database // persist to database
rrCache.Clean(600) rrCache.Clean(600)
rrCache.CreateWithType(fqdn, qtype) rrCache.Save()
return rrCache return rrCache
} }
func intelligentResolve(fqdn string, qtype dns.Type, securityLevel int8) *RRCache { func intelligentResolve(fqdn string, qtype dns.Type, securityLevel uint8) *RRCache {
// TODO: handle being offline // TODO: handle being offline
// TODO: handle multiple network connections // TODO: handle multiple network connections
if config.Changed() { // TODO: handle these in a separate goroutine
log.Info("intel: config changed, reloading resolvers") // if config.Changed() {
loadResolvers(false) // log.Info("intel: config changed, reloading resolvers")
} else if env.NetworkChanged() { // loadResolvers(false)
log.Info("intel: network changed, reloading resolvers") // } else if env.NetworkChanged() {
loadResolvers(true) // log.Info("intel: network changed, reloading resolvers")
} // loadResolvers(true)
config.RLock() // }
defer config.RUnlock()
resolversLock.RLock() resolversLock.RLock()
defer resolversLock.RUnlock() defer resolversLock.RUnlock()
lastFailBoundary := time.Now().Unix() - config.DNSServerRetryRate lastFailBoundary := time.Now().Unix() - nameserverRetryRate()
preDottedFqdn := "." + fqdn preDottedFqdn := "." + fqdn
// resolve: // resolve:
@ -510,11 +185,14 @@ func intelligentResolve(fqdn string, qtype dns.Type, securityLevel int8) *RRCach
} }
} }
// check config // check config
if config.DoNotUseMDNS.IsSetWithLevel(securityLevel) { if doNotUseMulticastDNS(securityLevel) {
return nil return nil
} }
// try mdns // try mdns
rrCache, _ := tryResolver(mDNSResolver, lastFailBoundary, fqdn, qtype, securityLevel) rrCache, err := queryMulticastDNS(fqdn, qtype)
if err != nil {
log.Errorf("intel: failed to query mdns: %s", err)
}
return rrCache return rrCache
} }
@ -533,15 +211,18 @@ func intelligentResolve(fqdn string, qtype dns.Type, securityLevel int8) *RRCach
switch { switch {
case strings.HasSuffix(preDottedFqdn, ".local."): case strings.HasSuffix(preDottedFqdn, ".local."):
// check config // check config
if config.DoNotUseMDNS.IsSetWithLevel(securityLevel) { if doNotUseMulticastDNS(securityLevel) {
return nil return nil
} }
// try mdns // try mdns
rrCache, _ := tryResolver(mDNSResolver, lastFailBoundary, fqdn, qtype, securityLevel) rrCache, err := queryMulticastDNS(fqdn, qtype)
if err != nil {
log.Errorf("intel: failed to query mdns: %s", err)
}
return rrCache return rrCache
case domainInScopes(preDottedFqdn, specialScopes): case domainInScopes(preDottedFqdn, specialScopes):
// check config // check config
if config.DoNotForwardSpecialDomains.IsSetWithLevel(securityLevel) { if doNotResolveSpecialDomains(securityLevel) {
return nil return nil
} }
// try local resolvers // try local resolvers
@ -568,15 +249,15 @@ func intelligentResolve(fqdn string, qtype dns.Type, securityLevel int8) *RRCach
} }
func tryResolver(resolver *Resolver, lastFailBoundary int64, fqdn string, qtype dns.Type, securityLevel int8) (*RRCache, bool) { func tryResolver(resolver *Resolver, lastFailBoundary int64, fqdn string, qtype dns.Type, securityLevel uint8) (*RRCache, bool) {
// skip if not allowed in current security level // skip if not allowed in current security level
if resolver.AllowedSecurityLevel < config.SecurityLevel() || resolver.AllowedSecurityLevel < securityLevel { if resolver.AllowedSecurityLevel < status.CurrentSecurityLevel() || resolver.AllowedSecurityLevel < securityLevel {
log.Tracef("intel: skipping resolver %s, because it isn't allowed to operate on the current security level: %d|%d", resolver, config.SecurityLevel(), securityLevel) log.Tracef("intel: skipping resolver %s, because it isn't allowed to operate on the current security level: %d|%d", resolver, status.CurrentSecurityLevel(), securityLevel)
return nil, false return nil, false
} }
// skip if not security level denies assigned dns servers // skip if not security level denies assigned dns servers
if config.DoNotUseAssignedDNS.IsSetWithLevel(securityLevel) && resolver.Source == "dhcp" { if doNotUseAssignedNameservers(securityLevel) && resolver.Source == "dhcp" {
log.Tracef("intel: skipping resolver %s, because assigned nameservers are not allowed on the current security level: %d|%d (%d)", resolver, config.SecurityLevel(), securityLevel, int8(config.DoNotUseAssignedDNS)) log.Tracef("intel: skipping resolver %s, because assigned nameservers are not allowed on the current security level: %d|%d", resolver, status.CurrentSecurityLevel(), securityLevel)
return nil, false return nil, false
} }
// check if failed recently // check if failed recently
@ -606,7 +287,7 @@ func tryResolver(resolver *Resolver, lastFailBoundary int64, fqdn string, qtype
} }
// resolve // resolve
log.Tracef("intel: trying to resolve %s%s with %s", fqdn, qtype.String(), resolver.Server) log.Tracef("intel: trying to resolve %s%s with %s", fqdn, qtype.String(), resolver.Server)
rrCache, err := resolver.Resolve(resolver, fqdn, qtype) rrCache, err := query(resolver, fqdn, qtype)
if err != nil { if err != nil {
// check if failing is disabled // check if failing is disabled
if atomic.LoadInt64(resolver.LastFail) == -1 { if atomic.LoadInt64(resolver.LastFail) == -1 {
@ -622,39 +303,62 @@ func tryResolver(resolver *Resolver, lastFailBoundary int64, fqdn string, qtype
return nil, false return nil, false
} }
resolver.Initialized.SetToIf(false, true) resolver.Initialized.SetToIf(false, true)
// remove localhost entries, remove LAN entries if server is in global IP space.
if resolver.ServerIPScope == netutils.Global {
rrCache.FilterEntries(true, false, false)
} else {
rrCache.FilterEntries(true, true, false)
}
return rrCache, true return rrCache, true
} }
func queryDNS(resolver *Resolver, fqdn string, qtype dns.Type) (*RRCache, error) { func query(resolver *Resolver, fqdn string, qtype dns.Type) (*RRCache, error) {
q := new(dns.Msg) q := new(dns.Msg)
q.SetQuestion(fqdn, uint16(qtype)) q.SetQuestion(fqdn, uint16(qtype))
var reply *dns.Msg var reply *dns.Msg
var err error var err error
for i := 0; i < 5; i++ { for i := 0; i < 3; i++ {
client := new(dns.Client)
reply, _, err = client.Exchange(q, resolver.ServerAddress) // log query time
// qStart := time.Now()
reply, _, err = resolver.clientManager.getDNSClient().Exchange(q, resolver.ServerAddress)
// log.Tracef("intel: query to %s took %s", resolver.Server, time.Now().Sub(qStart))
// error handling
if err != nil { if err != nil {
log.Tracef("intel: query to %s encountered error: %s", resolver.Server, err)
// TODO: handle special cases // TODO: handle special cases
// 1. connect: network is unreachable // 1. connect: network is unreachable
// 2. timeout // 2. timeout
// temporary error
if nerr, ok := err.(net.Error); ok && nerr.Timeout() { if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
log.Tracef("intel: retrying to resolve %s%s with %s, error was: %s", fqdn, qtype.String(), resolver.Server, err) log.Tracef("intel: retrying to resolve %s%s with %s, error was: %s", fqdn, qtype.String(), resolver.Server, err)
continue continue
} }
// permanent error
break break
} }
// no error
break
} }
if err != nil { if err != nil {
log.Warningf("resolving %s%s failed: %s", fqdn, qtype.String(), err) err = fmt.Errorf("resolving %s%s failed: %s", fqdn, qtype.String(), err)
return nil, fmt.Errorf("resolving %s%s failed: %s", fqdn, qtype.String(), err) log.Warning(err.Error())
return nil, err
} }
new := &RRCache{ new := &RRCache{
Domain: fqdn,
Question: qtype,
Answer: reply.Answer, Answer: reply.Answer,
Ns: reply.Ns, Ns: reply.Ns,
Extra: reply.Extra, Extra: reply.Extra,
@ -663,85 +367,3 @@ func queryDNS(resolver *Resolver, fqdn string, qtype dns.Type) (*RRCache, error)
// TODO: check if reply.Answer is valid // TODO: check if reply.Answer is valid
return new, nil return new, nil
} }
type DnsOverHttpsReply struct {
Status uint32
Truncated bool `json:"TC"`
Answer []DohRR
Additional []DohRR
}
type DohRR struct {
Name string `json:"name"`
Qtype uint16 `json:"type"`
TTL uint32 `json:"TTL"`
Data string `json:"data"`
}
func queryDNSoverHTTPS(resolver *Resolver, fqdn string, qtype dns.Type) (*RRCache, error) {
// API documentation: https://developers.google.com/speed/public-dns/docs/dns-over-https
payload := url.Values{}
payload.Add("name", fqdn)
payload.Add("type", strconv.Itoa(int(qtype)))
payload.Add("edns_client_subnet", "0.0.0.0/0")
// TODO: add random - only use upper- and lower-case letters, digits, hyphen, period, underscore and tilde
// payload.Add("random_padding", "")
resp, err := resolver.HTTPClient.Get(fmt.Sprintf("https://%s/resolve?%s", resolver.ServerAddress, payload.Encode()))
if err != nil {
return nil, fmt.Errorf("resolving %s%s failed: http error: %s", fqdn, qtype.String(), err)
// TODO: handle special cases
// 1. connect: network is unreachable
// intel: resolver DoH|dns.google.com:443|df:www.google.com failed (resolving discovery-v4-4.syncthing.net.A failed: http error: Get https://dns.google.com:443/resolve?edns_client_subnet=0.0.0.0%2F0&name=discovery-v4-4.syncthing.net.&type=1: dial tcp [2a00:1450:4001:819::2004]:443: connect: network is unreachable), moving to next
// 2. timeout
}
if resp.StatusCode != 200 {
return nil, fmt.Errorf("resolving %s%s failed: request was unsuccessful, got code %d", fqdn, qtype.String(), resp.StatusCode)
}
defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("resolving %s%s failed: error reading response body: %s", fqdn, qtype.String(), err)
}
var reply DnsOverHttpsReply
err = json.Unmarshal(body, &reply)
if err != nil {
return nil, fmt.Errorf("resolving %s%s failed: error parsing response body: %s", fqdn, qtype.String(), err)
}
if reply.Status != 0 {
// this happens if there is a server error (e.g. DNSSEC fail), ignore for now
// TODO: do something more intelligent
}
new := new(RRCache)
// TODO: handle TXT records
for _, entry := range reply.Answer {
rr, err := dns.NewRR(fmt.Sprintf("%s %d IN %s %s", entry.Name, entry.TTL, dns.Type(entry.Qtype).String(), entry.Data))
if err != nil {
log.Warningf("intel: resolving %s%s failed: failed to parse record to DNS: %s %d IN %s %s", fqdn, qtype.String(), entry.Name, entry.TTL, dns.Type(entry.Qtype).String(), entry.Data)
continue
}
new.Answer = append(new.Answer, rr)
}
for _, entry := range reply.Additional {
rr, err := dns.NewRR(fmt.Sprintf("%s %d IN %s %s", entry.Name, entry.TTL, dns.Type(entry.Qtype).String(), entry.Data))
if err != nil {
log.Warningf("intel: resolving %s%s failed: failed to parse record to DNS: %s %d IN %s %s", fqdn, qtype.String(), entry.Name, entry.TTL, dns.Type(entry.Qtype).String(), entry.Data)
continue
}
new.Extra = append(new.Extra, rr)
}
return new, nil
}
// TODO: implement T-DNS: DNS over TCP/TLS
// server list: https://dnsprivacy.org/wiki/display/DP/DNS+Privacy+Test+Servers

View file

@ -2,14 +2,16 @@
package intel package intel
import ( // DISABLE TESTING FOR NOW: find a way to have tests with the module system
"testing"
"time"
"github.com/miekg/dns" // import (
) // "testing"
// "time"
//
// "github.com/miekg/dns"
// )
func TestResolve(t *testing.T) { // func TestResolve(t *testing.T) {
Resolve("google.com.", dns.Type(dns.TypeA), 0) // Resolve("google.com.", dns.Type(dns.TypeA), 0)
time.Sleep(200 * time.Millisecond) // time.Sleep(200 * time.Millisecond)
} // }

295
intel/resolver.go Normal file
View file

@ -0,0 +1,295 @@
package intel
import (
"errors"
"fmt"
"net"
"sort"
"strconv"
"strings"
"sync"
"github.com/miekg/dns"
"github.com/tevino/abool"
"github.com/Safing/portbase/log"
"github.com/Safing/portmaster/network/environment"
"github.com/Safing/portmaster/network/netutils"
"github.com/Safing/portmaster/status"
)
// Resolver holds information about an active resolver.
type Resolver struct {
// static
Server string
ServerType string
ServerAddress string
ServerIP net.IP
ServerIPScope int8
ServerPort uint16
VerifyDomain string
Source string
clientManager *clientManager
Search *[]string
AllowedSecurityLevel uint8
SkipFqdnBeforeInit string
// atomic
Initialized *abool.AtomicBool
InitLock sync.Mutex
LastFail *int64
Expires *int64
// must be locked
LockReason sync.Mutex
FailReason string
// TODO: add:
// Expiration (for server got from DHCP / ICMPv6)
// bootstrapping (first query is already sent, wait for it to either succeed or fail - think about http bootstrapping here!)
// expanded server info: type, server address, server port, options - so we do not have to parse this every time!
}
func (r *Resolver) String() string {
return r.Server
}
// Scope defines a domain scope and which resolvers can resolve it.
type Scope struct {
Domain string
Resolvers []*Resolver
}
var (
globalResolvers []*Resolver // all resolvers
localResolvers []*Resolver // all resolvers that are in site-local or link-local IP ranges
localScopes []*Scope // list of scopes with a list of local resolvers that can resolve the scope
resolversLock sync.RWMutex
env = environment.NewInterface()
dupReqMap = make(map[string]*sync.Mutex)
dupReqLock sync.Mutex
)
func indexOfResolver(server string, list []*Resolver) int {
for k, v := range list {
if v.Server == server {
return k
}
}
return -1
}
func indexOfScope(domain string, list []*Scope) int {
for k, v := range list {
if v.Domain == domain {
return k
}
}
return -1
}
func parseAddress(server string) (net.IP, uint16, error) {
delimiter := strings.LastIndex(server, ":")
if delimiter < 0 {
return nil, 0, errors.New("port missing")
}
ip := net.ParseIP(strings.Trim(server[:delimiter], "[]"))
if ip == nil {
return nil, 0, errors.New("invalid IP address")
}
port, err := strconv.Atoi(server[delimiter+1:])
if err != nil || port < 1 || port > 65536 {
return nil, 0, errors.New("invalid port")
}
return ip, uint16(port), nil
}
func urlFormatAddress(ip net.IP, port uint16) string {
var address string
if ipv4 := ip.To4(); ipv4 != nil {
address = fmt.Sprintf("%s:%d", ipv4.String(), port)
} else {
address = fmt.Sprintf("[%s]:%d", ip.String(), port)
}
return address
}
func loadResolvers(resetResolvers bool) {
// TODO: what happens when a lot of processes want to reload at once? we do not need to run this multiple times in a short time frame.
resolversLock.Lock()
defer resolversLock.Unlock()
var newResolvers []*Resolver
configuredServersLoop:
for _, server := range configuredNameServers() {
key := indexOfResolver(server, newResolvers)
if key >= 0 {
continue configuredServersLoop
}
key = indexOfResolver(server, globalResolvers)
if resetResolvers || key == -1 {
parts := strings.Split(server, "|")
if len(parts) < 2 {
log.Warningf("intel: nameserver format invalid: %s", server)
continue configuredServersLoop
}
ip, port, err := parseAddress(parts[1])
if err != nil && strings.ToLower(parts[0]) != "https" {
log.Warningf("intel: nameserver (%s) address invalid: %s", server, err)
continue configuredServersLoop
}
var lastFail int64
new := &Resolver{
Server: server,
ServerType: parts[0],
ServerAddress: parts[1],
ServerIP: ip,
ServerIPScope: netutils.ClassifyIP(ip),
ServerPort: port,
LastFail: &lastFail,
Source: "config",
Initialized: abool.NewBool(false),
}
switch strings.ToLower(parts[0]) {
case "dns":
new.clientManager = newDNSClientManager(new)
case "tcp":
new.clientManager = newTCPClientManager(new)
case "tls":
new.AllowedSecurityLevel = status.SecurityLevelFortress
if len(parts) < 3 {
log.Warningf("intel: nameserver missing verification domain as third parameter: %s", server)
continue configuredServersLoop
}
new.VerifyDomain = parts[2]
new.clientManager = newTLSClientManager(new)
case "https":
new.AllowedSecurityLevel = status.SecurityLevelFortress
new.SkipFqdnBeforeInit = dns.Fqdn(strings.Split(parts[1], ":")[0])
if len(parts) > 2 {
new.VerifyDomain = parts[2]
}
new.clientManager = newHTTPSClientManager(new)
default:
log.Warningf("intel: nameserver (%s) type invalid: %s", server, parts[0])
continue configuredServersLoop
}
newResolvers = append(newResolvers, new)
} else {
newResolvers = append(newResolvers, globalResolvers[key])
}
}
// add local resolvers
assignedNameservers := environment.Nameservers()
assignedServersLoop:
for _, nameserver := range assignedNameservers {
server := fmt.Sprintf("dns|%s", urlFormatAddress(nameserver.IP, 53))
key := indexOfResolver(server, newResolvers)
if key >= 0 {
continue assignedServersLoop
}
key = indexOfResolver(server, globalResolvers)
if resetResolvers || key == -1 {
var lastFail int64
new := &Resolver{
Server: server,
ServerType: "dns",
ServerAddress: urlFormatAddress(nameserver.IP, 53),
ServerIP: nameserver.IP,
ServerIPScope: netutils.ClassifyIP(nameserver.IP),
ServerPort: 53,
LastFail: &lastFail,
Source: "dhcp",
Initialized: abool.NewBool(false),
AllowedSecurityLevel: status.SecurityLevelSecure,
}
new.clientManager = newDNSClientManager(new)
if netutils.IPIsLAN(nameserver.IP) && len(nameserver.Search) > 0 {
// only allow searches for local resolvers
var newSearch []string
for _, value := range nameserver.Search {
newSearch = append(newSearch, fmt.Sprintf(".%s.", strings.Trim(value, ".")))
}
new.Search = &newSearch
}
newResolvers = append(newResolvers, new)
} else {
newResolvers = append(newResolvers, globalResolvers[key])
}
}
// save resolvers
globalResolvers = newResolvers
if len(globalResolvers) == 0 {
log.Criticalf("intel: no (valid) dns servers found in configuration and system")
}
// make list with local resolvers
localResolvers = make([]*Resolver, 0)
for _, resolver := range globalResolvers {
if resolver.ServerIP != nil && netutils.IPIsLAN(resolver.ServerIP) {
localResolvers = append(localResolvers, resolver)
}
}
// add resolvers to every scope the cover
localScopes = make([]*Scope, 0)
for _, resolver := range globalResolvers {
if resolver.Search != nil {
// add resolver to custom searches
for _, search := range *resolver.Search {
if search == "." {
continue
}
key := indexOfScope(search, localScopes)
if key == -1 {
localScopes = append(localScopes, &Scope{
Domain: search,
Resolvers: []*Resolver{resolver},
})
} else {
localScopes[key].Resolvers = append(localScopes[key].Resolvers, resolver)
}
}
}
}
// sort scopes by length
sort.Slice(localScopes,
func(i, j int) bool {
return len(localScopes[i].Domain) > len(localScopes[j].Domain)
},
)
log.Trace("intel: loaded global resolvers:")
for _, resolver := range globalResolvers {
log.Tracef("intel: %s", resolver.Server)
}
log.Trace("intel: loaded local resolvers:")
for _, resolver := range localResolvers {
log.Tracef("intel: %s", resolver.Server)
}
log.Trace("intel: loaded scopes:")
for _, scope := range localScopes {
var scopeServers []string
for _, resolver := range scope.Resolvers {
scopeServers = append(scopeServers, resolver.Server)
}
log.Tracef("intel: %s: %s", scope.Domain, strings.Join(scopeServers, ", "))
}
}

72
intel/reverse.go Normal file
View file

@ -0,0 +1,72 @@
package intel
import (
"errors"
"strings"
"github.com/Safing/portbase/log"
"github.com/miekg/dns"
)
// ResolveIPAndValidate finds (reverse DNS), validates (forward DNS) and returns the domain name assigned to the given IP.
func ResolveIPAndValidate(ip string, securityLevel uint8) (domain string, err error) {
// get reversed DNS address
rQ, err := dns.ReverseAddr(ip)
if err != nil {
log.Tracef("intel: failed to get reverse address of %s: %s", ip, err)
return "", err
}
// get PTR record
rrCache := Resolve(rQ, dns.Type(dns.TypePTR), securityLevel)
if rrCache == nil {
return "", errors.New("querying for PTR record failed (may be NXDomain)")
}
// get result from record
var ptrName string
for _, rr := range rrCache.Answer {
ptrRec, ok := rr.(*dns.PTR)
if ok {
ptrName = ptrRec.Ptr
break
}
}
// check for nxDomain
if ptrName == "" {
return "", errors.New("no PTR record for IP (nxDomain)")
}
log.Infof("ptrName: %s", ptrName)
// get forward record
if strings.Contains(ip, ":") {
rrCache = Resolve(ptrName, dns.Type(dns.TypeAAAA), securityLevel)
} else {
rrCache = Resolve(ptrName, dns.Type(dns.TypeA), securityLevel)
}
if rrCache == nil {
return "", errors.New("querying for A/AAAA record failed (may be NXDomain)")
}
// check for matching A/AAAA record
log.Infof("rr: %s", rrCache)
for _, rr := range rrCache.Answer {
switch v := rr.(type) {
case *dns.A:
log.Infof("A: %s", v.A.String())
if ip == v.A.String() {
return ptrName, nil
}
case *dns.AAAA:
log.Infof("AAAA: %s", v.AAAA.String())
if ip == v.AAAA.String() {
return ptrName, nil
}
}
}
// no match
return "", errors.New("validation failed")
}

28
intel/reverse_test.go Normal file
View file

@ -0,0 +1,28 @@
package intel
import "testing"
func testReverse(t *testing.T, ip, result, expectedErr string) {
domain, err := ResolveIPAndValidate(ip, 0)
if err != nil {
if expectedErr == "" || err.Error() != expectedErr {
t.Errorf("reverse-validating %s: unexpected error: %s", ip, err)
}
return
}
if domain != result {
t.Errorf("reverse-validating %s: unexpected result: %s", ip, domain)
}
}
func TestResolveIPAndValidate(t *testing.T) {
testReverse(t, "198.41.0.4", "a.root-servers.net.", "")
testReverse(t, "9.9.9.9", "dns.quad9.net.", "")
testReverse(t, "2620:fe::fe", "dns.quad9.net.", "")
testReverse(t, "1.1.1.1", "one.one.one.one.", "")
testReverse(t, "2606:4700:4700::1111", "one.one.one.one.", "")
testReverse(t, "93.184.216.34", "example.com.", "no PTR record for IP (nxDomain)")
testReverse(t, "185.199.109.153", "sites.github.io.", "no PTR record for IP (nxDomain)")
}

256
intel/rrcache.go Normal file
View file

@ -0,0 +1,256 @@
// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file.
package intel
import (
"fmt"
"net"
"strings"
"time"
"github.com/Safing/portbase/log"
"github.com/Safing/portmaster/network/netutils"
"github.com/miekg/dns"
)
// RRCache is used to cache DNS data
type RRCache struct {
Domain string
Question dns.Type
Answer []dns.RR
Ns []dns.RR
Extra []dns.RR
TTL int64
updated int64
servedFromCache bool
requestingNew bool
Filtered bool
}
// Clean sets all TTLs to 17 and sets cache expiry with specified minimum.
func (m *RRCache) Clean(minExpires uint32) {
var lowestTTL uint32 = 0xFFFFFFFF
var header *dns.RR_Header
// set TTLs to 17
// TODO: double append? is there something more elegant?
for _, rr := range append(m.Answer, append(m.Ns, m.Extra...)...) {
header = rr.Header()
if lowestTTL > header.Ttl {
lowestTTL = header.Ttl
}
header.Ttl = 17
}
// TTL must be at least minExpires
if lowestTTL < minExpires {
lowestTTL = minExpires
}
// log.Tracef("lowest TTL is %d", lowestTTL)
m.TTL = time.Now().Unix() + int64(lowestTTL)
}
// ExportAllARecords return of a list of all A and AAAA IP addresses.
func (m *RRCache) ExportAllARecords() (ips []net.IP) {
for _, rr := range m.Answer {
if rr.Header().Class != dns.ClassINET {
continue
}
switch rr.Header().Rrtype {
case dns.TypeA:
aRecord, ok := rr.(*dns.A)
if ok {
ips = append(ips, aRecord.A)
}
case dns.TypeAAAA:
aaaaRecord, ok := rr.(*dns.AAAA)
if ok {
ips = append(ips, aaaaRecord.AAAA)
}
}
}
return
}
// ToNameRecord converts the RRCache to a NameRecord for cleaner persistence.
func (m *RRCache) ToNameRecord() *NameRecord {
new := &NameRecord{
Domain: m.Domain,
Question: m.Question.String(),
TTL: m.TTL,
Filtered: m.Filtered,
}
// stringify RR entries
for _, entry := range m.Answer {
new.Answer = append(new.Answer, entry.String())
}
for _, entry := range m.Ns {
new.Ns = append(new.Ns, entry.String())
}
for _, entry := range m.Extra {
new.Extra = append(new.Extra, entry.String())
}
return new
}
// Save saves the RRCache to the database as a NameRecord.
func (m *RRCache) Save() error {
return m.ToNameRecord().Save()
}
// GetRRCache tries to load the corresponding NameRecord from the database and convert it.
func GetRRCache(domain string, question dns.Type) (*RRCache, error) {
rrCache := &RRCache{
Domain: domain,
Question: question,
}
nameRecord, err := GetNameRecord(domain, question.String())
if err != nil {
return nil, err
}
rrCache.TTL = nameRecord.TTL
for _, entry := range nameRecord.Answer {
rr, err := dns.NewRR(entry)
if err == nil {
rrCache.Answer = append(rrCache.Answer, rr)
}
}
for _, entry := range nameRecord.Ns {
rr, err := dns.NewRR(entry)
if err == nil {
rrCache.Ns = append(rrCache.Ns, rr)
}
}
for _, entry := range nameRecord.Extra {
rr, err := dns.NewRR(entry)
if err == nil {
rrCache.Extra = append(rrCache.Extra, rr)
}
}
rrCache.Filtered = nameRecord.Filtered
rrCache.servedFromCache = true
return rrCache, nil
}
// ServedFromCache marks the RRCache as served from cache.
func (m *RRCache) ServedFromCache() bool {
return m.servedFromCache
}
// RequestingNew informs that it has expired and new RRs are being fetched.
func (m *RRCache) RequestingNew() bool {
return m.requestingNew
}
// Flags formats ServedFromCache and RequestingNew to a condensed, flag-like format.
func (m *RRCache) Flags() string {
var s string
if m.servedFromCache {
s += "C"
}
if m.requestingNew {
s += "R"
}
if m.Filtered {
s += "F"
}
if s != "" {
return fmt.Sprintf(" [%s]", s)
}
return ""
}
// IsNXDomain returnes whether the result is nxdomain.
func (m *RRCache) IsNXDomain() bool {
return len(m.Answer) == 0
}
// Duplicate returns a duplicate of the cache. slices are not copied, but referenced.
func (m *RRCache) Duplicate() *RRCache {
return &RRCache{
Domain: m.Domain,
Question: m.Question,
Answer: m.Answer,
Ns: m.Ns,
Extra: m.Extra,
TTL: m.TTL,
updated: m.updated,
servedFromCache: m.servedFromCache,
requestingNew: m.requestingNew,
Filtered: m.Filtered,
}
}
// FilterEntries filters resource records according to the given permission scope.
func (m *RRCache) FilterEntries(internet, lan, host bool) {
var filtered bool
m.Answer, filtered = filterEntries(m, m.Answer, internet, lan, host)
if filtered {
m.Filtered = true
}
m.Extra, filtered = filterEntries(m, m.Extra, internet, lan, host)
if filtered {
m.Filtered = true
}
}
func filterEntries(m *RRCache, entries []dns.RR, internet, lan, host bool) (filteredEntries []dns.RR, filtered bool) {
filteredEntries = make([]dns.RR, 0, len(entries))
var classification int8
var deletedEntries []string
entryLoop:
for _, rr := range entries {
classification = -1
switch v := rr.(type) {
case *dns.A:
classification = netutils.ClassifyIP(v.A)
case *dns.AAAA:
classification = netutils.ClassifyIP(v.AAAA)
}
if classification >= 0 {
switch {
case !internet && classification == netutils.Global:
filtered = true
deletedEntries = append(deletedEntries, rr.String())
continue entryLoop
case !lan && (classification == netutils.SiteLocal || classification == netutils.LinkLocal):
filtered = true
deletedEntries = append(deletedEntries, rr.String())
continue entryLoop
case !host && classification == netutils.HostLocal:
filtered = true
deletedEntries = append(deletedEntries, rr.String())
continue entryLoop
}
}
filteredEntries = append(filteredEntries, rr)
}
if len(deletedEntries) > 0 {
log.Infof("intel: filtered DNS replies for %s%s: %s (Settings: Int=%v LAN=%v Host=%v)",
m.Domain,
m.Question.String(),
strings.Join(deletedEntries, ", "),
internet,
lan,
host,
)
}
return
}

View file

@ -5,7 +5,7 @@ package intel
import "strings" import "strings"
var ( var (
localReverseScopes = &[]string{ localReverseScopes = []string{
".10.in-addr.arpa.", ".10.in-addr.arpa.",
".16.172.in-addr.arpa.", ".16.172.in-addr.arpa.",
".17.172.in-addr.arpa.", ".17.172.in-addr.arpa.",
@ -31,7 +31,8 @@ var (
".b.e.f.ip6.arpa.", ".b.e.f.ip6.arpa.",
} }
specialScopes = &[]string{ // RFC6761, RFC7686
specialScopes = []string{
".example.", ".example.",
".example.com.", ".example.com.",
".example.net.", ".example.net.",
@ -42,8 +43,8 @@ var (
} }
) )
func domainInScopes(fqdn string, list *[]string) bool { func domainInScopes(fqdn string, list []string) bool {
for _, scope := range *list { for _, scope := range list {
if strings.HasSuffix(fqdn, scope) { if strings.HasSuffix(fqdn, scope) {
return true return true
} }

87
main.go Normal file
View file

@ -0,0 +1,87 @@
package main
import (
"flag"
"fmt"
"os"
"os/signal"
"runtime/pprof"
"syscall"
"time"
"github.com/Safing/portbase/info"
"github.com/Safing/portbase/log"
"github.com/Safing/portbase/modules"
// include packages here
_ "github.com/Safing/portbase/api"
_ "github.com/Safing/portbase/database/dbmodule"
_ "github.com/Safing/portbase/database/storage/badger"
_ "github.com/Safing/portmaster/firewall"
_ "github.com/Safing/portmaster/nameserver"
_ "github.com/Safing/portmaster/ui"
)
var (
printStackOnExit bool
)
func init() {
flag.BoolVar(&printStackOnExit, "print-stack-on-exit", false, "prints the stack before of shutting down")
}
func main() {
// Set Info
info.Set("Portmaster", "0.2.0")
// Start
err := modules.Start()
if err != nil {
if err == modules.ErrCleanExit {
os.Exit(0)
} else {
err = modules.Shutdown()
if err != nil {
log.Shutdown()
}
os.Exit(1)
}
}
// Shutdown
// catch interrupt for clean shutdown
signalCh := make(chan os.Signal)
signal.Notify(
signalCh,
os.Interrupt,
syscall.SIGHUP,
syscall.SIGINT,
syscall.SIGTERM,
syscall.SIGQUIT,
)
select {
case <-signalCh:
fmt.Println(" <INTERRUPT>")
log.Warning("main: program was interrupted, shutting down.")
if printStackOnExit {
fmt.Println("=== PRINTING STACK ===")
pprof.Lookup("goroutine").WriteTo(os.Stdout, 1)
fmt.Println("=== END STACK ===")
}
go func() {
time.Sleep(3 * time.Second)
fmt.Println("===== TAKING TOO LONG FOR SHUTDOWN - PRINTING STACK TRACES =====")
pprof.Lookup("goroutine").WriteTo(os.Stdout, 2)
os.Exit(1)
}()
modules.Shutdown()
os.Exit(0)
case <-modules.ShuttingDown():
}
}

View file

@ -4,38 +4,60 @@ package nameserver
import ( import (
"net" "net"
"time"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/Safing/safing-core/analytics/algs" "github.com/Safing/portbase/log"
"github.com/Safing/safing-core/intel" "github.com/Safing/portbase/modules"
"github.com/Safing/safing-core/log"
"github.com/Safing/safing-core/modules" "github.com/Safing/portmaster/analytics/algs"
"github.com/Safing/safing-core/network" "github.com/Safing/portmaster/firewall"
"github.com/Safing/safing-core/network/netutils" "github.com/Safing/portmaster/intel"
"github.com/Safing/safing-core/portmaster" "github.com/Safing/portmaster/network"
"github.com/Safing/portmaster/network/netutils"
) )
var ( var (
nameserverModule *modules.Module localhostIPs []dns.RR
) )
func init() { func init() {
nameserverModule = modules.Register("Nameserver", 128) modules.Register("nameserver", prep, start, nil, "intel")
} }
func Start() { func prep() error {
localhostIPv4, err := dns.NewRR("localhost. 17 IN A 127.0.0.1")
if err != nil {
return err
}
localhostIPv6, err := dns.NewRR("localhost. 17 IN AAAA ::1")
if err != nil {
return err
}
localhostIPs = []dns.RR{localhostIPv4, localhostIPv6}
return nil
}
func start() error {
server := &dns.Server{Addr: "127.0.0.1:53", Net: "udp"} server := &dns.Server{Addr: "127.0.0.1:53", Net: "udp"}
dns.HandleFunc(".", handleRequest) dns.HandleFunc(".", handleRequest)
go func() { go run(server)
return nil
}
func run(server *dns.Server) {
for {
err := server.ListenAndServe() err := server.ListenAndServe()
if err != nil { if err != nil {
log.Errorf("nameserver: server failed: %s", err) log.Errorf("nameserver: server failed: %s", err)
log.Info("nameserver: restarting server in 10 seconds")
time.Sleep(10 * time.Second)
}
} }
}()
// TODO: stop mocking
defer nameserverModule.StopComplete()
<-nameserverModule.Stop
} }
func nxDomain(w dns.ResponseWriter, query *dns.Msg) { func nxDomain(w dns.ResponseWriter, query *dns.Msg) {
@ -47,7 +69,6 @@ func nxDomain(w dns.ResponseWriter, query *dns.Msg) {
func handleRequest(w dns.ResponseWriter, query *dns.Msg) { func handleRequest(w dns.ResponseWriter, query *dns.Msg) {
// TODO: if there are 3 request for the same domain/type in a row, delete all caches of that domain // TODO: if there are 3 request for the same domain/type in a row, delete all caches of that domain
// TODO: handle securityLevelOff
// only process first question, that's how everyone does it. // only process first question, that's how everyone does it.
question := query.Question[0] question := query.Question[0]
@ -82,6 +103,14 @@ func handleRequest(w dns.ResponseWriter, query *dns.Msg) {
return return
} }
// handle request for localhost
if fqdn == "localhost." {
m := new(dns.Msg)
m.SetReply(query)
m.Answer = localhostIPs
w.WriteMsg(m)
}
// get remote address // get remote address
// start := time.Now() // start := time.Now()
rAddr, ok := w.RemoteAddr().(*net.UDPAddr) rAddr, ok := w.RemoteAddr().(*net.UDPAddr)
@ -109,19 +138,19 @@ func handleRequest(w dns.ResponseWriter, query *dns.Msg) {
// log.Tracef("nameserver: took %s to get connection/process of %s request", time.Now().Sub(timed).String(), fqdn) // log.Tracef("nameserver: took %s to get connection/process of %s request", time.Now().Sub(timed).String(), fqdn)
// check profile before we even get intel and rr // check profile before we even get intel and rr
if connection.Verdict == network.UNDECIDED { if connection.GetVerdict() == network.UNDECIDED {
// start = time.Now() // start = time.Now()
portmaster.DecideOnConnectionBeforeIntel(connection, fqdn) firewall.DecideOnConnectionBeforeIntel(connection, fqdn)
// log.Tracef("nameserver: took %s to make decision", time.Since(start)) // log.Tracef("nameserver: took %s to make decision", time.Since(start))
} }
if connection.Verdict == network.BLOCK || connection.Verdict == network.DROP { if connection.GetVerdict() == network.BLOCK || connection.GetVerdict() == network.DROP {
nxDomain(w, query) nxDomain(w, query)
return return
} }
// get intel and RRs // get intel and RRs
// start = time.Now() // start = time.Now()
domainIntel, rrCache := intel.GetIntelAndRRs(fqdn, qtype, connection.Process().Profile.SecurityLevel) domainIntel, rrCache := intel.GetIntelAndRRs(fqdn, qtype, connection.Process().ProfileSet().SecurityLevel())
// log.Tracef("nameserver: took %s to get intel and RRs", time.Since(start)) // log.Tracef("nameserver: took %s to get intel and RRs", time.Since(start))
if rrCache == nil { if rrCache == nil {
// TODO: analyze nxdomain requests, malware could be trying DGA-domains // TODO: analyze nxdomain requests, malware could be trying DGA-domains
@ -131,14 +160,16 @@ func handleRequest(w dns.ResponseWriter, query *dns.Msg) {
} }
// set intel // set intel
connection.Lock()
connection.Intel = domainIntel connection.Intel = domainIntel
connection.Unlock()
connection.Save() connection.Save()
// do a full check with intel // do a full check with intel
if connection.Verdict == network.UNDECIDED { if connection.GetVerdict() == network.UNDECIDED {
rrCache = portmaster.DecideOnConnectionAfterIntel(connection, fqdn, rrCache) rrCache = firewall.DecideOnConnectionAfterIntel(connection, fqdn, rrCache)
} }
if rrCache == nil || connection.Verdict == network.BLOCK || connection.Verdict == network.DROP { if rrCache == nil || connection.GetVerdict() == network.BLOCK || connection.GetVerdict() == network.DROP {
nxDomain(w, query) nxDomain(w, query)
return return
} }
@ -150,23 +181,27 @@ func handleRequest(w dns.ResponseWriter, query *dns.Msg) {
ipInfo, err := intel.GetIPInfo(v.A.String()) ipInfo, err := intel.GetIPInfo(v.A.String())
if err != nil { if err != nil {
ipInfo = &intel.IPInfo{ ipInfo = &intel.IPInfo{
IP: v.A.String(),
Domains: []string{fqdn}, Domains: []string{fqdn},
} }
ipInfo.Create(v.A.String())
} else {
ipInfo.Domains = append(ipInfo.Domains, fqdn)
ipInfo.Save() ipInfo.Save()
} else {
if ipInfo.AddDomain(fqdn) {
ipInfo.Save()
}
} }
case *dns.AAAA: case *dns.AAAA:
ipInfo, err := intel.GetIPInfo(v.AAAA.String()) ipInfo, err := intel.GetIPInfo(v.AAAA.String())
if err != nil { if err != nil {
ipInfo = &intel.IPInfo{ ipInfo = &intel.IPInfo{
IP: v.AAAA.String(),
Domains: []string{fqdn}, Domains: []string{fqdn},
} }
ipInfo.Create(v.AAAA.String())
} else {
ipInfo.Domains = append(ipInfo.Domains, fqdn)
ipInfo.Save() ipInfo.Save()
} else {
if ipInfo.AddDomain(fqdn) {
ipInfo.Save()
}
} }
} }
} }

View file

@ -0,0 +1,100 @@
package only
import (
"time"
"github.com/miekg/dns"
"github.com/Safing/portbase/log"
"github.com/Safing/portbase/modules"
"github.com/Safing/portmaster/analytics/algs"
"github.com/Safing/portmaster/intel"
"github.com/Safing/portmaster/network/netutils"
)
func init() {
modules.Register("nameserver", nil, start, nil, "intel")
}
func start() error {
server := &dns.Server{Addr: "127.0.0.1:53", Net: "udp"}
dns.HandleFunc(".", handleRequest)
go run(server)
return nil
}
func run(server *dns.Server) {
for {
err := server.ListenAndServe()
if err != nil {
log.Errorf("nameserver: server failed: %s", err)
log.Info("nameserver: restarting server in 10 seconds")
time.Sleep(10 * time.Second)
}
}
}
func nxDomain(w dns.ResponseWriter, query *dns.Msg) {
m := new(dns.Msg)
m.SetRcode(query, dns.RcodeNameError)
w.WriteMsg(m)
}
func handleRequest(w dns.ResponseWriter, query *dns.Msg) {
// TODO: if there are 3 request for the same domain/type in a row, delete all caches of that domain
// TODO: handle securityLevelOff
// only process first question, that's how everyone does it.
question := query.Question[0]
fqdn := dns.Fqdn(question.Name)
qtype := dns.Type(question.Qtype)
// use this to time how long it takes process this request
// timed := time.Now()
// defer log.Tracef("nameserver: took %s to handle request for %s%s", time.Now().Sub(timed).String(), fqdn, qtype.String())
// check if valid domain name
if !netutils.IsValidFqdn(fqdn) {
log.Tracef("nameserver: domain name %s is invalid, returning nxdomain", fqdn)
nxDomain(w, query)
return
}
// check for possible DNS tunneling / data transmission
// TODO: improve this
lms := algs.LmsScoreOfDomain(fqdn)
// log.Tracef("nameserver: domain %s has lms score of %f", fqdn, lms)
if lms < 10 {
log.Tracef("nameserver: possible data tunnel: %s has lms score of %f, returning nxdomain", fqdn, lms)
nxDomain(w, query)
return
}
// check class
if question.Qclass != dns.ClassINET {
// we only serve IN records, send NXDOMAIN
nxDomain(w, query)
return
}
// get intel and RRs
// start = time.Now()
_, rrCache := intel.GetIntelAndRRs(fqdn, qtype, 0)
// log.Tracef("nameserver: took %s to get intel and RRs", time.Since(start))
if rrCache == nil {
// TODO: analyze nxdomain requests, malware could be trying DGA-domains
log.Infof("nameserver: %s is nxdomain", fqdn)
nxDomain(w, query)
return
}
// reply to query
m := new(dns.Msg)
m.SetReply(query)
m.Answer = rrCache.Answer
m.Ns = rrCache.Ns
m.Extra = rrCache.Extra
w.WriteMsg(m)
}

View file

@ -5,34 +5,52 @@ package network
import ( import (
"time" "time"
"github.com/Safing/safing-core/process" "github.com/Safing/portmaster/process"
) )
func init() { var (
go cleaner() cleanerTickDuration = 10 * time.Second
} deadLinksTimeout = 3 * time.Minute
thresholdDuration = 3 * time.Minute
)
func cleaner() { func cleaner() {
time.Sleep(15 * time.Second)
for { for {
markDeadLinks() time.Sleep(cleanerTickDuration)
purgeDeadFor(5 * time.Minute)
time.Sleep(15 * time.Second) cleanLinks()
time.Sleep(2 * time.Second)
cleanConnections()
time.Sleep(2 * time.Second)
cleanProcesses()
} }
} }
func markDeadLinks() { func cleanLinks() {
activeIDs := process.GetActiveConnectionIDs() activeIDs := process.GetActiveConnectionIDs()
allLinksLock.RLock()
defer allLinksLock.RUnlock()
now := time.Now().Unix() now := time.Now().Unix()
var found bool deleteOlderThan := time.Now().Add(-deadLinksTimeout).Unix()
for key, link := range allLinks {
// skip dead links // log.Tracef("network.clean: now=%d", now)
// log.Tracef("network.clean: deleteOlderThan=%d", deleteOlderThan)
linksLock.RLock()
defer linksLock.RUnlock()
var found bool
for key, link := range links {
// delete dead links
if link.Ended > 0 { if link.Ended > 0 {
link.Lock()
deleteThis := link.Ended < deleteOlderThan
link.Unlock()
if deleteThis {
// log.Tracef("network.clean: deleted %s", link.DatabaseKey())
go link.Delete()
}
continue continue
} }
@ -48,56 +66,28 @@ func markDeadLinks() {
// mark end time // mark end time
if !found { if !found {
link.Ended = now link.Ended = now
link.Save() // log.Tracef("network.clean: marked %s as ended.", link.DatabaseKey())
go link.Save()
} }
} }
} }
func purgeDeadFor(age time.Duration) { func cleanConnections() {
connections := make(map[*Connection]bool) connectionsLock.RLock()
processes := make(map[*process.Process]bool) defer connectionsLock.RUnlock()
allLinksLock.Lock() threshold := time.Now().Add(-thresholdDuration).Unix()
defer allLinksLock.Unlock() for _, conn := range connections {
conn.Lock()
// delete old dead links if conn.FirstLinkEstablished < threshold && conn.LinkCount == 0 {
// make a list of connections without links // log.Tracef("network.clean: deleted %s", conn.DatabaseKey())
ageAgo := time.Now().Add(-1 * age).Unix() go conn.Delete()
for key, link := range allLinks {
if link.Ended != 0 && link.Ended < ageAgo {
link.Delete()
delete(allLinks, key)
_, ok := connections[link.Connection()]
if !ok {
connections[link.Connection()] = false
} }
} else { conn.Unlock()
connections[link.Connection()] = true
} }
} }
// delete connections without links func cleanProcesses() {
// make a list of processes without connections process.CleanProcessStorage(thresholdDuration)
for conn, active := range connections {
if conn != nil {
if !active {
conn.Delete()
_, ok := processes[conn.Process()]
if !ok {
processes[conn.Process()] = false
}
} else {
processes[conn.Process()] = true
}
}
}
// delete processes without connections
for proc, active := range processes {
if proc != nil && !active {
proc.Delete()
}
}
} }

View file

@ -3,21 +3,24 @@
package network package network
import ( import (
"errors"
"fmt" "fmt"
"net" "net"
"sync"
"time" "time"
"github.com/Safing/safing-core/database" "github.com/Safing/portbase/database/record"
"github.com/Safing/safing-core/intel" "github.com/Safing/portmaster/intel"
"github.com/Safing/safing-core/network/packet" "github.com/Safing/portmaster/network/netutils"
"github.com/Safing/safing-core/process" "github.com/Safing/portmaster/network/packet"
"github.com/Safing/portmaster/process"
datastore "github.com/ipfs/go-datastore"
) )
// Connection describes a connection between a process and a domain // Connection describes a connection between a process and a domain
type Connection struct { type Connection struct {
database.Base record.Base
sync.Mutex
Domain string Domain string
Direction bool Direction bool
Intel *intel.Intel Intel *intel.Intel
@ -25,125 +28,162 @@ type Connection struct {
Verdict Verdict Verdict Verdict
Reason string Reason string
Inspect bool Inspect bool
FirstLinkEstablished int64 FirstLinkEstablished int64
LastLinkEstablished int64
LinkCount uint
} }
var connectionModel *Connection // only use this as parameter for database.EnsureModel-like functions // Process returns the process that owns the connection.
func (conn *Connection) Process() *process.Process {
conn.Lock()
defer conn.Unlock()
func init() { return conn.process
database.RegisterModel(connectionModel, func() database.Model { return new(Connection) })
} }
func (m *Connection) Process() *process.Process { // GetVerdict returns the current verdict.
return m.process func (conn *Connection) GetVerdict() Verdict {
conn.Lock()
defer conn.Unlock()
return conn.Verdict
} }
// Create creates a new database entry in the database in the default namespace for this object // Accept accepts the connection and adds the given reason.
func (m *Connection) Create(name string) error { func (conn *Connection) Accept(reason string) {
return m.CreateObject(&database.OrphanedConnection, name, m) conn.AddReason(reason)
conn.UpdateVerdict(ACCEPT)
} }
// CreateInProcessNamespace creates a new database entry in the namespace of the connection's process // Deny blocks or drops the connection depending on the connection direction and adds the given reason.
func (m *Connection) CreateInProcessNamespace() error { func (conn *Connection) Deny(reason string) {
if m.process != nil { if conn.Direction {
return m.CreateObject(m.process.GetKey(), m.Domain, m) conn.Drop(reason)
} else {
conn.Block(reason)
} }
return m.CreateObject(&database.OrphanedConnection, m.Domain, m)
} }
// Save saves the object to the database (It must have been either already created or loaded from the database) // Block blocks the connection and adds the given reason.
func (m *Connection) Save() error { func (conn *Connection) Block(reason string) {
return m.SaveObject(m) conn.AddReason(reason)
conn.UpdateVerdict(BLOCK)
} }
func (m *Connection) CantSay() { // Drop drops the connection and adds the given reason.
if m.Verdict != CANTSAY { func (conn *Connection) Drop(reason string) {
m.Verdict = CANTSAY conn.AddReason(reason)
m.SaveObject(m) conn.UpdateVerdict(DROP)
}
// UpdateVerdict sets a new verdict for this link, making sure it does not interfere with previous verdicts
func (conn *Connection) UpdateVerdict(newVerdict Verdict) {
conn.Lock()
defer conn.Unlock()
if newVerdict > conn.Verdict {
conn.Verdict = newVerdict
go conn.Save()
} }
return
}
func (m *Connection) Drop() {
if m.Verdict != DROP {
m.Verdict = DROP
m.SaveObject(m)
}
return
}
func (m *Connection) Block() {
if m.Verdict != BLOCK {
m.Verdict = BLOCK
m.SaveObject(m)
}
return
}
func (m *Connection) Accept() {
if m.Verdict != ACCEPT {
m.Verdict = ACCEPT
m.SaveObject(m)
}
return
} }
// AddReason adds a human readable string as to why a certain verdict was set in regard to this connection // AddReason adds a human readable string as to why a certain verdict was set in regard to this connection
func (m *Connection) AddReason(newReason string) { func (conn *Connection) AddReason(reason string) {
if m.Reason != "" { if reason == "" {
m.Reason += " | " return
} }
m.Reason += newReason
conn.Lock()
defer conn.Unlock()
if conn.Reason != "" {
conn.Reason += " | "
}
conn.Reason += reason
} }
// GetConnectionByFirstPacket returns the matching connection from the internal storage.
func GetConnectionByFirstPacket(pkt packet.Packet) (*Connection, error) { func GetConnectionByFirstPacket(pkt packet.Packet) (*Connection, error) {
// get Process // get Process
proc, direction, err := process.GetProcessByPacket(pkt) proc, direction, err := process.GetProcessByPacket(pkt)
if err != nil { if err != nil {
return nil, err return nil, err
} }
var domain string
// if INBOUND // Incoming
if direction { if direction {
connection, err := GetConnectionFromProcessNamespace(proc, "I") switch netutils.ClassifyIP(pkt.GetIPHeader().Src) {
if err != nil { case netutils.HostLocal:
domain = IncomingHost
case netutils.LinkLocal, netutils.SiteLocal, netutils.LocalMulticast:
domain = IncomingLAN
case netutils.Global, netutils.GlobalMulticast:
domain = IncomingInternet
case netutils.Invalid:
domain = IncomingInvalid
}
connection, ok := GetConnection(proc.Pid, domain)
if !ok {
connection = &Connection{ connection = &Connection{
Domain: "I", Domain: domain,
Direction: true, Direction: Inbound,
process: proc, process: proc,
Inspect: true, Inspect: true,
FirstLinkEstablished: time.Now().Unix(), FirstLinkEstablished: time.Now().Unix(),
} }
} }
connection.process.AddConnection()
return connection, nil return connection, nil
} }
// get domain // get domain
ipinfo, err := intel.GetIPInfo(pkt.FmtRemoteIP()) ipinfo, err := intel.GetIPInfo(pkt.FmtRemoteIP())
// PeerToPeer
if err != nil { if err != nil {
// if no domain could be found, it must be a direct connection // if no domain could be found, it must be a direct connection
connection, err := GetConnectionFromProcessNamespace(proc, "D")
if err != nil { switch netutils.ClassifyIP(pkt.GetIPHeader().Dst) {
case netutils.HostLocal:
domain = PeerHost
case netutils.LinkLocal, netutils.SiteLocal, netutils.LocalMulticast:
domain = PeerLAN
case netutils.Global, netutils.GlobalMulticast:
domain = PeerInternet
case netutils.Invalid:
domain = PeerInvalid
}
connection, ok := GetConnection(proc.Pid, domain)
if !ok {
connection = &Connection{ connection = &Connection{
Domain: "D", Domain: domain,
Direction: Outbound,
process: proc, process: proc,
Inspect: true, Inspect: true,
FirstLinkEstablished: time.Now().Unix(), FirstLinkEstablished: time.Now().Unix(),
} }
} }
connection.process.AddConnection()
return connection, nil return connection, nil
} }
// To Domain
// FIXME: how to handle multiple possible domains? // FIXME: how to handle multiple possible domains?
connection, err := GetConnectionFromProcessNamespace(proc, ipinfo.Domains[0]) connection, ok := GetConnection(proc.Pid, ipinfo.Domains[0])
if err != nil { if !ok {
connection = &Connection{ connection = &Connection{
Domain: ipinfo.Domains[0], Domain: ipinfo.Domains[0],
Direction: Outbound,
process: proc, process: proc,
Inspect: true, Inspect: true,
FirstLinkEstablished: time.Now().Unix(), FirstLinkEstablished: time.Now().Unix(),
} }
} }
connection.process.AddConnection()
return connection, nil return connection, nil
} }
@ -154,6 +194,7 @@ var (
dnsPort uint16 = 53 dnsPort uint16 = 53
) )
// GetConnectionByDNSRequest returns the matching connection from the internal storage.
func GetConnectionByDNSRequest(ip net.IP, port uint16, fqdn string) (*Connection, error) { func GetConnectionByDNSRequest(ip net.IP, port uint16, fqdn string) (*Connection, error) {
// get Process // get Process
proc, err := process.GetProcessByEndpoints(ip, port, dnsAddress, dnsPort, packet.UDP) proc, err := process.GetProcessByEndpoints(ip, port, dnsAddress, dnsPort, packet.UDP)
@ -161,70 +202,124 @@ func GetConnectionByDNSRequest(ip net.IP, port uint16, fqdn string) (*Connection
return nil, err return nil, err
} }
connection, err := GetConnectionFromProcessNamespace(proc, fqdn) connection, ok := GetConnection(proc.Pid, fqdn)
if err != nil { if !ok {
connection = &Connection{ connection = &Connection{
Domain: fqdn, Domain: fqdn,
process: proc, process: proc,
Inspect: true, Inspect: true,
} }
connection.CreateInProcessNamespace() connection.process.AddConnection()
connection.Save()
} }
return connection, nil return connection, nil
} }
// GetConnection fetches a Connection from the database from the default namespace for this object // GetConnection fetches a connection object from the internal storage.
func GetConnection(name string) (*Connection, error) { func GetConnection(pid int, domain string) (conn *Connection, ok bool) {
return GetConnectionFromNamespace(&database.OrphanedConnection, name) connectionsLock.RLock()
defer connectionsLock.RUnlock()
conn, ok = connections[fmt.Sprintf("%d/%s", pid, domain)]
return
} }
// GetConnectionFromProcessNamespace fetches a Connection from the namespace of its process func (conn *Connection) makeKey() string {
func GetConnectionFromProcessNamespace(process *process.Process, domain string) (*Connection, error) { return fmt.Sprintf("%d/%s", conn.process.Pid, conn.Domain)
return GetConnectionFromNamespace(process.GetKey(), domain)
} }
// GetConnectionFromNamespace fetches a Connection form the database, but from a custom namespace // Save saves the connection object in the storage and propagates the change.
func GetConnectionFromNamespace(namespace *datastore.Key, name string) (*Connection, error) { func (conn *Connection) Save() error {
object, err := database.GetAndEnsureModel(namespace, name, connectionModel) conn.Lock()
if err != nil { defer conn.Unlock()
return nil, err
if conn.process == nil {
return errors.New("cannot save connection without process")
} }
model, ok := object.(*Connection)
if !conn.KeyIsSet() {
conn.SetKey(fmt.Sprintf("network:tree/%d/%s", conn.process.Pid, conn.Domain))
conn.CreateMeta()
}
key := conn.makeKey()
connectionsLock.RLock()
_, ok := connections[key]
connectionsLock.RUnlock()
if !ok { if !ok {
return nil, database.NewMismatchError(object, connectionModel) connectionsLock.Lock()
connections[key] = conn
connectionsLock.Unlock()
} }
return model, nil
go dbController.PushUpdate(conn)
return nil
} }
func (m *Connection) AddLink(link *Link, pkt packet.Packet) { // Delete deletes a connection from the storage and propagates the change.
link.connection = m func (conn *Connection) Delete() {
link.Verdict = m.Verdict conn.Lock()
link.Inspect = m.Inspect defer conn.Unlock()
if m.FirstLinkEstablished == 0 {
m.FirstLinkEstablished = time.Now().Unix() connectionsLock.Lock()
m.Save() delete(connections, conn.makeKey())
} connectionsLock.Unlock()
link.CreateInConnectionNamespace(pkt.GetConnectionID())
conn.Meta().Delete()
go dbController.PushUpdate(conn)
conn.process.RemoveConnection()
go conn.process.Save()
} }
// FORMATTING // AddLink applies the connection to the link and increases sets counter and timestamps.
func (conn *Connection) AddLink(link *Link) {
link.Lock()
link.connection = conn
link.Verdict = conn.Verdict
link.Inspect = conn.Inspect
link.Unlock()
link.Save()
func (m *Connection) String() string { conn.Lock()
switch m.Domain { conn.LinkCount++
case "I": conn.LastLinkEstablished = time.Now().Unix()
if m.process == nil { if conn.FirstLinkEstablished == 0 {
conn.FirstLinkEstablished = conn.LastLinkEstablished
}
conn.Unlock()
conn.Save()
}
// RemoveLink lowers the link counter by one.
func (conn *Connection) RemoveLink() {
conn.Lock()
defer conn.Unlock()
if conn.LinkCount > 0 {
conn.LinkCount--
}
}
// String returns a string representation of Connection.
func (conn *Connection) String() string {
conn.Lock()
defer conn.Unlock()
switch conn.Domain {
case IncomingHost, IncomingLAN, IncomingInternet, IncomingInvalid:
if conn.process == nil {
return "? <- *" return "? <- *"
} }
return fmt.Sprintf("%s <- *", m.process.String()) return fmt.Sprintf("%s <- *", conn.process.String())
case "D": case PeerHost, PeerLAN, PeerInternet, PeerInvalid:
if m.process == nil { if conn.process == nil {
return "? -> *" return "? -> *"
} }
return fmt.Sprintf("%s -> *", m.process.String()) return fmt.Sprintf("%s -> *", conn.process.String())
default: default:
if m.process == nil { if conn.process == nil {
return fmt.Sprintf("? -> %s", m.Domain) return fmt.Sprintf("? -> %s", conn.Domain)
} }
return fmt.Sprintf("%s -> %s", m.process.String(), m.Domain) return fmt.Sprintf("%s -> %s", conn.process.String(), conn.Domain)
} }
} }

122
network/database.go Normal file
View file

@ -0,0 +1,122 @@
package network
import (
"strconv"
"strings"
"sync"
"github.com/Safing/portbase/database"
"github.com/Safing/portbase/database/iterator"
"github.com/Safing/portbase/database/query"
"github.com/Safing/portbase/database/record"
"github.com/Safing/portbase/database/storage"
"github.com/Safing/portmaster/process"
)
var (
links = make(map[string]*Link)
linksLock sync.RWMutex
connections = make(map[string]*Connection)
connectionsLock sync.RWMutex
dbController *database.Controller
)
// StorageInterface provices a storage.Interface to the configuration manager.
type StorageInterface struct {
storage.InjectBase
}
// Get returns a database record.
func (s *StorageInterface) Get(key string) (record.Record, error) {
splitted := strings.Split(key, "/")
switch splitted[0] {
case "tree":
switch len(splitted) {
case 2:
pid, err := strconv.Atoi(splitted[1])
if err == nil {
proc, ok := process.GetProcessFromStorage(pid)
if ok {
return proc, nil
}
}
case 3:
connectionsLock.RLock()
defer connectionsLock.RUnlock()
conn, ok := connections[splitted[2]]
if ok {
return conn, nil
}
case 4:
linksLock.RLock()
defer linksLock.RUnlock()
link, ok := links[splitted[3]]
if ok {
return link, nil
}
}
}
return nil, storage.ErrNotFound
}
// Query returns a an iterator for the supplied query.
func (s *StorageInterface) Query(q *query.Query, local, internal bool) (*iterator.Iterator, error) {
it := iterator.New()
go s.processQuery(q, it)
// TODO: check local and internal
return it, nil
}
func (s *StorageInterface) processQuery(q *query.Query, it *iterator.Iterator) {
// processes
for _, proc := range process.All() {
if strings.HasPrefix(proc.DatabaseKey(), q.DatabaseKeyPrefix()) {
it.Next <- proc
}
}
// connections
connectionsLock.RLock()
for _, conn := range connections {
if strings.HasPrefix(conn.DatabaseKey(), q.DatabaseKeyPrefix()) {
it.Next <- conn
}
}
connectionsLock.RUnlock()
// links
linksLock.RLock()
for _, link := range links {
if strings.HasPrefix(link.DatabaseKey(), q.DatabaseKeyPrefix()) {
it.Next <- link
}
}
linksLock.RUnlock()
it.Finish(nil)
}
func registerAsDatabase() error {
_, err := database.Register(&database.Database{
Name: "network",
Description: "Network and Firewall Data",
StorageType: "injected",
PrimaryAPI: "",
})
if err != nil {
return err
}
controller, err := database.InjectDatabase("network", &StorageInterface{})
if err != nil {
return err
}
dbController = controller
process.SetDBController(dbController)
return nil
}

View file

@ -4,7 +4,7 @@ import (
"net" "net"
"strings" "strings"
"github.com/Safing/safing-core/network/netutils" "github.com/Safing/portmaster/network/netutils"
) )
func GetAssignedAddresses() (ipv4 []net.IP, ipv6 []net.IP, err error) { func GetAssignedAddresses() (ipv4 []net.IP, ipv6 []net.IP, err error) {

View file

@ -24,6 +24,9 @@ func getNameserversFromDbus() ([]Nameserver, error) {
var nameservers []Nameserver var nameservers []Nameserver
var err error var err error
dbusConnLock.Lock()
defer dbusConnLock.Unlock()
if dbusConn == nil { if dbusConn == nil {
dbusConn, err = dbus.SystemBus() dbusConn, err = dbus.SystemBus()
} }
@ -158,6 +161,9 @@ func getNameserversFromDbus() ([]Nameserver, error) {
func getConnectivityStateFromDbus() (uint8, error) { func getConnectivityStateFromDbus() (uint8, error) {
var err error var err error
dbusConnLock.Lock()
defer dbusConnLock.Unlock()
if dbusConn == nil { if dbusConn == nil {
dbusConn, err = dbus.SystemBus() dbusConn, err = dbus.SystemBus()
} }

View file

@ -11,7 +11,7 @@ import (
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/Safing/safing-core/log" "github.com/Safing/portbase/log"
) )
// TODO: find a good way to identify a network // TODO: find a good way to identify a network

View file

@ -0,0 +1,27 @@
package environment
import "net"
func Nameservers() []Nameserver {
return nil
}
func Gateways() []*net.IP {
return nil
}
// TODO: implement using
// ifconfig
// scutil --nwi
// scutil --proxy
// networksetup -listallnetworkservices
// networksetup -listnetworkserviceorder
// networksetup -getdnsservers "Wi-Fi"
// networksetup -getsearchdomains <networkservice>
// networksetup -getftpproxy <networkservice>
// networksetup -getwebproxy <networkservice>
// networksetup -getsecurewebproxy <networkservice>
// networksetup -getstreamingproxy <networkservice>
// networksetup -getgopherproxy <networkservice>
// networksetup -getsocksfirewallproxy <networkservice>
// route -n get default

View file

@ -12,8 +12,8 @@ import (
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/Safing/safing-core/log" "github.com/Safing/portbase/log"
"github.com/Safing/safing-core/network/netutils" "github.com/Safing/portmaster/network/netutils"
) )
// Gateways returns the currently active gateways // Gateways returns the currently active gateways

View file

@ -1,4 +1,4 @@
// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. // +build linux
package environment package environment

View file

@ -8,9 +8,10 @@ import (
"log" "log"
"net" "net"
"os" "os"
"github.com/Safing/safing-core/network/netutils"
"time" "time"
"github.com/Safing/portmaster/network/netutils"
"golang.org/x/net/icmp" "golang.org/x/net/icmp"
"golang.org/x/net/ipv4" "golang.org/x/net/ipv4"
) )
@ -99,7 +100,7 @@ next:
if ip == nil { if ip == nil {
return nil, errors.New(fmt.Sprintf("failed to parse IP: %s", peer.String())) return nil, errors.New(fmt.Sprintf("failed to parse IP: %s", peer.String()))
} }
if !netutils.IPIsLocal(ip) { if !netutils.IPIsLAN(ip) {
return ip, nil return ip, nil
} }
continue next continue next

View file

@ -3,28 +3,30 @@
package network package network
import ( import (
"errors"
"fmt" "fmt"
"sync" "sync"
"time" "time"
datastore "github.com/ipfs/go-datastore" "github.com/Safing/portbase/database/record"
"github.com/Safing/portbase/log"
"github.com/Safing/safing-core/database" "github.com/Safing/portmaster/network/packet"
"github.com/Safing/safing-core/log"
"github.com/Safing/safing-core/network/packet"
) )
// FirewallHandler defines the function signature for a firewall handle function
type FirewallHandler func(pkt packet.Packet, link *Link) type FirewallHandler func(pkt packet.Packet, link *Link)
var ( var (
linkTimeout = 10 * time.Minute linkTimeout = 10 * time.Minute
allLinks = make(map[string]*Link)
allLinksLock sync.RWMutex
) )
// Link describes an distinct physical connection (e.g. TCP connection) - like an instance - of a Connection // Link describes a distinct physical connection (e.g. TCP connection) - like an instance - of a Connection.
type Link struct { type Link struct {
database.Base record.Base
sync.Mutex
ID string
Verdict Verdict Verdict Verdict
Reason string Reason string
Tunneled bool Tunneled bool
@ -32,180 +34,322 @@ type Link struct {
Inspect bool Inspect bool
Started int64 Started int64
Ended int64 Ended int64
connection *Connection
RemoteAddress string RemoteAddress string
ActiveInspectors []bool `json:"-" bson:"-"`
InspectorData map[uint8]interface{} `json:"-" bson:"-"`
pktQueue chan packet.Packet pktQueue chan packet.Packet
firewallHandler FirewallHandler firewallHandler FirewallHandler
} connection *Connection
var linkModel *Link // only use this as parameter for database.EnsureModel-like functions activeInspectors []bool
inspectorData map[uint8]interface{}
func init() {
database.RegisterModel(linkModel, func() database.Model { return new(Link) })
} }
// Connection returns the Connection the Link is part of // Connection returns the Connection the Link is part of
func (m *Link) Connection() *Connection { func (link *Link) Connection() *Connection {
return m.connection link.Lock()
defer link.Unlock()
return link.connection
}
// GetVerdict returns the current verdict.
func (link *Link) GetVerdict() Verdict {
link.Lock()
defer link.Unlock()
return link.Verdict
} }
// FirewallHandlerIsSet returns whether a firewall handler is set or not // FirewallHandlerIsSet returns whether a firewall handler is set or not
func (m *Link) FirewallHandlerIsSet() bool { func (link *Link) FirewallHandlerIsSet() bool {
return m.firewallHandler != nil link.Lock()
defer link.Unlock()
return link.firewallHandler != nil
} }
// SetFirewallHandler sets the firewall handler for this link // SetFirewallHandler sets the firewall handler for this link
func (m *Link) SetFirewallHandler(handler FirewallHandler) { func (link *Link) SetFirewallHandler(handler FirewallHandler) {
if m.firewallHandler == nil { link.Lock()
m.firewallHandler = handler defer link.Unlock()
m.pktQueue = make(chan packet.Packet, 1000)
go m.packetHandler() if link.firewallHandler == nil {
link.firewallHandler = handler
link.pktQueue = make(chan packet.Packet, 1000)
go link.packetHandler()
return return
} }
m.firewallHandler = handler link.firewallHandler = handler
} }
// StopFirewallHandler unsets the firewall handler // StopFirewallHandler unsets the firewall handler
func (m *Link) StopFirewallHandler() { func (link *Link) StopFirewallHandler() {
m.pktQueue <- nil link.Lock()
link.firewallHandler = nil
link.Unlock()
link.pktQueue <- nil
} }
// HandlePacket queues packet of Link for handling // HandlePacket queues packet of Link for handling
func (m *Link) HandlePacket(pkt packet.Packet) { func (link *Link) HandlePacket(pkt packet.Packet) {
if m.firewallHandler != nil { link.Lock()
m.pktQueue <- pkt defer link.Unlock()
if link.firewallHandler != nil {
link.pktQueue <- pkt
return return
} }
log.Criticalf("network: link %s does not have a firewallHandler, maybe its a copy, dropping packet", m) log.Criticalf("network: link %s does not have a firewallHandler, dropping packet", link)
pkt.Drop() pkt.Drop()
} }
// Accept accepts the link and adds the given reason.
func (link *Link) Accept(reason string) {
link.AddReason(reason)
link.UpdateVerdict(ACCEPT)
}
// Deny blocks or drops the link depending on the connection direction and adds the given reason.
func (link *Link) Deny(reason string) {
if link.connection != nil && link.connection.Direction {
link.Drop(reason)
} else {
link.Block(reason)
}
}
// Block blocks the link and adds the given reason.
func (link *Link) Block(reason string) {
link.AddReason(reason)
link.UpdateVerdict(BLOCK)
}
// Drop drops the link and adds the given reason.
func (link *Link) Drop(reason string) {
link.AddReason(reason)
link.UpdateVerdict(DROP)
}
// RerouteToNameserver reroutes the link to the portmaster nameserver.
func (link *Link) RerouteToNameserver() {
link.UpdateVerdict(RerouteToNameserver)
}
// RerouteToTunnel reroutes the link to the tunnel entrypoint and adds the given reason for accepting the connection.
func (link *Link) RerouteToTunnel(reason string) {
link.AddReason(reason)
link.UpdateVerdict(RerouteToTunnel)
}
// UpdateVerdict sets a new verdict for this link, making sure it does not interfere with previous verdicts // UpdateVerdict sets a new verdict for this link, making sure it does not interfere with previous verdicts
func (m *Link) UpdateVerdict(newVerdict Verdict) { func (link *Link) UpdateVerdict(newVerdict Verdict) {
if newVerdict > m.Verdict { link.Lock()
m.Verdict = newVerdict defer link.Unlock()
m.Save()
if newVerdict > link.Verdict {
link.Verdict = newVerdict
go link.Save()
} }
} }
// AddReason adds a human readable string as to why a certain verdict was set in regard to this link // AddReason adds a human readable string as to why a certain verdict was set in regard to this link
func (m *Link) AddReason(newReason string) { func (link *Link) AddReason(reason string) {
if m.Reason != "" { if reason == "" {
m.Reason += " | " return
} }
m.Reason += newReason
link.Lock()
defer link.Unlock()
if link.Reason != "" {
link.Reason += " | "
}
link.Reason += reason
} }
// packetHandler sequentially handles queued packets // packetHandler sequentially handles queued packets
func (m *Link) packetHandler() { func (link *Link) packetHandler() {
for { for {
pkt := <-m.pktQueue pkt := <-link.pktQueue
if pkt == nil { if pkt == nil {
break return
} }
m.firewallHandler(pkt, m) link.Lock()
fwH := link.firewallHandler
link.Unlock()
if fwH != nil {
fwH(pkt, link)
} else {
link.ApplyVerdict(pkt)
} }
m.firewallHandler = nil
}
// Create creates a new database entry in the database in the default namespace for this object
func (m *Link) Create(name string) error {
m.CreateShallow(name)
return m.CreateObject(&database.OrphanedLink, name, m)
}
// Create creates a new database entry in the database in the default namespace for this object
func (m *Link) CreateShallow(name string) {
allLinksLock.Lock()
allLinks[name] = m
allLinksLock.Unlock()
}
// CreateWithDefaultKey creates a new database entry in the database in the default namespace for this object using the default key
func (m *Link) CreateInConnectionNamespace(name string) error {
if m.connection != nil {
return m.CreateObject(m.connection.GetKey(), name, m)
} }
return m.CreateObject(&database.OrphanedLink, name, m)
} }
// Save saves the object to the database (It must have been either already created or loaded from the database) // ApplyVerdict appies the link verdict to a packet.
func (m *Link) Save() error { func (link *Link) ApplyVerdict(pkt packet.Packet) {
return m.SaveObject(m) link.Lock()
defer link.Unlock()
if link.VerdictPermanent {
switch link.Verdict {
case ACCEPT:
pkt.PermanentAccept()
case BLOCK:
pkt.PermanentBlock()
case DROP:
pkt.PermanentDrop()
case RerouteToNameserver:
pkt.RerouteToNameserver()
case RerouteToTunnel:
pkt.RerouteToTunnel()
default:
pkt.Drop()
}
} else {
switch link.Verdict {
case ACCEPT:
pkt.Accept()
case BLOCK:
pkt.Block()
case DROP:
pkt.Drop()
case RerouteToNameserver:
pkt.RerouteToNameserver()
case RerouteToTunnel:
pkt.RerouteToTunnel()
default:
pkt.Drop()
}
}
}
// Save saves the link object in the storage and propagates the change.
func (link *Link) Save() error {
link.Lock()
defer link.Unlock()
if link.connection == nil {
return errors.New("cannot save link without connection")
}
if !link.KeyIsSet() {
link.SetKey(fmt.Sprintf("network:tree/%d/%s/%s", link.connection.Process().Pid, link.connection.Domain, link.ID))
link.CreateMeta()
}
linksLock.RLock()
_, ok := links[link.ID]
linksLock.RUnlock()
if !ok {
linksLock.Lock()
links[link.ID] = link
linksLock.Unlock()
}
go dbController.PushUpdate(link)
return nil
}
// Delete deletes a link from the storage and propagates the change.
func (link *Link) Delete() {
link.Lock()
defer link.Unlock()
linksLock.Lock()
delete(links, link.ID)
linksLock.Unlock()
link.Meta().Delete()
go dbController.PushUpdate(link)
link.connection.RemoveLink()
go link.connection.Save()
} }
// GetLink fetches a Link from the database from the default namespace for this object // GetLink fetches a Link from the database from the default namespace for this object
func GetLink(name string) (*Link, error) { func GetLink(id string) (*Link, bool) {
allLinksLock.RLock() linksLock.RLock()
link, ok := allLinks[name] defer linksLock.RUnlock()
allLinksLock.RUnlock()
if !ok {
return nil, database.ErrNotFound
}
return link, nil
// return GetLinkFromNamespace(&database.RunningLink, name)
}
func SaveInCache(link *Link) { link, ok := links[id]
return link, ok
}
// GetLinkFromNamespace fetches a Link form the database, but from a custom namespace
func GetLinkFromNamespace(namespace *datastore.Key, name string) (*Link, error) {
object, err := database.GetAndEnsureModel(namespace, name, linkModel)
if err != nil {
return nil, err
}
model, ok := object.(*Link)
if !ok {
return nil, database.NewMismatchError(object, linkModel)
}
return model, nil
} }
// GetOrCreateLinkByPacket returns the associated Link for a packet and a bool expressing if the Link was newly created // GetOrCreateLinkByPacket returns the associated Link for a packet and a bool expressing if the Link was newly created
func GetOrCreateLinkByPacket(pkt packet.Packet) (*Link, bool) { func GetOrCreateLinkByPacket(pkt packet.Packet) (*Link, bool) {
link, err := GetLink(pkt.GetConnectionID()) link, ok := GetLink(pkt.GetLinkID())
if err != nil { if ok {
return CreateLinkFromPacket(pkt), true
}
return link, false return link, false
}
return CreateLinkFromPacket(pkt), true
} }
// CreateLinkFromPacket creates a new Link based on Packet. The Link is shallowly saved and SHOULD be saved to the database as soon more information is available // CreateLinkFromPacket creates a new Link based on Packet.
func CreateLinkFromPacket(pkt packet.Packet) *Link { func CreateLinkFromPacket(pkt packet.Packet) *Link {
link := &Link{ link := &Link{
ID: pkt.GetLinkID(),
Verdict: UNDECIDED, Verdict: UNDECIDED,
Started: time.Now().Unix(), Started: time.Now().Unix(),
RemoteAddress: pkt.FmtRemoteAddress(), RemoteAddress: pkt.FmtRemoteAddress(),
} }
link.CreateShallow(pkt.GetConnectionID())
return link return link
} }
// FORMATTING // GetActiveInspectors returns the list of active inspectors.
func (m *Link) String() string { func (link *Link) GetActiveInspectors() []bool {
if m.connection == nil { link.Lock()
return fmt.Sprintf("? <-> %s", m.RemoteAddress) defer link.Unlock()
return link.activeInspectors
}
// SetActiveInspectors sets the list of active inspectors.
func (link *Link) SetActiveInspectors(new []bool) {
link.Lock()
defer link.Unlock()
link.activeInspectors = new
}
// GetInspectorData returns the list of inspector data.
func (link *Link) GetInspectorData() map[uint8]interface{} {
link.Lock()
defer link.Unlock()
return link.inspectorData
}
// SetInspectorData set the list of inspector data.
func (link *Link) SetInspectorData(new map[uint8]interface{}) {
link.Lock()
defer link.Unlock()
link.inspectorData = new
}
// String returns a string representation of Link.
func (link *Link) String() string {
link.Lock()
defer link.Unlock()
if link.connection == nil {
return fmt.Sprintf("? <-> %s", link.RemoteAddress)
} }
switch m.connection.Domain { switch link.connection.Domain {
case "I": case "I":
if m.connection.process == nil { if link.connection.process == nil {
return fmt.Sprintf("? <- %s", m.RemoteAddress) return fmt.Sprintf("? <- %s", link.RemoteAddress)
} }
return fmt.Sprintf("%s <- %s", m.connection.process.String(), m.RemoteAddress) return fmt.Sprintf("%s <- %s", link.connection.process.String(), link.RemoteAddress)
case "D": case "D":
if m.connection.process == nil { if link.connection.process == nil {
return fmt.Sprintf("? -> %s", m.RemoteAddress) return fmt.Sprintf("? -> %s", link.RemoteAddress)
} }
return fmt.Sprintf("%s -> %s", m.connection.process.String(), m.RemoteAddress) return fmt.Sprintf("%s -> %s", link.connection.process.String(), link.RemoteAddress)
default: default:
if m.connection.process == nil { if link.connection.process == nil {
return fmt.Sprintf("? -> %s (%s)", m.connection.Domain, m.RemoteAddress) return fmt.Sprintf("? -> %s (%s)", link.connection.Domain, link.RemoteAddress)
} }
return fmt.Sprintf("%s to %s (%s)", m.connection.process.String(), m.connection.Domain, m.RemoteAddress) return fmt.Sprintf("%s to %s (%s)", link.connection.process.String(), link.connection.Domain, link.RemoteAddress)
} }
} }

14
network/module.go Normal file
View file

@ -0,0 +1,14 @@
package network
import (
"github.com/Safing/portbase/modules"
)
func init() {
modules.Register("network", nil, start, nil, "database")
}
func start() error {
go cleaner()
return registerAsDatabase()
}

View file

@ -11,6 +11,7 @@ var (
cleanDomainRegex = regexp.MustCompile("^((xn--)?[a-z0-9-_]{0,61}[a-z0-9]{1,1}\\.)*(xn--)?([a-z0-9-]{1,61}|[a-z0-9-]{1,30}\\.[a-z]{2,}\\.)$") cleanDomainRegex = regexp.MustCompile("^((xn--)?[a-z0-9-_]{0,61}[a-z0-9]{1,1}\\.)*(xn--)?([a-z0-9-]{1,61}|[a-z0-9-]{1,30}\\.[a-z]{2,}\\.)$")
) )
// IsValidFqdn returns whether the given string is a valid fqdn.
func IsValidFqdn(fqdn string) bool { func IsValidFqdn(fqdn string) bool {
return cleanDomainRegex.MatchString(fqdn) return cleanDomainRegex.MatchString(fqdn)
} }

View file

@ -4,95 +4,101 @@ package netutils
import "net" import "net"
// IP types // IP classifications
const ( const (
hostLocal int8 = iota HostLocal int8 = iota
linkLocal LinkLocal
siteLocal SiteLocal
global Global
localMulticast LocalMulticast
globalMulticast GlobalMulticast
invalid Invalid
) )
func classifyAddress(ip net.IP) int8 { // ClassifyIP returns the classification for the given IP address.
func ClassifyIP(ip net.IP) int8 {
if ip4 := ip.To4(); ip4 != nil { if ip4 := ip.To4(); ip4 != nil {
// IPv4 // IPv4
switch { switch {
case ip4[0] == 127: case ip4[0] == 127:
// 127.0.0.0/8 // 127.0.0.0/8
return hostLocal return HostLocal
case ip4[0] == 169 && ip4[1] == 254: case ip4[0] == 169 && ip4[1] == 254:
// 169.254.0.0/16 // 169.254.0.0/16
return linkLocal return LinkLocal
case ip4[0] == 10: case ip4[0] == 10:
// 10.0.0.0/8 // 10.0.0.0/8
return siteLocal return SiteLocal
case ip4[0] == 172 && ip4[1]&0xf0 == 16: case ip4[0] == 172 && ip4[1]&0xf0 == 16:
// 172.16.0.0/12 // 172.16.0.0/12
return siteLocal return SiteLocal
case ip4[0] == 192 && ip4[1] == 168: case ip4[0] == 192 && ip4[1] == 168:
// 192.168.0.0/16 // 192.168.0.0/16
return siteLocal return SiteLocal
case ip4[0] == 224: case ip4[0] == 224:
// 224.0.0.0/8 // 224.0.0.0/8
return localMulticast return LocalMulticast
case ip4[0] >= 225 && ip4[0] <= 239: case ip4[0] >= 225 && ip4[0] <= 239:
// 225.0.0.0/8 - 239.0.0.0/8 // 225.0.0.0/8 - 239.0.0.0/8
return globalMulticast return GlobalMulticast
case ip4[0] >= 240: case ip4[0] >= 240:
// 240.0.0.0/8 - 255.0.0.0/8 // 240.0.0.0/8 - 255.0.0.0/8
return invalid return Invalid
default: default:
return global return Global
} }
} else if len(ip) == net.IPv6len { } else if len(ip) == net.IPv6len {
// IPv6 // IPv6
switch { switch {
case ip.Equal(net.IPv6loopback): case ip.Equal(net.IPv6loopback):
return hostLocal return HostLocal
case ip[0]&0xfe == 0xfc: case ip[0]&0xfe == 0xfc:
// fc00::/7 // fc00::/7
return siteLocal return SiteLocal
case ip[0] == 0xfe && ip[1]&0xc0 == 0x80: case ip[0] == 0xfe && ip[1]&0xc0 == 0x80:
// fe80::/10 // fe80::/10
return linkLocal return LinkLocal
case ip[0] == 0xff && ip[1] <= 0x05: case ip[0] == 0xff && ip[1] <= 0x05:
// ff00::/16 - ff05::/16 // ff00::/16 - ff05::/16
return localMulticast return LocalMulticast
case ip[0] == 0xff: case ip[0] == 0xff:
// other ff00::/8 // other ff00::/8
return globalMulticast return GlobalMulticast
default: default:
return global return Global
} }
} }
return invalid return Invalid
} }
// IPIsLocal returns true if the given IP is a site-local or link-local address // IPIsLocalhost returns whether the IP refers to the host itself.
func IPIsLocal(ip net.IP) bool { func IPIsLocalhost(ip net.IP) bool {
switch classifyAddress(ip) { return ClassifyIP(ip) == HostLocal
case siteLocal: }
// IPIsLAN returns true if the given IP is a site-local or link-local address.
func IPIsLAN(ip net.IP) bool {
switch ClassifyIP(ip) {
case SiteLocal:
return true return true
case linkLocal: case LinkLocal:
return true return true
default: default:
return false return false
} }
} }
// IPIsGlobal returns true if the given IP is a global address // IPIsGlobal returns true if the given IP is a global address.
func IPIsGlobal(ip net.IP) bool { func IPIsGlobal(ip net.IP) bool {
return classifyAddress(ip) == global return ClassifyIP(ip) == Global
} }
// IPIsLinkLocal returns true if the given IP is a link-local address // IPIsLinkLocal returns true if the given IP is a link-local address.
func IPIsLinkLocal(ip net.IP) bool { func IPIsLinkLocal(ip net.IP) bool {
return classifyAddress(ip) == linkLocal return ClassifyIP(ip) == LinkLocal
} }
// IPIsSiteLocal returns true if the given IP is a site-local address // IPIsSiteLocal returns true if the given IP is a site-local address.
func IPIsSiteLocal(ip net.IP) bool { func IPIsSiteLocal(ip net.IP) bool {
return classifyAddress(ip) == siteLocal return ClassifyIP(ip) == SiteLocal
} }

View file

@ -6,14 +6,14 @@ import (
) )
func TestIPClassification(t *testing.T) { func TestIPClassification(t *testing.T) {
testClassification(t, net.IPv4(71, 87, 113, 211), global) testClassification(t, net.IPv4(71, 87, 113, 211), Global)
testClassification(t, net.IPv4(127, 0, 0, 1), hostLocal) testClassification(t, net.IPv4(127, 0, 0, 1), HostLocal)
testClassification(t, net.IPv4(127, 255, 255, 1), hostLocal) testClassification(t, net.IPv4(127, 255, 255, 1), HostLocal)
testClassification(t, net.IPv4(192, 168, 172, 24), siteLocal) testClassification(t, net.IPv4(192, 168, 172, 24), SiteLocal)
} }
func testClassification(t *testing.T, ip net.IP, expectedClassification int8) { func testClassification(t *testing.T, ip net.IP, expectedClassification int8) {
c := classifyAddress(ip) c := ClassifyIP(ip)
if c != expectedClassification { if c != expectedClassification {
t.Errorf("%s is %s, expected %s", ip, classificationString(c), classificationString(expectedClassification)) t.Errorf("%s is %s, expected %s", ip, classificationString(c), classificationString(expectedClassification))
} }
@ -21,19 +21,19 @@ func testClassification(t *testing.T, ip net.IP, expectedClassification int8) {
func classificationString(c int8) string { func classificationString(c int8) string {
switch c { switch c {
case hostLocal: case HostLocal:
return "hostLocal" return "hostLocal"
case linkLocal: case LinkLocal:
return "linkLocal" return "linkLocal"
case siteLocal: case SiteLocal:
return "siteLocal" return "siteLocal"
case global: case Global:
return "global" return "global"
case localMulticast: case LocalMulticast:
return "localMulticast" return "localMulticast"
case globalMulticast: case GlobalMulticast:
return "globalMulticast" return "globalMulticast"
case invalid: case Invalid:
return "invalid" return "invalid"
default: default:
return "unknown" return "unknown"

View file

@ -106,7 +106,7 @@ type TCPUDPHeader struct {
} }
type PacketBase struct { type PacketBase struct {
connectionID string linkID string
Direction bool Direction bool
InTunnel bool InTunnel bool
Payload []byte Payload []byte
@ -146,25 +146,25 @@ func (pkt *PacketBase) IPVersion() IPVersion {
return pkt.Version return pkt.Version
} }
func (pkt *PacketBase) GetConnectionID() string { func (pkt *PacketBase) GetLinkID() string {
if pkt.connectionID == "" { if pkt.linkID == "" {
pkt.createConnectionID() pkt.createLinkID()
} }
return pkt.connectionID return pkt.linkID
} }
func (pkt *PacketBase) createConnectionID() { func (pkt *PacketBase) createLinkID() {
if pkt.IPHeader.Protocol == TCP || pkt.IPHeader.Protocol == UDP { if pkt.IPHeader.Protocol == TCP || pkt.IPHeader.Protocol == UDP {
if pkt.Direction { if pkt.Direction {
pkt.connectionID = fmt.Sprintf("%d-%s-%d-%s-%d", pkt.Protocol, pkt.Dst, pkt.DstPort, pkt.Src, pkt.SrcPort) pkt.linkID = fmt.Sprintf("%d-%s-%d-%s-%d", pkt.Protocol, pkt.Dst, pkt.DstPort, pkt.Src, pkt.SrcPort)
} else { } else {
pkt.connectionID = fmt.Sprintf("%d-%s-%d-%s-%d", pkt.Protocol, pkt.Src, pkt.SrcPort, pkt.Dst, pkt.DstPort) pkt.linkID = fmt.Sprintf("%d-%s-%d-%s-%d", pkt.Protocol, pkt.Src, pkt.SrcPort, pkt.Dst, pkt.DstPort)
} }
} else { } else {
if pkt.Direction { if pkt.Direction {
pkt.connectionID = fmt.Sprintf("%d-%s-%s", pkt.Protocol, pkt.Dst, pkt.Src) pkt.linkID = fmt.Sprintf("%d-%s-%s", pkt.Protocol, pkt.Dst, pkt.Src)
} else { } else {
pkt.connectionID = fmt.Sprintf("%d-%s-%s", pkt.Protocol, pkt.Src, pkt.Dst) pkt.linkID = fmt.Sprintf("%d-%s-%s", pkt.Protocol, pkt.Src, pkt.Dst)
} }
} }
} }
@ -299,7 +299,7 @@ type Packet interface {
IsOutbound() bool IsOutbound() bool
SetInbound() SetInbound()
SetOutbound() SetOutbound()
GetConnectionID() string GetLinkID() string
IPVersion() IPVersion IPVersion() IPVersion
// MATCHING // MATCHING

View file

@ -0,0 +1,45 @@
package reference
import "strconv"
var (
protocolNames = map[uint8]string{
1: "ICMP",
2: "IGMP",
6: "TCP",
17: "UDP",
27: "RDP",
58: "ICMPv6",
33: "DCCP",
136: "UDPLite",
}
protocolNumbers = map[string]uint8{
"ICMP": 1,
"IGMP": 2,
"TCP": 6,
"UDP": 17,
"RDP": 27,
"DCCP": 33,
"ICMPv6": 58,
"UDPLite": 136,
}
)
// GetProtocolName returns the name of a IP protocol number.
func GetProtocolName(protocol uint8) (name string) {
name, ok := protocolNames[protocol]
if ok {
return name
}
return strconv.Itoa(int(protocol))
}
// GetProtocolNumber returns the number of a IP protocol name.
func GetProtocolNumber(protocol string) (number uint8, ok bool) {
number, ok = protocolNumbers[protocol]
if ok {
return number, true
}
return 0, false
}

View file

@ -2,20 +2,34 @@
package network package network
// Status describes the status of a connection. // Verdict describes the decision made about a connection or link.
type Verdict uint8 type Verdict uint8
// List of values a Status can have // List of values a Status can have
const ( const (
// UNDECIDED is the default status of new connections // UNDECIDED is the default status of new connections
UNDECIDED Verdict = iota UNDECIDED Verdict = iota
CANTSAY
ACCEPT ACCEPT
BLOCK BLOCK
DROP DROP
RerouteToNameserver
RerouteToTunnel
) )
// Packer Directions
const ( const (
Inbound = true Inbound = true
Outbound = false Outbound = false
) )
// Non-Domain Connections
const (
IncomingHost = "IH"
IncomingLAN = "IL"
IncomingInternet = "II"
IncomingInvalid = "IX"
PeerHost = "PH"
PeerLAN = "PL"
PeerInternet = "PI"
PeerInvalid = "PX"
)

31
network/unknown.go Normal file
View file

@ -0,0 +1,31 @@
package network
import "github.com/Safing/portmaster/process"
// Static reasons
const (
ReasonUnknownProcess = "unknown connection owner: process could not be found"
)
var (
UnknownDirectConnection = &Connection{
Domain: "PI",
Direction: Outbound,
Verdict: DROP,
Reason: ReasonUnknownProcess,
process: process.UnknownProcess,
}
UnknownIncomingConnection = &Connection{
Domain: "II",
Direction: Inbound,
Verdict: DROP,
Reason: ReasonUnknownProcess,
process: process.UnknownProcess,
}
)
func init() {
UnknownDirectConnection.Save()
UnknownIncomingConnection.Save()
}

View file

@ -1,395 +0,0 @@
// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file.
package portmaster
import (
"net"
"os"
"strings"
"github.com/Safing/safing-core/intel"
"github.com/Safing/safing-core/log"
"github.com/Safing/safing-core/network"
"github.com/Safing/safing-core/network/netutils"
"github.com/Safing/safing-core/network/packet"
"github.com/Safing/safing-core/port17/mode"
"github.com/Safing/safing-core/profiles"
"github.com/agext/levenshtein"
)
// use https://github.com/agext/levenshtein
// Call order:
//
// 1. DecideOnConnectionBeforeIntel (if connecting to domain)
// is called when a DNS query is made, before the query is resolved
// 2. DecideOnConnectionAfterIntel (if connecting to domain)
// is called when a DNS query is made, after the query is resolved
// 3. DecideOnConnection
// is called when the first packet of the first link of the connection arrives
// 4. DecideOnLink
// is called when when the first packet of a link arrives only if connection has verdict UNDECIDED or CANTSAY
func DecideOnConnectionBeforeIntel(connection *network.Connection, fqdn string) {
// check:
// Profile.DomainWhitelist
// Profile.Flags
// - process specific: System, Admin, User
// - network specific: Internet, LocalNet
// grant self
if connection.Process().Pid == os.Getpid() {
log.Infof("sheriff: granting own connection %s", connection)
connection.Accept()
return
}
// check if there is a profile
profile := connection.Process().Profile
if profile == nil {
log.Infof("sheriff: no profile, denying connection %s", connection)
connection.AddReason("no profile")
connection.Block()
return
}
// check user class
if profile.Flags.Has(profiles.System) {
if !connection.Process().IsSystem() {
log.Infof("sheriff: denying connection %s, profile has System flag set, but process is not executed by System", connection)
connection.AddReason("must be executed by system")
connection.Block()
return
}
}
if profile.Flags.Has(profiles.Admin) {
if !connection.Process().IsAdmin() {
log.Infof("sheriff: denying connection %s, profile has Admin flag set, but process is not executed by Admin", connection)
connection.AddReason("must be executed by admin")
connection.Block()
return
}
}
if profile.Flags.Has(profiles.User) {
if !connection.Process().IsUser() {
log.Infof("sheriff: denying connection %s, profile has User flag set, but process is not executed by a User", connection)
connection.AddReason("must be executed by user")
connection.Block()
return
}
}
// check for any network access
if !profile.Flags.Has(profiles.Internet) && !profile.Flags.Has(profiles.LocalNet) {
log.Infof("sheriff: denying connection %s, profile denies Internet and local network access", connection)
connection.Block()
return
}
// check domain whitelist/blacklist
if len(profile.DomainWhitelist) > 0 {
matched := false
for _, entry := range profile.DomainWhitelist {
if !strings.HasSuffix(entry, ".") {
entry += "."
}
if strings.HasPrefix(entry, "*") {
if strings.HasSuffix(fqdn, strings.Trim(entry, "*")) {
matched = true
break
}
} else {
if entry == fqdn {
matched = true
break
}
}
}
if matched {
if profile.DomainWhitelistIsBlacklist {
log.Infof("sheriff: denying connection %s, profile has %s in domain blacklist", connection, fqdn)
connection.AddReason("domain blacklisted")
connection.Block()
return
}
} else {
if !profile.DomainWhitelistIsBlacklist {
log.Infof("sheriff: denying connection %s, profile does not have %s in domain whitelist", connection, fqdn)
connection.AddReason("domain not in whitelist")
connection.Block()
return
}
}
}
}
func DecideOnConnectionAfterIntel(connection *network.Connection, fqdn string, rrCache *intel.RRCache) *intel.RRCache {
// check:
// TODO: Profile.ClassificationBlacklist
// TODO: Profile.ClassificationWhitelist
// Profile.Flags
// - network specific: Strict
// check if there is a profile
profile := connection.Process().Profile
if profile == nil {
log.Infof("sheriff: no profile, denying connection %s", connection)
connection.AddReason("no profile")
connection.Block()
return rrCache
}
// check Strict flag
// TODO: drastically improve this!
if profile.Flags.Has(profiles.Strict) {
matched := false
pathElements := strings.Split(connection.Process().Path, "/")
if len(pathElements) > 2 {
pathElements = pathElements[len(pathElements)-2:]
}
domainElements := strings.Split(fqdn, ".")
matchLoop:
for _, domainElement := range domainElements {
for _, pathElement := range pathElements {
if levenshtein.Match(domainElement, pathElement, nil) > 0.5 {
matched = true
break matchLoop
}
}
if levenshtein.Match(domainElement, profile.Name, nil) > 0.5 {
matched = true
break matchLoop
}
if levenshtein.Match(domainElement, connection.Process().Name, nil) > 0.5 {
matched = true
break matchLoop
}
}
if !matched {
log.Infof("sheriff: denying connection %s, profile has declared Strict flag and no match to domain was found", connection)
connection.AddReason("domain does not relate to process")
connection.Block()
return rrCache
}
}
// tunneling
// TODO: link this to real status
port17Active := mode.Client()
if port17Active {
tunnelInfo, err := AssignTunnelIP(fqdn)
if err != nil {
log.Errorf("portmaster: could not get tunnel IP for routing %s: %s", connection, err)
return nil // return nxDomain
}
// save original reply
tunnelInfo.RRCache = rrCache
// return tunnel IP
return tunnelInfo.ExportTunnelIP()
}
return rrCache
}
func DecideOnConnection(connection *network.Connection, pkt packet.Packet) {
// check:
// Profile.Flags
// - process specific: System, Admin, User
// - network specific: Internet, LocalNet, Service, Directconnect
// grant self
if connection.Process().Pid == os.Getpid() {
log.Infof("sheriff: granting own connection %s", connection)
connection.Accept()
return
}
// check if there is a profile
profile := connection.Process().Profile
if profile == nil {
log.Infof("sheriff: no profile, denying connection %s", connection)
connection.AddReason("no profile")
connection.Block()
return
}
// check user class
if profile.Flags.Has(profiles.System) {
if !connection.Process().IsSystem() {
log.Infof("sheriff: denying connection %s, profile has System flag set, but process is not executed by System", connection)
connection.AddReason("must be executed by system")
connection.Block()
return
}
}
if profile.Flags.Has(profiles.Admin) {
if !connection.Process().IsAdmin() {
log.Infof("sheriff: denying connection %s, profile has Admin flag set, but process is not executed by Admin", connection)
connection.AddReason("must be executed by admin")
connection.Block()
return
}
}
if profile.Flags.Has(profiles.User) {
if !connection.Process().IsUser() {
log.Infof("sheriff: denying connection %s, profile has User flag set, but process is not executed by a User", connection)
connection.AddReason("must be executed by user")
connection.Block()
return
}
}
// check for any network access
if !profile.Flags.Has(profiles.Internet) && !profile.Flags.Has(profiles.LocalNet) {
log.Infof("sheriff: denying connection %s, profile denies Internet and local network access", connection)
connection.AddReason("no network access allowed")
connection.Block()
return
}
switch connection.Domain {
case "I":
// check Service flag
if !profile.Flags.Has(profiles.Service) {
log.Infof("sheriff: denying connection %s, profile does not declare service", connection)
connection.AddReason("not a service")
connection.Drop()
return
}
// check if incoming connections are allowed on any port, but only if there no other restrictions
if !!profile.Flags.Has(profiles.Internet) && !!profile.Flags.Has(profiles.LocalNet) && len(profile.ListenPorts) == 0 {
log.Infof("sheriff: granting connection %s, profile allows incoming connections from anywhere and on any port", connection)
connection.Accept()
return
}
case "D":
// check Directconnect flag
if !profile.Flags.Has(profiles.Directconnect) {
log.Infof("sheriff: denying connection %s, profile does not declare direct connections", connection)
connection.AddReason("direct connections (without DNS) not allowed")
connection.Drop()
return
}
}
log.Infof("sheriff: could not decide on connection %s, deciding on per-link basis", connection)
connection.CantSay()
}
func DecideOnLink(connection *network.Connection, link *network.Link, pkt packet.Packet) {
// check:
// Profile.Flags
// - network specific: Internet, LocalNet
// Profile.ConnectPorts
// Profile.ListenPorts
// check if there is a profile
profile := connection.Process().Profile
if profile == nil {
log.Infof("sheriff: no profile, denying %s", link)
link.AddReason("no profile")
link.UpdateVerdict(network.BLOCK)
return
}
// check LocalNet and Internet flags
var remoteIP net.IP
if connection.Direction {
remoteIP = pkt.GetIPHeader().Src
} else {
remoteIP = pkt.GetIPHeader().Dst
}
if netutils.IPIsLocal(remoteIP) {
if !profile.Flags.Has(profiles.LocalNet) {
log.Infof("sheriff: dropping link %s, profile does not allow communication in the local network", link)
link.AddReason("profile does not allow access to local network")
link.UpdateVerdict(network.BLOCK)
return
}
} else {
if !profile.Flags.Has(profiles.Internet) {
log.Infof("sheriff: dropping link %s, profile does not allow communication with the Internet", link)
link.AddReason("profile does not allow access to the Internet")
link.UpdateVerdict(network.BLOCK)
return
}
}
// check connect ports
if connection.Domain != "I" && len(profile.ConnectPorts) > 0 {
tcpUdpHeader := pkt.GetTCPUDPHeader()
if tcpUdpHeader == nil {
log.Infof("sheriff: blocking link %s, profile has declared connect port whitelist, but link is not TCP/UDP", link)
link.AddReason("profile has declared connect port whitelist, but link is not TCP/UDP")
link.UpdateVerdict(network.BLOCK)
return
}
// packet *should* be outbound, but we could be deciding on an already active connection.
var remotePort uint16
if connection.Direction {
remotePort = tcpUdpHeader.SrcPort
} else {
remotePort = tcpUdpHeader.DstPort
}
matched := false
for _, port := range profile.ConnectPorts {
if remotePort == port {
matched = true
break
}
}
if !matched {
log.Infof("sheriff: blocking link %s, remote port %d not in profile connect port whitelist", link, remotePort)
link.AddReason("destination port not in whitelist")
link.UpdateVerdict(network.BLOCK)
return
}
}
// check listen ports
if connection.Domain == "I" && len(profile.ListenPorts) > 0 {
tcpUdpHeader := pkt.GetTCPUDPHeader()
if tcpUdpHeader == nil {
log.Infof("sheriff: dropping link %s, profile has declared listen port whitelist, but link is not TCP/UDP", link)
link.AddReason("profile has declared listen port whitelist, but link is not TCP/UDP")
link.UpdateVerdict(network.DROP)
return
}
// packet *should* be inbound, but we could be deciding on an already active connection.
var localPort uint16
if connection.Direction {
localPort = tcpUdpHeader.DstPort
} else {
localPort = tcpUdpHeader.SrcPort
}
matched := false
for _, port := range profile.ListenPorts {
if localPort == port {
matched = true
break
}
}
if !matched {
log.Infof("sheriff: blocking link %s, local port %d not in profile listen port whitelist", link, localPort)
link.AddReason("listen port not in whitelist")
link.UpdateVerdict(network.BLOCK)
return
}
}
log.Infof("sheriff: accepting link %s", link)
link.UpdateVerdict(network.ACCEPT)
}

107
process/database.go Normal file
View file

@ -0,0 +1,107 @@
package process
import (
"fmt"
"sync"
"time"
"github.com/Safing/portbase/database"
"github.com/Safing/portmaster/profile"
"github.com/tevino/abool"
)
var (
processes = make(map[int]*Process)
processesLock sync.RWMutex
dbController *database.Controller
dbControllerFlag = abool.NewBool(false)
)
// GetProcessFromStorage returns a process from the internal storage.
func GetProcessFromStorage(pid int) (*Process, bool) {
processesLock.RLock()
defer processesLock.RUnlock()
p, ok := processes[pid]
return p, ok
}
// All returns a copy of all process objects.
func All() []*Process {
processesLock.RLock()
defer processesLock.RUnlock()
all := make([]*Process, 0, len(processes))
for _, proc := range processes {
all = append(all, proc)
}
return all
}
// Save saves the process to the internal state and pushes an update.
func (p *Process) Save() {
p.Lock()
defer p.Unlock()
if !p.KeyIsSet() {
p.SetKey(fmt.Sprintf("network:tree/%d", p.Pid))
p.CreateMeta()
}
processesLock.RLock()
_, ok := processes[p.Pid]
processesLock.RUnlock()
if !ok {
processesLock.Lock()
processes[p.Pid] = p
processesLock.Unlock()
}
if dbControllerFlag.IsSet() {
go dbController.PushUpdate(p)
}
}
// Delete deletes a process from the storage and propagates the change.
func (p *Process) Delete() {
p.Lock()
defer p.Unlock()
processesLock.Lock()
delete(processes, p.Pid)
processesLock.Unlock()
p.Meta().Delete()
if dbControllerFlag.IsSet() {
go dbController.PushUpdate(p)
}
// TODO: this should not be necessary, as processes should always have a profileSet.
if p.profileSet != nil {
profile.DeactivateProfileSet(p.profileSet)
}
}
// CleanProcessStorage cleans the storage from old processes.
func CleanProcessStorage(thresholdDuration time.Duration) {
processesLock.Lock()
defer processesLock.Unlock()
threshold := time.Now().Add(-thresholdDuration).Unix()
for _, p := range processes {
p.Lock()
if p.FirstConnectionEstablished < threshold && p.ConnectionCount == 0 {
go p.Delete()
}
p.Unlock()
}
}
// SetDBController sets the database controller and allows the package to push database updates on a save. It must be set by the package that registers the "network" database.
func SetDBController(controller *database.Controller) {
dbController = controller
dbControllerFlag.Set()
}

View file

@ -1,21 +1,7 @@
// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file.
/* /*
Package process fetches process and socket information from the operating system.
Profiles It can find the process owning a network connection.
Profiles describe the network behaviour
Profiles are found in 3 different paths:
- /Me/Profiles/: Profiles used for this system
- /Data/Profiles/: Profiles supplied by Safing
- /Company/Profiles/: Profiles supplied by the company
When a program wants to use the network for the first time, Safing first searches for a Profile in the Company namespace, then in the Data namespace. If neither is found, it searches for a default profile in the same order.
Default profiles are profiles with a path ending with a "/". The default profile with the longest matching path is chosen.
*/ */
package process package process

43
process/executable.go Normal file
View file

@ -0,0 +1,43 @@
// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file.
package process
import (
"crypto"
"encoding/hex"
"hash"
"io"
"os"
)
// GetExecHash returns the hash of the executable with the given algorithm.
func (p *Process) GetExecHash(algorithm string) (string, error) {
sum, ok := p.ExecHashes[algorithm]
if ok {
return sum, nil
}
var hasher hash.Hash
switch algorithm {
case "md5":
hasher = crypto.MD5.New()
case "sha1":
hasher = crypto.SHA1.New()
case "sha256":
hasher = crypto.SHA256.New()
}
file, err := os.Open(p.Path)
if err != nil {
return "", err
}
_, err = io.Copy(hasher, file)
if err != nil {
return "", err
}
sum = hex.EncodeToString(hasher.Sum(nil))
p.ExecHashes[algorithm] = sum
return sum, nil
}

View file

@ -1,74 +0,0 @@
// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file.
package process
import (
"github.com/Safing/safing-core/database"
"strings"
"time"
datastore "github.com/ipfs/go-datastore"
)
// ExecutableSignature stores a signature of an executable.
type ExecutableSignature []byte
// FileInfo stores (security) information about a file.
type FileInfo struct {
database.Base
HumanName string
Owners []string
ApproxLastSeen int64
Signature *ExecutableSignature
}
var fileInfoModel *FileInfo // only use this as parameter for database.EnsureModel-like functions
func init() {
database.RegisterModel(fileInfoModel, func() database.Model { return new(FileInfo) })
}
// Create saves FileInfo with the provided name in the default namespace.
func (m *FileInfo) Create(name string) error {
return m.CreateObject(&database.FileInfoCache, name, m)
}
// CreateInNamespace saves FileInfo with the provided name in the provided namespace.
func (m *FileInfo) CreateInNamespace(namespace *datastore.Key, name string) error {
return m.CreateObject(namespace, name, m)
}
// Save saves FileInfo.
func (m *FileInfo) Save() error {
return m.SaveObject(m)
}
// getFileInfo fetches FileInfo with the provided name from the default namespace.
func getFileInfo(name string) (*FileInfo, error) {
return getFileInfoFromNamespace(&database.FileInfoCache, name)
}
// getFileInfoFromNamespace fetches FileInfo with the provided name from the provided namespace.
func getFileInfoFromNamespace(namespace *datastore.Key, name string) (*FileInfo, error) {
object, err := database.GetAndEnsureModel(namespace, name, fileInfoModel)
if err != nil {
return nil, err
}
model, ok := object.(*FileInfo)
if !ok {
return nil, database.NewMismatchError(object, fileInfoModel)
}
return model, nil
}
// GetFileInfo gathers information about a file and returns *FileInfo
func GetFileInfo(path string) *FileInfo {
// TODO: actually get file information
// TODO: try to load from DB
// TODO: save to DB (key: hash of some sorts)
splittedPath := strings.Split("/", path)
return &FileInfo{
HumanName: splittedPath[len(splittedPath)-1],
ApproxLastSeen: time.Now().Unix(),
}
}

View file

@ -4,14 +4,17 @@ import (
"errors" "errors"
"net" "net"
"github.com/Safing/safing-core/network/packet" "github.com/Safing/portbase/log"
"github.com/Safing/portmaster/network/packet"
) )
// Errors
var ( var (
ErrConnectionNotFound = errors.New("could not find connection") ErrConnectionNotFound = errors.New("could not find connection in system state tables")
ErrProcessNotFound = errors.New("could not find process") ErrProcessNotFound = errors.New("could not find process in system state tables")
) )
// GetPidByPacket returns the pid of the owner of the packet.
func GetPidByPacket(pkt packet.Packet) (pid int, direction bool, err error) { func GetPidByPacket(pkt packet.Packet) (pid int, direction bool, err error) {
var localIP net.IP var localIP net.IP
@ -50,26 +53,33 @@ func GetPidByPacket(pkt packet.Packet) (pid int, direction bool, err error) {
} }
// GetProcessByPacket returns the process that owns the given packet.
func GetProcessByPacket(pkt packet.Packet) (process *Process, direction bool, err error) { func GetProcessByPacket(pkt packet.Packet) (process *Process, direction bool, err error) {
var pid int var pid int
pid, direction, err = GetPidByPacket(pkt) pid, direction, err = GetPidByPacket(pkt)
if pid < 0 {
return nil, direction, ErrConnectionNotFound
}
if err != nil { if err != nil {
return nil, direction, err return nil, direction, err
} }
if pid < 0 {
return nil, direction, ErrConnectionNotFound
}
process, err = GetOrFindProcess(pid) process, err = GetOrFindProcess(pid)
if err != nil { if err != nil {
return nil, direction, err return nil, direction, err
} }
err = process.FindProfiles()
if err != nil {
log.Errorf("failed to find profiles for process %s: %s", process.String(), err)
}
return process, direction, nil return process, direction, nil
} }
// GetPidByEndpoints returns the pid of the owner of the described link.
func GetPidByEndpoints(localIP net.IP, localPort uint16, remoteIP net.IP, remotePort uint16, protocol packet.IPProtocol) (pid int, direction bool, err error) { func GetPidByEndpoints(localIP net.IP, localPort uint16, remoteIP net.IP, remotePort uint16, protocol packet.IPProtocol) (pid int, direction bool, err error) {
ipVersion := packet.IPv4 ipVersion := packet.IPv4
@ -92,6 +102,7 @@ func GetPidByEndpoints(localIP net.IP, localPort uint16, remoteIP net.IP, remote
} }
// GetProcessByEndpoints returns the process that owns the described link.
func GetProcessByEndpoints(localIP net.IP, localPort uint16, remoteIP net.IP, remotePort uint16, protocol packet.IPProtocol) (process *Process, err error) { func GetProcessByEndpoints(localIP net.IP, localPort uint16, remoteIP net.IP, remotePort uint16, protocol packet.IPProtocol) (process *Process, err error) {
var pid int var pid int
@ -108,41 +119,16 @@ func GetProcessByEndpoints(localIP net.IP, localPort uint16, remoteIP net.IP, re
return nil, err return nil, err
} }
err = process.FindProfiles()
if err != nil {
log.Errorf("failed to find profiles for process %s: %s", process.String(), err)
}
return process, nil return process, nil
} }
// GetActiveConnectionIDs returns a list of all active connection IDs.
func GetActiveConnectionIDs() []string { func GetActiveConnectionIDs() []string {
return getActiveConnectionIDs() return getActiveConnectionIDs()
} }
// func GetProcessByPid(pid int) *Process {
// process, err := GetOrFindProcess(pid)
// if err != nil {
// log.Warningf("process: failed to get process %d: %s", pid, err)
// return nil
// }
// return process
// }
// func GetProcessOfConnection(localIP *net.IP, localPort uint16, protocol uint8) (process *Process, status uint8) {
// pid, status := GetPidOfConnection(localIP, localPort, protocol)
// if status == Success {
// process = GetProcessByPid(pid)
// if process == nil {
// return nil, NoProcessInfo
// }
// }
// return
// }
// func GetProcessByPacket(pkt packet.Packet) (process *Process, direction bool, status uint8) {
// pid, direction, status := GetPidByPacket(pkt)
// if status == Success {
// process = GetProcessByPid(pid)
// if process == nil {
// return nil, direction, NoProcessInfo
// }
// }
// return
// }

View file

@ -1,6 +1,8 @@
package process package process
import "github.com/Safing/safing-core/process/proc" import (
"github.com/Safing/portmaster/process/proc"
)
var ( var (
getTCP4PacketInfo = proc.GetTCP4PacketInfo getTCP4PacketInfo = proc.GetTCP4PacketInfo

Binary file not shown.

115
process/matching.go Normal file
View file

@ -0,0 +1,115 @@
package process
import (
"github.com/Safing/portbase/database"
"github.com/Safing/portbase/database/query"
"github.com/Safing/portbase/log"
"github.com/Safing/portmaster/profile"
)
var (
profileDB = database.NewInterface(nil)
)
// FindProfiles finds and assigns a profile set to the process.
func (p *Process) FindProfiles() error {
p.Lock()
defer p.Unlock()
// only find profiles if not already done.
if p.profileSet != nil {
return nil
}
// User Profile
it, err := profileDB.Query(query.New(profile.MakeProfileKey(profile.UserNamespace, "")).Where(query.Where("LinkedPath", query.SameAs, p.Path)))
if err != nil {
return err
}
var userProfile *profile.Profile
for r := range it.Next {
it.Cancel()
userProfile, err = profile.EnsureProfile(r)
if err != nil {
return err
}
break
}
if it.Err() != nil {
return it.Err()
}
// create new profile if it does not exist.
if userProfile == nil {
// create new profile
userProfile = profile.New()
userProfile.Name = p.ExecName
userProfile.LinkedPath = p.Path
}
if userProfile.MarkUsed() {
userProfile.Save(profile.UserNamespace)
}
// Stamp
// Find/Re-evaluate Stamp profile
// 1. check linked stamp profile
// 2. if last check is was more than a week ago, fetch from stamp:
// 3. send path identifier to stamp
// 4. evaluate all returned profiles
// 5. select best
// 6. link stamp profile to user profile
// FIXME: implement!
p.UserProfileKey = userProfile.Key()
p.profileSet = profile.NewSet(userProfile, nil)
go p.Save()
return nil
}
func selectProfile(p *Process, profs []*profile.Profile) (selectedProfile *profile.Profile) {
var highestScore int
for _, prof := range profs {
score := matchProfile(p, prof)
if score > highestScore {
selectedProfile = prof
}
}
return
}
func matchProfile(p *Process, prof *profile.Profile) (score int) {
for _, fp := range prof.Fingerprints {
score += matchFingerprint(p, fp)
}
return
}
func matchFingerprint(p *Process, fp *profile.Fingerprint) (score int) {
if !fp.MatchesOS() {
return 0
}
switch fp.Type {
case "full_path":
if p.Path == fp.Value {
}
return profile.GetFingerprintWeight(fp.Type)
case "partial_path":
// FIXME: if full_path matches, do not match partial paths
return profile.GetFingerprintWeight(fp.Type)
case "md5_sum", "sha1_sum", "sha256_sum":
// FIXME: one sum is enough, check sums in a grouped form, start with the best
sum, err := p.GetExecHash(fp.Type)
if err != nil {
log.Errorf("process: failed to get hash of executable: %s", err)
} else if sum == fp.Value {
return profile.GetFingerprintWeight(fp.Type)
}
}
return 0
}

View file

@ -13,14 +13,19 @@ const (
NoProcess NoProcess
) )
func GetPidOfConnection(localIP *net.IP, localPort uint16, protocol uint8) (pid int, status uint8) { var (
waitTime = 15 * time.Millisecond
)
// GetPidOfConnection returns the PID of the given connection.
func GetPidOfConnection(localIP net.IP, localPort uint16, protocol uint8) (pid int, status uint8) {
uid, inode, ok := getConnectionSocket(localIP, localPort, protocol) uid, inode, ok := getConnectionSocket(localIP, localPort, protocol)
if !ok { if !ok {
uid, inode, ok = getListeningSocket(localIP, localPort, protocol) uid, inode, ok = getListeningSocket(localIP, localPort, protocol)
for i := 0; i < 3 && !ok; i++ { for i := 0; i < 3 && !ok; i++ {
// give kernel some time, then try again // give kernel some time, then try again
// log.Tracef("process: giving kernel some time to think") // log.Tracef("process: giving kernel some time to think")
time.Sleep(15 * time.Millisecond) time.Sleep(waitTime)
uid, inode, ok = getConnectionSocket(localIP, localPort, protocol) uid, inode, ok = getConnectionSocket(localIP, localPort, protocol)
if !ok { if !ok {
uid, inode, ok = getListeningSocket(localIP, localPort, protocol) uid, inode, ok = getListeningSocket(localIP, localPort, protocol)
@ -30,27 +35,48 @@ func GetPidOfConnection(localIP *net.IP, localPort uint16, protocol uint8) (pid
return -1, NoSocket return -1, NoSocket
} }
} }
pid, ok = GetPidOfInode(uid, inode) pid, ok = GetPidOfInode(uid, inode)
for i := 0; i < 3 && !ok; i++ { for i := 0; i < 3 && !ok; i++ {
// give kernel some time, then try again // give kernel some time, then try again
// log.Tracef("process: giving kernel some time to think") // log.Tracef("process: giving kernel some time to think")
time.Sleep(15 * time.Millisecond) time.Sleep(waitTime)
pid, ok = GetPidOfInode(uid, inode) pid, ok = GetPidOfInode(uid, inode)
} }
if !ok { if !ok {
return -1, NoProcess return -1, NoProcess
} }
return return
} }
func GetPidOfIncomingConnection(localIP *net.IP, localPort uint16, protocol uint8) (pid int, status uint8) { // GetPidOfConnection returns the PID of the given incoming connection.
func GetPidOfIncomingConnection(localIP net.IP, localPort uint16, protocol uint8) (pid int, status uint8) {
uid, inode, ok := getListeningSocket(localIP, localPort, protocol) uid, inode, ok := getListeningSocket(localIP, localPort, protocol)
if !ok {
// for TCP4 and UDP4, also try TCP6 and UDP6, as linux sometimes treats them as a single dual socket, and shows the IPv6 version.
switch protocol {
case TCP4:
uid, inode, ok = getListeningSocket(localIP, localPort, TCP6)
case UDP4:
uid, inode, ok = getListeningSocket(localIP, localPort, UDP6)
}
if !ok { if !ok {
return -1, NoSocket return -1, NoSocket
} }
}
pid, ok = GetPidOfInode(uid, inode) pid, ok = GetPidOfInode(uid, inode)
for i := 0; i < 3 && !ok; i++ {
// give kernel some time, then try again
// log.Tracef("process: giving kernel some time to think")
time.Sleep(waitTime)
pid, ok = GetPidOfInode(uid, inode)
}
if !ok { if !ok {
return -1, NoProcess return -1, NoProcess
} }
return return
} }

View file

@ -6,39 +6,39 @@ import (
) )
func GetTCP4PacketInfo(localIP net.IP, localPort uint16, remoteIP net.IP, remotePort uint16, pktDirection bool) (pid int, direction bool, err error) { func GetTCP4PacketInfo(localIP net.IP, localPort uint16, remoteIP net.IP, remotePort uint16, pktDirection bool) (pid int, direction bool, err error) {
return search(TCP4, localIP, localPort, direction) return search(TCP4, localIP, localPort, pktDirection)
} }
func GetTCP6PacketInfo(localIP net.IP, localPort uint16, remoteIP net.IP, remotePort uint16, pktDirection bool) (pid int, direction bool, err error) { func GetTCP6PacketInfo(localIP net.IP, localPort uint16, remoteIP net.IP, remotePort uint16, pktDirection bool) (pid int, direction bool, err error) {
return search(TCP6, localIP, localPort, direction) return search(TCP6, localIP, localPort, pktDirection)
} }
func GetUDP4PacketInfo(localIP net.IP, localPort uint16, remoteIP net.IP, remotePort uint16, pktDirection bool) (pid int, direction bool, err error) { func GetUDP4PacketInfo(localIP net.IP, localPort uint16, remoteIP net.IP, remotePort uint16, pktDirection bool) (pid int, direction bool, err error) {
return search(UDP4, localIP, localPort, direction) return search(UDP4, localIP, localPort, pktDirection)
} }
func GetUDP6PacketInfo(localIP net.IP, localPort uint16, remoteIP net.IP, remotePort uint16, pktDirection bool) (pid int, direction bool, err error) { func GetUDP6PacketInfo(localIP net.IP, localPort uint16, remoteIP net.IP, remotePort uint16, pktDirection bool) (pid int, direction bool, err error) {
return search(UDP6, localIP, localPort, direction) return search(UDP6, localIP, localPort, pktDirection)
} }
func search(protocol uint8, localIP net.IP, localPort uint16, pktDirection bool) (pid int, direction bool, err error) { func search(protocol uint8, localIP net.IP, localPort uint16, pktDirection bool) (pid int, direction bool, err error) {
var status uint8 var status uint8
if pktDirection { if pktDirection {
pid, status = GetPidOfIncomingConnection(&localIP, localPort, protocol) pid, status = GetPidOfIncomingConnection(localIP, localPort, protocol)
if pid >= 0 { if pid >= 0 {
return pid, true, nil return pid, true, nil
} }
// pid, status = GetPidOfConnection(&localIP, localPort, protocol) // pid, status = GetPidOfConnection(localIP, localPort, protocol)
// if pid >= 0 { // if pid >= 0 {
// return pid, false, nil // return pid, false, nil
// } // }
} else { } else {
pid, status = GetPidOfConnection(&localIP, localPort, protocol) pid, status = GetPidOfConnection(localIP, localPort, protocol)
if pid >= 0 { if pid >= 0 {
return pid, false, nil return pid, false, nil
} }
// pid, status = GetPidOfIncomingConnection(&localIP, localPort, protocol) // pid, status = GetPidOfIncomingConnection(localIP, localPort, protocol)
// if pid >= 0 { // if pid >= 0 {
// return pid, true, nil // return pid, true, nil
// } // }

View file

@ -10,7 +10,7 @@ import (
"sync" "sync"
"syscall" "syscall"
"github.com/Safing/safing-core/log" "github.com/Safing/portbase/log"
) )
var ( var (

View file

@ -13,7 +13,7 @@ import (
"sync" "sync"
"unicode" "unicode"
"github.com/Safing/safing-core/log" "github.com/Safing/portbase/log"
) )
/* /*
@ -81,7 +81,7 @@ var (
globalListeningUDP6 = make(map[uint16][]int) globalListeningUDP6 = make(map[uint16][]int)
) )
func getConnectionSocket(localIP *net.IP, localPort uint16, protocol uint8) (int, int, bool) { func getConnectionSocket(localIP net.IP, localPort uint16, protocol uint8) (int, int, bool) {
// listeningSocketsLock.Lock() // listeningSocketsLock.Lock()
// defer listeningSocketsLock.Unlock() // defer listeningSocketsLock.Unlock()
@ -98,10 +98,10 @@ func getConnectionSocket(localIP *net.IP, localPort uint16, protocol uint8) (int
localIPHex = strings.ToUpper(hex.EncodeToString([]byte{localIPBytes[3], localIPBytes[2], localIPBytes[1], localIPBytes[0]})) localIPHex = strings.ToUpper(hex.EncodeToString([]byte{localIPBytes[3], localIPBytes[2], localIPBytes[1], localIPBytes[0]}))
case TCP6: case TCP6:
procFile = TCP6Data procFile = TCP6Data
localIPHex = hex.EncodeToString([]byte(*localIP)) localIPHex = hex.EncodeToString([]byte(localIP))
case UDP6: case UDP6:
procFile = UDP6Data procFile = UDP6Data
localIPHex = hex.EncodeToString([]byte(*localIP)) localIPHex = hex.EncodeToString([]byte(localIP))
} }
localPortHex := fmt.Sprintf("%04X", localPort) localPortHex := fmt.Sprintf("%04X", localPort)
@ -162,38 +162,38 @@ func getConnectionSocket(localIP *net.IP, localPort uint16, protocol uint8) (int
} }
func getListeningSocket(localIP *net.IP, localPort uint16, protocol uint8) (uid, inode int, ok bool) { func getListeningSocket(localIP net.IP, localPort uint16, protocol uint8) (uid, inode int, ok bool) {
listeningSocketsLock.Lock() listeningSocketsLock.Lock()
defer listeningSocketsLock.Unlock() defer listeningSocketsLock.Unlock()
var addressListening *map[string][]int var addressListening map[string][]int
var globalListening *map[uint16][]int var globalListening map[uint16][]int
switch protocol { switch protocol {
case TCP4: case TCP4:
addressListening = &addressListeningTCP4 addressListening = addressListeningTCP4
globalListening = &globalListeningTCP4 globalListening = globalListeningTCP4
case UDP4: case UDP4:
addressListening = &addressListeningUDP4 addressListening = addressListeningUDP4
globalListening = &globalListeningUDP4 globalListening = globalListeningUDP4
case TCP6: case TCP6:
addressListening = &addressListeningTCP6 addressListening = addressListeningTCP6
globalListening = &globalListeningTCP6 globalListening = globalListeningTCP6
case UDP6: case UDP6:
addressListening = &addressListeningUDP6 addressListening = addressListeningUDP6
globalListening = &globalListeningUDP6 globalListening = globalListeningUDP6
} }
data, ok := (*addressListening)[fmt.Sprintf("%s:%d", localIP, localPort)] data, ok := addressListening[fmt.Sprintf("%s:%d", localIP, localPort)]
if !ok { if !ok {
data, ok = (*globalListening)[localPort] data, ok = globalListening[localPort]
} }
if ok { if ok {
return data[0], data[1], true return data[0], data[1], true
} }
updateListeners(protocol) updateListeners(protocol)
data, ok = (*addressListening)[fmt.Sprintf("%s:%d", localIP, localPort)] data, ok = addressListening[fmt.Sprintf("%s:%d", localIP, localPort)]
if !ok { if !ok {
data, ok = (*globalListening)[localPort] data, ok = globalListening[localPort]
} }
if ok { if ok {
return data[0], data[1], true return data[0], data[1], true
@ -206,7 +206,7 @@ func procDelimiter(c rune) bool {
return unicode.IsSpace(c) || c == ':' return unicode.IsSpace(c) || c == ':'
} }
func convertIPv4(data string) *net.IP { func convertIPv4(data string) net.IP {
decoded, err := hex.DecodeString(data) decoded, err := hex.DecodeString(data)
if err != nil { if err != nil {
log.Warningf("process: could not parse IPv4 %s: %s", data, err) log.Warningf("process: could not parse IPv4 %s: %s", data, err)
@ -217,10 +217,10 @@ func convertIPv4(data string) *net.IP {
return nil return nil
} }
ip := net.IPv4(decoded[3], decoded[2], decoded[1], decoded[0]) ip := net.IPv4(decoded[3], decoded[2], decoded[1], decoded[0])
return &ip return ip
} }
func convertIPv6(data string) *net.IP { func convertIPv6(data string) net.IP {
decoded, err := hex.DecodeString(data) decoded, err := hex.DecodeString(data)
if err != nil { if err != nil {
log.Warningf("process: could not parse IPv6 %s: %s", data, err) log.Warningf("process: could not parse IPv6 %s: %s", data, err)
@ -231,7 +231,7 @@ func convertIPv6(data string) *net.IP {
return nil return nil
} }
ip := net.IP(decoded) ip := net.IP(decoded)
return &ip return ip
} }
func updateListeners(protocol uint8) { func updateListeners(protocol uint8) {
@ -247,7 +247,7 @@ func updateListeners(protocol uint8) {
} }
} }
func getListenerMaps(procFile, zeroIP, socketStatusListening string, ipConverter func(string) *net.IP) (map[string][]int, map[uint16][]int) { func getListenerMaps(procFile, zeroIP, socketStatusListening string, ipConverter func(string) net.IP) (map[string][]int, map[uint16][]int) {
addressListening := make(map[string][]int) addressListening := make(map[string][]int)
globalListening := make(map[uint16][]int) globalListening := make(map[uint16][]int)
@ -312,6 +312,7 @@ func getListenerMaps(procFile, zeroIP, socketStatusListening string, ipConverter
return addressListening, globalListening return addressListening, globalListening
} }
// GetActiveConnectionIDs returns all connection IDs that are still marked as active by the OS.
func GetActiveConnectionIDs() []string { func GetActiveConnectionIDs() []string {
var connections []string var connections []string
@ -323,7 +324,7 @@ func GetActiveConnectionIDs() []string {
return connections return connections
} }
func getConnectionIDsFromSource(source string, protocol uint16, ipConverter func(string) *net.IP) []string { func getConnectionIDsFromSource(source string, protocol uint16, ipConverter func(string) net.IP) []string {
var connections []string var connections []string
// open file // open file

View file

@ -22,14 +22,14 @@ func TestSockets(t *testing.T) {
t.Logf("addressListeningUDP6: %v", addressListeningUDP6) t.Logf("addressListeningUDP6: %v", addressListeningUDP6)
t.Logf("globalListeningUDP6: %v", globalListeningUDP6) t.Logf("globalListeningUDP6: %v", globalListeningUDP6)
getListeningSocket(&net.IPv4zero, 53, TCP4) getListeningSocket(net.IPv4zero, 53, TCP4)
getListeningSocket(&net.IPv4zero, 53, UDP4) getListeningSocket(net.IPv4zero, 53, UDP4)
getListeningSocket(&net.IPv6zero, 53, TCP6) getListeningSocket(net.IPv6zero, 53, TCP6)
getListeningSocket(&net.IPv6zero, 53, UDP6) getListeningSocket(net.IPv6zero, 53, UDP6)
// spotify: 192.168.0.102:5353 192.121.140.65:80 // spotify: 192.168.0.102:5353 192.121.140.65:80
localIP := net.IPv4(192, 168, 127, 10) localIP := net.IPv4(192, 168, 127, 10)
uid, inode, ok := getConnectionSocket(&localIP, 46634, TCP4) uid, inode, ok := getConnectionSocket(localIP, 46634, TCP4)
t.Logf("getConnectionSocket: %d %d %v", uid, inode, ok) t.Logf("getConnectionSocket: %d %d %v", uid, inode, ok)
activeConnectionIDs := GetActiveConnectionIDs() activeConnectionIDs := GetActiveConnectionIDs()

View file

@ -5,20 +5,22 @@ package process
import ( import (
"fmt" "fmt"
"runtime" "runtime"
"strconv"
"strings" "strings"
"sync"
"time"
datastore "github.com/ipfs/go-datastore"
processInfo "github.com/shirou/gopsutil/process" processInfo "github.com/shirou/gopsutil/process"
"github.com/Safing/safing-core/database" "github.com/Safing/portbase/database/record"
"github.com/Safing/safing-core/log" "github.com/Safing/portbase/log"
"github.com/Safing/safing-core/profiles" "github.com/Safing/portmaster/profile"
) )
// A Process represents a process running on the operating system // A Process represents a process running on the operating system
type Process struct { type Process struct {
database.Base record.Base
sync.Mutex
UserID int UserID int
UserName string UserName string
UserHome string UserHome string
@ -26,68 +28,70 @@ type Process struct {
ParentPid int ParentPid int
Path string Path string
Cwd string Cwd string
FileInfo *FileInfo
CmdLine string CmdLine string
FirstArg string FirstArg string
ProfileKey string
Profile *profiles.Profile ExecName string
ExecHashes map[string]string
// ExecOwner ...
// ExecSignature ...
UserProfileKey string
profileSet *profile.Set
Name string Name string
Icon string Icon string
// Icon is a path to the icon and is either prefixed "f:" for filepath, "d:" for database cache path or "c:"/"a:" for a the icon key to fetch it from a company / authoritative node and cache it in its own cache. // Icon is a path to the icon and is either prefixed "f:" for filepath, "d:" for database cache path or "c:"/"a:" for a the icon key to fetch it from a company / authoritative node and cache it in its own cache.
FirstConnectionEstablished int64
LastConnectionEstablished int64
ConnectionCount uint
} }
var processModel *Process // only use this as parameter for database.EnsureModel-like functions // ProfileSet returns the assigned profile set.
func (p *Process) ProfileSet() *profile.Set {
p.Lock()
defer p.Unlock()
func init() { return p.profileSet
database.RegisterModel(processModel, func() database.Model { return new(Process) })
} }
// Create saves Process with the provided name in the default namespace. // Strings returns a string represenation of process.
func (m *Process) Create(name string) error { func (p *Process) String() string {
return m.CreateObject(&database.Processes, name, m) p.Lock()
} defer p.Unlock()
// CreateInNamespace saves Process with the provided name in the provided namespace. if p == nil {
func (m *Process) CreateInNamespace(namespace *datastore.Key, name string) error {
return m.CreateObject(namespace, name, m)
}
// Save saves Process.
func (m *Process) Save() error {
return m.SaveObject(m)
}
// GetProcess fetches Process with the provided name from the default namespace.
func GetProcess(name string) (*Process, error) {
return GetProcessFromNamespace(&database.Processes, name)
}
// GetProcessFromNamespace fetches Process with the provided name from the provided namespace.
func GetProcessFromNamespace(namespace *datastore.Key, name string) (*Process, error) {
object, err := database.GetAndEnsureModel(namespace, name, processModel)
if err != nil {
return nil, err
}
model, ok := object.(*Process)
if !ok {
return nil, database.NewMismatchError(object, processModel)
}
return model, nil
}
func (m *Process) String() string {
if m == nil {
return "?" return "?"
} }
if m.Profile != nil && !m.Profile.Default { return fmt.Sprintf("%s:%s:%d", p.UserName, p.Path, p.Pid)
return fmt.Sprintf("%s:%s:%d", m.UserName, m.Profile, m.Pid)
}
return fmt.Sprintf("%s:%s:%d", m.UserName, m.Path, m.Pid)
} }
// AddConnection increases the connection counter and the last connection timestamp.
func (p *Process) AddConnection() {
p.Lock()
defer p.Unlock()
p.ConnectionCount++
p.LastConnectionEstablished = time.Now().Unix()
if p.FirstConnectionEstablished == 0 {
p.FirstConnectionEstablished = p.LastConnectionEstablished
}
}
// RemoveConnection lowers the connection counter by one.
func (p *Process) RemoveConnection() {
p.Lock()
defer p.Unlock()
if p.ConnectionCount > 0 {
p.ConnectionCount--
}
}
// GetOrFindProcess returns the process for the given PID.
func GetOrFindProcess(pid int) (*Process, error) { func GetOrFindProcess(pid int) (*Process, error) {
process, err := GetProcess(strconv.Itoa(pid)) process, ok := GetProcessFromStorage(pid)
if err == nil { if ok {
return process, nil return process, nil
} }
@ -96,13 +100,9 @@ func GetOrFindProcess(pid int) (*Process, error) {
} }
switch { switch {
case (pid == 0 && runtime.GOOS == "linux") || (pid == 4 && runtime.GOOS == "windows"): case new.IsKernel():
new.UserName = "Kernel" new.UserName = "Kernel"
new.Name = "Operating System" new.Name = "Operating System"
new.Profile = &profiles.Profile{
Name: "OS",
Flags: []int8{profiles.Internet, profiles.LocalNet, profiles.Directconnect, profiles.Service},
}
default: default:
pInfo, err := processInfo.NewProcess(int32(pid)) pInfo, err := processInfo.NewProcess(int32(pid))
@ -113,7 +113,8 @@ func GetOrFindProcess(pid int) (*Process, error) {
// UID // UID
// net yet implemented for windows // net yet implemented for windows
if runtime.GOOS == "linux" { if runtime.GOOS == "linux" {
uids, err := pInfo.Uids() var uids []int32
uids, err = pInfo.Uids()
if err != nil { if err != nil {
log.Warningf("process: failed to get UID: %s", err) log.Warningf("process: failed to get UID: %s", err)
} else { } else {
@ -167,85 +168,87 @@ func GetOrFindProcess(pid int) (*Process, error) {
// new.Icon, err = // new.Icon, err =
// get Profile // get Profile
processPath := new.Path // processPath := new.Path
var applyProfile *profiles.Profile // var applyProfile *profiles.Profile
iterations := 0 // iterations := 0
for applyProfile == nil { // for applyProfile == nil {
//
iterations++ // iterations++
if iterations > 10 { // if iterations > 10 {
log.Warningf("process: got into loop while getting profile for %s", new) // log.Warningf("process: got into loop while getting profile for %s", new)
break // break
}
applyProfile, err = profiles.GetActiveProfileByPath(processPath)
if err == database.ErrNotFound {
applyProfile, err = profiles.FindProfileByPath(processPath, new.UserHome)
}
if err != nil {
log.Warningf("process: could not get profile for %s: %s", new, err)
} else if applyProfile == nil {
log.Warningf("process: no default profile found for %s", new)
} else {
// TODO: there is a lot of undefined behaviour if chaining framework profiles
// process framework
if applyProfile.Framework != nil {
if applyProfile.Framework.FindParent > 0 {
var ppid int32
for i := uint8(1); i < applyProfile.Framework.FindParent; i++ {
parent, err := pInfo.Parent()
if err != nil {
return nil, err
}
ppid = parent.Pid
}
if applyProfile.Framework.MergeWithParent {
return GetOrFindProcess(int(ppid))
}
// processPath, err = os.Readlink(fmt.Sprintf("/proc/%d/exe", pid))
// if err != nil {
// return nil, fmt.Errorf("could not read /proc/%d/exe: %s", pid, err)
// } // }
continue //
// applyProfile, err = profiles.GetActiveProfileByPath(processPath)
// if err == database.ErrNotFound {
// applyProfile, err = profiles.FindProfileByPath(processPath, new.UserHome)
// }
// if err != nil {
// log.Warningf("process: could not get profile for %s: %s", new, err)
// } else if applyProfile == nil {
// log.Warningf("process: no default profile found for %s", new)
// } else {
//
// // TODO: there is a lot of undefined behaviour if chaining framework profiles
//
// // process framework
// if applyProfile.Framework != nil {
// if applyProfile.Framework.FindParent > 0 {
// var ppid int32
// for i := uint8(1); i < applyProfile.Framework.FindParent; i++ {
// parent, err := pInfo.Parent()
// if err != nil {
// return nil, err
// }
// ppid = parent.Pid
// }
// if applyProfile.Framework.MergeWithParent {
// return GetOrFindProcess(int(ppid))
// }
// // processPath, err = os.Readlink(fmt.Sprintf("/proc/%d/exe", pid))
// // if err != nil {
// // return nil, fmt.Errorf("could not read /proc/%d/exe: %s", pid, err)
// // }
// continue
// }
//
// newCommand, err := applyProfile.Framework.GetNewPath(new.CmdLine, new.Cwd)
// if err != nil {
// return nil, err
// }
//
// // assign
// new.CmdLine = newCommand
// new.Path = strings.SplitN(newCommand, " ", 2)[0]
// processPath = new.Path
//
// // make sure we loop
// applyProfile = nil
// continue
// }
//
// // apply profile to process
// log.Debugf("process: applied profile to %s: %s", new, applyProfile)
// new.Profile = applyProfile
// new.ProfileKey = applyProfile.GetKey().String()
//
// // update Profile with Process icon if Profile does not have one
// if !new.Profile.Default && new.Icon != "" && new.Profile.Icon == "" {
// new.Profile.Icon = new.Icon
// new.Profile.Save()
// }
// }
// }
// Executable Information
// FIXME: use os specific path seperator
splittedPath := strings.Split(new.Path, "/")
new.ExecName = splittedPath[len(splittedPath)-1]
} }
newCommand, err := applyProfile.Framework.GetNewPath(new.CmdLine, new.Cwd) // save to storage
if err != nil { new.Save()
return nil, err
}
// assign
new.CmdLine = newCommand
new.Path = strings.SplitN(newCommand, " ", 2)[0]
processPath = new.Path
// make sure we loop
applyProfile = nil
continue
}
// apply profile to process
log.Debugf("process: applied profile to %s: %s", new, applyProfile)
new.Profile = applyProfile
new.ProfileKey = applyProfile.GetKey().String()
// update Profile with Process icon if Profile does not have one
if !new.Profile.Default && new.Icon != "" && new.Profile.Icon == "" {
new.Profile.Icon = new.Icon
new.Profile.Save()
}
}
}
// get FileInfo
new.FileInfo = GetFileInfo(new.Path)
}
// save to DB
new.Create(strconv.Itoa(new.Pid))
return new, nil return new, nil
} }

View file

@ -1,13 +1,21 @@
package process package process
// IsUser returns whether the process is run by a normal user.
func (m *Process) IsUser() bool { func (m *Process) IsUser() bool {
return m.UserID >= 1000 return m.UserID >= 1000
} }
// IsAdmin returns whether the process is run by an admin user.
func (m *Process) IsAdmin() bool { func (m *Process) IsAdmin() bool {
return m.UserID >= 0 return m.UserID >= 0
} }
// IsSystem returns whether the process is run by the operating system.
func (m *Process) IsSystem() bool { func (m *Process) IsSystem() bool {
return m.UserID == 0 return m.UserID == 0
} }
// IsKernel returns whether the process is the Kernel.
func (m *Process) IsKernel() bool {
return m.Pid == 0
}

View file

@ -2,15 +2,23 @@ package process
import "strings" import "strings"
// IsUser returns whether the process is run by a normal user.
func (m *Process) IsUser() bool { func (m *Process) IsUser() bool {
return m.Pid != 4 && // Kernel return m.Pid != 4 && // Kernel
!strings.HasPrefix(m.UserName, "NT-") // NT-Authority (localized!) !strings.HasPrefix(m.UserName, "NT-") // NT-Authority (localized!)
} }
// IsAdmin returns whether the process is run by an admin user.
func (m *Process) IsAdmin() bool { func (m *Process) IsAdmin() bool {
return strings.HasPrefix(m.UserName, "NT-") // NT-Authority (localized!) return strings.HasPrefix(m.UserName, "NT-") // NT-Authority (localized!)
} }
// IsSystem returns whether the process is run by the operating system.
func (m *Process) IsSystem() bool { func (m *Process) IsSystem() bool {
return m.Pid == 4 return m.Pid == 4
} }
// IsKernel returns whether the process is the Kernel.
func (m *Process) IsKernel() bool {
return m.Pid == 4
}

16
process/unknown.go Normal file
View file

@ -0,0 +1,16 @@
package process
var (
// UnknownProcess is used when a process cannot be found.
UnknownProcess = &Process{
UserID: -1,
UserName: "Unknown",
Pid: -1,
ParentPid: -1,
Name: "Unknown Processes",
}
)
func init() {
UnknownProcess.Save()
}

54
profile/active.go Normal file
View file

@ -0,0 +1,54 @@
package profile
import "sync"
var (
activeProfileSets = make(map[string]*Set)
activeProfileSetsLock sync.RWMutex
)
func activateProfileSet(set *Set) {
set.Lock()
defer set.Unlock()
activeProfileSetsLock.Lock()
defer activeProfileSetsLock.Unlock()
activeProfileSets[set.profiles[0].ID] = set
}
// DeactivateProfileSet marks a profile set as not active.
func DeactivateProfileSet(set *Set) {
set.Lock()
defer set.Unlock()
activeProfileSetsLock.Lock()
defer activeProfileSetsLock.Unlock()
delete(activeProfileSets, set.profiles[0].ID)
}
func updateActiveUserProfile(profile *Profile) {
activeProfileSetsLock.RLock()
defer activeProfileSetsLock.RUnlock()
activeSet, ok := activeProfileSets[profile.ID]
if ok {
activeSet.Lock()
defer activeSet.Unlock()
activeSet.profiles[0] = profile
}
}
func updateActiveStampProfile(profile *Profile) {
activeProfileSetsLock.RLock()
defer activeProfileSetsLock.RUnlock()
for _, activeSet := range activeProfileSets {
activeSet.Lock()
activeProfile := activeSet.profiles[2]
if activeProfile != nil {
activeProfile.Lock()
if activeProfile.ID == profile.ID {
activeSet.profiles[2] = profile
}
activeProfile.Unlock()
}
activeSet.Unlock()
}
}

8
profile/const.go Normal file
View file

@ -0,0 +1,8 @@
package profile
// Platform identifiers
const (
PlatformLinux = "linux"
PlatformWindows = "windows"
PlatformMac = "macos"
)

6
profile/const_darwin.go Normal file
View file

@ -0,0 +1,6 @@
package profile
// OS Identifier
const (
osIdentifier = PlatformMac
)

6
profile/const_linux.go Normal file
View file

@ -0,0 +1,6 @@
package profile
// OS Identifier
const (
osIdentifier = PlatformLinux
)

22
profile/database.go Normal file
View file

@ -0,0 +1,22 @@
package profile
import (
"github.com/Safing/portbase/database"
)
// core:profiles/user/12345-1234-125-1234-1235
// core:profiles/special/default
// /global
// core:profiles/stamp/12334-1235-1234-5123-1234
// core:profiles/identifier/base64
// Namespaces
const (
UserNamespace = "user"
StampNamespace = "stamp"
SpecialNamespace = "special"
)
var (
profileDB = database.NewInterface(nil)
)

44
profile/defaults.go Normal file
View file

@ -0,0 +1,44 @@
package profile
import (
"github.com/Safing/portmaster/status"
)
func makeDefaultGlobalProfile() *Profile {
return &Profile{
ID: "global",
Name: "Global Profile",
}
}
func makeDefaultFallbackProfile() *Profile {
return &Profile{
ID: "fallback",
Name: "Fallback Profile",
Flags: map[uint8]uint8{
// Profile Modes
Blacklist: status.SecurityLevelDynamic,
Prompt: status.SecurityLevelSecure,
Whitelist: status.SecurityLevelFortress,
// Network Locations
Internet: status.SecurityLevelsDynamicAndSecure,
LAN: status.SecurityLevelsDynamicAndSecure,
Localhost: status.SecurityLevelsAll,
// Specials
Related: status.SecurityLevelDynamic,
PeerToPeer: status.SecurityLevelDynamic,
},
ServiceEndpoints: []*EndpointPermission{
&EndpointPermission{
DomainOrIP: "",
Wildcard: true,
Protocol: 0,
StartPort: 0,
EndPort: 0,
Permit: false,
},
},
}
}

138
profile/endpoints.go Normal file
View file

@ -0,0 +1,138 @@
package profile
import (
"fmt"
"strconv"
"strings"
"github.com/Safing/portmaster/intel"
)
// Endpoints is a list of permitted or denied endpoints.
type Endpoints []*EndpointPermission
// EndpointPermission holds a decision about an endpoint.
type EndpointPermission struct {
DomainOrIP string
Wildcard bool
Protocol uint8
StartPort uint16
EndPort uint16
Permit bool
Created int64
}
// IsSet returns whether the Endpoints object is "set".
func (e Endpoints) IsSet() bool {
if len(e) > 0 {
return true
}
return false
}
// Check checks if the given domain is governed in the list of domains and returns whether it is permitted.
// If getDomainOfIP (returns reverse and forward dns matching domain name) is supplied, an IP will be resolved to a domain, if necessary.
func (e Endpoints) Check(domainOrIP string, protocol uint8, port uint16, checkReverseIP bool, securityLevel uint8) (permit bool, reason string, ok bool) {
// ip resolving
var cachedGetDomainOfIP func() string
if checkReverseIP {
var ipResolved bool
var ipName string
// setup caching wrapper
cachedGetDomainOfIP = func() string {
if !ipResolved {
result, err := intel.ResolveIPAndValidate(domainOrIP, securityLevel)
if err != nil {
// log.Debug()
ipName = result
}
ipResolved = true
}
return ipName
}
}
isDomain := strings.HasSuffix(domainOrIP, ".")
for _, entry := range e {
if entry != nil {
if ok, reason := entry.Matches(domainOrIP, protocol, port, isDomain, cachedGetDomainOfIP); ok {
return entry.Permit, reason, true
}
}
}
return false, "", false
}
func isSubdomainOf(domain, subdomain string) bool {
dotPrefixedDomain := "." + domain
return strings.HasSuffix(subdomain, dotPrefixedDomain)
}
// Matches checks whether the given endpoint has a managed permission. If getDomainOfIP (returns reverse and forward dns matching domain name) is supplied, this declares an incoming connection.
func (ep EndpointPermission) Matches(domainOrIP string, protocol uint8, port uint16, isDomain bool, getDomainOfIP func() string) (match bool, reason string) {
if ep.Protocol > 0 && protocol != ep.Protocol {
return false, ""
}
if ep.StartPort > 0 && (port < ep.StartPort || port > ep.EndPort) {
return false, ""
}
switch {
case ep.Wildcard && len(ep.DomainOrIP) == 0:
// host wildcard
return true, fmt.Sprintf("%s matches %s", domainOrIP, ep)
case domainOrIP == ep.DomainOrIP:
// host match
return true, fmt.Sprintf("%s matches %s", domainOrIP, ep)
case isDomain && ep.Wildcard && isSubdomainOf(ep.DomainOrIP, domainOrIP):
// subdomain match
return true, fmt.Sprintf("%s matches %s", domainOrIP, ep)
case !isDomain && getDomainOfIP != nil && getDomainOfIP() == ep.DomainOrIP:
// resolved IP match
return true, fmt.Sprintf("%s->%s matches %s", domainOrIP, getDomainOfIP(), ep)
case !isDomain && getDomainOfIP != nil && ep.Wildcard && isSubdomainOf(ep.DomainOrIP, getDomainOfIP()):
// resolved IP subdomain match
return true, fmt.Sprintf("%s->%s matches %s", domainOrIP, getDomainOfIP(), ep)
default:
// no match
return false, ""
}
}
func (e Endpoints) String() string {
var s []string
for _, entry := range e {
s = append(s, entry.String())
}
return fmt.Sprintf("[%s]", strings.Join(s, ", "))
}
func (ep EndpointPermission) String() string {
s := ep.DomainOrIP
s += " "
if ep.Protocol > 0 {
s += strconv.Itoa(int(ep.Protocol))
} else {
s += "*"
}
s += "/"
if ep.StartPort > 0 {
if ep.StartPort == ep.EndPort {
s += strconv.Itoa(int(ep.StartPort))
} else {
s += fmt.Sprintf("%d-%d", ep.StartPort, ep.EndPort)
}
} else {
s += "*"
}
return s
}

61
profile/endpoints_test.go Normal file
View file

@ -0,0 +1,61 @@
package profile
import (
"testing"
)
// TODO: RETIRED
// func testdeMatcher(t *testing.T, value string, expectedResult bool) {
// if domainEndingMatcher.MatchString(value) != expectedResult {
// if expectedResult {
// t.Errorf("domainEndingMatcher should match %s", value)
// } else {
// t.Errorf("domainEndingMatcher should not match %s", value)
// }
// }
// }
//
// func TestdomainEndingMatcher(t *testing.T) {
// testdeMatcher(t, "example.com", true)
// testdeMatcher(t, "com", true)
// testdeMatcher(t, "example.xn--lgbbat1ad8j", true)
// testdeMatcher(t, "xn--lgbbat1ad8j", true)
// testdeMatcher(t, "fe80::beef", false)
// testdeMatcher(t, "fe80::dead:beef", false)
// testdeMatcher(t, "10.2.3.4", false)
// testdeMatcher(t, "4", false)
// }
func TestEPString(t *testing.T) {
var endpoints Endpoints
endpoints = []*EndpointPermission{
&EndpointPermission{
DomainOrIP: "example.com",
Wildcard: false,
Protocol: 6,
Permit: true,
},
&EndpointPermission{
DomainOrIP: "8.8.8.8",
Protocol: 17, // TCP
StartPort: 53, // DNS
EndPort: 53,
Permit: false,
},
&EndpointPermission{
DomainOrIP: "google.com",
Wildcard: true,
Permit: false,
},
}
if endpoints.String() != "[example.com 6/*, 8.8.8.8 17/53, google.com */*]" {
t.Errorf("unexpected result: %s", endpoints.String())
}
var noEndpoints Endpoints
noEndpoints = []*EndpointPermission{}
if noEndpoints.String() != "[]" {
t.Errorf("unexpected result: %s", noEndpoints.String())
}
}

48
profile/fingerprint.go Normal file
View file

@ -0,0 +1,48 @@
package profile
import "time"
var (
fingerprintWeights = map[string]int{
"full_path": 2,
"partial_path": 1,
"md5_sum": 4,
"sha1_sum": 5,
"sha256_sum": 6,
}
)
// Fingerprint links processes to profiles.
type Fingerprint struct {
OS string
Type string
Value string
Comment string
LastUsed int64
}
// MatchesOS returns whether the Fingerprint is applicable for the current OS.
func (fp *Fingerprint) MatchesOS() bool {
return fp.OS == osIdentifier
}
// GetFingerprintWeight returns the weight of the given fingerprint type.
func GetFingerprintWeight(fpType string) (weight int) {
weight, ok := fingerprintWeights[fpType]
if ok {
return weight
}
return 0
}
// AddFingerprint adds the given fingerprint to the profile.
func (p *Profile) AddFingerprint(fp *Fingerprint) {
if fp.OS == "" {
fp.OS = osIdentifier
}
if fp.LastUsed == 0 {
fp.LastUsed = time.Now().Unix()
}
p.Fingerprints = append(p.Fingerprints, fp)
}

130
profile/flags.go Normal file
View file

@ -0,0 +1,130 @@
package profile
import (
"errors"
"fmt"
"strings"
"github.com/Safing/portmaster/status"
)
// Flags are used to quickly add common attributes to profiles
type Flags map[uint8]uint8
// Profile Flags
const (
// Profile Modes
Prompt uint8 = 0 // Prompt first-seen connections
Blacklist uint8 = 1 // Allow everything not explicitly denied
Whitelist uint8 = 2 // Only allow everything explicitly allowed
// Network Locations
Internet uint8 = 16 // Allow connections to the Internet
LAN uint8 = 17 // Allow connections to the local area network
Localhost uint8 = 18 // Allow connections on the local host
// Specials
Related uint8 = 32 // If and before prompting, allow domains that are related to the program
PeerToPeer uint8 = 33 // Allow program to directly communicate with peers, without resolving DNS first
Service uint8 = 34 // Allow program to accept incoming connections
Independent uint8 = 35 // Ignore profile settings coming from the Community
RequireGate17 uint8 = 36 // Require all connections to go over Gate17
)
var (
// ErrFlagsParseFailed is returned if a an invalid flag is encountered while parsing
ErrFlagsParseFailed = errors.New("profiles: failed to parse flags")
sortedFlags = []uint8{
Prompt,
Blacklist,
Whitelist,
Internet,
LAN,
Localhost,
Related,
PeerToPeer,
Service,
Independent,
RequireGate17,
}
flagIDs = map[string]uint8{
"Prompt": Prompt,
"Blacklist": Blacklist,
"Whitelist": Whitelist,
"Internet": Internet,
"LAN": LAN,
"Localhost": Localhost,
"Related": Related,
"PeerToPeer": PeerToPeer,
"Service": Service,
"Independent": Independent,
"RequireGate17": RequireGate17,
}
flagNames = map[uint8]string{
Prompt: "Prompt",
Blacklist: "Blacklist",
Whitelist: "Whitelist",
Internet: "Internet",
LAN: "LAN",
Localhost: "Localhost",
Related: "Related",
PeerToPeer: "PeerToPeer",
Service: "Service",
Independent: "Independent",
RequireGate17: "RequireGate17",
}
)
// Check checks if a flag is set at all and if it's active in the given security level.
func (flags Flags) Check(flag, level uint8) (active bool, ok bool) {
if flags == nil {
return false, false
}
setting, ok := flags[flag]
if ok {
if setting&level > 0 {
return true, true
}
return false, true
}
return false, false
}
func getLevelMarker(levels, level uint8) string {
if levels&level > 0 {
return "+"
}
return "-"
}
// String return a string representation of Flags
func (flags Flags) String() string {
var markedFlags []string
for _, flag := range sortedFlags {
levels, ok := flags[flag]
if ok {
s := flagNames[flag]
if levels != status.SecurityLevelsAll {
s += getLevelMarker(levels, status.SecurityLevelDynamic)
s += getLevelMarker(levels, status.SecurityLevelSecure)
s += getLevelMarker(levels, status.SecurityLevelFortress)
}
markedFlags = append(markedFlags, s)
}
}
return fmt.Sprintf("[%s]", strings.Join(markedFlags, ", "))
}
// Add adds a flag to the Flags with the given level.
func (flags Flags) Add(flag, levels uint8) {
flags[flag] = levels
}
// Remove removes a flag from the Flags.
func (flags Flags) Remove(flag uint8) {
delete(flags, flag)
}

69
profile/flags_test.go Normal file
View file

@ -0,0 +1,69 @@
package profile
import (
"testing"
"github.com/Safing/portmaster/status"
)
func TestProfileFlags(t *testing.T) {
// check if all IDs have a name
for key, entry := range flagIDs {
if _, ok := flagNames[entry]; !ok {
t.Errorf("could not find entry for %s in flagNames", key)
}
}
// check if all names have an ID
for key, entry := range flagNames {
if _, ok := flagIDs[entry]; !ok {
t.Errorf("could not find entry for %d in flagNames", key)
}
}
testFlags := Flags{
Prompt: status.SecurityLevelsAll,
Internet: status.SecurityLevelsDynamicAndSecure,
LAN: status.SecurityLevelsDynamicAndSecure,
Localhost: status.SecurityLevelsAll,
Related: status.SecurityLevelDynamic,
RequireGate17: status.SecurityLevelsSecureAndFortress,
}
if testFlags.String() != "[Prompt, Internet++-, LAN++-, Localhost, Related+--, RequireGate17-++]" {
t.Errorf("unexpected output: %s", testFlags.String())
}
// // check Has
// emptyFlags := ProfileFlags{}
// for flag, name := range flagNames {
// if !sortedFlags.Has(flag) {
// t.Errorf("sortedFlags should have flag %s (%d)", name, flag)
// }
// if emptyFlags.Has(flag) {
// t.Errorf("emptyFlags should not have flag %s (%d)", name, flag)
// }
// }
//
// // check ProfileFlags creation from strings
// var allFlagStrings []string
// for _, flag := range *sortedFlags {
// allFlagStrings = append(allFlagStrings, flagNames[flag])
// }
// newFlags, err := FlagsFromNames(allFlagStrings)
// if err != nil {
// t.Errorf("error while parsing flags: %s", err)
// }
// if newFlags.String() != sortedFlags.String() {
// t.Errorf("parsed flags are not correct (or tests have not been updated to reflect the right number), expected %v, got %v", *sortedFlags, *newFlags)
// }
//
// // check ProfileFlags Stringer
// flagString := newFlags.String()
// check := strings.Join(allFlagStrings, ",")
// if flagString != check {
// t.Errorf("flag string is not correct, expected %s, got %s", check, flagString)
// }
}

76
profile/framework.go Normal file
View file

@ -0,0 +1,76 @@
package profile
// DEACTIVATED
// import (
// "fmt"
// "os"
// "path/filepath"
// "regexp"
// "strings"
//
// "github.com/Safing/portbase/log"
// )
//
// type Framework struct {
// // go hirarchy up
// FindParent uint8 `json:",omitempty bson:",omitempty"`
// // get path from parent, amount of levels to go up the tree (1 means parent, 2 means parent of parents, and so on)
// MergeWithParent bool `json:",omitempty bson:",omitempty"`
// // instead of getting the path of the parent, merge with it by presenting connections as if they were from that parent
//
// // go hirarchy down
// Find string `json:",omitempty bson:",omitempty"`
// // Regular expression for finding path elements
// Build string `json:",omitempty bson:",omitempty"`
// // Path definitions for building path
// Virtual bool `json:",omitempty bson:",omitempty"`
// // Treat resulting path as virtual, do not check if valid
// }
//
// func (f *Framework) GetNewPath(command string, cwd string) (string, error) {
// // "/usr/bin/python script"
// // to
// // "/path/to/script"
// regex, err := regexp.Compile(f.Find)
// if err != nil {
// return "", fmt.Errorf("profiles(framework): failed to compile framework regex: %s", err)
// }
// matched := regex.FindAllStringSubmatch(command, -1)
// if len(matched) == 0 || len(matched[0]) < 2 {
// return "", fmt.Errorf("profiles(framework): regex \"%s\" for constructing path did not match command \"%s\"", f.Find, command)
// }
//
// var lastError error
// var buildPath string
// for _, buildPath = range strings.Split(f.Build, "|") {
//
// buildPath = strings.Replace(buildPath, "{CWD}", cwd, -1)
// for i := 1; i < len(matched[0]); i++ {
// buildPath = strings.Replace(buildPath, fmt.Sprintf("{%d}", i), matched[0][i], -1)
// }
//
// buildPath = filepath.Clean(buildPath)
//
// if !f.Virtual {
// if !strings.HasPrefix(buildPath, "~/") && !filepath.IsAbs(buildPath) {
// lastError = fmt.Errorf("constructed path \"%s\" from framework is not absolute", buildPath)
// continue
// }
// if _, err := os.Stat(buildPath); os.IsNotExist(err) {
// lastError = fmt.Errorf("constructed path \"%s\" does not exist", buildPath)
// continue
// }
// }
//
// lastError = nil
// break
//
// }
//
// if lastError != nil {
// return "", fmt.Errorf("profiles(framework): failed to construct valid path, last error: %s", lastError)
// }
// log.Tracef("profiles(framework): transformed \"%s\" (%s) to \"%s\"", command, cwd, buildPath)
// return buildPath, nil
// }

30
profile/framework_test.go Normal file
View file

@ -0,0 +1,30 @@
package profile
// DEACTIVATED
// import (
// "testing"
// )
//
// func testGetNewPath(t *testing.T, f *Framework, command, cwd, expect string) {
// newPath, err := f.GetNewPath(command, cwd)
// if err != nil {
// t.Errorf("GetNewPath failed: %s", err)
// }
// if newPath != expect {
// t.Errorf("GetNewPath return unexpected result: got %s, expected %s", newPath, expect)
// }
// }
//
// func TestFramework(t *testing.T) {
// f1 := &Framework{
// Find: "([^ ]+)$",
// Build: "{CWD}/{1}",
// }
// testGetNewPath(t, f1, "/usr/bin/python bash", "/bin", "/bin/bash")
// f2 := &Framework{
// Find: "([^ ]+)$",
// Build: "{1}|{CWD}/{1}",
// }
// testGetNewPath(t, f2, "/usr/bin/python /bin/bash", "/tmp", "/bin/bash")
// }

View file

@ -0,0 +1,47 @@
package profile
import (
"path/filepath"
"strings"
"github.com/Safing/portbase/utils"
)
// GetPathIdentifier returns the identifier from the given path
func GetPathIdentifier(path string) string {
// clean path
// TODO: is this necessary?
cleanedPath, err := filepath.EvalSymlinks(path)
if err == nil {
path = cleanedPath
} else {
path = filepath.Clean(path)
}
splittedPath := strings.Split(path, "/")
// strip sensitive data
switch {
case strings.HasPrefix(path, "/home/"):
splittedPath = splittedPath[3:]
case strings.HasPrefix(path, "/root/"):
splittedPath = splittedPath[2:]
}
// common directories with executable
if i := utils.IndexOfString(splittedPath, "bin"); i > 0 {
splittedPath = splittedPath[i:]
return strings.Join(splittedPath, "/")
}
if i := utils.IndexOfString(splittedPath, "sbin"); i > 0 {
splittedPath = splittedPath[i:]
return strings.Join(splittedPath, "/")
}
// shorten to max 3
if len(splittedPath) > 3 {
splittedPath = splittedPath[len(splittedPath)-3:]
}
return strings.Join(splittedPath, "/")
}

View file

@ -0,0 +1,23 @@
package profile
import "testing"
func testPathID(t *testing.T, execPath, identifierPath string) {
result := GetPathIdentifier(execPath)
if result != identifierPath {
t.Errorf("unexpected identifier path for %s: got %s, expected %s", execPath, result, identifierPath)
}
}
func TestGetPathIdentifier(t *testing.T) {
testPathID(t, "/bin/bash", "bin/bash")
testPathID(t, "/home/user/bin/bash", "bin/bash")
testPathID(t, "/home/user/project/main", "project/main")
testPathID(t, "/root/project/main", "project/main")
testPathID(t, "/tmp/a/b/c/d/install.sh", "c/d/install.sh")
testPathID(t, "/sbin/init", "sbin/init")
testPathID(t, "/lib/systemd/systemd-udevd", "lib/systemd/systemd-udevd")
testPathID(t, "/bundle/ruby/2.4.0/bin/passenger", "bin/passenger")
testPathID(t, "/usr/sbin/cron", "sbin/cron")
testPathID(t, "/usr/local/bin/python", "bin/python")
}

102
profile/index/index.go Normal file
View file

@ -0,0 +1,102 @@
package index
import (
"encoding/base64"
"errors"
"fmt"
"sync"
"github.com/Safing/portbase/database/record"
"github.com/Safing/portbase/utils"
)
// ProfileIndex links an Identifier to Profiles
type ProfileIndex struct {
record.Base
sync.Mutex
ID string
UserProfiles []string
StampProfiles []string
}
func makeIndexRecordKey(fpType, id string) string {
return fmt.Sprintf("index:profiles/%s:%s", fpType, base64.RawURLEncoding.EncodeToString([]byte(id)))
}
// NewIndex returns a new ProfileIndex.
func NewIndex(id string) *ProfileIndex {
return &ProfileIndex{
ID: id,
}
}
// AddUserProfile adds a User Profile to the index.
func (pi *ProfileIndex) AddUserProfile(identifier string) (changed bool) {
if !utils.StringInSlice(pi.UserProfiles, identifier) {
pi.UserProfiles = append(pi.UserProfiles, identifier)
return true
}
return false
}
// AddStampProfile adds a Stamp Profile to the index.
func (pi *ProfileIndex) AddStampProfile(identifier string) (changed bool) {
if !utils.StringInSlice(pi.StampProfiles, identifier) {
pi.StampProfiles = append(pi.StampProfiles, identifier)
return true
}
return false
}
// RemoveUserProfile removes a profile from the index.
func (pi *ProfileIndex) RemoveUserProfile(id string) {
pi.UserProfiles = utils.RemoveFromStringSlice(pi.UserProfiles, id)
}
// RemoveStampProfile removes a profile from the index.
func (pi *ProfileIndex) RemoveStampProfile(id string) {
pi.StampProfiles = utils.RemoveFromStringSlice(pi.StampProfiles, id)
}
// Get gets a ProfileIndex from the database.
func Get(fpType, id string) (*ProfileIndex, error) {
key := makeIndexRecordKey(fpType, id)
r, err := indexDB.Get(key)
if err != nil {
return nil, err
}
// unwrap
if r.IsWrapped() {
// only allocate a new struct, if we need it
new := &ProfileIndex{}
err = record.Unwrap(r, new)
if err != nil {
return nil, err
}
return new, nil
}
// or adjust type
new, ok := r.(*ProfileIndex)
if !ok {
return nil, fmt.Errorf("record not of type *ProfileIndex, but %T", r)
}
return new, nil
}
// Save saves the Identifiers to the database
func (pi *ProfileIndex) Save() error {
if !pi.KeyIsSet() {
if pi.ID != "" {
pi.SetKey(makeIndexRecordKey(pi.ID))
} else {
return errors.New("missing identification Key")
}
}
return indexDB.Put(pi)
}

Some files were not shown because too many files have changed in this diff Show more