diff --git a/service/firewall/interception/windowskext2/handler.go b/service/firewall/interception/windowskext2/handler.go index bb6348dd..166ebbc2 100644 --- a/service/firewall/interception/windowskext2/handler.go +++ b/service/firewall/interception/windowskext2/handler.go @@ -5,11 +5,13 @@ package windowskext import ( "context" + "errors" "fmt" "net" "time" "github.com/safing/portmaster/service/process" + "github.com/safing/portmaster/windows_kext/kextinterface" "github.com/tevino/abool" @@ -32,8 +34,15 @@ func (v *VersionInfo) String() string { func Handler(ctx context.Context, packets chan packet.Packet, bandwidthUpdate chan *packet.BandwidthUpdate) { for { packetInfo, err := RecvVerdictRequest() + + if errors.Is(err, kextinterface.ErrUnexpectedInfoSize) || errors.Is(err, kextinterface.ErrUnexpectedReadError) { + log.Criticalf("unexpected kext info data: %s", err) + continue // Depending on the info type this may not affect the functionality. Try to continue reading the next commands. + } + if err != nil { log.Warningf("failed to get packet from windows kext: %s", err) + // Probably IO error, nothing else we can do. return } diff --git a/windows_kext/kextinterface/info.go b/windows_kext/kextinterface/info.go index 763c3e8e..a2f5cd91 100644 --- a/windows_kext/kextinterface/info.go +++ b/windows_kext/kextinterface/info.go @@ -3,6 +3,7 @@ package kextinterface import ( "encoding/binary" "errors" + "fmt" "io" ) @@ -18,6 +19,7 @@ const ( var ( ErrUnknownInfoType = errors.New("unknown info type") + ErrUnexpectedInfoSize = errors.New("unexpected info size") ErrUnexpectedReadError = errors.New("unexpected read error") ) @@ -135,117 +137,215 @@ type Info struct { BandwidthStats *BandwidthStatsArray } -func RecvInfo(reader io.Reader) (*Info, error) { - var infoType byte - err := binary.Read(reader, binary.LittleEndian, &infoType) +type readHelper struct { + infoType byte + commandSize uint32 + + readSize int + + reader io.Reader +} + +func newReadHelper(reader io.Reader) (*readHelper, error) { + helper := &readHelper{reader: reader} + + err := binary.Read(reader, binary.LittleEndian, &helper.infoType) if err != nil { return nil, err } - // Read size of data - var size uint32 - err = binary.Read(reader, binary.LittleEndian, &size) + err = binary.Read(reader, binary.LittleEndian, &helper.commandSize) if err != nil { return nil, err } + return helper, nil +} + +func (r *readHelper) ReadData(data any) error { + err := binary.Read(r, binary.LittleEndian, data) + if err != nil { + return errors.Join(ErrUnexpectedReadError, err) + } + + if err := r.checkOverRead(); err != nil { + return err + } + + return nil +} + +// Passing size = 0 will read the rest of the command. +func (r *readHelper) ReadBytes(size uint32) ([]byte, error) { + if uint32(r.readSize) >= r.commandSize { + return nil, errors.Join(fmt.Errorf("cannot read more bytes than the command size: %d >= %d", r.readSize, r.commandSize), ErrUnexpectedReadError) + } + + if size == 0 { + size = r.commandSize - uint32(r.readSize) + } + + if r.commandSize < uint32(r.readSize)+size { + return nil, ErrUnexpectedInfoSize + } + + bytes := make([]byte, size) + err := binary.Read(r, binary.LittleEndian, bytes) + if err != nil { + return nil, errors.Join(ErrUnexpectedReadError, err) + } + + return bytes, nil +} + +func (r *readHelper) ReadUntilTheEnd() { + _, _ = r.ReadBytes(0) +} + +func (r *readHelper) checkOverRead() error { + if uint32(r.readSize) > r.commandSize { + return ErrUnexpectedInfoSize + } + + return nil +} + +func (r *readHelper) Read(p []byte) (n int, err error) { + n, err = r.reader.Read(p) + r.readSize += n + return +} + +func RecvInfo(reader io.Reader) (*Info, error) { + helper, err := newReadHelper(reader) + if err != nil { + return nil, err + } + + // Make sure the whole command is read before return. + defer helper.ReadUntilTheEnd() + // Read data - switch infoType { + switch helper.infoType { case InfoConnectionIpv4: { + parseError := fmt.Errorf("failed to parse InfoConnectionIpv4") + newInfo := ConnectionV4{} var fixedSizeValues connectionV4Internal - err = binary.Read(reader, binary.LittleEndian, &fixedSizeValues) + // Read fixed size values. + err = helper.ReadData(&fixedSizeValues) if err != nil { - return nil, errors.Join(ErrUnexpectedReadError, err) + return nil, errors.Join(parseError, err, fmt.Errorf("fixed")) } - // Read size of payload - var size uint32 - err = binary.Read(reader, binary.LittleEndian, &size) + newInfo.connectionV4Internal = fixedSizeValues + // Read size of payload. + var payloadSize uint32 + err = helper.ReadData(&payloadSize) if err != nil { - return nil, errors.Join(ErrUnexpectedReadError, err) + return nil, errors.Join(parseError, err, fmt.Errorf("payloadsize")) } - newInfo := ConnectionV4{connectionV4Internal: fixedSizeValues, Payload: make([]byte, size)} - err = binary.Read(reader, binary.LittleEndian, &newInfo.Payload) - if err != nil { - return nil, errors.Join(ErrUnexpectedReadError, err) + + // Check if there is payload. + if payloadSize > 0 { + // Read payload. + newInfo.Payload, err = helper.ReadBytes(payloadSize) + if err != nil { + return nil, errors.Join(parseError, err, fmt.Errorf("payload")) + } } return &Info{ConnectionV4: &newInfo}, nil } case InfoConnectionIpv6: { - var fixedSizeValues connectionV6Internal - err = binary.Read(reader, binary.LittleEndian, &fixedSizeValues) + parseError := fmt.Errorf("failed to parse InfoConnectionIpv6") + newInfo := ConnectionV6{} + + // Read fixed size values. + err = helper.ReadData(&newInfo.connectionV6Internal) if err != nil { - return nil, errors.Join(ErrUnexpectedReadError, err) + return nil, errors.Join(parseError, err) } - // Read size of payload - var size uint32 - err = binary.Read(reader, binary.LittleEndian, &size) + + // Read size of payload. + var payloadSize uint32 + err = helper.ReadData(&payloadSize) if err != nil { - return nil, errors.Join(ErrUnexpectedReadError, err) + return nil, errors.Join(parseError, err) } - newInfo := ConnectionV6{connectionV6Internal: fixedSizeValues, Payload: make([]byte, size)} - err = binary.Read(reader, binary.LittleEndian, &newInfo.Payload) - if err != nil { - return nil, errors.Join(ErrUnexpectedReadError, err) + + // Check if there is payload. + if payloadSize > 0 { + // Read payload. + newInfo.Payload, err = helper.ReadBytes(payloadSize) + if err != nil { + return nil, errors.Join(parseError, err) + } } + return &Info{ConnectionV6: &newInfo}, nil } case InfoConnectionEndEventV4: { + parseError := fmt.Errorf("failed to parse InfoConnectionEndEventV4") var connectionEnd ConnectionEndV4 - err = binary.Read(reader, binary.LittleEndian, &connectionEnd) + + // Read fixed size values. + err = helper.ReadData(&connectionEnd) if err != nil { - return nil, errors.Join(ErrUnexpectedReadError, err) + return nil, errors.Join(parseError, err) } return &Info{ConnectionEndV4: &connectionEnd}, nil } case InfoConnectionEndEventV6: { + parseError := fmt.Errorf("failed to parse InfoConnectionEndEventV6") var connectionEnd ConnectionEndV6 - err = binary.Read(reader, binary.LittleEndian, &connectionEnd) + + // Read fixed size values. + err = helper.ReadData(&connectionEnd) if err != nil { - return nil, errors.Join(ErrUnexpectedReadError, err) + return nil, errors.Join(parseError, err) } return &Info{ConnectionEndV6: &connectionEnd}, nil } case InfoLogLine: { + parseError := fmt.Errorf("failed to parse InfoLogLine") logLine := LogLine{} // Read severity - err = binary.Read(reader, binary.LittleEndian, &logLine.Severity) + err = helper.ReadData(&logLine.Severity) if err != nil { - return nil, errors.Join(ErrUnexpectedReadError, err) + return nil, errors.Join(parseError, err) } // Read string - line := make([]byte, size-1) // -1 for the severity enum. - err = binary.Read(reader, binary.LittleEndian, &line) + bytes, err := helper.ReadBytes(0) if err != nil { - return nil, errors.Join(ErrUnexpectedReadError, err) + return nil, errors.Join(parseError, err) } - logLine.Line = string(line) + logLine.Line = string(bytes) return &Info{LogLine: &logLine}, nil } case InfoBandwidthStatsV4: { + parseError := fmt.Errorf("failed to parse InfoBandwidthStatsV4") // Read Protocol var protocol uint8 - err = binary.Read(reader, binary.LittleEndian, &protocol) + err = helper.ReadData(&protocol) if err != nil { - return nil, errors.Join(ErrUnexpectedReadError, err) + return nil, errors.Join(parseError, err) } // Read size of array var size uint32 - err = binary.Read(reader, binary.LittleEndian, &size) + err = helper.ReadData(&size) if err != nil { - return nil, errors.Join(ErrUnexpectedReadError, err) + return nil, errors.Join(parseError, err) } // Read array statsArray := make([]BandwidthValueV4, size) for i := range int(size) { - err = binary.Read(reader, binary.LittleEndian, &statsArray[i]) + err = helper.ReadData(&statsArray[i]) if err != nil { - return nil, errors.Join(ErrUnexpectedReadError, err) + return nil, errors.Join(parseError, err) } } @@ -253,24 +353,25 @@ func RecvInfo(reader io.Reader) (*Info, error) { } case InfoBandwidthStatsV6: { + parseError := fmt.Errorf("failed to parse InfoBandwidthStatsV6") // Read Protocol var protocol uint8 - err = binary.Read(reader, binary.LittleEndian, &protocol) + err = helper.ReadData(&protocol) if err != nil { - return nil, errors.Join(ErrUnexpectedReadError, err) + return nil, errors.Join(parseError, err) } // Read size of array var size uint32 - err = binary.Read(reader, binary.LittleEndian, &size) + err = helper.ReadData(&size) if err != nil { - return nil, errors.Join(ErrUnexpectedReadError, err) + return nil, errors.Join(parseError, err) } // Read array statsArray := make([]BandwidthValueV6, size) for i := range int(size) { - err = binary.Read(reader, binary.LittleEndian, &statsArray[i]) + err = helper.ReadData(&statsArray[i]) if err != nil { - return nil, errors.Join(ErrUnexpectedReadError, err) + return nil, errors.Join(parseError, err) } } @@ -278,10 +379,5 @@ func RecvInfo(reader io.Reader) (*Info, error) { } } - // Command not recognized, read until the end of command and return. - // During normal operation this should not happen. - unknownData := make([]byte, size) - _, _ = reader.Read(unknownData) - return nil, ErrUnknownInfoType } diff --git a/windows_kext/kextinterface/protocol_test.go b/windows_kext/kextinterface/protocol_test.go index cf047442..c09cf286 100644 --- a/windows_kext/kextinterface/protocol_test.go +++ b/windows_kext/kextinterface/protocol_test.go @@ -18,8 +18,17 @@ func TestRustInfoFile(t *testing.T) { defer func() { _ = file.Close() }() + first := true for { info, err := RecvInfo(file) + // First info should be with invalid size. + if first { + if !errors.Is(err, ErrUnexpectedInfoSize) { + t.Errorf("unexpected error: %s\n", err) + } + first = false + continue + } if err != nil { if errors.Is(err, ErrUnexpectedReadError) { t.Errorf("unexpected error: %s\n", err) diff --git a/windows_kext/kextinterface/testdata/rust_info_test.bin b/windows_kext/kextinterface/testdata/rust_info_test.bin index 3f9049a9..3b8588c7 100644 Binary files a/windows_kext/kextinterface/testdata/rust_info_test.bin and b/windows_kext/kextinterface/testdata/rust_info_test.bin differ diff --git a/windows_kext/protocol/src/info.rs b/windows_kext/protocol/src/info.rs index b8eb0c79..fc50d589 100644 --- a/windows_kext/protocol/src/info.rs +++ b/windows_kext/protocol/src/info.rs @@ -441,6 +441,22 @@ fn generate_test_info_file() -> Result<(), std::io::Error> { for _ in 0..selected.capacity() { selected.push(enums.choose(&mut rng).unwrap().clone()); } + // Write wrong size data. + let mut info = connection_info_v6( + 1, + 2, + 3, + 4, + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], + [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17], + 5, + 6, + 7, + &[1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + ); + info.assert_size(); + info.0[0] = InfoType::ConnectionIpv4 as u8; + file.write_all(&info.0)?; for value in selected { file.write_all(&match value { @@ -548,5 +564,6 @@ fn generate_test_info_file() -> Result<(), std::io::Error> { } })?; } + return Ok(()); } diff --git a/windows_kext/protocol/testdata/go_command_test.bin b/windows_kext/protocol/testdata/go_command_test.bin index 586c70ad..d518dbf1 100644 Binary files a/windows_kext/protocol/testdata/go_command_test.bin and b/windows_kext/protocol/testdata/go_command_test.bin differ