diff --git a/firewall/interception.go b/firewall/interception.go index c87a8c57..91f3ba26 100644 --- a/firewall/interception.go +++ b/firewall/interception.go @@ -79,20 +79,9 @@ func interceptionPrep() (err error) { func interceptionStart() error { startAPIAuth() - interceptionModule.StartWorker("stat logger", func(ctx context.Context) error { - statLogger() - return nil - }) - - interceptionModule.StartWorker("packet handler", func(ctx context.Context) error { - run() - return nil - }) - - interceptionModule.StartWorker("ports state cleaner", func(ctx context.Context) error { - portsInUseCleaner() - return nil - }) + interceptionModule.StartWorker("stat logger", statLogger) + interceptionModule.StartWorker("packet handler", packetHandler) + interceptionModule.StartWorker("ports state cleaner", portsInUseCleaner) return interception.Start() } @@ -328,22 +317,22 @@ func issueVerdict(conn *network.Connection, pkt packet.Packet, verdict network.V // return // } -func run() { +func packetHandler(ctx context.Context) error { for { select { - case <-interceptionModule.Stopping(): - return + case <-ctx.Done(): + return nil case pkt := <-interception.Packets: handlePacket(pkt) } } } -func statLogger() { +func statLogger(ctx context.Context) error { for { select { - case <-interceptionModule.Stopping(): - return + case <-ctx.Done(): + return nil case <-time.After(10 * time.Second): log.Tracef( "filter: packets accepted %d, blocked %d, dropped %d, failed %d", diff --git a/firewall/interception/nfqexp/nfqexp.go b/firewall/interception/nfqexp/nfqexp.go index 89eced06..45c358e6 100644 --- a/firewall/interception/nfqexp/nfqexp.go +++ b/firewall/interception/nfqexp/nfqexp.go @@ -69,7 +69,7 @@ func New(qid uint16, v6 bool) (*Queue, error) { pkt.Payload = *attrs.Payload } - if err := pmpacket.Parse(pkt.Payload, &pkt.Base); err != nil { + if err := pmpacket.Parse(pkt.Payload, pkt.Info()); err != nil { log.Warningf("nfqexp: failed to parse payload: %s", err) _ = pkt.Drop() return 0 diff --git a/firewall/interception/nfqueue/nfqueue.go b/firewall/interception/nfqueue/nfqueue.go index 41c05346..545bcf2f 100644 --- a/firewall/interception/nfqueue/nfqueue.go +++ b/firewall/interception/nfqueue/nfqueue.go @@ -12,7 +12,6 @@ import "C" import ( "errors" "fmt" - "net" "os" "runtime" "sync" @@ -20,6 +19,7 @@ import ( "time" "unsafe" + "github.com/safing/portbase/log" "github.com/safing/portmaster/network/packet" ) @@ -155,8 +155,6 @@ func go_nfq_callback(id uint32, hwproto uint16, hook uint8, mark *uint32, qid := *qidptr // nfq := (*NFQueue)(nfqptr) - ipVersion := packet.IPVersion(version) - ipsz := C.int(ipVersion.ByteSize()) bs := C.GoBytes(payload, (C.int)(payloadLen)) verdict := make(chan uint32, 1) @@ -173,19 +171,11 @@ func go_nfq_callback(id uint32, hwproto uint16, hook uint8, mark *uint32, // Payload pkt.Payload = bs - // Info - info := pkt.Info() - info.Version = ipVersion - info.InTunnel = false - info.Protocol = packet.IPProtocol(protocol) - - // IPs - info.Src = net.IP(C.GoBytes(saddr, ipsz)) - info.Dst = net.IP(C.GoBytes(daddr, ipsz)) - - // Ports - info.SrcPort = sport - info.DstPort = dport + if err := packet.Parse(bs, pkt.Info()); err != nil { + log.Warningf("nfqueue: failed to parse packet: %s; dropping", err) + *mark = 1702 + return queues[qid].DefaultVerdict + } // fmt.Printf("%s queuing packet\n", time.Now().Format("060102 15:04:05.000")) // BUG: "panic: send on closed channel" when shutting down diff --git a/firewall/ports.go b/firewall/ports.go index c0e3eb65..d8a61660 100644 --- a/firewall/ports.go +++ b/firewall/ports.go @@ -1,6 +1,7 @@ package firewall import ( + "context" "sync" "time" @@ -69,11 +70,11 @@ func GetPermittedPort() uint16 { return 0 } -func portsInUseCleaner() { +func portsInUseCleaner(ctx context.Context) error { for { select { - case <-interceptionModule.Stopping(): - return + case <-ctx.Done(): + return nil case <-time.After(cleanerTickDuration): cleanPortsInUse() } diff --git a/network/packet/const.go b/network/packet/const.go index 6107d2e7..374ef57a 100644 --- a/network/packet/const.go +++ b/network/packet/const.go @@ -23,12 +23,13 @@ const ( InBound = true OutBound = false - ICMP = IPProtocol(1) - IGMP = IPProtocol(2) - TCP = IPProtocol(6) - UDP = IPProtocol(17) - ICMPv6 = IPProtocol(58) - RAW = IPProtocol(255) + ICMP = IPProtocol(1) + IGMP = IPProtocol(2) + TCP = IPProtocol(6) + UDP = IPProtocol(17) + ICMPv6 = IPProtocol(58) + UDPLite = IPProtocol(136) + RAW = IPProtocol(255) ) // Verdicts @@ -78,6 +79,8 @@ func (p IPProtocol) String() string { return "TCP" case UDP: return "UDP" + case UDPLite: + return "UDPLite" case ICMP: return "ICMP" case ICMPv6: diff --git a/network/packet/packet.go b/network/packet/packet.go index 8076ed69..315292ca 100644 --- a/network/packet/packet.go +++ b/network/packet/packet.go @@ -59,7 +59,7 @@ func (pkt *Base) HasPorts() bool { switch pkt.info.Protocol { case TCP: return true - case UDP: + case UDP, UDPLite: return true } return false diff --git a/network/packet/parse.go b/network/packet/parse.go index 28815620..1ceb8613 100644 --- a/network/packet/parse.go +++ b/network/packet/parse.go @@ -8,84 +8,141 @@ import ( "github.com/google/gopacket/layers" ) +var layerType2IPProtocol map[gopacket.LayerType]IPProtocol + +func genIPProtocolFromLayerType() { + layerType2IPProtocol = make(map[gopacket.LayerType]IPProtocol) + for k, v := range layers.IPProtocolMetadata { + layerType2IPProtocol[v.LayerType] = IPProtocol(k) + } +} + +func parseIPv4(packet gopacket.Packet, info *Info) error { + if ipv4, ok := packet.NetworkLayer().(*layers.IPv4); ok { + info.Version = IPv4 + info.Src = ipv4.SrcIP + info.Dst = ipv4.DstIP + info.Protocol = IPProtocol(ipv4.Protocol) + } + return nil +} + +func parseIPv6(packet gopacket.Packet, info *Info) error { + if ipv6, ok := packet.NetworkLayer().(*layers.IPv6); ok { + info.Version = IPv6 + info.Src = ipv6.SrcIP + info.Dst = ipv6.DstIP + // we set Protocol to NextHeader as a fallback. If TCP or + // UDP layers are detected (somewhere in the list of options) + // the Protocol field is adjusted correctly. + info.Protocol = IPProtocol(ipv6.NextHeader) + } + return nil +} + +func parseTCP(packet gopacket.Packet, info *Info) error { + if tcp, ok := packet.TransportLayer().(*layers.TCP); ok { + info.Protocol = TCP + info.SrcPort = uint16(tcp.SrcPort) + info.DstPort = uint16(tcp.DstPort) + } + return nil +} + +func parseUDP(packet gopacket.Packet, info *Info) error { + if udp, ok := packet.TransportLayer().(*layers.UDP); ok { + info.Protocol = UDP + info.SrcPort = uint16(udp.SrcPort) + info.DstPort = uint16(udp.DstPort) + } + return nil +} + +func parseUDPLite(packet gopacket.Packet, info *Info) error { + if udpLite, ok := packet.TransportLayer().(*layers.UDPLite); ok { + info.Protocol = UDPLite + info.SrcPort = uint16(udpLite.SrcPort) + info.DstPort = uint16(udpLite.DstPort) + } + return nil +} + +func parseICMPv4(packet gopacket.Packet, info *Info) error { + if icmp, ok := packet.Layer(layers.LayerTypeICMPv4).(*layers.ICMPv4); ok { + info.Protocol = ICMP + _ = icmp + } + return nil +} + +func parseICMPv6(packet gopacket.Packet, info *Info) error { + if icmp6, ok := packet.Layer(layers.LayerTypeICMPv6).(*layers.ICMPv6); ok { + info.Protocol = ICMPv6 + _ = icmp6 + } + return nil +} + +func parseIGMP(packet gopacket.Packet, info *Info) error { + // gopacket uses LayerTypeIGMP for v1, v2 and v3 and may thus + // either return layers.IGMP or layers.IGMPv1or2 + if layer := packet.Layer(layers.LayerTypeIGMP); layer != nil { + info.Protocol = IGMP + } + return nil +} + +func checkError(packet gopacket.Packet, _ *Info) error { + if err := packet.ErrorLayer(); err != nil { + return err.Error() + } + return nil +} + // Parse parses an IP packet and saves the information in the given packet object. -func Parse(packetData []byte, packet *Base) error { - - var parsedPacket gopacket.Packet - +func Parse(packetData []byte, pktInfo *Info) error { if len(packetData) == 0 { return errors.New("empty packet") } - switch packetData[0] >> 4 { + ipVersion := packetData[0] >> 4 + var networkLayerType gopacket.LayerType + + switch ipVersion { case 4: - parsedPacket = gopacket.NewPacket(packetData, layers.LayerTypeIPv4, gopacket.DecodeOptions{Lazy: true, NoCopy: true}) - if ipv4Layer := parsedPacket.Layer(layers.LayerTypeIPv4); ipv4Layer != nil { - ipv4, _ := ipv4Layer.(*layers.IPv4) - packet.info.Version = IPv4 - packet.info.Protocol = IPProtocol(ipv4.Protocol) - packet.info.Src = ipv4.SrcIP - packet.info.Dst = ipv4.DstIP - } else { - var err error - if errLayer := parsedPacket.ErrorLayer(); errLayer != nil { - err = errLayer.Error() - } - return fmt.Errorf("failed to parse IPv4 packet: %s", err) - } + networkLayerType = layers.LayerTypeIPv4 case 6: - parsedPacket = gopacket.NewPacket(packetData, layers.LayerTypeIPv6, gopacket.DecodeOptions{Lazy: true, NoCopy: true}) - if ipv6Layer := parsedPacket.Layer(layers.LayerTypeIPv6); ipv6Layer != nil { - ipv6, _ := ipv6Layer.(*layers.IPv6) - packet.info.Version = IPv6 - packet.info.Protocol = IPProtocol(ipv6.NextHeader) - packet.info.Src = ipv6.SrcIP - packet.info.Dst = ipv6.DstIP - } else { - var err error - if errLayer := parsedPacket.ErrorLayer(); errLayer != nil { - err = errLayer.Error() - } - return fmt.Errorf("failed to parse IPv6 packet: %s", err) - } + networkLayerType = layers.LayerTypeIPv6 default: - return errors.New("unknown IP version") + return fmt.Errorf("unknown IP version or network protocol: %02x", ipVersion) } - switch packet.info.Protocol { - case TCP: - if tcpLayer := parsedPacket.Layer(layers.LayerTypeTCP); tcpLayer != nil { - tcp, _ := tcpLayer.(*layers.TCP) - packet.info.SrcPort = uint16(tcp.SrcPort) - packet.info.DstPort = uint16(tcp.DstPort) - } else { - var err error - if errLayer := parsedPacket.ErrorLayer(); errLayer != nil { - err = errLayer.Error() - } - return fmt.Errorf("could not parse TCP layer: %s", err) + packet := gopacket.NewPacket(packetData, networkLayerType, gopacket.DecodeOptions{ + Lazy: true, + NoCopy: true, + }) + + availableDecoders := []func(gopacket.Packet, *Info) error{ + parseIPv4, + parseIPv6, + parseTCP, + parseUDP, + //parseUDPLite, // we don't yet support udplite + parseICMPv4, + parseICMPv6, + parseIGMP, + checkError, + } + + for _, dec := range availableDecoders { + if err := dec(packet, pktInfo); err != nil { + return err } - case UDP: - if udpLayer := parsedPacket.Layer(layers.LayerTypeUDP); udpLayer != nil { - udp, _ := udpLayer.(*layers.UDP) - packet.info.SrcPort = uint16(udp.SrcPort) - packet.info.DstPort = uint16(udp.DstPort) - } else { - var err error - if errLayer := parsedPacket.ErrorLayer(); errLayer != nil { - err = errLayer.Error() - } - return fmt.Errorf("could not parse UDP layer: %s", err) - } - } - - if appLayer := parsedPacket.ApplicationLayer(); appLayer != nil { - packet.Payload = appLayer.Payload() - } - - if errLayer := parsedPacket.ErrorLayer(); errLayer != nil { - return errLayer.Error() } return nil } + +func init() { + genIPProtocolFromLayerType() +}