Merge branch 'develop' into feature/ui-revamp

This commit is contained in:
Patrick Pacher 2020-10-15 11:15:07 +02:00
commit f65d3f36e7
No known key found for this signature in database
GPG key ID: E8CD2DA160925A6D
29 changed files with 390 additions and 673 deletions

View file

@ -50,4 +50,4 @@ Operating System:
If applicable you can provide related sections from the log files and ensure to **remove sensitive or otherwise private information**. If applicable you can provide related sections from the log files and ensure to **remove sensitive or otherwise private information**.
- Linux: `/var/lib/portmaster/logs` - Linux: `/var/lib/portmaster/logs`
- Windows: `%PROGRAMDATA%\Portmaster\ļogs` - Windows: `%PROGRAMDATA%\Safing\Portmaster\logs`

View file

@ -18,7 +18,7 @@ import (
func main() { func main() {
// set information // set information
info.Set("Portmaster", "0.5.6", "AGPLv3", true) info.Set("Portmaster", "0.5.7", "AGPLv3", true)
// enable SPN client mode // enable SPN client mode
conf.EnableClient(true) conf.EnableClient(true)

View file

@ -31,7 +31,7 @@ func registerConfig() error {
Key: CfgOptionPermanentVerdictsKey, Key: CfgOptionPermanentVerdictsKey,
Description: "With permanent verdicts, control of a connection is fully handed back to the OS after the initial decision. This brings a great performance increase, but makes it impossible to change the decision of a link later on.", Description: "With permanent verdicts, control of a connection is fully handed back to the OS after the initial decision. This brings a great performance increase, but makes it impossible to change the decision of a link later on.",
OptType: config.OptTypeBool, OptType: config.OptTypeBool,
ExpertiseLevel: config.ExpertiseLevelExpert, ExpertiseLevel: config.ExpertiseLevelDeveloper,
ReleaseLevel: config.ReleaseLevelExperimental, ReleaseLevel: config.ReleaseLevelExperimental,
DefaultValue: true, DefaultValue: true,
Annotations: config.Annotations{ Annotations: config.Annotations{

View file

@ -288,7 +288,7 @@ func issueVerdict(conn *network.Connection, pkt packet.Packet, verdict network.V
err = pkt.RerouteToTunnel() err = pkt.RerouteToTunnel()
case network.VerdictFailed: case network.VerdictFailed:
atomic.AddUint64(packetsFailed, 1) atomic.AddUint64(packetsFailed, 1)
fallthrough err = pkt.Drop()
default: default:
atomic.AddUint64(packetsDropped, 1) atomic.AddUint64(packetsDropped, 1)
err = pkt.Drop() err = pkt.Drop()

View file

@ -25,7 +25,18 @@ func Start() error {
return nil return nil
} }
return start() var inputPackets = Packets
if packetMetricsDestination != "" {
go metrics.writeMetrics()
inputPackets = make(chan packet.Packet)
go func() {
for p := range inputPackets {
Packets <- tracePacket(p)
}
}()
}
return start(inputPackets)
} }
// Stop starts the interception. // Stop starts the interception.
@ -34,5 +45,7 @@ func Stop() error {
return nil return nil
} }
close(metrics.done)
return stop() return stop()
} }

View file

@ -4,10 +4,11 @@ package interception
import ( import (
"github.com/safing/portbase/log" "github.com/safing/portbase/log"
"github.com/safing/portmaster/network/packet"
) )
// start starts the interception. // start starts the interception.
func start() error { func start(_ chan packet.Packet) error {
log.Info("interception: this platform has no support for packet interception - a lot of functionality will be broken") log.Info("interception: this platform has no support for packet interception - a lot of functionality will be broken")
return nil return nil
} }

View file

@ -1,8 +1,10 @@
package interception package interception
import "github.com/safing/portmaster/network/packet"
// start starts the interception. // start starts the interception.
func start() error { func start(ch chan packet.Packet) error {
return StartNfqueueInterception() return StartNfqueueInterception(ch)
} }
// stop starts the interception. // stop starts the interception.

View file

@ -7,11 +7,12 @@ import (
"github.com/safing/portbase/notifications" "github.com/safing/portbase/notifications"
"github.com/safing/portbase/utils/osdetail" "github.com/safing/portbase/utils/osdetail"
"github.com/safing/portmaster/firewall/interception/windowskext" "github.com/safing/portmaster/firewall/interception/windowskext"
"github.com/safing/portmaster/network/packet"
"github.com/safing/portmaster/updates" "github.com/safing/portmaster/updates"
) )
// start starts the interception. // start starts the interception.
func start() error { func start(ch chan packet.Packet) error {
dllFile, err := updates.GetPlatformFile("kext/portmaster-kext.dll") dllFile, err := updates.GetPlatformFile("kext/portmaster-kext.dll")
if err != nil { if err != nil {
return fmt.Errorf("interception: could not get kext dll: %s", err) return fmt.Errorf("interception: could not get kext dll: %s", err)
@ -31,7 +32,7 @@ func start() error {
return fmt.Errorf("interception: could not start windows kext: %s", err) return fmt.Errorf("interception: could not start windows kext: %s", err)
} }
go windowskext.Handler(Packets) go windowskext.Handler(ch)
go handleWindowsDNSCache() go handleWindowsDNSCache()
return nil return nil

View file

@ -0,0 +1,78 @@
package interception
import (
"flag"
"fmt"
"os"
"sync"
"time"
"github.com/safing/portbase/log"
)
var (
packetMetricsDestination string
metrics = &packetMetrics{
done: make(chan struct{}),
}
)
func init() {
flag.StringVar(&packetMetricsDestination, "write-packet-metrics", "", "Write packet metrics to the specified file")
}
type (
performanceRecord struct {
start int64
duration time.Duration
verdict string
}
packetMetrics struct {
done chan struct{}
l sync.Mutex
records []*performanceRecord
}
)
func (pm *packetMetrics) record(tp *tracedPacket, verdict string) {
go func(start int64, duration time.Duration) {
pm.l.Lock()
defer pm.l.Unlock()
pm.records = append(pm.records, &performanceRecord{
start: start,
duration: duration,
verdict: verdict,
})
}(tp.start.UnixNano(), time.Since(tp.start))
}
func (pm *packetMetrics) writeMetrics() {
if packetMetricsDestination == "" {
return
}
f, err := os.Create(packetMetricsDestination)
if err != nil {
log.Errorf("Failed to create packet metrics file: %s", err)
return
}
defer f.Close()
for {
select {
case <-pm.done:
return
case <-time.After(time.Second * 5):
}
pm.l.Lock()
records := pm.records
pm.records = nil
pm.l.Unlock()
for _, r := range records {
fmt.Fprintf(f, "%d;%s;%s;%.2f\n", r.start, r.verdict, r.duration, float64(r.duration)/float64(time.Microsecond))
}
}
}

View file

@ -1,7 +1,7 @@
// +build linux // +build linux
// Package nfqexp contains a nfqueue library experiment. // Package nfq contains a nfqueue library experiment.
package nfqexp package nfq
import ( import (
"context" "context"
@ -79,18 +79,18 @@ func New(qid uint16, v6 bool) (*Queue, error) { //nolint:gocognit
} }
if err := pmpacket.Parse(pkt.Payload, pkt.Info()); err != nil { if err := pmpacket.Parse(pkt.Payload, pkt.Info()); err != nil {
log.Warningf("nfqexp: failed to parse payload: %s", err) log.Warningf("nfqueue: failed to parse payload: %s", err)
_ = pkt.Drop() _ = pkt.Drop()
return 0 return 0
} }
select { select {
case q.packets <- pkt: case q.packets <- pkt:
log.Tracef("nfqexp: queued packet %s (%s -> %s) after %s", pkt.ID(), pkt.Info().Src, pkt.Info().Dst, time.Since(pkt.received)) log.Tracef("nfqueue: queued packet %s (%s -> %s) after %s", pkt.ID(), pkt.Info().Src, pkt.Info().Dst, time.Since(pkt.received))
case <-ctx.Done(): case <-ctx.Done():
return 0 return 0
case <-time.After(time.Second): case <-time.After(time.Second):
log.Warningf("nfqexp: failed to queue packet (%s since it was handed over by the kernel)", time.Since(pkt.received)) log.Warningf("nfqueue: failed to queue packet (%s since it was handed over by the kernel)", time.Since(pkt.received))
} }
go func() { go func() {
@ -98,9 +98,9 @@ func New(qid uint16, v6 bool) (*Queue, error) { //nolint:gocognit
case <-pkt.verdictSet: case <-pkt.verdictSet:
case <-time.After(20 * time.Second): case <-time.After(20 * time.Second):
log.Warningf("nfqexp: no verdict set for packet %s (%s -> %s) after %s, dropping", pkt.ID(), pkt.Info().Src, pkt.Info().Dst, time.Since(pkt.received)) log.Warningf("nfqueue: no verdict set for packet %s (%s -> %s) after %s, dropping", pkt.ID(), pkt.Info().Src, pkt.Info().Dst, time.Since(pkt.received))
if err := pkt.Drop(); err != nil { if err := pkt.Drop(); err != nil {
log.Warningf("nfqexp: failed to apply default-drop to unveridcted packet %s (%s -> %s)", pkt.ID(), pkt.Info().Src, pkt.Info().Dst) log.Warningf("nfqueue: failed to apply default-drop to unveridcted packet %s (%s -> %s)", pkt.ID(), pkt.Info().Src, pkt.Info().Dst)
} }
} }
}() }()
@ -118,7 +118,7 @@ func New(qid uint16, v6 bool) (*Queue, error) { //nolint:gocognit
if opError.Timeout() || opError.Temporary() { if opError.Timeout() || opError.Temporary() {
c := atomic.LoadUint64(&q.pendingVerdicts) c := atomic.LoadUint64(&q.pendingVerdicts)
if c > 0 { if c > 0 {
log.Tracef("nfqexp: waiting for %d pending verdicts", c) log.Tracef("nfqueue: waiting for %d pending verdicts", c)
for atomic.LoadUint64(&q.pendingVerdicts) > 0 { // must NOT use c here for atomic.LoadUint64(&q.pendingVerdicts) > 0 { // must NOT use c here
<-q.verdictCompleted <-q.verdictCompleted
@ -128,7 +128,7 @@ func New(qid uint16, v6 bool) (*Queue, error) { //nolint:gocognit
return 0 return 0
} }
} }
log.Errorf("nfqexp: encountered error while receiving packets: %s\n", e.Error()) log.Errorf("nfqueue: encountered error while receiving packets: %s\n", e.Error())
return 1 return 1
} }
@ -146,7 +146,7 @@ func (q *Queue) Destroy() {
q.cancelSocketCallback() q.cancelSocketCallback()
if err := q.nf.Close(); err != nil { if err := q.nf.Close(); err != nil {
log.Errorf("nfqexp: failed to close queue %d: %s", q.id, err) log.Errorf("nfqueue: failed to close queue %d: %s", q.id, err)
} }
} }

View file

@ -1,6 +1,6 @@
// +build linux // +build linux
package nfqexp package nfq
import ( import (
"errors" "errors"
@ -108,12 +108,12 @@ func (pkt *packet) setMark(mark int) error {
} }
} }
log.Errorf("nfqexp: failed to set verdict %s for %s (%s -> %s): %s", markToString(mark), pkt.ID(), pkt.Info().Src, pkt.Info().Dst, err) log.Errorf("nfqueue: failed to set verdict %s for %s (%s -> %s): %s", markToString(mark), pkt.ID(), pkt.Info().Src, pkt.Info().Dst, err)
return err return err
} }
break break
} }
log.Tracef("nfqexp: marking packet %s (%s -> %s) on queue %d with %s after %s", pkt.ID(), pkt.Info().Src, pkt.Info().Dst, pkt.queue.id, markToString(mark), time.Since(pkt.received)) log.Tracef("nfqueue: marking packet %s (%s -> %s) on queue %d with %s after %s", pkt.ID(), pkt.Info().Src, pkt.Info().Dst, pkt.queue.id, markToString(mark), time.Since(pkt.received))
return nil return nil
} }

