Merge from develop, update README.md

This commit is contained in:
Daniel 2019-01-24 15:40:18 +01:00
commit 3ba081404b
143 changed files with 6882 additions and 3125 deletions

4
.gitignore vendored Normal file
View file

@ -0,0 +1,4 @@
dnsonly
main
*.exe

View file

@ -1,3 +1,41 @@
# Portmaster
The Portmaster is currently being revamped. You can check out the latest changes in the `develop` branch.
The Portmaster enables you to protect your data on your device. You are back in charge of your outgoing connections: you choose what data you share and what data stays private.
## Current Status
The Portmaster is currently in alpha. Expect dragons.
Supported platforms:
- linux_amd64
- windows_amd64 (_soon_)
- darwin_amd64 (_later_)
## Usage
Just download the portmaster from the releases page.
./portmaster -db=/opt/pm_db
# this will add some rules to iptables for traffic interception via nfqueue (and will clean up afterwards!)
# then start the ui
./portmaster -db=/opt/pm_db -ui
# missing files will be automatically download when first needed
## Documentation
Documentation _in progress_ can be found here: [http://docs.safing.io/](http://docs.safing.io/)
## Dependencies
#### Linux
- libnetfilter_queue
- debian/ubuntu: `sudo apt-get install libnetfilter-queue1`
- fedora: `sudo yum install libnetfilter_queue`
- arch: `sudo pacman -S libnetfilter_queue`
- [Network Manager](https://wiki.gnome.org/Projects/NetworkManager) (_optional_)
#### Windows
- Windows 7 (with update KB3033929) or up
- [KB3033929](https://docs.microsoft.com/en-us/security-updates/SecurityAdvisories/2015/3033929) (a 2015 security update) is required for correctly verifying the driver signature
- Windows Server 2016 systems must have secure boot disabled. (_clarification needed_)

52
build Executable file
View file

@ -0,0 +1,52 @@
#!/bin/bash
# get build data
if [[ "$BUILD_COMMIT" == "" ]]; then
BUILD_COMMIT=$(git describe --all --long --abbrev=99 --dirty 2>/dev/null)
fi
if [[ "$BUILD_USER" == "" ]]; then
BUILD_USER=$(id -un)
fi
if [[ "$BUILD_HOST" == "" ]]; then
BUILD_HOST=$(hostname -f)
fi
if [[ "$BUILD_DATE" == "" ]]; then
BUILD_DATE=$(date +%d.%m.%Y)
fi
if [[ "$BUILD_SOURCE" == "" ]]; then
BUILD_SOURCE=$(git remote -v | grep origin | cut -f2 | cut -d" " -f1 | head -n 1)
fi
if [[ "$BUILD_SOURCE" == "" ]]; then
BUILD_SOURCE=$(git remote -v | cut -f2 | cut -d" " -f1 | head -n 1)
fi
BUILD_BUILDOPTIONS=$(echo $* | sed "s/ /§/g")
# check
if [[ "$BUILD_COMMIT" == "" ]]; then
echo "could not automatically determine BUILD_COMMIT, please supply manually as environment variable."
exit 1
fi
if [[ "$BUILD_USER" == "" ]]; then
echo "could not automatically determine BUILD_USER, please supply manually as environment variable."
exit 1
fi
if [[ "$BUILD_HOST" == "" ]]; then
echo "could not automatically determine BUILD_HOST, please supply manually as environment variable."
exit 1
fi
if [[ "$BUILD_DATE" == "" ]]; then
echo "could not automatically determine BUILD_DATE, please supply manually as environment variable."
exit 1
fi
if [[ "$BUILD_SOURCE" == "" ]]; then
echo "could not automatically determine BUILD_SOURCE, please supply manually as environment variable."
exit 1
fi
echo "Please notice, that this build script includes metadata into the build."
echo "This information is useful for debugging and license compliance."
echo "Run the compiled binary with the -version flag to see the information included."
# build
BUILD_PATH="github.com/Safing/portbase/info"
go build -ldflags "-X ${BUILD_PATH}.commit=${BUILD_COMMIT} -X ${BUILD_PATH}.buildOptions=${BUILD_BUILDOPTIONS} -X ${BUILD_PATH}.buildUser=${BUILD_USER} -X ${BUILD_PATH}.buildHost=${BUILD_HOST} -X ${BUILD_PATH}.buildDate=${BUILD_DATE} -X ${BUILD_PATH}.buildSource=${BUILD_SOURCE}" $*

55
dnsonly.go Normal file
View file

@ -0,0 +1,55 @@
package main
import (
"fmt"
"os"
"os/signal"
"syscall"
"github.com/Safing/portbase/info"
"github.com/Safing/portbase/log"
"github.com/Safing/portbase/modules"
// include packages here
_ "github.com/Safing/portmaster/nameserver/only"
)
func main() {
// Set Info
info.Set("Portmaster (DNS only)", "0.2.0")
// Start
err := modules.Start()
if err != nil {
if err == modules.ErrCleanExit {
os.Exit(0)
} else {
err = modules.Shutdown()
if err != nil {
log.Shutdown()
}
os.Exit(1)
}
}
// Shutdown
// catch interrupt for clean shutdown
signalCh := make(chan os.Signal)
signal.Notify(
signalCh,
os.Interrupt,
syscall.SIGHUP,
syscall.SIGINT,
syscall.SIGTERM,
syscall.SIGQUIT,
)
select {
case <-signalCh:
fmt.Println(" <INTERRUPT>")
log.Warning("main: program was interrupted, shutting down.")
modules.Shutdown()
case <-modules.ShuttingDown():
}
}

26
firewall/config.go Normal file
View file

@ -0,0 +1,26 @@
package firewall
import (
"github.com/Safing/portbase/config"
)
var (
permanentVerdicts config.BoolOption
)
func registerConfig() error {
err := config.Register(&config.Option{
Name: "Permanent Verdicts",
Key: "firewall/permanentVerdicts",
Description: "With permanent verdicts, control of a connection is fully handed back to the OS after the initial decision. This brings a great performance increase, but makes it impossible to change the decision of a link later on.",
ExpertiseLevel: config.ExpertiseLevelExpert,
OptType: config.OptTypeBool,
DefaultValue: true,
})
if err != nil {
return err
}
permanentVerdicts = config.Concurrent.GetAsBool("firewall/permanentVerdicts", true)
return nil
}

View file

@ -1,28 +1,22 @@
// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file.
package firewall
import (
"fmt"
"net"
"os"
"sync/atomic"
"time"
"github.com/Safing/safing-core/configuration"
"github.com/Safing/safing-core/firewall/inspection"
"github.com/Safing/safing-core/firewall/interception"
"github.com/Safing/safing-core/log"
"github.com/Safing/safing-core/modules"
"github.com/Safing/safing-core/network"
"github.com/Safing/safing-core/network/packet"
"github.com/Safing/safing-core/port17/entry"
"github.com/Safing/safing-core/port17/mode"
"github.com/Safing/safing-core/portmaster"
"github.com/Safing/safing-core/process"
"github.com/Safing/portbase/log"
"github.com/Safing/portbase/modules"
"github.com/Safing/portmaster/firewall/inspection"
"github.com/Safing/portmaster/firewall/interception"
"github.com/Safing/portmaster/network"
"github.com/Safing/portmaster/network/packet"
"github.com/Safing/portmaster/process"
)
var (
firewallModule *modules.Module
// localNet net.IPNet
localhost net.IP
dnsServer net.IPNet
@ -30,8 +24,6 @@ var (
packetsBlocked *uint64
packetsDropped *uint64
config = configuration.Get()
localNet4 *net.IPNet
// Yes, this would normally be 127.0.0.0/8
// TODO: figure out any side effects
@ -46,23 +38,30 @@ var (
)
func init() {
modules.Register("firewall", prep, start, stop, "global", "network", "nameserver", "profile")
}
var err error
func prep() (err error) {
err = registerConfig()
if err != nil {
return err
}
_, localNet4, err = net.ParseCIDR("127.0.0.0/24")
// Yes, this would normally be 127.0.0.0/8
// TODO: figure out any side effects
if err != nil {
log.Criticalf("firewall: failed to parse cidr 127.0.0.0/24: %s", err)
return fmt.Errorf("firewall: failed to parse cidr 127.0.0.0/24: %s", err)
}
_, tunnelNet4, err = net.ParseCIDR("127.17.0.0/16")
if err != nil {
log.Criticalf("firewall: failed to parse cidr 127.17.0.0/16: %s", err)
return fmt.Errorf("firewall: failed to parse cidr 127.17.0.0/16: %s", err)
}
_, tunnelNet6, err = net.ParseCIDR("fd17::/64")
if err != nil {
log.Criticalf("firewall: failed to parse cidr fd17::/64: %s", err)
return fmt.Errorf("firewall: failed to parse cidr fd17::/64: %s", err)
}
var pA uint64
@ -71,20 +70,22 @@ func init() {
packetsBlocked = &pB
var pD uint64
packetsDropped = &pD
return nil
}
func Start() {
firewallModule = modules.Register("Firewall", 128)
defer firewallModule.StopComplete()
// start interceptor
go interception.Start()
func start() error {
go statLogger()
go run()
// go run()
// go run()
// go run()
// go run()
// go run()
// go run()
run()
return interception.Start()
}
func stop() error {
return interception.Stop()
}
func handlePacket(pkt packet.Packet) {
@ -111,12 +112,6 @@ func handlePacket(pkt packet.Packet) {
return
}
// allow anything that goes to a tunnel entrypoint
if pkt.IsOutbound() && (pkt.GetIPHeader().Dst.Equal(tunnelEntry4) || pkt.GetIPHeader().Dst.Equal(tunnelEntry6)) {
pkt.PermanentAccept()
return
}
// log.Debugf("firewall: pkt %s has ID %s", pkt, pkt.GetConnectionID())
// use this to time how long it takes process packet
@ -124,16 +119,16 @@ func handlePacket(pkt packet.Packet) {
// defer log.Tracef("firewall: took %s to process packet %s", time.Now().Sub(timed).String(), pkt)
// check if packet is destined for tunnel
switch pkt.IPVersion() {
case packet.IPv4:
if portmaster.TunnelNet4 != nil && portmaster.TunnelNet4.Contains(pkt.GetIPHeader().Dst) {
tunnelHandler(pkt)
}
case packet.IPv6:
if portmaster.TunnelNet6 != nil && portmaster.TunnelNet6.Contains(pkt.GetIPHeader().Dst) {
tunnelHandler(pkt)
}
}
// switch pkt.IPVersion() {
// case packet.IPv4:
// if TunnelNet4 != nil && TunnelNet4.Contains(pkt.GetIPHeader().Dst) {
// tunnelHandler(pkt)
// }
// case packet.IPv6:
// if TunnelNet6 != nil && TunnelNet6.Contains(pkt.GetIPHeader().Dst) {
// tunnelHandler(pkt)
// }
// }
// associate packet to link and handle
link, created := network.GetOrCreateLinkByPacket(pkt)
@ -146,7 +141,7 @@ func handlePacket(pkt packet.Packet) {
link.HandlePacket(pkt)
return
}
verdict(pkt, link.Verdict)
verdict(pkt, link.GetVerdict())
}
@ -157,57 +152,68 @@ func initialHandler(pkt packet.Packet, link *network.Link) {
if err != nil {
if err != process.ErrConnectionNotFound {
log.Warningf("firewall: could not find process of packet (dropping link %s): %s", pkt.String(), err)
link.Deny(fmt.Sprintf("could not find process or it does not exist (unsolicited packet): %s", err))
} else {
log.Warningf("firewall: internal error finding process of packet (dropping link %s): %s", pkt.String(), err)
link.Deny(fmt.Sprintf("internal error finding process: %s", err))
}
link.UpdateVerdict(network.DROP)
verdict(pkt, network.DROP)
if pkt.IsInbound() {
network.UnknownIncomingConnection.AddLink(link)
} else {
network.UnknownDirectConnection.AddLink(link)
}
verdict(pkt, link.GetVerdict())
link.StopFirewallHandler()
return
}
// add new Link to Connection (and save both)
connection.AddLink(link)
// reroute dns requests to nameserver
if connection.Process().Pid != os.Getpid() && pkt.IsOutbound() && pkt.GetTCPUDPHeader() != nil && !pkt.GetIPHeader().Dst.Equal(localhost) && pkt.GetTCPUDPHeader().DstPort == 53 {
pkt.RerouteToNameserver()
link.RerouteToNameserver()
verdict(pkt, link.GetVerdict())
link.StopFirewallHandler()
return
}
// persist connection
connection.CreateInProcessNamespace()
// add new Link to Connection
connection.AddLink(link, pkt)
// make a decision if not made already
if connection.Verdict == network.UNDECIDED {
portmaster.DecideOnConnection(connection, pkt)
if connection.GetVerdict() == network.UNDECIDED {
DecideOnConnection(connection, pkt)
}
if connection.Verdict != network.CANTSAY {
link.UpdateVerdict(connection.Verdict)
if connection.GetVerdict() == network.ACCEPT {
DecideOnLink(connection, link, pkt)
} else {
portmaster.DecideOnLink(connection, link, pkt)
link.UpdateVerdict(connection.GetVerdict())
}
// log decision
logInitialVerdict(link)
// TODO: link this to real status
port17Active := mode.Client()
// port17Active := mode.Client()
switch {
case port17Active && link.Inspect:
// tunnel link, but also inspect (after reroute)
link.Tunneled = true
link.SetFirewallHandler(inspectThenVerdict)
verdict(pkt, link.Verdict)
case port17Active:
// tunnel link, don't inspect
link.Tunneled = true
link.StopFirewallHandler()
permanentVerdict(pkt, network.ACCEPT)
// case port17Active && link.Inspect:
// // tunnel link, but also inspect (after reroute)
// link.Tunneled = true
// link.SetFirewallHandler(inspectThenVerdict)
// verdict(pkt, link.GetVerdict())
// case port17Active:
// // tunnel link, don't inspect
// link.Tunneled = true
// link.StopFirewallHandler()
// permanentVerdict(pkt, network.ACCEPT)
case link.Inspect:
link.SetFirewallHandler(inspectThenVerdict)
inspectThenVerdict(pkt, link)
default:
link.StopFirewallHandler()
verdict(pkt, link.Verdict)
verdict(pkt, link.GetVerdict())
}
}
@ -216,10 +222,11 @@ func inspectThenVerdict(pkt packet.Packet, link *network.Link) {
pktVerdict, continueInspection := inspection.RunInspectors(pkt, link)
if continueInspection {
// do not allow to circumvent link decision: e.g. to ACCEPT packets from a DROP-ed link
if pktVerdict > link.Verdict {
linkVerdict := link.GetVerdict()
if pktVerdict > linkVerdict {
verdict(pkt, pktVerdict)
} else {
verdict(pkt, link.Verdict)
verdict(pkt, linkVerdict)
}
return
}
@ -227,13 +234,11 @@ func inspectThenVerdict(pkt packet.Packet, link *network.Link) {
// we are done with inspecting
link.StopFirewallHandler()
config.Changed()
config.RLock()
link.VerdictPermanent = config.PermanentVerdicts
config.RUnlock()
link.Lock()
defer link.Unlock()
link.VerdictPermanent = permanentVerdicts()
if link.VerdictPermanent {
link.Save()
go link.Save()
permanentVerdict(pkt, link.Verdict)
} else {
verdict(pkt, link.Verdict)
@ -254,6 +259,12 @@ func permanentVerdict(pkt packet.Packet, action network.Verdict) {
atomic.AddUint64(packetsDropped, 1)
pkt.PermanentDrop()
return
case network.RerouteToNameserver:
pkt.RerouteToNameserver()
return
case network.RerouteToTunnel:
pkt.RerouteToTunnel()
return
}
pkt.Drop()
}
@ -272,36 +283,46 @@ func verdict(pkt packet.Packet, action network.Verdict) {
atomic.AddUint64(packetsDropped, 1)
pkt.Drop()
return
case network.RerouteToNameserver:
pkt.RerouteToNameserver()
return
case network.RerouteToTunnel:
pkt.RerouteToTunnel()
return
}
pkt.Drop()
}
func tunnelHandler(pkt packet.Packet) {
tunnelInfo := portmaster.GetTunnelInfo(pkt.GetIPHeader().Dst)
if tunnelInfo == nil {
pkt.Block()
return
}
entry.CreateTunnel(pkt, tunnelInfo.Domain, tunnelInfo.RRCache.ExportAllARecords())
log.Tracef("firewall: rerouting %s to tunnel entry point", pkt)
pkt.RerouteToTunnel()
return
}
// func tunnelHandler(pkt packet.Packet) {
// tunnelInfo := GetTunnelInfo(pkt.GetIPHeader().Dst)
// if tunnelInfo == nil {
// pkt.Block()
// return
// }
//
// entry.CreateTunnel(pkt, tunnelInfo.Domain, tunnelInfo.RRCache.ExportAllARecords())
// log.Tracef("firewall: rerouting %s to tunnel entry point", pkt)
// pkt.RerouteToTunnel()
// return
// }
func logInitialVerdict(link *network.Link) {
// switch link.Verdict {
// switch link.GetVerdict() {
// case network.ACCEPT:
// log.Infof("firewall: accepting new link: %s", link.String())
// case network.BLOCK:
// log.Infof("firewall: blocking new link: %s", link.String())
// case network.DROP:
// log.Infof("firewall: dropping new link: %s", link.String())
// case network.RerouteToNameserver:
// log.Infof("firewall: rerouting new link to nameserver: %s", link.String())
// case network.RerouteToTunnel:
// log.Infof("firewall: rerouting new link to tunnel: %s", link.String())
// }
}
func logChangedVerdict(link *network.Link) {
// switch link.Verdict {
// switch link.GetVerdict() {
// case network.ACCEPT:
// log.Infof("firewall: change! - now accepting link: %s", link.String())
// case network.BLOCK:
@ -312,25 +333,26 @@ func logChangedVerdict(link *network.Link) {
}
func run() {
packetProcessingLoop:
for {
select {
case <-firewallModule.Stop:
break packetProcessingLoop
case <-modules.ShuttingDown():
return
case pkt := <-interception.Packets:
handlePacket(pkt)
}
}
}
func statLogger() {
for {
time.Sleep(10 * time.Second)
log.Tracef("firewall: packets accepted %d, blocked %d, dropped %d", atomic.LoadUint64(packetsAccepted), atomic.LoadUint64(packetsBlocked), atomic.LoadUint64(packetsDropped))
atomic.StoreUint64(packetsAccepted, 0)
atomic.StoreUint64(packetsBlocked, 0)
atomic.StoreUint64(packetsDropped, 0)
select {
case <-modules.ShuttingDown():
return
case <-time.After(10 * time.Second):
log.Tracef("firewall: packets accepted %d, blocked %d, dropped %d", atomic.LoadUint64(packetsAccepted), atomic.LoadUint64(packetsBlocked), atomic.LoadUint64(packetsDropped))
atomic.StoreUint64(packetsAccepted, 0)
atomic.StoreUint64(packetsBlocked, 0)
atomic.StoreUint64(packetsDropped, 0)
}
}
}

View file

@ -3,9 +3,10 @@
package inspection
import (
"github.com/Safing/safing-core/network"
"github.com/Safing/safing-core/network/packet"
"sync"
"github.com/Safing/portmaster/network"
"github.com/Safing/portmaster/network/packet"
)
const (
@ -40,24 +41,28 @@ func RunInspectors(pkt packet.Packet, link *network.Link) (network.Verdict, bool
// inspectorsLock.Lock()
// defer inspectorsLock.Unlock()
if link.ActiveInspectors == nil {
link.ActiveInspectors = make([]bool, len(inspectors), len(inspectors))
activeInspectors := link.GetActiveInspectors()
if activeInspectors == nil {
activeInspectors = make([]bool, len(inspectors), len(inspectors))
link.SetActiveInspectors(activeInspectors)
}
if link.InspectorData == nil {
link.InspectorData = make(map[uint8]interface{})
inspectorData := link.GetInspectorData()
if inspectorData == nil {
inspectorData = make(map[uint8]interface{})
link.SetInspectorData(inspectorData)
}
continueInspection := false
verdict := network.UNDECIDED
for key, skip := range link.ActiveInspectors {
for key, skip := range activeInspectors {
if skip {
continue
}
if link.Verdict > inspectVerdicts[key] {
link.ActiveInspectors[key] = true
activeInspectors[key] = true
continue
}
@ -78,16 +83,16 @@ func RunInspectors(pkt packet.Packet, link *network.Link) (network.Verdict, bool
continueInspection = true
case BLOCK_LINK:
link.UpdateVerdict(network.BLOCK)
link.ActiveInspectors[key] = true
activeInspectors[key] = true
if verdict < network.BLOCK {
verdict = network.BLOCK
}
case DROP_LINK:
link.UpdateVerdict(network.DROP)
link.ActiveInspectors[key] = true
activeInspectors[key] = true
verdict = network.DROP
case STOP_INSPECTING:
link.ActiveInspectors[key] = true
activeInspectors[key] = true
}
}

View file

@ -1,5 +1,3 @@
// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file.
package tls
var (

View file

@ -1,5 +1,3 @@
// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file.
package tls
import (
@ -12,14 +10,13 @@ import (
"github.com/google/gopacket/layers"
"github.com/google/gopacket/tcpassembly"
"github.com/Safing/safing-core/configuration"
"github.com/Safing/safing-core/crypto/verify"
"github.com/Safing/safing-core/firewall/inspection"
"github.com/Safing/safing-core/firewall/inspection/tls/tlslib"
"github.com/Safing/safing-core/log"
"github.com/Safing/safing-core/network"
"github.com/Safing/safing-core/network/netutils"
"github.com/Safing/safing-core/network/packet"
"github.com/Safing/portbase/log"
"github.com/Safing/portmaster/firewall/inspection"
"github.com/Safing/portmaster/firewall/inspection/tls/tlslib"
"github.com/Safing/portmaster/firewall/inspection/tls/verify"
"github.com/Safing/portmaster/network"
"github.com/Safing/portmaster/network/netutils"
"github.com/Safing/portmaster/network/packet"
)
// TODO:
@ -31,8 +28,6 @@ var (
tlsInspectorIndex int
assemblerManager *netutils.SimpleStreamAssemblerManager
assembler *tcpassembly.Assembler
config = configuration.Get()
)
const (

View file

@ -6,7 +6,7 @@ import (
"fmt"
"testing"
"github.com/Safing/safing-core/firewall/inspection/tls/tlslib"
"github.com/Safing/portmaster/firewall/inspection/tls/tlslib"
)
var clientHelloSample = []byte{

View file

@ -1,5 +1,3 @@
// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file.
package verify
import (
@ -14,15 +12,15 @@ import (
"strings"
"github.com/cloudflare/cfssl/crypto/pkcs7"
datastore "github.com/ipfs/go-datastore"
"github.com/Safing/safing-core/crypto/hash"
"github.com/Safing/safing-core/database"
"github.com/Safing/portbase/crypto/hash"
"github.com/Safing/portbase/database"
"github.com/Safing/portbase/database/record"
)
// Cert saves a certificate.
type Cert struct {
database.Base
record.Record
cert *x509.Certificate
Raw []byte
@ -120,7 +118,7 @@ func (m *Cert) CreateRevokedCert(caID string, serialNumber *big.Int) error {
}
// CreateInNamespace saves Cert with the provided name in the provided namespace.
func (m *Cert) CreateInNamespace(namespace *datastore.Key, name string) error {
func (m *Cert) CreateInNamespace(namespace string, name string) error {
return m.CreateObject(namespace, name, m)
}
@ -140,7 +138,7 @@ func GetCertWithSPKI(spki []byte) (*Cert, error) {
}
// GetCertFromNamespace gets Cert with the provided name from the provided namespace.
func GetCertFromNamespace(namespace *datastore.Key, name string) (*Cert, error) {
func GetCertFromNamespace(namespace string, name string) (*Cert, error) {
object, err := database.GetAndEnsureModel(namespace, name, certModel)
if err != nil {
return nil, err

View file

@ -1,5 +1,3 @@
// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file.
package verify
import (

View file

@ -1,5 +1,3 @@
// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file.
package verify
import (
@ -14,16 +12,15 @@ import (
"sync"
"time"
datastore "github.com/ipfs/go-datastore"
"github.com/Safing/safing-core/crypto/hash"
"github.com/Safing/safing-core/database"
"github.com/Safing/safing-core/log"
"github.com/Safing/portbase/crypto/hash"
"github.com/Safing/portbase/database"
"github.com/Safing/portbase/database/record"
"github.com/Safing/portbase/log"
)
// CARevocationInfo saves Information on revokation of Certificates of a Certificate Authority.
type CARevocationInfo struct {
database.Base
record.Record
CRLDistributionPoints []string
OCSPServers []string
@ -39,23 +36,17 @@ type CARevocationInfo struct {
}
var (
caRevocationInfoModel *CARevocationInfo // only use this as parameter for database.EnsureModel-like functions
dupCrlReqMap = make(map[string]*sync.Mutex)
dupCrlReqLock sync.Mutex
)
func init() {
database.RegisterModel(caRevocationInfoModel, func() database.Model { return new(CARevocationInfo) })
}
// Create saves CARevocationInfo with the provided name in the default namespace.
func (m *CARevocationInfo) Create(name string) error {
return m.CreateObject(&database.CARevocationInfoCache, name, m)
}
// CreateInNamespace saves CARevocationInfo with the provided name in the provided namespace.
func (m *CARevocationInfo) CreateInNamespace(namespace *datastore.Key, name string) error {
func (m *CARevocationInfo) CreateInNamespace(namespace string, name string) error {
return m.CreateObject(namespace, name, m)
}
@ -78,7 +69,7 @@ func GetCARevocationInfo(name string) (*CARevocationInfo, error) {
}
// GetCARevocationInfoFromNamespace fetches CARevocationInfo with the provided name from the provided namespace.
func GetCARevocationInfoFromNamespace(namespace *datastore.Key, name string) (*CARevocationInfo, error) {
func GetCARevocationInfoFromNamespace(namespace string, name string) (*CARevocationInfo, error) {
object, err := database.GetAndEnsureModel(namespace, name, caRevocationInfoModel)
if err != nil {
return nil, err

View file

@ -1,5 +1,3 @@
// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file.
package verify
import (
@ -16,8 +14,8 @@ import (
"golang.org/x/crypto/ocsp"
"github.com/Safing/safing-core/crypto/hash"
"github.com/Safing/safing-core/log"
"github.com/Safing/portbase/crypto/hash"
"github.com/Safing/portbase/log"
)
var (

View file

@ -1,5 +1,3 @@
// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file.
package verify
import (
@ -8,9 +6,8 @@ import (
"fmt"
"time"
"github.com/Safing/safing-core/configuration"
"github.com/Safing/safing-core/crypto/hash"
"github.com/Safing/safing-core/database"
"github.com/Safing/portbase/crypto/hash"
"github.com/Safing/portbase/database"
)
// useful references:
@ -24,10 +21,6 @@ import (
// RE: https://www.grc.com/revocation/crlsets.htm
// RE: RE: https://www.imperialviolet.org/2014/04/29/revocationagain.html
var (
config = configuration.Get()
)
// FullCheckBytes does a full certificate check, certificates are provided as raw bytes.
// It parses the raw certificates and calls FullCheck.
func FullCheckBytes(name string, certBytes [][]byte) (bool, error) {

View file

@ -2,10 +2,19 @@
package interception
import "github.com/Safing/safing-core/network/packet"
import "github.com/Safing/portmaster/network/packet"
var Packets chan packet.Packet
func init() {
var (
// Packets channel for feeding the firewall.
Packets = make(chan packet.Packet, 1000)
)
// Start starts the interception.
func Start() error {
return StartNfqueueInterception()
}
// Stop starts the interception.
func Stop() error {
return StopNfqueueInterception()
}

View file

@ -1,31 +1,31 @@
// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file.
package interception
import (
"github.com/Safing/safing-core/firewall/interception/windivert"
"github.com/Safing/safing-core/log"
"github.com/Safing/safing-core/modules"
"github.com/Safing/safing-core/network/packet"
"fmt"
"github.com/Safing/portmaster/firewall/interception/windivert"
"github.com/Safing/portmaster/network/packet"
)
var Packets chan packet.Packet
func init() {
// Packets channel for feeding the firewall.
Packets = make(chan packet.Packet, 1000)
}
func Start() {
windivertModule := modules.Register("Firewall:Interception:WinDivert", 192)
// Start starts the interception.
func Start() error {
wd, err := windivert.New("/WinDivert.dll", "")
if err != nil {
log.Criticalf("firewall/interception: could not init windivert: %s", err)
} else {
wd.Packets(Packets)
return fmt.Errorf("firewall/interception: could not init windivert: %s", err)
}
<-windivertModule.Stop
windivertModule.StopComplete()
return wd.Packets(Packets)
}
// Stop starts the interception.
func Stop() error {
return nil
}

View file

@ -2,45 +2,48 @@
package nfqueue
import (
"github.com/Safing/safing-core/network/packet"
"sync"
)
// suspended for now
type multiQueue struct {
qs []*nfQueue
}
func NewMultiQueue(min, max uint16) (mq *multiQueue) {
mq = &multiQueue{make([]*nfQueue, 0, max-min)}
for i := min; i < max; i++ {
mq.qs = append(mq.qs, NewNFQueue(i))
}
return mq
}
func (mq *multiQueue) Process() <-chan packet.Packet {
var (
wg sync.WaitGroup
out = make(chan packet.Packet, len(mq.qs))
)
for _, q := range mq.qs {
wg.Add(1)
go func(ch <-chan packet.Packet) {
for pkt := range ch {
out <- pkt
}
wg.Done()
}(q.Process())
}
go func() {
wg.Wait()
close(out)
}()
return out
}
func (mq *multiQueue) Destroy() {
for _, q := range mq.qs {
q.Destroy()
}
}
// import (
// "sync"
//
// "github.com/Safing/portmaster/network/packet"
// )
//
// type multiQueue struct {
// qs []*NFQueue
// }
//
// func NewMultiQueue(min, max uint16) (mq *multiQueue) {
// mq = &multiQueue{make([]*NFQueue, 0, max-min)}
// for i := min; i < max; i++ {
// mq.qs = append(mq.qs, NewNFQueue(i))
// }
// return mq
// }
//
// func (mq *multiQueue) Process() <-chan packet.Packet {
// var (
// wg sync.WaitGroup
// out = make(chan packet.Packet, len(mq.qs))
// )
// for _, q := range mq.qs {
// wg.Add(1)
// go func(ch <-chan packet.Packet) {
// for pkt := range ch {
// out <- pkt
// }
// wg.Done()
// }(q.Process())
// }
// go func() {
// wg.Wait()
// close(out)
// }()
// return out
// }
// func (mq *multiQueue) Destroy() {
// for _, q := range mq.qs {
// q.Destroy()
// }
// }

View file

@ -17,17 +17,19 @@ import (
"syscall"
"time"
"unsafe"
"errors"
"fmt"
"github.com/Safing/safing-core/network/packet"
"github.com/Safing/portmaster/network/packet"
)
var queues map[uint16]*nfQueue
var queues map[uint16]*NFQueue
func init() {
queues = make(map[uint16]*nfQueue)
queues = make(map[uint16]*NFQueue)
}
type nfQueue struct {
type NFQueue struct {
DefaultVerdict uint32
Timeout time.Duration
qid uint16
@ -38,83 +40,77 @@ type nfQueue struct {
fd int
lk sync.Mutex
pktch chan packet.Packet
Packets chan packet.Packet
}
func NewNFQueue(qid uint16) (nfq *nfQueue) {
func NewNFQueue(qid uint16) (nfq *NFQueue, err error) {
if os.Geteuid() != 0 {
panic("Must be ran by root.")
return nil, errors.New("must be root to intercept packets")
}
nfq = &nfQueue{DefaultVerdict: NFQ_ACCEPT, Timeout: 100 * time.Millisecond, qid: qid, qidptr: &qid}
nfq = &NFQueue{DefaultVerdict: NFQ_ACCEPT, Timeout: 100 * time.Millisecond, qid: qid, qidptr: &qid}
queues[nfq.qid] = nfq
return nfq
}
/*
This returns a channel that will recieve packets,
the user then must call pkt.Accept() or pkt.Drop()
*/
func (this *nfQueue) Process() <-chan packet.Packet {
if this.h != nil {
return this.pktch
err = nfq.init()
if err != nil {
return nil, err
}
this.init()
go func() {
runtime.LockOSThread()
C.loop_for_packets(this.h)
C.loop_for_packets(nfq.h)
}()
return this.pktch
return nfq, nil
}
func (this *nfQueue) init() {
func (this *NFQueue) init() error {
var err error
if this.h, err = C.nfq_open(); err != nil || this.h == nil {
panic(err)
fmt.Errorf("could not open nfqueue: %s", err)
}
//if this.qh, err = C.nfq_create_queue(this.h, qid, C.get_cb(), unsafe.Pointer(nfq)); err != nil || this.qh == nil {
this.pktch = make(chan packet.Packet, 1)
this.Packets = make(chan packet.Packet, 1)
if C.nfq_unbind_pf(this.h, C.AF_INET) < 0 {
this.Destroy()
panic("nfq_unbind_pf(AF_INET) failed, are you running root?.")
return errors.New("nfq_unbind_pf(AF_INET) failed, are you root?")
}
if C.nfq_unbind_pf(this.h, C.AF_INET6) < 0 {
this.Destroy()
panic("nfq_unbind_pf(AF_INET6) failed.")
return errors.New("nfq_unbind_pf(AF_INET6) failed")
}
if C.nfq_bind_pf(this.h, C.AF_INET) < 0 {
this.Destroy()
panic("nfq_bind_pf(AF_INET) failed.")
return errors.New("nfq_bind_pf(AF_INET) failed")
}
if C.nfq_bind_pf(this.h, C.AF_INET6) < 0 {
this.Destroy()
panic("nfq_bind_pf(AF_INET6) failed.")
return errors.New("nfq_bind_pf(AF_INET6) failed")
}
if this.qh, err = C.create_queue(this.h, C.uint16_t(this.qid)); err != nil || this.qh == nil {
C.nfq_close(this.h)
panic(err)
return fmt.Errorf("could not create queue: %s", err)
}
this.fd = int(C.nfq_fd(this.h))
if C.nfq_set_mode(this.qh, C.NFQNL_COPY_PACKET, 0xffff) < 0 {
this.Destroy()
panic("nfq_set_mode(NFQNL_COPY_PACKET) failed.")
return errors.New("nfq_set_mode(NFQNL_COPY_PACKET) failed")
}
if C.nfq_set_queue_maxlen(this.qh, 1024*8) < 0 {
this.Destroy()
panic("nfq_set_queue_maxlen(1024 * 8) failed.")
return errors.New("nfq_set_queue_maxlen(1024 * 8) failed")
}
return nil
}
func (this *nfQueue) Destroy() {
func (this *NFQueue) Destroy() {
this.lk.Lock()
defer this.lk.Unlock()
@ -131,12 +127,12 @@ func (this *nfQueue) Destroy() {
}
// TODO: don't close, we're exiting anyway
// if this.pktch != nil {
// close(this.pktch)
// if this.Packets != nil {
// close(this.Packets)
// }
}
func (this *nfQueue) Valid() bool {
func (this *NFQueue) Valid() bool {
return this.h != nil && this.qh != nil
}
@ -148,7 +144,7 @@ func go_nfq_callback(id uint32, hwproto uint16, hook uint8, mark *uint32,
qidptr := (*uint16)(data)
qid := uint16(*qidptr)
// nfq := (*nfQueue)(nfqptr)
// nfq := (*NFQueue)(nfqptr)
new_version := version
ipver := packet.IPVersion(new_version)
ipsz := C.int(ipver.ByteSize())
@ -187,7 +183,7 @@ func go_nfq_callback(id uint32, hwproto uint16, hook uint8, mark *uint32,
// fmt.Printf("%s queuing packet\n", time.Now().Format("060102 15:04:05.000"))
// BUG: "panic: send on closed channel" when shutting down
queues[qid].pktch <- &pkt
queues[qid].Packets <- &pkt
select {
case v = <-pkt.verdict:

View file

@ -5,7 +5,7 @@ package nfqueue
import (
"fmt"
"github.com/Safing/safing-core/network/packet"
"github.com/Safing/portmaster/network/packet"
)
var (

View file

@ -1,31 +1,33 @@
// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file.
// +build linux
package interception
import (
"fmt"
"sort"
"strings"
"github.com/coreos/go-iptables/iptables"
"github.com/Safing/safing-core/firewall/interception/nfqueue"
"github.com/Safing/safing-core/log"
"github.com/Safing/safing-core/modules"
"github.com/Safing/portmaster/firewall/interception/nfqueue"
)
// iptables -A OUTPUT -p icmp -j", "NFQUEUE", "--queue-num", "1", "--queue-bypass
var nfqueueModule *modules.Module
var (
v4chains []string
v4rules []string
v4once []string
var v4chains []string
var v4rules []string
var v4once []string
v6chains []string
v6rules []string
v6once []string
var v6chains []string
var v6rules []string
var v6once []string
out4Queue *nfqueue.NFQueue
in4Queue *nfqueue.NFQueue
out6Queue *nfqueue.NFQueue
in6Queue *nfqueue.NFQueue
shutdownSignal = make(chan struct{})
)
func init() {
@ -100,8 +102,8 @@ func init() {
}
// Reverse because we'd like to insert in a loop
sort.Reverse(sort.StringSlice(v4once))
sort.Reverse(sort.StringSlice(v6once))
_ = sort.Reverse(sort.StringSlice(v4once)) // silence vet (sort is used just like in the docs)
_ = sort.Reverse(sort.StringSlice(v6once)) // silence vet (sort is used just like in the docs)
}
@ -127,9 +129,10 @@ func activateNfqueueFirewall() error {
}
}
var ok bool
for _, rule := range v4once {
splittedRule := strings.Split(rule, " ")
ok, err := ip4tables.Exists(splittedRule[0], splittedRule[1], splittedRule[2:]...)
ok, err = ip4tables.Exists(splittedRule[0], splittedRule[1], splittedRule[2:]...)
if err != nil {
return err
}
@ -183,9 +186,10 @@ func deactivateNfqueueFirewall() error {
return err
}
var ok bool
for _, rule := range v4once {
splittedRule := strings.Split(rule, " ")
ok, err := ip4tables.Exists(splittedRule[0], splittedRule[1], splittedRule[2:]...)
ok, err = ip4tables.Exists(splittedRule[0], splittedRule[1], splittedRule[2:]...)
if err != nil {
return err
}
@ -198,10 +202,10 @@ func deactivateNfqueueFirewall() error {
for _, chain := range v4chains {
splittedRule := strings.Split(chain, " ")
if err := ip4tables.ClearChain(splittedRule[0], splittedRule[1]); err != nil {
if err = ip4tables.ClearChain(splittedRule[0], splittedRule[1]); err != nil {
return err
}
if err := ip4tables.DeleteChain(splittedRule[0], splittedRule[1]); err != nil {
if err = ip4tables.DeleteChain(splittedRule[0], splittedRule[1]); err != nil {
return err
}
}
@ -238,70 +242,84 @@ func deactivateNfqueueFirewall() error {
return nil
}
func Start() {
// StartNfqueueInterception starts the nfqueue interception.
func StartNfqueueInterception() (err error) {
nfqueueModule = modules.Register("Firewall:Interception:Nfqueue", 192)
if err := activateNfqueueFirewall(); err != nil {
log.Criticalf("could not activate firewall for nfqueue: %q", err)
err = activateNfqueueFirewall()
if err != nil {
Stop()
return fmt.Errorf("could not initialize nfqueue: %s", err)
}
out4Queue := nfqueue.NewNFQueue(17040)
in4Queue := nfqueue.NewNFQueue(17140)
out6Queue := nfqueue.NewNFQueue(17060)
in6Queue := nfqueue.NewNFQueue(17160)
out4Queue, err = nfqueue.NewNFQueue(17040)
if err != nil {
Stop()
return fmt.Errorf("interception: failed to create nfqueue(IPv4, in): %s", err)
}
in4Queue, err = nfqueue.NewNFQueue(17140)
if err != nil {
Stop()
return fmt.Errorf("interception: failed to create nfqueue(IPv4, in): %s", err)
}
out6Queue, err = nfqueue.NewNFQueue(17060)
if err != nil {
Stop()
return fmt.Errorf("interception: failed to create nfqueue(IPv4, in): %s", err)
}
in6Queue, err = nfqueue.NewNFQueue(17160)
if err != nil {
Stop()
return fmt.Errorf("interception: failed to create nfqueue(IPv4, in): %s", err)
}
out4Channel := out4Queue.Process()
// if err != nil {
// log.Criticalf("could not open nfqueue out4")
// } else {
defer out4Queue.Destroy()
// }
in4Channel := in4Queue.Process()
// if err != nil {
// log.Criticalf("could not open nfqueue in4")
// } else {
defer in4Queue.Destroy()
// }
out6Channel := out6Queue.Process()
// if err != nil {
// log.Criticalf("could not open nfqueue out6")
// } else {
defer out6Queue.Destroy()
// }
in6Channel := in6Queue.Process()
// if err != nil {
// log.Criticalf("could not open nfqueue in6")
// } else {
defer in6Queue.Destroy()
// }
go handleInterception()
return nil
}
packetInterceptionLoop:
// StopNfqueueInterception stops the nfqueue interception.
func StopNfqueueInterception() error {
defer close(shutdownSignal)
if out4Queue != nil {
out4Queue.Destroy()
}
if in4Queue != nil {
in4Queue.Destroy()
}
if out6Queue != nil {
out6Queue.Destroy()
}
if in6Queue != nil {
in6Queue.Destroy()
}
err := deactivateNfqueueFirewall()
if err != nil {
return fmt.Errorf("interception: error while deactivating nfqueue: %s", err)
}
return nil
}
func handleInterception() {
for {
select {
case <-nfqueueModule.Stop:
break packetInterceptionLoop
case pkt := <-out4Channel:
case <-shutdownSignal:
return
case pkt := <-out4Queue.Packets:
pkt.SetOutbound()
Packets <- pkt
case pkt := <-in4Channel:
case pkt := <-in4Queue.Packets:
pkt.SetInbound()
Packets <- pkt
case pkt := <-out6Channel:
case pkt := <-out6Queue.Packets:
pkt.SetOutbound()
Packets <- pkt
case pkt := <-in6Channel:
case pkt := <-in6Queue.Packets:
pkt.SetInbound()
Packets <- pkt
}
}
if err := deactivateNfqueueFirewall(); err != nil {
log.Criticalf("could not deactivate firewall for nfqueue: %q", err)
}
nfqueueModule.StopComplete()
}
func stringInSlice(slice []string, value string) bool {

340
firewall/master.go Normal file
View file

@ -0,0 +1,340 @@
package firewall
import (
"fmt"
"os"
"strings"
"github.com/Safing/portbase/log"
"github.com/Safing/portmaster/intel"
"github.com/Safing/portmaster/network"
"github.com/Safing/portmaster/network/packet"
"github.com/Safing/portmaster/profile"
"github.com/Safing/portmaster/status"
"github.com/agext/levenshtein"
)
// Call order:
//
// 1. DecideOnConnectionBeforeIntel (if connecting to domain)
// is called when a DNS query is made, before the query is resolved
// 2. DecideOnConnectionAfterIntel (if connecting to domain)
// is called when a DNS query is made, after the query is resolved
// 3. DecideOnConnection
// is called when the first packet of the first link of the connection arrives
// 4. DecideOnLink
// is called when when the first packet of a link arrives only if connection has verdict UNDECIDED or CANTSAY
// DecideOnConnectionBeforeIntel makes a decision about a connection before the dns query is resolved and intel is gathered.
func DecideOnConnectionBeforeIntel(connection *network.Connection, fqdn string) {
// check:
// Profile.DomainWhitelist
// Profile.Flags
// - process specific: System, Admin, User
// - network specific: Internet, LocalNet
// grant self
if connection.Process().Pid == os.Getpid() {
log.Infof("firewall: granting own connection %s", connection)
connection.Accept("")
return
}
// check if there is a profile
profileSet := connection.Process().ProfileSet()
if profileSet == nil {
log.Errorf("firewall: denying connection %s, no profile set", connection)
connection.Deny("no profile set")
return
}
profileSet.Update(status.CurrentSecurityLevel())
// check for any network access
if !profileSet.CheckFlag(profile.Internet) && !profileSet.CheckFlag(profile.LAN) {
log.Infof("firewall: denying connection %s, accessing Internet or LAN not allowed", connection)
connection.Deny("accessing Internet or LAN not allowed")
return
}
// check domain list
permitted, reason, ok := profileSet.CheckEndpoint(fqdn, 0, 0, false)
if ok {
if permitted {
log.Infof("firewall: accepting connection %s, endpoint is whitelisted: %s", connection, reason)
connection.Accept(fmt.Sprintf("endpoint is whitelisted: %s", reason))
} else {
log.Infof("firewall: denying connection %s, endpoint is blacklisted", connection)
connection.Deny("endpoint is blacklisted")
}
return
}
switch profileSet.GetProfileMode() {
case profile.Whitelist:
log.Infof("firewall: denying connection %s, domain is not whitelisted", connection)
connection.Deny("domain is not whitelisted")
case profile.Prompt:
// check Related flag
// TODO: improve this!
if profileSet.CheckFlag(profile.Related) {
matched := false
pathElements := strings.Split(connection.Process().Path, "/") // FIXME: path seperator
// only look at the last two path segments
if len(pathElements) > 2 {
pathElements = pathElements[len(pathElements)-2:]
}
domainElements := strings.Split(fqdn, ".")
var domainElement string
var processElement string
matchLoop:
for _, domainElement = range domainElements {
for _, pathElement := range pathElements {
if levenshtein.Match(domainElement, pathElement, nil) > 0.5 {
matched = true
processElement = pathElement
break matchLoop
}
}
if levenshtein.Match(domainElement, profileSet.UserProfile().Name, nil) > 0.5 {
matched = true
processElement = profileSet.UserProfile().Name
break matchLoop
}
if levenshtein.Match(domainElement, connection.Process().Name, nil) > 0.5 {
matched = true
processElement = connection.Process().Name
break matchLoop
}
if levenshtein.Match(domainElement, connection.Process().ExecName, nil) > 0.5 {
matched = true
processElement = connection.Process().ExecName
break matchLoop
}
}
if matched {
log.Infof("firewall: accepting connection %s, match to domain was found: %s ~= %s", connection, domainElement, processElement)
connection.Accept("domain is related to process")
}
}
if connection.GetVerdict() != network.ACCEPT {
// TODO
log.Infof("firewall: accepting connection %s, domain permitted (prompting is not yet implemented)", connection)
connection.Accept("domain permitted (prompting is not yet implemented)")
}
case profile.Blacklist:
log.Infof("firewall: accepting connection %s, domain is not blacklisted", connection)
connection.Accept("domain is not blacklisted")
}
}
// DecideOnConnectionAfterIntel makes a decision about a connection after the dns query is resolved and intel is gathered.
func DecideOnConnectionAfterIntel(connection *network.Connection, fqdn string, rrCache *intel.RRCache) *intel.RRCache {
// grant self
if connection.Process().Pid == os.Getpid() {
log.Infof("firewall: granting own connection %s", connection)
connection.Accept("")
return rrCache
}
// check if there is a profile
profileSet := connection.Process().ProfileSet()
if profileSet == nil {
log.Errorf("firewall: denying connection %s, no profile set", connection)
connection.Deny("no profile")
return rrCache
}
profileSet.Update(status.CurrentSecurityLevel())
// TODO: Stamp integration
// TODO: Gate17 integration
// tunnelInfo, err := AssignTunnelIP(fqdn)
rrCache.Duplicate().FilterEntries(profileSet.CheckFlag(profile.Internet), profileSet.CheckFlag(profile.LAN), false)
if len(rrCache.Answer) == 0 {
if profileSet.CheckFlag(profile.Internet) {
connection.Deny("server is located in the LAN, but LAN access is not permitted")
} else {
connection.Deny("server is located in the Internet, but Internet access is not permitted")
}
}
return rrCache
}
// DeciceOnConnection makes a decision about a connection with its first packet.
func DecideOnConnection(connection *network.Connection, pkt packet.Packet) {
// grant self
if connection.Process().Pid == os.Getpid() {
log.Infof("firewall: granting own connection %s", connection)
connection.Accept("")
return
}
// check if there is a profile
profileSet := connection.Process().ProfileSet()
if profileSet == nil {
log.Errorf("firewall: denying connection %s, no profile set", connection)
connection.Deny("no profile")
return
}
profileSet.Update(status.CurrentSecurityLevel())
// check connection type
switch connection.Domain {
case network.IncomingHost, network.IncomingLAN, network.IncomingInternet, network.IncomingInvalid:
if !profileSet.CheckFlag(profile.Service) {
log.Infof("firewall: denying connection %s, not a service", connection)
if connection.Domain == network.IncomingHost {
connection.Block("not a service")
} else {
connection.Drop("not a service")
}
return
}
case network.PeerLAN, network.PeerInternet, network.PeerInvalid: // Important: PeerHost is and should be missing!
if !profileSet.CheckFlag(profile.PeerToPeer) {
log.Infof("firewall: denying connection %s, peer to peer connections (to an IP) not allowed", connection)
connection.Deny("peer to peer connections (to an IP) not allowed")
return
}
default:
}
// check network scope
switch connection.Domain {
case network.IncomingHost:
if !profileSet.CheckFlag(profile.Localhost) {
log.Infof("firewall: denying connection %s, serving localhost not allowed", connection)
connection.Block("serving localhost not allowed")
return
}
case network.IncomingLAN:
if !profileSet.CheckFlag(profile.LAN) {
log.Infof("firewall: denying connection %s, serving LAN not allowed", connection)
connection.Deny("serving LAN not allowed")
return
}
case network.IncomingInternet:
if !profileSet.CheckFlag(profile.Internet) {
log.Infof("firewall: denying connection %s, serving Internet not allowed", connection)
connection.Deny("serving Internet not allowed")
return
}
case network.IncomingInvalid:
log.Infof("firewall: denying connection %s, invalid IP address", connection)
connection.Drop("invalid IP address")
return
case network.PeerHost:
if !profileSet.CheckFlag(profile.Localhost) {
log.Infof("firewall: denying connection %s, accessing localhost not allowed", connection)
connection.Block("accessing localhost not allowed")
return
}
case network.PeerLAN:
if !profileSet.CheckFlag(profile.LAN) {
log.Infof("firewall: denying connection %s, accessing the LAN not allowed", connection)
connection.Deny("accessing the LAN not allowed")
return
}
case network.PeerInternet:
if !profileSet.CheckFlag(profile.Internet) {
log.Infof("firewall: denying connection %s, accessing the Internet not allowed", connection)
connection.Deny("accessing the Internet not allowed")
return
}
case network.PeerInvalid:
log.Infof("firewall: denying connection %s, invalid IP address", connection)
connection.Deny("invalid IP address")
return
}
log.Infof("firewall: accepting connection %s", connection)
connection.Accept("")
}
// DecideOnLink makes a decision about a link with the first packet.
func DecideOnLink(connection *network.Connection, link *network.Link, pkt packet.Packet) {
// check:
// Profile.Flags
// - network specific: Internet, LocalNet
// Profile.ConnectPorts
// Profile.ListenPorts
// grant self
if connection.Process().Pid == os.Getpid() {
log.Infof("firewall: granting own link %s", connection)
connection.Accept("")
return
}
// check if there is a profile
profileSet := connection.Process().ProfileSet()
if profileSet == nil {
log.Infof("firewall: no profile, denying %s", link)
link.Block("no profile")
return
}
profileSet.Update(status.CurrentSecurityLevel())
// get host
var domainOrIP string
switch {
case strings.HasSuffix(connection.Domain, "."):
domainOrIP = connection.Domain
case connection.Direction:
domainOrIP = pkt.GetIPHeader().Src.String()
default:
domainOrIP = pkt.GetIPHeader().Dst.String()
}
// get protocol / destination port
protocol := pkt.GetIPHeader().Protocol
var dstPort uint16
tcpUDPHeader := pkt.GetTCPUDPHeader()
if tcpUDPHeader != nil {
dstPort = tcpUDPHeader.DstPort
}
// check endpoints list
permitted, reason, ok := profileSet.CheckEndpoint(domainOrIP, uint8(protocol), dstPort, connection.Direction)
if ok {
if permitted {
log.Infof("firewall: accepting link %s, endpoint is whitelisted: %s", link, reason)
link.Accept(fmt.Sprintf("port whitelisted: %s", reason))
} else {
log.Infof("firewall: denying link %s: port %d is blacklisted", link, dstPort)
link.Deny("port blacklisted")
}
return
}
switch profileSet.GetProfileMode() {
case profile.Whitelist:
log.Infof("firewall: denying link %s: endpoint %d is not whitelisted", link, dstPort)
link.Deny("endpoint is not whitelisted")
return
case profile.Prompt:
log.Infof("firewall: accepting link %s: endpoint %d is blacklisted", link, dstPort)
link.Accept("endpoint permitted (prompting is not yet implemented)")
return
case profile.Blacklist:
log.Infof("firewall: accepting link %s: endpoint %d is not blacklisted", link, dstPort)
link.Accept("endpoint is not blacklisted")
return
}
log.Infof("firewall: accepting link %s", link)
link.Accept("")
}

View file

@ -1,4 +1,4 @@
package portmaster
package firewall
import (
"errors"
@ -7,8 +7,8 @@ import (
"sync"
"time"
"github.com/Safing/safing-core/crypto/random"
"github.com/Safing/safing-core/intel"
"github.com/Safing/portbase/crypto/random"
"github.com/Safing/portmaster/intel"
"github.com/miekg/dns"
)

49
global/databases.go Normal file
View file

@ -0,0 +1,49 @@
package global
import (
"github.com/Safing/portbase/database"
"github.com/Safing/portbase/modules"
// module dependencies
_ "github.com/Safing/portbase/database/dbmodule"
_ "github.com/Safing/portbase/database/storage/badger"
_ "github.com/Safing/portmaster/status"
)
func init() {
modules.Register("global", nil, start, nil, "database", "status")
}
func start() error {
_, err := database.Register(&database.Database{
Name: "core",
Description: "Holds core data, such as settings and profiles",
StorageType: "badger",
PrimaryAPI: "",
})
if err != nil {
return err
}
_, err = database.Register(&database.Database{
Name: "cache",
Description: "Cached data, such as Intelligence and DNS Records",
StorageType: "badger",
PrimaryAPI: "",
})
if err != nil {
return err
}
// _, err = database.Register(&database.Database{
// Name: "history",
// Description: "Historic event data",
// StorageType: "badger",
// PrimaryAPI: "",
// })
// if err != nil {
// return err
// }
return nil
}

94
intel/clients.go Normal file
View file

@ -0,0 +1,94 @@
package intel
import (
"crypto/tls"
"sync"
"time"
"github.com/miekg/dns"
)
type clientManager struct {
dnsClient *dns.Client
factory func() *dns.Client
lock sync.Mutex
refreshAfter time.Time
ttl time.Duration // force refresh of connection to reduce traceability
}
// ref: https://godoc.org/github.com/miekg/dns#Client
func newDNSClientManager(resolver *Resolver) *clientManager {
return &clientManager{
ttl: -1 * time.Minute,
factory: func() *dns.Client {
return &dns.Client{
Timeout: 5 * time.Second,
}
},
}
}
func newTCPClientManager(resolver *Resolver) *clientManager {
return &clientManager{
ttl: -15 * time.Minute,
factory: func() *dns.Client {
return &dns.Client{
Net: "tcp",
Timeout: 5 * time.Second,
}
},
}
}
func newTLSClientManager(resolver *Resolver) *clientManager {
return &clientManager{
ttl: -15 * time.Minute,
factory: func() *dns.Client {
return &dns.Client{
Net: "tcp-tls",
TLSConfig: &tls.Config{
MinVersion: tls.VersionTLS12,
ServerName: resolver.VerifyDomain,
// TODO: use custom random
// Rand: io.Reader,
},
Timeout: 5 * time.Second,
}
},
}
}
func newHTTPSClientManager(resolver *Resolver) *clientManager {
return &clientManager{
ttl: -15 * time.Minute,
factory: func() *dns.Client {
new := &dns.Client{
Net: "https",
TLSConfig: &tls.Config{
MinVersion: tls.VersionTLS12,
// TODO: use custom random
// Rand: io.Reader,
},
Timeout: 5 * time.Second,
}
if resolver.VerifyDomain != "" {
new.TLSConfig.ServerName = resolver.VerifyDomain
}
return new
},
}
}
func (cm *clientManager) getDNSClient() *dns.Client {
cm.lock.Lock()
defer cm.lock.Unlock()
if cm.dnsClient == nil || time.Now().After(cm.refreshAfter) {
cm.dnsClient = cm.factory()
cm.refreshAfter = time.Now().Add(cm.ttl)
}
return cm.dnsClient
}

100
intel/config.go Normal file
View file

@ -0,0 +1,100 @@
package intel
import (
"github.com/Safing/portbase/config"
"github.com/Safing/portmaster/status"
)
var (
configuredNameServers config.StringArrayOption
defaultNameServers = []string{
"tls|1.1.1.1:853|cloudflare-dns.com", // Cloudflare
"tls|1.0.0.1:853|cloudflare-dns.com", // Cloudflare
"tls|9.9.9.9:853|dns.quad9.net", // Quad9
// "https|cloudflare-dns.com/dns-query", // HTTPS still experimental
"dns|1.1.1.1:53", // Cloudflare
"dns|1.0.0.1:53", // Cloudflare
"dns|9.9.9.9:53", // Quad9
}
nameserverRetryRate config.IntOption
doNotUseMulticastDNS status.SecurityLevelOption
doNotUseAssignedNameservers status.SecurityLevelOption
doNotResolveSpecialDomains status.SecurityLevelOption
)
func prep() error {
err := config.Register(&config.Option{
Name: "Nameservers (DNS)",
Key: "intel/nameservers",
Description: "Nameserver to use for resolving DNS requests.",
ExpertiseLevel: config.ExpertiseLevelExpert,
OptType: config.OptTypeStringArray,
DefaultValue: defaultNameServers,
ValidationRegex: "^(dns|tcp|tls|https)$",
})
if err != nil {
return err
}
configuredNameServers = config.Concurrent.GetAsStringArray("intel/nameservers", defaultNameServers)
err = config.Register(&config.Option{
Name: "Nameserver Retry Rate",
Key: "intel/nameserverRetryRate",
Description: "Rate at which to retry failed nameservers, in seconds.",
ExpertiseLevel: config.ExpertiseLevelExpert,
OptType: config.OptTypeInt,
DefaultValue: 600,
})
if err != nil {
return err
}
nameserverRetryRate = config.Concurrent.GetAsInt("intel/nameserverRetryRate", 0)
err = config.Register(&config.Option{
Name: "Do not use Multicast DNS",
Key: "intel/doNotUseMulticastDNS",
Description: "Multicast DNS queries other devices in the local network",
ExpertiseLevel: config.ExpertiseLevelExpert,
OptType: config.OptTypeInt,
ExternalOptType: "security level",
DefaultValue: 3,
ValidationRegex: "^(1|2|3)$",
})
if err != nil {
return err
}
doNotUseMulticastDNS = status.ConfigIsActiveConcurrent("intel/doNotUseMulticastDNS")
err = config.Register(&config.Option{
Name: "Do not use assigned Nameservers",
Key: "intel/doNotUseAssignedNameservers",
Description: "that were acquired by the network (dhcp) or system",
ExpertiseLevel: config.ExpertiseLevelExpert,
OptType: config.OptTypeInt,
ExternalOptType: "security level",
DefaultValue: 3,
ValidationRegex: "^(1|2|3)$",
})
if err != nil {
return err
}
doNotUseAssignedNameservers = status.ConfigIsActiveConcurrent("intel/doNotUseAssignedNameservers")
err = config.Register(&config.Option{
Name: "Do not resolve special domains",
Key: "intel/doNotResolveSpecialDomains",
Description: "Do not resolve special (top level) domains: example, example.com, example.net, example.org, invalid, test, onion. (RFC6761, RFC7686)",
ExpertiseLevel: config.ExpertiseLevelExpert,
OptType: config.OptTypeInt,
ExternalOptType: "security level",
DefaultValue: 3,
ValidationRegex: "^(1|2|3)$",
})
if err != nil {
return err
}
doNotResolveSpecialDomains = status.ConfigIsActiveConcurrent("intel/doNotResolveSpecialDomains")
return nil
}

View file

@ -1,62 +0,0 @@
// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file.
package intel
import (
"github.com/Safing/safing-core/database"
datastore "github.com/ipfs/go-datastore"
)
// EntityClassification holds classification information about an internet entity.
type EntityClassification struct {
lists []byte
}
// Intel holds intelligence data for a domain.
type Intel struct {
database.Base
Domain string
DomainOwner string
CertOwner string
Classification *EntityClassification
}
var intelModel *Intel // only use this as parameter for database.EnsureModel-like functions
func init() {
database.RegisterModel(intelModel, func() database.Model { return new(Intel) })
}
// Create saves the Intel with the provided name in the default namespace.
func (m *Intel) Create(name string) error {
return m.CreateObject(&database.IntelCache, name, m)
}
// CreateInNamespace saves the Intel with the provided name in the provided namespace.
func (m *Intel) CreateInNamespace(namespace *datastore.Key, name string) error {
return m.CreateObject(namespace, name, m)
}
// Save saves the Intel.
func (m *Intel) Save() error {
return m.SaveObject(m)
}
// getIntel fetches the Intel with the provided name in the default namespace.
func getIntel(name string) (*Intel, error) {
return getIntelFromNamespace(&database.IntelCache, name)
}
// getIntelFromNamespace fetches the Intel with the provided name in the provided namespace.
func getIntelFromNamespace(namespace *datastore.Key, name string) (*Intel, error) {
object, err := database.GetAndEnsureModel(namespace, name, intelModel)
if err != nil {
return nil, err
}
model, ok := object.(*Intel)
if !ok {
return nil, database.NewMismatchError(object, intelModel)
}
return model, nil
}

View file

@ -1,218 +0,0 @@
// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file.
package intel
import (
"fmt"
"net"
"time"
"github.com/Safing/safing-core/database"
datastore "github.com/ipfs/go-datastore"
"github.com/miekg/dns"
)
// RRCache is used to cache DNS data
type RRCache struct {
Answer []dns.RR
Ns []dns.RR
Extra []dns.RR
Expires int64
Modified int64
servedFromCache bool
requestingNew bool
}
func (m *RRCache) Clean(minExpires uint32) {
var lowestTTL uint32 = 0xFFFFFFFF
var header *dns.RR_Header
// set TTLs to 17
// TODO: double append? is there something more elegant?
for _, rr := range append(m.Answer, append(m.Ns, m.Extra...)...) {
header = rr.Header()
if lowestTTL > header.Ttl {
lowestTTL = header.Ttl
}
header.Ttl = 17
}
// TTL must be at least minExpires
if lowestTTL < minExpires {
lowestTTL = minExpires
}
m.Expires = time.Now().Unix() + int64(lowestTTL)
m.Modified = time.Now().Unix()
}
func (m *RRCache) ExportAllARecords() (ips []net.IP) {
for _, rr := range m.Answer {
if rr.Header().Class == dns.ClassINET && rr.Header().Rrtype == dns.TypeA {
aRecord, ok := rr.(*dns.A)
if ok {
ips = append(ips, aRecord.A)
}
} else if rr.Header().Class == dns.ClassINET && rr.Header().Rrtype == dns.TypeAAAA {
aRecord, ok := rr.(*dns.AAAA)
if ok {
ips = append(ips, aRecord.AAAA)
}
}
}
return
}
func (m *RRCache) ToRRSave() *RRSave {
var s RRSave
s.Expires = m.Expires
s.Modified = m.Modified
for _, entry := range m.Answer {
s.Answer = append(s.Answer, entry.String())
}
for _, entry := range m.Ns {
s.Ns = append(s.Ns, entry.String())
}
for _, entry := range m.Extra {
s.Extra = append(s.Extra, entry.String())
}
return &s
}
func (m *RRCache) Create(name string) error {
s := m.ToRRSave()
return s.CreateObject(&database.DNSCache, name, s)
}
func (m *RRCache) CreateWithType(name string, qtype dns.Type) error {
s := m.ToRRSave()
return s.Create(fmt.Sprintf("%s%s", name, qtype.String()))
}
func (m *RRCache) Save() error {
s := m.ToRRSave()
return s.SaveObject(s)
}
func GetRRCache(domain string, qtype dns.Type) (*RRCache, error) {
return GetRRCacheFromNamespace(&database.DNSCache, domain, qtype)
}
func GetRRCacheFromNamespace(namespace *datastore.Key, domain string, qtype dns.Type) (*RRCache, error) {
var m RRCache
rrSave, err := GetRRSaveFromNamespace(namespace, domain, qtype)
if err != nil {
return nil, err
}
m.Expires = rrSave.Expires
m.Modified = rrSave.Modified
for _, entry := range rrSave.Answer {
rr, err := dns.NewRR(entry)
if err == nil {
m.Answer = append(m.Answer, rr)
}
}
for _, entry := range rrSave.Ns {
rr, err := dns.NewRR(entry)
if err == nil {
m.Ns = append(m.Ns, rr)
}
}
for _, entry := range rrSave.Extra {
rr, err := dns.NewRR(entry)
if err == nil {
m.Extra = append(m.Extra, rr)
}
}
m.servedFromCache = true
return &m, nil
}
// ServedFromCache marks the RRCache as served from cache.
func (m *RRCache) ServedFromCache() bool {
return m.servedFromCache
}
// RequestingNew informs that it has expired and new RRs are being fetched.
func (m *RRCache) RequestingNew() bool {
return m.requestingNew
}
// Flags formats ServedFromCache and RequestingNew to a condensed, flag-like format.
func (m *RRCache) Flags() string {
switch {
case m.servedFromCache && m.requestingNew:
return " [CR]"
case m.servedFromCache:
return " [C]"
case m.requestingNew:
return " [R]" // theoretically impossible, but let's leave it here, just in case
default:
return ""
}
}
// IsNXDomain returnes whether the result is nxdomain.
func (m *RRCache) IsNXDomain() bool {
return len(m.Answer) == 0
}
// RRSave is helper struct to RRCache to better save data to the database.
type RRSave struct {
database.Base
Answer []string
Ns []string
Extra []string
Expires int64
Modified int64
}
var rrSaveModel *RRSave // only use this as parameter for database.EnsureModel-like functions
func init() {
database.RegisterModel(rrSaveModel, func() database.Model { return new(RRSave) })
}
// Create saves RRSave with the provided name in the default namespace.
func (m *RRSave) Create(name string) error {
return m.CreateObject(&database.DNSCache, name, m)
}
// CreateWithType saves RRSave with the provided name and type in the default namespace.
func (m *RRSave) CreateWithType(name string, qtype dns.Type) error {
return m.Create(fmt.Sprintf("%s%s", name, qtype.String()))
}
// CreateInNamespace saves RRSave with the provided name in the provided namespace.
func (m *RRSave) CreateInNamespace(namespace *datastore.Key, name string) error {
return m.CreateObject(namespace, name, m)
}
// Save saves RRSave.
func (m *RRSave) Save() error {
return m.SaveObject(m)
}
// GetRRSave fetches RRSave with the provided name in the default namespace.
func GetRRSave(name string, qtype dns.Type) (*RRSave, error) {
return GetRRSaveFromNamespace(&database.DNSCache, name, qtype)
}
// GetRRSaveFromNamespace fetches RRSave with the provided name in the provided namespace.
func GetRRSaveFromNamespace(namespace *datastore.Key, name string, qtype dns.Type) (*RRSave, error) {
object, err := database.GetAndEnsureModel(namespace, fmt.Sprintf("%s%s", name, qtype.String()), rrSaveModel)
if err != nil {
return nil, err
}
model, ok := object.(*RRSave)
if !ok {
return nil, database.NewMismatchError(object, rrSaveModel)
}
return model, nil
}

View file

@ -1,48 +0,0 @@
// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file.
package intel
import (
"github.com/Safing/safing-core/log"
"sync"
"github.com/miekg/dns"
)
var (
dfMap = make(map[string]string)
dfMapLock sync.RWMutex
)
func checkDomainFronting(hidden string, qtype dns.Type, securityLevel int8) (*RRCache, bool) {
dfMapLock.RLock()
front, ok := dfMap[hidden]
dfMapLock.RUnlock()
if !ok {
return nil, false
}
log.Tracef("intel: applying domain fronting %s -> %s", hidden, front)
// get domain name
rrCache := resolveAndCache(front, qtype, securityLevel)
if rrCache == nil {
return nil, true
}
// replace domain name
var header *dns.RR_Header
for _, rr := range rrCache.Answer {
header = rr.Header()
if header.Name == front {
header.Name = hidden
}
}
// save under front
rrCache.CreateWithType(hidden, qtype)
return rrCache, true
}
func addDomainFronting(hidden string, front string) {
dfMapLock.Lock()
dfMap[hidden] = front
dfMapLock.Unlock()
return
}

View file

@ -3,44 +3,66 @@
package intel
import (
"github.com/Safing/safing-core/database"
"github.com/Safing/safing-core/modules"
"fmt"
"sync"
"github.com/miekg/dns"
"github.com/Safing/portbase/database"
"github.com/Safing/portbase/database/record"
)
var (
intelModule *modules.Module
intelDatabase = database.NewInterface(&database.Options{
AlwaysSetRelativateExpiry: 2592000, // 30 days
})
)
func init() {
intelModule = modules.Register("Intel", 128)
go Start()
// Intel holds intelligence data for a domain.
type Intel struct {
record.Base
sync.Mutex
Domain string
}
// GetIntel returns an Intel object of the given domain. The returned Intel object MUST not be modified.
func GetIntel(domain string) *Intel {
fqdn := dns.Fqdn(domain)
intel, err := getIntel(fqdn)
func makeIntelKey(domain string) string {
return fmt.Sprintf("cache:intel/domain/%s", domain)
}
// GetIntelFromDB gets an Intel record from the database.
func GetIntelFromDB(domain string) (*Intel, error) {
key := makeIntelKey(domain)
r, err := intelDatabase.Get(key)
if err != nil {
if err == database.ErrNotFound {
intel = &Intel{Domain: fqdn}
intel.Create(fqdn)
} else {
return nil
}
return nil, err
}
return intel
// unwrap
if r.IsWrapped() {
// only allocate a new struct, if we need it
new := &Intel{}
err = record.Unwrap(r, new)
if err != nil {
return nil, err
}
return new, nil
}
// or adjust type
new, ok := r.(*Intel)
if !ok {
return nil, fmt.Errorf("record not of type *Intel, but %T", r)
}
return new, nil
}
func GetIntelAndRRs(domain string, qtype dns.Type, securityLevel int8) (intel *Intel, rrs *RRCache) {
intel = GetIntel(domain)
rrs = Resolve(domain, qtype, securityLevel)
return
// Save saves the Intel record to the database.
func (intel *Intel) Save() error {
intel.SetKey(makeIntelKey(intel.Domain))
return intelDatabase.PutNew(intel)
}
func Start() {
// mocking until intel has its own goroutines
defer intelModule.StopComplete()
<-intelModule.Stop
// GetIntel fetches intelligence data for the given domain.
func GetIntel(domain string) (*Intel, error) {
return &Intel{Domain: domain}, nil
}

View file

@ -1,61 +1,91 @@
// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file.
package intel
import (
"fmt"
"strings"
"sync"
"github.com/Safing/safing-core/database"
"github.com/Safing/portbase/database"
"github.com/Safing/portbase/database/record"
"github.com/Safing/portbase/utils"
)
datastore "github.com/ipfs/go-datastore"
var (
ipInfoDatabase = database.NewInterface(&database.Options{
AlwaysSetRelativateExpiry: 86400, // 24 hours
})
)
// IPInfo represents various information about an IP.
type IPInfo struct {
database.Base
record.Base
sync.Mutex
IP string
Domains []string
}
var ipInfoModel *IPInfo // only use this as parameter for database.EnsureModel-like functions
func init() {
database.RegisterModel(ipInfoModel, func() database.Model { return new(IPInfo) })
func makeIPInfoKey(ip string) string {
return fmt.Sprintf("cache:intel/ipInfo/%s", ip)
}
// Create saves the IPInfo with the provided name in the default namespace.
func (m *IPInfo) Create(name string) error {
return m.CreateObject(&database.IPInfoCache, name, m)
}
// GetIPInfo gets an IPInfo record from the database.
func GetIPInfo(ip string) (*IPInfo, error) {
key := makeIPInfoKey(ip)
// CreateInNamespace saves the IPInfo with the provided name in the provided namespace.
func (m *IPInfo) CreateInNamespace(namespace *datastore.Key, name string) error {
return m.CreateObject(namespace, name, m)
}
// Save saves the IPInfo.
func (m *IPInfo) Save() error {
return m.SaveObject(m)
}
// GetIPInfo fetches the IPInfo with the provided name in the default namespace.
func GetIPInfo(name string) (*IPInfo, error) {
return GetIPInfoFromNamespace(&database.IPInfoCache, name)
}
// GetIPInfoFromNamespace fetches the IPInfo with the provided name in the provided namespace.
func GetIPInfoFromNamespace(namespace *datastore.Key, name string) (*IPInfo, error) {
object, err := database.GetAndEnsureModel(namespace, name, ipInfoModel)
r, err := ipInfoDatabase.Get(key)
if err != nil {
return nil, err
}
model, ok := object.(*IPInfo)
if !ok {
return nil, database.NewMismatchError(object, ipInfoModel)
// unwrap
if r.IsWrapped() {
// only allocate a new struct, if we need it
new := &IPInfo{}
err = record.Unwrap(r, new)
if err != nil {
return nil, err
}
return new, nil
}
return model, nil
// or adjust type
new, ok := r.(*IPInfo)
if !ok {
return nil, fmt.Errorf("record not of type *IPInfo, but %T", r)
}
return new, nil
}
// AddDomain adds a domain to the list and reports back if it was added, or was already present.
func (ipi *IPInfo) AddDomain(domain string) (added bool) {
ipi.Lock()
defer ipi.Unlock()
if !utils.StringInSlice(ipi.Domains, domain) {
ipi.Domains = append([]string{domain}, ipi.Domains...)
return true
}
return false
}
// Save saves the IPInfo record to the database.
func (ipi *IPInfo) Save() error {
ipi.Lock()
if !ipi.KeyIsSet() {
ipi.SetKey(makeIPInfoKey(ipi.IP))
}
ipi.Unlock()
return ipInfoDatabase.Put(ipi)
}
// FmtDomains returns a string consisting of the domains that have seen to use this IP, joined by " or "
func (m *IPInfo) FmtDomains() string {
return strings.Join(m.Domains, " or ")
func (ipi *IPInfo) FmtDomains() string {
return strings.Join(ipi.Domains, " or ")
}
// FmtDomains returns a string consisting of the domains that have seen to use this IP, joined by " or "
func (ipi *IPInfo) String() string {
ipi.Lock()
defer ipi.Unlock()
return fmt.Sprintf("<IPInfo[%s] %s: %s", ipi.Key(), ipi.IP, ipi.FmtDomains())
}

25
intel/ipinfo_test.go Normal file
View file

@ -0,0 +1,25 @@
package intel
import "testing"
func testDomains(t *testing.T, ipi *IPInfo, expectedDomains string) {
if ipi.FmtDomains() != expectedDomains {
t.Errorf("unexpected domains '%s', expected '%s'", ipi.FmtDomains(), expectedDomains)
}
}
func TestIPInfo(t *testing.T) {
ipi := &IPInfo{
IP: "1.2.3.4",
Domains: []string{"example.com.", "sub.example.com."},
}
testDomains(t, ipi, "example.com. or sub.example.com.")
ipi.AddDomain("added.example.com.")
testDomains(t, ipi, "added.example.com. or example.com. or sub.example.com.")
ipi.AddDomain("sub.example.com.")
testDomains(t, ipi, "added.example.com. or example.com. or sub.example.com.")
ipi.AddDomain("added.example.com.")
testDomains(t, ipi, "added.example.com. or example.com. or sub.example.com.")
}

35
intel/main.go Normal file
View file

@ -0,0 +1,35 @@
package intel
import (
"github.com/miekg/dns"
"github.com/Safing/portbase/log"
"github.com/Safing/portbase/modules"
// module dependencies
_ "github.com/Safing/portmaster/global"
)
func init() {
modules.Register("intel", prep, start, nil, "global")
}
func start() error {
// load resolvers from config and environment
loadResolvers(false)
go listenToMDNS()
return nil
}
// GetIntelAndRRs returns intel and DNS resource records for the given domain.
func GetIntelAndRRs(domain string, qtype dns.Type, securityLevel uint8) (intel *Intel, rrs *RRCache) {
intel, err := GetIntel(domain)
if err != nil {
log.Errorf("intel: failed to get intel: %s", err)
intel = nil
}
rrs = Resolve(domain, qtype, securityLevel)
return
}

38
intel/main_test.go Normal file
View file

@ -0,0 +1,38 @@
package intel
import (
"os"
"testing"
"github.com/Safing/portbase/database/dbmodule"
"github.com/Safing/portbase/log"
"github.com/Safing/portbase/modules"
)
func TestMain(m *testing.M) {
// setup
testDir := os.TempDir()
dbmodule.SetDatabaseLocation(testDir)
err := modules.Start()
if err != nil {
if err == modules.ErrCleanExit {
os.Exit(0)
} else {
err = modules.Shutdown()
if err != nil {
log.Shutdown()
}
os.Exit(1)
}
}
// run tests
rv := m.Run()
// teardown
modules.Shutdown()
os.RemoveAll(testDir)
// exit with test run return value
os.Exit(rv)
}

View file

@ -6,14 +6,16 @@ import (
"errors"
"fmt"
"net"
"github.com/Safing/safing-core/log"
"strings"
"sync"
"time"
"github.com/miekg/dns"
"github.com/Safing/portbase/log"
)
// DNS Classes
const (
DNSClassMulticast = dns.ClassINET | 1<<15
)
@ -33,10 +35,6 @@ type savedQuestion struct {
expires int64
}
func init() {
go listenToMDNS()
}
func indexOfRR(entry *dns.RR_Header, list *[]dns.RR) int {
for k, v := range *list {
if entry.Name == v.Header().Name && entry.Rrtype == v.Header().Rrtype {
@ -89,7 +87,7 @@ func listenToMDNS() {
var question *dns.Question
var saveFullRequest bool
scavengedRecords := make(map[string]*dns.RR)
scavengedRecords := make(map[string]dns.RR)
var rrCache *RRCache
// save every received response
@ -114,7 +112,7 @@ func listenToMDNS() {
continue
}
// continue if no question
// get question, some servers do not reply with question
if len(message.Question) == 0 {
questionsLock.Lock()
savedQ, ok := questions[message.MsgHdr.Id]
@ -138,8 +136,11 @@ func listenToMDNS() {
// get entry from database
if saveFullRequest {
rrCache, err = GetRRCache(question.Name, dns.Type(question.Qtype))
if err != nil || rrCache.Modified < time.Now().Add(-2*time.Second).Unix() || rrCache.Expires < time.Now().Unix() {
rrCache = &RRCache{}
if err != nil || rrCache.updated < time.Now().Add(-2*time.Second).Unix() || rrCache.TTL < time.Now().Unix() {
rrCache = &RRCache{
Domain: question.Name,
Question: dns.Type(question.Qtype),
}
}
}
@ -155,12 +156,12 @@ func listenToMDNS() {
}
switch entry.(type) {
case *dns.A:
scavengedRecords[fmt.Sprintf("%sA", entry.Header().Name)] = &entry
scavengedRecords[fmt.Sprintf("%sA", entry.Header().Name)] = entry
case *dns.AAAA:
scavengedRecords[fmt.Sprintf("%sAAAA", entry.Header().Name)] = &entry
scavengedRecords[fmt.Sprintf("%sAAAA", entry.Header().Name)] = entry
case *dns.PTR:
if !strings.HasPrefix(entry.Header().Name, "_") {
scavengedRecords[fmt.Sprintf("%sPTR", entry.Header().Name)] = &entry
scavengedRecords[fmt.Sprintf("%sPTR", entry.Header().Name)] = entry
}
}
}
@ -177,17 +178,16 @@ func listenToMDNS() {
}
switch entry.(type) {
case *dns.A:
scavengedRecords[fmt.Sprintf("%sA", entry.Header().Name)] = &entry
scavengedRecords[fmt.Sprintf("%s_A", entry.Header().Name)] = entry
case *dns.AAAA:
scavengedRecords[fmt.Sprintf("%sAAAA", entry.Header().Name)] = &entry
scavengedRecords[fmt.Sprintf("%s_AAAA", entry.Header().Name)] = entry
case *dns.PTR:
if !strings.HasPrefix(entry.Header().Name, "_") {
scavengedRecords[fmt.Sprintf("%sPTR", entry.Header().Name)] = &entry
scavengedRecords[fmt.Sprintf("%s_PTR", entry.Header().Name)] = entry
}
}
}
}
// TODO: scan Extra for A and AAAA records and save them seperately
for _, entry := range message.Extra {
if strings.HasSuffix(entry.Header().Name, ".local.") || domainInScopes(entry.Header().Name, localReverseScopes) {
if saveFullRequest {
@ -200,34 +200,35 @@ func listenToMDNS() {
}
switch entry.(type) {
case *dns.A:
scavengedRecords[fmt.Sprintf("%sA", entry.Header().Name)] = &entry
scavengedRecords[fmt.Sprintf("%sA", entry.Header().Name)] = entry
case *dns.AAAA:
scavengedRecords[fmt.Sprintf("%sAAAA", entry.Header().Name)] = &entry
scavengedRecords[fmt.Sprintf("%sAAAA", entry.Header().Name)] = entry
case *dns.PTR:
if !strings.HasPrefix(entry.Header().Name, "_") {
scavengedRecords[fmt.Sprintf("%sPTR", entry.Header().Name)] = &entry
scavengedRecords[fmt.Sprintf("%sPTR", entry.Header().Name)] = entry
}
}
}
}
var questionID string
if saveFullRequest {
rrCache.Clean(60)
rrCache.CreateWithType(question.Name, dns.Type(question.Qtype))
// log.Tracef("intel: mdns saved full reply to %s%s", question.Name, dns.Type(question.Qtype).String())
rrCache.Save()
questionID = fmt.Sprintf("%s%s", question.Name, dns.Type(question.Qtype).String())
}
for k, v := range scavengedRecords {
if saveFullRequest {
if k == fmt.Sprintf("%s%s", question.Name, dns.Type(question.Qtype).String()) {
continue
}
if saveFullRequest && k == questionID {
continue
}
rrCache = &RRCache{
Answer: []dns.RR{*v},
Domain: v.Header().Name,
Question: dns.Type(v.Header().Class),
Answer: []dns.RR{v},
}
rrCache.Clean(60)
rrCache.Create(k)
rrCache.Save()
// log.Tracef("intel: mdns scavenged %s", k)
}
@ -261,7 +262,7 @@ func listenForDNSPackets(conn *net.UDPConn, messages chan *dns.Msg) {
}
}
func queryMulticastDNS(resolver *Resolver, fqdn string, qtype dns.Type) (*RRCache, error) {
func queryMulticastDNS(fqdn string, qtype dns.Type) (*RRCache, error) {
q := new(dns.Msg)
q.SetQuestion(fqdn, uint16(qtype))
// request unicast response

73
intel/namerecord.go Normal file
View file

@ -0,0 +1,73 @@
package intel
import (
"errors"
"fmt"
"sync"
"github.com/Safing/portbase/database"
"github.com/Safing/portbase/database/record"
)
var (
recordDatabase = database.NewInterface(&database.Options{
AlwaysSetRelativateExpiry: 2592000, // 30 days
CacheSize: 128,
})
)
// NameRecord is helper struct to RRCache to better save data to the database.
type NameRecord struct {
record.Base
sync.Mutex
Domain string
Question string
Answer []string
Ns []string
Extra []string
TTL int64
Filtered bool
}
func makeNameRecordKey(domain string, question string) string {
return fmt.Sprintf("cache:intel/nameRecord/%s%s", domain, question)
}
// GetNameRecord gets a NameRecord from the database.
func GetNameRecord(domain string, question string) (*NameRecord, error) {
key := makeNameRecordKey(domain, question)
r, err := recordDatabase.Get(key)
if err != nil {
return nil, err
}
// unwrap
if r.IsWrapped() {
// only allocate a new struct, if we need it
new := &NameRecord{}
err = record.Unwrap(r, new)
if err != nil {
return nil, err
}
return new, nil
}
// or adjust type
new, ok := r.(*NameRecord)
if !ok {
return nil, fmt.Errorf("record not of type *NameRecord, but %T", r)
}
return new, nil
}
// Save saves the NameRecord to the database.
func (rec *NameRecord) Save() error {
if rec.Domain == "" || rec.Question == "" {
return errors.New("could not save NameRecord, missing Domain and/or Question")
}
rec.SetKey(makeNameRecordKey(rec.Domain, rec.Question))
return recordDatabase.PutNew(rec)
}

View file

@ -3,30 +3,20 @@
package intel
import (
"crypto/tls"
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"math/rand"
"net"
"net/http"
"net/url"
"sort"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/miekg/dns"
"github.com/tevino/abool"
"github.com/Safing/safing-core/configuration"
"github.com/Safing/safing-core/database"
"github.com/Safing/safing-core/log"
"github.com/Safing/safing-core/network/environment"
"github.com/Safing/safing-core/network/netutils"
"github.com/Safing/portbase/database"
"github.com/Safing/portbase/log"
"github.com/Safing/portmaster/network/netutils"
"github.com/Safing/portmaster/status"
)
// TODO: make resolver interface for http package
@ -79,322 +69,14 @@ import (
// global -> local scopes, global
// special -> local scopes, local
type Resolver struct {
// static
Server string
ServerAddress string
IP *net.IP
Port uint16
Resolve func(resolver *Resolver, fqdn string, qtype dns.Type) (*RRCache, error)
Search *[]string
AllowedSecurityLevel int8
SkipFqdnBeforeInit string
HTTPClient *http.Client
Source string
// atomic
Initialized *abool.AtomicBool
InitLock sync.Mutex
LastFail *int64
Expires *int64
// must be locked
LockReason sync.Mutex
FailReason string
// TODO: add:
// Expiration (for server got from DHCP / ICMPv6)
// bootstrapping (first query is already sent, wait for it to either succeed or fail - think about http bootstrapping here!)
// expanded server info: type, server address, server port, options - so we do not have to parse this every time!
}
func (r *Resolver) String() string {
return r.Server
}
func (r *Resolver) Address() string {
return urlFormatAddress(r.IP, r.Port)
}
type Scope struct {
Domain string
Resolvers []*Resolver
}
var (
config = configuration.Get()
globalResolvers []*Resolver // all resolvers
localResolvers []*Resolver // all resolvers that are in site-local or link-local IP ranges
localScopes []Scope // list of scopes with a list of local resolvers that can resolve the scope
mDNSResolver *Resolver // holds a reference to the mDNS resolver
resolversLock sync.RWMutex
env = environment.NewInterface()
dupReqMap = make(map[string]*sync.Mutex)
dupReqLock sync.Mutex
)
func init() {
loadResolvers(false)
}
func indexOfResolver(server string, list []*Resolver) int {
for k, v := range list {
if v.Server == server {
return k
}
}
return -1
}
func indexOfScope(domain string, list *[]Scope) int {
for k, v := range *list {
if v.Domain == domain {
return k
}
}
return -1
}
func parseAddress(server string) (*net.IP, uint16, error) {
delimiter := strings.LastIndex(server, ":")
if delimiter < 0 {
return nil, 0, errors.New("port missing")
}
ip := net.ParseIP(strings.Trim(server[:delimiter], "[]"))
if ip == nil {
return nil, 0, errors.New("invalid IP address")
}
port, err := strconv.Atoi(server[delimiter+1:])
if err != nil || port < 1 || port > 65536 {
return nil, 0, errors.New("invalid port")
}
return &ip, uint16(port), nil
}
func urlFormatAddress(ip *net.IP, port uint16) string {
var address string
if ipv4 := ip.To4(); ipv4 != nil {
address = fmt.Sprintf("%s:%d", ipv4.String(), port)
} else {
address = fmt.Sprintf("[%s]:%d", ip.String(), port)
}
return address
}
func loadResolvers(resetResolvers bool) {
// TODO: what happens when a lot of processes want to reload at once? we do not need to run this multiple times in a short time frame.
resolversLock.Lock()
defer resolversLock.Unlock()
var newResolvers []*Resolver
configuredServersLoop:
for _, server := range config.DNSServers {
key := indexOfResolver(server, newResolvers)
if key >= 0 {
continue configuredServersLoop
}
key = indexOfResolver(server, globalResolvers)
if resetResolvers || key == -1 {
parts := strings.Split(server, "|")
if len(parts) < 2 {
log.Warningf("intel: invalid DNS server in config: %s (invalid format)", server)
continue configuredServersLoop
}
var lastFail int64
new := &Resolver{
Server: server,
ServerAddress: parts[1],
LastFail: &lastFail,
Source: "config",
Initialized: abool.NewBool(false),
}
ip, port, err := parseAddress(parts[1])
if err != nil {
new.IP = ip
new.Port = port
}
switch {
case strings.HasPrefix(server, "DNS|"):
new.Resolve = queryDNS
new.AllowedSecurityLevel = configuration.SecurityLevelFortress
case strings.HasPrefix(server, "DoH|"):
new.Resolve = queryDNSoverHTTPS
new.AllowedSecurityLevel = configuration.SecurityLevelFortress
new.SkipFqdnBeforeInit = dns.Fqdn(strings.Split(parts[1], ":")[0])
tls := &tls.Config{
// TODO: use custom random
// Rand: io.Reader,
}
tr := &http.Transport{
MaxIdleConnsPerHost: 100,
TLSClientConfig: tls,
// TODO: use custom resolver as of Go1.9
}
if len(parts) == 3 && strings.HasPrefix(parts[2], "df:") {
// activate domain fronting
tls.ServerName = parts[2][3:]
addDomainFronting(new.SkipFqdnBeforeInit, dns.Fqdn(tls.ServerName))
new.SkipFqdnBeforeInit = dns.Fqdn(tls.ServerName)
}
new.HTTPClient = &http.Client{Transport: tr}
default:
log.Warningf("intel: invalid DNS server in config: %s (not starting with a valid identifier)", server)
continue configuredServersLoop
}
newResolvers = append(newResolvers, new)
} else {
newResolvers = append(newResolvers, globalResolvers[key])
}
}
// add local resolvers
assignedNameservers := environment.Nameservers()
assignedServersLoop:
for _, nameserver := range assignedNameservers {
server := fmt.Sprintf("DNS|%s", urlFormatAddress(&nameserver.IP, 53))
key := indexOfResolver(server, newResolvers)
if key >= 0 {
continue assignedServersLoop
}
key = indexOfResolver(server, globalResolvers)
if resetResolvers || key == -1 {
var lastFail int64
new := &Resolver{
Server: server,
ServerAddress: urlFormatAddress(&nameserver.IP, 53),
IP: &nameserver.IP,
Port: 53,
LastFail: &lastFail,
Resolve: queryDNS,
AllowedSecurityLevel: configuration.SecurityLevelFortress,
Initialized: abool.NewBool(false),
Source: "dhcp",
}
if netutils.IPIsLocal(nameserver.IP) && len(nameserver.Search) > 0 {
// only allow searches for local resolvers
var newSearch []string
for _, value := range nameserver.Search {
newSearch = append(newSearch, fmt.Sprintf(".%s.", strings.Trim(value, ".")))
}
new.Search = &newSearch
}
newResolvers = append(newResolvers, new)
} else {
newResolvers = append(newResolvers, globalResolvers[key])
}
}
// save resolvers
globalResolvers = newResolvers
if len(globalResolvers) == 0 {
log.Criticalf("intel: no (valid) dns servers found in configuration and system")
}
// make list with local resolvers
localResolvers = make([]*Resolver, 0)
for _, resolver := range globalResolvers {
if resolver.IP != nil && netutils.IPIsLocal(*resolver.IP) {
localResolvers = append(localResolvers, resolver)
}
}
// add resolvers to every scope the cover
localScopes = make([]Scope, 0)
for _, resolver := range globalResolvers {
if resolver.Search != nil {
// add resolver to custom searches
for _, search := range *resolver.Search {
if search == "." {
continue
}
key := indexOfScope(search, &localScopes)
if key == -1 {
localScopes = append(localScopes, Scope{
Domain: search,
Resolvers: []*Resolver{resolver},
})
} else {
localScopes[key].Resolvers = append(localScopes[key].Resolvers, resolver)
}
}
}
}
// init mdns resolver
if mDNSResolver == nil {
cannotFail := int64(-1)
mDNSResolver = &Resolver{
Server: "mDNS",
Resolve: queryMulticastDNS,
AllowedSecurityLevel: config.DoNotUseMDNS.Level(),
Initialized: abool.NewBool(false),
Source: "static",
LastFail: &cannotFail,
}
}
// sort scopes by length
sort.Slice(localScopes,
func(i, j int) bool {
return len(localScopes[i].Domain) > len(localScopes[j].Domain)
},
)
log.Trace("intel: loaded global resolvers:")
for _, resolver := range globalResolvers {
log.Tracef("intel: %s", resolver.Server)
}
log.Trace("intel: loaded local resolvers:")
for _, resolver := range localResolvers {
log.Tracef("intel: %s", resolver.Server)
}
log.Trace("intel: loaded scopes:")
for _, scope := range localScopes {
var scopeServers []string
for _, resolver := range scope.Resolvers {
scopeServers = append(scopeServers, resolver.Server)
}
log.Tracef("intel: %s: %s", scope.Domain, strings.Join(scopeServers, ", "))
}
}
// Resolve resolves the given query for a domain and type and returns a RRCache object or nil, if the query failed.
func Resolve(fqdn string, qtype dns.Type, securityLevel int8) *RRCache {
func Resolve(fqdn string, qtype dns.Type, securityLevel uint8) *RRCache {
fqdn = dns.Fqdn(fqdn)
// use this to time how long it takes resolve this domain
// timed := time.Now()
// defer log.Tracef("intel: took %s to get resolve %s%s", time.Now().Sub(timed).String(), fqdn, qtype.String())
// handle request for localhost
if fqdn == "localhost." {
var rr dns.RR
var err error
switch uint16(qtype) {
case dns.TypeA:
rr, err = dns.NewRR("localhost. 3600 IN A 127.0.0.1")
case dns.TypeAAAA:
rr, err = dns.NewRR("localhost. 3600 IN AAAA ::1")
default:
return nil
}
if err != nil {
return nil
}
return &RRCache{
Answer: []dns.RR{rr},
}
}
// check cache
rrCache, err := GetRRCache(fqdn, qtype)
if err != nil {
@ -406,7 +88,8 @@ func Resolve(fqdn string, qtype dns.Type, securityLevel int8) *RRCache {
return resolveAndCache(fqdn, qtype, securityLevel)
}
if rrCache.Expires <= time.Now().Unix() {
if rrCache.TTL <= time.Now().Unix() {
log.Tracef("intel: serving cache, requesting new. TTL=%d, now=%d", rrCache.TTL, time.Now().Unix())
rrCache.requestingNew = true
go resolveAndCache(fqdn, qtype, securityLevel)
}
@ -420,17 +103,9 @@ func Resolve(fqdn string, qtype dns.Type, securityLevel int8) *RRCache {
return rrCache
}
func resolveAndCache(fqdn string, qtype dns.Type, securityLevel int8) *RRCache {
func resolveAndCache(fqdn string, qtype dns.Type, securityLevel uint8) (rrCache *RRCache) {
// log.Tracef("intel: resolving %s%s", fqdn, qtype.String())
rrCache, ok := checkDomainFronting(fqdn, qtype, securityLevel)
if ok {
if rrCache == nil {
return nil
}
return rrCache
}
// dedup requests
dupKey := fmt.Sprintf("%s%s", fqdn, qtype.String())
dupReqLock.Lock()
@ -456,7 +131,7 @@ func resolveAndCache(fqdn string, qtype dns.Type, securityLevel int8) *RRCache {
}
defer func() {
dupReqLock.Lock()
delete(dupReqMap, fqdn)
delete(dupReqMap, dupKey)
dupReqLock.Unlock()
mutex.Unlock()
}()
@ -469,29 +144,29 @@ func resolveAndCache(fqdn string, qtype dns.Type, securityLevel int8) *RRCache {
// persist to database
rrCache.Clean(600)
rrCache.CreateWithType(fqdn, qtype)
rrCache.Save()
return rrCache
}
func intelligentResolve(fqdn string, qtype dns.Type, securityLevel int8) *RRCache {
func intelligentResolve(fqdn string, qtype dns.Type, securityLevel uint8) *RRCache {
// TODO: handle being offline
// TODO: handle multiple network connections
if config.Changed() {
log.Info("intel: config changed, reloading resolvers")
loadResolvers(false)
} else if env.NetworkChanged() {
log.Info("intel: network changed, reloading resolvers")
loadResolvers(true)
}
config.RLock()
defer config.RUnlock()
// TODO: handle these in a separate goroutine
// if config.Changed() {
// log.Info("intel: config changed, reloading resolvers")
// loadResolvers(false)
// } else if env.NetworkChanged() {
// log.Info("intel: network changed, reloading resolvers")
// loadResolvers(true)
// }
resolversLock.RLock()
defer resolversLock.RUnlock()
lastFailBoundary := time.Now().Unix() - config.DNSServerRetryRate
lastFailBoundary := time.Now().Unix() - nameserverRetryRate()
preDottedFqdn := "." + fqdn
// resolve:
@ -510,11 +185,14 @@ func intelligentResolve(fqdn string, qtype dns.Type, securityLevel int8) *RRCach
}
}
// check config
if config.DoNotUseMDNS.IsSetWithLevel(securityLevel) {
if doNotUseMulticastDNS(securityLevel) {
return nil
}
// try mdns
rrCache, _ := tryResolver(mDNSResolver, lastFailBoundary, fqdn, qtype, securityLevel)
rrCache, err := queryMulticastDNS(fqdn, qtype)
if err != nil {
log.Errorf("intel: failed to query mdns: %s", err)
}
return rrCache
}
@ -533,15 +211,18 @@ func intelligentResolve(fqdn string, qtype dns.Type, securityLevel int8) *RRCach
switch {
case strings.HasSuffix(preDottedFqdn, ".local."):
// check config
if config.DoNotUseMDNS.IsSetWithLevel(securityLevel) {
if doNotUseMulticastDNS(securityLevel) {
return nil
}
// try mdns
rrCache, _ := tryResolver(mDNSResolver, lastFailBoundary, fqdn, qtype, securityLevel)
rrCache, err := queryMulticastDNS(fqdn, qtype)
if err != nil {
log.Errorf("intel: failed to query mdns: %s", err)
}
return rrCache
case domainInScopes(preDottedFqdn, specialScopes):
// check config
if config.DoNotForwardSpecialDomains.IsSetWithLevel(securityLevel) {
if doNotResolveSpecialDomains(securityLevel) {
return nil
}
// try local resolvers
@ -568,15 +249,15 @@ func intelligentResolve(fqdn string, qtype dns.Type, securityLevel int8) *RRCach
}
func tryResolver(resolver *Resolver, lastFailBoundary int64, fqdn string, qtype dns.Type, securityLevel int8) (*RRCache, bool) {
func tryResolver(resolver *Resolver, lastFailBoundary int64, fqdn string, qtype dns.Type, securityLevel uint8) (*RRCache, bool) {
// skip if not allowed in current security level
if resolver.AllowedSecurityLevel < config.SecurityLevel() || resolver.AllowedSecurityLevel < securityLevel {
log.Tracef("intel: skipping resolver %s, because it isn't allowed to operate on the current security level: %d|%d", resolver, config.SecurityLevel(), securityLevel)
if resolver.AllowedSecurityLevel < status.CurrentSecurityLevel() || resolver.AllowedSecurityLevel < securityLevel {
log.Tracef("intel: skipping resolver %s, because it isn't allowed to operate on the current security level: %d|%d", resolver, status.CurrentSecurityLevel(), securityLevel)
return nil, false
}
// skip if not security level denies assigned dns servers
if config.DoNotUseAssignedDNS.IsSetWithLevel(securityLevel) && resolver.Source == "dhcp" {
log.Tracef("intel: skipping resolver %s, because assigned nameservers are not allowed on the current security level: %d|%d (%d)", resolver, config.SecurityLevel(), securityLevel, int8(config.DoNotUseAssignedDNS))
if doNotUseAssignedNameservers(securityLevel) && resolver.Source == "dhcp" {
log.Tracef("intel: skipping resolver %s, because assigned nameservers are not allowed on the current security level: %d|%d", resolver, status.CurrentSecurityLevel(), securityLevel)
return nil, false
}
// check if failed recently
@ -606,7 +287,7 @@ func tryResolver(resolver *Resolver, lastFailBoundary int64, fqdn string, qtype
}
// resolve
log.Tracef("intel: trying to resolve %s%s with %s", fqdn, qtype.String(), resolver.Server)
rrCache, err := resolver.Resolve(resolver, fqdn, qtype)
rrCache, err := query(resolver, fqdn, qtype)
if err != nil {
// check if failing is disabled
if atomic.LoadInt64(resolver.LastFail) == -1 {
@ -622,126 +303,67 @@ func tryResolver(resolver *Resolver, lastFailBoundary int64, fqdn string, qtype
return nil, false
}
resolver.Initialized.SetToIf(false, true)
// remove localhost entries, remove LAN entries if server is in global IP space.
if resolver.ServerIPScope == netutils.Global {
rrCache.FilterEntries(true, false, false)
} else {
rrCache.FilterEntries(true, true, false)
}
return rrCache, true
}
func queryDNS(resolver *Resolver, fqdn string, qtype dns.Type) (*RRCache, error) {
func query(resolver *Resolver, fqdn string, qtype dns.Type) (*RRCache, error) {
q := new(dns.Msg)
q.SetQuestion(fqdn, uint16(qtype))
var reply *dns.Msg
var err error
for i := 0; i < 5; i++ {
client := new(dns.Client)
reply, _, err = client.Exchange(q, resolver.ServerAddress)
for i := 0; i < 3; i++ {
// log query time
// qStart := time.Now()
reply, _, err = resolver.clientManager.getDNSClient().Exchange(q, resolver.ServerAddress)
// log.Tracef("intel: query to %s took %s", resolver.Server, time.Now().Sub(qStart))
// error handling
if err != nil {
log.Tracef("intel: query to %s encountered error: %s", resolver.Server, err)
// TODO: handle special cases
// 1. connect: network is unreachable
// 2. timeout
// temporary error
if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
log.Tracef("intel: retrying to resolve %s%s with %s, error was: %s", fqdn, qtype.String(), resolver.Server, err)
continue
}
// permanent error
break
}
// no error
break
}
if err != nil {
log.Warningf("resolving %s%s failed: %s", fqdn, qtype.String(), err)
return nil, fmt.Errorf("resolving %s%s failed: %s", fqdn, qtype.String(), err)
err = fmt.Errorf("resolving %s%s failed: %s", fqdn, qtype.String(), err)
log.Warning(err.Error())
return nil, err
}
new := &RRCache{
Answer: reply.Answer,
Ns: reply.Ns,
Extra: reply.Extra,
Domain: fqdn,
Question: qtype,
Answer: reply.Answer,
Ns: reply.Ns,
Extra: reply.Extra,
}
// TODO: check if reply.Answer is valid
return new, nil
}
type DnsOverHttpsReply struct {
Status uint32
Truncated bool `json:"TC"`
Answer []DohRR
Additional []DohRR
}
type DohRR struct {
Name string `json:"name"`
Qtype uint16 `json:"type"`
TTL uint32 `json:"TTL"`
Data string `json:"data"`
}
func queryDNSoverHTTPS(resolver *Resolver, fqdn string, qtype dns.Type) (*RRCache, error) {
// API documentation: https://developers.google.com/speed/public-dns/docs/dns-over-https
payload := url.Values{}
payload.Add("name", fqdn)
payload.Add("type", strconv.Itoa(int(qtype)))
payload.Add("edns_client_subnet", "0.0.0.0/0")
// TODO: add random - only use upper- and lower-case letters, digits, hyphen, period, underscore and tilde
// payload.Add("random_padding", "")
resp, err := resolver.HTTPClient.Get(fmt.Sprintf("https://%s/resolve?%s", resolver.ServerAddress, payload.Encode()))
if err != nil {
return nil, fmt.Errorf("resolving %s%s failed: http error: %s", fqdn, qtype.String(), err)
// TODO: handle special cases
// 1. connect: network is unreachable
// intel: resolver DoH|dns.google.com:443|df:www.google.com failed (resolving discovery-v4-4.syncthing.net.A failed: http error: Get https://dns.google.com:443/resolve?edns_client_subnet=0.0.0.0%2F0&name=discovery-v4-4.syncthing.net.&type=1: dial tcp [2a00:1450:4001:819::2004]:443: connect: network is unreachable), moving to next
// 2. timeout
}
if resp.StatusCode != 200 {
return nil, fmt.Errorf("resolving %s%s failed: request was unsuccessful, got code %d", fqdn, qtype.String(), resp.StatusCode)
}
defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("resolving %s%s failed: error reading response body: %s", fqdn, qtype.String(), err)
}
var reply DnsOverHttpsReply
err = json.Unmarshal(body, &reply)
if err != nil {
return nil, fmt.Errorf("resolving %s%s failed: error parsing response body: %s", fqdn, qtype.String(), err)
}
if reply.Status != 0 {
// this happens if there is a server error (e.g. DNSSEC fail), ignore for now
// TODO: do something more intelligent
}
new := new(RRCache)
// TODO: handle TXT records
for _, entry := range reply.Answer {
rr, err := dns.NewRR(fmt.Sprintf("%s %d IN %s %s", entry.Name, entry.TTL, dns.Type(entry.Qtype).String(), entry.Data))
if err != nil {
log.Warningf("intel: resolving %s%s failed: failed to parse record to DNS: %s %d IN %s %s", fqdn, qtype.String(), entry.Name, entry.TTL, dns.Type(entry.Qtype).String(), entry.Data)
continue
}
new.Answer = append(new.Answer, rr)
}
for _, entry := range reply.Additional {
rr, err := dns.NewRR(fmt.Sprintf("%s %d IN %s %s", entry.Name, entry.TTL, dns.Type(entry.Qtype).String(), entry.Data))
if err != nil {
log.Warningf("intel: resolving %s%s failed: failed to parse record to DNS: %s %d IN %s %s", fqdn, qtype.String(), entry.Name, entry.TTL, dns.Type(entry.Qtype).String(), entry.Data)
continue
}
new.Extra = append(new.Extra, rr)
}
return new, nil
}
// TODO: implement T-DNS: DNS over TCP/TLS
// server list: https://dnsprivacy.org/wiki/display/DP/DNS+Privacy+Test+Servers

View file

@ -2,14 +2,16 @@
package intel
import (
"testing"
"time"
// DISABLE TESTING FOR NOW: find a way to have tests with the module system
"github.com/miekg/dns"
)
// import (
// "testing"
// "time"
//
// "github.com/miekg/dns"
// )
func TestResolve(t *testing.T) {
Resolve("google.com.", dns.Type(dns.TypeA), 0)
time.Sleep(200 * time.Millisecond)
}
// func TestResolve(t *testing.T) {
// Resolve("google.com.", dns.Type(dns.TypeA), 0)
// time.Sleep(200 * time.Millisecond)
// }

295
intel/resolver.go Normal file
View file

@ -0,0 +1,295 @@
package intel
import (
"errors"
"fmt"
"net"
"sort"
"strconv"
"strings"
"sync"
"github.com/miekg/dns"
"github.com/tevino/abool"
"github.com/Safing/portbase/log"
"github.com/Safing/portmaster/network/environment"
"github.com/Safing/portmaster/network/netutils"
"github.com/Safing/portmaster/status"
)
// Resolver holds information about an active resolver.
type Resolver struct {
// static
Server string
ServerType string
ServerAddress string
ServerIP net.IP
ServerIPScope int8
ServerPort uint16
VerifyDomain string
Source string
clientManager *clientManager
Search *[]string
AllowedSecurityLevel uint8
SkipFqdnBeforeInit string
// atomic
Initialized *abool.AtomicBool
InitLock sync.Mutex
LastFail *int64
Expires *int64
// must be locked
LockReason sync.Mutex
FailReason string
// TODO: add:
// Expiration (for server got from DHCP / ICMPv6)
// bootstrapping (first query is already sent, wait for it to either succeed or fail - think about http bootstrapping here!)
// expanded server info: type, server address, server port, options - so we do not have to parse this every time!
}
func (r *Resolver) String() string {
return r.Server
}
// Scope defines a domain scope and which resolvers can resolve it.
type Scope struct {
Domain string
Resolvers []*Resolver
}
var (
globalResolvers []*Resolver // all resolvers
localResolvers []*Resolver // all resolvers that are in site-local or link-local IP ranges
localScopes []*Scope // list of scopes with a list of local resolvers that can resolve the scope
resolversLock sync.RWMutex
env = environment.NewInterface()
dupReqMap = make(map[string]*sync.Mutex)
dupReqLock sync.Mutex
)
func indexOfResolver(server string, list []*Resolver) int {
for k, v := range list {
if v.Server == server {
return k
}
}
return -1
}
func indexOfScope(domain string, list []*Scope) int {
for k, v := range list {
if v.Domain == domain {
return k
}
}
return -1
}
func parseAddress(server string) (net.IP, uint16, error) {
delimiter := strings.LastIndex(server, ":")
if delimiter < 0 {
return nil, 0, errors.New("port missing")
}
ip := net.ParseIP(strings.Trim(server[:delimiter], "[]"))
if ip == nil {
return nil, 0, errors.New("invalid IP address")
}
port, err := strconv.Atoi(server[delimiter+1:])
if err != nil || port < 1 || port > 65536 {
return nil, 0, errors.New("invalid port")
}
return ip, uint16(port), nil
}
func urlFormatAddress(ip net.IP, port uint16) string {
var address string
if ipv4 := ip.To4(); ipv4 != nil {
address = fmt.Sprintf("%s:%d", ipv4.String(), port)
} else {
address = fmt.Sprintf("[%s]:%d", ip.String(), port)
}
return address
}
func loadResolvers(resetResolvers bool) {
// TODO: what happens when a lot of processes want to reload at once? we do not need to run this multiple times in a short time frame.
resolversLock.Lock()
defer resolversLock.Unlock()
var newResolvers []*Resolver
configuredServersLoop:
for _, server := range configuredNameServers() {
key := indexOfResolver(server, newResolvers)
if key >= 0 {
continue configuredServersLoop
}
key = indexOfResolver(server, globalResolvers)
if resetResolvers || key == -1 {
parts := strings.Split(server, "|")
if len(parts) < 2 {
log.Warningf("intel: nameserver format invalid: %s", server)
continue configuredServersLoop
}
ip, port, err := parseAddress(parts[1])
if err != nil && strings.ToLower(parts[0]) != "https" {
log.Warningf("intel: nameserver (%s) address invalid: %s", server, err)
continue configuredServersLoop
}
var lastFail int64
new := &Resolver{
Server: server,
ServerType: parts[0],
ServerAddress: parts[1],
ServerIP: ip,
ServerIPScope: netutils.ClassifyIP(ip),
ServerPort: port,
LastFail: &lastFail,
Source: "config",
Initialized: abool.NewBool(false),
}
switch strings.ToLower(parts[0]) {
case "dns":
new.clientManager = newDNSClientManager(new)
case "tcp":
new.clientManager = newTCPClientManager(new)
case "tls":
new.AllowedSecurityLevel = status.SecurityLevelFortress
if len(parts) < 3 {
log.Warningf("intel: nameserver missing verification domain as third parameter: %s", server)
continue configuredServersLoop
}
new.VerifyDomain = parts[2]
new.clientManager = newTLSClientManager(new)
case "https":
new.AllowedSecurityLevel = status.SecurityLevelFortress
new.SkipFqdnBeforeInit = dns.Fqdn(strings.Split(parts[1], ":")[0])
if len(parts) > 2 {
new.VerifyDomain = parts[2]
}
new.clientManager = newHTTPSClientManager(new)
default:
log.Warningf("intel: nameserver (%s) type invalid: %s", server, parts[0])
continue configuredServersLoop
}
newResolvers = append(newResolvers, new)
} else {
newResolvers = append(newResolvers, globalResolvers[key])
}
}
// add local resolvers
assignedNameservers := environment.Nameservers()
assignedServersLoop:
for _, nameserver := range assignedNameservers {
server := fmt.Sprintf("dns|%s", urlFormatAddress(nameserver.IP, 53))
key := indexOfResolver(server, newResolvers)
if key >= 0 {
continue assignedServersLoop
}
key = indexOfResolver(server, globalResolvers)
if resetResolvers || key == -1 {
var lastFail int64
new := &Resolver{
Server: server,
ServerType: "dns",
ServerAddress: urlFormatAddress(nameserver.IP, 53),
ServerIP: nameserver.IP,
ServerIPScope: netutils.ClassifyIP(nameserver.IP),
ServerPort: 53,
LastFail: &lastFail,
Source: "dhcp",
Initialized: abool.NewBool(false),
AllowedSecurityLevel: status.SecurityLevelSecure,
}
new.clientManager = newDNSClientManager(new)
if netutils.IPIsLAN(nameserver.IP) && len(nameserver.Search) > 0 {
// only allow searches for local resolvers
var newSearch []string
for _, value := range nameserver.Search {
newSearch = append(newSearch, fmt.Sprintf(".%s.", strings.Trim(value, ".")))
}
new.Search = &newSearch
}
newResolvers = append(newResolvers, new)
} else {
newResolvers = append(newResolvers, globalResolvers[key])
}
}
// save resolvers
globalResolvers = newResolvers
if len(globalResolvers) == 0 {
log.Criticalf("intel: no (valid) dns servers found in configuration and system")
}
// make list with local resolvers
localResolvers = make([]*Resolver, 0)
for _, resolver := range globalResolvers {
if resolver.ServerIP != nil && netutils.IPIsLAN(resolver.ServerIP) {
localResolvers = append(localResolvers, resolver)
}
}
// add resolvers to every scope the cover
localScopes = make([]*Scope, 0)
for _, resolver := range globalResolvers {
if resolver.Search != nil {
// add resolver to custom searches
for _, search := range *resolver.Search {
if search == "." {
continue
}
key := indexOfScope(search, localScopes)
if key == -1 {
localScopes = append(localScopes, &Scope{
Domain: search,
Resolvers: []*Resolver{resolver},
})
} else {
localScopes[key].Resolvers = append(localScopes[key].Resolvers, resolver)
}
}
}
}
// sort scopes by length
sort.Slice(localScopes,
func(i, j int) bool {
return len(localScopes[i].Domain) > len(localScopes[j].Domain)
},
)
log.Trace("intel: loaded global resolvers:")
for _, resolver := range globalResolvers {
log.Tracef("intel: %s", resolver.Server)
}
log.Trace("intel: loaded local resolvers:")
for _, resolver := range localResolvers {
log.Tracef("intel: %s", resolver.Server)
}
log.Trace("intel: loaded scopes:")
for _, scope := range localScopes {
var scopeServers []string
for _, resolver := range scope.Resolvers {
scopeServers = append(scopeServers, resolver.Server)
}
log.Tracef("intel: %s: %s", scope.Domain, strings.Join(scopeServers, ", "))
}
}

72
intel/reverse.go Normal file
View file

@ -0,0 +1,72 @@
package intel
import (
"errors"
"strings"
"github.com/Safing/portbase/log"
"github.com/miekg/dns"
)
// ResolveIPAndValidate finds (reverse DNS), validates (forward DNS) and returns the domain name assigned to the given IP.
func ResolveIPAndValidate(ip string, securityLevel uint8) (domain string, err error) {
// get reversed DNS address
rQ, err := dns.ReverseAddr(ip)
if err != nil {
log.Tracef("intel: failed to get reverse address of %s: %s", ip, err)
return "", err
}
// get PTR record
rrCache := Resolve(rQ, dns.Type(dns.TypePTR), securityLevel)
if rrCache == nil {
return "", errors.New("querying for PTR record failed (may be NXDomain)")
}
// get result from record
var ptrName string
for _, rr := range rrCache.Answer {
ptrRec, ok := rr.(*dns.PTR)
if ok {
ptrName = ptrRec.Ptr
break
}
}
// check for nxDomain
if ptrName == "" {
return "", errors.New("no PTR record for IP (nxDomain)")
}
log.Infof("ptrName: %s", ptrName)
// get forward record
if strings.Contains(ip, ":") {
rrCache = Resolve(ptrName, dns.Type(dns.TypeAAAA), securityLevel)
} else {
rrCache = Resolve(ptrName, dns.Type(dns.TypeA), securityLevel)
}
if rrCache == nil {
return "", errors.New("querying for A/AAAA record failed (may be NXDomain)")
}
// check for matching A/AAAA record
log.Infof("rr: %s", rrCache)
for _, rr := range rrCache.Answer {
switch v := rr.(type) {
case *dns.A:
log.Infof("A: %s", v.A.String())
if ip == v.A.String() {
return ptrName, nil
}
case *dns.AAAA:
log.Infof("AAAA: %s", v.AAAA.String())
if ip == v.AAAA.String() {
return ptrName, nil
}
}
}
// no match
return "", errors.New("validation failed")
}

28
intel/reverse_test.go Normal file
View file

@ -0,0 +1,28 @@
package intel
import "testing"
func testReverse(t *testing.T, ip, result, expectedErr string) {
domain, err := ResolveIPAndValidate(ip, 0)
if err != nil {
if expectedErr == "" || err.Error() != expectedErr {
t.Errorf("reverse-validating %s: unexpected error: %s", ip, err)
}
return
}
if domain != result {
t.Errorf("reverse-validating %s: unexpected result: %s", ip, domain)
}
}
func TestResolveIPAndValidate(t *testing.T) {
testReverse(t, "198.41.0.4", "a.root-servers.net.", "")
testReverse(t, "9.9.9.9", "dns.quad9.net.", "")
testReverse(t, "2620:fe::fe", "dns.quad9.net.", "")
testReverse(t, "1.1.1.1", "one.one.one.one.", "")
testReverse(t, "2606:4700:4700::1111", "one.one.one.one.", "")
testReverse(t, "93.184.216.34", "example.com.", "no PTR record for IP (nxDomain)")
testReverse(t, "185.199.109.153", "sites.github.io.", "no PTR record for IP (nxDomain)")
}

256
intel/rrcache.go Normal file
View file

@ -0,0 +1,256 @@
// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file.
package intel
import (
"fmt"
"net"
"strings"
"time"
"github.com/Safing/portbase/log"
"github.com/Safing/portmaster/network/netutils"
"github.com/miekg/dns"
)
// RRCache is used to cache DNS data
type RRCache struct {
Domain string
Question dns.Type
Answer []dns.RR
Ns []dns.RR
Extra []dns.RR
TTL int64
updated int64
servedFromCache bool
requestingNew bool
Filtered bool
}
// Clean sets all TTLs to 17 and sets cache expiry with specified minimum.
func (m *RRCache) Clean(minExpires uint32) {
var lowestTTL uint32 = 0xFFFFFFFF
var header *dns.RR_Header
// set TTLs to 17
// TODO: double append? is there something more elegant?
for _, rr := range append(m.Answer, append(m.Ns, m.Extra...)...) {
header = rr.Header()
if lowestTTL > header.Ttl {
lowestTTL = header.Ttl
}
header.Ttl = 17
}
// TTL must be at least minExpires
if lowestTTL < minExpires {
lowestTTL = minExpires
}
// log.Tracef("lowest TTL is %d", lowestTTL)
m.TTL = time.Now().Unix() + int64(lowestTTL)
}
// ExportAllARecords return of a list of all A and AAAA IP addresses.
func (m *RRCache) ExportAllARecords() (ips []net.IP) {
for _, rr := range m.Answer {
if rr.Header().Class != dns.ClassINET {
continue
}
switch rr.Header().Rrtype {
case dns.TypeA:
aRecord, ok := rr.(*dns.A)
if ok {
ips = append(ips, aRecord.A)
}
case dns.TypeAAAA:
aaaaRecord, ok := rr.(*dns.AAAA)
if ok {
ips = append(ips, aaaaRecord.AAAA)
}
}
}
return
}
// ToNameRecord converts the RRCache to a NameRecord for cleaner persistence.
func (m *RRCache) ToNameRecord() *NameRecord {
new := &NameRecord{
Domain: m.Domain,
Question: m.Question.String(),
TTL: m.TTL,
Filtered: m.Filtered,
}
// stringify RR entries
for _, entry := range m.Answer {
new.Answer = append(new.Answer, entry.String())
}
for _, entry := range m.Ns {
new.Ns = append(new.Ns, entry.String())
}
for _, entry := range m.Extra {
new.Extra = append(new.Extra, entry.String())
}
return new
}
// Save saves the RRCache to the database as a NameRecord.
func (m *RRCache) Save() error {
return m.ToNameRecord().Save()
}
// GetRRCache tries to load the corresponding NameRecord from the database and convert it.
func GetRRCache(domain string, question dns.Type) (*RRCache, error) {
rrCache := &RRCache{
Domain: domain,
Question: question,
}
nameRecord, err := GetNameRecord(domain, question.String())
if err != nil {
return nil, err
}
rrCache.TTL = nameRecord.TTL
for _, entry := range nameRecord.Answer {
rr, err := dns.NewRR(entry)
if err == nil {
rrCache.Answer = append(rrCache.Answer, rr)
}
}
for _, entry := range nameRecord.Ns {
rr, err := dns.NewRR(entry)
if err == nil {
rrCache.Ns = append(rrCache.Ns, rr)
}
}
for _, entry := range nameRecord.Extra {
rr, err := dns.NewRR(entry)
if err == nil {
rrCache.Extra = append(rrCache.Extra, rr)
}
}
rrCache.Filtered = nameRecord.Filtered
rrCache.servedFromCache = true
return rrCache, nil
}
// ServedFromCache marks the RRCache as served from cache.
func (m *RRCache) ServedFromCache() bool {
return m.servedFromCache
}
// RequestingNew informs that it has expired and new RRs are being fetched.
func (m *RRCache) RequestingNew() bool {
return m.requestingNew
}
// Flags formats ServedFromCache and RequestingNew to a condensed, flag-like format.
func (m *RRCache) Flags() string {
var s string
if m.servedFromCache {
s += "C"
}
if m.requestingNew {
s += "R"
}
if m.Filtered {
s += "F"
}
if s != "" {
return fmt.Sprintf(" [%s]", s)
}
return ""
}
// IsNXDomain returnes whether the result is nxdomain.
func (m *RRCache) IsNXDomain() bool {
return len(m.Answer) == 0
}
// Duplicate returns a duplicate of the cache. slices are not copied, but referenced.
func (m *RRCache) Duplicate() *RRCache {
return &RRCache{
Domain: m.Domain,
Question: m.Question,
Answer: m.Answer,
Ns: m.Ns,
Extra: m.Extra,
TTL: m.TTL,
updated: m.updated,
servedFromCache: m.servedFromCache,
requestingNew: m.requestingNew,
Filtered: m.Filtered,
}
}
// FilterEntries filters resource records according to the given permission scope.
func (m *RRCache) FilterEntries(internet, lan, host bool) {
var filtered bool
m.Answer, filtered = filterEntries(m, m.Answer, internet, lan, host)
if filtered {
m.Filtered = true
}
m.Extra, filtered = filterEntries(m, m.Extra, internet, lan, host)
if filtered {
m.Filtered = true
}
}
func filterEntries(m *RRCache, entries []dns.RR, internet, lan, host bool) (filteredEntries []dns.RR, filtered bool) {
filteredEntries = make([]dns.RR, 0, len(entries))
var classification int8
var deletedEntries []string
entryLoop:
for _, rr := range entries {
classification = -1
switch v := rr.(type) {
case *dns.A:
classification = netutils.ClassifyIP(v.A)
case *dns.AAAA:
classification = netutils.ClassifyIP(v.AAAA)
}
if classification >= 0 {
switch {
case !internet && classification == netutils.Global:
filtered = true
deletedEntries = append(deletedEntries, rr.String())
continue entryLoop
case !lan && (classification == netutils.SiteLocal || classification == netutils.LinkLocal):
filtered = true
deletedEntries = append(deletedEntries, rr.String())
continue entryLoop
case !host && classification == netutils.HostLocal:
filtered = true
deletedEntries = append(deletedEntries, rr.String())
continue entryLoop
}
}
filteredEntries = append(filteredEntries, rr)
}
if len(deletedEntries) > 0 {
log.Infof("intel: filtered DNS replies for %s%s: %s (Settings: Int=%v LAN=%v Host=%v)",
m.Domain,
m.Question.String(),
strings.Join(deletedEntries, ", "),
internet,
lan,
host,
)
}
return
}

View file

@ -5,7 +5,7 @@ package intel
import "strings"
var (
localReverseScopes = &[]string{
localReverseScopes = []string{
".10.in-addr.arpa.",
".16.172.in-addr.arpa.",
".17.172.in-addr.arpa.",
@ -31,7 +31,8 @@ var (
".b.e.f.ip6.arpa.",
}
specialScopes = &[]string{
// RFC6761, RFC7686
specialScopes = []string{
".example.",
".example.com.",
".example.net.",
@ -42,8 +43,8 @@ var (
}
)
func domainInScopes(fqdn string, list *[]string) bool {
for _, scope := range *list {
func domainInScopes(fqdn string, list []string) bool {
for _, scope := range list {
if strings.HasSuffix(fqdn, scope) {
return true
}

87
main.go Normal file
View file

@ -0,0 +1,87 @@
package main
import (
"flag"
"fmt"
"os"
"os/signal"
"runtime/pprof"
"syscall"
"time"
"github.com/Safing/portbase/info"
"github.com/Safing/portbase/log"
"github.com/Safing/portbase/modules"
// include packages here
_ "github.com/Safing/portbase/api"
_ "github.com/Safing/portbase/database/dbmodule"
_ "github.com/Safing/portbase/database/storage/badger"
_ "github.com/Safing/portmaster/firewall"
_ "github.com/Safing/portmaster/nameserver"
_ "github.com/Safing/portmaster/ui"
)
var (
printStackOnExit bool
)
func init() {
flag.BoolVar(&printStackOnExit, "print-stack-on-exit", false, "prints the stack before of shutting down")
}
func main() {
// Set Info
info.Set("Portmaster", "0.2.0")
// Start
err := modules.Start()
if err != nil {
if err == modules.ErrCleanExit {
os.Exit(0)
} else {
err = modules.Shutdown()
if err != nil {
log.Shutdown()
}
os.Exit(1)
}
}
// Shutdown
// catch interrupt for clean shutdown
signalCh := make(chan os.Signal)
signal.Notify(
signalCh,
os.Interrupt,
syscall.SIGHUP,
syscall.SIGINT,
syscall.SIGTERM,
syscall.SIGQUIT,
)
select {
case <-signalCh:
fmt.Println(" <INTERRUPT>")
log.Warning("main: program was interrupted, shutting down.")
if printStackOnExit {
fmt.Println("=== PRINTING STACK ===")
pprof.Lookup("goroutine").WriteTo(os.Stdout, 1)
fmt.Println("=== END STACK ===")
}
go func() {
time.Sleep(3 * time.Second)
fmt.Println("===== TAKING TOO LONG FOR SHUTDOWN - PRINTING STACK TRACES =====")
pprof.Lookup("goroutine").WriteTo(os.Stdout, 2)
os.Exit(1)
}()
modules.Shutdown()
os.Exit(0)
case <-modules.ShuttingDown():
}
}

View file

@ -4,38 +4,60 @@ package nameserver
import (
"net"
"time"
"github.com/miekg/dns"
"github.com/Safing/safing-core/analytics/algs"
"github.com/Safing/safing-core/intel"
"github.com/Safing/safing-core/log"
"github.com/Safing/safing-core/modules"
"github.com/Safing/safing-core/network"
"github.com/Safing/safing-core/network/netutils"
"github.com/Safing/safing-core/portmaster"
"github.com/Safing/portbase/log"
"github.com/Safing/portbase/modules"
"github.com/Safing/portmaster/analytics/algs"
"github.com/Safing/portmaster/firewall"
"github.com/Safing/portmaster/intel"
"github.com/Safing/portmaster/network"
"github.com/Safing/portmaster/network/netutils"
)
var (
nameserverModule *modules.Module
localhostIPs []dns.RR
)
func init() {
nameserverModule = modules.Register("Nameserver", 128)
modules.Register("nameserver", prep, start, nil, "intel")
}
func Start() {
func prep() error {
localhostIPv4, err := dns.NewRR("localhost. 17 IN A 127.0.0.1")
if err != nil {
return err
}
localhostIPv6, err := dns.NewRR("localhost. 17 IN AAAA ::1")
if err != nil {
return err
}
localhostIPs = []dns.RR{localhostIPv4, localhostIPv6}
return nil
}
func start() error {
server := &dns.Server{Addr: "127.0.0.1:53", Net: "udp"}
dns.HandleFunc(".", handleRequest)
go func() {
go run(server)
return nil
}
func run(server *dns.Server) {
for {
err := server.ListenAndServe()
if err != nil {
log.Errorf("nameserver: server failed: %s", err)
log.Info("nameserver: restarting server in 10 seconds")
time.Sleep(10 * time.Second)
}
}()
// TODO: stop mocking
defer nameserverModule.StopComplete()
<-nameserverModule.Stop
}
}
func nxDomain(w dns.ResponseWriter, query *dns.Msg) {
@ -47,7 +69,6 @@ func nxDomain(w dns.ResponseWriter, query *dns.Msg) {
func handleRequest(w dns.ResponseWriter, query *dns.Msg) {
// TODO: if there are 3 request for the same domain/type in a row, delete all caches of that domain
// TODO: handle securityLevelOff
// only process first question, that's how everyone does it.
question := query.Question[0]
@ -82,6 +103,14 @@ func handleRequest(w dns.ResponseWriter, query *dns.Msg) {
return
}
// handle request for localhost
if fqdn == "localhost." {
m := new(dns.Msg)
m.SetReply(query)
m.Answer = localhostIPs
w.WriteMsg(m)
}
// get remote address
// start := time.Now()
rAddr, ok := w.RemoteAddr().(*net.UDPAddr)
@ -109,19 +138,19 @@ func handleRequest(w dns.ResponseWriter, query *dns.Msg) {
// log.Tracef("nameserver: took %s to get connection/process of %s request", time.Now().Sub(timed).String(), fqdn)
// check profile before we even get intel and rr
if connection.Verdict == network.UNDECIDED {
if connection.GetVerdict() == network.UNDECIDED {
// start = time.Now()
portmaster.DecideOnConnectionBeforeIntel(connection, fqdn)
firewall.DecideOnConnectionBeforeIntel(connection, fqdn)
// log.Tracef("nameserver: took %s to make decision", time.Since(start))
}
if connection.Verdict == network.BLOCK || connection.Verdict == network.DROP {
if connection.GetVerdict() == network.BLOCK || connection.GetVerdict() == network.DROP {
nxDomain(w, query)
return
}
// get intel and RRs
// start = time.Now()
domainIntel, rrCache := intel.GetIntelAndRRs(fqdn, qtype, connection.Process().Profile.SecurityLevel)
domainIntel, rrCache := intel.GetIntelAndRRs(fqdn, qtype, connection.Process().ProfileSet().SecurityLevel())
// log.Tracef("nameserver: took %s to get intel and RRs", time.Since(start))
if rrCache == nil {
// TODO: analyze nxdomain requests, malware could be trying DGA-domains
@ -131,14 +160,16 @@ func handleRequest(w dns.ResponseWriter, query *dns.Msg) {
}
// set intel
connection.Lock()
connection.Intel = domainIntel
connection.Unlock()
connection.Save()
// do a full check with intel
if connection.Verdict == network.UNDECIDED {
rrCache = portmaster.DecideOnConnectionAfterIntel(connection, fqdn, rrCache)
if connection.GetVerdict() == network.UNDECIDED {
rrCache = firewall.DecideOnConnectionAfterIntel(connection, fqdn, rrCache)
}
if rrCache == nil || connection.Verdict == network.BLOCK || connection.Verdict == network.DROP {
if rrCache == nil || connection.GetVerdict() == network.BLOCK || connection.GetVerdict() == network.DROP {
nxDomain(w, query)
return
}
@ -150,23 +181,27 @@ func handleRequest(w dns.ResponseWriter, query *dns.Msg) {
ipInfo, err := intel.GetIPInfo(v.A.String())
if err != nil {
ipInfo = &intel.IPInfo{
IP: v.A.String(),
Domains: []string{fqdn},
}
ipInfo.Create(v.A.String())
} else {
ipInfo.Domains = append(ipInfo.Domains, fqdn)
ipInfo.Save()
} else {
if ipInfo.AddDomain(fqdn) {
ipInfo.Save()
}
}
case *dns.AAAA:
ipInfo, err := intel.GetIPInfo(v.AAAA.String())
if err != nil {
ipInfo = &intel.IPInfo{
IP: v.AAAA.String(),
Domains: []string{fqdn},
}
ipInfo.Create(v.AAAA.String())
} else {
ipInfo.Domains = append(ipInfo.Domains, fqdn)
ipInfo.Save()
} else {
if ipInfo.AddDomain(fqdn) {
ipInfo.Save()
}
}
}
}

View file

@ -0,0 +1,100 @@
package only
import (
"time"
"github.com/miekg/dns"
"github.com/Safing/portbase/log"
"github.com/Safing/portbase/modules"
"github.com/Safing/portmaster/analytics/algs"
"github.com/Safing/portmaster/intel"
"github.com/Safing/portmaster/network/netutils"
)
func init() {
modules.Register("nameserver", nil, start, nil, "intel")
}
func start() error {
server := &dns.Server{Addr: "127.0.0.1:53", Net: "udp"}
dns.HandleFunc(".", handleRequest)
go run(server)
return nil
}
func run(server *dns.Server) {
for {
err := server.ListenAndServe()
if err != nil {
log.Errorf("nameserver: server failed: %s", err)
log.Info("nameserver: restarting server in 10 seconds")
time.Sleep(10 * time.Second)
}
}
}
func nxDomain(w dns.ResponseWriter, query *dns.Msg) {
m := new(dns.Msg)
m.SetRcode(query, dns.RcodeNameError)
w.WriteMsg(m)
}
func handleRequest(w dns.ResponseWriter, query *dns.Msg) {
// TODO: if there are 3 request for the same domain/type in a row, delete all caches of that domain
// TODO: handle securityLevelOff
// only process first question, that's how everyone does it.
question := query.Question[0]
fqdn := dns.Fqdn(question.Name)
qtype := dns.Type(question.Qtype)
// use this to time how long it takes process this request
// timed := time.Now()
// defer log.Tracef("nameserver: took %s to handle request for %s%s", time.Now().Sub(timed).String(), fqdn, qtype.String())
// check if valid domain name
if !netutils.IsValidFqdn(fqdn) {
log.Tracef("nameserver: domain name %s is invalid, returning nxdomain", fqdn)
nxDomain(w, query)
return
}
// check for possible DNS tunneling / data transmission
// TODO: improve this
lms := algs.LmsScoreOfDomain(fqdn)
// log.Tracef("nameserver: domain %s has lms score of %f", fqdn, lms)
if lms < 10 {
log.Tracef("nameserver: possible data tunnel: %s has lms score of %f, returning nxdomain", fqdn, lms)
nxDomain(w, query)
return
}
// check class
if question.Qclass != dns.ClassINET {
// we only serve IN records, send NXDOMAIN
nxDomain(w, query)
return
}
// get intel and RRs
// start = time.Now()
_, rrCache := intel.GetIntelAndRRs(fqdn, qtype, 0)
// log.Tracef("nameserver: took %s to get intel and RRs", time.Since(start))
if rrCache == nil {
// TODO: analyze nxdomain requests, malware could be trying DGA-domains
log.Infof("nameserver: %s is nxdomain", fqdn)
nxDomain(w, query)
return
}
// reply to query
m := new(dns.Msg)
m.SetReply(query)
m.Answer = rrCache.Answer
m.Ns = rrCache.Ns
m.Extra = rrCache.Extra
w.WriteMsg(m)
}

View file

@ -5,34 +5,52 @@ package network
import (
"time"
"github.com/Safing/safing-core/process"
"github.com/Safing/portmaster/process"
)
func init() {
go cleaner()
}
var (
cleanerTickDuration = 10 * time.Second
deadLinksTimeout = 3 * time.Minute
thresholdDuration = 3 * time.Minute
)
func cleaner() {
time.Sleep(15 * time.Second)
for {
markDeadLinks()
purgeDeadFor(5 * time.Minute)
time.Sleep(15 * time.Second)
time.Sleep(cleanerTickDuration)
cleanLinks()
time.Sleep(2 * time.Second)
cleanConnections()
time.Sleep(2 * time.Second)
cleanProcesses()
}
}
func markDeadLinks() {
func cleanLinks() {
activeIDs := process.GetActiveConnectionIDs()
allLinksLock.RLock()
defer allLinksLock.RUnlock()
now := time.Now().Unix()
var found bool
for key, link := range allLinks {
deleteOlderThan := time.Now().Add(-deadLinksTimeout).Unix()
// skip dead links
// log.Tracef("network.clean: now=%d", now)
// log.Tracef("network.clean: deleteOlderThan=%d", deleteOlderThan)
linksLock.RLock()
defer linksLock.RUnlock()
var found bool
for key, link := range links {
// delete dead links
if link.Ended > 0 {
link.Lock()
deleteThis := link.Ended < deleteOlderThan
link.Unlock()
if deleteThis {
// log.Tracef("network.clean: deleted %s", link.DatabaseKey())
go link.Delete()
}
continue
}
@ -48,56 +66,28 @@ func markDeadLinks() {
// mark end time
if !found {
link.Ended = now
link.Save()
// log.Tracef("network.clean: marked %s as ended.", link.DatabaseKey())
go link.Save()
}
}
}
func purgeDeadFor(age time.Duration) {
connections := make(map[*Connection]bool)
processes := make(map[*process.Process]bool)
func cleanConnections() {
connectionsLock.RLock()
defer connectionsLock.RUnlock()
allLinksLock.Lock()
defer allLinksLock.Unlock()
// delete old dead links
// make a list of connections without links
ageAgo := time.Now().Add(-1 * age).Unix()
for key, link := range allLinks {
if link.Ended != 0 && link.Ended < ageAgo {
link.Delete()
delete(allLinks, key)
_, ok := connections[link.Connection()]
if !ok {
connections[link.Connection()] = false
}
} else {
connections[link.Connection()] = true
threshold := time.Now().Add(-thresholdDuration).Unix()
for _, conn := range connections {
conn.Lock()
if conn.FirstLinkEstablished < threshold && conn.LinkCount == 0 {
// log.Tracef("network.clean: deleted %s", conn.DatabaseKey())
go conn.Delete()
}
conn.Unlock()
}
// delete connections without links
// make a list of processes without connections
for conn, active := range connections {
if conn != nil {
if !active {
conn.Delete()
_, ok := processes[conn.Process()]
if !ok {
processes[conn.Process()] = false
}
} else {
processes[conn.Process()] = true
}
}
}
// delete processes without connections
for proc, active := range processes {
if proc != nil && !active {
proc.Delete()
}
}
}
func cleanProcesses() {
process.CleanProcessStorage(thresholdDuration)
}

View file

@ -3,147 +3,187 @@
package network
import (
"errors"
"fmt"
"net"
"sync"
"time"
"github.com/Safing/safing-core/database"
"github.com/Safing/safing-core/intel"
"github.com/Safing/safing-core/network/packet"
"github.com/Safing/safing-core/process"
datastore "github.com/ipfs/go-datastore"
"github.com/Safing/portbase/database/record"
"github.com/Safing/portmaster/intel"
"github.com/Safing/portmaster/network/netutils"
"github.com/Safing/portmaster/network/packet"
"github.com/Safing/portmaster/process"
)
// Connection describes a connection between a process and a domain
type Connection struct {
database.Base
Domain string
Direction bool
Intel *intel.Intel
process *process.Process
Verdict Verdict
Reason string
Inspect bool
record.Base
sync.Mutex
Domain string
Direction bool
Intel *intel.Intel
process *process.Process
Verdict Verdict
Reason string
Inspect bool
FirstLinkEstablished int64
LastLinkEstablished int64
LinkCount uint
}
var connectionModel *Connection // only use this as parameter for database.EnsureModel-like functions
// Process returns the process that owns the connection.
func (conn *Connection) Process() *process.Process {
conn.Lock()
defer conn.Unlock()
func init() {
database.RegisterModel(connectionModel, func() database.Model { return new(Connection) })
return conn.process
}
func (m *Connection) Process() *process.Process {
return m.process
// GetVerdict returns the current verdict.
func (conn *Connection) GetVerdict() Verdict {
conn.Lock()
defer conn.Unlock()
return conn.Verdict
}
// Create creates a new database entry in the database in the default namespace for this object
func (m *Connection) Create(name string) error {
return m.CreateObject(&database.OrphanedConnection, name, m)
// Accept accepts the connection and adds the given reason.
func (conn *Connection) Accept(reason string) {
conn.AddReason(reason)
conn.UpdateVerdict(ACCEPT)
}
// CreateInProcessNamespace creates a new database entry in the namespace of the connection's process
func (m *Connection) CreateInProcessNamespace() error {
if m.process != nil {
return m.CreateObject(m.process.GetKey(), m.Domain, m)
// Deny blocks or drops the connection depending on the connection direction and adds the given reason.
func (conn *Connection) Deny(reason string) {
if conn.Direction {
conn.Drop(reason)
} else {
conn.Block(reason)
}
return m.CreateObject(&database.OrphanedConnection, m.Domain, m)
}
// Save saves the object to the database (It must have been either already created or loaded from the database)
func (m *Connection) Save() error {
return m.SaveObject(m)
// Block blocks the connection and adds the given reason.
func (conn *Connection) Block(reason string) {
conn.AddReason(reason)
conn.UpdateVerdict(BLOCK)
}
func (m *Connection) CantSay() {
if m.Verdict != CANTSAY {
m.Verdict = CANTSAY
m.SaveObject(m)
// Drop drops the connection and adds the given reason.
func (conn *Connection) Drop(reason string) {
conn.AddReason(reason)
conn.UpdateVerdict(DROP)
}
// UpdateVerdict sets a new verdict for this link, making sure it does not interfere with previous verdicts
func (conn *Connection) UpdateVerdict(newVerdict Verdict) {
conn.Lock()
defer conn.Unlock()
if newVerdict > conn.Verdict {
conn.Verdict = newVerdict
go conn.Save()
}
return
}
func (m *Connection) Drop() {
if m.Verdict != DROP {
m.Verdict = DROP
m.SaveObject(m)
}
return
}
func (m *Connection) Block() {
if m.Verdict != BLOCK {
m.Verdict = BLOCK
m.SaveObject(m)
}
return
}
func (m *Connection) Accept() {
if m.Verdict != ACCEPT {
m.Verdict = ACCEPT
m.SaveObject(m)
}
return
}
// AddReason adds a human readable string as to why a certain verdict was set in regard to this connection
func (m *Connection) AddReason(newReason string) {
if m.Reason != "" {
m.Reason += " | "
func (conn *Connection) AddReason(reason string) {
if reason == "" {
return
}
m.Reason += newReason
conn.Lock()
defer conn.Unlock()
if conn.Reason != "" {
conn.Reason += " | "
}
conn.Reason += reason
}
// GetConnectionByFirstPacket returns the matching connection from the internal storage.
func GetConnectionByFirstPacket(pkt packet.Packet) (*Connection, error) {
// get Process
proc, direction, err := process.GetProcessByPacket(pkt)
if err != nil {
return nil, err
}
var domain string
// if INBOUND
// Incoming
if direction {
connection, err := GetConnectionFromProcessNamespace(proc, "I")
if err != nil {
switch netutils.ClassifyIP(pkt.GetIPHeader().Src) {
case netutils.HostLocal:
domain = IncomingHost
case netutils.LinkLocal, netutils.SiteLocal, netutils.LocalMulticast:
domain = IncomingLAN
case netutils.Global, netutils.GlobalMulticast:
domain = IncomingInternet
case netutils.Invalid:
domain = IncomingInvalid
}
connection, ok := GetConnection(proc.Pid, domain)
if !ok {
connection = &Connection{
Domain: "I",
Direction: true,
Domain: domain,
Direction: Inbound,
process: proc,
Inspect: true,
FirstLinkEstablished: time.Now().Unix(),
}
}
connection.process.AddConnection()
return connection, nil
}
// get domain
ipinfo, err := intel.GetIPInfo(pkt.FmtRemoteIP())
// PeerToPeer
if err != nil {
// if no domain could be found, it must be a direct connection
connection, err := GetConnectionFromProcessNamespace(proc, "D")
if err != nil {
switch netutils.ClassifyIP(pkt.GetIPHeader().Dst) {
case netutils.HostLocal:
domain = PeerHost
case netutils.LinkLocal, netutils.SiteLocal, netutils.LocalMulticast:
domain = PeerLAN
case netutils.Global, netutils.GlobalMulticast:
domain = PeerInternet
case netutils.Invalid:
domain = PeerInvalid
}
connection, ok := GetConnection(proc.Pid, domain)
if !ok {
connection = &Connection{
Domain: "D",
Domain: domain,
Direction: Outbound,
process: proc,
Inspect: true,
FirstLinkEstablished: time.Now().Unix(),
}
}
connection.process.AddConnection()
return connection, nil
}
// To Domain
// FIXME: how to handle multiple possible domains?
connection, err := GetConnectionFromProcessNamespace(proc, ipinfo.Domains[0])
if err != nil {
connection, ok := GetConnection(proc.Pid, ipinfo.Domains[0])
if !ok {
connection = &Connection{
Domain: ipinfo.Domains[0],
Direction: Outbound,
process: proc,
Inspect: true,
FirstLinkEstablished: time.Now().Unix(),
}
}
connection.process.AddConnection()
return connection, nil
}
@ -154,6 +194,7 @@ var (
dnsPort uint16 = 53
)
// GetConnectionByDNSRequest returns the matching connection from the internal storage.
func GetConnectionByDNSRequest(ip net.IP, port uint16, fqdn string) (*Connection, error) {
// get Process
proc, err := process.GetProcessByEndpoints(ip, port, dnsAddress, dnsPort, packet.UDP)
@ -161,70 +202,124 @@ func GetConnectionByDNSRequest(ip net.IP, port uint16, fqdn string) (*Connection
return nil, err
}
connection, err := GetConnectionFromProcessNamespace(proc, fqdn)
if err != nil {
connection, ok := GetConnection(proc.Pid, fqdn)
if !ok {
connection = &Connection{
Domain: fqdn,
process: proc,
Inspect: true,
}
connection.CreateInProcessNamespace()
connection.process.AddConnection()
connection.Save()
}
return connection, nil
}
// GetConnection fetches a Connection from the database from the default namespace for this object
func GetConnection(name string) (*Connection, error) {
return GetConnectionFromNamespace(&database.OrphanedConnection, name)
// GetConnection fetches a connection object from the internal storage.
func GetConnection(pid int, domain string) (conn *Connection, ok bool) {
connectionsLock.RLock()
defer connectionsLock.RUnlock()
conn, ok = connections[fmt.Sprintf("%d/%s", pid, domain)]
return
}
// GetConnectionFromProcessNamespace fetches a Connection from the namespace of its process
func GetConnectionFromProcessNamespace(process *process.Process, domain string) (*Connection, error) {
return GetConnectionFromNamespace(process.GetKey(), domain)
func (conn *Connection) makeKey() string {
return fmt.Sprintf("%d/%s", conn.process.Pid, conn.Domain)
}
// GetConnectionFromNamespace fetches a Connection form the database, but from a custom namespace
func GetConnectionFromNamespace(namespace *datastore.Key, name string) (*Connection, error) {
object, err := database.GetAndEnsureModel(namespace, name, connectionModel)
if err != nil {
return nil, err
// Save saves the connection object in the storage and propagates the change.
func (conn *Connection) Save() error {
conn.Lock()
defer conn.Unlock()
if conn.process == nil {
return errors.New("cannot save connection without process")
}
model, ok := object.(*Connection)
if !conn.KeyIsSet() {
conn.SetKey(fmt.Sprintf("network:tree/%d/%s", conn.process.Pid, conn.Domain))
conn.CreateMeta()
}
key := conn.makeKey()
connectionsLock.RLock()
_, ok := connections[key]
connectionsLock.RUnlock()
if !ok {
return nil, database.NewMismatchError(object, connectionModel)
connectionsLock.Lock()
connections[key] = conn
connectionsLock.Unlock()
}
return model, nil
go dbController.PushUpdate(conn)
return nil
}
func (m *Connection) AddLink(link *Link, pkt packet.Packet) {
link.connection = m
link.Verdict = m.Verdict
link.Inspect = m.Inspect
if m.FirstLinkEstablished == 0 {
m.FirstLinkEstablished = time.Now().Unix()
m.Save()
}
link.CreateInConnectionNamespace(pkt.GetConnectionID())
// Delete deletes a connection from the storage and propagates the change.
func (conn *Connection) Delete() {
conn.Lock()
defer conn.Unlock()
connectionsLock.Lock()
delete(connections, conn.makeKey())
connectionsLock.Unlock()
conn.Meta().Delete()
go dbController.PushUpdate(conn)
conn.process.RemoveConnection()
go conn.process.Save()
}
// FORMATTING
// AddLink applies the connection to the link and increases sets counter and timestamps.
func (conn *Connection) AddLink(link *Link) {
link.Lock()
link.connection = conn
link.Verdict = conn.Verdict
link.Inspect = conn.Inspect
link.Unlock()
link.Save()
func (m *Connection) String() string {
switch m.Domain {
case "I":
if m.process == nil {
conn.Lock()
conn.LinkCount++
conn.LastLinkEstablished = time.Now().Unix()
if conn.FirstLinkEstablished == 0 {
conn.FirstLinkEstablished = conn.LastLinkEstablished
}
conn.Unlock()
conn.Save()
}
// RemoveLink lowers the link counter by one.
func (conn *Connection) RemoveLink() {
conn.Lock()
defer conn.Unlock()
if conn.LinkCount > 0 {
conn.LinkCount--
}
}
// String returns a string representation of Connection.
func (conn *Connection) String() string {
conn.Lock()
defer conn.Unlock()
switch conn.Domain {
case IncomingHost, IncomingLAN, IncomingInternet, IncomingInvalid:
if conn.process == nil {
return "? <- *"
}
return fmt.Sprintf("%s <- *", m.process.String())
case "D":
if m.process == nil {
return fmt.Sprintf("%s <- *", conn.process.String())
case PeerHost, PeerLAN, PeerInternet, PeerInvalid:
if conn.process == nil {
return "? -> *"
}
return fmt.Sprintf("%s -> *", m.process.String())
return fmt.Sprintf("%s -> *", conn.process.String())
default:
if m.process == nil {
return fmt.Sprintf("? -> %s", m.Domain)
if conn.process == nil {
return fmt.Sprintf("? -> %s", conn.Domain)
}
return fmt.Sprintf("%s -> %s", m.process.String(), m.Domain)
return fmt.Sprintf("%s -> %s", conn.process.String(), conn.Domain)
}
}

122
network/database.go Normal file
View file

@ -0,0 +1,122 @@
package network
import (
"strconv"
"strings"
"sync"
"github.com/Safing/portbase/database"
"github.com/Safing/portbase/database/iterator"
"github.com/Safing/portbase/database/query"
"github.com/Safing/portbase/database/record"
"github.com/Safing/portbase/database/storage"
"github.com/Safing/portmaster/process"
)
var (
links = make(map[string]*Link)
linksLock sync.RWMutex
connections = make(map[string]*Connection)
connectionsLock sync.RWMutex
dbController *database.Controller
)
// StorageInterface provices a storage.Interface to the configuration manager.
type StorageInterface struct {
storage.InjectBase
}
// Get returns a database record.
func (s *StorageInterface) Get(key string) (record.Record, error) {
splitted := strings.Split(key, "/")
switch splitted[0] {
case "tree":
switch len(splitted) {
case 2:
pid, err := strconv.Atoi(splitted[1])
if err == nil {
proc, ok := process.GetProcessFromStorage(pid)
if ok {
return proc, nil
}
}
case 3:
connectionsLock.RLock()
defer connectionsLock.RUnlock()
conn, ok := connections[splitted[2]]
if ok {
return conn, nil
}
case 4:
linksLock.RLock()
defer linksLock.RUnlock()
link, ok := links[splitted[3]]
if ok {
return link, nil
}
}
}
return nil, storage.ErrNotFound
}
// Query returns a an iterator for the supplied query.
func (s *StorageInterface) Query(q *query.Query, local, internal bool) (*iterator.Iterator, error) {
it := iterator.New()
go s.processQuery(q, it)
// TODO: check local and internal
return it, nil
}
func (s *StorageInterface) processQuery(q *query.Query, it *iterator.Iterator) {
// processes
for _, proc := range process.All() {
if strings.HasPrefix(proc.DatabaseKey(), q.DatabaseKeyPrefix()) {
it.Next <- proc
}
}
// connections
connectionsLock.RLock()
for _, conn := range connections {
if strings.HasPrefix(conn.DatabaseKey(), q.DatabaseKeyPrefix()) {
it.Next <- conn
}
}
connectionsLock.RUnlock()
// links
linksLock.RLock()
for _, link := range links {
if strings.HasPrefix(link.DatabaseKey(), q.DatabaseKeyPrefix()) {
it.Next <- link
}
}
linksLock.RUnlock()
it.Finish(nil)
}
func registerAsDatabase() error {
_, err := database.Register(&database.Database{
Name: "network",
Description: "Network and Firewall Data",
StorageType: "injected",
PrimaryAPI: "",
})
if err != nil {
return err
}
controller, err := database.InjectDatabase("network", &StorageInterface{})
if err != nil {
return err
}
dbController = controller
process.SetDBController(dbController)
return nil
}

View file

@ -4,7 +4,7 @@ import (
"net"
"strings"
"github.com/Safing/safing-core/network/netutils"
"github.com/Safing/portmaster/network/netutils"
)
func GetAssignedAddresses() (ipv4 []net.IP, ipv6 []net.IP, err error) {

View file

@ -24,6 +24,9 @@ func getNameserversFromDbus() ([]Nameserver, error) {
var nameservers []Nameserver
var err error
dbusConnLock.Lock()
defer dbusConnLock.Unlock()
if dbusConn == nil {
dbusConn, err = dbus.SystemBus()
}
@ -158,6 +161,9 @@ func getNameserversFromDbus() ([]Nameserver, error) {
func getConnectivityStateFromDbus() (uint8, error) {
var err error
dbusConnLock.Lock()
defer dbusConnLock.Unlock()
if dbusConn == nil {
dbusConn, err = dbus.SystemBus()
}

View file

@ -11,7 +11,7 @@ import (
"sync/atomic"
"time"
"github.com/Safing/safing-core/log"
"github.com/Safing/portbase/log"
)
// TODO: find a good way to identify a network

View file

@ -0,0 +1,27 @@
package environment
import "net"
func Nameservers() []Nameserver {
return nil
}
func Gateways() []*net.IP {
return nil
}
// TODO: implement using
// ifconfig
// scutil --nwi
// scutil --proxy
// networksetup -listallnetworkservices
// networksetup -listnetworkserviceorder
// networksetup -getdnsservers "Wi-Fi"
// networksetup -getsearchdomains <networkservice>
// networksetup -getftpproxy <networkservice>
// networksetup -getwebproxy <networkservice>
// networksetup -getsecurewebproxy <networkservice>
// networksetup -getstreamingproxy <networkservice>
// networksetup -getgopherproxy <networkservice>
// networksetup -getsocksfirewallproxy <networkservice>
// route -n get default

View file

@ -12,8 +12,8 @@ import (
"github.com/miekg/dns"
"github.com/Safing/safing-core/log"
"github.com/Safing/safing-core/network/netutils"
"github.com/Safing/portbase/log"
"github.com/Safing/portmaster/network/netutils"
)
// Gateways returns the currently active gateways

View file

@ -1,4 +1,4 @@
// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file.
// +build linux
package environment

View file

@ -8,9 +8,10 @@ import (
"log"
"net"
"os"
"github.com/Safing/safing-core/network/netutils"
"time"
"github.com/Safing/portmaster/network/netutils"
"golang.org/x/net/icmp"
"golang.org/x/net/ipv4"
)
@ -99,7 +100,7 @@ next:
if ip == nil {
return nil, errors.New(fmt.Sprintf("failed to parse IP: %s", peer.String()))
}
if !netutils.IPIsLocal(ip) {
if !netutils.IPIsLAN(ip) {
return ip, nil
}
continue next

View file

@ -3,28 +3,30 @@
package network
import (
"errors"
"fmt"
"sync"
"time"
datastore "github.com/ipfs/go-datastore"
"github.com/Safing/safing-core/database"
"github.com/Safing/safing-core/log"
"github.com/Safing/safing-core/network/packet"
"github.com/Safing/portbase/database/record"
"github.com/Safing/portbase/log"
"github.com/Safing/portmaster/network/packet"
)
// FirewallHandler defines the function signature for a firewall handle function
type FirewallHandler func(pkt packet.Packet, link *Link)
var (
linkTimeout = 10 * time.Minute
allLinks = make(map[string]*Link)
allLinksLock sync.RWMutex
linkTimeout = 10 * time.Minute
)
// Link describes an distinct physical connection (e.g. TCP connection) - like an instance - of a Connection
// Link describes a distinct physical connection (e.g. TCP connection) - like an instance - of a Connection.
type Link struct {
database.Base
record.Base
sync.Mutex
ID string
Verdict Verdict
Reason string
Tunneled bool
@ -32,180 +34,322 @@ type Link struct {
Inspect bool
Started int64
Ended int64
connection *Connection
RemoteAddress string
ActiveInspectors []bool `json:"-" bson:"-"`
InspectorData map[uint8]interface{} `json:"-" bson:"-"`
pktQueue chan packet.Packet
firewallHandler FirewallHandler
}
connection *Connection
var linkModel *Link // only use this as parameter for database.EnsureModel-like functions
func init() {
database.RegisterModel(linkModel, func() database.Model { return new(Link) })
activeInspectors []bool
inspectorData map[uint8]interface{}
}
// Connection returns the Connection the Link is part of
func (m *Link) Connection() *Connection {
return m.connection
func (link *Link) Connection() *Connection {
link.Lock()
defer link.Unlock()
return link.connection
}
// GetVerdict returns the current verdict.
func (link *Link) GetVerdict() Verdict {
link.Lock()
defer link.Unlock()
return link.Verdict
}
// FirewallHandlerIsSet returns whether a firewall handler is set or not
func (m *Link) FirewallHandlerIsSet() bool {
return m.firewallHandler != nil
func (link *Link) FirewallHandlerIsSet() bool {
link.Lock()
defer link.Unlock()
return link.firewallHandler != nil
}
// SetFirewallHandler sets the firewall handler for this link
func (m *Link) SetFirewallHandler(handler FirewallHandler) {
if m.firewallHandler == nil {
m.firewallHandler = handler
m.pktQueue = make(chan packet.Packet, 1000)
go m.packetHandler()
func (link *Link) SetFirewallHandler(handler FirewallHandler) {
link.Lock()
defer link.Unlock()
if link.firewallHandler == nil {
link.firewallHandler = handler
link.pktQueue = make(chan packet.Packet, 1000)
go link.packetHandler()
return
}
m.firewallHandler = handler
link.firewallHandler = handler
}
// StopFirewallHandler unsets the firewall handler
func (m *Link) StopFirewallHandler() {
m.pktQueue <- nil
func (link *Link) StopFirewallHandler() {
link.Lock()
link.firewallHandler = nil
link.Unlock()
link.pktQueue <- nil
}
// HandlePacket queues packet of Link for handling
func (m *Link) HandlePacket(pkt packet.Packet) {
if m.firewallHandler != nil {
m.pktQueue <- pkt
func (link *Link) HandlePacket(pkt packet.Packet) {
link.Lock()
defer link.Unlock()
if link.firewallHandler != nil {
link.pktQueue <- pkt
return
}
log.Criticalf("network: link %s does not have a firewallHandler, maybe its a copy, dropping packet", m)
log.Criticalf("network: link %s does not have a firewallHandler, dropping packet", link)
pkt.Drop()
}
// Accept accepts the link and adds the given reason.
func (link *Link) Accept(reason string) {
link.AddReason(reason)
link.UpdateVerdict(ACCEPT)
}
// Deny blocks or drops the link depending on the connection direction and adds the given reason.
func (link *Link) Deny(reason string) {
if link.connection != nil && link.connection.Direction {
link.Drop(reason)
} else {
link.Block(reason)
}
}
// Block blocks the link and adds the given reason.
func (link *Link) Block(reason string) {
link.AddReason(reason)
link.UpdateVerdict(BLOCK)
}
// Drop drops the link and adds the given reason.
func (link *Link) Drop(reason string) {
link.AddReason(reason)
link.UpdateVerdict(DROP)
}
// RerouteToNameserver reroutes the link to the portmaster nameserver.
func (link *Link) RerouteToNameserver() {
link.UpdateVerdict(RerouteToNameserver)
}
// RerouteToTunnel reroutes the link to the tunnel entrypoint and adds the given reason for accepting the connection.
func (link *Link) RerouteToTunnel(reason string) {
link.AddReason(reason)
link.UpdateVerdict(RerouteToTunnel)
}
// UpdateVerdict sets a new verdict for this link, making sure it does not interfere with previous verdicts
func (m *Link) UpdateVerdict(newVerdict Verdict) {
if newVerdict > m.Verdict {
m.Verdict = newVerdict
m.Save()
func (link *Link) UpdateVerdict(newVerdict Verdict) {
link.Lock()
defer link.Unlock()
if newVerdict > link.Verdict {
link.Verdict = newVerdict
go link.Save()
}
}
// AddReason adds a human readable string as to why a certain verdict was set in regard to this link
func (m *Link) AddReason(newReason string) {
if m.Reason != "" {
m.Reason += " | "
func (link *Link) AddReason(reason string) {
if reason == "" {
return
}
m.Reason += newReason
link.Lock()
defer link.Unlock()
if link.Reason != "" {
link.Reason += " | "
}
link.Reason += reason
}
// packetHandler sequentially handles queued packets
func (m *Link) packetHandler() {
func (link *Link) packetHandler() {
for {
pkt := <-m.pktQueue
pkt := <-link.pktQueue
if pkt == nil {
break
return
}
link.Lock()
fwH := link.firewallHandler
link.Unlock()
if fwH != nil {
fwH(pkt, link)
} else {
link.ApplyVerdict(pkt)
}
m.firewallHandler(pkt, m)
}
m.firewallHandler = nil
}
// Create creates a new database entry in the database in the default namespace for this object
func (m *Link) Create(name string) error {
m.CreateShallow(name)
return m.CreateObject(&database.OrphanedLink, name, m)
}
// ApplyVerdict appies the link verdict to a packet.
func (link *Link) ApplyVerdict(pkt packet.Packet) {
link.Lock()
defer link.Unlock()
// Create creates a new database entry in the database in the default namespace for this object
func (m *Link) CreateShallow(name string) {
allLinksLock.Lock()
allLinks[name] = m
allLinksLock.Unlock()
}
// CreateWithDefaultKey creates a new database entry in the database in the default namespace for this object using the default key
func (m *Link) CreateInConnectionNamespace(name string) error {
if m.connection != nil {
return m.CreateObject(m.connection.GetKey(), name, m)
if link.VerdictPermanent {
switch link.Verdict {
case ACCEPT:
pkt.PermanentAccept()
case BLOCK:
pkt.PermanentBlock()
case DROP:
pkt.PermanentDrop()
case RerouteToNameserver:
pkt.RerouteToNameserver()
case RerouteToTunnel:
pkt.RerouteToTunnel()
default:
pkt.Drop()
}
} else {
switch link.Verdict {
case ACCEPT:
pkt.Accept()
case BLOCK:
pkt.Block()
case DROP:
pkt.Drop()
case RerouteToNameserver:
pkt.RerouteToNameserver()
case RerouteToTunnel:
pkt.RerouteToTunnel()
default:
pkt.Drop()
}
}
return m.CreateObject(&database.OrphanedLink, name, m)
}
// Save saves the object to the database (It must have been either already created or loaded from the database)
func (m *Link) Save() error {
return m.SaveObject(m)
// Save saves the link object in the storage and propagates the change.
func (link *Link) Save() error {
link.Lock()
defer link.Unlock()
if link.connection == nil {
return errors.New("cannot save link without connection")
}
if !link.KeyIsSet() {
link.SetKey(fmt.Sprintf("network:tree/%d/%s/%s", link.connection.Process().Pid, link.connection.Domain, link.ID))
link.CreateMeta()
}
linksLock.RLock()
_, ok := links[link.ID]
linksLock.RUnlock()
if !ok {
linksLock.Lock()
links[link.ID] = link
linksLock.Unlock()
}
go dbController.PushUpdate(link)
return nil
}
// Delete deletes a link from the storage and propagates the change.
func (link *Link) Delete() {
link.Lock()
defer link.Unlock()
linksLock.Lock()
delete(links, link.ID)
linksLock.Unlock()
link.Meta().Delete()
go dbController.PushUpdate(link)
link.connection.RemoveLink()
go link.connection.Save()
}
// GetLink fetches a Link from the database from the default namespace for this object
func GetLink(name string) (*Link, error) {
allLinksLock.RLock()
link, ok := allLinks[name]
allLinksLock.RUnlock()
if !ok {
return nil, database.ErrNotFound
}
return link, nil
// return GetLinkFromNamespace(&database.RunningLink, name)
}
func GetLink(id string) (*Link, bool) {
linksLock.RLock()
defer linksLock.RUnlock()
func SaveInCache(link *Link) {
}
// GetLinkFromNamespace fetches a Link form the database, but from a custom namespace
func GetLinkFromNamespace(namespace *datastore.Key, name string) (*Link, error) {
object, err := database.GetAndEnsureModel(namespace, name, linkModel)
if err != nil {
return nil, err
}
model, ok := object.(*Link)
if !ok {
return nil, database.NewMismatchError(object, linkModel)
}
return model, nil
link, ok := links[id]
return link, ok
}
// GetOrCreateLinkByPacket returns the associated Link for a packet and a bool expressing if the Link was newly created
func GetOrCreateLinkByPacket(pkt packet.Packet) (*Link, bool) {
link, err := GetLink(pkt.GetConnectionID())
if err != nil {
return CreateLinkFromPacket(pkt), true
link, ok := GetLink(pkt.GetLinkID())
if ok {
return link, false
}
return link, false
return CreateLinkFromPacket(pkt), true
}
// CreateLinkFromPacket creates a new Link based on Packet. The Link is shallowly saved and SHOULD be saved to the database as soon more information is available
// CreateLinkFromPacket creates a new Link based on Packet.
func CreateLinkFromPacket(pkt packet.Packet) *Link {
link := &Link{
ID: pkt.GetLinkID(),
Verdict: UNDECIDED,
Started: time.Now().Unix(),
RemoteAddress: pkt.FmtRemoteAddress(),
}
link.CreateShallow(pkt.GetConnectionID())
return link
}
// FORMATTING
func (m *Link) String() string {
if m.connection == nil {
return fmt.Sprintf("? <-> %s", m.RemoteAddress)
// GetActiveInspectors returns the list of active inspectors.
func (link *Link) GetActiveInspectors() []bool {
link.Lock()
defer link.Unlock()
return link.activeInspectors
}
// SetActiveInspectors sets the list of active inspectors.
func (link *Link) SetActiveInspectors(new []bool) {
link.Lock()
defer link.Unlock()
link.activeInspectors = new
}
// GetInspectorData returns the list of inspector data.
func (link *Link) GetInspectorData() map[uint8]interface{} {
link.Lock()
defer link.Unlock()
return link.inspectorData
}
// SetInspectorData set the list of inspector data.
func (link *Link) SetInspectorData(new map[uint8]interface{}) {
link.Lock()
defer link.Unlock()
link.inspectorData = new
}
// String returns a string representation of Link.
func (link *Link) String() string {
link.Lock()
defer link.Unlock()
if link.connection == nil {
return fmt.Sprintf("? <-> %s", link.RemoteAddress)
}
switch m.connection.Domain {
switch link.connection.Domain {
case "I":
if m.connection.process == nil {
return fmt.Sprintf("? <- %s", m.RemoteAddress)
if link.connection.process == nil {
return fmt.Sprintf("? <- %s", link.RemoteAddress)
}
return fmt.Sprintf("%s <- %s", m.connection.process.String(), m.RemoteAddress)
return fmt.Sprintf("%s <- %s", link.connection.process.String(), link.RemoteAddress)
case "D":
if m.connection.process == nil {
return fmt.Sprintf("? -> %s", m.RemoteAddress)
if link.connection.process == nil {
return fmt.Sprintf("? -> %s", link.RemoteAddress)
}
return fmt.Sprintf("%s -> %s", m.connection.process.String(), m.RemoteAddress)
return fmt.Sprintf("%s -> %s", link.connection.process.String(), link.RemoteAddress)
default:
if m.connection.process == nil {
return fmt.Sprintf("? -> %s (%s)", m.connection.Domain, m.RemoteAddress)
if link.connection.process == nil {
return fmt.Sprintf("? -> %s (%s)", link.connection.Domain, link.RemoteAddress)
}
return fmt.Sprintf("%s to %s (%s)", m.connection.process.String(), m.connection.Domain, m.RemoteAddress)
return fmt.Sprintf("%s to %s (%s)", link.connection.process.String(), link.connection.Domain, link.RemoteAddress)
}
}

14
network/module.go Normal file
View file

@ -0,0 +1,14 @@
package network
import (
"github.com/Safing/portbase/modules"
)
func init() {
modules.Register("network", nil, start, nil, "database")
}
func start() error {
go cleaner()
return registerAsDatabase()
}

View file

@ -11,6 +11,7 @@ var (
cleanDomainRegex = regexp.MustCompile("^((xn--)?[a-z0-9-_]{0,61}[a-z0-9]{1,1}\\.)*(xn--)?([a-z0-9-]{1,61}|[a-z0-9-]{1,30}\\.[a-z]{2,}\\.)$")
)
// IsValidFqdn returns whether the given string is a valid fqdn.
func IsValidFqdn(fqdn string) bool {
return cleanDomainRegex.MatchString(fqdn)
}

View file

@ -4,95 +4,101 @@ package netutils
import "net"
// IP types
// IP classifications
const (
hostLocal int8 = iota
linkLocal
siteLocal
global
localMulticast
globalMulticast
invalid
HostLocal int8 = iota
LinkLocal
SiteLocal
Global
LocalMulticast
GlobalMulticast
Invalid
)
func classifyAddress(ip net.IP) int8 {
// ClassifyIP returns the classification for the given IP address.
func ClassifyIP(ip net.IP) int8 {
if ip4 := ip.To4(); ip4 != nil {
// IPv4
switch {
case ip4[0] == 127:
// 127.0.0.0/8
return hostLocal
return HostLocal
case ip4[0] == 169 && ip4[1] == 254:
// 169.254.0.0/16
return linkLocal
return LinkLocal
case ip4[0] == 10:
// 10.0.0.0/8
return siteLocal
return SiteLocal
case ip4[0] == 172 && ip4[1]&0xf0 == 16:
// 172.16.0.0/12
return siteLocal
return SiteLocal
case ip4[0] == 192 && ip4[1] == 168:
// 192.168.0.0/16
return siteLocal
return SiteLocal
case ip4[0] == 224:
// 224.0.0.0/8
return localMulticast
return LocalMulticast
case ip4[0] >= 225 && ip4[0] <= 239:
// 225.0.0.0/8 - 239.0.0.0/8
return globalMulticast
return GlobalMulticast
case ip4[0] >= 240:
// 240.0.0.0/8 - 255.0.0.0/8
return invalid
return Invalid
default:
return global
return Global
}
} else if len(ip) == net.IPv6len {
// IPv6
switch {
case ip.Equal(net.IPv6loopback):
return hostLocal
return HostLocal
case ip[0]&0xfe == 0xfc:
// fc00::/7
return siteLocal
return SiteLocal
case ip[0] == 0xfe && ip[1]&0xc0 == 0x80:
// fe80::/10
return linkLocal
return LinkLocal
case ip[0] == 0xff && ip[1] <= 0x05:
// ff00::/16 - ff05::/16
return localMulticast
return LocalMulticast
case ip[0] == 0xff:
// other ff00::/8
return globalMulticast
return GlobalMulticast
default:
return global
return Global
}
}
return invalid
return Invalid
}
// IPIsLocal returns true if the given IP is a site-local or link-local address
func IPIsLocal(ip net.IP) bool {
switch classifyAddress(ip) {
case siteLocal:
// IPIsLocalhost returns whether the IP refers to the host itself.
func IPIsLocalhost(ip net.IP) bool {
return ClassifyIP(ip) == HostLocal
}
// IPIsLAN returns true if the given IP is a site-local or link-local address.
func IPIsLAN(ip net.IP) bool {
switch ClassifyIP(ip) {
case SiteLocal:
return true
case linkLocal:
case LinkLocal:
return true
default:
return false
}
}
// IPIsGlobal returns true if the given IP is a global address
// IPIsGlobal returns true if the given IP is a global address.
func IPIsGlobal(ip net.IP) bool {
return classifyAddress(ip) == global
return ClassifyIP(ip) == Global
}
// IPIsLinkLocal returns true if the given IP is a link-local address
// IPIsLinkLocal returns true if the given IP is a link-local address.
func IPIsLinkLocal(ip net.IP) bool {
return classifyAddress(ip) == linkLocal
return ClassifyIP(ip) == LinkLocal
}
// IPIsSiteLocal returns true if the given IP is a site-local address
// IPIsSiteLocal returns true if the given IP is a site-local address.
func IPIsSiteLocal(ip net.IP) bool {
return classifyAddress(ip) == siteLocal
return ClassifyIP(ip) == SiteLocal
}

View file

@ -6,14 +6,14 @@ import (
)
func TestIPClassification(t *testing.T) {
testClassification(t, net.IPv4(71, 87, 113, 211), global)
testClassification(t, net.IPv4(127, 0, 0, 1), hostLocal)
testClassification(t, net.IPv4(127, 255, 255, 1), hostLocal)
testClassification(t, net.IPv4(192, 168, 172, 24), siteLocal)
testClassification(t, net.IPv4(71, 87, 113, 211), Global)
testClassification(t, net.IPv4(127, 0, 0, 1), HostLocal)
testClassification(t, net.IPv4(127, 255, 255, 1), HostLocal)
testClassification(t, net.IPv4(192, 168, 172, 24), SiteLocal)
}
func testClassification(t *testing.T, ip net.IP, expectedClassification int8) {
c := classifyAddress(ip)
c := ClassifyIP(ip)
if c != expectedClassification {
t.Errorf("%s is %s, expected %s", ip, classificationString(c), classificationString(expectedClassification))
}
@ -21,19 +21,19 @@ func testClassification(t *testing.T, ip net.IP, expectedClassification int8) {
func classificationString(c int8) string {
switch c {
case hostLocal:
case HostLocal:
return "hostLocal"
case linkLocal:
case LinkLocal:
return "linkLocal"
case siteLocal:
case SiteLocal:
return "siteLocal"
case global:
case Global:
return "global"
case localMulticast:
case LocalMulticast:
return "localMulticast"
case globalMulticast:
case GlobalMulticast:
return "globalMulticast"
case invalid:
case Invalid:
return "invalid"
default:
return "unknown"

View file

@ -106,10 +106,10 @@ type TCPUDPHeader struct {
}
type PacketBase struct {
connectionID string
Direction bool
InTunnel bool
Payload []byte
linkID string
Direction bool
InTunnel bool
Payload []byte
*IPHeader
*TCPUDPHeader
}
@ -146,25 +146,25 @@ func (pkt *PacketBase) IPVersion() IPVersion {
return pkt.Version
}
func (pkt *PacketBase) GetConnectionID() string {
if pkt.connectionID == "" {
pkt.createConnectionID()
func (pkt *PacketBase) GetLinkID() string {
if pkt.linkID == "" {
pkt.createLinkID()
}
return pkt.connectionID
return pkt.linkID
}
func (pkt *PacketBase) createConnectionID() {
func (pkt *PacketBase) createLinkID() {
if pkt.IPHeader.Protocol == TCP || pkt.IPHeader.Protocol == UDP {
if pkt.Direction {
pkt.connectionID = fmt.Sprintf("%d-%s-%d-%s-%d", pkt.Protocol, pkt.Dst, pkt.DstPort, pkt.Src, pkt.SrcPort)
pkt.linkID = fmt.Sprintf("%d-%s-%d-%s-%d", pkt.Protocol, pkt.Dst, pkt.DstPort, pkt.Src, pkt.SrcPort)
} else {
pkt.connectionID = fmt.Sprintf("%d-%s-%d-%s-%d", pkt.Protocol, pkt.Src, pkt.SrcPort, pkt.Dst, pkt.DstPort)
pkt.linkID = fmt.Sprintf("%d-%s-%d-%s-%d", pkt.Protocol, pkt.Src, pkt.SrcPort, pkt.Dst, pkt.DstPort)
}
} else {
if pkt.Direction {
pkt.connectionID = fmt.Sprintf("%d-%s-%s", pkt.Protocol, pkt.Dst, pkt.Src)
pkt.linkID = fmt.Sprintf("%d-%s-%s", pkt.Protocol, pkt.Dst, pkt.Src)
} else {
pkt.connectionID = fmt.Sprintf("%d-%s-%s", pkt.Protocol, pkt.Src, pkt.Dst)
pkt.linkID = fmt.Sprintf("%d-%s-%s", pkt.Protocol, pkt.Src, pkt.Dst)
}
}
}
@ -299,7 +299,7 @@ type Packet interface {
IsOutbound() bool
SetInbound()
SetOutbound()
GetConnectionID() string
GetLinkID() string
IPVersion() IPVersion
// MATCHING

View file

@ -0,0 +1,45 @@
package reference
import "strconv"
var (
protocolNames = map[uint8]string{
1: "ICMP",
2: "IGMP",
6: "TCP",
17: "UDP",
27: "RDP",
58: "ICMPv6",
33: "DCCP",
136: "UDPLite",
}
protocolNumbers = map[string]uint8{
"ICMP": 1,
"IGMP": 2,
"TCP": 6,
"UDP": 17,
"RDP": 27,
"DCCP": 33,
"ICMPv6": 58,
"UDPLite": 136,
}
)
// GetProtocolName returns the name of a IP protocol number.
func GetProtocolName(protocol uint8) (name string) {
name, ok := protocolNames[protocol]
if ok {
return name
}
return strconv.Itoa(int(protocol))
}
// GetProtocolNumber returns the number of a IP protocol name.
func GetProtocolNumber(protocol string) (number uint8, ok bool) {
number, ok = protocolNumbers[protocol]
if ok {
return number, true
}
return 0, false
}

View file

@ -2,20 +2,34 @@
package network
// Status describes the status of a connection.
// Verdict describes the decision made about a connection or link.
type Verdict uint8
// List of values a Status can have
const (
// UNDECIDED is the default status of new connections
UNDECIDED Verdict = iota
CANTSAY
ACCEPT
BLOCK
DROP
RerouteToNameserver
RerouteToTunnel
)
// Packer Directions
const (
Inbound = true
Outbound = false
)
// Non-Domain Connections
const (
IncomingHost = "IH"
IncomingLAN = "IL"
IncomingInternet = "II"
IncomingInvalid = "IX"
PeerHost = "PH"
PeerLAN = "PL"
PeerInternet = "PI"
PeerInvalid = "PX"
)

31
network/unknown.go Normal file
View file

@ -0,0 +1,31 @@
package network
import "github.com/Safing/portmaster/process"
// Static reasons
const (
ReasonUnknownProcess = "unknown connection owner: process could not be found"
)
var (
UnknownDirectConnection = &Connection{
Domain: "PI",
Direction: Outbound,
Verdict: DROP,
Reason: ReasonUnknownProcess,
process: process.UnknownProcess,
}
UnknownIncomingConnection = &Connection{
Domain: "II",
Direction: Inbound,
Verdict: DROP,
Reason: ReasonUnknownProcess,
process: process.UnknownProcess,
}
)
func init() {
UnknownDirectConnection.Save()
UnknownIncomingConnection.Save()
}

View file

@ -1,395 +0,0 @@
// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file.
package portmaster
import (
"net"
"os"
"strings"
"github.com/Safing/safing-core/intel"
"github.com/Safing/safing-core/log"
"github.com/Safing/safing-core/network"
"github.com/Safing/safing-core/network/netutils"
"github.com/Safing/safing-core/network/packet"
"github.com/Safing/safing-core/port17/mode"
"github.com/Safing/safing-core/profiles"
"github.com/agext/levenshtein"
)
// use https://github.com/agext/levenshtein
// Call order:
//
// 1. DecideOnConnectionBeforeIntel (if connecting to domain)
// is called when a DNS query is made, before the query is resolved
// 2. DecideOnConnectionAfterIntel (if connecting to domain)
// is called when a DNS query is made, after the query is resolved
// 3. DecideOnConnection
// is called when the first packet of the first link of the connection arrives
// 4. DecideOnLink
// is called when when the first packet of a link arrives only if connection has verdict UNDECIDED or CANTSAY
func DecideOnConnectionBeforeIntel(connection *network.Connection, fqdn string) {
// check:
// Profile.DomainWhitelist
// Profile.Flags
// - process specific: System, Admin, User
// - network specific: Internet, LocalNet
// grant self
if connection.Process().Pid == os.Getpid() {
log.Infof("sheriff: granting own connection %s", connection)
connection.Accept()
return
}
// check if there is a profile
profile := connection.Process().Profile
if profile == nil {
log.Infof("sheriff: no profile, denying connection %s", connection)
connection.AddReason("no profile")
connection.Block()
return
}
// check user class
if profile.Flags.Has(profiles.System) {
if !connection.Process().IsSystem() {
log.Infof("sheriff: denying connection %s, profile has System flag set, but process is not executed by System", connection)
connection.AddReason("must be executed by system")
connection.Block()
return
}
}
if profile.Flags.Has(profiles.Admin) {
if !connection.Process().IsAdmin() {
log.Infof("sheriff: denying connection %s, profile has Admin flag set, but process is not executed by Admin", connection)
connection.AddReason("must be executed by admin")
connection.Block()
return
}
}
if profile.Flags.Has(profiles.User) {
if !connection.Process().IsUser() {
log.Infof("sheriff: denying connection %s, profile has User flag set, but process is not executed by a User", connection)
connection.AddReason("must be executed by user")
connection.Block()
return
}
}
// check for any network access
if !profile.Flags.Has(profiles.Internet) && !profile.Flags.Has(profiles.LocalNet) {
log.Infof("sheriff: denying connection %s, profile denies Internet and local network access", connection)
connection.Block()
return
}
// check domain whitelist/blacklist
if len(profile.DomainWhitelist) > 0 {
matched := false
for _, entry := range profile.DomainWhitelist {
if !strings.HasSuffix(entry, ".") {
entry += "."
}
if strings.HasPrefix(entry, "*") {
if strings.HasSuffix(fqdn, strings.Trim(entry, "*")) {
matched = true
break
}
} else {
if entry == fqdn {
matched = true
break
}
}
}
if matched {
if profile.DomainWhitelistIsBlacklist {
log.Infof("sheriff: denying connection %s, profile has %s in domain blacklist", connection, fqdn)
connection.AddReason("domain blacklisted")
connection.Block()
return
}
} else {
if !profile.DomainWhitelistIsBlacklist {
log.Infof("sheriff: denying connection %s, profile does not have %s in domain whitelist", connection, fqdn)
connection.AddReason("domain not in whitelist")
connection.Block()
return
}
}
}
}
func DecideOnConnectionAfterIntel(connection *network.Connection, fqdn string, rrCache *intel.RRCache) *intel.RRCache {
// check:
// TODO: Profile.ClassificationBlacklist
// TODO: Profile.ClassificationWhitelist
// Profile.Flags
// - network specific: Strict
// check if there is a profile
profile := connection.Process().Profile
if profile == nil {
log.Infof("sheriff: no profile, denying connection %s", connection)
connection.AddReason("no profile")
connection.Block()
return rrCache
}
// check Strict flag
// TODO: drastically improve this!
if profile.Flags.Has(profiles.Strict) {
matched := false
pathElements := strings.Split(connection.Process().Path, "/")
if len(pathElements) > 2 {
pathElements = pathElements[len(pathElements)-2:]
}
domainElements := strings.Split(fqdn, ".")
matchLoop:
for _, domainElement := range domainElements {
for _, pathElement := range pathElements {
if levenshtein.Match(domainElement, pathElement, nil) > 0.5 {
matched = true
break matchLoop
}
}
if levenshtein.Match(domainElement, profile.Name, nil) > 0.5 {
matched = true
break matchLoop
}
if levenshtein.Match(domainElement, connection.Process().Name, nil) > 0.5 {
matched = true
break matchLoop
}
}
if !matched {
log.Infof("sheriff: denying connection %s, profile has declared Strict flag and no match to domain was found", connection)
connection.AddReason("domain does not relate to process")
connection.Block()
return rrCache
}
}
// tunneling
// TODO: link this to real status
port17Active := mode.Client()
if port17Active {
tunnelInfo, err := AssignTunnelIP(fqdn)
if err != nil {
log.Errorf("portmaster: could not get tunnel IP for routing %s: %s", connection, err)
return nil // return nxDomain
}
// save original reply
tunnelInfo.RRCache = rrCache
// return tunnel IP
return tunnelInfo.ExportTunnelIP()
}
return rrCache
}
func DecideOnConnection(connection *network.Connection, pkt packet.Packet) {
// check:
// Profile.Flags
// - process specific: System, Admin, User
// - network specific: Internet, LocalNet, Service, Directconnect
// grant self
if connection.Process().Pid == os.Getpid() {
log.Infof("sheriff: granting own connection %s", connection)
connection.Accept()
return
}
// check if there is a profile
profile := connection.Process().Profile
if profile == nil {
log.Infof("sheriff: no profile, denying connection %s", connection)
connection.AddReason("no profile")
connection.Block()
return
}
// check user class
if profile.Flags.Has(profiles.System) {
if !connection.Process().IsSystem() {
log.Infof("sheriff: denying connection %s, profile has System flag set, but process is not executed by System", connection)
connection.AddReason("must be executed by system")
connection.Block()
return
}
}
if profile.Flags.Has(profiles.Admin) {
if !connection.Process().IsAdmin() {
log.Infof("sheriff: denying connection %s, profile has Admin flag set, but process is not executed by Admin", connection)
connection.AddReason("must be executed by admin")
connection.Block()
return
}
}
if profile.Flags.Has(profiles.User) {
if !connection.Process().IsUser() {
log.Infof("sheriff: denying connection %s, profile has User flag set, but process is not executed by a User", connection)
connection.AddReason("must be executed by user")
connection.Block()
return
}
}
// check for any network access
if !profile.Flags.Has(profiles.Internet) && !profile.Flags.Has(profiles.LocalNet) {
log.Infof("sheriff: denying connection %s, profile denies Internet and local network access", connection)
connection.AddReason("no network access allowed")
connection.Block()
return
}
switch connection.Domain {
case "I":
// check Service flag
if !profile.Flags.Has(profiles.Service) {
log.Infof("sheriff: denying connection %s, profile does not declare service", connection)
connection.AddReason("not a service")
connection.Drop()
return
}
// check if incoming connections are allowed on any port, but only if there no other restrictions
if !!profile.Flags.Has(profiles.Internet) && !!profile.Flags.Has(profiles.LocalNet) && len(profile.ListenPorts) == 0 {
log.Infof("sheriff: granting connection %s, profile allows incoming connections from anywhere and on any port", connection)
connection.Accept()
return
}
case "D":
// check Directconnect flag
if !profile.Flags.Has(profiles.Directconnect) {
log.Infof("sheriff: denying connection %s, profile does not declare direct connections", connection)
connection.AddReason("direct connections (without DNS) not allowed")
connection.Drop()
return
}
}
log.Infof("sheriff: could not decide on connection %s, deciding on per-link basis", connection)
connection.CantSay()
}
func DecideOnLink(connection *network.Connection, link *network.Link, pkt packet.Packet) {
// check:
// Profile.Flags
// - network specific: Internet, LocalNet
// Profile.ConnectPorts
// Profile.ListenPorts
// check if there is a profile
profile := connection.Process().Profile
if profile == nil {
log.Infof("sheriff: no profile, denying %s", link)
link.AddReason("no profile")
link.UpdateVerdict(network.BLOCK)
return
}
// check LocalNet and Internet flags
var remoteIP net.IP
if connection.Direction {
remoteIP = pkt.GetIPHeader().Src
} else {
remoteIP = pkt.GetIPHeader().Dst
}
if netutils.IPIsLocal(remoteIP) {
if !profile.Flags.Has(profiles.LocalNet) {
log.Infof("sheriff: dropping link %s, profile does not allow communication in the local network", link)
link.AddReason("profile does not allow access to local network")
link.UpdateVerdict(network.BLOCK)
return
}
} else {
if !profile.Flags.Has(profiles.Internet) {
log.Infof("sheriff: dropping link %s, profile does not allow communication with the Internet", link)
link.AddReason("profile does not allow access to the Internet")
link.UpdateVerdict(network.BLOCK)
return
}
}
// check connect ports
if connection.Domain != "I" && len(profile.ConnectPorts) > 0 {
tcpUdpHeader := pkt.GetTCPUDPHeader()
if tcpUdpHeader == nil {
log.Infof("sheriff: blocking link %s, profile has declared connect port whitelist, but link is not TCP/UDP", link)
link.AddReason("profile has declared connect port whitelist, but link is not TCP/UDP")
link.UpdateVerdict(network.BLOCK)
return
}
// packet *should* be outbound, but we could be deciding on an already active connection.
var remotePort uint16
if connection.Direction {
remotePort = tcpUdpHeader.SrcPort
} else {
remotePort = tcpUdpHeader.DstPort
}
matched := false
for _, port := range profile.ConnectPorts {
if remotePort == port {
matched = true
break
}
}
if !matched {
log.Infof("sheriff: blocking link %s, remote port %d not in profile connect port whitelist", link, remotePort)
link.AddReason("destination port not in whitelist")
link.UpdateVerdict(network.BLOCK)
return
}
}
// check listen ports
if connection.Domain == "I" && len(profile.ListenPorts) > 0 {
tcpUdpHeader := pkt.GetTCPUDPHeader()
if tcpUdpHeader == nil {
log.Infof("sheriff: dropping link %s, profile has declared listen port whitelist, but link is not TCP/UDP", link)
link.AddReason("profile has declared listen port whitelist, but link is not TCP/UDP")
link.UpdateVerdict(network.DROP)
return
}
// packet *should* be inbound, but we could be deciding on an already active connection.
var localPort uint16
if connection.Direction {
localPort = tcpUdpHeader.DstPort
} else {
localPort = tcpUdpHeader.SrcPort
}
matched := false
for _, port := range profile.ListenPorts {
if localPort == port {
matched = true
break
}
}
if !matched {
log.Infof("sheriff: blocking link %s, local port %d not in profile listen port whitelist", link, localPort)
link.AddReason("listen port not in whitelist")
link.UpdateVerdict(network.BLOCK)
return
}
}
log.Infof("sheriff: accepting link %s", link)
link.UpdateVerdict(network.ACCEPT)
}

107
process/database.go Normal file
View file

@ -0,0 +1,107 @@
package process
import (
"fmt"
"sync"
"time"
"github.com/Safing/portbase/database"
"github.com/Safing/portmaster/profile"
"github.com/tevino/abool"
)
var (
processes = make(map[int]*Process)
processesLock sync.RWMutex
dbController *database.Controller
dbControllerFlag = abool.NewBool(false)
)
// GetProcessFromStorage returns a process from the internal storage.
func GetProcessFromStorage(pid int) (*Process, bool) {
processesLock.RLock()
defer processesLock.RUnlock()
p, ok := processes[pid]
return p, ok
}
// All returns a copy of all process objects.
func All() []*Process {
processesLock.RLock()
defer processesLock.RUnlock()
all := make([]*Process, 0, len(processes))
for _, proc := range processes {
all = append(all, proc)
}
return all
}
// Save saves the process to the internal state and pushes an update.
func (p *Process) Save() {
p.Lock()
defer p.Unlock()
if !p.KeyIsSet() {
p.SetKey(fmt.Sprintf("network:tree/%d", p.Pid))
p.CreateMeta()
}
processesLock.RLock()
_, ok := processes[p.Pid]
processesLock.RUnlock()
if !ok {
processesLock.Lock()
processes[p.Pid] = p
processesLock.Unlock()
}
if dbControllerFlag.IsSet() {
go dbController.PushUpdate(p)
}
}
// Delete deletes a process from the storage and propagates the change.
func (p *Process) Delete() {
p.Lock()
defer p.Unlock()
processesLock.Lock()
delete(processes, p.Pid)
processesLock.Unlock()
p.Meta().Delete()
if dbControllerFlag.IsSet() {
go dbController.PushUpdate(p)
}
// TODO: this should not be necessary, as processes should always have a profileSet.
if p.profileSet != nil {
profile.DeactivateProfileSet(p.profileSet)
}
}
// CleanProcessStorage cleans the storage from old processes.
func CleanProcessStorage(thresholdDuration time.Duration) {
processesLock.Lock()
defer processesLock.Unlock()
threshold := time.Now().Add(-thresholdDuration).Unix()
for _, p := range processes {
p.Lock()
if p.FirstConnectionEstablished < threshold && p.ConnectionCount == 0 {
go p.Delete()
}
p.Unlock()
}
}
// SetDBController sets the database controller and allows the package to push database updates on a save. It must be set by the package that registers the "network" database.
func SetDBController(controller *database.Controller) {
dbController = controller
dbControllerFlag.Set()
}

View file

@ -1,21 +1,7 @@
// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file.
/*
Profiles
Profiles describe the network behaviour
Profiles are found in 3 different paths:
- /Me/Profiles/: Profiles used for this system
- /Data/Profiles/: Profiles supplied by Safing
- /Company/Profiles/: Profiles supplied by the company
When a program wants to use the network for the first time, Safing first searches for a Profile in the Company namespace, then in the Data namespace. If neither is found, it searches for a default profile in the same order.
Default profiles are profiles with a path ending with a "/". The default profile with the longest matching path is chosen.
Package process fetches process and socket information from the operating system.
It can find the process owning a network connection.
*/
package process

43
process/executable.go Normal file
View file

@ -0,0 +1,43 @@
// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file.
package process
import (
"crypto"
"encoding/hex"
"hash"
"io"
"os"
)
// GetExecHash returns the hash of the executable with the given algorithm.
func (p *Process) GetExecHash(algorithm string) (string, error) {
sum, ok := p.ExecHashes[algorithm]
if ok {
return sum, nil
}
var hasher hash.Hash
switch algorithm {
case "md5":
hasher = crypto.MD5.New()
case "sha1":
hasher = crypto.SHA1.New()
case "sha256":
hasher = crypto.SHA256.New()
}
file, err := os.Open(p.Path)
if err != nil {
return "", err
}
_, err = io.Copy(hasher, file)
if err != nil {
return "", err
}
sum = hex.EncodeToString(hasher.Sum(nil))
p.ExecHashes[algorithm] = sum
return sum, nil
}

View file

@ -1,74 +0,0 @@
// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file.
package process
import (
"github.com/Safing/safing-core/database"
"strings"
"time"
datastore "github.com/ipfs/go-datastore"
)
// ExecutableSignature stores a signature of an executable.
type ExecutableSignature []byte
// FileInfo stores (security) information about a file.
type FileInfo struct {
database.Base
HumanName string
Owners []string
ApproxLastSeen int64
Signature *ExecutableSignature
}
var fileInfoModel *FileInfo // only use this as parameter for database.EnsureModel-like functions
func init() {
database.RegisterModel(fileInfoModel, func() database.Model { return new(FileInfo) })
}
// Create saves FileInfo with the provided name in the default namespace.
func (m *FileInfo) Create(name string) error {
return m.CreateObject(&database.FileInfoCache, name, m)
}
// CreateInNamespace saves FileInfo with the provided name in the provided namespace.
func (m *FileInfo) CreateInNamespace(namespace *datastore.Key, name string) error {
return m.CreateObject(namespace, name, m)
}
// Save saves FileInfo.
func (m *FileInfo) Save() error {
return m.SaveObject(m)
}
// getFileInfo fetches FileInfo with the provided name from the default namespace.
func getFileInfo(name string) (*FileInfo, error) {
return getFileInfoFromNamespace(&database.FileInfoCache, name)
}
// getFileInfoFromNamespace fetches FileInfo with the provided name from the provided namespace.
func getFileInfoFromNamespace(namespace *datastore.Key, name string) (*FileInfo, error) {
object, err := database.GetAndEnsureModel(namespace, name, fileInfoModel)
if err != nil {
return nil, err
}
model, ok := object.(*FileInfo)
if !ok {
return nil, database.NewMismatchError(object, fileInfoModel)
}
return model, nil
}
// GetFileInfo gathers information about a file and returns *FileInfo
func GetFileInfo(path string) *FileInfo {
// TODO: actually get file information
// TODO: try to load from DB
// TODO: save to DB (key: hash of some sorts)
splittedPath := strings.Split("/", path)
return &FileInfo{
HumanName: splittedPath[len(splittedPath)-1],
ApproxLastSeen: time.Now().Unix(),
}
}

View file

@ -4,14 +4,17 @@ import (
"errors"
"net"
"github.com/Safing/safing-core/network/packet"
"github.com/Safing/portbase/log"
"github.com/Safing/portmaster/network/packet"
)
// Errors
var (
ErrConnectionNotFound = errors.New("could not find connection")
ErrProcessNotFound = errors.New("could not find process")
ErrConnectionNotFound = errors.New("could not find connection in system state tables")
ErrProcessNotFound = errors.New("could not find process in system state tables")
)
// GetPidByPacket returns the pid of the owner of the packet.
func GetPidByPacket(pkt packet.Packet) (pid int, direction bool, err error) {
var localIP net.IP
@ -50,26 +53,33 @@ func GetPidByPacket(pkt packet.Packet) (pid int, direction bool, err error) {
}
// GetProcessByPacket returns the process that owns the given packet.
func GetProcessByPacket(pkt packet.Packet) (process *Process, direction bool, err error) {
var pid int
pid, direction, err = GetPidByPacket(pkt)
if pid < 0 {
return nil, direction, ErrConnectionNotFound
}
if err != nil {
return nil, direction, err
}
if pid < 0 {
return nil, direction, ErrConnectionNotFound
}
process, err = GetOrFindProcess(pid)
if err != nil {
return nil, direction, err
}
err = process.FindProfiles()
if err != nil {
log.Errorf("failed to find profiles for process %s: %s", process.String(), err)
}
return process, direction, nil
}
// GetPidByEndpoints returns the pid of the owner of the described link.
func GetPidByEndpoints(localIP net.IP, localPort uint16, remoteIP net.IP, remotePort uint16, protocol packet.IPProtocol) (pid int, direction bool, err error) {
ipVersion := packet.IPv4
@ -92,6 +102,7 @@ func GetPidByEndpoints(localIP net.IP, localPort uint16, remoteIP net.IP, remote
}
// GetProcessByEndpoints returns the process that owns the described link.
func GetProcessByEndpoints(localIP net.IP, localPort uint16, remoteIP net.IP, remotePort uint16, protocol packet.IPProtocol) (process *Process, err error) {
var pid int
@ -108,41 +119,16 @@ func GetProcessByEndpoints(localIP net.IP, localPort uint16, remoteIP net.IP, re
return nil, err
}
err = process.FindProfiles()
if err != nil {
log.Errorf("failed to find profiles for process %s: %s", process.String(), err)
}
return process, nil
}
// GetActiveConnectionIDs returns a list of all active connection IDs.
func GetActiveConnectionIDs() []string {
return getActiveConnectionIDs()
}
// func GetProcessByPid(pid int) *Process {
// process, err := GetOrFindProcess(pid)
// if err != nil {
// log.Warningf("process: failed to get process %d: %s", pid, err)
// return nil
// }
// return process
// }
// func GetProcessOfConnection(localIP *net.IP, localPort uint16, protocol uint8) (process *Process, status uint8) {
// pid, status := GetPidOfConnection(localIP, localPort, protocol)
// if status == Success {
// process = GetProcessByPid(pid)
// if process == nil {
// return nil, NoProcessInfo
// }
// }
// return
// }
// func GetProcessByPacket(pkt packet.Packet) (process *Process, direction bool, status uint8) {
// pid, direction, status := GetPidByPacket(pkt)
// if status == Success {
// process = GetProcessByPid(pid)
// if process == nil {
// return nil, direction, NoProcessInfo
// }
// }
// return
// }

View file

@ -1,6 +1,8 @@
package process
import "github.com/Safing/safing-core/process/proc"
import (
"github.com/Safing/portmaster/process/proc"
)
var (
getTCP4PacketInfo = proc.GetTCP4PacketInfo

Binary file not shown.

115
process/matching.go Normal file
View file

@ -0,0 +1,115 @@
package process
import (
"github.com/Safing/portbase/database"
"github.com/Safing/portbase/database/query"
"github.com/Safing/portbase/log"
"github.com/Safing/portmaster/profile"
)
var (
profileDB = database.NewInterface(nil)
)
// FindProfiles finds and assigns a profile set to the process.
func (p *Process) FindProfiles() error {
p.Lock()
defer p.Unlock()
// only find profiles if not already done.
if p.profileSet != nil {
return nil
}
// User Profile
it, err := profileDB.Query(query.New(profile.MakeProfileKey(profile.UserNamespace, "")).Where(query.Where("LinkedPath", query.SameAs, p.Path)))
if err != nil {
return err
}
var userProfile *profile.Profile
for r := range it.Next {
it.Cancel()
userProfile, err = profile.EnsureProfile(r)
if err != nil {
return err
}
break
}
if it.Err() != nil {
return it.Err()
}
// create new profile if it does not exist.
if userProfile == nil {
// create new profile
userProfile = profile.New()
userProfile.Name = p.ExecName
userProfile.LinkedPath = p.Path
}
if userProfile.MarkUsed() {
userProfile.Save(profile.UserNamespace)
}
// Stamp
// Find/Re-evaluate Stamp profile
// 1. check linked stamp profile
// 2. if last check is was more than a week ago, fetch from stamp:
// 3. send path identifier to stamp
// 4. evaluate all returned profiles
// 5. select best
// 6. link stamp profile to user profile
// FIXME: implement!
p.UserProfileKey = userProfile.Key()
p.profileSet = profile.NewSet(userProfile, nil)
go p.Save()
return nil
}
func selectProfile(p *Process, profs []*profile.Profile) (selectedProfile *profile.Profile) {
var highestScore int
for _, prof := range profs {
score := matchProfile(p, prof)
if score > highestScore {
selectedProfile = prof
}
}
return
}
func matchProfile(p *Process, prof *profile.Profile) (score int) {
for _, fp := range prof.Fingerprints {
score += matchFingerprint(p, fp)
}
return
}
func matchFingerprint(p *Process, fp *profile.Fingerprint) (score int) {
if !fp.MatchesOS() {
return 0
}
switch fp.Type {
case "full_path":
if p.Path == fp.Value {
}
return profile.GetFingerprintWeight(fp.Type)
case "partial_path":
// FIXME: if full_path matches, do not match partial paths
return profile.GetFingerprintWeight(fp.Type)
case "md5_sum", "sha1_sum", "sha256_sum":
// FIXME: one sum is enough, check sums in a grouped form, start with the best
sum, err := p.GetExecHash(fp.Type)
if err != nil {
log.Errorf("process: failed to get hash of executable: %s", err)
} else if sum == fp.Value {
return profile.GetFingerprintWeight(fp.Type)
}
}
return 0
}

View file

@ -13,14 +13,19 @@ const (
NoProcess
)
func GetPidOfConnection(localIP *net.IP, localPort uint16, protocol uint8) (pid int, status uint8) {
var (
waitTime = 15 * time.Millisecond
)
// GetPidOfConnection returns the PID of the given connection.
func GetPidOfConnection(localIP net.IP, localPort uint16, protocol uint8) (pid int, status uint8) {
uid, inode, ok := getConnectionSocket(localIP, localPort, protocol)
if !ok {
uid, inode, ok = getListeningSocket(localIP, localPort, protocol)
for i := 0; i < 3 && !ok; i++ {
// give kernel some time, then try again
// log.Tracef("process: giving kernel some time to think")
time.Sleep(15 * time.Millisecond)
time.Sleep(waitTime)
uid, inode, ok = getConnectionSocket(localIP, localPort, protocol)
if !ok {
uid, inode, ok = getListeningSocket(localIP, localPort, protocol)
@ -30,27 +35,48 @@ func GetPidOfConnection(localIP *net.IP, localPort uint16, protocol uint8) (pid
return -1, NoSocket
}
}
pid, ok = GetPidOfInode(uid, inode)
for i := 0; i < 3 && !ok; i++ {
// give kernel some time, then try again
// log.Tracef("process: giving kernel some time to think")
time.Sleep(15 * time.Millisecond)
time.Sleep(waitTime)
pid, ok = GetPidOfInode(uid, inode)
}
if !ok {
return -1, NoProcess
}
return
}
func GetPidOfIncomingConnection(localIP *net.IP, localPort uint16, protocol uint8) (pid int, status uint8) {
// GetPidOfConnection returns the PID of the given incoming connection.
func GetPidOfIncomingConnection(localIP net.IP, localPort uint16, protocol uint8) (pid int, status uint8) {
uid, inode, ok := getListeningSocket(localIP, localPort, protocol)
if !ok {
return -1, NoSocket
// for TCP4 and UDP4, also try TCP6 and UDP6, as linux sometimes treats them as a single dual socket, and shows the IPv6 version.
switch protocol {
case TCP4:
uid, inode, ok = getListeningSocket(localIP, localPort, TCP6)
case UDP4:
uid, inode, ok = getListeningSocket(localIP, localPort, UDP6)
}
if !ok {
return -1, NoSocket
}
}
pid, ok = GetPidOfInode(uid, inode)
for i := 0; i < 3 && !ok; i++ {
// give kernel some time, then try again
// log.Tracef("process: giving kernel some time to think")
time.Sleep(waitTime)
pid, ok = GetPidOfInode(uid, inode)
}
if !ok {
return -1, NoProcess
}
return
}

View file

@ -6,39 +6,39 @@ import (
)
func GetTCP4PacketInfo(localIP net.IP, localPort uint16, remoteIP net.IP, remotePort uint16, pktDirection bool) (pid int, direction bool, err error) {
return search(TCP4, localIP, localPort, direction)
return search(TCP4, localIP, localPort, pktDirection)
}
func GetTCP6PacketInfo(localIP net.IP, localPort uint16, remoteIP net.IP, remotePort uint16, pktDirection bool) (pid int, direction bool, err error) {
return search(TCP6, localIP, localPort, direction)
return search(TCP6, localIP, localPort, pktDirection)
}
func GetUDP4PacketInfo(localIP net.IP, localPort uint16, remoteIP net.IP, remotePort uint16, pktDirection bool) (pid int, direction bool, err error) {
return search(UDP4, localIP, localPort, direction)
return search(UDP4, localIP, localPort, pktDirection)
}
func GetUDP6PacketInfo(localIP net.IP, localPort uint16, remoteIP net.IP, remotePort uint16, pktDirection bool) (pid int, direction bool, err error) {
return search(UDP6, localIP, localPort, direction)
return search(UDP6, localIP, localPort, pktDirection)
}
func search(protocol uint8, localIP net.IP, localPort uint16, pktDirection bool) (pid int, direction bool, err error) {
var status uint8
if pktDirection {
pid, status = GetPidOfIncomingConnection(&localIP, localPort, protocol)
pid, status = GetPidOfIncomingConnection(localIP, localPort, protocol)
if pid >= 0 {
return pid, true, nil
}
// pid, status = GetPidOfConnection(&localIP, localPort, protocol)
// pid, status = GetPidOfConnection(localIP, localPort, protocol)
// if pid >= 0 {
// return pid, false, nil
// }
} else {
pid, status = GetPidOfConnection(&localIP, localPort, protocol)
pid, status = GetPidOfConnection(localIP, localPort, protocol)
if pid >= 0 {
return pid, false, nil
}
// pid, status = GetPidOfIncomingConnection(&localIP, localPort, protocol)
// pid, status = GetPidOfIncomingConnection(localIP, localPort, protocol)
// if pid >= 0 {
// return pid, true, nil
// }

View file

@ -10,7 +10,7 @@ import (
"sync"
"syscall"
"github.com/Safing/safing-core/log"
"github.com/Safing/portbase/log"
)
var (

View file

@ -13,7 +13,7 @@ import (
"sync"
"unicode"
"github.com/Safing/safing-core/log"
"github.com/Safing/portbase/log"
)
/*
@ -81,7 +81,7 @@ var (
globalListeningUDP6 = make(map[uint16][]int)
)
func getConnectionSocket(localIP *net.IP, localPort uint16, protocol uint8) (int, int, bool) {
func getConnectionSocket(localIP net.IP, localPort uint16, protocol uint8) (int, int, bool) {
// listeningSocketsLock.Lock()
// defer listeningSocketsLock.Unlock()
@ -98,10 +98,10 @@ func getConnectionSocket(localIP *net.IP, localPort uint16, protocol uint8) (int
localIPHex = strings.ToUpper(hex.EncodeToString([]byte{localIPBytes[3], localIPBytes[2], localIPBytes[1], localIPBytes[0]}))
case TCP6:
procFile = TCP6Data
localIPHex = hex.EncodeToString([]byte(*localIP))
localIPHex = hex.EncodeToString([]byte(localIP))
case UDP6:
procFile = UDP6Data
localIPHex = hex.EncodeToString([]byte(*localIP))
localIPHex = hex.EncodeToString([]byte(localIP))
}
localPortHex := fmt.Sprintf("%04X", localPort)
@ -162,38 +162,38 @@ func getConnectionSocket(localIP *net.IP, localPort uint16, protocol uint8) (int
}
func getListeningSocket(localIP *net.IP, localPort uint16, protocol uint8) (uid, inode int, ok bool) {
func getListeningSocket(localIP net.IP, localPort uint16, protocol uint8) (uid, inode int, ok bool) {
listeningSocketsLock.Lock()
defer listeningSocketsLock.Unlock()
var addressListening *map[string][]int
var globalListening *map[uint16][]int
var addressListening map[string][]int
var globalListening map[uint16][]int
switch protocol {
case TCP4:
addressListening = &addressListeningTCP4
globalListening = &globalListeningTCP4
addressListening = addressListeningTCP4
globalListening = globalListeningTCP4
case UDP4:
addressListening = &addressListeningUDP4
globalListening = &globalListeningUDP4
addressListening = addressListeningUDP4
globalListening = globalListeningUDP4
case TCP6:
addressListening = &addressListeningTCP6
globalListening = &globalListeningTCP6
addressListening = addressListeningTCP6
globalListening = globalListeningTCP6
case UDP6:
addressListening = &addressListeningUDP6
globalListening = &globalListeningUDP6
addressListening = addressListeningUDP6
globalListening = globalListeningUDP6
}
data, ok := (*addressListening)[fmt.Sprintf("%s:%d", localIP, localPort)]
data, ok := addressListening[fmt.Sprintf("%s:%d", localIP, localPort)]
if !ok {
data, ok = (*globalListening)[localPort]
data, ok = globalListening[localPort]
}
if ok {
return data[0], data[1], true
}
updateListeners(protocol)
data, ok = (*addressListening)[fmt.Sprintf("%s:%d", localIP, localPort)]
data, ok = addressListening[fmt.Sprintf("%s:%d", localIP, localPort)]
if !ok {
data, ok = (*globalListening)[localPort]
data, ok = globalListening[localPort]
}
if ok {
return data[0], data[1], true
@ -206,7 +206,7 @@ func procDelimiter(c rune) bool {
return unicode.IsSpace(c) || c == ':'
}
func convertIPv4(data string) *net.IP {
func convertIPv4(data string) net.IP {
decoded, err := hex.DecodeString(data)
if err != nil {
log.Warningf("process: could not parse IPv4 %s: %s", data, err)
@ -217,10 +217,10 @@ func convertIPv4(data string) *net.IP {
return nil
}
ip := net.IPv4(decoded[3], decoded[2], decoded[1], decoded[0])
return &ip
return ip
}
func convertIPv6(data string) *net.IP {
func convertIPv6(data string) net.IP {
decoded, err := hex.DecodeString(data)
if err != nil {
log.Warningf("process: could not parse IPv6 %s: %s", data, err)
@ -231,7 +231,7 @@ func convertIPv6(data string) *net.IP {
return nil
}
ip := net.IP(decoded)
return &ip
return ip
}
func updateListeners(protocol uint8) {
@ -247,7 +247,7 @@ func updateListeners(protocol uint8) {
}
}
func getListenerMaps(procFile, zeroIP, socketStatusListening string, ipConverter func(string) *net.IP) (map[string][]int, map[uint16][]int) {
func getListenerMaps(procFile, zeroIP, socketStatusListening string, ipConverter func(string) net.IP) (map[string][]int, map[uint16][]int) {
addressListening := make(map[string][]int)
globalListening := make(map[uint16][]int)
@ -312,6 +312,7 @@ func getListenerMaps(procFile, zeroIP, socketStatusListening string, ipConverter
return addressListening, globalListening
}
// GetActiveConnectionIDs returns all connection IDs that are still marked as active by the OS.
func GetActiveConnectionIDs() []string {
var connections []string
@ -323,7 +324,7 @@ func GetActiveConnectionIDs() []string {
return connections
}
func getConnectionIDsFromSource(source string, protocol uint16, ipConverter func(string) *net.IP) []string {
func getConnectionIDsFromSource(source string, protocol uint16, ipConverter func(string) net.IP) []string {
var connections []string
// open file

View file

@ -22,14 +22,14 @@ func TestSockets(t *testing.T) {
t.Logf("addressListeningUDP6: %v", addressListeningUDP6)
t.Logf("globalListeningUDP6: %v", globalListeningUDP6)
getListeningSocket(&net.IPv4zero, 53, TCP4)
getListeningSocket(&net.IPv4zero, 53, UDP4)
getListeningSocket(&net.IPv6zero, 53, TCP6)
getListeningSocket(&net.IPv6zero, 53, UDP6)
getListeningSocket(net.IPv4zero, 53, TCP4)
getListeningSocket(net.IPv4zero, 53, UDP4)
getListeningSocket(net.IPv6zero, 53, TCP6)
getListeningSocket(net.IPv6zero, 53, UDP6)
// spotify: 192.168.0.102:5353 192.121.140.65:80
localIP := net.IPv4(192, 168, 127, 10)
uid, inode, ok := getConnectionSocket(&localIP, 46634, TCP4)
uid, inode, ok := getConnectionSocket(localIP, 46634, TCP4)
t.Logf("getConnectionSocket: %d %d %v", uid, inode, ok)
activeConnectionIDs := GetActiveConnectionIDs()

View file

@ -5,89 +5,93 @@ package process
import (
"fmt"
"runtime"
"strconv"
"strings"
"sync"
"time"
datastore "github.com/ipfs/go-datastore"
processInfo "github.com/shirou/gopsutil/process"
"github.com/Safing/safing-core/database"
"github.com/Safing/safing-core/log"
"github.com/Safing/safing-core/profiles"
"github.com/Safing/portbase/database/record"
"github.com/Safing/portbase/log"
"github.com/Safing/portmaster/profile"
)
// A Process represents a process running on the operating system
type Process struct {
database.Base
UserID int
UserName string
UserHome string
Pid int
ParentPid int
Path string
Cwd string
FileInfo *FileInfo
CmdLine string
FirstArg string
ProfileKey string
Profile *profiles.Profile
Name string
Icon string
record.Base
sync.Mutex
UserID int
UserName string
UserHome string
Pid int
ParentPid int
Path string
Cwd string
CmdLine string
FirstArg string
ExecName string
ExecHashes map[string]string
// ExecOwner ...
// ExecSignature ...
UserProfileKey string
profileSet *profile.Set
Name string
Icon string
// Icon is a path to the icon and is either prefixed "f:" for filepath, "d:" for database cache path or "c:"/"a:" for a the icon key to fetch it from a company / authoritative node and cache it in its own cache.
FirstConnectionEstablished int64
LastConnectionEstablished int64
ConnectionCount uint
}
var processModel *Process // only use this as parameter for database.EnsureModel-like functions
// ProfileSet returns the assigned profile set.
func (p *Process) ProfileSet() *profile.Set {
p.Lock()
defer p.Unlock()
func init() {
database.RegisterModel(processModel, func() database.Model { return new(Process) })
return p.profileSet
}
// Create saves Process with the provided name in the default namespace.
func (m *Process) Create(name string) error {
return m.CreateObject(&database.Processes, name, m)
}
// Strings returns a string represenation of process.
func (p *Process) String() string {
p.Lock()
defer p.Unlock()
// CreateInNamespace saves Process with the provided name in the provided namespace.
func (m *Process) CreateInNamespace(namespace *datastore.Key, name string) error {
return m.CreateObject(namespace, name, m)
}
// Save saves Process.
func (m *Process) Save() error {
return m.SaveObject(m)
}
// GetProcess fetches Process with the provided name from the default namespace.
func GetProcess(name string) (*Process, error) {
return GetProcessFromNamespace(&database.Processes, name)
}
// GetProcessFromNamespace fetches Process with the provided name from the provided namespace.
func GetProcessFromNamespace(namespace *datastore.Key, name string) (*Process, error) {
object, err := database.GetAndEnsureModel(namespace, name, processModel)
if err != nil {
return nil, err
}
model, ok := object.(*Process)
if !ok {
return nil, database.NewMismatchError(object, processModel)
}
return model, nil
}
func (m *Process) String() string {
if m == nil {
if p == nil {
return "?"
}
if m.Profile != nil && !m.Profile.Default {
return fmt.Sprintf("%s:%s:%d", m.UserName, m.Profile, m.Pid)
}
return fmt.Sprintf("%s:%s:%d", m.UserName, m.Path, m.Pid)
return fmt.Sprintf("%s:%s:%d", p.UserName, p.Path, p.Pid)
}
// AddConnection increases the connection counter and the last connection timestamp.
func (p *Process) AddConnection() {
p.Lock()
defer p.Unlock()
p.ConnectionCount++
p.LastConnectionEstablished = time.Now().Unix()
if p.FirstConnectionEstablished == 0 {
p.FirstConnectionEstablished = p.LastConnectionEstablished
}
}
// RemoveConnection lowers the connection counter by one.
func (p *Process) RemoveConnection() {
p.Lock()
defer p.Unlock()
if p.ConnectionCount > 0 {
p.ConnectionCount--
}
}
// GetOrFindProcess returns the process for the given PID.
func GetOrFindProcess(pid int) (*Process, error) {
process, err := GetProcess(strconv.Itoa(pid))
if err == nil {
process, ok := GetProcessFromStorage(pid)
if ok {
return process, nil
}
@ -96,13 +100,9 @@ func GetOrFindProcess(pid int) (*Process, error) {
}
switch {
case (pid == 0 && runtime.GOOS == "linux") || (pid == 4 && runtime.GOOS == "windows"):
case new.IsKernel():
new.UserName = "Kernel"
new.Name = "Operating System"
new.Profile = &profiles.Profile{
Name: "OS",
Flags: []int8{profiles.Internet, profiles.LocalNet, profiles.Directconnect, profiles.Service},
}
default:
pInfo, err := processInfo.NewProcess(int32(pid))
@ -113,7 +113,8 @@ func GetOrFindProcess(pid int) (*Process, error) {
// UID
// net yet implemented for windows
if runtime.GOOS == "linux" {
uids, err := pInfo.Uids()
var uids []int32
uids, err = pInfo.Uids()
if err != nil {
log.Warningf("process: failed to get UID: %s", err)
} else {
@ -167,85 +168,87 @@ func GetOrFindProcess(pid int) (*Process, error) {
// new.Icon, err =
// get Profile
processPath := new.Path
var applyProfile *profiles.Profile
iterations := 0
for applyProfile == nil {
// processPath := new.Path
// var applyProfile *profiles.Profile
// iterations := 0
// for applyProfile == nil {
//
// iterations++
// if iterations > 10 {
// log.Warningf("process: got into loop while getting profile for %s", new)
// break
// }
//
// applyProfile, err = profiles.GetActiveProfileByPath(processPath)
// if err == database.ErrNotFound {
// applyProfile, err = profiles.FindProfileByPath(processPath, new.UserHome)
// }
// if err != nil {
// log.Warningf("process: could not get profile for %s: %s", new, err)
// } else if applyProfile == nil {
// log.Warningf("process: no default profile found for %s", new)
// } else {
//
// // TODO: there is a lot of undefined behaviour if chaining framework profiles
//
// // process framework
// if applyProfile.Framework != nil {
// if applyProfile.Framework.FindParent > 0 {
// var ppid int32
// for i := uint8(1); i < applyProfile.Framework.FindParent; i++ {
// parent, err := pInfo.Parent()
// if err != nil {
// return nil, err
// }
// ppid = parent.Pid
// }
// if applyProfile.Framework.MergeWithParent {
// return GetOrFindProcess(int(ppid))
// }
// // processPath, err = os.Readlink(fmt.Sprintf("/proc/%d/exe", pid))
// // if err != nil {
// // return nil, fmt.Errorf("could not read /proc/%d/exe: %s", pid, err)
// // }
// continue
// }
//
// newCommand, err := applyProfile.Framework.GetNewPath(new.CmdLine, new.Cwd)
// if err != nil {
// return nil, err
// }
//
// // assign
// new.CmdLine = newCommand
// new.Path = strings.SplitN(newCommand, " ", 2)[0]
// processPath = new.Path
//
// // make sure we loop
// applyProfile = nil
// continue
// }
//
// // apply profile to process
// log.Debugf("process: applied profile to %s: %s", new, applyProfile)
// new.Profile = applyProfile
// new.ProfileKey = applyProfile.GetKey().String()
//
// // update Profile with Process icon if Profile does not have one
// if !new.Profile.Default && new.Icon != "" && new.Profile.Icon == "" {
// new.Profile.Icon = new.Icon
// new.Profile.Save()
// }
// }
// }
iterations++
if iterations > 10 {
log.Warningf("process: got into loop while getting profile for %s", new)
break
}
applyProfile, err = profiles.GetActiveProfileByPath(processPath)
if err == database.ErrNotFound {
applyProfile, err = profiles.FindProfileByPath(processPath, new.UserHome)
}
if err != nil {
log.Warningf("process: could not get profile for %s: %s", new, err)
} else if applyProfile == nil {
log.Warningf("process: no default profile found for %s", new)
} else {
// TODO: there is a lot of undefined behaviour if chaining framework profiles
// process framework
if applyProfile.Framework != nil {
if applyProfile.Framework.FindParent > 0 {
var ppid int32
for i := uint8(1); i < applyProfile.Framework.FindParent; i++ {
parent, err := pInfo.Parent()
if err != nil {
return nil, err
}
ppid = parent.Pid
}
if applyProfile.Framework.MergeWithParent {
return GetOrFindProcess(int(ppid))
}
// processPath, err = os.Readlink(fmt.Sprintf("/proc/%d/exe", pid))
// if err != nil {
// return nil, fmt.Errorf("could not read /proc/%d/exe: %s", pid, err)
// }
continue
}
newCommand, err := applyProfile.Framework.GetNewPath(new.CmdLine, new.Cwd)
if err != nil {
return nil, err
}
// assign
new.CmdLine = newCommand
new.Path = strings.SplitN(newCommand, " ", 2)[0]
processPath = new.Path
// make sure we loop
applyProfile = nil
continue
}
// apply profile to process
log.Debugf("process: applied profile to %s: %s", new, applyProfile)
new.Profile = applyProfile
new.ProfileKey = applyProfile.GetKey().String()
// update Profile with Process icon if Profile does not have one
if !new.Profile.Default && new.Icon != "" && new.Profile.Icon == "" {
new.Profile.Icon = new.Icon
new.Profile.Save()
}
}
}
// get FileInfo
new.FileInfo = GetFileInfo(new.Path)
// Executable Information
// FIXME: use os specific path seperator
splittedPath := strings.Split(new.Path, "/")
new.ExecName = splittedPath[len(splittedPath)-1]
}
// save to DB
new.Create(strconv.Itoa(new.Pid))
// save to storage
new.Save()
return new, nil
}

View file

@ -1,13 +1,21 @@
package process
// IsUser returns whether the process is run by a normal user.
func (m *Process) IsUser() bool {
return m.UserID >= 1000
}
// IsAdmin returns whether the process is run by an admin user.
func (m *Process) IsAdmin() bool {
return m.UserID >= 0
}
// IsSystem returns whether the process is run by the operating system.
func (m *Process) IsSystem() bool {
return m.UserID == 0
}
// IsKernel returns whether the process is the Kernel.
func (m *Process) IsKernel() bool {
return m.Pid == 0
}

View file

@ -2,15 +2,23 @@ package process
import "strings"
// IsUser returns whether the process is run by a normal user.
func (m *Process) IsUser() bool {
return m.Pid != 4 && // Kernel
!strings.HasPrefix(m.UserName, "NT-") // NT-Authority (localized!)
}
// IsAdmin returns whether the process is run by an admin user.
func (m *Process) IsAdmin() bool {
return strings.HasPrefix(m.UserName, "NT-") // NT-Authority (localized!)
}
// IsSystem returns whether the process is run by the operating system.
func (m *Process) IsSystem() bool {
return m.Pid == 4
}
// IsKernel returns whether the process is the Kernel.
func (m *Process) IsKernel() bool {
return m.Pid == 4
}

16
process/unknown.go Normal file
View file

@ -0,0 +1,16 @@
package process
var (
// UnknownProcess is used when a process cannot be found.
UnknownProcess = &Process{
UserID: -1,
UserName: "Unknown",
Pid: -1,
ParentPid: -1,
Name: "Unknown Processes",
}
)
func init() {
UnknownProcess.Save()
}

54
profile/active.go Normal file
View file

@ -0,0 +1,54 @@
package profile
import "sync"
var (
activeProfileSets = make(map[string]*Set)
activeProfileSetsLock sync.RWMutex
)
func activateProfileSet(set *Set) {
set.Lock()
defer set.Unlock()
activeProfileSetsLock.Lock()
defer activeProfileSetsLock.Unlock()
activeProfileSets[set.profiles[0].ID] = set
}
// DeactivateProfileSet marks a profile set as not active.
func DeactivateProfileSet(set *Set) {
set.Lock()
defer set.Unlock()
activeProfileSetsLock.Lock()
defer activeProfileSetsLock.Unlock()
delete(activeProfileSets, set.profiles[0].ID)
}
func updateActiveUserProfile(profile *Profile) {
activeProfileSetsLock.RLock()
defer activeProfileSetsLock.RUnlock()
activeSet, ok := activeProfileSets[profile.ID]
if ok {
activeSet.Lock()
defer activeSet.Unlock()
activeSet.profiles[0] = profile
}
}
func updateActiveStampProfile(profile *Profile) {
activeProfileSetsLock.RLock()
defer activeProfileSetsLock.RUnlock()
for _, activeSet := range activeProfileSets {
activeSet.Lock()
activeProfile := activeSet.profiles[2]
if activeProfile != nil {
activeProfile.Lock()
if activeProfile.ID == profile.ID {
activeSet.profiles[2] = profile
}
activeProfile.Unlock()
}
activeSet.Unlock()
}
}

8
profile/const.go Normal file
View file

@ -0,0 +1,8 @@
package profile
// Platform identifiers
const (
PlatformLinux = "linux"
PlatformWindows = "windows"
PlatformMac = "macos"
)

6
profile/const_darwin.go Normal file
View file

@ -0,0 +1,6 @@
package profile
// OS Identifier
const (
osIdentifier = PlatformMac
)

6
profile/const_linux.go Normal file
View file

@ -0,0 +1,6 @@
package profile
// OS Identifier
const (
osIdentifier = PlatformLinux
)

22
profile/database.go Normal file
View file

@ -0,0 +1,22 @@
package profile
import (
"github.com/Safing/portbase/database"
)
// core:profiles/user/12345-1234-125-1234-1235
// core:profiles/special/default
// /global
// core:profiles/stamp/12334-1235-1234-5123-1234
// core:profiles/identifier/base64
// Namespaces
const (
UserNamespace = "user"
StampNamespace = "stamp"
SpecialNamespace = "special"
)
var (
profileDB = database.NewInterface(nil)
)

44
profile/defaults.go Normal file
View file

@ -0,0 +1,44 @@
package profile
import (
"github.com/Safing/portmaster/status"
)
func makeDefaultGlobalProfile() *Profile {
return &Profile{
ID: "global",
Name: "Global Profile",
}
}
func makeDefaultFallbackProfile() *Profile {
return &Profile{
ID: "fallback",
Name: "Fallback Profile",
Flags: map[uint8]uint8{
// Profile Modes
Blacklist: status.SecurityLevelDynamic,
Prompt: status.SecurityLevelSecure,
Whitelist: status.SecurityLevelFortress,
// Network Locations
Internet: status.SecurityLevelsDynamicAndSecure,
LAN: status.SecurityLevelsDynamicAndSecure,
Localhost: status.SecurityLevelsAll,
// Specials
Related: status.SecurityLevelDynamic,
PeerToPeer: status.SecurityLevelDynamic,
},
ServiceEndpoints: []*EndpointPermission{
&EndpointPermission{
DomainOrIP: "",
Wildcard: true,
Protocol: 0,
StartPort: 0,
EndPort: 0,
Permit: false,
},
},
}
}

138
profile/endpoints.go Normal file
View file

@ -0,0 +1,138 @@
package profile
import (
"fmt"
"strconv"
"strings"
"github.com/Safing/portmaster/intel"
)
// Endpoints is a list of permitted or denied endpoints.
type Endpoints []*EndpointPermission
// EndpointPermission holds a decision about an endpoint.
type EndpointPermission struct {
DomainOrIP string
Wildcard bool
Protocol uint8
StartPort uint16
EndPort uint16
Permit bool
Created int64
}
// IsSet returns whether the Endpoints object is "set".
func (e Endpoints) IsSet() bool {
if len(e) > 0 {
return true
}
return false
}
// Check checks if the given domain is governed in the list of domains and returns whether it is permitted.
// If getDomainOfIP (returns reverse and forward dns matching domain name) is supplied, an IP will be resolved to a domain, if necessary.
func (e Endpoints) Check(domainOrIP string, protocol uint8, port uint16, checkReverseIP bool, securityLevel uint8) (permit bool, reason string, ok bool) {
// ip resolving
var cachedGetDomainOfIP func() string
if checkReverseIP {
var ipResolved bool
var ipName string
// setup caching wrapper
cachedGetDomainOfIP = func() string {
if !ipResolved {
result, err := intel.ResolveIPAndValidate(domainOrIP, securityLevel)
if err != nil {
// log.Debug()
ipName = result
}
ipResolved = true
}
return ipName
}
}
isDomain := strings.HasSuffix(domainOrIP, ".")
for _, entry := range e {
if entry != nil {
if ok, reason := entry.Matches(domainOrIP, protocol, port, isDomain, cachedGetDomainOfIP); ok {
return entry.Permit, reason, true
}
}
}
return false, "", false
}
func isSubdomainOf(domain, subdomain string) bool {
dotPrefixedDomain := "." + domain
return strings.HasSuffix(subdomain, dotPrefixedDomain)
}
// Matches checks whether the given endpoint has a managed permission. If getDomainOfIP (returns reverse and forward dns matching domain name) is supplied, this declares an incoming connection.
func (ep EndpointPermission) Matches(domainOrIP string, protocol uint8, port uint16, isDomain bool, getDomainOfIP func() string) (match bool, reason string) {
if ep.Protocol > 0 && protocol != ep.Protocol {
return false, ""
}
if ep.StartPort > 0 && (port < ep.StartPort || port > ep.EndPort) {
return false, ""
}
switch {
case ep.Wildcard && len(ep.DomainOrIP) == 0:
// host wildcard
return true, fmt.Sprintf("%s matches %s", domainOrIP, ep)
case domainOrIP == ep.DomainOrIP:
// host match
return true, fmt.Sprintf("%s matches %s", domainOrIP, ep)
case isDomain && ep.Wildcard && isSubdomainOf(ep.DomainOrIP, domainOrIP):
// subdomain match
return true, fmt.Sprintf("%s matches %s", domainOrIP, ep)
case !isDomain && getDomainOfIP != nil && getDomainOfIP() == ep.DomainOrIP:
// resolved IP match
return true, fmt.Sprintf("%s->%s matches %s", domainOrIP, getDomainOfIP(), ep)
case !isDomain && getDomainOfIP != nil && ep.Wildcard && isSubdomainOf(ep.DomainOrIP, getDomainOfIP()):
// resolved IP subdomain match
return true, fmt.Sprintf("%s->%s matches %s", domainOrIP, getDomainOfIP(), ep)
default:
// no match
return false, ""
}
}
func (e Endpoints) String() string {
var s []string
for _, entry := range e {
s = append(s, entry.String())
}
return fmt.Sprintf("[%s]", strings.Join(s, ", "))
}
func (ep EndpointPermission) String() string {
s := ep.DomainOrIP
s += " "
if ep.Protocol > 0 {
s += strconv.Itoa(int(ep.Protocol))
} else {
s += "*"
}
s += "/"
if ep.StartPort > 0 {
if ep.StartPort == ep.EndPort {
s += strconv.Itoa(int(ep.StartPort))
} else {
s += fmt.Sprintf("%d-%d", ep.StartPort, ep.EndPort)
}
} else {
s += "*"
}
return s
}

61
profile/endpoints_test.go Normal file
View file

@ -0,0 +1,61 @@
package profile
import (
"testing"
)
// TODO: RETIRED
// func testdeMatcher(t *testing.T, value string, expectedResult bool) {
// if domainEndingMatcher.MatchString(value) != expectedResult {
// if expectedResult {
// t.Errorf("domainEndingMatcher should match %s", value)
// } else {
// t.Errorf("domainEndingMatcher should not match %s", value)
// }
// }
// }
//
// func TestdomainEndingMatcher(t *testing.T) {
// testdeMatcher(t, "example.com", true)
// testdeMatcher(t, "com", true)
// testdeMatcher(t, "example.xn--lgbbat1ad8j", true)
// testdeMatcher(t, "xn--lgbbat1ad8j", true)
// testdeMatcher(t, "fe80::beef", false)
// testdeMatcher(t, "fe80::dead:beef", false)
// testdeMatcher(t, "10.2.3.4", false)
// testdeMatcher(t, "4", false)
// }
func TestEPString(t *testing.T) {
var endpoints Endpoints
endpoints = []*EndpointPermission{
&EndpointPermission{
DomainOrIP: "example.com",
Wildcard: false,
Protocol: 6,
Permit: true,
},
&EndpointPermission{
DomainOrIP: "8.8.8.8",
Protocol: 17, // TCP
StartPort: 53, // DNS
EndPort: 53,
Permit: false,
},
&EndpointPermission{
DomainOrIP: "google.com",
Wildcard: true,
Permit: false,
},
}
if endpoints.String() != "[example.com 6/*, 8.8.8.8 17/53, google.com */*]" {
t.Errorf("unexpected result: %s", endpoints.String())
}
var noEndpoints Endpoints
noEndpoints = []*EndpointPermission{}
if noEndpoints.String() != "[]" {
t.Errorf("unexpected result: %s", noEndpoints.String())
}
}

48
profile/fingerprint.go Normal file
View file

@ -0,0 +1,48 @@
package profile
import "time"
var (
fingerprintWeights = map[string]int{
"full_path": 2,
"partial_path": 1,
"md5_sum": 4,
"sha1_sum": 5,
"sha256_sum": 6,
}
)
// Fingerprint links processes to profiles.
type Fingerprint struct {
OS string
Type string
Value string
Comment string
LastUsed int64
}
// MatchesOS returns whether the Fingerprint is applicable for the current OS.
func (fp *Fingerprint) MatchesOS() bool {
return fp.OS == osIdentifier
}
// GetFingerprintWeight returns the weight of the given fingerprint type.
func GetFingerprintWeight(fpType string) (weight int) {
weight, ok := fingerprintWeights[fpType]
if ok {
return weight
}
return 0
}
// AddFingerprint adds the given fingerprint to the profile.
func (p *Profile) AddFingerprint(fp *Fingerprint) {
if fp.OS == "" {
fp.OS = osIdentifier
}
if fp.LastUsed == 0 {
fp.LastUsed = time.Now().Unix()
}
p.Fingerprints = append(p.Fingerprints, fp)
}

130
profile/flags.go Normal file
View file

@ -0,0 +1,130 @@
package profile
import (
"errors"
"fmt"
"strings"
"github.com/Safing/portmaster/status"
)
// Flags are used to quickly add common attributes to profiles
type Flags map[uint8]uint8
// Profile Flags
const (
// Profile Modes
Prompt uint8 = 0 // Prompt first-seen connections
Blacklist uint8 = 1 // Allow everything not explicitly denied
Whitelist uint8 = 2 // Only allow everything explicitly allowed
// Network Locations
Internet uint8 = 16 // Allow connections to the Internet
LAN uint8 = 17 // Allow connections to the local area network
Localhost uint8 = 18 // Allow connections on the local host
// Specials
Related uint8 = 32 // If and before prompting, allow domains that are related to the program
PeerToPeer uint8 = 33 // Allow program to directly communicate with peers, without resolving DNS first
Service uint8 = 34 // Allow program to accept incoming connections
Independent uint8 = 35 // Ignore profile settings coming from the Community
RequireGate17 uint8 = 36 // Require all connections to go over Gate17
)
var (
// ErrFlagsParseFailed is returned if a an invalid flag is encountered while parsing
ErrFlagsParseFailed = errors.New("profiles: failed to parse flags")
sortedFlags = []uint8{
Prompt,
Blacklist,
Whitelist,
Internet,
LAN,
Localhost,
Related,
PeerToPeer,
Service,
Independent,
RequireGate17,
}
flagIDs = map[string]uint8{
"Prompt": Prompt,
"Blacklist": Blacklist,
"Whitelist": Whitelist,
"Internet": Internet,
"LAN": LAN,
"Localhost": Localhost,
"Related": Related,
"PeerToPeer": PeerToPeer,
"Service": Service,
"Independent": Independent,
"RequireGate17": RequireGate17,
}
flagNames = map[uint8]string{
Prompt: "Prompt",
Blacklist: "Blacklist",
Whitelist: "Whitelist",
Internet: "Internet",
LAN: "LAN",
Localhost: "Localhost",
Related: "Related",
PeerToPeer: "PeerToPeer",
Service: "Service",
Independent: "Independent",
RequireGate17: "RequireGate17",
}
)
// Check checks if a flag is set at all and if it's active in the given security level.
func (flags Flags) Check(flag, level uint8) (active bool, ok bool) {
if flags == nil {
return false, false
}
setting, ok := flags[flag]
if ok {
if setting&level > 0 {
return true, true
}
return false, true
}
return false, false
}
func getLevelMarker(levels, level uint8) string {
if levels&level > 0 {
return "+"
}
return "-"
}
// String return a string representation of Flags
func (flags Flags) String() string {
var markedFlags []string
for _, flag := range sortedFlags {
levels, ok := flags[flag]
if ok {
s := flagNames[flag]
if levels != status.SecurityLevelsAll {
s += getLevelMarker(levels, status.SecurityLevelDynamic)
s += getLevelMarker(levels, status.SecurityLevelSecure)
s += getLevelMarker(levels, status.SecurityLevelFortress)
}
markedFlags = append(markedFlags, s)
}
}
return fmt.Sprintf("[%s]", strings.Join(markedFlags, ", "))
}
// Add adds a flag to the Flags with the given level.
func (flags Flags) Add(flag, levels uint8) {
flags[flag] = levels
}
// Remove removes a flag from the Flags.
func (flags Flags) Remove(flag uint8) {
delete(flags, flag)
}

69
profile/flags_test.go Normal file
View file

@ -0,0 +1,69 @@
package profile
import (
"testing"
"github.com/Safing/portmaster/status"
)
func TestProfileFlags(t *testing.T) {
// check if all IDs have a name
for key, entry := range flagIDs {
if _, ok := flagNames[entry]; !ok {
t.Errorf("could not find entry for %s in flagNames", key)
}
}
// check if all names have an ID
for key, entry := range flagNames {
if _, ok := flagIDs[entry]; !ok {
t.Errorf("could not find entry for %d in flagNames", key)
}
}
testFlags := Flags{
Prompt: status.SecurityLevelsAll,
Internet: status.SecurityLevelsDynamicAndSecure,
LAN: status.SecurityLevelsDynamicAndSecure,
Localhost: status.SecurityLevelsAll,
Related: status.SecurityLevelDynamic,
RequireGate17: status.SecurityLevelsSecureAndFortress,
}
if testFlags.String() != "[Prompt, Internet++-, LAN++-, Localhost, Related+--, RequireGate17-++]" {
t.Errorf("unexpected output: %s", testFlags.String())
}
// // check Has
// emptyFlags := ProfileFlags{}
// for flag, name := range flagNames {
// if !sortedFlags.Has(flag) {
// t.Errorf("sortedFlags should have flag %s (%d)", name, flag)
// }
// if emptyFlags.Has(flag) {
// t.Errorf("emptyFlags should not have flag %s (%d)", name, flag)
// }
// }
//
// // check ProfileFlags creation from strings
// var allFlagStrings []string
// for _, flag := range *sortedFlags {
// allFlagStrings = append(allFlagStrings, flagNames[flag])
// }
// newFlags, err := FlagsFromNames(allFlagStrings)
// if err != nil {
// t.Errorf("error while parsing flags: %s", err)
// }
// if newFlags.String() != sortedFlags.String() {
// t.Errorf("parsed flags are not correct (or tests have not been updated to reflect the right number), expected %v, got %v", *sortedFlags, *newFlags)
// }
//
// // check ProfileFlags Stringer
// flagString := newFlags.String()
// check := strings.Join(allFlagStrings, ",")
// if flagString != check {
// t.Errorf("flag string is not correct, expected %s, got %s", check, flagString)
// }
}

76
profile/framework.go Normal file
View file

@ -0,0 +1,76 @@
package profile
// DEACTIVATED
// import (
// "fmt"
// "os"
// "path/filepath"
// "regexp"
// "strings"
//
// "github.com/Safing/portbase/log"
// )
//
// type Framework struct {
// // go hirarchy up
// FindParent uint8 `json:",omitempty bson:",omitempty"`
// // get path from parent, amount of levels to go up the tree (1 means parent, 2 means parent of parents, and so on)
// MergeWithParent bool `json:",omitempty bson:",omitempty"`
// // instead of getting the path of the parent, merge with it by presenting connections as if they were from that parent
//
// // go hirarchy down
// Find string `json:",omitempty bson:",omitempty"`
// // Regular expression for finding path elements
// Build string `json:",omitempty bson:",omitempty"`
// // Path definitions for building path
// Virtual bool `json:",omitempty bson:",omitempty"`
// // Treat resulting path as virtual, do not check if valid
// }
//
// func (f *Framework) GetNewPath(command string, cwd string) (string, error) {
// // "/usr/bin/python script"
// // to
// // "/path/to/script"
// regex, err := regexp.Compile(f.Find)
// if err != nil {
// return "", fmt.Errorf("profiles(framework): failed to compile framework regex: %s", err)
// }
// matched := regex.FindAllStringSubmatch(command, -1)
// if len(matched) == 0 || len(matched[0]) < 2 {
// return "", fmt.Errorf("profiles(framework): regex \"%s\" for constructing path did not match command \"%s\"", f.Find, command)
// }
//
// var lastError error
// var buildPath string
// for _, buildPath = range strings.Split(f.Build, "|") {
//
// buildPath = strings.Replace(buildPath, "{CWD}", cwd, -1)
// for i := 1; i < len(matched[0]); i++ {
// buildPath = strings.Replace(buildPath, fmt.Sprintf("{%d}", i), matched[0][i], -1)
// }
//
// buildPath = filepath.Clean(buildPath)
//
// if !f.Virtual {
// if !strings.HasPrefix(buildPath, "~/") && !filepath.IsAbs(buildPath) {
// lastError = fmt.Errorf("constructed path \"%s\" from framework is not absolute", buildPath)
// continue
// }
// if _, err := os.Stat(buildPath); os.IsNotExist(err) {
// lastError = fmt.Errorf("constructed path \"%s\" does not exist", buildPath)
// continue
// }
// }
//
// lastError = nil
// break
//
// }
//
// if lastError != nil {
// return "", fmt.Errorf("profiles(framework): failed to construct valid path, last error: %s", lastError)
// }
// log.Tracef("profiles(framework): transformed \"%s\" (%s) to \"%s\"", command, cwd, buildPath)
// return buildPath, nil
// }

30
profile/framework_test.go Normal file
View file

@ -0,0 +1,30 @@
package profile
// DEACTIVATED
// import (
// "testing"
// )
//
// func testGetNewPath(t *testing.T, f *Framework, command, cwd, expect string) {
// newPath, err := f.GetNewPath(command, cwd)
// if err != nil {
// t.Errorf("GetNewPath failed: %s", err)
// }
// if newPath != expect {
// t.Errorf("GetNewPath return unexpected result: got %s, expected %s", newPath, expect)
// }
// }
//
// func TestFramework(t *testing.T) {
// f1 := &Framework{
// Find: "([^ ]+)$",
// Build: "{CWD}/{1}",
// }
// testGetNewPath(t, f1, "/usr/bin/python bash", "/bin", "/bin/bash")
// f2 := &Framework{
// Find: "([^ ]+)$",
// Build: "{1}|{CWD}/{1}",
// }
// testGetNewPath(t, f2, "/usr/bin/python /bin/bash", "/tmp", "/bin/bash")
// }

View file

@ -0,0 +1,47 @@
package profile
import (
"path/filepath"
"strings"
"github.com/Safing/portbase/utils"
)
// GetPathIdentifier returns the identifier from the given path
func GetPathIdentifier(path string) string {
// clean path
// TODO: is this necessary?
cleanedPath, err := filepath.EvalSymlinks(path)
if err == nil {
path = cleanedPath
} else {
path = filepath.Clean(path)
}
splittedPath := strings.Split(path, "/")
// strip sensitive data
switch {
case strings.HasPrefix(path, "/home/"):
splittedPath = splittedPath[3:]
case strings.HasPrefix(path, "/root/"):
splittedPath = splittedPath[2:]
}
// common directories with executable
if i := utils.IndexOfString(splittedPath, "bin"); i > 0 {
splittedPath = splittedPath[i:]
return strings.Join(splittedPath, "/")
}
if i := utils.IndexOfString(splittedPath, "sbin"); i > 0 {
splittedPath = splittedPath[i:]
return strings.Join(splittedPath, "/")
}
// shorten to max 3
if len(splittedPath) > 3 {
splittedPath = splittedPath[len(splittedPath)-3:]
}
return strings.Join(splittedPath, "/")
}

View file

@ -0,0 +1,23 @@
package profile
import "testing"
func testPathID(t *testing.T, execPath, identifierPath string) {
result := GetPathIdentifier(execPath)
if result != identifierPath {
t.Errorf("unexpected identifier path for %s: got %s, expected %s", execPath, result, identifierPath)
}
}
func TestGetPathIdentifier(t *testing.T) {
testPathID(t, "/bin/bash", "bin/bash")
testPathID(t, "/home/user/bin/bash", "bin/bash")
testPathID(t, "/home/user/project/main", "project/main")
testPathID(t, "/root/project/main", "project/main")
testPathID(t, "/tmp/a/b/c/d/install.sh", "c/d/install.sh")
testPathID(t, "/sbin/init", "sbin/init")
testPathID(t, "/lib/systemd/systemd-udevd", "lib/systemd/systemd-udevd")
testPathID(t, "/bundle/ruby/2.4.0/bin/passenger", "bin/passenger")
testPathID(t, "/usr/sbin/cron", "sbin/cron")
testPathID(t, "/usr/local/bin/python", "bin/python")
}

102
profile/index/index.go Normal file
View file

@ -0,0 +1,102 @@
package index
import (
"encoding/base64"
"errors"
"fmt"
"sync"
"github.com/Safing/portbase/database/record"
"github.com/Safing/portbase/utils"
)
// ProfileIndex links an Identifier to Profiles
type ProfileIndex struct {
record.Base
sync.Mutex
ID string
UserProfiles []string
StampProfiles []string
}
func makeIndexRecordKey(fpType, id string) string {
return fmt.Sprintf("index:profiles/%s:%s", fpType, base64.RawURLEncoding.EncodeToString([]byte(id)))
}
// NewIndex returns a new ProfileIndex.
func NewIndex(id string) *ProfileIndex {
return &ProfileIndex{
ID: id,
}
}
// AddUserProfile adds a User Profile to the index.
func (pi *ProfileIndex) AddUserProfile(identifier string) (changed bool) {
if !utils.StringInSlice(pi.UserProfiles, identifier) {
pi.UserProfiles = append(pi.UserProfiles, identifier)
return true
}
return false
}
// AddStampProfile adds a Stamp Profile to the index.
func (pi *ProfileIndex) AddStampProfile(identifier string) (changed bool) {
if !utils.StringInSlice(pi.StampProfiles, identifier) {
pi.StampProfiles = append(pi.StampProfiles, identifier)
return true
}
return false
}
// RemoveUserProfile removes a profile from the index.
func (pi *ProfileIndex) RemoveUserProfile(id string) {
pi.UserProfiles = utils.RemoveFromStringSlice(pi.UserProfiles, id)
}
// RemoveStampProfile removes a profile from the index.
func (pi *ProfileIndex) RemoveStampProfile(id string) {
pi.StampProfiles = utils.RemoveFromStringSlice(pi.StampProfiles, id)
}
// Get gets a ProfileIndex from the database.
func Get(fpType, id string) (*ProfileIndex, error) {
key := makeIndexRecordKey(fpType, id)
r, err := indexDB.Get(key)
if err != nil {
return nil, err
}
// unwrap
if r.IsWrapped() {
// only allocate a new struct, if we need it
new := &ProfileIndex{}
err = record.Unwrap(r, new)
if err != nil {
return nil, err
}
return new, nil
}
// or adjust type
new, ok := r.(*ProfileIndex)
if !ok {
return nil, fmt.Errorf("record not of type *ProfileIndex, but %T", r)
}
return new, nil
}
// Save saves the Identifiers to the database
func (pi *ProfileIndex) Save() error {
if !pi.KeyIsSet() {
if pi.ID != "" {
pi.SetKey(makeIndexRecordKey(pi.ID))
} else {
return errors.New("missing identification Key")
}
}
return indexDB.Put(pi)
}

Some files were not shown because too many files have changed in this diff Show more