[service] Add check for kext command size

This commit is contained in:
Vladimir Stoilov 2024-06-28 13:20:18 +03:00
parent 176494550e
commit 4bf1736a83
6 changed files with 186 additions and 55 deletions
service/firewall/interception/windowskext2
windows_kext

View file

@ -5,11 +5,13 @@ package windowskext
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"net" "net"
"time" "time"
"github.com/safing/portmaster/service/process" "github.com/safing/portmaster/service/process"
"github.com/safing/portmaster/windows_kext/kextinterface"
"github.com/tevino/abool" "github.com/tevino/abool"
@ -32,8 +34,15 @@ func (v *VersionInfo) String() string {
func Handler(ctx context.Context, packets chan packet.Packet, bandwidthUpdate chan *packet.BandwidthUpdate) { func Handler(ctx context.Context, packets chan packet.Packet, bandwidthUpdate chan *packet.BandwidthUpdate) {
for { for {
packetInfo, err := RecvVerdictRequest() packetInfo, err := RecvVerdictRequest()
if errors.Is(err, kextinterface.ErrUnexpectedInfoSize) || errors.Is(err, kextinterface.ErrUnexpectedReadError) {
log.Criticalf("unexpected kext info data: %s", err)
continue // Depending on the info type this may not affect the functionality. Try to continue reading the next commands.
}
if err != nil { if err != nil {
log.Warningf("failed to get packet from windows kext: %s", err) log.Warningf("failed to get packet from windows kext: %s", err)
// Probably IO error, nothing else we can do.
return return
} }

View file

@ -3,6 +3,7 @@ package kextinterface
import ( import (
"encoding/binary" "encoding/binary"
"errors" "errors"
"fmt"
"io" "io"
) )
@ -18,6 +19,7 @@ const (
var ( var (
ErrUnknownInfoType = errors.New("unknown info type") ErrUnknownInfoType = errors.New("unknown info type")
ErrUnexpectedInfoSize = errors.New("unexpected info size")
ErrUnexpectedReadError = errors.New("unexpected read error") ErrUnexpectedReadError = errors.New("unexpected read error")
) )
@ -135,117 +137,215 @@ type Info struct {
BandwidthStats *BandwidthStatsArray BandwidthStats *BandwidthStatsArray
} }
func RecvInfo(reader io.Reader) (*Info, error) { type readHelper struct {
var infoType byte infoType byte
err := binary.Read(reader, binary.LittleEndian, &infoType) commandSize uint32
readSize int
reader io.Reader
}
func newReadHelper(reader io.Reader) (*readHelper, error) {
helper := &readHelper{reader: reader}
err := binary.Read(reader, binary.LittleEndian, &helper.infoType)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// Read size of data err = binary.Read(reader, binary.LittleEndian, &helper.commandSize)
var size uint32
err = binary.Read(reader, binary.LittleEndian, &size)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return helper, nil
}
func (r *readHelper) ReadData(data any) error {
err := binary.Read(r, binary.LittleEndian, data)
if err != nil {
return errors.Join(ErrUnexpectedReadError, err)
}
if err := r.checkOverRead(); err != nil {
return err
}
return nil
}
// Passing size = 0 will read the rest of the command.
func (r *readHelper) ReadBytes(size uint32) ([]byte, error) {
if uint32(r.readSize) >= r.commandSize {
return nil, errors.Join(fmt.Errorf("cannot read more bytes than the command size: %d >= %d", r.readSize, r.commandSize), ErrUnexpectedReadError)
}
if size == 0 {
size = r.commandSize - uint32(r.readSize)
}
if r.commandSize < uint32(r.readSize)+size {
return nil, ErrUnexpectedInfoSize
}
bytes := make([]byte, size)
err := binary.Read(r, binary.LittleEndian, bytes)
if err != nil {
return nil, errors.Join(ErrUnexpectedReadError, err)
}
return bytes, nil
}
func (r *readHelper) ReadUntilTheEnd() {
_, _ = r.ReadBytes(0)
}
func (r *readHelper) checkOverRead() error {
if uint32(r.readSize) > r.commandSize {
return ErrUnexpectedInfoSize
}
return nil
}
func (r *readHelper) Read(p []byte) (n int, err error) {
n, err = r.reader.Read(p)
r.readSize += n
return
}
func RecvInfo(reader io.Reader) (*Info, error) {
helper, err := newReadHelper(reader)
if err != nil {
return nil, err
}
// Make sure the whole command is read before return.
defer helper.ReadUntilTheEnd()
// Read data // Read data
switch infoType { switch helper.infoType {
case InfoConnectionIpv4: case InfoConnectionIpv4:
{ {
parseError := fmt.Errorf("failed to parse InfoConnectionIpv4")
newInfo := ConnectionV4{}
var fixedSizeValues connectionV4Internal var fixedSizeValues connectionV4Internal
err = binary.Read(reader, binary.LittleEndian, &fixedSizeValues) // Read fixed size values.
err = helper.ReadData(&fixedSizeValues)
if err != nil { if err != nil {
return nil, errors.Join(ErrUnexpectedReadError, err) return nil, errors.Join(parseError, err, fmt.Errorf("fixed"))
} }
// Read size of payload newInfo.connectionV4Internal = fixedSizeValues
var size uint32 // Read size of payload.
err = binary.Read(reader, binary.LittleEndian, &size) var payloadSize uint32
err = helper.ReadData(&payloadSize)
if err != nil { if err != nil {
return nil, errors.Join(ErrUnexpectedReadError, err) return nil, errors.Join(parseError, err, fmt.Errorf("payloadsize"))
} }
newInfo := ConnectionV4{connectionV4Internal: fixedSizeValues, Payload: make([]byte, size)}
err = binary.Read(reader, binary.LittleEndian, &newInfo.Payload) // Check if there is payload.
if err != nil { if payloadSize > 0 {
return nil, errors.Join(ErrUnexpectedReadError, err) // Read payload.
newInfo.Payload, err = helper.ReadBytes(payloadSize)
if err != nil {
return nil, errors.Join(parseError, err, fmt.Errorf("payload"))
}
} }
return &Info{ConnectionV4: &newInfo}, nil return &Info{ConnectionV4: &newInfo}, nil
} }
case InfoConnectionIpv6: case InfoConnectionIpv6:
{ {
var fixedSizeValues connectionV6Internal parseError := fmt.Errorf("failed to parse InfoConnectionIpv6")
err = binary.Read(reader, binary.LittleEndian, &fixedSizeValues) newInfo := ConnectionV6{}
// Read fixed size values.
err = helper.ReadData(&newInfo.connectionV6Internal)
if err != nil { if err != nil {
return nil, errors.Join(ErrUnexpectedReadError, err) return nil, errors.Join(parseError, err)
} }
// Read size of payload
var size uint32 // Read size of payload.
err = binary.Read(reader, binary.LittleEndian, &size) var payloadSize uint32
err = helper.ReadData(&payloadSize)
if err != nil { if err != nil {
return nil, errors.Join(ErrUnexpectedReadError, err) return nil, errors.Join(parseError, err)
} }
newInfo := ConnectionV6{connectionV6Internal: fixedSizeValues, Payload: make([]byte, size)}
err = binary.Read(reader, binary.LittleEndian, &newInfo.Payload) // Check if there is payload.
if err != nil { if payloadSize > 0 {
return nil, errors.Join(ErrUnexpectedReadError, err) // Read payload.
newInfo.Payload, err = helper.ReadBytes(payloadSize)
if err != nil {
return nil, errors.Join(parseError, err)
}
} }
return &Info{ConnectionV6: &newInfo}, nil return &Info{ConnectionV6: &newInfo}, nil
} }
case InfoConnectionEndEventV4: case InfoConnectionEndEventV4:
{ {
parseError := fmt.Errorf("failed to parse InfoConnectionEndEventV4")
var connectionEnd ConnectionEndV4 var connectionEnd ConnectionEndV4
err = binary.Read(reader, binary.LittleEndian, &connectionEnd)
// Read fixed size values.
err = helper.ReadData(&connectionEnd)
if err != nil { if err != nil {
return nil, errors.Join(ErrUnexpectedReadError, err) return nil, errors.Join(parseError, err)
} }
return &Info{ConnectionEndV4: &connectionEnd}, nil return &Info{ConnectionEndV4: &connectionEnd}, nil
} }
case InfoConnectionEndEventV6: case InfoConnectionEndEventV6:
{ {
parseError := fmt.Errorf("failed to parse InfoConnectionEndEventV6")
var connectionEnd ConnectionEndV6 var connectionEnd ConnectionEndV6
err = binary.Read(reader, binary.LittleEndian, &connectionEnd)
// Read fixed size values.
err = helper.ReadData(&connectionEnd)
if err != nil { if err != nil {
return nil, errors.Join(ErrUnexpectedReadError, err) return nil, errors.Join(parseError, err)
} }
return &Info{ConnectionEndV6: &connectionEnd}, nil return &Info{ConnectionEndV6: &connectionEnd}, nil
} }
case InfoLogLine: case InfoLogLine:
{ {
parseError := fmt.Errorf("failed to parse InfoLogLine")
logLine := LogLine{} logLine := LogLine{}
// Read severity // Read severity
err = binary.Read(reader, binary.LittleEndian, &logLine.Severity) err = helper.ReadData(&logLine.Severity)
if err != nil { if err != nil {
return nil, errors.Join(ErrUnexpectedReadError, err) return nil, errors.Join(parseError, err)
} }
// Read string // Read string
line := make([]byte, size-1) // -1 for the severity enum. bytes, err := helper.ReadBytes(0)
err = binary.Read(reader, binary.LittleEndian, &line)
if err != nil { if err != nil {
return nil, errors.Join(ErrUnexpectedReadError, err) return nil, errors.Join(parseError, err)
} }
logLine.Line = string(line) logLine.Line = string(bytes)
return &Info{LogLine: &logLine}, nil return &Info{LogLine: &logLine}, nil
} }
case InfoBandwidthStatsV4: case InfoBandwidthStatsV4:
{ {
parseError := fmt.Errorf("failed to parse InfoBandwidthStatsV4")
// Read Protocol // Read Protocol
var protocol uint8 var protocol uint8
err = binary.Read(reader, binary.LittleEndian, &protocol) err = helper.ReadData(&protocol)
if err != nil { if err != nil {
return nil, errors.Join(ErrUnexpectedReadError, err) return nil, errors.Join(parseError, err)
} }
// Read size of array // Read size of array
var size uint32 var size uint32
err = binary.Read(reader, binary.LittleEndian, &size) err = helper.ReadData(&size)
if err != nil { if err != nil {
return nil, errors.Join(ErrUnexpectedReadError, err) return nil, errors.Join(parseError, err)
} }
// Read array // Read array
statsArray := make([]BandwidthValueV4, size) statsArray := make([]BandwidthValueV4, size)
for i := range int(size) { for i := range int(size) {
err = binary.Read(reader, binary.LittleEndian, &statsArray[i]) err = helper.ReadData(&statsArray[i])
if err != nil { if err != nil {
return nil, errors.Join(ErrUnexpectedReadError, err) return nil, errors.Join(parseError, err)
} }
} }
@ -253,24 +353,25 @@ func RecvInfo(reader io.Reader) (*Info, error) {
} }
case InfoBandwidthStatsV6: case InfoBandwidthStatsV6:
{ {
parseError := fmt.Errorf("failed to parse InfoBandwidthStatsV6")
// Read Protocol // Read Protocol
var protocol uint8 var protocol uint8
err = binary.Read(reader, binary.LittleEndian, &protocol) err = helper.ReadData(&protocol)
if err != nil { if err != nil {
return nil, errors.Join(ErrUnexpectedReadError, err) return nil, errors.Join(parseError, err)
} }
// Read size of array // Read size of array
var size uint32 var size uint32
err = binary.Read(reader, binary.LittleEndian, &size) err = helper.ReadData(&size)
if err != nil { if err != nil {
return nil, errors.Join(ErrUnexpectedReadError, err) return nil, errors.Join(parseError, err)
} }
// Read array // Read array
statsArray := make([]BandwidthValueV6, size) statsArray := make([]BandwidthValueV6, size)
for i := range int(size) { for i := range int(size) {
err = binary.Read(reader, binary.LittleEndian, &statsArray[i]) err = helper.ReadData(&statsArray[i])
if err != nil { if err != nil {
return nil, errors.Join(ErrUnexpectedReadError, err) return nil, errors.Join(parseError, err)
} }
} }
@ -278,10 +379,5 @@ func RecvInfo(reader io.Reader) (*Info, error) {
} }
} }
// Command not recognized, read until the end of command and return.
// During normal operation this should not happen.
unknownData := make([]byte, size)
_, _ = reader.Read(unknownData)
return nil, ErrUnknownInfoType return nil, ErrUnknownInfoType
} }

View file

@ -18,8 +18,17 @@ func TestRustInfoFile(t *testing.T) {
defer func() { defer func() {
_ = file.Close() _ = file.Close()
}() }()
first := true
for { for {
info, err := RecvInfo(file) info, err := RecvInfo(file)
// First info should be with invalid size.
if first {
if !errors.Is(err, ErrUnexpectedInfoSize) {
t.Errorf("unexpected error: %s\n", err)
}
first = false
continue
}
if err != nil { if err != nil {
if errors.Is(err, ErrUnexpectedReadError) { if errors.Is(err, ErrUnexpectedReadError) {
t.Errorf("unexpected error: %s\n", err) t.Errorf("unexpected error: %s\n", err)

View file

@ -441,6 +441,22 @@ fn generate_test_info_file() -> Result<(), std::io::Error> {
for _ in 0..selected.capacity() { for _ in 0..selected.capacity() {
selected.push(enums.choose(&mut rng).unwrap().clone()); selected.push(enums.choose(&mut rng).unwrap().clone());
} }
// Write wrong size data.
let mut info = connection_info_v6(
1,
2,
3,
4,
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
[2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17],
5,
6,
7,
&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
);
info.assert_size();
info.0[0] = InfoType::ConnectionIpv4 as u8;
file.write_all(&info.0)?;
for value in selected { for value in selected {
file.write_all(&match value { file.write_all(&match value {
@ -548,5 +564,6 @@ fn generate_test_info_file() -> Result<(), std::io::Error> {
} }
})?; })?;
} }
return Ok(()); return Ok(());
} }