View file

@ -1,3 +0,0 @@
Parts of this package (this directory) are forked from the go-nfqueue repo: https://github.com/OneOfOne/go-nfqueue
These portions are copyrighted by Ahmed W.
The fork commit is (with high certainty): https://github.com/OneOfOne/go-nfqueue/commit/3bdd8bdfd98a1ed51119f9cf7494162484dfbe7c

View file

@ -1,4 +0,0 @@
// +build linux
// Package nfqueue provides network interception capabilities on linux via iptables nfqueue.
package nfqueue

View file

@ -1,49 +0,0 @@
// +build linux
package nfqueue
// suspended for now
// import (
// "sync"
//
// "github.com/safing/portmaster/network/packet"
// )
//
// type multiQueue struct {
// qs []*NFQueue
// }
//
// func NewMultiQueue(min, max uint16) (mq *multiQueue) {
// mq = &multiQueue{make([]*NFQueue, 0, max-min)}
// for i := min; i < max; i++ {
// mq.qs = append(mq.qs, NewNFQueue(i))
// }
// return mq
// }
//
// func (mq *multiQueue) Process() <-chan packet.Packet {
// var (
// wg sync.WaitGroup
// out = make(chan packet.Packet, len(mq.qs))
// )
// for _, q := range mq.qs {
// wg.Add(1)
// go func(ch <-chan packet.Packet) {
// for pkt := range ch {
// out <- pkt
// }
// wg.Done()
// }(q.Process())
// }
// go func() {
// wg.Wait()
// close(out)
// }()
// return out
// }
// func (mq *multiQueue) Destroy() {
// for _, q := range mq.qs {
// q.Destroy()
// }
// }

