diff --git a/Cargo.lock b/Cargo.lock index 6802cb7a..4c2b05b9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1010,8 +1010,11 @@ version = "0.1.0" dependencies = [ "anyhow", "capnp", + "clap", + "clap_complete", "hex", "thiserror", + "tokio", ] [[package]] @@ -1735,7 +1738,6 @@ dependencies = [ "capnp", "capnp-rpc", "clap", - "clap_complete", "futures-util", "g3-ctl", "g3-tls-cert", @@ -1861,7 +1863,6 @@ dependencies = [ "capnp", "capnp-rpc", "clap", - "clap_complete", "futures-util", "g3-ctl", "g3-types", @@ -1961,7 +1962,6 @@ dependencies = [ "capnp", "capnp-rpc", "clap", - "clap_complete", "futures-util", "g3-ctl", "g3tiles-proto", diff --git a/g3keymess/utils/ctl/Cargo.toml b/g3keymess/utils/ctl/Cargo.toml index 1326ba5d..bb26c142 100644 --- a/g3keymess/utils/ctl/Cargo.toml +++ b/g3keymess/utils/ctl/Cargo.toml @@ -10,8 +10,7 @@ edition.workspace = true anyhow.workspace = true thiserror.workspace = true clap.workspace = true -clap_complete.workspace = true -tokio = { workspace = true, features = ["rt", "net", "macros", "io-util", "fs"] } +tokio = { workspace = true, features = ["rt", "macros", "io-util", "fs"] } tokio-util = { workspace = true, features = ["compat"] } futures-util.workspace = true capnp-rpc.workspace = true diff --git a/g3keymess/utils/ctl/src/main.rs b/g3keymess/utils/ctl/src/main.rs index e0276074..63ef4091 100644 --- a/g3keymess/utils/ctl/src/main.rs +++ b/g3keymess/utils/ctl/src/main.rs @@ -14,17 +14,11 @@ * limitations under the License. */ -use std::io; -use std::path::PathBuf; -use std::str::FromStr; - use anyhow::anyhow; use capnp_rpc::{rpc_twoparty_capnp, twoparty, RpcSystem}; -use clap::builder::ArgPredicate; -use clap::{value_parser, Arg, ArgMatches, Command, ValueHint}; -use clap_complete::Shell; -use tokio::io::AsyncWriteExt; -use tokio::net::UnixStream; +use clap::Command; + +use g3_ctl::{CommandError, DaemonCtlArgs, DaemonCtlArgsExt}; use g3keymess_proto::proc_capnp::proc_control; @@ -35,99 +29,9 @@ mod server; mod local; -const DEFAULT_SYS_CONTROL_DIR: &str = "/run/g3keymess"; -const DEFAULT_TMP_CONTROL_DIR: &str = "/tmp/g3"; - -const GLOBAL_ARG_COMPLETION: &str = "completion"; -const GLOBAL_ARG_CONTROL_DIR: &str = "control-dir"; -const GLOBAL_ARG_GROUP: &str = "daemon-group"; -const GLOBAL_ARG_PID: &str = "pid"; - -async fn connect_to_daemon(args: &ArgMatches) -> anyhow::Result { - let control_dir = args.get_one::(GLOBAL_ARG_CONTROL_DIR).unwrap(); - let daemon_group = args - .get_one::(GLOBAL_ARG_GROUP) - .map(|s| s.as_str()) - .unwrap_or_default(); - - let socket_path = match args.get_one::(GLOBAL_ARG_PID) { - Some(pid) => control_dir.join(format!("{daemon_group}_{}.sock", *pid)), - None => control_dir.join(format!("{daemon_group}.sock")), - }; - - let mut stream = tokio::net::UnixStream::connect(&socket_path) - .await - .map_err(|e| { - anyhow!( - "failed to connect to control socket {}: {e:?}", - socket_path.display() - ) - })?; - stream - .write_all(b"capnp\n") - .await - .map_err(|e| anyhow!("enter capnp mode failed: {e:?}"))?; - stream - .flush() - .await - .map_err(|e| anyhow!("enter capnp mod failed: {e:?}"))?; - Ok(stream) -} - -fn dir_exist(dir: &str) -> bool { - let path = PathBuf::from_str(dir).unwrap(); - std::fs::read_dir(path).is_ok() -} - -fn auto_detect_control_dir() -> &'static str { - if dir_exist(DEFAULT_SYS_CONTROL_DIR) { - DEFAULT_SYS_CONTROL_DIR - } else { - DEFAULT_TMP_CONTROL_DIR - } -} - fn build_cli_args() -> Command { - Command::new("g3tiles-ctl") - .arg( - Arg::new(GLOBAL_ARG_COMPLETION) - .num_args(1) - .value_name("SHELL") - .long("completion") - .value_parser(value_parser!(Shell)) - .exclusive(true), - ) - .arg( - Arg::new(GLOBAL_ARG_CONTROL_DIR) - .help("Directory that contains the control socket") - .value_name("CONTROL DIR") - .value_hint(ValueHint::DirPath) - .value_parser(value_parser!(PathBuf)) - .short('C') - .long("control-dir") - .default_value(auto_detect_control_dir()) - .default_value_if(GLOBAL_ARG_COMPLETION, ArgPredicate::IsPresent, None), - ) - .arg( - Arg::new(GLOBAL_ARG_GROUP) - .required_unless_present_any([GLOBAL_ARG_PID, GLOBAL_ARG_COMPLETION]) - .num_args(1) - .value_name("GROUP NAME") - .help("Daemon group name") - .short('G') - .long("daemon-group"), - ) - .arg( - Arg::new(GLOBAL_ARG_PID) - .help("Daemon pid") - .required_unless_present_any([GLOBAL_ARG_GROUP, GLOBAL_ARG_COMPLETION]) - .num_args(1) - .value_name("PID") - .value_parser(value_parser!(usize)) - .short('p') - .long("daemon-pid"), - ) - .subcommand_required(true) + Command::new(env!("CARGO_PKG_NAME")) + .append_daemon_ctl_args() .subcommand(proc::commands::version()) .subcommand(proc::commands::offline()) .subcommand(proc::commands::list()) @@ -141,14 +45,12 @@ fn build_cli_args() -> Command { async fn main() -> anyhow::Result<()> { let args = build_cli_args().get_matches(); - if let Some(target) = args.get_one::(GLOBAL_ARG_COMPLETION) { - let mut app = build_cli_args(); - let bin_name = app.get_name().to_string(); - clap_complete::generate(*target, &mut app, bin_name, &mut io::stdout()); + let mut ctl_opts = DaemonCtlArgs::parse_clap(&args); + if ctl_opts.generate_shell_completion(build_cli_args) { return Ok(()); } - let stream = connect_to_daemon(&args).await?; + let stream = ctl_opts.connect_to_daemon("g3keymess").await?; let (reader, writer) = tokio::io::split(stream); let reader = tokio_util::compat::TokioAsyncReadCompatExt::compat(reader); @@ -179,7 +81,9 @@ async fn main() -> anyhow::Result<()> { proc::COMMAND_CHECK_KEY => proc::check_key(&proc_control, args).await, server::COMMAND => server::run(&proc_control, args).await, local::COMMAND_CHECK_DUP => local::check_dup(args), - _ => unreachable!(), + _ => Err(CommandError::Cli(anyhow!( + "unsupported command {subcommand}" + ))), } }) .await diff --git a/g3proxy/utils/ctl/Cargo.toml b/g3proxy/utils/ctl/Cargo.toml index 4ce12760..1d7a577c 100644 --- a/g3proxy/utils/ctl/Cargo.toml +++ b/g3proxy/utils/ctl/Cargo.toml @@ -10,8 +10,7 @@ edition.workspace = true anyhow.workspace = true thiserror.workspace = true clap.workspace = true -clap_complete.workspace = true -tokio = { workspace = true, features = ["rt", "net", "macros", "io-util", "fs"] } +tokio = { workspace = true, features = ["rt", "macros", "io-util", "fs"] } tokio-util = { workspace = true, features = ["compat"] } futures-util.workspace = true capnp-rpc.workspace = true diff --git a/g3proxy/utils/ctl/src/main.rs b/g3proxy/utils/ctl/src/main.rs index 4877400b..ad2e07ff 100644 --- a/g3proxy/utils/ctl/src/main.rs +++ b/g3proxy/utils/ctl/src/main.rs @@ -14,17 +14,11 @@ * limitations under the License. */ -use std::io; -use std::path::PathBuf; -use std::str::FromStr; - use anyhow::anyhow; use capnp_rpc::{rpc_twoparty_capnp, twoparty, RpcSystem}; -use clap::builder::ArgPredicate; -use clap::{value_parser, Arg, ArgMatches, Command, ValueHint}; -use clap_complete::Shell; -use tokio::io::AsyncWriteExt; -use tokio::net::UnixStream; +use clap::Command; + +use g3_ctl::{CommandError, DaemonCtlArgs, DaemonCtlArgsExt}; use g3proxy_proto::proc_capnp::proc_control; @@ -36,99 +30,9 @@ mod resolver; mod server; mod user_group; -const DEFAULT_SYS_CONTROL_DIR: &str = "/run/g3proxy"; -const DEFAULT_TMP_CONTROL_DIR: &str = "/tmp/g3"; - -const GLOBAL_ARG_COMPLETION: &str = "completion"; -const GLOBAL_ARG_CONTROL_DIR: &str = "control-dir"; -const GLOBAL_ARG_GROUP: &str = "daemon-group"; -const GLOBAL_ARG_PID: &str = "pid"; - -async fn connect_to_daemon(args: &ArgMatches) -> anyhow::Result { - let control_dir = args.get_one::(GLOBAL_ARG_CONTROL_DIR).unwrap(); - let daemon_group = args - .get_one::(GLOBAL_ARG_GROUP) - .map(|s| s.as_str()) - .unwrap_or_default(); - - let socket_path = match args.get_one::(GLOBAL_ARG_PID) { - Some(pid) => control_dir.join(format!("{daemon_group}_{}.sock", *pid)), - None => control_dir.join(format!("{daemon_group}.sock")), - }; - - let mut stream = tokio::net::UnixStream::connect(&socket_path) - .await - .map_err(|e| { - anyhow!( - "failed to connect to control socket {}: {e:?}", - socket_path.display() - ) - })?; - stream - .write_all(b"capnp\n") - .await - .map_err(|e| anyhow!("enter capnp mode failed: {e:?}"))?; - stream - .flush() - .await - .map_err(|e| anyhow!("enter capnp mod failed: {e:?}"))?; - Ok(stream) -} - -fn dir_exist(dir: &str) -> bool { - let path = PathBuf::from_str(dir).unwrap(); - std::fs::read_dir(path).is_ok() -} - -fn auto_detect_control_dir() -> &'static str { - if dir_exist(DEFAULT_SYS_CONTROL_DIR) { - DEFAULT_SYS_CONTROL_DIR - } else { - DEFAULT_TMP_CONTROL_DIR - } -} - fn build_cli_args() -> Command { - Command::new("g3proxy-ctl") - .arg( - Arg::new(GLOBAL_ARG_COMPLETION) - .num_args(1) - .value_name("SHELL") - .long("completion") - .value_parser(value_parser!(Shell)) - .exclusive(true), - ) - .arg( - Arg::new(GLOBAL_ARG_CONTROL_DIR) - .help("Directory that contains the control socket") - .value_name("CONTROL DIR") - .value_hint(ValueHint::DirPath) - .value_parser(value_parser!(PathBuf)) - .short('C') - .long("control-dir") - .default_value(auto_detect_control_dir()) - .default_value_if(GLOBAL_ARG_COMPLETION, ArgPredicate::IsPresent, None), - ) - .arg( - Arg::new(GLOBAL_ARG_GROUP) - .required_unless_present_any([GLOBAL_ARG_PID, GLOBAL_ARG_COMPLETION]) - .num_args(1) - .value_name("GROUP NAME") - .help("Daemon group name") - .short('G') - .long("daemon-group"), - ) - .arg( - Arg::new(GLOBAL_ARG_PID) - .help("Daemon pid") - .required_unless_present_any([GLOBAL_ARG_GROUP, GLOBAL_ARG_COMPLETION]) - .num_args(1) - .value_name("PID") - .value_parser(value_parser!(usize)) - .short('p') - .long("daemon-pid"), - ) - .subcommand_required(true) + Command::new(env!("CARGO_PKG_NAME")) + .append_daemon_ctl_args() .subcommand(proc::commands::version()) .subcommand(proc::commands::offline()) .subcommand(proc::commands::force_quit()) @@ -149,14 +53,12 @@ fn build_cli_args() -> Command { async fn main() -> anyhow::Result<()> { let args = build_cli_args().get_matches(); - if let Some(target) = args.get_one::(GLOBAL_ARG_COMPLETION) { - let mut app = build_cli_args(); - let bin_name = app.get_name().to_string(); - clap_complete::generate(*target, &mut app, bin_name, &mut io::stdout()); + let mut ctl_opts = DaemonCtlArgs::parse_clap(&args); + if ctl_opts.generate_shell_completion(build_cli_args) { return Ok(()); } - let stream = connect_to_daemon(&args).await?; + let stream = ctl_opts.connect_to_daemon("g3proxy").await?; let (reader, writer) = tokio::io::split(stream); let reader = tokio_util::compat::TokioAsyncReadCompatExt::compat(reader); @@ -196,7 +98,9 @@ async fn main() -> anyhow::Result<()> { resolver::COMMAND => resolver::run(&proc_control, args).await, escaper::COMMAND => escaper::run(&proc_control, args).await, server::COMMAND => server::run(&proc_control, args).await, - _ => unreachable!(), + _ => Err(CommandError::Cli(anyhow!( + "unsupported command {subcommand}" + ))), } }) .await diff --git a/g3tiles/utils/ctl/Cargo.toml b/g3tiles/utils/ctl/Cargo.toml index 6ecbe83c..3f7df4c0 100644 --- a/g3tiles/utils/ctl/Cargo.toml +++ b/g3tiles/utils/ctl/Cargo.toml @@ -10,8 +10,7 @@ edition.workspace = true anyhow.workspace = true thiserror.workspace = true clap.workspace = true -clap_complete.workspace = true -tokio = { workspace = true, features = ["rt", "net", "macros", "io-util", "fs"] } +tokio = { workspace = true, features = ["rt", "macros", "io-util", "fs"] } tokio-util = { workspace = true, features = ["compat"] } futures-util.workspace = true capnp-rpc.workspace = true diff --git a/g3tiles/utils/ctl/src/main.rs b/g3tiles/utils/ctl/src/main.rs index b9508e19..c5fd59bb 100644 --- a/g3tiles/utils/ctl/src/main.rs +++ b/g3tiles/utils/ctl/src/main.rs @@ -14,17 +14,11 @@ * limitations under the License. */ -use std::io; -use std::path::PathBuf; -use std::str::FromStr; - use anyhow::anyhow; use capnp_rpc::{rpc_twoparty_capnp, twoparty, RpcSystem}; -use clap::builder::ArgPredicate; -use clap::{value_parser, Arg, ArgMatches, Command, ValueHint}; -use clap_complete::Shell; -use tokio::io::AsyncWriteExt; -use tokio::net::UnixStream; +use clap::Command; + +use g3_ctl::{CommandError, DaemonCtlArgs, DaemonCtlArgsExt}; use g3tiles_proto::proc_capnp::proc_control; @@ -33,99 +27,9 @@ mod proc; mod server; -const DEFAULT_SYS_CONTROL_DIR: &str = "/run/g3tiles"; -const DEFAULT_TMP_CONTROL_DIR: &str = "/tmp/g3"; - -const GLOBAL_ARG_COMPLETION: &str = "completion"; -const GLOBAL_ARG_CONTROL_DIR: &str = "control-dir"; -const GLOBAL_ARG_GROUP: &str = "daemon-group"; -const GLOBAL_ARG_PID: &str = "pid"; - -async fn connect_to_daemon(args: &ArgMatches) -> anyhow::Result { - let control_dir = args.get_one::(GLOBAL_ARG_CONTROL_DIR).unwrap(); - let daemon_group = args - .get_one::(GLOBAL_ARG_GROUP) - .map(|s| s.as_str()) - .unwrap_or_default(); - - let socket_path = match args.get_one::(GLOBAL_ARG_PID) { - Some(pid) => control_dir.join(format!("{daemon_group}_{}.sock", *pid)), - None => control_dir.join(format!("{daemon_group}.sock")), - }; - - let mut stream = tokio::net::UnixStream::connect(&socket_path) - .await - .map_err(|e| { - anyhow!( - "failed to connect to control socket {}: {e:?}", - socket_path.display() - ) - })?; - stream - .write_all(b"capnp\n") - .await - .map_err(|e| anyhow!("enter capnp mode failed: {e:?}"))?; - stream - .flush() - .await - .map_err(|e| anyhow!("enter capnp mod failed: {e:?}"))?; - Ok(stream) -} - -fn dir_exist(dir: &str) -> bool { - let path = PathBuf::from_str(dir).unwrap(); - std::fs::read_dir(path).is_ok() -} - -fn auto_detect_control_dir() -> &'static str { - if dir_exist(DEFAULT_SYS_CONTROL_DIR) { - DEFAULT_SYS_CONTROL_DIR - } else { - DEFAULT_TMP_CONTROL_DIR - } -} - fn build_cli_args() -> Command { - Command::new("g3tiles-ctl") - .arg( - Arg::new(GLOBAL_ARG_COMPLETION) - .num_args(1) - .value_name("SHELL") - .long("completion") - .value_parser(value_parser!(Shell)) - .exclusive(true), - ) - .arg( - Arg::new(GLOBAL_ARG_CONTROL_DIR) - .help("Directory that contains the control socket") - .value_name("CONTROL DIR") - .value_hint(ValueHint::DirPath) - .value_parser(value_parser!(PathBuf)) - .short('C') - .long("control-dir") - .default_value(auto_detect_control_dir()) - .default_value_if(GLOBAL_ARG_COMPLETION, ArgPredicate::IsPresent, None), - ) - .arg( - Arg::new(GLOBAL_ARG_GROUP) - .required_unless_present_any([GLOBAL_ARG_PID, GLOBAL_ARG_COMPLETION]) - .num_args(1) - .value_name("GROUP NAME") - .help("Daemon group name") - .short('G') - .long("daemon-group"), - ) - .arg( - Arg::new(GLOBAL_ARG_PID) - .help("Daemon pid") - .required_unless_present_any([GLOBAL_ARG_GROUP, GLOBAL_ARG_COMPLETION]) - .num_args(1) - .value_name("PID") - .value_parser(value_parser!(usize)) - .short('p') - .long("daemon-pid"), - ) - .subcommand_required(true) + Command::new(env!("CARGO_PKG_NAME")) + .append_daemon_ctl_args() .subcommand(proc::commands::version()) .subcommand(proc::commands::offline()) .subcommand(proc::commands::force_quit()) @@ -141,14 +45,12 @@ fn build_cli_args() -> Command { async fn main() -> anyhow::Result<()> { let args = build_cli_args().get_matches(); - if let Some(target) = args.get_one::(GLOBAL_ARG_COMPLETION) { - let mut app = build_cli_args(); - let bin_name = app.get_name().to_string(); - clap_complete::generate(*target, &mut app, bin_name, &mut io::stdout()); + let mut ctl_opts = DaemonCtlArgs::parse_clap(&args); + if ctl_opts.generate_shell_completion(build_cli_args) { return Ok(()); } - let stream = connect_to_daemon(&args).await?; + let stream = ctl_opts.connect_to_daemon("g3tiles").await?; let (reader, writer) = tokio::io::split(stream); let reader = tokio_util::compat::TokioAsyncReadCompatExt::compat(reader); @@ -181,7 +83,9 @@ async fn main() -> anyhow::Result<()> { proc::COMMAND_RELOAD_DISCOVER => proc::reload_discover(&proc_control, args).await, proc::COMMAND_RELOAD_BACKEND => proc::reload_backend(&proc_control, args).await, server::COMMAND => server::run(&proc_control, args).await, - _ => unreachable!(), + _ => Err(CommandError::Cli(anyhow!( + "unsupported command {subcommand}" + ))), } }) .await diff --git a/lib/g3-ctl/Cargo.toml b/lib/g3-ctl/Cargo.toml index c5c791f3..45a2595a 100644 --- a/lib/g3-ctl/Cargo.toml +++ b/lib/g3-ctl/Cargo.toml @@ -9,5 +9,8 @@ edition.workspace = true [dependencies] thiserror.workspace = true anyhow.workspace = true +clap.workspace = true +clap_complete.workspace = true capnp.workspace = true hex.workspace = true +tokio = { workspace = true, features = ["net", "io-util"] } diff --git a/lib/g3-ctl/src/io.rs b/lib/g3-ctl/src/io.rs new file mode 100644 index 00000000..31e4aafb --- /dev/null +++ b/lib/g3-ctl/src/io.rs @@ -0,0 +1,71 @@ +/* + * Copyright 2024 ByteDance and/or its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +use super::{CommandError, CommandResult}; + +pub fn print_ok_notice(notice_reader: capnp::text::Reader<'_>) -> CommandResult<()> { + match notice_reader.to_str() { + Ok(notice) => { + println!("notice: {notice}"); + Ok(()) + } + Err(e) => Err(CommandError::Utf8 { + field: "ok", + reason: e, + }), + } +} + +pub fn print_text(field: &'static str, text_reader: capnp::text::Reader<'_>) -> CommandResult<()> { + match text_reader.to_str() { + Ok(text) => { + println!("{text}"); + Ok(()) + } + Err(e) => Err(CommandError::Utf8 { field, reason: e }), + } +} + +#[inline] +pub fn print_version(version_reader: capnp::text::Reader<'_>) -> CommandResult<()> { + print_text("version", version_reader) +} + +pub fn print_text_list( + field: &'static str, + list: capnp::text_list::Reader<'_>, +) -> CommandResult<()> { + for text in list.iter() { + print_text(field, text?)?; + } + Ok(()) +} + +#[inline] +pub fn print_result_list(result_list_reader: capnp::text_list::Reader<'_>) -> CommandResult<()> { + print_text_list("result", result_list_reader) +} + +pub fn print_data(data_reader: capnp::data::Reader<'_>) { + println!("{}", hex::encode(data_reader)); +} + +pub fn print_data_list(list: capnp::data_list::Reader<'_>) -> CommandResult<()> { + for data in list.iter() { + print_data(data?); + } + Ok(()) +} diff --git a/lib/g3-ctl/src/lib.rs b/lib/g3-ctl/src/lib.rs index 491a7cfc..15dc5df2 100644 --- a/lib/g3-ctl/src/lib.rs +++ b/lib/g3-ctl/src/lib.rs @@ -14,59 +14,11 @@ * limitations under the License. */ +mod opts; +pub use opts::{DaemonCtlArgs, DaemonCtlArgsExt}; + mod error; pub use error::{CommandError, CommandResult}; -pub fn print_ok_notice(notice_reader: capnp::text::Reader<'_>) -> CommandResult<()> { - match notice_reader.to_str() { - Ok(notice) => { - println!("notice: {notice}"); - Ok(()) - } - Err(e) => Err(CommandError::Utf8 { - field: "ok", - reason: e, - }), - } -} - -pub fn print_text(field: &'static str, text_reader: capnp::text::Reader<'_>) -> CommandResult<()> { - match text_reader.to_str() { - Ok(text) => { - println!("{text}"); - Ok(()) - } - Err(e) => Err(CommandError::Utf8 { field, reason: e }), - } -} - -#[inline] -pub fn print_version(version_reader: capnp::text::Reader<'_>) -> CommandResult<()> { - print_text("version", version_reader) -} - -pub fn print_text_list( - field: &'static str, - list: capnp::text_list::Reader<'_>, -) -> CommandResult<()> { - for text in list.iter() { - print_text(field, text?)?; - } - Ok(()) -} - -#[inline] -pub fn print_result_list(result_list_reader: capnp::text_list::Reader<'_>) -> CommandResult<()> { - print_text_list("result", result_list_reader) -} - -pub fn print_data(data_reader: capnp::data::Reader<'_>) { - println!("{}", hex::encode(data_reader)); -} - -pub fn print_data_list(list: capnp::data_list::Reader<'_>) -> CommandResult<()> { - for data in list.iter() { - print_data(data?); - } - Ok(()) -} +mod io; +pub use io::*; diff --git a/lib/g3-ctl/src/opts.rs b/lib/g3-ctl/src/opts.rs new file mode 100644 index 00000000..1eb3ce14 --- /dev/null +++ b/lib/g3-ctl/src/opts.rs @@ -0,0 +1,160 @@ +/* + * Copyright 2024 ByteDance and/or its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +use std::io; +#[cfg(unix)] +use std::path::PathBuf; + +use anyhow::anyhow; +use clap::{value_parser, Arg, ArgMatches, Command, ValueHint}; +use clap_complete::Shell; +use tokio::io::AsyncWriteExt; +use tokio::net::UnixStream; + +const DEFAULT_TMP_CONTROL_DIR: &str = "/tmp/g3"; + +const GLOBAL_ARG_COMPLETION: &str = "completion"; +const GLOBAL_ARG_CONTROL_DIR: &str = "control-dir"; +const GLOBAL_ARG_GROUP: &str = "daemon-group"; +const GLOBAL_ARG_PID: &str = "pid"; + +pub trait DaemonCtlArgsExt { + fn append_daemon_ctl_args(self) -> Self; +} + +#[derive(Debug, Default)] +pub struct DaemonCtlArgs { + shell_completion: Option, + #[cfg(unix)] + control_dir: Option, + daemon_group: String, + pid: usize, +} + +impl DaemonCtlArgs { + pub fn parse_clap(args: &ArgMatches) -> Self { + let mut config = DaemonCtlArgs::default(); + + if let Some(shell) = args.get_one::(GLOBAL_ARG_COMPLETION) { + config.shell_completion = Some(*shell); + return config; + } + + #[cfg(unix)] + if let Some(dir) = args.get_one::(GLOBAL_ARG_CONTROL_DIR) { + config.control_dir = Some(dir.clone()); + } + + if let Some(group) = args.get_one::(GLOBAL_ARG_GROUP) { + config.daemon_group.clone_from(group); + } + + if let Some(pid) = args.get_one::(GLOBAL_ARG_PID) { + config.pid = *pid; + } + + config + } + + pub fn generate_shell_completion(&mut self, build_cmd: F) -> bool + where + F: Fn() -> Command, + { + let Some(shell) = self.shell_completion.take() else { + return false; + }; + let mut cmd = build_cmd(); + let bin_name = cmd.get_name().to_string(); + clap_complete::generate(shell, &mut cmd, bin_name, &mut io::stdout()); + true + } + + #[cfg(unix)] + pub async fn connect_to_daemon(&self, daemon_name: &'static str) -> anyhow::Result { + let control_dir = self.control_dir.clone().unwrap_or_else(|| { + let mut sys_ctl_dir = PathBuf::from("/run"); + sys_ctl_dir.push(daemon_name); + + if sys_ctl_dir.is_dir() { + sys_ctl_dir + } else { + PathBuf::from(DEFAULT_TMP_CONTROL_DIR) + } + }); + let socket_path = if self.pid != 0 { + control_dir.join(format!("{}_{}.sock", self.daemon_group, self.pid)) + } else { + control_dir.join(format!("{}.sock", self.daemon_group)) + }; + + let mut stream = UnixStream::connect(&socket_path).await.map_err(|e| { + anyhow!( + "failed to connect to control socket {}: {e:?}", + socket_path.display() + ) + })?; + stream + .write_all(b"capnp\n") + .await + .map_err(|e| anyhow!("enter capnp mode failed: {e:?}"))?; + stream + .flush() + .await + .map_err(|e| anyhow!("enter capnp mod failed: {e:?}"))?; + Ok(stream) + } +} + +impl DaemonCtlArgsExt for Command { + fn append_daemon_ctl_args(self) -> Self { + self.arg( + Arg::new(GLOBAL_ARG_COMPLETION) + .num_args(1) + .value_name("SHELL") + .long("completion") + .value_parser(value_parser!(Shell)) + .exclusive(true), + ) + .arg( + Arg::new(GLOBAL_ARG_CONTROL_DIR) + .help("Directory that contains the control socket") + .value_name("CONTROL DIR") + .value_hint(ValueHint::DirPath) + .value_parser(value_parser!(PathBuf)) + .short('C') + .long("control-dir"), + ) + .arg( + Arg::new(GLOBAL_ARG_GROUP) + .required_unless_present_any([GLOBAL_ARG_PID, GLOBAL_ARG_COMPLETION]) + .num_args(1) + .value_name("GROUP NAME") + .help("Daemon group name") + .short('G') + .long("daemon-group"), + ) + .arg( + Arg::new(GLOBAL_ARG_PID) + .help("Daemon pid") + .required_unless_present_any([GLOBAL_ARG_GROUP, GLOBAL_ARG_COMPLETION]) + .num_args(1) + .value_name("PID") + .value_parser(value_parser!(usize)) + .short('p') + .long("daemon-pid"), + ) + } +}