add thrift var int encoder/decoder

This commit is contained in:
Zhang Jingqiang 2026-02-07 22:19:22 +08:00
parent 53f491bcb2
commit 29f67f3e40
12 changed files with 289 additions and 59 deletions

7
Cargo.lock generated
View file

@ -1666,7 +1666,6 @@ dependencies = [
"hickory-proto",
"http",
"indicatif",
"integer-encoding",
"itoa",
"openssl-probe",
"quinn",
@ -2444,12 +2443,6 @@ dependencies = [
"libc",
]
[[package]]
name = "integer-encoding"
version = "4.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "14c00403deb17c3221a1fe4fb571b9ed0370b3dcd116553c77fa294a3d918699"
[[package]]
name = "ip_network"
version = "0.4.1"

View file

@ -36,7 +36,6 @@ rustc-hash.workspace = true
concurrent-queue = "2.5"
hex.workspace = true
itoa.workspace = true
integer-encoding = "4.0"
fastrand.workspace = true
base64.workspace = true
hickory-client.workspace = true

View file

@ -4,7 +4,8 @@
*/
use anyhow::anyhow;
use integer_encoding::VarInt;
use g3_types::codec::ThriftVarIntEncoder;
pub(crate) struct CompactMessageBuilder {
name: String,
@ -15,8 +16,10 @@ impl CompactMessageBuilder {
pub(crate) fn new_call(name: &str) -> anyhow::Result<Self> {
// the name length is encoded from an unsigned integer, which is different from
// https://github.com/apache/thrift/blob/master/doc/specs/thrift-compact-protocol.md
let name_len = u32::try_from(name.len()).map_err(|_| anyhow!("too long method name"))?;
let name_len_bytes = name_len.encode_var_vec();
let name_len = i32::try_from(name.len()).map_err(|_| anyhow!("too long method name"))?;
let mut encoder = ThriftVarIntEncoder::default();
let name_len_bytes = encoder.encode_i32(name_len).to_vec();
Ok(CompactMessageBuilder {
name: name.to_string(),
@ -39,10 +42,9 @@ impl CompactMessageBuilder {
// set fixed bits and message type to "Call"
buf.extend_from_slice(&[0x82, 0x21]);
let seq_id_size = seq_id.required_space();
let seq_id_offset = buf.len();
buf.resize(seq_id_offset + seq_id_size, 0);
seq_id.encode_var(&mut buf[seq_id_offset..]);
let mut encoder = ThriftVarIntEncoder::default();
let seq_id_bytes = encoder.encode_i32(seq_id);
buf.extend_from_slice(seq_id_bytes);
buf.extend_from_slice(&self.name_len_bytes);
buf.extend_from_slice(self.name.as_bytes());
@ -75,7 +77,7 @@ mod tests {
assert_eq!(
&buf,
&[
0x0, 0x0, 0x0, 0x9, 0x82, 0x21, 0x1, 0x4, 0x70, 0x69, 0x6e, 0x67, 0x0
0x0, 0x0, 0x0, 0x9, 0x82, 0x21, 0x1, 0x8, 0x70, 0x69, 0x6e, 0x67, 0x0
]
);
}

View file

@ -3,7 +3,7 @@
* Copyright 2025 ByteDance and/or its affiliates.
*/
use integer_encoding::VarInt;
use g3_types::codec::ThriftVarInt32;
use crate::target::thrift::protocol::{ThriftResponseMessage, ThriftResponseMessageParseError};
@ -36,33 +36,19 @@ impl CompactMessageParser {
}
let left = &buf[2..];
let Some((seq_id, nr)) = i32::decode_var(left) else {
return Err(ThriftResponseMessageParseError::InvalidVarIntEncoding(
"seq id",
));
};
if nr == 0 {
return Err(ThriftResponseMessageParseError::InvalidVarIntEncoding(
"seq id",
));
}
let seq_id = ThriftVarInt32::parse(left)
.map_err(|e| ThriftResponseMessageParseError::InvalidVarIntEncoding("seq id", e))?;
let left = &left[nr..];
let left = &left[seq_id.encoded_len()..];
if left.is_empty() {
return Err(ThriftResponseMessageParseError::NoEnoughData);
}
let Some((name_len, nr)) = i32::decode_var(left) else {
return Err(ThriftResponseMessageParseError::InvalidVarIntEncoding(
"name length",
));
};
if nr == 0 {
return Err(ThriftResponseMessageParseError::InvalidVarIntEncoding(
"name length",
));
}
let name_len = ThriftVarInt32::parse(left).map_err(|e| {
ThriftResponseMessageParseError::InvalidVarIntEncoding("name length", e)
})?;
let name_len = usize::try_from(name_len)
let left = &left[name_len.encoded_len()..];
let name_len = usize::try_from(name_len.value())
.map_err(|_| ThriftResponseMessageParseError::InvalidNameLength)?;
if left.len() < name_len {
return Err(ThriftResponseMessageParseError::NoEnoughData);
@ -74,7 +60,7 @@ impl CompactMessageParser {
Ok(ThriftResponseMessage {
method: name.to_string(),
seq_id,
seq_id: seq_id.value(),
encoded_length: data.len(),
})
}

View file

@ -5,6 +5,8 @@
use thiserror::Error;
use g3_types::codec::Leb128DecodeError;
mod binary;
pub(super) use binary::BinaryMessageBuilder;
use binary::BinaryMessageParser;
@ -66,8 +68,8 @@ pub(super) enum ThriftResponseMessageParseError {
InvalidVersion,
#[error("invalid message type {0}")]
InvalidMessageType(u8),
#[error("invalid varint encoding for {0}")]
InvalidVarIntEncoding(&'static str),
#[error("invalid varint encoding for {0}: {1}")]
InvalidVarIntEncoding(&'static str, Leb128DecodeError),
#[error("invalid name length")]
InvalidNameLength,
#[error("invalid name encoding")]

View file

@ -3,11 +3,13 @@
* Copyright 2025 ByteDance and/or its affiliates.
*/
use anyhow::{Context, anyhow};
use integer_encoding::VarInt;
use std::collections::BTreeMap;
use std::convert::TryFrom;
use anyhow::{Context, anyhow};
use g3_types::codec::ThriftVarIntEncoder;
use super::HeaderBufOffsets;
use crate::target::thrift::protocol::ThriftProtocol;
@ -21,14 +23,15 @@ impl TryFrom<String> for StringValue {
type Error = anyhow::Error;
fn try_from(value: String) -> Result<Self, Self::Error> {
let len = i16::try_from(value.len()).map_err(|_| {
let len = i32::try_from(value.len()).map_err(|_| {
anyhow!(
"too long Thrift THeader string value length {}",
value.len()
)
})?;
let mut encoder = ThriftVarIntEncoder::default();
Ok(StringValue {
len_bytes: len.encode_var_vec(),
len_bytes: encoder.encode_i32(len).to_vec(),
value,
})
}
@ -67,6 +70,8 @@ impl ThriftTHeaderBuilder {
let content_offset = buf.len();
let mut encoder = ThriftVarIntEncoder::default();
// PROTOCOL ID (varint, i32)
// See `THeaderProtocolID` in
// https://github.com/apache/thrift/blob/master/lib/go/thrift/header_transport.go
@ -74,17 +79,17 @@ impl ThriftTHeaderBuilder {
ThriftProtocol::Binary => 0i32,
ThriftProtocol::Compact => 2i32,
};
varint_encode(protocol_id, buf);
buf.extend_from_slice(encoder.encode_i32(protocol_id));
// NUM TRANSFORMS (varint, i32)
varint_encode(0i32, buf);
buf.extend_from_slice(encoder.encode_i32(0));
// INFO_KEYVALUE
if !self.info_key_values.is_empty() {
varint_encode(1i32, buf);
buf.extend_from_slice(encoder.encode_i32(1));
let kv_count = i32::try_from(self.info_key_values.len())
.map_err(|_| anyhow!("too many INFO_KEYVALUE headers"))?;
varint_encode(kv_count, buf);
buf.extend_from_slice(encoder.encode_i32(kv_count));
for (k, v) in self.info_key_values.iter() {
buf.extend_from_slice(&k.len_bytes);
buf.extend_from_slice(k.value.as_bytes());
@ -117,15 +122,6 @@ impl ThriftTHeaderBuilder {
}
}
fn varint_encode<T>(v: T, buf: &mut Vec<u8>)
where
T: VarInt,
{
let write_offset = buf.len();
buf.resize(write_offset + v.required_space(), 0);
v.encode_var(&mut buf[write_offset..]);
}
#[cfg(test)]
mod tests {
use super::*;

View file

@ -5,3 +5,6 @@
mod option;
pub use option::OptionExt;
mod zig_zag;
pub use zig_zag::{FromZigZag, ToZigZag};

View file

@ -0,0 +1,85 @@
/*
* SPDX-License-Identifier: Apache-2.0
* Copyright 2026 G3-OSS developers.
*/
pub trait ToZigZag<T> {
fn to_zig_zag(self) -> T;
}
pub trait FromZigZag<T> {
fn from_zig_zag(value: T) -> Self;
}
impl ToZigZag<u32> for i32 {
fn to_zig_zag(self) -> u32 {
((self >> 31) ^ (self << 1)) as u32
}
}
impl FromZigZag<u32> for i32 {
fn from_zig_zag(value: u32) -> Self {
(value >> 1) as i32 ^ -(value as i32 & 1)
}
}
impl ToZigZag<u64> for i64 {
fn to_zig_zag(self) -> u64 {
((self >> 63) ^ (self << 1)) as u64
}
}
impl FromZigZag<u64> for i64 {
fn from_zig_zag(value: u64) -> Self {
(value >> 1) as i64 ^ -(value as i64 & 1)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn encode_i32() {
assert_eq!(0i32.to_zig_zag(), 0);
assert_eq!((-1i32).to_zig_zag(), 1);
assert_eq!(1i32.to_zig_zag(), 2);
assert_eq!((-2i32).to_zig_zag(), 3);
assert_eq!(2i32.to_zig_zag(), 4);
assert_eq!(i32::MAX.to_zig_zag(), u32::MAX - 1);
assert_eq!(i32::MIN.to_zig_zag(), u32::MAX);
}
#[test]
fn decode_u32() {
assert_eq!(i32::from_zig_zag(0), 0);
assert_eq!(i32::from_zig_zag(1), -1);
assert_eq!(i32::from_zig_zag(2), 1);
assert_eq!(i32::from_zig_zag(3), -2);
assert_eq!(i32::from_zig_zag(4), 2);
assert_eq!(i32::from_zig_zag(u32::MAX - 1), i32::MAX);
assert_eq!(i32::from_zig_zag(u32::MAX), i32::MIN);
}
#[test]
fn encode_i64() {
assert_eq!(0i64.to_zig_zag(), 0);
assert_eq!((-1i64).to_zig_zag(), 1);
assert_eq!(1i64.to_zig_zag(), 2);
assert_eq!((-2i64).to_zig_zag(), 3);
assert_eq!(2i64.to_zig_zag(), 4);
assert_eq!(i64::MAX.to_zig_zag(), u64::MAX - 1);
assert_eq!(i64::MIN.to_zig_zag(), u64::MAX);
}
#[test]
fn decode_u64() {
assert_eq!(i64::from_zig_zag(0), 0);
assert_eq!(i64::from_zig_zag(1), -1);
assert_eq!(i64::from_zig_zag(2), 1);
assert_eq!(i64::from_zig_zag(3), -2);
assert_eq!(i64::from_zig_zag(4), 2);
assert_eq!(i64::from_zig_zag(u64::MAX - 1), i64::MAX);
assert_eq!(i64::from_zig_zag(u64::MAX), i64::MIN);
}
}

View file

@ -0,0 +1,111 @@
/*
* SPDX-License-Identifier: Apache-2.0
* Copyright 2026 G3-OSS developers.
*/
use thiserror::Error;
#[derive(Debug, Error)]
pub enum Leb128DecodeError {
#[error("need more data")]
NeedMoreData,
#[error("no ending byte found")]
NoEndFound,
}
pub struct Leb128<T> {
value: T,
encoded_len: usize,
}
impl<T> Leb128<T> {
pub fn encoded_len(&self) -> usize {
self.encoded_len
}
}
impl<T: Copy> Leb128<T> {
pub fn value(&self) -> T {
self.value
}
}
impl Leb128<u32> {
pub fn decode(data: &[u8]) -> Result<Self, Leb128DecodeError> {
if data.is_empty() {
return Err(Leb128DecodeError::NeedMoreData);
}
let bv = data[0] & 0x7F;
if data[0] & 0x80 == 0 {
return Ok(Leb128 {
value: bv as u32,
encoded_len: 1,
});
}
let mut value = bv as u32;
let mut encoded_len = 1;
let mut total_bits = 7;
let left = &data[1..];
for b in left {
encoded_len += 1;
let bv = *b & 0x7f;
value |= (bv as u32) << total_bits;
if (*b & 0x80) == 0 {
// 5 * 7 = 32, so no need to check bits for the last byte
return Ok(Leb128 { value, encoded_len });
} else {
total_bits += 7;
if total_bits > 32 {
return Err(Leb128DecodeError::NoEndFound);
}
}
}
Err(Leb128DecodeError::NeedMoreData)
}
}
#[derive(Debug, Clone, Default)]
pub struct Leb128Encoder {
data: [u8; 10],
}
impl Leb128Encoder {
pub fn encode_u32(&mut self, mut data: u32) -> &[u8] {
let mut offset = 0;
loop {
let bv = (data & 0x7f) as u8;
data >>= 7;
if data == 0 {
self.data[offset] = bv;
return &self.data[0..=offset];
} else {
self.data[offset] = bv | 0x80;
offset += 1;
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn decode_u32() {
let v = Leb128::<u32>::decode(&[0x01]).unwrap();
assert_eq!(v.value, 1);
let v = Leb128::<u32>::decode(&[0xE5, 0x8E, 0x26]).unwrap();
assert_eq!(v.value, 624485);
}
#[test]
fn encode_u32() {
let mut encoder = Leb128Encoder::default();
assert_eq!(encoder.encode_u32(1), &[0x01]);
assert_eq!(encoder.encode_u32(624485), &[0xE5, 0x8E, 0x26]);
}
}

View file

@ -6,8 +6,14 @@
mod tlv;
pub use tlv::{T1L2BVParse, TlvParse};
mod leb128;
pub use leb128::{Leb128, Leb128DecodeError, Leb128Encoder};
mod ber;
pub use ber::*;
mod ldap;
pub use ldap::*;
mod thrift;
pub use thrift::*;

View file

@ -0,0 +1,7 @@
/*
* SPDX-License-Identifier: Apache-2.0
* Copyright 2026 G3-OSS developers.
*/
mod var_int;
pub use var_int::{ThriftVarInt32, ThriftVarIntEncoder};

View file

@ -0,0 +1,40 @@
/*
* SPDX-License-Identifier: Apache-2.0
* Copyright 2026 G3-OSS developers.
*/
use g3_std_ext::core::{FromZigZag, ToZigZag};
use crate::codec::{Leb128, Leb128DecodeError, Leb128Encoder};
pub struct ThriftVarInt32 {
leb128: Leb128<u32>,
}
impl ThriftVarInt32 {
pub fn parse(data: &[u8]) -> Result<ThriftVarInt32, Leb128DecodeError> {
let leb128 = Leb128::decode(data)?;
Ok(ThriftVarInt32 { leb128 })
}
pub fn value(&self) -> i32 {
let uv = self.leb128.value();
i32::from_zig_zag(uv)
}
pub fn encoded_len(&self) -> usize {
self.leb128.encoded_len()
}
}
#[derive(Default)]
pub struct ThriftVarIntEncoder {
leb128: Leb128Encoder,
}
impl ThriftVarIntEncoder {
pub fn encode_i32(&mut self, v: i32) -> &[u8] {
let uv = v.to_zig_zag();
self.leb128.encode_u32(uv)
}
}