View file

@ -1,88 +0,0 @@
#include "nfqueue.h"
#include "_cgo_export.h"
int nfqueue_cb_new(struct nfq_q_handle *qh, struct nfgenmsg *nfmsg, struct nfq_data *nfa, void *data) {
struct nfqnl_msg_packet_hdr *ph = nfq_get_msg_packet_hdr(nfa);
if(ph == NULL) {
return 1;
}
int id = ntohl(ph->packet_id);
unsigned char * payload;
unsigned char * saddr, * daddr;
uint16_t sport = 0, dport = 0, checksum = 0;
uint32_t mark = nfq_get_nfmark(nfa);
int len = nfq_get_payload(nfa, &payload);
unsigned char * origpayload = payload;
int origlen = len;
if(len < sizeof(struct iphdr)) {
return 0;
}
struct iphdr * ip = (struct iphdr *) payload;
if(ip->version == 4) {
uint32_t ipsz = (ip->ihl << 2);
if(len < ipsz) {
return 0;
}
len -= ipsz;
payload += ipsz;
saddr = (unsigned char *)&ip->saddr;
daddr = (unsigned char *)&ip->daddr;
if(ip->protocol == IPPROTO_TCP) {
if(len < sizeof(struct tcphdr)) {
return 0;
}
struct tcphdr *tcp = (struct tcphdr *) payload;
uint32_t tcpsz = (tcp->doff << 2);
if(len < tcpsz) {
return 0;
}
len -= tcpsz;
payload += tcpsz;
sport = ntohs(tcp->source);
dport = ntohs(tcp->dest);
checksum = ntohs(tcp->check);
} else if(ip->protocol == IPPROTO_UDP) {
if(len < sizeof(struct udphdr)) {
return 0;
}
struct udphdr *u = (struct udphdr *) payload;
len -= sizeof(struct udphdr);
payload += sizeof(struct udphdr);
sport = ntohs(u->source);
dport = ntohs(u->dest);
checksum = ntohs(u->check);
}
} else {
struct ipv6hdr *ip6 = (struct ipv6hdr*) payload;
saddr = (unsigned char *)&ip6->saddr;
daddr = (unsigned char *)&ip6->daddr;
//ipv6
}
//pass everything we can and let Go handle it, I'm not a big fan of C
uint32_t verdict = go_nfq_callback(id, ntohs(ph->hw_protocol), ph->hook, &mark, ip->version, ip->protocol,
ip->tos, ip->ttl, saddr, daddr, sport, dport, checksum, origlen, origpayload, data);
return nfq_set_verdict2(qh, id, verdict, mark, 0, NULL);
}
void loop_for_packets(struct nfq_handle *h) {
int fd = nfq_fd(h);
char buf[65535] __attribute__ ((aligned));
int rv;
while ((rv = recv(fd, buf, sizeof(buf), 0)) && rv >= 0) {
nfq_handle_packet(h, buf, rv);
}
}

View file

@ -1,195 +0,0 @@
// +build linux
package nfqueue
/*
#cgo LDFLAGS: -lnetfilter_queue
#cgo CFLAGS: -Wall
#include "nfqueue.h"
*/
import "C"
import (
"errors"
"fmt"
"os"
"runtime"
"sync"
"syscall"
"time"
"unsafe"
"github.com/safing/portbase/log"
"github.com/safing/portmaster/network/packet"
)
var queues map[uint16]*NFQueue
func init() {
queues = make(map[uint16]*NFQueue)
}
// NFQueue holds a Linux NFQ Handle and associated information.
//nolint:maligned // FIXME
type NFQueue struct {
DefaultVerdict uint32
Timeout time.Duration
qid uint16
qidptr *uint16
h *C.struct_nfq_handle
//qh *C.struct_q_handle
qh *C.struct_nfq_q_handle
fd int
lk sync.Mutex
Packets chan packet.Packet
}
// NewNFQueue initializes a new netfilter queue.
func NewNFQueue(qid uint16) (nfq *NFQueue, err error) {
if os.Geteuid() != 0 {
return nil, errors.New("must be root to intercept packets")
}
nfq = &NFQueue{DefaultVerdict: NFQ_DROP, Timeout: 3000 * time.Millisecond, qid: qid, qidptr: &qid}
queues[nfq.qid] = nfq
err = nfq.init()
if err != nil {
return nil, err
}
go func() {
runtime.LockOSThread()
C.loop_for_packets(nfq.h)
}()
return nfq, nil
}
// PacketChannel returns a packet channel
func (nfq *NFQueue) PacketChannel() <-chan packet.Packet {
return nfq.Packets
}
func (nfq *NFQueue) init() error {
var err error
if nfq.h, err = C.nfq_open(); err != nil || nfq.h == nil {
return fmt.Errorf("could not open nfqueue: %s", err)
}
//if nfq.qh, err = C.nfq_create_queue(nfq.h, qid, C.get_cb(), unsafe.Pointer(nfq)); err != nil || nfq.qh == nil {
nfq.Packets = make(chan packet.Packet, 1)
if C.nfq_unbind_pf(nfq.h, C.AF_INET) < 0 {
nfq.Destroy()
return errors.New("nfq_unbind_pf(AF_INET) failed, are you root?")
}
if C.nfq_unbind_pf(nfq.h, C.AF_INET6) < 0 {
nfq.Destroy()
return errors.New("nfq_unbind_pf(AF_INET6) failed")
}
if C.nfq_bind_pf(nfq.h, C.AF_INET) < 0 {
nfq.Destroy()
return errors.New("nfq_bind_pf(AF_INET) failed")
}
if C.nfq_bind_pf(nfq.h, C.AF_INET6) < 0 {
nfq.Destroy()
return errors.New("nfq_bind_pf(AF_INET6) failed")
}
if nfq.qh, err = C.create_queue(nfq.h, C.uint16_t(nfq.qid)); err != nil || nfq.qh == nil {
C.nfq_close(nfq.h)
return fmt.Errorf("could not create queue: %s", err)
}
nfq.fd = int(C.nfq_fd(nfq.h))
if C.nfq_set_mode(nfq.qh, C.NFQNL_COPY_PACKET, 0xffff) < 0 {
nfq.Destroy()
return errors.New("nfq_set_mode(NFQNL_COPY_PACKET) failed")
}
if C.nfq_set_queue_maxlen(nfq.qh, 1024*8) < 0 {
nfq.Destroy()
return errors.New("nfq_set_queue_maxlen(1024 * 8) failed")
}
return nil
}
// Destroy closes all the nfqueues.
func (nfq *NFQueue) Destroy() {
nfq.lk.Lock()
defer nfq.lk.Unlock()
if nfq.fd != 0 && nfq.Valid() {
syscall.Close(nfq.fd)
}
if nfq.qh != nil {
C.nfq_destroy_queue(nfq.qh)
nfq.qh = nil
}
if nfq.h != nil {
C.nfq_close(nfq.h)
nfq.h = nil
}
// TODO: don't close, we're exiting anyway
// if nfq.Packets != nil {
// close(nfq.Packets)
// }
}
// Valid returns whether the NFQueue is still valid.
func (nfq *NFQueue) Valid() bool {
return nfq.h != nil && nfq.qh != nil
}
//export go_nfq_callback
func go_nfq_callback(id uint32, hwproto uint16, hook uint8, mark *uint32,
version, protocol, tos, ttl uint8, saddr, daddr unsafe.Pointer,
sport, dport, checksum uint16, payloadLen uint32, payload, data unsafe.Pointer) (v uint32) {
qidptr := (*uint16)(data)
qid := *qidptr
// nfq := (*NFQueue)(nfqptr)
bs := C.GoBytes(payload, (C.int)(payloadLen))
verdict := make(chan uint32, 1)
pkt := Packet{
QueueID: qid,
ID: id,
HWProtocol: hwproto,
Hook: hook,
Mark: *mark,
verdict: verdict,
// StartedHandling: time.Now(),
}
// Payload
pkt.Payload = bs
if err := packet.Parse(bs, pkt.Info()); err != nil {
log.Warningf("nfqueue: failed to parse packet: %s; dropping", err)
*mark = 1702
return queues[qid].DefaultVerdict
}
// fmt.Printf("%s queuing packet\n", time.Now().Format("060102 15:04:05.000"))
// BUG: "panic: send on closed channel" when shutting down
queues[qid].Packets <- &pkt
select {
case v = <-pkt.verdict:
*mark = pkt.Mark
// *mark = 1710
case <-time.After(queues[qid].Timeout):
v = queues[qid].DefaultVerdict
}
// log.Tracef("nfqueue: took %s to handle packet", time.Now().Sub(pkt.StartedHandling).String())
return v
}

