feat: add reporter, database

This commit is contained in:
Lowder 2025-12-10 04:50:14 +05:00
parent 2926da13a2
commit 3523dd30c8
No known key found for this signature in database
GPG key ID: E63B80B1F1DC187C
33 changed files with 2625 additions and 166 deletions

View file

@ -1,21 +1,34 @@
#[macro_use] extern crate rocket;
#[macro_use]
extern crate rocket;
mod agency;
mod db;
mod whitelist;
use crate::db::{check_whitelist, save_query};
use log::error;
use querying::resolver::Resolver;
use querying::target::Target;
use querying::{Check, CheckError, CheckVerdict, Checker};
use rocket::fairing::AdHoc;
use rocket::fs::FileServer;
use rocket::http::Status;
use rocket::response::content::RawJavaScript;
use rocket::tokio::sync::RwLock;
use rocket::tokio::time;
use rocket::{fairing, tokio, Build, Request, Rocket, State};
use rocket_cache_response::CacheResponse;
use rocket_client_addr::ClientRealAddr;
use rocket_db_pools::{Connection, Database};
use rocket_dyn_templates::{context, Metadata, Template};
use serde::Serialize;
use std::path::PathBuf;
use std::sync::Arc;
use std::time::Duration;
use rocket::fs::FileServer;
use rocket::http::Status;
use rocket::{tokio, Request, State};
use rocket::response::content::RawJavaScript;
use rocket_dyn_templates::{context, Metadata, Template};
use serde::Serialize;
use log::error;
use rocket::tokio::sync::RwLock;
use rocket::tokio::time;
use rocket_cache_response::CacheResponse;
use querying::{Check, CheckError, CheckVerdict, Checker};
use querying::resolver::Resolver;
use querying::target::Target;
use sqlx::types::Uuid;
#[derive(rocket_db_pools::Database)]
#[database("cheburcheck")]
struct Db(sqlx::PgPool);
#[derive(Serialize)]
struct GlobalContext {
@ -24,19 +37,24 @@ struct GlobalContext {
impl GlobalContext {
fn new() -> Self {
GlobalContext { version: env!("CARGO_PKG_VERSION") }
GlobalContext {
version: env!("CARGO_PKG_VERSION"),
}
}
}
#[get("/")]
async fn index(checker: &State<Arc<RwLock<Checker>>>) -> Template {
let checker_ref = checker.read().await;
Template::render("index", context! {
global: GlobalContext::new(),
domain_count: format_number(checker_ref.total_domains().await),
v4_count: format_number(checker_ref.total_v4s().await),
last_update: checker_ref.last_update(),
})
Template::render(
"index",
context! {
global: GlobalContext::new(),
domain_count: format_number(checker_ref.total_domains().await),
v4_count: format_number(checker_ref.total_v4s().await),
last_update: checker_ref.last_update(),
},
)
}
#[get("/kb/<page>")]
@ -46,9 +64,12 @@ fn page(metadata: Metadata, page: &str) -> Option<Template> {
return None;
}
Some(Template::render(page, context! {
global: GlobalContext::new(),
}))
Some(Template::render(
page,
context! {
global: GlobalContext::new(),
},
))
}
#[get("/healthcheck")]
@ -60,32 +81,86 @@ async fn healthcheck(checker: &State<Arc<RwLock<Checker>>>) -> (Status, String)
}
}
#[post("/feedback/<uuid>/<works>")]
async fn feedback(uuid: &str, works: bool, mut db: Connection<Db>, addr: &ClientRealAddr) -> Result<(), Status> {
sqlx::query!(
"INSERT INTO human_reports (id, source_ip, works) VALUES ($1, $2, $3)",
Uuid::try_parse(uuid).map_err(|_| Status::BadRequest)?,
addr.ip.to_string(),
works
).execute(&mut **db).await.map_err(|_| Status::InternalServerError)?;
Ok(())
}
#[get("/check?<target>")]
async fn check(target: &str, checker: &State<Arc<RwLock<Checker>>>) -> Result<Template, Status> {
async fn check(
target: &str,
checker: &State<Arc<RwLock<Checker>>>,
addr: &ClientRealAddr,
mut db: Connection<Db>,
) -> Result<Template, Status> {
let target = Target::from(target);
match checker.read().await.check(target.clone()).await {
Err(CheckError::NotFound) =>
Ok(Template::render("empty", context! {
global: GlobalContext::new(),
target: target.to_query(),
target_type: target.readable_type(),
})),
Ok(Check { verdict: CheckVerdict::Clear, geo, ips }) =>
Ok(Template::render("result", context! {
let check = checker.read().await.check(target.clone()).await;
let id = if let Ok(check) = &check {
match save_query(&mut db, &target, check, addr, checker.read().await).await {
Ok(id) => Some(id.to_string()),
Err(e) => {
warn!("Failed to save check: {:?}", e);
None
}
}
} else {
None
};
let whitelist = if let Target::Domain(domain) = &target {
check_whitelist(domain, &mut db)
.await
.map_err(|_| Status::InternalServerError)?
} else {
None
};
match check {
Err(CheckError::NotFound) => Ok(Template::render(
"empty",
context! {
global: GlobalContext::new(),
target: target.to_query(),
target_type: target.readable_type(),
},
)),
Ok(Check {
verdict: CheckVerdict::Clear,
geo,
ips,
}) => Ok(Template::render(
"result",
context! {
id,
global: GlobalContext::new(),
found: false,
target: target.to_query(),
target_type: target.readable_type(),
whitelist,
ips,
geo,
})),
},
)),
Ok(Check {
verdict: CheckVerdict::Blocked {
rkn_domain,
rkn_subnets,
cdn_provider_subnets
}, geo, ips }) =>
Ok(Template::render("result", context! {
verdict:
CheckVerdict::Blocked {
rkn_domain,
rkn_subnets,
cdn_provider_subnets,
},
geo,
ips,
}) => Ok(Template::render(
"result",
context! {
id,
global: GlobalContext::new(),
found: true,
domain: rkn_domain,
@ -95,9 +170,11 @@ async fn check(target: &str, checker: &State<Arc<RwLock<Checker>>>) -> Result<Te
.collect::<Vec<_>>(),
target: target.to_query(),
target_type: target.readable_type(),
whitelist,
ips,
geo,
})),
},
)),
Err(e) => {
error!("check failed {:?}", e);
Err(Status::InternalServerError)
@ -107,14 +184,17 @@ async fn check(target: &str, checker: &State<Arc<RwLock<Checker>>>) -> Result<Te
#[catch(default)]
fn default(status: Status, _req: &Request) -> Template {
Template::render("error", context! {
global: GlobalContext::new(),
status: status.code,
reason: status.reason_lossy(),
})
Template::render(
"error",
context! {
global: GlobalContext::new(),
status: status.code,
reason: status.reason_lossy(),
},
)
}
#[rocket::get("/vendor/lucide.js")]
#[rocket::get("/lucide.js")]
fn lucide() -> CacheResponse<RawJavaScript<&'static [u8]>> {
CacheResponse::Public {
responder: RawJavaScript(include_bytes!(concat!(env!("OUT_DIR"), "/lucide.js"))),
@ -122,9 +202,26 @@ fn lucide() -> CacheResponse<RawJavaScript<&'static [u8]>> {
must_revalidate: false,
}
}
#[rocket::get("/chart.js")]
fn chartjs() -> CacheResponse<RawJavaScript<&'static [u8]>> {
CacheResponse::Public {
responder: RawJavaScript(include_bytes!(concat!(env!("OUT_DIR"), "/chart.js"))),
max_age: 604800,
must_revalidate: false,
}
}
#[rocket::get("/chartjs-plugin-datalabels.js")]
fn chartjs_datalabels() -> CacheResponse<RawJavaScript<&'static [u8]>> {
CacheResponse::Public {
responder: RawJavaScript(include_bytes!(concat!(env!("OUT_DIR"), "/chartjs-plugin-datalabels.js"))),
max_age: 604800,
must_revalidate: false,
}
}
fn format_number(number: usize) -> String {
number.to_string()
number
.to_string()
.as_bytes()
.rchunks(3)
.rev()
@ -134,12 +231,31 @@ fn format_number(number: usize) -> String {
.join(" ")
}
async fn run_migrations(rocket: Rocket<Build>) -> fairing::Result {
match Db::fetch(&rocket) {
Some(db) => match sqlx::migrate!("./migrations").run(&**db).await {
Ok(_) => Ok(rocket),
Err(e) => {
error!("Failed to run database migrations: {}", e);
Err(rocket)
}
},
None => Err(rocket),
}
}
#[launch]
async fn rocket() -> _ {
env_logger::builder().filter_level(log::LevelFilter::Info).init();
env_logger::builder()
.filter_level(log::LevelFilter::Info)
.init();
let mut interval = time::interval(Duration::from_secs(std::env::var("DATABASE_INTERVAL_SECONDS")
.unwrap_or("21600".to_string()).parse().unwrap()));
let mut interval = time::interval(Duration::from_secs(
std::env::var("DATABASE_INTERVAL_SECONDS")
.unwrap_or("21600".to_string())
.parse()
.unwrap(),
));
let checker = Arc::new(RwLock::new(Checker::new().await));
@ -154,10 +270,20 @@ async fn rocket() -> _ {
}
});
rocket::build()
let figment = rocket::Config::figment().merge((
"databases.cheburcheck.url",
dotenvy::var("DATABASE_URL").expect("DATABASE_URL must be set"),
));
rocket::custom(figment)
.manage(Resolver::new().await)
.manage(checker)
.mount("/", routes![index, lucide, check, healthcheck, page])
.attach(Db::init())
.attach(AdHoc::try_on_ignite("SQLx Migrations", run_migrations))
.mount("/", routes![index, check, healthcheck, page, feedback])
.mount("/vendor", routes![lucide, chartjs, chartjs_datalabels])
.mount("/agency", routes![agency::upload_report])
.mount("/whitelist", routes![whitelist::histogram, whitelist::export_csv])
.register("/", catchers![default])
.mount("/", FileServer::from(PathBuf::from("static")))
.attach(Template::fairing())