From be517dd58ff2bcebca9cf952f88384a091b68e9c Mon Sep 17 00:00:00 2001 From: Daniel Date: Mon, 6 May 2019 11:00:10 +0200 Subject: [PATCH] Update network.Packet structure --- firewall/interception/nfqueue/nfqueue.go | 36 ++--- firewall/interception/nfqueue/packet.go | 2 +- firewall/interception/windowskext/packet.go | 2 +- network/packet/const.go | 95 ++++++++++++ network/packet/packet.go | 164 +++++--------------- network/packet/packetinfo.go | 50 ++++++ network/packet/parse.go | 2 +- 7 files changed, 204 insertions(+), 147 deletions(-) create mode 100644 network/packet/const.go create mode 100644 network/packet/packetinfo.go diff --git a/firewall/interception/nfqueue/nfqueue.go b/firewall/interception/nfqueue/nfqueue.go index aedd1850..2f5a7383 100644 --- a/firewall/interception/nfqueue/nfqueue.go +++ b/firewall/interception/nfqueue/nfqueue.go @@ -10,6 +10,8 @@ package nfqueue import "C" import ( + "errors" + "fmt" "net" "os" "runtime" @@ -17,8 +19,6 @@ import ( "syscall" "time" "unsafe" - "errors" - "fmt" "github.com/Safing/portmaster/network/packet" ) @@ -145,9 +145,8 @@ func go_nfq_callback(id uint32, hwproto uint16, hook uint8, mark *uint32, qid := uint16(*qidptr) // nfq := (*NFQueue)(nfqptr) - new_version := version - ipver := packet.IPVersion(new_version) - ipsz := C.int(ipver.ByteSize()) + ipVersion := packet.IPVersion(version) + ipsz := C.int(ipVersion.ByteSize()) bs := C.GoBytes(payload, (C.int)(payload_len)) verdict := make(chan uint32, 1) @@ -164,22 +163,19 @@ func go_nfq_callback(id uint32, hwproto uint16, hook uint8, mark *uint32, // Payload pkt.Payload = bs - // IPHeader - pkt.IPHeader = &packet.IPHeader{ - Version: 4, - Protocol: packet.IPProtocol(protocol), - Tos: tos, - TTL: ttl, - Src: net.IP(C.GoBytes(saddr, ipsz)), - Dst: net.IP(C.GoBytes(daddr, ipsz)), - } + // Info + info := pkt.Info() + info.Version = ipVersion + info.InTunnel = false + info.Protocol = packet.IPProtocol(protocol) - // TCPUDPHeader - pkt.TCPUDPHeader = &packet.TCPUDPHeader{ - SrcPort: sport, - DstPort: dport, - Checksum: checksum, - } + // IPs + info.Src = net.IP(C.GoBytes(saddr, ipsz)) + info.Dst = net.IP(C.GoBytes(daddr, ipsz)) + + // Ports + info.SrcPort = sport + info.DstPort = dport // fmt.Printf("%s queuing packet\n", time.Now().Format("060102 15:04:05.000")) // BUG: "panic: send on closed channel" when shutting down diff --git a/firewall/interception/nfqueue/packet.go b/firewall/interception/nfqueue/packet.go index a05abc6d..51143f81 100644 --- a/firewall/interception/nfqueue/packet.go +++ b/firewall/interception/nfqueue/packet.go @@ -22,7 +22,7 @@ const ( ) type Packet struct { - packet.PacketBase + packet.Base QueueId uint16 Id uint32 diff --git a/firewall/interception/windowskext/packet.go b/firewall/interception/windowskext/packet.go index 408edde4..1356bbaf 100644 --- a/firewall/interception/windowskext/packet.go +++ b/firewall/interception/windowskext/packet.go @@ -12,7 +12,7 @@ import ( // Packet represents an IP packet. type Packet struct { - packet.PacketBase + packet.Base verdictRequest *VerdictRequest verdictSet *abool.AtomicBool diff --git a/network/packet/const.go b/network/packet/const.go new file mode 100644 index 00000000..09955577 --- /dev/null +++ b/network/packet/const.go @@ -0,0 +1,95 @@ +package packet + +import ( + "errors" + "fmt" +) + +type ( + IPVersion uint8 + IPProtocol uint8 + Verdict uint8 + Endpoint bool +) + +const ( + IPv4 = IPVersion(4) + IPv6 = IPVersion(6) + + InBound = true + OutBound = false + + Local = true + Remote = false + + // convenience + IGMP = IPProtocol(2) + RAW = IPProtocol(255) + TCP = IPProtocol(6) + UDP = IPProtocol(17) + ICMP = IPProtocol(1) + ICMPv6 = IPProtocol(58) +) + +const ( + DROP Verdict = iota + BLOCK + ACCEPT + STOLEN + QUEUE + REPEAT + STOP +) + +var ( + ErrFailedToLoadPayload = errors.New("could not load packet payload") +) + +// Returns the byte size of the ip, IPv4 = 4 bytes, IPv6 = 16 +func (v IPVersion) ByteSize() int { + switch v { + case IPv4: + return 4 + case IPv6: + return 16 + } + return 0 +} + +func (v IPVersion) String() string { + switch v { + case IPv4: + return "IPv4" + case IPv6: + return "IPv6" + } + return fmt.Sprintf("", uint8(v)) +} + +func (p IPProtocol) String() string { + switch p { + case RAW: + return "RAW" + case TCP: + return "TCP" + case UDP: + return "UDP" + case ICMP: + return "ICMP" + case ICMPv6: + return "ICMPv6" + case IGMP: + return "IGMP" + } + return fmt.Sprintf("", uint8(p)) +} + +func (v Verdict) String() string { + switch v { + case DROP: + return "DROP" + case ACCEPT: + return "ACCEPT" + } + return fmt.Sprintf("", uint8(v)) +} diff --git a/network/packet/packet.go b/network/packet/packet.go index dac03710..6db8ed20 100644 --- a/network/packet/packet.go +++ b/network/packet/packet.go @@ -3,142 +3,49 @@ package packet import ( - "errors" "fmt" "net" ) -type ( - IPVersion uint8 - IPProtocol uint8 - Verdict uint8 - Endpoint bool -) - -const ( - IPv4 = IPVersion(4) - IPv6 = IPVersion(6) - - InBound = true - OutBound = false - - Local = true - Remote = false - - // convenience - IGMP = IPProtocol(2) - RAW = IPProtocol(255) - TCP = IPProtocol(6) - UDP = IPProtocol(17) - ICMP = IPProtocol(1) - ICMPv6 = IPProtocol(58) -) - -const ( - DROP Verdict = iota - BLOCK - ACCEPT - STOLEN - QUEUE - REPEAT - STOP -) - -var ( - ErrFailedToLoadPayload = errors.New("could not load packet payload") -) - -// Returns the byte size of the ip, IPv4 = 4 bytes, IPv6 = 16 -func (v IPVersion) ByteSize() int { - switch v { - case IPv4: - return 4 - case IPv6: - return 16 - } - return 0 -} - -func (v IPVersion) String() string { - switch v { - case IPv4: - return "IPv4" - case IPv6: - return "IPv6" - } - return fmt.Sprintf("", uint8(v)) -} - -func (p IPProtocol) String() string { - switch p { - case RAW: - return "RAW" - case TCP: - return "TCP" - case UDP: - return "UDP" - case ICMP: - return "ICMP" - case ICMPv6: - return "ICMPv6" - case IGMP: - return "IGMP" - } - return fmt.Sprintf("", uint8(p)) -} - -func (v Verdict) String() string { - switch v { - case DROP: - return "DROP" - case ACCEPT: - return "ACCEPT" - } - return fmt.Sprintf("", uint8(v)) -} - -// PacketInfo holds IP and TCP/UDP header information -type PacketInfo struct { - Direction bool - InTunnel bool - - Version IPVersion - Src, Dst net.IP - Protocol IPProtocol - SrcPort, DstPort uint16 -} - -type PacketBase struct { - info PacketInfo +// Base is a base structure for satisfying the Packet interface. +type Base struct { + info Info linkID string Payload []byte } -func (pkt *PacketBase) Info() *PacketInfo { +// Info returns the packet Info. +func (pkt *Base) Info() *Info { return &pkt.info } -func (pkt *PacketBase) SetPacketInfo(packetInfo PacketInfo) { +// SetPacketInfo sets a new packet Info. This must only used when initializing the packet structure. +func (pkt *Base) SetPacketInfo(packetInfo Info) { pkt.info = packetInfo } -func (pkt *PacketBase) SetInbound() { +// SetInbound sets a the packet direction to inbound. This must only used when initializing the packet structure. +func (pkt *Base) SetInbound() { pkt.info.Direction = true } -func (pkt *PacketBase) SetOutbound() { +// SetOutbound sets a the packet direction to outbound. This must only used when initializing the packet structure. +func (pkt *Base) SetOutbound() { pkt.info.Direction = false } -func (pkt *PacketBase) IsInbound() bool { +// IsInbound checks if the packet is inbound. +func (pkt *Base) IsInbound() bool { return pkt.info.Direction } -func (pkt *PacketBase) IsOutbound() bool { +// IsOutbound checks if the packet is outbound. +func (pkt *Base) IsOutbound() bool { return !pkt.info.Direction } -func (pkt *PacketBase) HasPorts() bool { +// HasPorts checks if the packet has a protocol that uses ports. +func (pkt *Base) HasPorts() bool { switch pkt.info.Protocol { case TCP: return true @@ -148,18 +55,20 @@ func (pkt *PacketBase) HasPorts() bool { return false } -func (pkt *PacketBase) GetPayload() ([]byte, error) { +// GetPayload returns the packet payload. In some cases, this will fetch the payload from the os integration system. +func (pkt *Base) GetPayload() ([]byte, error) { return pkt.Payload, ErrFailedToLoadPayload } -func (pkt *PacketBase) GetLinkID() string { +// GetLinkID returns the link ID for this packet. +func (pkt *Base) GetLinkID() string { if pkt.linkID == "" { pkt.createLinkID() } return pkt.linkID } -func (pkt *PacketBase) createLinkID() { +func (pkt *Base) createLinkID() { if pkt.info.Protocol == TCP || pkt.info.Protocol == UDP { if pkt.info.Direction { pkt.linkID = fmt.Sprintf("%d-%s-%d-%s-%d", pkt.info.Protocol, pkt.info.Dst, pkt.info.DstPort, pkt.info.Src, pkt.info.SrcPort) @@ -175,14 +84,14 @@ func (pkt *PacketBase) createLinkID() { } } -// Matches checks if a the packet matches a given endpoint (remote or local) in protocol, network and port. +// MatchesAddress checks if a the packet matches a given endpoint (remote or local) in protocol, network and port. // // Comparison matrix: // IN OUT // Local Dst Src // Remote Src Dst // -func (pkt *PacketBase) MatchesAddress(remote bool, protocol IPProtocol, network *net.IPNet, port uint16) bool { +func (pkt *Base) MatchesAddress(remote bool, protocol IPProtocol, network *net.IPNet, port uint16) bool { if pkt.info.Protocol != protocol { return false } @@ -204,7 +113,14 @@ func (pkt *PacketBase) MatchesAddress(remote bool, protocol IPProtocol, network return true } -func (pkt *PacketBase) MatchesIP(endpoint bool, network *net.IPNet) bool { +// MatchesIP checks if a the packet matches a given endpoint (remote or local) IP. +// +// Comparison matrix: +// IN OUT +// Local Dst Src +// Remote Src Dst +// +func (pkt *Base) MatchesIP(endpoint bool, network *net.IPNet) bool { if pkt.info.Direction != endpoint { if network.Contains(pkt.info.Src) { return true @@ -219,12 +135,12 @@ func (pkt *PacketBase) MatchesIP(endpoint bool, network *net.IPNet) bool { // FORMATTING -func (pkt *PacketBase) String() string { +func (pkt *Base) String() string { return pkt.FmtPacket() } // FmtPacket returns the most important information about the packet as a string -func (pkt *PacketBase) FmtPacket() string { +func (pkt *Base) FmtPacket() string { if pkt.info.Protocol == TCP || pkt.info.Protocol == UDP { if pkt.info.Direction { return fmt.Sprintf("IN %s %s:%d <-> %s:%d", pkt.info.Protocol, pkt.info.Dst, pkt.info.DstPort, pkt.info.Src, pkt.info.SrcPort) @@ -238,12 +154,12 @@ func (pkt *PacketBase) FmtPacket() string { } // FmtProtocol returns the protocol as a string -func (pkt *PacketBase) FmtProtocol() string { +func (pkt *Base) FmtProtocol() string { return pkt.info.Protocol.String() } // FmtRemoteIP returns the remote IP address as a string -func (pkt *PacketBase) FmtRemoteIP() string { +func (pkt *Base) FmtRemoteIP() string { if pkt.info.Direction { return pkt.info.Src.String() } @@ -251,7 +167,7 @@ func (pkt *PacketBase) FmtRemoteIP() string { } // FmtRemotePort returns the remote port as a string -func (pkt *PacketBase) FmtRemotePort() string { +func (pkt *Base) FmtRemotePort() string { if pkt.info.SrcPort != 0 { if pkt.info.Direction { return fmt.Sprintf("%d", pkt.info.SrcPort) @@ -262,7 +178,7 @@ func (pkt *PacketBase) FmtRemotePort() string { } // FmtRemoteAddress returns the full remote address (protocol, IP, port) as a string -func (pkt *PacketBase) FmtRemoteAddress() string { +func (pkt *Base) FmtRemoteAddress() string { return fmt.Sprintf("%s:%s:%s", pkt.info.Protocol.String(), pkt.FmtRemoteIP(), pkt.FmtRemotePort()) } @@ -279,8 +195,8 @@ type Packet interface { RerouteToTunnel() error // INFO - Info() *PacketInfo - SetPacketInfo(PacketInfo) + Info() *Info + SetPacketInfo(Info) IsInbound() bool IsOutbound() bool SetInbound() diff --git a/network/packet/packetinfo.go b/network/packet/packetinfo.go new file mode 100644 index 00000000..d1dc1ce0 --- /dev/null +++ b/network/packet/packetinfo.go @@ -0,0 +1,50 @@ +// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. + +package packet + +import ( + "net" +) + +// Info holds IP and TCP/UDP header information +type Info struct { + Direction bool + InTunnel bool + + Version IPVersion + Src, Dst net.IP + Protocol IPProtocol + SrcPort, DstPort uint16 +} + +// LocalIP returns the local IP of the packet. +func (pi *Info) LocalIP() net.IP { + if pi.Direction { + return pi.Dst + } + return pi.Src +} + +// RemoteIP returns the remote IP of the packet. +func (pi *Info) RemoteIP() net.IP { + if pi.Direction { + return pi.Src + } + return pi.Dst +} + +// LocalPort returns the local port of the packet. +func (pi *Info) LocalPort() uint16 { + if pi.Direction { + return pi.DstPort + } + return pi.SrcPort +} + +// RemotePort returns the remote port of the packet. +func (pi *Info) RemotePort() uint16 { + if pi.Direction { + return pi.SrcPort + } + return pi.DstPort +} diff --git a/network/packet/parse.go b/network/packet/parse.go index 7987e40f..28815620 100644 --- a/network/packet/parse.go +++ b/network/packet/parse.go @@ -9,7 +9,7 @@ import ( ) // Parse parses an IP packet and saves the information in the given packet object. -func Parse(packetData []byte, packet *PacketBase) error { +func Parse(packetData []byte, packet *Base) error { var parsedPacket gopacket.Packet