View file

@ -1,30 +0,0 @@
#pragma once
// #define _BSD_SOURCE
// #define __BSD_SOURCE
// #define __FAVOR_BSD // Just Using _BSD_SOURCE didn't work on my system for some reason
// #define __USE_BSD
#include <stdlib.h>
// #include <sys/socket.h>
// #include <netinet/in.h>
#include <arpa/inet.h>
#include <linux/ip.h>
#include <linux/tcp.h>
#include <linux/udp.h>
#include <linux/ipv6.h>
// #include <linux/netfilter.h>
#include <libnetfilter_queue/libnetfilter_queue.h>
// extern int nfq_callback(uint8_t version, uint8_t protocol, unsigned char *saddr, unsigned char *daddr,
// uint16_t sport, uint16_t dport, unsigned char * extra, void* data);
int nfqueue_cb_new(struct nfq_q_handle *qh, struct nfgenmsg *nfmsg, struct nfq_data *nfa, void *data);
void loop_for_packets(struct nfq_handle *h);
static inline struct nfq_q_handle * create_queue(struct nfq_handle *h, uint16_t qid) {
//we use this because it's more convient to pass the callback in C
// FIXME: check malloc success
uint16_t *data = malloc(sizeof(uint16_t));
*data = qid;
return nfq_create_queue(h, qid, &nfqueue_cb_new, (void*)data);
}

View file

@ -1,129 +0,0 @@
// +build linux
package nfqueue
import (
"errors"
"github.com/safing/portmaster/network/packet"
)
// NFQ Errors
var (
ErrVerdictSentOrTimedOut = errors.New("the verdict was already sent or timed out")
)
// NFQ Packet Constants
//nolint:golint,stylecheck // FIXME
const (
NFQ_DROP uint32 = 0 // discarded the packet
NFQ_ACCEPT uint32 = 1 // the packet passes, continue iterations
NFQ_STOLEN uint32 = 2 // gone away
NFQ_QUEUE uint32 = 3 // inject the packet into a different queue (the target queue number is in the high 16 bits of the verdict)
NFQ_REPEAT uint32 = 4 // iterate the same cycle once more
NFQ_STOP uint32 = 5 // accept, but don't continue iterations
)
// Packet represents a packet with a NFQ reference.
type Packet struct {
packet.Base
QueueID uint16
ID uint32
HWProtocol uint16
Hook uint8
Mark uint32
// StartedHandling time.Time
verdict chan uint32
}
// func (pkt *Packet) String() string {
// return fmt.Sprintf("<Packet QId: %d, Id: %d, Type: %s, Src: %s:%d, Dst: %s:%d, Mark: 0x%X, Checksum: 0x%X, TOS: 0x%X, TTL: %d>",
// pkt.QueueID, pkt.Id, pkt.Protocol, pkt.Src, pkt.SrcPort, pkt.Dst, pkt.DstPort, pkt.Mark, pkt.Checksum, pkt.Tos, pkt.TTL)
// }
// nolint:unparam
func (pkt *Packet) setVerdict(v uint32) (err error) {
defer func() {
if x := recover(); x != nil {
err = ErrVerdictSentOrTimedOut
}
}()
pkt.verdict <- v
close(pkt.verdict)
// log.Tracef("filter: packet %s verdict %d", pkt, v)
return err
}
// Marks:
// 17: Identifier
// 0/1: Just this packet/this Link
// 0/1/2: Accept, Block, Drop
// func (pkt *Packet) Accept() error {
// return pkt.setVerdict(NFQ_STOP)
// }
//
// func (pkt *Packet) Block() error {
// pkt.Mark = 1701
// return pkt.setVerdict(NFQ_ACCEPT)
// }
//
// func (pkt *Packet) Drop() error {
// return pkt.setVerdict(NFQ_DROP)
// }
// Accept implements the packet interface.
func (pkt *Packet) Accept() error {
pkt.Mark = 1700
return pkt.setVerdict(NFQ_ACCEPT)
}
// Block implements the packet interface.
func (pkt *Packet) Block() error {
pkt.Mark = 1701
return pkt.setVerdict(NFQ_ACCEPT)
}
// Drop implements the packet interface.
func (pkt *Packet) Drop() error {
pkt.Mark = 1702
return pkt.setVerdict(NFQ_ACCEPT)
}
// PermanentAccept implements the packet interface.
func (pkt *Packet) PermanentAccept() error {
pkt.Mark = 1710
return pkt.setVerdict(NFQ_ACCEPT)
}
// PermanentBlock implements the packet interface.
func (pkt *Packet) PermanentBlock() error {
pkt.Mark = 1711
return pkt.setVerdict(NFQ_ACCEPT)
}
// PermanentDrop implements the packet interface.
func (pkt *Packet) PermanentDrop() error {
pkt.Mark = 1712
return pkt.setVerdict(NFQ_ACCEPT)
}
// RerouteToNameserver implements the packet interface.
func (pkt *Packet) RerouteToNameserver() error {
pkt.Mark = 1799
return pkt.setVerdict(NFQ_ACCEPT)
}
// RerouteToTunnel implements the packet interface.
func (pkt *Packet) RerouteToTunnel() error {
pkt.Mark = 1717
return pkt.setVerdict(NFQ_ACCEPT)
}
//HUGE warning, if the iptables rules aren't set correctly this can cause some problems.
// func (pkt *Packet) Repeat() error {
// return this.SetVerdict(REPEAT)
// }

