diff --git a/g3proxy/doc/protocol/helper/ip_locate.rst b/g3proxy/doc/protocol/helper/ip_locate.rst index b40f6a1b..3611eb93 100644 --- a/g3proxy/doc/protocol/helper/ip_locate.rst +++ b/g3proxy/doc/protocol/helper/ip_locate.rst @@ -43,19 +43,54 @@ The target ip address as specified in the request. This should be present if it's a response to a request, or absent if it's a push response. -location --------- - -**optional**, **id**: 2, **type**: :ref:`ip location ` - -Set the IP location value. - ttl --- -**optional**, **id**: 3, **type**: u32 +**optional**, **id**: 2, **type**: u32 Set the expire ttl of the response. If not set, the :ref:`default expire ttl ` config will take effect. + +network +------- + +**required**, **id**: 3, **type**: :ref:`ip network str ` + +Set the registered network address. + +country +------- + +**optional**, **id**: 4, **type**: :ref:`iso country code ` + +Set the country. + +continent +--------- + +**optional**, **id**: 5, **type**: :ref:`continent code ` + +Set the continent + +as_number +--------- + +**optional**, **id**: 6, **type**: u32 + +Set the AS Number. + +isp_name +-------- + +**optional**, **id**: 7, **type**: str + +Set the name of it's ISP. + +isp_domain +---------- + +**optional**, **id**: 8, **type**: str + +Set the domain of it's ISP. diff --git a/lib/g3-geoip/src/continent.rs b/lib/g3-geoip/src/continent.rs index 47b46d80..c6f9bfe9 100644 --- a/lib/g3-geoip/src/continent.rs +++ b/lib/g3-geoip/src/continent.rs @@ -27,6 +27,8 @@ const ALL_CONTINENT_NAMES: &[&str] = &[ "South America", ]; +const ALL_CONTINENT_CODES: &[&str] = &["AF", "AN", "AS", "EU", "NA", "OC", "SA"]; + #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] #[repr(u8)] pub enum ContinentCode { @@ -44,6 +46,10 @@ impl ContinentCode { ALL_CONTINENT_NAMES[*self as usize] } + pub fn code(&self) -> &'static str { + ALL_CONTINENT_CODES[*self as usize] + } + pub fn variant_count() -> usize { Self::SA as usize } diff --git a/lib/g3-ip-locate/src/lib.rs b/lib/g3-ip-locate/src/lib.rs index e642847a..c32cb237 100644 --- a/lib/g3-ip-locate/src/lib.rs +++ b/lib/g3-ip-locate/src/lib.rs @@ -40,9 +40,10 @@ mod protocol; pub use protocol::{request_key, request_key_id, response_key, response_key_id}; mod request; +pub use request::Request; mod response; -use response::Response; +pub use response::Response; struct CacheQueryRequest { ip: IpAddr, diff --git a/lib/g3-ip-locate/src/protocol.rs b/lib/g3-ip-locate/src/protocol.rs index c95841e1..a87b4545 100644 --- a/lib/g3-ip-locate/src/protocol.rs +++ b/lib/g3-ip-locate/src/protocol.rs @@ -25,11 +25,21 @@ pub mod request_key_id { pub mod response_key { pub const IP: &str = "ip"; pub const TTL: &str = "ttl"; - pub const LOCATION: &str = "location"; + pub const NETWORK: &str = "network"; + pub const COUNTRY: &str = "country"; + pub const CONTINENT: &str = "continent"; + pub const AS_NUMBER: &str = "as_number"; + pub const ISP_NAME: &str = "isp_name"; + pub const ISP_DOMAIN: &str = "isp_domain"; } pub mod response_key_id { pub const IP: u64 = 1; pub const TTL: u64 = 2; - pub const LOCATION: u64 = 3; + pub const NETWORK: u64 = 3; + pub const COUNTRY: u64 = 4; + pub const CONTINENT: u64 = 5; + pub const AS_NUMBER: u64 = 6; + pub const ISP_NAME: u64 = 7; + pub const ISP_DOMAIN: u64 = 8; } diff --git a/lib/g3-ip-locate/src/query.rs b/lib/g3-ip-locate/src/query.rs index b96534f9..2185cfff 100644 --- a/lib/g3-ip-locate/src/query.rs +++ b/lib/g3-ip-locate/src/query.rs @@ -27,13 +27,15 @@ use log::warn; use tokio::io::ReadBuf; use tokio::net::UdpSocket; -use super::{IpLocateServiceConfig, IpLocationCacheResponse, IpLocationQueryHandle, Response}; +use super::{ + IpLocateServiceConfig, IpLocationCacheResponse, IpLocationQueryHandle, Request, Response, +}; pub(crate) struct IpLocationQueryRuntime { socket: UdpSocket, query_handle: IpLocationQueryHandle, read_buffer: Box<[u8]>, - write_queue: VecDeque, + write_queue: VecDeque<(IpAddr, Vec)>, default_expire_ttl: u32, maximum_expire_ttl: u32, query_wait: Duration, @@ -68,8 +70,10 @@ impl IpLocationQueryRuntime { fn handle_req(&mut self, ip: IpAddr) { if self.query_handle.should_send_raw_query(ip, self.query_wait) { - // TODO encode request - self.write_queue.push_back(ip); + match Request::encode_new(ip) { + Ok(buf) => self.write_queue.push_back((ip, buf)), + Err(_) => self.send_empty_result(ip, self.default_expire_ttl, false), + } } } @@ -119,14 +123,10 @@ impl IpLocationQueryRuntime { } // send req from write queue - while let Some(ip) = self.write_queue.pop_front() { - let r = match ip { - IpAddr::V4(v4) => self.socket.poll_send(cx, &v4.octets()), - IpAddr::V6(v6) => self.socket.poll_send(cx, &v6.octets()), - }; - match r { + while let Some((ip, buf)) = self.write_queue.pop_front() { + match self.socket.poll_send(cx, &buf) { Poll::Pending => { - self.write_queue.push_front(ip); + self.write_queue.push_front((ip, buf)); break; } Poll::Ready(Ok(_)) => {} diff --git a/lib/g3-ip-locate/src/request.rs b/lib/g3-ip-locate/src/request.rs index 867e1a18..1920efea 100644 --- a/lib/g3-ip-locate/src/request.rs +++ b/lib/g3-ip-locate/src/request.rs @@ -13,3 +13,81 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + +use std::net::IpAddr; + +use anyhow::{anyhow, Context}; +use rmpv::ValueRef; + +use super::{request_key, request_key_id}; + +#[derive(Default)] +pub struct Request { + ip: Option, +} + +impl Request { + fn set(&mut self, k: ValueRef, v: ValueRef) -> anyhow::Result<()> { + match k { + ValueRef::String(s) => { + let key = s + .as_str() + .ok_or_else(|| anyhow!("invalid string key {k}"))?; + match g3_msgpack::key::normalize(key).as_str() { + request_key::IP => self + .set_ip_value(v) + .context(format!("invalid ip address value for key {key}")), + _ => Err(anyhow!("invalid key {key}")), + } + } + ValueRef::Integer(i) => { + let key_id = i.as_u64().ok_or_else(|| anyhow!("invalid u64 key {k}"))?; + match key_id { + request_key_id::IP => self + .set_ip_value(v) + .context(format!("invalid ip address value for key id {key_id}")), + _ => Err(anyhow!("invalid key id {key_id}")), + } + } + _ => Err(anyhow!("unsupported key type: {k}")), + } + } + + fn set_ip_value(&mut self, v: ValueRef) -> anyhow::Result<()> { + let ip = g3_msgpack::value::as_ipaddr(&v)?; + self.ip = Some(ip); + Ok(()) + } + + #[inline] + pub fn ip(&self) -> Option { + self.ip + } + + pub fn parse_req(mut data: &[u8]) -> anyhow::Result { + let v = rmpv::decode::read_value_ref(&mut data) + .map_err(|e| anyhow!("invalid req data: {e}"))?; + + let mut request = Request::default(); + if let ValueRef::Map(map) = v { + for (k, v) in map { + request.set(k, v)?; + } + } else { + request + .set_ip_value(v) + .context("invalid single host string value")?; + } + + Ok(request) + } + + pub fn encode_new(ip: IpAddr) -> Result, ()> { + let ip = ip.to_string(); + let value = ValueRef::String(ip.as_str().into()); + + let mut buf = Vec::with_capacity(320); + rmpv::encode::write_value_ref(&mut buf, &value).map_err(|_| ())?; + Ok(buf) + } +} diff --git a/lib/g3-ip-locate/src/response.rs b/lib/g3-ip-locate/src/response.rs index 6af0a40b..3104920c 100644 --- a/lib/g3-ip-locate/src/response.rs +++ b/lib/g3-ip-locate/src/response.rs @@ -19,14 +19,14 @@ use std::net::IpAddr; use anyhow::{anyhow, Context}; use rmpv::ValueRef; -use g3_geoip::IpLocation; +use g3_geoip::{IpLocation, IpLocationBuilder}; use super::{response_key, response_key_id}; #[derive(Default)] -pub(super) struct Response { +pub struct Response { ip: Option, - location: Option, + location_builder: IpLocationBuilder, ttl: Option, } @@ -48,10 +48,35 @@ impl Response { .context(format!("invalid u32 value for key {key}"))?; self.ttl = Some(ttl); } - response_key::LOCATION => { - let location = g3_msgpack::value::as_ip_location(&v) - .context(format!("invalid ip location value for key {key}"))?; - self.location = Some(location); + response_key::NETWORK => { + let network = g3_msgpack::value::as_ip_network(&v) + .context(format!("invalid ip network value for key {key}"))?; + self.location_builder.set_network(network); + } + response_key::COUNTRY => { + let country = g3_msgpack::value::as_iso_country_code(&v) + .context(format!("invalid iso country code value for key {key}"))?; + self.location_builder.set_country(country); + } + response_key::CONTINENT => { + let continent = g3_msgpack::value::as_continent_code(&v) + .context(format!("invalid continent code value for key {key}"))?; + self.location_builder.set_continent(continent); + } + response_key::AS_NUMBER => { + let number = g3_msgpack::value::as_u32(&v) + .context(format!("invalid u32 value for key {key}"))?; + self.location_builder.set_as_number(number); + } + response_key::ISP_NAME => { + let name = g3_msgpack::value::as_string(&v) + .context(format!("invalid string value for key {key}"))?; + self.location_builder.set_isp_name(name); + } + response_key::ISP_DOMAIN => { + let domain = g3_msgpack::value::as_string(&v) + .context(format!("invalid string value for key {key}"))?; + self.location_builder.set_isp_domain(domain); } _ => {} // ignore unknown keys } @@ -69,10 +94,36 @@ impl Response { .context(format!("invalid u32 value for key id {key_id}"))?; self.ttl = Some(ttl); } - response_key_id::LOCATION => { - let location = g3_msgpack::value::as_ip_location(&v) - .context(format!("invalid ip location value for key id {key_id}"))?; - self.location = Some(location); + response_key_id::NETWORK => { + let network = g3_msgpack::value::as_ip_network(&v) + .context(format!("invalid ip network value for key id {key_id}"))?; + self.location_builder.set_network(network); + } + response_key_id::COUNTRY => { + let country = g3_msgpack::value::as_iso_country_code(&v).context( + format!("invalid iso country code value for key id {key_id}"), + )?; + self.location_builder.set_country(country); + } + response_key_id::CONTINENT => { + let continent = g3_msgpack::value::as_continent_code(&v) + .context(format!("invalid continent code value for key id {key_id}"))?; + self.location_builder.set_continent(continent); + } + response_key_id::AS_NUMBER => { + let number = g3_msgpack::value::as_u32(&v) + .context(format!("invalid u32 value for key id {key_id}"))?; + self.location_builder.set_as_number(number); + } + response_key_id::ISP_NAME => { + let name = g3_msgpack::value::as_string(&v) + .context(format!("invalid string value for key id {key_id}"))?; + self.location_builder.set_isp_name(name); + } + response_key_id::ISP_DOMAIN => { + let domain = g3_msgpack::value::as_string(&v) + .context(format!("invalid string value for key id {key_id}"))?; + self.location_builder.set_isp_domain(domain); } _ => {} // ignore unknown keys } @@ -95,6 +146,61 @@ impl Response { } pub(super) fn into_parts(self) -> (Option, Option, Option) { - (self.ip, self.location, self.ttl) + let location = self.location_builder.build().ok(); + (self.ip, location, self.ttl) + } + + pub fn encode_new(ip: IpAddr, location: IpLocation, ttl: u32) -> anyhow::Result> { + let ip = ip.to_string(); + let network = location.network_addr().to_string(); + let mut map = vec![ + ( + ValueRef::Integer(response_key_id::IP.into()), + ValueRef::String(ip.as_str().into()), + ), + ( + ValueRef::Integer(response_key_id::NETWORK.into()), + ValueRef::String(network.as_str().into()), + ), + ( + ValueRef::Integer(response_key_id::TTL.into()), + ValueRef::Integer(ttl.into()), + ), + ]; + if let Some(country) = location.country() { + map.push(( + ValueRef::Integer(response_key_id::COUNTRY.into()), + ValueRef::String(country.alpha2_code().into()), + )); + } + if let Some(continent) = location.continent() { + map.push(( + ValueRef::Integer(response_key_id::CONTINENT.into()), + ValueRef::String(continent.code().into()), + )); + } + if let Some(number) = location.network_asn() { + map.push(( + ValueRef::Integer(response_key_id::AS_NUMBER.into()), + ValueRef::Integer(number.into()), + )); + } + if let Some(name) = location.isp_name() { + map.push(( + ValueRef::Integer(response_key_id::ISP_NAME.into()), + ValueRef::String(name.into()), + )); + } + if let Some(domain) = location.isp_domain() { + map.push(( + ValueRef::Integer(response_key_id::ISP_DOMAIN.into()), + ValueRef::String(domain.into()), + )); + } + let mut buf = Vec::with_capacity(4096); + let v = ValueRef::Map(map); + rmpv::encode::write_value_ref(&mut buf, &v) + .map_err(|e| anyhow!("msgpack encode failed: {e}"))?; + Ok(buf) } }