Allow auth via HTTP authorization header

This commit is contained in:
Antoine Gersant 2018-10-28 19:04:21 -07:00
parent 7e11b651ed
commit ed2ae20951
3 changed files with 47 additions and 19 deletions

View file

@ -7,6 +7,7 @@ use std::fs::File;
use std::ops::Deref; use std::ops::Deref;
use std::path::PathBuf; use std::path::PathBuf;
use std::str; use std::str;
use std::str::FromStr;
use std::sync::Arc; use std::sync::Arc;
use config::{self, Config, Preferences}; use config::{self, Config, Preferences};
@ -59,19 +60,44 @@ struct Auth {
username: String, username: String,
} }
fn get_auth_cookie(username: &str) -> Cookie<'static> {
Cookie::build(SESSION_FIELD_USERNAME, username.to_owned())
.same_site(rocket::http::SameSite::Lax)
.finish()
}
impl<'a, 'r> FromRequest<'a, 'r> for Auth { impl<'a, 'r> FromRequest<'a, 'r> for Auth {
type Error = (); type Error = ();
fn from_request(request: &'a Request<'r>) -> request::Outcome<Self, ()> { fn from_request(request: &'a Request<'r>) -> request::Outcome<Self, ()> {
let mut cookies = request.guard::<Cookies>().unwrap(); let mut cookies = request.guard::<Cookies>().unwrap();
match cookies.get_private(SESSION_FIELD_USERNAME) { if let Some(u) = cookies.get_private(SESSION_FIELD_USERNAME) {
Some(u) => Outcome::Success(Auth { return Outcome::Success(Auth {
username: u.value().to_string(), username: u.value().to_string(),
}), });
_ => Outcome::Failure((Status::Forbidden, ())),
} }
// TODO allow auth via authorization header if let Some(auth_header_string) = request.headers().get_one("Authorization") {
use rocket::http::hyper::header::*;
if let Ok(Basic {
username,
password: Some(password),
}) = Basic::from_str(auth_header_string.trim_start_matches("Basic ")) // Sadness
{
let db = match request.guard::<State<Arc<DB>>>() {
Outcome::Success(d) => d,
_ => return Outcome::Failure((Status::InternalServerError, ()))
};
if user::auth(db.deref().deref(), &username, &password).unwrap_or(false) {
cookies.add_private(get_auth_cookie(&username));
return Outcome::Success(Auth {
username: username.to_string(),
});
}
}
}
Outcome::Failure((Status::Unauthorized, ()))
} }
} }
@ -147,7 +173,10 @@ fn initial_setup(db: State<Arc<DB>>) -> Result<Json<InitialSetup>, errors::Error
} }
#[get("/settings")] #[get("/settings")]
fn get_settings(db: State<Arc<DB>>, _admin_rights: AdminRights) -> Result<Json<Config>, errors::Error> { fn get_settings(
db: State<Arc<DB>>,
_admin_rights: AdminRights,
) -> Result<Json<Config>, errors::Error> {
let config = config::read::<DB>(&db)?; let config = config::read::<DB>(&db)?;
Ok(Json(config)) Ok(Json(config))
} }
@ -205,11 +234,7 @@ fn auth(
mut cookies: Cookies, mut cookies: Cookies,
) -> Result<Json<AuthOutput>, errors::Error> { ) -> Result<Json<AuthOutput>, errors::Error> {
user::auth::<DB>(&db, &credentials.username, &credentials.password)?; user::auth::<DB>(&db, &credentials.username, &credentials.password)?;
cookies.add_private( cookies.add_private(get_auth_cookie(&credentials.username));
Cookie::build(SESSION_FIELD_USERNAME, credentials.username.clone())
.same_site(rocket::http::SameSite::Lax)
.finish(),
);
let auth_output = AuthOutput { let auth_output = AuthOutput {
admin: user::is_admin::<DB>(&db, &credentials.username)?, admin: user::is_admin::<DB>(&db, &credentials.username)?,
@ -355,7 +380,11 @@ fn delete_playlist(db: State<Arc<DB>>, auth: Auth, name: String) -> Result<(), e
} }
#[put("/lastfm/now_playing/<path>")] #[put("/lastfm/now_playing/<path>")]
fn lastfm_now_playing(db: State<Arc<DB>>, auth: Auth, path: VFSPathBuf) -> Result<(), errors::Error> { fn lastfm_now_playing(
db: State<Arc<DB>>,
auth: Auth,
path: VFSPathBuf,
) -> Result<(), errors::Error> {
lastfm::now_playing(db.deref().deref(), &auth.username, &path.into() as &PathBuf)?; lastfm::now_playing(db.deref().deref(), &auth.username, &path.into() as &PathBuf)?;
Ok(()) Ok(())
} }

View file

@ -3,7 +3,7 @@ use rocket::http::hyper::header::*;
use rocket::response::{self, Responder}; use rocket::response::{self, Responder};
use std::cmp; use std::cmp;
use std::convert::From; use std::convert::From;
use std::fs::{File}; use std::fs::File;
use std::io::{Read, Seek, SeekFrom}; use std::io::{Read, Seek, SeekFrom};
use std::str::FromStr; use std::str::FromStr;
@ -79,7 +79,6 @@ fn truncate_range(range: &PartialFileRange, file_length: &Option<u64>) -> Option
impl<'r> Responder<'r> for RangeResponder<File> { impl<'r> Responder<'r> for RangeResponder<File> {
fn respond_to(mut self, request: &rocket::request::Request) -> response::Result<'r> { fn respond_to(mut self, request: &rocket::request::Request) -> response::Result<'r> {
let range_header = request.headers().get_one("Range"); let range_header = request.headers().get_one("Range");
let range_header = match range_header { let range_header = match range_header {
None => return Ok(self.original.respond_to(request)?), None => return Ok(self.original.respond_to(request)?),