View file

@ -10,13 +10,10 @@ import (
"github.com/hashicorp/go-multierror" "github.com/hashicorp/go-multierror"
"github.com/safing/portbase/log" "github.com/safing/portbase/log"
"github.com/safing/portmaster/firewall/interception/nfqexp" "github.com/safing/portmaster/firewall/interception/nfq"
"github.com/safing/portmaster/firewall/interception/nfqueue"
"github.com/safing/portmaster/network/packet" "github.com/safing/portmaster/network/packet"
) )
// iptables -A OUTPUT -p icmp -j", "NFQUEUE", "--queue-num", "1", "--queue-bypass
var ( var (
v4chains []string v4chains []string
v4rules []string v4rules []string
@ -37,13 +34,10 @@ var (
) )
func init() { func init() {
flag.BoolVar(&experimentalNfqueueBackend, "experimental-nfqueue", false, "use experimental nfqueue packet") flag.BoolVar(&experimentalNfqueueBackend, "experimental-nfqueue", false, "(deprecated flag; always used)")
} }
// nfQueueFactoryFunc creates a new nfQueue with qid as the queue number. // nfQueue encapsulates nfQueue providers.
type nfQueueFactoryFunc func(qid uint16, v6 bool) (nfQueue, error)
// nfQueue encapsulates nfQueue providers
type nfQueue interface { type nfQueue interface {
PacketChannel() <-chan packet.Packet PacketChannel() <-chan packet.Packet
Destroy() Destroy()
@ -227,16 +221,11 @@ func deactivateIPTables(protocol iptables.Protocol, rules, chains []string) erro
} }
// StartNfqueueInterception starts the nfqueue interception. // StartNfqueueInterception starts the nfqueue interception.
func StartNfqueueInterception() (err error) { func StartNfqueueInterception(packets chan<- packet.Packet) (err error) {
var nfQueueFactory nfQueueFactoryFunc = func(qid uint16, v6 bool) (nfQueue, error) { // @deprecated, remove in v1
return nfqueue.NewNFQueue(qid)
}
if experimentalNfqueueBackend { if experimentalNfqueueBackend {
log.Infof("nfqueue: using experimental nfqueue backend") log.Warningf("[DEPRECATED] --experimental-nfqueue has been deprecated as the backend is now used by default")
nfQueueFactory = func(qid uint16, v6 bool) (nfQueue, error) { log.Warningf("[DEPRECATED] please remove the flag from your configuration!")
return nfqexp.New(qid, v6)
}
} }
err = activateNfqueueFirewall() err = activateNfqueueFirewall()
@ -245,28 +234,28 @@ func StartNfqueueInterception() (err error) {
return fmt.Errorf("could not initialize nfqueue: %s", err) return fmt.Errorf("could not initialize nfqueue: %s", err)
} }
out4Queue, err = nfQueueFactory(17040, false) out4Queue, err = nfq.New(17040, false)
if err != nil { if err != nil {
_ = Stop() _ = Stop()
return fmt.Errorf("nfqueue(IPv4, out): %w", err) return fmt.Errorf("nfqueue(IPv4, out): %w", err)
} }
in4Queue, err = nfQueueFactory(17140, false) in4Queue, err = nfq.New(17140, false)
if err != nil { if err != nil {
_ = Stop() _ = Stop()
return fmt.Errorf("nfqueue(IPv4, in): %w", err) return fmt.Errorf("nfqueue(IPv4, in): %w", err)
} }
out6Queue, err = nfQueueFactory(17060, true) out6Queue, err = nfq.New(17060, true)
if err != nil { if err != nil {
_ = Stop() _ = Stop()
return fmt.Errorf("nfqueue(IPv6, out): %w", err) return fmt.Errorf("nfqueue(IPv6, out): %w", err)
} }
in6Queue, err = nfQueueFactory(17160, true) in6Queue, err = nfq.New(17160, true)
if err != nil { if err != nil {
_ = Stop() _ = Stop()
return fmt.Errorf("nfqueue(IPv6, in): %w", err) return fmt.Errorf("nfqueue(IPv6, in): %w", err)
} }
go handleInterception() go handleInterception(packets)
return nil return nil
} }
@ -295,23 +284,26 @@ func StopNfqueueInterception() error {
return nil return nil
} }
func handleInterception() { func handleInterception(packets chan<- packet.Packet) {
for { for {
var pkt packet.Packet
select { select {
case <-shutdownSignal: case <-shutdownSignal:
return return
case pkt := <-out4Queue.PacketChannel(): case pkt = <-out4Queue.PacketChannel():
pkt.SetOutbound() pkt.SetOutbound()
Packets <- pkt case pkt = <-in4Queue.PacketChannel():
case pkt := <-in4Queue.PacketChannel():
pkt.SetInbound() pkt.SetInbound()
Packets <- pkt case pkt = <-out6Queue.PacketChannel():
case pkt := <-out6Queue.PacketChannel():
pkt.SetOutbound() pkt.SetOutbound()
Packets <- pkt case pkt = <-in6Queue.PacketChannel():
case pkt := <-in6Queue.PacketChannel():
pkt.SetInbound() pkt.SetInbound()
Packets <- pkt }
select {
case packets <- pkt:
case <-shutdownSignal:
return
} }
} }
} }

View file

@ -0,0 +1,67 @@
package interception
import (
"time"
"github.com/safing/portmaster/network/packet"
)
type tracedPacket struct {
start time.Time
packet.Packet
}
func tracePacket(p packet.Packet) packet.Packet {
return &tracedPacket{
start: time.Now(),
Packet: p,
}
}
func (p *tracedPacket) markServed(v string) {
if packetMetricsDestination == "" {
return
}
metrics.record(p, v)
}
func (p *tracedPacket) Accept() error {
defer p.markServed("accept")
return p.Packet.Accept()
}
func (p *tracedPacket) Block() error {
defer p.markServed("block")
return p.Packet.Block()
}
func (p *tracedPacket) Drop() error {
defer p.markServed("drop")
return p.Packet.Drop()
}
func (p *tracedPacket) PermanentAccept() error {
defer p.markServed("perm-accept")
return p.Packet.PermanentAccept()
}
func (p *tracedPacket) PermanentBlock() error {
defer p.markServed("perm-block")
return p.Packet.PermanentBlock()
}
func (p *tracedPacket) PermanentDrop() error {
defer p.markServed("perm-drop")
return p.Packet.PermanentDrop()
}
func (p *tracedPacket) RerouteToNameserver() error {
defer p.markServed("reroute-ns")
return p.Packet.RerouteToNameserver()
}
func (p *tracedPacket) RerouteToTunnel() error {
defer p.markServed("reroute-tunnel")
return p.Packet.RerouteToTunnel()
}

