Clean up linter errors

This commit is contained in:
Daniel 2019-11-07 16:13:22 +01:00
parent 35c7f4955b
commit f75fc7d162
50 changed files with 402 additions and 334 deletions

View file

@ -3,7 +3,6 @@ package firewall
import ( import (
"context" "context"
"fmt" "fmt"
"net"
"os" "os"
"sync/atomic" "sync/atomic"
"time" "time"
@ -21,26 +20,28 @@ import (
) )
var ( var (
module *modules.Module
// localNet net.IPNet // localNet net.IPNet
localhost net.IP // localhost net.IP
dnsServer net.IPNet // dnsServer net.IPNet
packetsAccepted *uint64 packetsAccepted *uint64
packetsBlocked *uint64 packetsBlocked *uint64
packetsDropped *uint64 packetsDropped *uint64
localNet4 *net.IPNet // localNet4 *net.IPNet
localhost4 = net.IPv4(127, 0, 0, 1) // localhost4 = net.IPv4(127, 0, 0, 1)
localhost6 = net.IPv6loopback // localhost6 = net.IPv6loopback
tunnelNet4 *net.IPNet // tunnelNet4 *net.IPNet
tunnelNet6 *net.IPNet // tunnelNet6 *net.IPNet
tunnelEntry4 = net.IPv4(127, 0, 0, 17) // tunnelEntry4 = net.IPv4(127, 0, 0, 17)
tunnelEntry6 = net.ParseIP("fd17::17") // tunnelEntry6 = net.ParseIP("fd17::17")
) )
func init() { func init() {
modules.Register("firewall", prep, start, stop, "core", "network", "nameserver", "profile", "updates") module = modules.Register("firewall", prep, start, stop, "core", "network", "nameserver", "profile", "updates")
} }
func prep() (err error) { func prep() (err error) {
@ -55,21 +56,21 @@ func prep() (err error) {
return err return err
} }
_, localNet4, err = net.ParseCIDR("127.0.0.0/24") // _, localNet4, err = net.ParseCIDR("127.0.0.0/24")
// Yes, this would normally be 127.0.0.0/8 // // Yes, this would normally be 127.0.0.0/8
// TODO: figure out any side effects // // TODO: figure out any side effects
if err != nil { // if err != nil {
return fmt.Errorf("firewall: failed to parse cidr 127.0.0.0/24: %s", err) // return fmt.Errorf("firewall: failed to parse cidr 127.0.0.0/24: %s", err)
} // }
_, tunnelNet4, err = net.ParseCIDR("127.17.0.0/16") // _, tunnelNet4, err = net.ParseCIDR("127.17.0.0/16")
if err != nil { // if err != nil {
return fmt.Errorf("firewall: failed to parse cidr 127.17.0.0/16: %s", err) // return fmt.Errorf("firewall: failed to parse cidr 127.17.0.0/16: %s", err)
} // }
_, tunnelNet6, err = net.ParseCIDR("fd17::/64") // _, tunnelNet6, err = net.ParseCIDR("fd17::/64")
if err != nil { // if err != nil {
return fmt.Errorf("firewall: failed to parse cidr fd17::/64: %s", err) // return fmt.Errorf("firewall: failed to parse cidr fd17::/64: %s", err)
} // }
var pA uint64 var pA uint64
packetsAccepted = &pA packetsAccepted = &pA
@ -83,9 +84,21 @@ func prep() (err error) {
func start() error { func start() error {
startAPIAuth() startAPIAuth()
go statLogger()
go run() module.StartWorker("stat logger", func(ctx context.Context) error {
go portsInUseCleaner() statLogger()
return nil
})
module.StartWorker("packet handler", func(ctx context.Context) error {
run()
return nil
})
module.StartWorker("ports state cleaner", func(ctx context.Context) error {
portsInUseCleaner()
return nil
})
return interception.Start() return interception.Start()
} }
@ -106,7 +119,7 @@ func handlePacket(pkt packet.Packet) {
// allow local dns // allow local dns
if (pkt.Info().DstPort == 53 || pkt.Info().SrcPort == 53) && pkt.Info().Src.Equal(pkt.Info().Dst) { if (pkt.Info().DstPort == 53 || pkt.Info().SrcPort == 53) && pkt.Info().Src.Equal(pkt.Info().Dst) {
log.Debugf("accepting local dns: %s", pkt) log.Debugf("accepting local dns: %s", pkt)
pkt.PermanentAccept() _ = pkt.PermanentAccept()
return return
} }
@ -114,7 +127,7 @@ func handlePacket(pkt packet.Packet) {
if apiPortSet { if apiPortSet {
if (pkt.Info().DstPort == apiPort || pkt.Info().SrcPort == apiPort) && pkt.Info().Src.Equal(pkt.Info().Dst) { if (pkt.Info().DstPort == apiPort || pkt.Info().SrcPort == apiPort) && pkt.Info().Src.Equal(pkt.Info().Dst) {
log.Debugf("accepting api connection: %s", pkt) log.Debugf("accepting api connection: %s", pkt)
pkt.PermanentAccept() _ = pkt.PermanentAccept()
return return
} }
} }
@ -130,20 +143,20 @@ func handlePacket(pkt packet.Packet) {
switch pkt.Info().Protocol { switch pkt.Info().Protocol {
case packet.ICMP: case packet.ICMP:
log.Debugf("accepting ICMP: %s", pkt) log.Debugf("accepting ICMP: %s", pkt)
pkt.PermanentAccept() _ = pkt.PermanentAccept()
return return
case packet.ICMPv6: case packet.ICMPv6:
log.Debugf("accepting ICMPv6: %s", pkt) log.Debugf("accepting ICMPv6: %s", pkt)
pkt.PermanentAccept() _ = pkt.PermanentAccept()
return return
case packet.IGMP: case packet.IGMP:
log.Debugf("accepting IGMP: %s", pkt) log.Debugf("accepting IGMP: %s", pkt)
pkt.PermanentAccept() _ = pkt.PermanentAccept()
return return
case packet.UDP: case packet.UDP:
if pkt.Info().DstPort == 67 || pkt.Info().DstPort == 68 { if pkt.Info().DstPort == 67 || pkt.Info().DstPort == 68 {
log.Debugf("accepting DHCP: %s", pkt) log.Debugf("accepting DHCP: %s", pkt)
pkt.PermanentAccept() _ = pkt.PermanentAccept()
return return
} }
// TODO: Howto handle NetBios? // TODO: Howto handle NetBios?
@ -310,39 +323,44 @@ func issueVerdict(pkt packet.Packet, link *network.Link, verdict network.Verdict
verdict = link.Verdict verdict = link.Verdict
} }
var err error
switch verdict { switch verdict {
case network.VerdictAccept: case network.VerdictAccept:
atomic.AddUint64(packetsAccepted, 1) atomic.AddUint64(packetsAccepted, 1)
if link.VerdictPermanent { if link.VerdictPermanent {
pkt.PermanentAccept() err = pkt.PermanentAccept()
} else { } else {
pkt.Accept() err = pkt.Accept()
} }
case network.VerdictBlock: case network.VerdictBlock:
atomic.AddUint64(packetsBlocked, 1) atomic.AddUint64(packetsBlocked, 1)
if link.VerdictPermanent { if link.VerdictPermanent {
pkt.PermanentBlock() err = pkt.PermanentBlock()
} else { } else {
pkt.Block() err = pkt.Block()
} }
case network.VerdictDrop: case network.VerdictDrop:
atomic.AddUint64(packetsDropped, 1) atomic.AddUint64(packetsDropped, 1)
if link.VerdictPermanent { if link.VerdictPermanent {
pkt.PermanentDrop() err = pkt.PermanentDrop()
} else { } else {
pkt.Drop() err = pkt.Drop()
} }
case network.VerdictRerouteToNameserver: case network.VerdictRerouteToNameserver:
pkt.RerouteToNameserver() err = pkt.RerouteToNameserver()
case network.VerdictRerouteToTunnel: case network.VerdictRerouteToTunnel:
pkt.RerouteToTunnel() err = pkt.RerouteToTunnel()
default: default:
atomic.AddUint64(packetsDropped, 1) atomic.AddUint64(packetsDropped, 1)
pkt.Drop() err = pkt.Drop()
} }
link.Unlock() link.Unlock()
if err != nil {
log.Warningf("firewall: failed to apply verdict to pkt %s: %s", pkt, err)
}
log.Tracer(pkt.Ctx()).Infof("firewall: %s %s", link.Verdict, link) log.Tracer(pkt.Ctx()).Infof("firewall: %s %s", link.Verdict, link)
} }

View file

@ -7,6 +7,7 @@ import (
"github.com/safing/portmaster/network/packet" "github.com/safing/portmaster/network/packet"
) )
//nolint:golint,stylecheck // FIXME
const ( const (
DO_NOTHING uint8 = iota DO_NOTHING uint8 = iota
BLOCK_PACKET BLOCK_PACKET
@ -25,6 +26,7 @@ var (
inspectorsLock sync.Mutex inspectorsLock sync.Mutex
) )
// RegisterInspector registers a traffic inspector.
func RegisterInspector(name string, inspector inspectorFn, inspectVerdict network.Verdict) (index int) { func RegisterInspector(name string, inspector inspectorFn, inspectVerdict network.Verdict) (index int) {
inspectorsLock.Lock() inspectorsLock.Lock()
defer inspectorsLock.Unlock() defer inspectorsLock.Unlock()
@ -35,13 +37,14 @@ func RegisterInspector(name string, inspector inspectorFn, inspectVerdict networ
return return
} }
// RunInspectors runs all the applicable inspectors on the given packet.
func RunInspectors(pkt packet.Packet, link *network.Link) (network.Verdict, bool) { func RunInspectors(pkt packet.Packet, link *network.Link) (network.Verdict, bool) {
// inspectorsLock.Lock() // inspectorsLock.Lock()
// defer inspectorsLock.Unlock() // defer inspectorsLock.Unlock()
activeInspectors := link.GetActiveInspectors() activeInspectors := link.GetActiveInspectors()
if activeInspectors == nil { if activeInspectors == nil {
activeInspectors = make([]bool, len(inspectors), len(inspectors)) activeInspectors = make([]bool, len(inspectors))
link.SetActiveInspectors(activeInspectors) link.SetActiveInspectors(activeInspectors)
} }

View file

@ -1,2 +1,2 @@
// Package nfqueue provides network interception capabilites on linux via iptables nfqueue. // Package nfqueue provides network interception capabilities on linux via iptables nfqueue.
package nfqueue package nfqueue

View file

@ -27,6 +27,8 @@ func init() {
queues = make(map[uint16]*NFQueue) queues = make(map[uint16]*NFQueue)
} }
// NFQueue holds a Linux NFQ Handle and associated information.
//nolint:maligned // FIXME
type NFQueue struct { type NFQueue struct {
DefaultVerdict uint32 DefaultVerdict uint32
Timeout time.Duration Timeout time.Duration
@ -41,6 +43,7 @@ type NFQueue struct {
Packets chan packet.Packet Packets chan packet.Packet
} }
// NewNFQueue initializes a new netfilter queue.
func NewNFQueue(qid uint16) (nfq *NFQueue, err error) { func NewNFQueue(qid uint16) (nfq *NFQueue, err error) {
if os.Geteuid() != 0 { if os.Geteuid() != 0 {
return nil, errors.New("must be root to intercept packets") return nil, errors.New("must be root to intercept packets")
@ -61,96 +64,98 @@ func NewNFQueue(qid uint16) (nfq *NFQueue, err error) {
return nfq, nil return nfq, nil
} }
func (this *NFQueue) init() error { func (nfq *NFQueue) init() error {
var err error var err error
if this.h, err = C.nfq_open(); err != nil || this.h == nil { if nfq.h, err = C.nfq_open(); err != nil || nfq.h == nil {
return fmt.Errorf("could not open nfqueue: %s", err) return fmt.Errorf("could not open nfqueue: %s", err)
} }
//if this.qh, err = C.nfq_create_queue(this.h, qid, C.get_cb(), unsafe.Pointer(nfq)); err != nil || this.qh == nil { //if nfq.qh, err = C.nfq_create_queue(nfq.h, qid, C.get_cb(), unsafe.Pointer(nfq)); err != nil || nfq.qh == nil {
this.Packets = make(chan packet.Packet, 1) nfq.Packets = make(chan packet.Packet, 1)
if C.nfq_unbind_pf(this.h, C.AF_INET) < 0 { if C.nfq_unbind_pf(nfq.h, C.AF_INET) < 0 {
this.Destroy() nfq.Destroy()
return errors.New("nfq_unbind_pf(AF_INET) failed, are you root?") return errors.New("nfq_unbind_pf(AF_INET) failed, are you root?")
} }
if C.nfq_unbind_pf(this.h, C.AF_INET6) < 0 { if C.nfq_unbind_pf(nfq.h, C.AF_INET6) < 0 {
this.Destroy() nfq.Destroy()
return errors.New("nfq_unbind_pf(AF_INET6) failed") return errors.New("nfq_unbind_pf(AF_INET6) failed")
} }
if C.nfq_bind_pf(this.h, C.AF_INET) < 0 { if C.nfq_bind_pf(nfq.h, C.AF_INET) < 0 {
this.Destroy() nfq.Destroy()
return errors.New("nfq_bind_pf(AF_INET) failed") return errors.New("nfq_bind_pf(AF_INET) failed")
} }
if C.nfq_bind_pf(this.h, C.AF_INET6) < 0 { if C.nfq_bind_pf(nfq.h, C.AF_INET6) < 0 {
this.Destroy() nfq.Destroy()
return errors.New("nfq_bind_pf(AF_INET6) failed") return errors.New("nfq_bind_pf(AF_INET6) failed")
} }
if this.qh, err = C.create_queue(this.h, C.uint16_t(this.qid)); err != nil || this.qh == nil { if nfq.qh, err = C.create_queue(nfq.h, C.uint16_t(nfq.qid)); err != nil || nfq.qh == nil {
C.nfq_close(this.h) C.nfq_close(nfq.h)
return fmt.Errorf("could not create queue: %s", err) return fmt.Errorf("could not create queue: %s", err)
} }
this.fd = int(C.nfq_fd(this.h)) nfq.fd = int(C.nfq_fd(nfq.h))
if C.nfq_set_mode(this.qh, C.NFQNL_COPY_PACKET, 0xffff) < 0 { if C.nfq_set_mode(nfq.qh, C.NFQNL_COPY_PACKET, 0xffff) < 0 {
this.Destroy() nfq.Destroy()
return errors.New("nfq_set_mode(NFQNL_COPY_PACKET) failed") return errors.New("nfq_set_mode(NFQNL_COPY_PACKET) failed")
} }
if C.nfq_set_queue_maxlen(this.qh, 1024*8) < 0 { if C.nfq_set_queue_maxlen(nfq.qh, 1024*8) < 0 {
this.Destroy() nfq.Destroy()
return errors.New("nfq_set_queue_maxlen(1024 * 8) failed") return errors.New("nfq_set_queue_maxlen(1024 * 8) failed")
} }
return nil return nil
} }
func (this *NFQueue) Destroy() { // Destroy closes all the nfqueues.
this.lk.Lock() func (nfq *NFQueue) Destroy() {
defer this.lk.Unlock() nfq.lk.Lock()
defer nfq.lk.Unlock()
if this.fd != 0 && this.Valid() { if nfq.fd != 0 && nfq.Valid() {
syscall.Close(this.fd) syscall.Close(nfq.fd)
} }
if this.qh != nil { if nfq.qh != nil {
C.nfq_destroy_queue(this.qh) C.nfq_destroy_queue(nfq.qh)
this.qh = nil nfq.qh = nil
} }
if this.h != nil { if nfq.h != nil {
C.nfq_close(this.h) C.nfq_close(nfq.h)
this.h = nil nfq.h = nil
} }
// TODO: don't close, we're exiting anyway // TODO: don't close, we're exiting anyway
// if this.Packets != nil { // if nfq.Packets != nil {
// close(this.Packets) // close(nfq.Packets)
// } // }
} }
func (this *NFQueue) Valid() bool { // Valid returns whether the NFQueue is still valid.
return this.h != nil && this.qh != nil func (nfq *NFQueue) Valid() bool {
return nfq.h != nil && nfq.qh != nil
} }
//export go_nfq_callback //export go_nfq_callback
func go_nfq_callback(id uint32, hwproto uint16, hook uint8, mark *uint32, func go_nfq_callback(id uint32, hwproto uint16, hook uint8, mark *uint32,
version, protocol, tos, ttl uint8, saddr, daddr unsafe.Pointer, version, protocol, tos, ttl uint8, saddr, daddr unsafe.Pointer,
sport, dport, checksum uint16, payload_len uint32, payload, data unsafe.Pointer) (v uint32) { sport, dport, checksum uint16, payloadLen uint32, payload, data unsafe.Pointer) (v uint32) {
qidptr := (*uint16)(data) qidptr := (*uint16)(data)
qid := uint16(*qidptr) qid := *qidptr
// nfq := (*NFQueue)(nfqptr) // nfq := (*NFQueue)(nfqptr)
ipVersion := packet.IPVersion(version) ipVersion := packet.IPVersion(version)
ipsz := C.int(ipVersion.ByteSize()) ipsz := C.int(ipVersion.ByteSize())
bs := C.GoBytes(payload, (C.int)(payload_len)) bs := C.GoBytes(payload, (C.int)(payloadLen))
verdict := make(chan uint32, 1) verdict := make(chan uint32, 1)
pkt := Packet{ pkt := Packet{
QueueId: qid, QueueID: qid,
Id: id, ID: id,
HWProtocol: hwproto, HWProtocol: hwproto,
Hook: hook, Hook: hook,
Mark: *mark, Mark: *mark,

View file

@ -1,15 +1,18 @@
package nfqueue package nfqueue
import ( import (
"fmt" "errors"
"github.com/safing/portmaster/network/packet" "github.com/safing/portmaster/network/packet"
) )
// NFQ Errors
var ( var (
ErrVerdictSentOrTimedOut error = fmt.Errorf("The verdict was already sent or timed out.") ErrVerdictSentOrTimedOut = errors.New("the verdict was already sent or timed out")
) )
// NFQ Packet Constants
//nolint:golint,stylecheck // FIXME
const ( const (
NFQ_DROP uint32 = 0 // discarded the packet NFQ_DROP uint32 = 0 // discarded the packet
NFQ_ACCEPT uint32 = 1 // the packet passes, continue iterations NFQ_ACCEPT uint32 = 1 // the packet passes, continue iterations
@ -19,11 +22,12 @@ const (
NFQ_STOP uint32 = 5 // accept, but don't continue iterations NFQ_STOP uint32 = 5 // accept, but don't continue iterations
) )
// Packet represents a packet with a NFQ reference.
type Packet struct { type Packet struct {
packet.Base packet.Base
QueueId uint16 QueueID uint16
Id uint32 ID uint32
HWProtocol uint16 HWProtocol uint16
Hook uint8 Hook uint8
Mark uint32 Mark uint32
@ -35,9 +39,10 @@ type Packet struct {
// func (pkt *Packet) String() string { // func (pkt *Packet) String() string {
// return fmt.Sprintf("<Packet QId: %d, Id: %d, Type: %s, Src: %s:%d, Dst: %s:%d, Mark: 0x%X, Checksum: 0x%X, TOS: 0x%X, TTL: %d>", // return fmt.Sprintf("<Packet QId: %d, Id: %d, Type: %s, Src: %s:%d, Dst: %s:%d, Mark: 0x%X, Checksum: 0x%X, TOS: 0x%X, TTL: %d>",
// pkt.QueueId, pkt.Id, pkt.Protocol, pkt.Src, pkt.SrcPort, pkt.Dst, pkt.DstPort, pkt.Mark, pkt.Checksum, pkt.Tos, pkt.TTL) // pkt.QueueID, pkt.Id, pkt.Protocol, pkt.Src, pkt.SrcPort, pkt.Dst, pkt.DstPort, pkt.Mark, pkt.Checksum, pkt.Tos, pkt.TTL)
// } // }
//nolint:unparam // FIXME
func (pkt *Packet) setVerdict(v uint32) (err error) { func (pkt *Packet) setVerdict(v uint32) (err error) {
defer func() { defer func() {
if x := recover(); x != nil { if x := recover(); x != nil {
@ -68,41 +73,49 @@ func (pkt *Packet) setVerdict(v uint32) (err error) {
// return pkt.setVerdict(NFQ_DROP) // return pkt.setVerdict(NFQ_DROP)
// } // }
// Accept implements the packet interface.
func (pkt *Packet) Accept() error { func (pkt *Packet) Accept() error {
pkt.Mark = 1700 pkt.Mark = 1700
return pkt.setVerdict(NFQ_ACCEPT) return pkt.setVerdict(NFQ_ACCEPT)
} }
// Block implements the packet interface.
func (pkt *Packet) Block() error { func (pkt *Packet) Block() error {
pkt.Mark = 1701 pkt.Mark = 1701
return pkt.setVerdict(NFQ_ACCEPT) return pkt.setVerdict(NFQ_ACCEPT)
} }
// Drop implements the packet interface.
func (pkt *Packet) Drop() error { func (pkt *Packet) Drop() error {
pkt.Mark = 1702 pkt.Mark = 1702
return pkt.setVerdict(NFQ_ACCEPT) return pkt.setVerdict(NFQ_ACCEPT)
} }
// PermanentAccept implements the packet interface.
func (pkt *Packet) PermanentAccept() error { func (pkt *Packet) PermanentAccept() error {
pkt.Mark = 1710 pkt.Mark = 1710
return pkt.setVerdict(NFQ_ACCEPT) return pkt.setVerdict(NFQ_ACCEPT)
} }
// PermanentBlock implements the packet interface.
func (pkt *Packet) PermanentBlock() error { func (pkt *Packet) PermanentBlock() error {
pkt.Mark = 1711 pkt.Mark = 1711
return pkt.setVerdict(NFQ_ACCEPT) return pkt.setVerdict(NFQ_ACCEPT)
} }
// PermanentDrop implements the packet interface.
func (pkt *Packet) PermanentDrop() error { func (pkt *Packet) PermanentDrop() error {
pkt.Mark = 1712 pkt.Mark = 1712
return pkt.setVerdict(NFQ_ACCEPT) return pkt.setVerdict(NFQ_ACCEPT)
} }
// RerouteToNameserver implements the packet interface.
func (pkt *Packet) RerouteToNameserver() error { func (pkt *Packet) RerouteToNameserver() error {
pkt.Mark = 1799 pkt.Mark = 1799
return pkt.setVerdict(NFQ_ACCEPT) return pkt.setVerdict(NFQ_ACCEPT)
} }
// RerouteToTunnel implements the packet interface.
func (pkt *Packet) RerouteToTunnel() error { func (pkt *Packet) RerouteToTunnel() error {
pkt.Mark = 1717 pkt.Mark = 1717
return pkt.setVerdict(NFQ_ACCEPT) return pkt.setVerdict(NFQ_ACCEPT)

View file

@ -247,28 +247,28 @@ func StartNfqueueInterception() (err error) {
err = activateNfqueueFirewall() err = activateNfqueueFirewall()
if err != nil { if err != nil {
Stop() _ = Stop()
return fmt.Errorf("could not initialize nfqueue: %s", err) return fmt.Errorf("could not initialize nfqueue: %s", err)
} }
out4Queue, err = nfqueue.NewNFQueue(17040) out4Queue, err = nfqueue.NewNFQueue(17040)
if err != nil { if err != nil {
Stop() _ = Stop()
return fmt.Errorf("interception: failed to create nfqueue(IPv4, in): %s", err) return fmt.Errorf("interception: failed to create nfqueue(IPv4, in): %s", err)
} }
in4Queue, err = nfqueue.NewNFQueue(17140) in4Queue, err = nfqueue.NewNFQueue(17140)
if err != nil { if err != nil {
Stop() _ = Stop()
return fmt.Errorf("interception: failed to create nfqueue(IPv4, in): %s", err) return fmt.Errorf("interception: failed to create nfqueue(IPv4, in): %s", err)
} }
out6Queue, err = nfqueue.NewNFQueue(17060) out6Queue, err = nfqueue.NewNFQueue(17060)
if err != nil { if err != nil {
Stop() _ = Stop()
return fmt.Errorf("interception: failed to create nfqueue(IPv4, in): %s", err) return fmt.Errorf("interception: failed to create nfqueue(IPv4, in): %s", err)
} }
in6Queue, err = nfqueue.NewNFQueue(17160) in6Queue, err = nfqueue.NewNFQueue(17160)
if err != nil { if err != nil {
Stop() _ = Stop()
return fmt.Errorf("interception: failed to create nfqueue(IPv4, in): %s", err) return fmt.Errorf("interception: failed to create nfqueue(IPv4, in): %s", err)
} }
@ -321,12 +321,3 @@ func handleInterception() {
} }
} }
} }
func stringInSlice(slice []string, value string) bool {
for _, entry := range slice {
if value == entry {
return true
}
}
return false
}

View file

@ -1,4 +1,5 @@
// +build windows // +build windows
package windowskext package windowskext
import ( import (

View file

@ -1,4 +1,5 @@
// +build windows // +build windows
package windowskext package windowskext
import ( import (

View file

@ -1,4 +1,5 @@
// +build windows // +build windows
package windowskext package windowskext
import ( import (

View file

@ -1,26 +0,0 @@
package main
import (
"fmt"
"unsafe"
)
const integerSize int = int(unsafe.Sizeof(0))
func isBigEndian() bool {
var i int = 0x1
bs := (*[integerSize]byte)(unsafe.Pointer(&i))
if bs[0] == 0 {
return true
} else {
return false
}
}
func main() {
if isBigEndian() {
fmt.Println("System is Big Endian (Network Byte Order): uint16 0x1234 is 0x1234 in memory")
} else {
fmt.Println("System is Little Endian (Host Byte Order): uint16 0x1234 is 0x3412 in memory")
}
}

View file

@ -139,6 +139,7 @@ func DecideOnCommunicationAfterIntel(comm *network.Communication, fqdn string, r
} }
// FilterDNSResponse filters a dns response according to the application profile and settings. // FilterDNSResponse filters a dns response according to the application profile and settings.
//nolint:gocognit // FIXME
func FilterDNSResponse(comm *network.Communication, q *intel.Query, rrCache *intel.RRCache) *intel.RRCache { func FilterDNSResponse(comm *network.Communication, q *intel.Query, rrCache *intel.RRCache) *intel.RRCache {
// do not modify own queries - this should not happen anyway // do not modify own queries - this should not happen anyway
if comm.Process().Pid == os.Getpid() { if comm.Process().Pid == os.Getpid() {
@ -497,7 +498,7 @@ func checkRelation(comm *network.Communication, fqdn string) (related bool) {
// TODO: add #AI // TODO: add #AI
pathElements := strings.Split(comm.Process().Path, "/") // FIXME: path seperator pathElements := strings.Split(comm.Process().Path, "/") // FIXME: path separator
// only look at the last two path segments // only look at the last two path segments
if len(pathElements) > 2 { if len(pathElements) > 2 {
pathElements = pathElements[len(pathElements)-2:] pathElements = pathElements[len(pathElements)-2:]
@ -537,5 +538,5 @@ matchLoop:
log.Infof("firewall: permitting communication %s, match to domain was found: %s is related to %s", comm, domainElement, processElement) log.Infof("firewall: permitting communication %s, match to domain was found: %s is related to %s", comm, domainElement, processElement)
comm.Accept(fmt.Sprintf("domain is related to process: %s is related to %s", domainElement, processElement)) comm.Accept(fmt.Sprintf("domain is related to process: %s is related to %s", domainElement, processElement))
} }
return return related
} }

View file

@ -80,10 +80,10 @@ func cleanPortsInUse() {
portsInUseLock.Lock() portsInUseLock.Lock()
defer portsInUseLock.Unlock() defer portsInUseLock.Unlock()
threshhold := time.Now().Add(-cleanTimeout) threshold := time.Now().Add(-cleanTimeout)
for port, status := range portsInUse { for port, status := range portsInUse {
if status.lastSeen.Before(threshhold) { if status.lastSeen.Before(threshold) {
delete(portsInUse, port) delete(portsInUse, port)
} }
} }

View file

@ -1,6 +1,7 @@
package firewall package firewall
import ( import (
"context"
"fmt" "fmt"
"time" "time"
@ -24,6 +25,11 @@ const (
denyServingIP = "deny-serving-ip" denyServingIP = "deny-serving-ip"
) )
var (
mtSaveProfile = "save profile"
)
//nolint:gocognit // FIXME
func prompt(comm *network.Communication, link *network.Link, pkt packet.Packet, fqdn string) { func prompt(comm *network.Communication, link *network.Link, pkt packet.Packet, fqdn string) {
nTTL := time.Duration(promptTimeout()) * time.Second nTTL := time.Duration(promptTimeout()) * time.Second
@ -66,11 +72,11 @@ func prompt(comm *network.Communication, link *network.Link, pkt packet.Packet,
case comm.Direction: // incoming case comm.Direction: // incoming
n.Message = fmt.Sprintf("Application %s wants to accept connections from %s (on %d/%d)", comm.Process(), pkt.Info().RemoteIP(), pkt.Info().Protocol, pkt.Info().LocalPort()) n.Message = fmt.Sprintf("Application %s wants to accept connections from %s (on %d/%d)", comm.Process(), pkt.Info().RemoteIP(), pkt.Info().Protocol, pkt.Info().LocalPort())
n.AvailableActions = []*notifications.Action{ n.AvailableActions = []*notifications.Action{
&notifications.Action{ {
ID: permitServingIP, ID: permitServingIP,
Text: "Permit", Text: "Permit",
}, },
&notifications.Action{ {
ID: denyServingIP, ID: denyServingIP,
Text: "Deny", Text: "Deny",
}, },
@ -78,11 +84,11 @@ func prompt(comm *network.Communication, link *network.Link, pkt packet.Packet,
case fqdn == "": // direct connection case fqdn == "": // direct connection
n.Message = fmt.Sprintf("Application %s wants to connect to %s (on %d/%d)", comm.Process(), pkt.Info().RemoteIP(), pkt.Info().Protocol, pkt.Info().RemotePort()) n.Message = fmt.Sprintf("Application %s wants to connect to %s (on %d/%d)", comm.Process(), pkt.Info().RemoteIP(), pkt.Info().Protocol, pkt.Info().RemotePort())
n.AvailableActions = []*notifications.Action{ n.AvailableActions = []*notifications.Action{
&notifications.Action{ {
ID: permitIP, ID: permitIP,
Text: "Permit", Text: "Permit",
}, },
&notifications.Action{ {
ID: denyIP, ID: denyIP,
Text: "Deny", Text: "Deny",
}, },
@ -94,15 +100,15 @@ func prompt(comm *network.Communication, link *network.Link, pkt packet.Packet,
n.Message = fmt.Sprintf("Application %s wants to connect to %s", comm.Process(), comm.Domain) n.Message = fmt.Sprintf("Application %s wants to connect to %s", comm.Process(), comm.Domain)
} }
n.AvailableActions = []*notifications.Action{ n.AvailableActions = []*notifications.Action{
&notifications.Action{ {
ID: permitDomainAll, ID: permitDomainAll,
Text: "Permit all", Text: "Permit all",
}, },
&notifications.Action{ {
ID: permitDomainDistinct, ID: permitDomainDistinct,
Text: "Permit", Text: "Permit",
}, },
&notifications.Action{ {
ID: denyDomainDistinct, ID: denyDomainDistinct,
Text: "Deny", Text: "Deny",
}, },
@ -182,7 +188,9 @@ func prompt(comm *network.Communication, link *network.Link, pkt packet.Packet,
} }
// save! // save!
go userProfile.Save("") module.StartMicroTask(&mtSaveProfile, func(ctx context.Context) error {
return userProfile.Save("")
})
case <-n.Expired(): case <-n.Expired():
if link != nil { if link != nil {

View file

@ -37,6 +37,7 @@ var (
ErrNoCompliance = fmt.Errorf("%w: no compliant resolvers for this query", ErrBlocked) ErrNoCompliance = fmt.Errorf("%w: no compliant resolvers for this query", ErrBlocked)
) )
// Query describes a dns query.
type Query struct { type Query struct {
FQDN string FQDN string
QType dns.Type QType dns.Type

View file

@ -11,6 +11,7 @@ import (
"github.com/safing/portmaster/network/environment" "github.com/safing/portmaster/network/environment"
) )
// DNS Resolver Attributes
const ( const (
ServerTypeDNS = "dns" ServerTypeDNS = "dns"
ServerTypeTCP = "tcp" ServerTypeTCP = "tcp"
@ -85,6 +86,7 @@ func (brc *BasicResolverConn) LastFail() time.Time {
return brc.lastFail return brc.lastFail
} }
// Query executes the given query against the resolver.
func (brc *BasicResolverConn) Query(ctx context.Context, q *Query) (*RRCache, error) { func (brc *BasicResolverConn) Query(ctx context.Context, q *Query) (*RRCache, error) {
// convenience // convenience
resolver := brc.resolver resolver := brc.resolver

View file

@ -25,7 +25,7 @@ var (
mtDNSRequest = "dns request" mtDNSRequest = "dns request"
listenAddress = "0.0.0.0:53" listenAddress = "0.0.0.0:53"
IPv4Localhost = net.IPv4(127, 0, 0, 1) ipv4Localhost = net.IPv4(127, 0, 0, 1)
localhostRRs []dns.RR localhostRRs []dns.RR
) )
@ -135,7 +135,7 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, query *dns.Msg) er
log.Warningf("nameserver: could not get remote address of request for %s%s, ignoring", q.FQDN, q.QType) log.Warningf("nameserver: could not get remote address of request for %s%s, ignoring", q.FQDN, q.QType)
return nil return nil
} }
if !remoteAddr.IP.Equal(IPv4Localhost) { if !remoteAddr.IP.Equal(ipv4Localhost) {
// if request is not coming from 127.0.0.1, check if it's really local // if request is not coming from 127.0.0.1, check if it's really local
localAddr, ok := w.RemoteAddr().(*net.UDPAddr) localAddr, ok := w.RemoteAddr().(*net.UDPAddr)

View file

@ -22,7 +22,7 @@ var (
mtDNSRequest = "dns request" mtDNSRequest = "dns request"
listenAddress = "127.0.0.1:53" listenAddress = "127.0.0.1:53"
IPv4Localhost = net.IPv4(127, 0, 0, 1) ipv4Localhost = net.IPv4(127, 0, 0, 1)
localhostRRs []dns.RR localhostRRs []dns.RR
) )
@ -127,7 +127,7 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, query *dns.Msg) er
log.Warningf("nameserver: could not get remote address of request for %s%s, ignoring", q.FQDN, q.QType) log.Warningf("nameserver: could not get remote address of request for %s%s, ignoring", q.FQDN, q.QType)
return nil return nil
} }
if !remoteAddr.IP.Equal(IPv4Localhost) { if !remoteAddr.IP.Equal(ipv4Localhost) {
// if request is not coming from 127.0.0.1, check if it's really local // if request is not coming from 127.0.0.1, check if it's really local
localAddr, ok := w.RemoteAddr().(*net.UDPAddr) localAddr, ok := w.RemoteAddr().(*net.UDPAddr)

View file

@ -1,6 +1,7 @@
package network package network
import ( import (
"context"
"time" "time"
"github.com/safing/portbase/log" "github.com/safing/portbase/log"
@ -11,7 +12,8 @@ var (
cleanerTickDuration = 10 * time.Second cleanerTickDuration = 10 * time.Second
deleteLinksAfterEndedThreshold = 5 * time.Minute deleteLinksAfterEndedThreshold = 5 * time.Minute
deleteCommsWithoutLinksThreshhold = 3 * time.Minute deleteCommsWithoutLinksThreshhold = 3 * time.Minute
lastEstablishedUpdateThreshold = 30 * time.Second
mtSaveLink = "save network link"
) )
func cleaner() { func cleaner() {
@ -68,12 +70,17 @@ func cleanLinks() (activeComms map[string]struct{}) {
link.Ended = now link.Ended = now
link.Unlock() link.Unlock()
log.Tracef("network.clean: marked %s as ended", link.DatabaseKey()) log.Tracef("network.clean: marked %s as ended", link.DatabaseKey())
go link.save() // save
linkToSave := link
module.StartMicroTask(&mtSaveLink, func(ctx context.Context) error {
linkToSave.saveAndLog()
return nil
})
} }
} }
return return activeComms
} }
func cleanComms(activeLinks map[string]struct{}) (activeComms map[string]struct{}) { func cleanComms(activeLinks map[string]struct{}) (activeComms map[string]struct{}) {

View file

@ -18,6 +18,7 @@ import (
) )
// Communication describes a logical connection between a process and a domain. // Communication describes a logical connection between a process and a domain.
//nolint:maligned // TODO: fix alignment
type Communication struct { type Communication struct {
record.Base record.Base
sync.Mutex sync.Mutex
@ -288,7 +289,10 @@ func (comm *Communication) SaveIfNeeded() {
comm.Unlock() comm.Unlock()
if save { if save {
comm.save() err := comm.save()
if err != nil {
log.Warningf("network: failed to save comm %s: %s", comm, err)
}
} }
} }

View file

@ -32,7 +32,7 @@ type StorageInterface struct {
func (s *StorageInterface) Get(key string) (record.Record, error) { func (s *StorageInterface) Get(key string) (record.Record, error) {
splitted := strings.Split(key, "/") splitted := strings.Split(key, "/")
switch splitted[0] { switch splitted[0] { //nolint:gocritic // TODO: implement full key space
case "tree": case "tree":
switch len(splitted) { switch len(splitted) {
case 2: case 2:

View file

@ -15,12 +15,14 @@ var (
module *modules.Module module *modules.Module
) )
// InitSubModule initializes module specific things with the given module. Intended to be used as part of the "network" module.
func InitSubModule(m *modules.Module) { func InitSubModule(m *modules.Module) {
module = m module = m
module.RegisterEvent(networkChangedEvent) module.RegisterEvent(networkChangedEvent)
module.RegisterEvent(onlineStatusChangedEvent) module.RegisterEvent(onlineStatusChangedEvent)
} }
// StartSubModule starts module specific things with the given module. Intended to be used as part of the "network" module.
func StartSubModule() error { func StartSubModule() error {
if module == nil { if module == nil {
return errors.New("not initialized") return errors.New("not initialized")

View file

@ -0,0 +1,35 @@
package environment
import (
"os"
"testing"
"github.com/safing/portbase/modules"
"github.com/safing/portmaster/core"
)
func TestMain(m *testing.M) {
// setup
tmpDir, err := core.InitForTesting()
if err != nil {
panic(err)
}
// setup package
netModule := modules.Register("network", nil, nil, nil, "core")
InitSubModule(netModule)
err = StartSubModule()
if err != nil {
panic(err)
}
// run tests
rv := m.Run()
// teardown
core.StopTesting()
_ = os.RemoveAll(tmpDir)
// exit with test run return value
os.Exit(rv)
}

View file

@ -95,7 +95,7 @@ func (l *Location) EstimateNetworkProximity(to *Location) (proximity int) {
} }
} }
return //nolint:nakedreturn return //nolint:nakedret
} }
// PrimitiveNetworkProximity calculates the numerical distance between two IP addresses. Returns a proximity value between 0 (far away) and 100 (nearby). // PrimitiveNetworkProximity calculates the numerical distance between two IP addresses. Returns a proximity value between 0 (far away) and 100 (nearby).

View file

@ -3,7 +3,6 @@ package geoip
import ( import (
"context" "context"
"fmt" "fmt"
"time"
"github.com/safing/portbase/modules" "github.com/safing/portbase/modules"
) )
@ -22,22 +21,12 @@ func start() error {
return fmt.Errorf("goeip: failed to load databases: %s", err) return fmt.Errorf("goeip: failed to load databases: %s", err)
} }
module.RegisterEventHook( return module.RegisterEventHook(
"updates", "updates",
"resource update", "resource update",
"upgrade databases", "upgrade databases",
upgradeDatabases, upgradeDatabases,
) )
// TODO: replace with update subscription
module.NewTask("update databases", func(ctx context.Context, task *modules.Task) {
dbFileLock.Lock()
defer dbFileLock.Unlock()
}).Repeat(10 * time.Minute).MaxDelay(1 * time.Hour)
return nil
} }
func upgradeDatabases(_ context.Context, _ interface{}) error { func upgradeDatabases(_ context.Context, _ interface{}) error {

View file

@ -1,6 +1,7 @@
package network package network
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"sync" "sync"
@ -14,11 +15,8 @@ import (
// FirewallHandler defines the function signature for a firewall handle function // FirewallHandler defines the function signature for a firewall handle function
type FirewallHandler func(pkt packet.Packet, link *Link) type FirewallHandler func(pkt packet.Packet, link *Link)
var (
linkTimeout = 10 * time.Minute
)
// Link describes a distinct physical connection (e.g. TCP connection) - like an instance - of a Connection. // Link describes a distinct physical connection (e.g. TCP connection) - like an instance - of a Connection.
//nolint:maligned // TODO: fix alignment
type Link struct { type Link struct {
record.Base record.Base
sync.Mutex sync.Mutex
@ -75,7 +73,13 @@ func (link *Link) SetFirewallHandler(handler FirewallHandler) {
if link.firewallHandler == nil { if link.firewallHandler == nil {
link.firewallHandler = handler link.firewallHandler = handler
link.pktQueue = make(chan packet.Packet, 1000) link.pktQueue = make(chan packet.Packet, 1000)
go link.packetHandler()
// start handling
module.StartWorker("", func(ctx context.Context) error {
link.packetHandler()
return nil
})
return return
} }
link.firewallHandler = handler link.firewallHandler = handler
@ -98,8 +102,13 @@ func (link *Link) HandlePacket(pkt packet.Packet) {
link.pktQueue <- pkt link.pktQueue <- pkt
return return
} }
log.Warningf("network: link %s does not have a firewallHandler, dropping packet", link) log.Warningf("network: link %s does not have a firewallHandler, dropping packet", link)
pkt.Drop()
err := pkt.Drop()
if err != nil {
log.Warningf("network: failed to drop packet %s: %s", pkt, err)
}
} }
// Accept accepts the link and adds the given reason. // Accept accepts the link and adds the given reason.
@ -195,41 +204,48 @@ func (link *Link) ApplyVerdict(pkt packet.Packet) {
link.Lock() link.Lock()
defer link.Unlock() defer link.Unlock()
var err error
if link.VerdictPermanent { if link.VerdictPermanent {
switch link.Verdict { switch link.Verdict {
case VerdictAccept: case VerdictAccept:
pkt.PermanentAccept() err = pkt.PermanentAccept()
case VerdictBlock: case VerdictBlock:
pkt.PermanentBlock() err = pkt.PermanentBlock()
case VerdictDrop: case VerdictDrop:
pkt.PermanentDrop() err = pkt.PermanentDrop()
case VerdictRerouteToNameserver: case VerdictRerouteToNameserver:
pkt.RerouteToNameserver() err = pkt.RerouteToNameserver()
case VerdictRerouteToTunnel: case VerdictRerouteToTunnel:
pkt.RerouteToTunnel() err = pkt.RerouteToTunnel()
default: default:
pkt.Drop() err = pkt.Drop()
} }
} else { } else {
switch link.Verdict { switch link.Verdict {
case VerdictAccept: case VerdictAccept:
pkt.Accept() err = pkt.Accept()
case VerdictBlock: case VerdictBlock:
pkt.Block() err = pkt.Block()
case VerdictDrop: case VerdictDrop:
pkt.Drop() err = pkt.Drop()
case VerdictRerouteToNameserver: case VerdictRerouteToNameserver:
pkt.RerouteToNameserver() err = pkt.RerouteToNameserver()
case VerdictRerouteToTunnel: case VerdictRerouteToTunnel:
pkt.RerouteToTunnel() err = pkt.RerouteToTunnel()
default: default:
pkt.Drop() err = pkt.Drop()
} }
} }
if err != nil {
log.Warningf("network: failed to apply link verdict to packet %s: %s", pkt, err)
}
} }
// SaveWhenFinished marks the Link for saving after all current actions are finished. // SaveWhenFinished marks the Link for saving after all current actions are finished.
func (link *Link) SaveWhenFinished() { func (link *Link) SaveWhenFinished() {
// FIXME: check if we should lock here
link.saveWhenFinished = true link.saveWhenFinished = true
} }
@ -243,11 +259,19 @@ func (link *Link) SaveIfNeeded() {
link.Unlock() link.Unlock()
if save { if save {
link.save() link.saveAndLog()
} }
} }
// Save saves the link object in the storage and propagates the change. // saveAndLog saves the link object in the storage and propagates the change. It does not return an error, but logs it.
func (link *Link) saveAndLog() {
err := link.save()
if err != nil {
log.Warningf("network: failed to save link %s: %s", link, err)
}
}
// save saves the link object in the storage and propagates the change.
func (link *Link) save() error { func (link *Link) save() error {
// update link // update link
link.Lock() link.Lock()

View file

@ -85,6 +85,7 @@ func initControlLogFile() *os.File {
return initializeLogFile(logFilePath, "control/portmaster-control", info.Version()) return initializeLogFile(logFilePath, "control/portmaster-control", info.Version())
} }
//nolint:deadcode,unused // false positive on linux, currently used by windows only
func logControlError(cErr error) { func logControlError(cErr error) {
// check if error present // check if error present
if cErr == nil { if cErr == nil {
@ -110,6 +111,7 @@ func logControlError(cErr error) {
errorFile.Close() errorFile.Close()
} }
//nolint:deadcode,unused // TODO
func logControlStack() { func logControlStack() {
// check logging dir // check logging dir
logFileBasePath := filepath.Join(logsRoot.Path, "fstree", "control") logFileBasePath := filepath.Join(logsRoot.Path, "fstree", "control")
@ -126,10 +128,11 @@ func logControlStack() {
} }
// write error and close // write error and close
pprof.Lookup("goroutine").WriteTo(errorFile, 1) _ = pprof.Lookup("goroutine").WriteTo(errorFile, 2)
errorFile.Close() errorFile.Close()
} }
//nolint:deadcode,unused // false positive on linux, currently used by windows only
func runAndLogControlError(wrappedFunc func(cmd *cobra.Command, args []string) error) func(cmd *cobra.Command, args []string) error { func runAndLogControlError(wrappedFunc func(cmd *cobra.Command, args []string) error) func(cmd *cobra.Command, args []string) error {
return func(cmd *cobra.Command, args []string) error { return func(cmd *cobra.Command, args []string) error {
err := wrappedFunc(cmd, args) err := wrappedFunc(cmd, args)

View file

@ -65,8 +65,8 @@ func init() {
rootCmd.PersistentFlags().StringVar(&dataDir, "data", "", "set data directory") rootCmd.PersistentFlags().StringVar(&dataDir, "data", "", "set data directory")
rootCmd.PersistentFlags().StringVar(&databaseDir, "db", "", "alias to --data (deprecated)") rootCmd.PersistentFlags().StringVar(&databaseDir, "db", "", "alias to --data (deprecated)")
rootCmd.MarkPersistentFlagDirname("data") _ = rootCmd.MarkPersistentFlagDirname("data")
rootCmd.MarkPersistentFlagDirname("db") _ = rootCmd.MarkPersistentFlagDirname("db")
rootCmd.Flags().BoolVar(&showFullVersion, "version", false, "print version") rootCmd.Flags().BoolVar(&showFullVersion, "version", false, "print version")
rootCmd.Flags().BoolVar(&showShortVersion, "ver", false, "print version number only") rootCmd.Flags().BoolVar(&showShortVersion, "ver", false, "print version number only")
} }
@ -85,11 +85,10 @@ func main() {
// }() // }()
// catch interrupt for clean shutdown // catch interrupt for clean shutdown
signalCh := make(chan os.Signal) signalCh := make(chan os.Signal, 2)
signal.Notify( signal.Notify(
signalCh, signalCh,
os.Interrupt, os.Interrupt,
os.Kill,
syscall.SIGHUP, syscall.SIGHUP,
syscall.SIGINT, syscall.SIGINT,
syscall.SIGTERM, syscall.SIGTERM,

View file

@ -124,7 +124,13 @@ func run(cmd *cobra.Command, opts *Options) (err error) {
if pid != 0 { if pid != 0 {
return fmt.Errorf("another instance of Portmaster Core is already running: PID %d", pid) return fmt.Errorf("another instance of Portmaster Core is already running: PID %d", pid)
} }
defer deleteInstanceLock(opts.ShortIdentifier) defer func() {
err := deleteInstanceLock(opts.ShortIdentifier)
if err != nil {
log.Printf("failed to delete instance lock: %s\n", err)
}
}()
} }
// notify service after some time // notify service after some time
@ -192,6 +198,7 @@ func run(cmd *cobra.Command, opts *Options) (err error) {
} }
} }
// nolint:gocyclo,gocognit // TODO: simplify
func execute(opts *Options, args []string) (cont bool, err error) { func execute(opts *Options, args []string) (cont bool, err error) {
file, err := registry.GetFile(platform(opts.Identifier)) file, err := registry.GetFile(platform(opts.Identifier))
if err != nil { if err != nil {
@ -236,7 +243,7 @@ func execute(opts *Options, args []string) (cont bool, err error) {
} }
// create command // create command
exc := exec.Command(file.Path(), args...) exc := exec.Command(file.Path(), args...) //nolint:gosec // everything is okay
if !runningInConsole && opts.AllowHidingWindow { if !runningInConsole && opts.AllowHidingWindow {
// Windows only: // Windows only:

View file

@ -8,8 +8,9 @@ var (
startupComplete = make(chan struct{}) // signal that the start procedure completed (is never closed, just signaled once) startupComplete = make(chan struct{}) // signal that the start procedure completed (is never closed, just signaled once)
shuttingDown = make(chan struct{}) // signal that we are shutting down (will be closed, may not be closed directly, use initiateShutdown) shuttingDown = make(chan struct{}) // signal that we are shutting down (will be closed, may not be closed directly, use initiateShutdown)
shutdownInitiated = false // not to be used directly shutdownInitiated = false // not to be used directly
shutdownError error // may not be read or written to directly //nolint:deadcode,unused // false positive on linux, currently used by windows only
shutdownLock sync.Mutex shutdownError error // may not be read or written to directly
shutdownLock sync.Mutex
) )
func initiateShutdown(err error) { func initiateShutdown(err error) {
@ -23,6 +24,7 @@ func initiateShutdown(err error) {
} }
} }
//nolint:deadcode,unused // false positive on linux, currently used by windows only
func getShutdownError() error { func getShutdownError() error {
shutdownLock.Lock() shutdownLock.Lock()
defer shutdownLock.Unlock() defer shutdownLock.Unlock()

View file

@ -24,8 +24,7 @@ var (
dbController *database.Controller dbController *database.Controller
dbControllerFlag = abool.NewBool(false) dbControllerFlag = abool.NewBool(false)
deleteProcessesThreshold = 15 * time.Minute deleteProcessesThreshold = 15 * time.Minute
lastEstablishedUpdateThreshold = 30 * time.Second
) )
// GetProcessFromStorage returns a process from the internal storage. // GetProcessFromStorage returns a process from the internal storage.

View file

@ -33,16 +33,14 @@ func (p *Process) FindProfiles(ctx context.Context) error {
} }
var userProfile *profile.Profile var userProfile *profile.Profile
for r := range it.Next { // get first result
it.Cancel() r := <-it.Next
userProfile, err = profile.EnsureProfile(r) // cancel immediately
if err != nil { it.Cancel()
return err // ensure its a profile
} userProfile, err = profile.EnsureProfile(r)
break if err != nil {
} return err
if it.Err() != nil {
return it.Err()
} }
// create new profile if it does not exist. // create new profile if it does not exist.
@ -54,7 +52,7 @@ func (p *Process) FindProfiles(ctx context.Context) error {
} }
if userProfile.MarkUsed() { if userProfile.MarkUsed() {
userProfile.Save(profile.UserNamespace) _ = userProfile.Save(profile.UserNamespace)
} }
// Stamp // Stamp
@ -74,17 +72,7 @@ func (p *Process) FindProfiles(ctx context.Context) error {
return nil return nil
} }
func selectProfile(p *Process, profs []*profile.Profile) (selectedProfile *profile.Profile) { //nolint:deadcode,unused // FIXME
var highestScore int
for _, prof := range profs {
score := matchProfile(p, prof)
if score > highestScore {
selectedProfile = prof
}
}
return
}
func matchProfile(p *Process, prof *profile.Profile) (score int) { func matchProfile(p *Process, prof *profile.Profile) (score int) {
for _, fp := range prof.Fingerprints { for _, fp := range prof.Fingerprints {
score += matchFingerprint(p, fp) score += matchFingerprint(p, fp)
@ -92,6 +80,7 @@ func matchProfile(p *Process, prof *profile.Profile) (score int) {
return return
} }
//nolint:deadcode,unused // FIXME
func matchFingerprint(p *Process, fp *profile.Fingerprint) (score int) { func matchFingerprint(p *Process, fp *profile.Fingerprint) (score int) {
if !fp.MatchesOS() { if !fp.MatchesOS() {
return 0 return 0
@ -100,8 +89,8 @@ func matchFingerprint(p *Process, fp *profile.Fingerprint) (score int) {
switch fp.Type { switch fp.Type {
case "full_path": case "full_path":
if p.Path == fp.Value { if p.Path == fp.Value {
return profile.GetFingerprintWeight(fp.Type)
} }
return profile.GetFingerprintWeight(fp.Type)
case "partial_path": case "partial_path":
// FIXME: if full_path matches, do not match partial paths // FIXME: if full_path matches, do not match partial paths
return profile.GetFingerprintWeight(fp.Type) return profile.GetFingerprintWeight(fp.Type)

View file

@ -155,7 +155,7 @@ func readDirNames(dir string) (names []string) {
defer file.Close() defer file.Close()
names, err = file.Readdirnames(0) names, err = file.Readdirnames(0)
if err != nil { if err != nil {
log.Warningf("process: could not get entries from direcotry %s: %s", dir, err) log.Warningf("process: could not get entries from directory %s: %s", dir, err)
return []string{} return []string{}
} }
return return

View file

@ -47,19 +47,6 @@ const (
UDP6Data = "/proc/net/udp6" UDP6Data = "/proc/net/udp6"
ICMP4Data = "/proc/net/icmp" ICMP4Data = "/proc/net/icmp"
ICMP6Data = "/proc/net/icmp6" ICMP6Data = "/proc/net/icmp6"
TCP_ESTABLISHED = iota + 1
TCP_SYN_SENT
TCP_SYN_RECV
TCP_FIN_WAIT1
TCP_FIN_WAIT2
TCP_TIME_WAIT
TCP_CLOSE
TCP_CLOSE_WAIT
TCP_LAST_ACK
TCP_LISTEN
TCP_CLOSING
TCP_NEW_SYN_RECV
) )
var ( var (

View file

@ -16,7 +16,7 @@ import (
) )
var ( var (
dupReqMap = make(map[int]*sync.Mutex) dupReqMap = make(map[int]*sync.WaitGroup)
dupReqLock sync.Mutex dupReqLock sync.Mutex
) )
@ -61,7 +61,7 @@ func (p *Process) ProfileSet() *profile.Set {
return p.profileSet return p.profileSet
} }
// Strings returns a string represenation of process. // Strings returns a string representation of process.
func (p *Process) String() string { func (p *Process) String() string {
p.Lock() p.Lock()
defer p.Unlock() defer p.Unlock()
@ -79,7 +79,7 @@ func (p *Process) AddCommunication() {
// check if we should save // check if we should save
save := false save := false
if p.LastCommEstablished < time.Now().Add(-3*time.Second).Unix() { if p.LastCommEstablished == 0 || p.LastCommEstablished < time.Now().Add(-3*time.Second).Unix() {
save = true save = true
} }
@ -206,6 +206,43 @@ func GetOrFindProcess(ctx context.Context, pid int) (*Process, error) {
return p, nil return p, nil
} }
func deduplicateRequest(ctx context.Context, pid int) (finishRequest func()) {
dupReqLock.Lock()
defer dupReqLock.Unlock()
// get duplicate request waitgroup
wg, requestActive := dupReqMap[pid]
// someone else is already on it!
if requestActive {
// log that we are waiting
log.Tracer(ctx).Tracef("intel: waiting for duplicate request for PID %d to complete", pid)
// wait
wg.Wait()
// done!
return nil
}
// we are currently the only one doing a request for this
// create new waitgroup
wg = new(sync.WaitGroup)
// add worker (us!)
wg.Add(1)
// add to registry
dupReqMap[pid] = wg
// return function to mark request as finished
return func() {
dupReqLock.Lock()
defer dupReqLock.Unlock()
// mark request as done
wg.Done()
// delete from registry
delete(dupReqMap, pid)
}
}
func loadProcess(ctx context.Context, pid int) (*Process, error) { func loadProcess(ctx context.Context, pid int) (*Process, error) {
if pid == -1 { if pid == -1 {
return UnknownProcess, nil return UnknownProcess, nil
@ -219,35 +256,20 @@ func loadProcess(ctx context.Context, pid int) (*Process, error) {
return process, nil return process, nil
} }
// dedup requests // dedupe!
dupReqLock.Lock() markRequestFinished := deduplicateRequest(ctx, pid)
mutex, requestActive := dupReqMap[pid] if markRequestFinished == nil {
if !requestActive { // we waited for another request, recheck the storage!
mutex = new(sync.Mutex)
mutex.Lock()
dupReqMap[pid] = mutex
dupReqLock.Unlock()
} else {
dupReqLock.Unlock()
log.Tracer(ctx).Tracef("process: waiting for duplicate request for PID %d to complete", pid)
mutex.Lock()
// wait until duplicate request is finished, then fetch current Process and return
mutex.Unlock()
process, ok = GetProcessFromStorage(pid) process, ok = GetProcessFromStorage(pid)
if ok { if ok {
return process, nil return process, nil
} }
return nil, fmt.Errorf("previous request for process with PID %d failed", pid) // if cache is still empty, go ahead
} else {
// we are the first!
defer markRequestFinished()
} }
// lock request for this pid
defer func() {
dupReqLock.Lock()
delete(dupReqMap, pid)
dupReqLock.Unlock()
mutex.Unlock()
}()
// create new process // create new process
new := &Process{ new := &Process{
Pid: pid, Pid: pid,

View file

@ -1,26 +1,26 @@
package process package process
// IsUser returns whether the process is run by a normal user. // IsUser returns whether the process is run by a normal user.
func (m *Process) IsUser() bool { func (p *Process) IsUser() bool {
return m.UserID >= 1000 return p.UserID >= 1000
} }
// IsAdmin returns whether the process is run by an admin user. // IsAdmin returns whether the process is run by an admin user.
func (m *Process) IsAdmin() bool { func (p *Process) IsAdmin() bool {
return m.UserID >= 0 return p.UserID >= 0
} }
// IsSystem returns whether the process is run by the operating system. // IsSystem returns whether the process is run by the operating system.
func (m *Process) IsSystem() bool { func (p *Process) IsSystem() bool {
return m.UserID == 0 return p.UserID == 0
} }
// IsKernel returns whether the process is the Kernel. // IsKernel returns whether the process is the Kernel.
func (m *Process) IsKernel() bool { func (p *Process) IsKernel() bool {
return m.Pid == 0 return p.Pid == 0
} }
// specialOSInit does special OS specific Process initialization. // specialOSInit does special OS specific Process initialization.
func (m *Process) specialOSInit() { func (p *Process) specialOSInit() {
} }

View file

@ -29,7 +29,7 @@ func makeDefaultFallbackProfile() *Profile {
Related: status.SecurityLevelDynamic, Related: status.SecurityLevelDynamic,
}, },
ServiceEndpoints: []*EndpointPermission{ ServiceEndpoints: []*EndpointPermission{
&EndpointPermission{ {
Type: EptAny, Type: EptAny,
Protocol: 0, Protocol: 0,
StartPort: 0, StartPort: 0,

View file

@ -15,8 +15,8 @@ type Endpoints []*EndpointPermission
// EndpointPermission holds a decision about an endpoint. // EndpointPermission holds a decision about an endpoint.
type EndpointPermission struct { type EndpointPermission struct {
Type EPType
Value string Value string
Type EPType
Protocol uint8 Protocol uint8
StartPort uint16 StartPort uint16
@ -55,10 +55,7 @@ const (
// IsSet returns whether the Endpoints object is "set". // IsSet returns whether the Endpoints object is "set".
func (e Endpoints) IsSet() bool { func (e Endpoints) IsSet() bool {
if len(e) > 0 { return len(e) > 0
return true
}
return false
} }
// CheckDomain checks the if the given endpoint matches a EndpointPermission in the list. // CheckDomain checks the if the given endpoint matches a EndpointPermission in the list.
@ -246,7 +243,7 @@ func (ep EndpointPermission) MatchesIP(domain string, ip net.IP, protocol uint8,
} }
func (e Endpoints) String() string { func (e Endpoints) String() string {
var s []string s := make([]string, 0, len(e))
for _, entry := range e { for _, entry := range e {
s = append(s, entry.String()) s = append(s, entry.String())
} }

View file

@ -155,15 +155,14 @@ func TestEndpointMatching(t *testing.T) {
} }
func TestEPString(t *testing.T) { func TestEPString(t *testing.T) {
var endpoints Endpoints var endpoints Endpoints = []*EndpointPermission{
endpoints = []*EndpointPermission{ {
&EndpointPermission{
Type: EptDomain, Type: EptDomain,
Value: "example.com", Value: "example.com",
Protocol: 6, Protocol: 6,
Permit: true, Permit: true,
}, },
&EndpointPermission{ {
Type: EptIPv4, Type: EptIPv4,
Value: "1.1.1.1", Value: "1.1.1.1",
Protocol: 17, // TCP Protocol: 17, // TCP
@ -171,7 +170,7 @@ func TestEPString(t *testing.T) {
EndPort: 53, EndPort: 53,
Permit: false, Permit: false,
}, },
&EndpointPermission{ {
Type: EptDomain, Type: EptDomain,
Value: "example.org", Value: "example.org",
Permit: false, Permit: false,
@ -181,8 +180,7 @@ func TestEPString(t *testing.T) {
t.Errorf("unexpected result: %s", endpoints.String()) t.Errorf("unexpected result: %s", endpoints.String())
} }
var noEndpoints Endpoints var noEndpoints Endpoints = []*EndpointPermission{}
noEndpoints = []*EndpointPermission{}
if noEndpoints.String() != "[]" { if noEndpoints.String() != "[]" {
t.Errorf("unexpected result: %s", noEndpoints.String()) t.Errorf("unexpected result: %s", noEndpoints.String())
} }

View file

@ -36,7 +36,7 @@ func GetFingerprintWeight(fpType string) (weight int) {
} }
// AddFingerprint adds the given fingerprint to the profile. // AddFingerprint adds the given fingerprint to the profile.
func (p *Profile) AddFingerprint(fp *Fingerprint) { func (profile *Profile) AddFingerprint(fp *Fingerprint) {
if fp.OS == "" { if fp.OS == "" {
fp.OS = osIdentifier fp.OS = osIdentifier
} }
@ -44,5 +44,5 @@ func (p *Profile) AddFingerprint(fp *Fingerprint) {
fp.LastUsed = time.Now().Unix() fp.LastUsed = time.Now().Unix()
} }
p.Fingerprints = append(p.Fingerprints, fp) profile.Fingerprints = append(profile.Fingerprints, fp)
} }

View file

@ -58,8 +58,8 @@ func New() *Profile {
} }
// MakeProfileKey creates the correct key for a profile with the given namespace and ID. // MakeProfileKey creates the correct key for a profile with the given namespace and ID.
func MakeProfileKey(namespace, ID string) string { func MakeProfileKey(namespace, id string) string {
return fmt.Sprintf("core:profiles/%s/%s", namespace, ID) return fmt.Sprintf("core:profiles/%s/%s", namespace, id)
} }
// Save saves the profile to the database // Save saves the profile to the database
@ -98,17 +98,17 @@ func (profile *Profile) DetailedString() string {
} }
// GetUserProfile loads a profile from the database. // GetUserProfile loads a profile from the database.
func GetUserProfile(ID string) (*Profile, error) { func GetUserProfile(id string) (*Profile, error) {
return getProfile(UserNamespace, ID) return getProfile(UserNamespace, id)
} }
// GetStampProfile loads a profile from the database. // GetStampProfile loads a profile from the database.
func GetStampProfile(ID string) (*Profile, error) { func GetStampProfile(id string) (*Profile, error) {
return getProfile(StampNamespace, ID) return getProfile(StampNamespace, id)
} }
func getProfile(namespace, ID string) (*Profile, error) { func getProfile(namespace, id string) (*Profile, error) {
r, err := profileDB.Get(MakeProfileKey(namespace, ID)) r, err := profileDB.Get(MakeProfileKey(namespace, id))
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -8,10 +8,6 @@ import (
"github.com/safing/portmaster/status" "github.com/safing/portmaster/status"
) )
var (
emptyFlags = Flags{}
)
// Set handles Profile chaining. // Set handles Profile chaining.
type Set struct { type Set struct {
sync.Mutex sync.Mutex

View file

@ -1,3 +1,4 @@
//nolint:unparam
package profile package profile
import ( import (
@ -30,25 +31,25 @@ func init() {
Independent: status.SecurityLevelFortress, Independent: status.SecurityLevelFortress,
}, },
Endpoints: []*EndpointPermission{ Endpoints: []*EndpointPermission{
&EndpointPermission{ {
Type: EptDomain, Type: EptDomain,
Value: "good.bad.example.com.", Value: "good.bad.example.com.",
Permit: true, Permit: true,
Created: time.Now().Unix(), Created: time.Now().Unix(),
}, },
&EndpointPermission{ {
Type: EptDomain, Type: EptDomain,
Value: "*bad.example.com.", Value: "*bad.example.com.",
Permit: false, Permit: false,
Created: time.Now().Unix(), Created: time.Now().Unix(),
}, },
&EndpointPermission{ {
Type: EptDomain, Type: EptDomain,
Value: "example.com.", Value: "example.com.",
Permit: true, Permit: true,
Created: time.Now().Unix(), Created: time.Now().Unix(),
}, },
&EndpointPermission{ {
Type: EptAny, Type: EptAny,
Permit: true, Permit: true,
Protocol: 6, Protocol: 6,
@ -67,13 +68,13 @@ func init() {
// Internet: status.SecurityLevelsAll, // Internet: status.SecurityLevelsAll,
// }, // },
Endpoints: []*EndpointPermission{ Endpoints: []*EndpointPermission{
&EndpointPermission{ {
Type: EptDomain, Type: EptDomain,
Value: "*bad2.example.com.", Value: "*bad2.example.com.",
Permit: false, Permit: false,
Created: time.Now().Unix(), Created: time.Now().Unix(),
}, },
&EndpointPermission{ {
Type: EptAny, Type: EptAny,
Permit: true, Permit: true,
Protocol: 6, Protocol: 6,
@ -83,7 +84,7 @@ func init() {
}, },
}, },
ServiceEndpoints: []*EndpointPermission{ ServiceEndpoints: []*EndpointPermission{
&EndpointPermission{ {
Type: EptAny, Type: EptAny,
Permit: true, Permit: true,
Protocol: 17, Protocol: 17,
@ -91,7 +92,7 @@ func init() {
EndPort: 12347, EndPort: 12347,
Created: time.Now().Unix(), Created: time.Now().Unix(),
}, },
&EndpointPermission{ // default deny { // default deny
Type: EptAny, Type: EptAny,
Permit: false, Permit: false,
Created: time.Now().Unix(), Created: time.Now().Unix(),
@ -173,6 +174,6 @@ func TestProfileSet(t *testing.T) {
} }
func getLineNumberOfCaller(levels int) int { func getLineNumberOfCaller(levels int) int {
_, _, line, _ := runtime.Caller(levels + 1) _, _, line, _ := runtime.Caller(levels + 1) //nolint:dogsled
return line return line
} }

View file

@ -24,7 +24,7 @@ func initSpecialProfiles() (err error) {
return err return err
} }
globalProfile = makeDefaultGlobalProfile() globalProfile = makeDefaultGlobalProfile()
globalProfile.Save(SpecialNamespace) _ = globalProfile.Save(SpecialNamespace)
} }
fallbackProfile, err = getSpecialProfile("fallback") fallbackProfile, err = getSpecialProfile("fallback")
@ -34,15 +34,15 @@ func initSpecialProfiles() (err error) {
} }
fallbackProfile = makeDefaultFallbackProfile() fallbackProfile = makeDefaultFallbackProfile()
ensureServiceEndpointsDenyAll(fallbackProfile) ensureServiceEndpointsDenyAll(fallbackProfile)
fallbackProfile.Save(SpecialNamespace) _ = fallbackProfile.Save(SpecialNamespace)
} }
ensureServiceEndpointsDenyAll(fallbackProfile) ensureServiceEndpointsDenyAll(fallbackProfile)
return nil return nil
} }
func getSpecialProfile(ID string) (*Profile, error) { func getSpecialProfile(id string) (*Profile, error) {
return getProfile(SpecialNamespace, ID) return getProfile(SpecialNamespace, id)
} }
func ensureServiceEndpointsDenyAll(p *Profile) (changed bool) { func ensureServiceEndpointsDenyAll(p *Profile) (changed bool) {
@ -52,7 +52,7 @@ func ensureServiceEndpointsDenyAll(p *Profile) (changed bool) {
ep.Protocol == 0 && ep.Protocol == 0 &&
ep.StartPort == 0 && ep.StartPort == 0 &&
ep.EndPort == 0 && ep.EndPort == 0 &&
ep.Permit == false { !ep.Permit {
return false return false
} }
} }

View file

@ -52,7 +52,7 @@ func updateListener(sub *database.Subscription) {
profile.Unlock() profile.Unlock()
if profileChanged { if profileChanged {
profile.Save(SpecialNamespace) _ = profile.Save(SpecialNamespace)
continue continue
} }

View file

@ -48,3 +48,7 @@ func initStatusHook() (err error) {
hook, err = database.RegisterHook(query.New(statusDBKey), &statusHook{}) hook, err = database.RegisterHook(query.New(statusDBKey), &statusHook{})
return err return err
} }
func stopStatusHook() error {
return hook.Cancel()
}

View file

@ -9,10 +9,6 @@ import (
_ "github.com/safing/portmaster/core" _ "github.com/safing/portmaster/core"
) )
var (
shutdownSignal = make(chan struct{})
)
func init() { func init() {
modules.Register("status", nil, start, stop, "core") modules.Register("status", nil, start, stop, "core")
} }
@ -57,11 +53,5 @@ func start() error {
} }
func stop() error { func stop() error {
select { return stopStatusHook()
case <-shutdownSignal:
// already closed
default:
close(shutdownSignal)
}
return nil
} }

View file

@ -129,7 +129,6 @@ func ServeFileFromBundle(w http.ResponseWriter, r *http.Request, bundleName stri
} }
readCloser.Close() readCloser.Close()
return
} }
// RedirectToBase redirects the requests to the control app // RedirectToBase redirects the requests to the control app

View file

@ -49,21 +49,19 @@ func initVersionExport() (err error) {
return err return err
} }
module.RegisterEventHook( return module.RegisterEventHook(
"updates", "updates",
eventVersionUpdate, eventVersionUpdate,
"export version status", "export version status",
export, export,
) )
return nil
} }
func stopVersionExport() error { func stopVersionExport() error {
return versionExportHook.Cancel() return versionExportHook.Cancel()
} }
var exportMicroTaskName = "update version status" // export is an event hook
func export(_ context.Context, _ interface{}) error { func export(_ context.Context, _ interface{}) error {
// populate // populate
versionExport.lock.Lock() versionExport.lock.Lock()

View file

@ -100,9 +100,7 @@ func start() error {
}).Repeat(24 * time.Hour).MaxDelay(1 * time.Hour).Schedule(time.Now().Add(10 * time.Second)) }).Repeat(24 * time.Hour).MaxDelay(1 * time.Hour).Schedule(time.Now().Add(10 * time.Second))
// react to upgrades // react to upgrades
initUpgrader() return initUpgrader()
return nil
} }
func stop() error { func stop() error {

View file

@ -28,16 +28,15 @@ const (
) )
var ( var (
upgraderActive = abool.NewBool(false) upgraderActive = abool.NewBool(false)
dontUpgradeBefore = time.Now().Add(5 * time.Minute) pmCtrlUpdate *updater.File
pmCtrlUpdate *updater.File pmCoreUpdate *updater.File
pmCoreUpdate *updater.File
rawVersionRegex = regexp.MustCompile(`^[0-9]+\.[0-9]+\.[0-9]+b?\*?$`) rawVersionRegex = regexp.MustCompile(`^[0-9]+\.[0-9]+\.[0-9]+b?\*?$`)
) )
func initUpgrader() { func initUpgrader() error {
module.RegisterEventHook( return module.RegisterEventHook(
"updates", "updates",
eventResourceUpdate, eventResourceUpdate,
"run upgrades", "run upgrades",