View file

@ -48,6 +48,8 @@ var (
var ( var (
cache = database.NewInterface(&database.Options{ cache = database.NewInterface(&database.Options{
Local: true,
Internal: true,
CacheSize: 2 ^ 8, CacheSize: 2 ^ 8,
}) })
) )

View file

@ -82,10 +82,19 @@ func handleRequestAsWorker(w dns.ResponseWriter, query *dns.Msg) {
func handleRequest(ctx context.Context, w dns.ResponseWriter, request *dns.Msg) error { //nolint:gocognit // TODO func handleRequest(ctx context.Context, w dns.ResponseWriter, request *dns.Msg) error { //nolint:gocognit // TODO
// Only process first question, that's how everyone does it. // Only process first question, that's how everyone does it.
question := request.Question[0] originalQuestion := request.Question[0]
// Check if we are handling a non-standard query name.
var nonStandardQuestionFormat bool
lowerCaseQuestion := strings.ToLower(originalQuestion.Name)
if lowerCaseQuestion != originalQuestion.Name {
nonStandardQuestionFormat = true
}
// Create query for the resolver.
q := &resolver.Query{ q := &resolver.Query{
FQDN: question.Name, FQDN: lowerCaseQuestion,
QType: dns.Type(question.Qtype), QType: dns.Type(originalQuestion.Qtype),
} }
// Get remote address of request. // Get remote address of request.
@ -118,9 +127,9 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, request *dns.Msg)
} }
// Check the Query Class. // Check the Query Class.
if question.Qclass != dns.ClassINET { if originalQuestion.Qclass != dns.ClassINET {
// we only serve IN records, return nxdomain // we only serve IN records, return nxdomain
tracer.Warningf("nameserver: only IN record requests are supported but received Qclass %d, returning NXDOMAIN", question.Qclass) tracer.Warningf("nameserver: only IN record requests are supported but received QClass %d, returning NXDOMAIN", originalQuestion.Qclass)
return reply(nsutil.Refused("unsupported qclass")) return reply(nsutil.Refused("unsupported qclass"))
} }
@ -245,6 +254,11 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, request *dns.Msg)
// Save dns request as open. // Save dns request as open.
defer network.SaveOpenDNSRequest(conn) defer network.SaveOpenDNSRequest(conn)
// Revert back to non-standard question format, if we had to convert.
if nonStandardQuestionFormat {
rrCache.ReplaceAnswerNames(originalQuestion.Name)
}
// Reply with successful response. // Reply with successful response.
tracer.Infof("nameserver: returning %s response for %s to %s", conn.Verdict.Verb(), q.ID(), conn.Process()) tracer.Infof("nameserver: returning %s response for %s to %s", conn.Verdict.Verb(), q.ID(), conn.Process())
return reply(rrCache, conn, rrCache) return reply(rrCache, conn, rrCache)

View file

@ -21,7 +21,10 @@ const (
) )
var ( var (
profileDB = database.NewInterface(nil) profileDB = database.NewInterface(&database.Options{
Local: true,
Internal: true,
})
) )
func makeScopedID(source profileSource, id string) string { func makeScopedID(source profileSource, id string) string {

View file

@ -5,6 +5,7 @@ import (
"strings" "strings"
"github.com/safing/portmaster/intel" "github.com/safing/portmaster/intel"
"github.com/safing/portmaster/network/netutils"
) )
const ( const (
@ -16,8 +17,7 @@ const (
) )
var ( var (
domainRegex = regexp.MustCompile(`^\*?(([a-z0-9][a-z0-9-]{0,61}[a-z0-9])?\.)*[a-z]{2,}\.?$`) allowedDomainChars = regexp.MustCompile(`^[a-z0-9\.-]+$`)
altDomainRegex = regexp.MustCompile(`^\*?[a-z0-9\.-]+\*$`)
) )
// EndpointDomain matches domains. // EndpointDomain matches domains.
@ -90,51 +90,63 @@ func (ep *EndpointDomain) String() string {
func parseTypeDomain(fields []string) (Endpoint, error) { func parseTypeDomain(fields []string) (Endpoint, error) {
domain := fields[1] domain := fields[1]
if domainRegex.MatchString(domain) || altDomainRegex.MatchString(domain) {
ep := &EndpointDomain{ ep := &EndpointDomain{
OriginalValue: domain, OriginalValue: domain,
} }
// fix domain ending // Fix domain ending.
switch domain[len(domain)-1] { switch domain[len(domain)-1] {
case '.': case '.', '*':
case '*':
default: default:
domain += "." domain += "."
} }
// fix domain case // Fix domain case.
domain = strings.ToLower(domain) domain = strings.ToLower(domain)
needValidFQDN := true
switch { switch {
case strings.HasPrefix(domain, "*") && strings.HasSuffix(domain, "*"): case strings.HasPrefix(domain, "*") && strings.HasSuffix(domain, "*"):
ep.MatchType = domainMatchTypeContains ep.MatchType = domainMatchTypeContains
ep.Domain = strings.Trim(domain, "*") ep.Domain = strings.TrimPrefix(domain, "*")
return ep.parsePPP(ep, fields) ep.Domain = strings.TrimSuffix(ep.Domain, "*")
needValidFQDN = false
case strings.HasSuffix(domain, "*"): case strings.HasSuffix(domain, "*"):
ep.MatchType = domainMatchTypePrefix ep.MatchType = domainMatchTypePrefix
ep.Domain = strings.Trim(domain, "*") ep.Domain = strings.TrimSuffix(domain, "*")
return ep.parsePPP(ep, fields) needValidFQDN = false
// Prefix matching cannot be combined with zone matching
if strings.HasPrefix(ep.Domain, ".") {
return nil, nil
}
case strings.HasPrefix(domain, "*"): case strings.HasPrefix(domain, "*"):
ep.MatchType = domainMatchTypeSuffix ep.MatchType = domainMatchTypeSuffix
ep.Domain = strings.Trim(domain, "*") ep.Domain = strings.TrimPrefix(domain, "*")
return ep.parsePPP(ep, fields) needValidFQDN = false
case strings.HasPrefix(domain, "."): case strings.HasPrefix(domain, "."):
ep.MatchType = domainMatchTypeZone ep.MatchType = domainMatchTypeZone
ep.Domain = strings.TrimLeft(domain, ".") ep.Domain = strings.TrimPrefix(domain, ".")
ep.DomainZone = "." + ep.Domain ep.DomainZone = "." + ep.Domain
return ep.parsePPP(ep, fields)
default: default:
ep.MatchType = domainMatchTypeExact ep.MatchType = domainMatchTypeExact
ep.Domain = domain ep.Domain = domain
return ep.parsePPP(ep, fields)
}
} }
// Validate domain "content".
switch {
case needValidFQDN && !netutils.IsValidFqdn(ep.Domain):
return nil, nil
case !needValidFQDN && !allowedDomainChars.MatchString(ep.Domain):
return nil, nil
case strings.Contains(ep.Domain, ".."):
// The above regex does not catch double dots.
return nil, nil return nil, nil
} }
return ep.parsePPP(ep, fields)
}

View file

@ -5,6 +5,8 @@ import (
"runtime" "runtime"
"testing" "testing"
"github.com/stretchr/testify/assert"
"github.com/safing/portmaster/core/pmtesting" "github.com/safing/portmaster/core/pmtesting"
"github.com/safing/portmaster/intel" "github.com/safing/portmaster/intel"
) )
@ -27,6 +29,31 @@ func testEndpointMatch(t *testing.T, ep Endpoint, entity *intel.Entity, expected
} }
} }
func testFormat(t *testing.T, endpoint string, shouldSucceed bool) {
_, err := parseEndpoint(endpoint)
if shouldSucceed {
assert.NoError(t, err)
} else {
assert.Error(t, err)
}
}
func TestEndpointFormat(t *testing.T) {
testFormat(t, "+ .", false)
testFormat(t, "+ .at", true)
testFormat(t, "+ .at.", true)
testFormat(t, "+ 1.at", true)
testFormat(t, "+ 1.at.", true)
testFormat(t, "+ 1.f.ix.de.", true)
testFormat(t, "+ *contains*", true)
testFormat(t, "+ *has.suffix", true)
testFormat(t, "+ *.has.suffix", true)
testFormat(t, "+ *has.prefix*", true)
testFormat(t, "+ *has.prefix.*", true)
testFormat(t, "+ .sub.and.prefix.*", false)
testFormat(t, "+ *.sub..and.prefix.*", false)
}
func TestEndpointMatching(t *testing.T) { func TestEndpointMatching(t *testing.T) {
// ANY // ANY

View file

@ -215,9 +215,7 @@ func checkCache(ctx context.Context, q *Query) *RRCache {
// Check if the cache will expire soon and start an async request. // Check if the cache will expire soon and start an async request.
if rrCache.ExpiresSoon() { if rrCache.ExpiresSoon() {
// Set flag that we are refreshing this entry. // Set flag that we are refreshing this entry.
rrCache.Lock() rrCache.RequestingNew = true
rrCache.requestingNew = true
rrCache.Unlock()
log.Tracer(ctx).Tracef( log.Tracer(ctx).Tracef(
"resolver: cache for %s will expire in %s, refreshing async now", "resolver: cache for %s will expire in %s, refreshing async now",
@ -397,7 +395,7 @@ resolveLoop:
// Check if we want to use an older cache instead. // Check if we want to use an older cache instead.
if oldCache != nil { if oldCache != nil {
oldCache.isBackup = true oldCache.IsBackup = true
switch { switch {
case err != nil: case err != nil:

View file

@ -196,9 +196,9 @@ func handleMDNSMessages(ctx context.Context, messages chan *dns.Msg) error {
if saveFullRequest { if saveFullRequest {
// get from database // get from database
rrCache, err = GetRRCache(question.Name, dns.Type(question.Qtype)) rrCache, err = GetRRCache(question.Name, dns.Type(question.Qtype))
// if we have no cached entry, or it has been updated less more than two seconds ago, or if it expired: // if we have no cached entry, or it has been updated more than two seconds ago, or if it expired:
// create new and do not append // create new and do not append
if err != nil || rrCache.updated < time.Now().Add(-2*time.Second).Unix() || rrCache.TTL < time.Now().Unix() { if err != nil || rrCache.Modified < time.Now().Add(-2*time.Second).Unix() || rrCache.Expired() {
rrCache = &RRCache{ rrCache = &RRCache{
Domain: question.Name, Domain: question.Name,
Question: dns.Type(question.Qtype), Question: dns.Type(question.Qtype),

View file

@ -5,7 +5,6 @@ import (
"fmt" "fmt"
"math/rand" "math/rand"
"net" "net"
"sync"
"time" "time"
"github.com/safing/portbase/log" "github.com/safing/portbase/log"
@ -15,31 +14,37 @@ import (
"github.com/miekg/dns" "github.com/miekg/dns"
) )
// RRCache is used to cache DNS data // RRCache is a single-use structure to hold a DNS response.
// Persistence is handled through NameRecords because of a limitation of the
// underlying dns library.
//nolint:maligned // TODO //nolint:maligned // TODO
type RRCache struct { type RRCache struct {
sync.Mutex // Respnse Header
Domain string
Question dns.Type
RCode int
Domain string // constant // Response Content
Question dns.Type // constant Answer []dns.RR
RCode int // constant Ns []dns.RR
Extra []dns.RR
TTL int64
Answer []dns.RR // constant // Source Information
Ns []dns.RR // constant Server string
Extra []dns.RR // constant ServerScope int8
TTL int64 // constant ServerInfo string
Server string // constant // Metadata about the request and handling
ServerScope int8 // constant ServedFromCache bool
ServerInfo string // constant RequestingNew bool
IsBackup bool
Filtered bool
FilteredEntries []string
servedFromCache bool // mutable // Modified holds when this entry was last changed, ie. saved to database.
requestingNew bool // mutable // This field is only populated when the entry comes from the cache.
isBackup bool // mutable Modified int64
Filtered bool // mutable
FilteredEntries []string // mutable
updated int64 // mutable
} }
// ID returns the ID of the RRCache consisting of the domain and question type. // ID returns the ID of the RRCache consisting of the domain and question type.
@ -59,9 +64,6 @@ func (rrCache *RRCache) ExpiresSoon() bool {
// Clean sets all TTLs to 17 and sets cache expiry with specified minimum. // Clean sets all TTLs to 17 and sets cache expiry with specified minimum.
func (rrCache *RRCache) Clean(minExpires uint32) { func (rrCache *RRCache) Clean(minExpires uint32) {
rrCache.Lock()
defer rrCache.Unlock()
var lowestTTL uint32 = 0xFFFFFFFF var lowestTTL uint32 = 0xFFFFFFFF
var header *dns.RR_Header var header *dns.RR_Header
@ -200,7 +202,8 @@ func GetRRCache(domain string, question dns.Type) (*RRCache, error) {
rrCache.Server = nameRecord.Server rrCache.Server = nameRecord.Server
rrCache.ServerScope = nameRecord.ServerScope rrCache.ServerScope = nameRecord.ServerScope
rrCache.ServerInfo = nameRecord.ServerInfo rrCache.ServerInfo = nameRecord.ServerInfo
rrCache.servedFromCache = true rrCache.ServedFromCache = true
rrCache.Modified = nameRecord.Meta().Modified
return rrCache, nil return rrCache, nil
} }
@ -217,26 +220,16 @@ func parseRR(section []dns.RR, entry string) []dns.RR {
return section return section
} }
// ServedFromCache marks the RRCache as served from cache.
func (rrCache *RRCache) ServedFromCache() bool {
return rrCache.servedFromCache
}
// RequestingNew informs that it has expired and new RRs are being fetched.
func (rrCache *RRCache) RequestingNew() bool {
return rrCache.requestingNew
}
// Flags formats ServedFromCache and RequestingNew to a condensed, flag-like format. // Flags formats ServedFromCache and RequestingNew to a condensed, flag-like format.
func (rrCache *RRCache) Flags() string { func (rrCache *RRCache) Flags() string {
var s string var s string
if rrCache.servedFromCache { if rrCache.ServedFromCache {
s += "C" s += "C"
} }
if rrCache.requestingNew { if rrCache.RequestingNew {
s += "R" s += "R"
} }
if rrCache.isBackup { if rrCache.IsBackup {
s += "B" s += "B"
} }
if rrCache.Filtered { if rrCache.Filtered {
@ -255,6 +248,7 @@ func (rrCache *RRCache) ShallowCopy() *RRCache {
Domain: rrCache.Domain, Domain: rrCache.Domain,
Question: rrCache.Question, Question: rrCache.Question,
RCode: rrCache.RCode, RCode: rrCache.RCode,
Answer: rrCache.Answer, Answer: rrCache.Answer,
Ns: rrCache.Ns, Ns: rrCache.Ns,
Extra: rrCache.Extra, Extra: rrCache.Extra,
@ -264,12 +258,25 @@ func (rrCache *RRCache) ShallowCopy() *RRCache {
ServerScope: rrCache.ServerScope, ServerScope: rrCache.ServerScope,
ServerInfo: rrCache.ServerInfo, ServerInfo: rrCache.ServerInfo,
updated: rrCache.updated, ServedFromCache: rrCache.ServedFromCache,
servedFromCache: rrCache.servedFromCache, RequestingNew: rrCache.RequestingNew,
requestingNew: rrCache.requestingNew, IsBackup: rrCache.IsBackup,
isBackup: rrCache.isBackup,
Filtered: rrCache.Filtered, Filtered: rrCache.Filtered,
FilteredEntries: rrCache.FilteredEntries, FilteredEntries: rrCache.FilteredEntries,
Modified: rrCache.Modified,
}
}
// ReplaceAnswerNames is a helper function that replaces all answer names, that
// match the query domain, with another value. This is used to support handling
// non-standard query names, which are resolved normalized, but have to be
// reverted back for the origin non-standard query name in order for the
// clients to recognize the response.
func (rrCache *RRCache) ReplaceAnswerNames(fqdn string) {
for _, answer := range rrCache.Answer {
if answer.Header().Name == rrCache.Domain {
answer.Header().Name = fqdn
}
} }
} }
@ -278,21 +285,16 @@ func (rrCache *RRCache) ReplyWithDNS(ctx context.Context, request *dns.Msg) *dns
// reply to query // reply to query
reply := new(dns.Msg) reply := new(dns.Msg)
reply.SetRcode(request, rrCache.RCode) reply.SetRcode(request, rrCache.RCode)
reply.Answer = rrCache.Answer
reply.Ns = rrCache.Ns reply.Ns = rrCache.Ns
reply.Extra = rrCache.Extra reply.Extra = rrCache.Extra
if len(rrCache.Answer) > 0 {
// Copy answers, as we randomize their order a little.
reply.Answer = make([]dns.RR, len(rrCache.Answer))
copy(reply.Answer, rrCache.Answer)
// Randomize the order of the answer records a little to allow dumb clients // Randomize the order of the answer records a little to allow dumb clients
// (who only look at the first record) to reliably connect. // (who only look at the first record) to reliably connect.
for i := range reply.Answer { for i := range reply.Answer {
j := rand.Intn(i + 1) j := rand.Intn(i + 1)
reply.Answer[i], reply.Answer[j] = reply.Answer[j], reply.Answer[i] reply.Answer[i], reply.Answer[j] = reply.Answer[j], reply.Answer[i]
} }
}
return reply return reply
} }
@ -300,7 +302,7 @@ func (rrCache *RRCache) ReplyWithDNS(ctx context.Context, request *dns.Msg) *dns
// GetExtraRRs returns a slice of RRs with additional informational records. // GetExtraRRs returns a slice of RRs with additional informational records.
func (rrCache *RRCache) GetExtraRRs(ctx context.Context, query *dns.Msg) (extra []dns.RR) { func (rrCache *RRCache) GetExtraRRs(ctx context.Context, query *dns.Msg) (extra []dns.RR) {
// Add cache status and source of data. // Add cache status and source of data.
if rrCache.servedFromCache { if rrCache.ServedFromCache {
extra = addExtra(ctx, extra, "served from cache, resolved by "+rrCache.ServerInfo) extra = addExtra(ctx, extra, "served from cache, resolved by "+rrCache.ServerInfo)
} else { } else {
extra = addExtra(ctx, extra, "freshly resolved by "+rrCache.ServerInfo) extra = addExtra(ctx, extra, "freshly resolved by "+rrCache.ServerInfo)
@ -312,10 +314,10 @@ func (rrCache *RRCache) GetExtraRRs(ctx context.Context, query *dns.Msg) (extra
} else { } else {
extra = addExtra(ctx, extra, fmt.Sprintf("record valid for %s", time.Until(time.Unix(rrCache.TTL, 0)).Round(time.Second))) extra = addExtra(ctx, extra, fmt.Sprintf("record valid for %s", time.Until(time.Unix(rrCache.TTL, 0)).Round(time.Second)))
} }
if rrCache.requestingNew { if rrCache.RequestingNew {
extra = addExtra(ctx, extra, "async request to refresh the cache has been started") extra = addExtra(ctx, extra, "async request to refresh the cache has been started")
} }
if rrCache.isBackup { if rrCache.IsBackup {
extra = addExtra(ctx, extra, "this record is served because a fresh request failed") extra = addExtra(ctx, extra, "this record is served because a fresh request failed")
} }

View file

@ -21,7 +21,10 @@ const (
// working vars // working vars
var ( var (
versionExport *versions versionExport *versions
versionExportDB = database.NewInterface(nil) versionExportDB = database.NewInterface(&database.Options{
Local: true,
Internal: true,
})
versionExportHook *database.RegisteredHook versionExportHook *database.RegisteredHook
) )