From 9e48dc408ecfe585ce61d2ed76f8d999b1413ca7 Mon Sep 17 00:00:00 2001 From: Antoine Gersant Date: Wed, 15 Jan 2020 21:58:37 -0800 Subject: [PATCH] DB interactions for actix version --- Cargo.lock | 21 +++++ Cargo.toml | 2 +- src/config.rs | 117 ++++++++++------------- src/db/mod.rs | 55 ++++------- src/ddns.rs | 14 +-- src/index.rs | 145 +++++++++++------------------ src/lastfm.rs | 28 ++---- src/main.rs | 14 ++- src/playlist.rs | 68 ++++++-------- src/service/actix/api.rs | 47 +++++++++- src/service/actix/mod.rs | 11 ++- src/service/actix/server.rs | 3 +- src/service/actix/tests/api.rs | 105 +++++++++++++++++++-- src/service/actix/tests/mod.rs | 14 ++- src/service/actix/tests/swagger.rs | 11 ++- src/service/actix/tests/web.rs | 4 +- src/service/dto.rs | 5 + src/service/error.rs | 15 +++ src/service/mod.rs | 1 + src/service/rocket/api.rs | 112 +++++++--------------- src/service/rocket/api_tests.rs | 8 +- src/service/rocket/server.rs | 4 +- src/service/rocket/test.rs | 7 +- src/user.rs | 71 +++++--------- src/vfs.rs | 4 +- 25 files changed, 449 insertions(+), 437 deletions(-) create mode 100644 src/service/error.rs diff --git a/Cargo.lock b/Cargo.lock index bf17dae..88b64d8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -774,6 +774,7 @@ dependencies = [ "byteorder 1.3.2 (registry+https://github.com/rust-lang/crates.io-index)", "diesel_derives 1.4.1 (registry+https://github.com/rust-lang/crates.io-index)", "libsqlite3-sys 0.16.0 (registry+https://github.com/rust-lang/crates.io-index)", + "r2d2 0.8.8 (registry+https://github.com/rust-lang/crates.io-index)", ] [[package]] @@ -2160,6 +2161,16 @@ dependencies = [ "proc-macro2 1.0.6 (registry+https://github.com/rust-lang/crates.io-index)", ] +[[package]] +name = "r2d2" +version = "0.8.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "log 0.4.8 (registry+https://github.com/rust-lang/crates.io-index)", + "parking_lot 0.10.0 (registry+https://github.com/rust-lang/crates.io-index)", + "scheduled-thread-pool 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)", +] + [[package]] name = "rand" version = "0.3.23" @@ -2590,6 +2601,14 @@ dependencies = [ "winapi 0.3.8 (registry+https://github.com/rust-lang/crates.io-index)", ] +[[package]] +name = "scheduled-thread-pool" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "parking_lot 0.10.0 (registry+https://github.com/rust-lang/crates.io-index)", +] + [[package]] name = "scoped_threadpool" version = "0.1.9" @@ -3644,6 +3663,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" "checksum quote 0.3.15 (registry+https://github.com/rust-lang/crates.io-index)" = "7a6e920b65c65f10b2ae65c831a81a073a89edd28c7cce89475bff467ab4167a" "checksum quote 0.6.13 (registry+https://github.com/rust-lang/crates.io-index)" = "6ce23b6b870e8f94f81fb0a363d65d86675884b34a09043c81e5562f11c1f8e1" "checksum quote 1.0.2 (registry+https://github.com/rust-lang/crates.io-index)" = "053a8c8bcc71fcce321828dc897a98ab9760bef03a4fc36693c231e5b3216cfe" +"checksum r2d2 0.8.8 (registry+https://github.com/rust-lang/crates.io-index)" = "1497e40855348e4a8a40767d8e55174bce1e445a3ac9254ad44ad468ee0485af" "checksum rand 0.3.23 (registry+https://github.com/rust-lang/crates.io-index)" = "64ac302d8f83c0c1974bf758f6b041c6c8ada916fbb44a609158ca8b064cc76c" "checksum rand 0.4.6 (registry+https://github.com/rust-lang/crates.io-index)" = "552840b97013b1a26992c11eac34bdd778e464601a4c2054b5f0bff7c6761293" "checksum rand 0.5.6 (registry+https://github.com/rust-lang/crates.io-index)" = "c618c47cd3ebd209790115ab837de41425723956ad3ce2e6a7f09890947cacb9" @@ -3686,6 +3706,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" "checksum safemem 0.3.3 (registry+https://github.com/rust-lang/crates.io-index)" = "ef703b7cb59335eae2eb93ceb664c0eb7ea6bf567079d843e09420219668e072" "checksum same-file 1.0.5 (registry+https://github.com/rust-lang/crates.io-index)" = "585e8ddcedc187886a30fa705c47985c3fa88d06624095856b36ca0b82ff4421" "checksum schannel 0.1.16 (registry+https://github.com/rust-lang/crates.io-index)" = "87f550b06b6cba9c8b8be3ee73f391990116bf527450d2556e9b9ce263b9a021" +"checksum scheduled-thread-pool 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)" = "f5de7bc31f28f8e6c28df5e1bf3d10610f5fdc14cc95f272853512c70a2bd779" "checksum scoped_threadpool 0.1.9 (registry+https://github.com/rust-lang/crates.io-index)" = "1d51f5df5af43ab3f1360b429fa5e0152ac5ce8c0bd6485cae490332e96846a8" "checksum scopeguard 1.0.0 (registry+https://github.com/rust-lang/crates.io-index)" = "b42e15e59b18a828bbf5c58ea01debb36b9b096346de35d941dcb89009f24a0d" "checksum sd-notify 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)" = "aef40838bbb143707f8309b1e92e6ba3225287592968ba6f6e3b6de4a9816486" diff --git a/Cargo.toml b/Cargo.toml index 0d36348..304f6e4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,7 +20,7 @@ anyhow = "1.0" ape = "0.2.0" app_dirs = "1.1.1" base64 = "0.11.0" -diesel = { version = "1.4", features = ["sqlite"] } +diesel = { version = "1.4", features = ["sqlite", "r2d2"] } diesel_migrations = { version = "1.4", features = ["sqlite"] } flame = { version = "0.2.2", optional = true } flamer = { version = "0.4", optional = true } diff --git a/src/config.rs b/src/config.rs index de141b5..6af5520 100644 --- a/src/config.rs +++ b/src/config.rs @@ -10,8 +10,7 @@ use std::io::Read; use std::path; use toml; -use crate::db::ConnectionSource; -use crate::db::{ddns_config, misc_settings, mount_points, users}; +use crate::db::{ddns_config, misc_settings, mount_points, users, DB}; use crate::ddns::DDNSConfig; use crate::user::*; use crate::vfs::MountPoint; @@ -73,14 +72,11 @@ pub fn parse_toml_file(path: &path::Path) -> Result { Ok(config) } -pub fn read(db: &T) -> Result -where - T: ConnectionSource, -{ +pub fn read(db: &DB) -> Result { use self::ddns_config::dsl::*; use self::misc_settings::dsl::*; - let connection = db.get_connection(); + let connection = db.connect()?; let mut config = Config { album_art_pattern: None, @@ -97,7 +93,7 @@ where index_sleep_duration_seconds, prefix_url, )) - .get_result(connection.deref())?; + .get_result(&connection)?; config.album_art_pattern = Some(art_pattern); config.reindex_every_n_seconds = Some(sleep_duration); @@ -108,13 +104,13 @@ where use self::mount_points::dsl::*; mount_dirs = mount_points .select((source, name)) - .get_results(connection.deref())?; + .get_results(&connection)?; config.mount_dirs = Some(mount_dirs); } let found_users: Vec<(String, i32)> = users::table .select((users::columns::name, users::columns::admin)) - .get_results(connection.deref())?; + .get_results(&connection)?; config.users = Some( found_users .into_iter() @@ -128,46 +124,39 @@ where let ydns = ddns_config .select((host, username, password)) - .get_result(connection.deref())?; + .get_result(&connection)?; config.ydns = Some(ydns); Ok(config) } #[cfg(test)] -pub fn reset(db: &T) -> Result<()> -where - T: ConnectionSource, -{ +pub fn reset(db: &DB) -> Result<()> { use self::ddns_config::dsl::*; - let connection = db.get_connection(); + let connection = db.connect()?; - diesel::delete(mount_points::table).execute(connection.deref())?; - diesel::delete(users::table).execute(connection.deref())?; + diesel::delete(mount_points::table).execute(&connection)?; + diesel::delete(users::table).execute(&connection)?; diesel::update(ddns_config) .set((host.eq(""), username.eq(""), password.eq(""))) - .execute(connection.deref())?; + .execute(&connection)?; Ok(()) } -pub fn amend(db: &T, new_config: &Config) -> Result<()> -where - T: ConnectionSource, -{ - let connection = db.get_connection(); +pub fn amend(db: &DB, new_config: &Config) -> Result<()> { + let connection = db.connect()?; if let Some(ref mount_dirs) = new_config.mount_dirs { - diesel::delete(mount_points::table).execute(connection.deref())?; + diesel::delete(mount_points::table).execute(&connection)?; diesel::insert_into(mount_points::table) .values(mount_dirs) - .execute(connection.deref())?; + .execute(&*connection)?; // TODO https://github.com/diesel-rs/diesel/issues/1822 } if let Some(ref config_users) = new_config.users { - let old_usernames: Vec = users::table - .select(users::name) - .get_results(connection.deref())?; + let old_usernames: Vec = + users::table.select(users::name).get_results(&connection)?; // Delete users that are not in new list let delete_usernames: Vec = old_usernames @@ -176,7 +165,7 @@ where .filter(|old_name| config_users.iter().find(|u| &u.name == old_name).is_none()) .collect::<_>(); diesel::delete(users::table.filter(users::name.eq_any(&delete_usernames))) - .execute(connection.deref())?; + .execute(&connection)?; // Insert new users let insert_users: Vec<&ConfigUser> = config_users @@ -194,7 +183,7 @@ where let new_user = User::new(&config_user.name, &config_user.password)?; diesel::insert_into(users::table) .values(&new_user) - .execute(connection.deref())?; + .execute(&connection)?; } // Update users @@ -204,26 +193,26 @@ where let hash = hash_password(&user.password)?; diesel::update(users::table.filter(users::name.eq(&user.name))) .set(users::password_hash.eq(hash)) - .execute(connection.deref())?; + .execute(&connection)?; } // Update admin rights diesel::update(users::table.filter(users::name.eq(&user.name))) .set(users::admin.eq(user.admin as i32)) - .execute(connection.deref())?; + .execute(&connection)?; } } if let Some(sleep_duration) = new_config.reindex_every_n_seconds { diesel::update(misc_settings::table) .set(misc_settings::index_sleep_duration_seconds.eq(sleep_duration as i32)) - .execute(connection.deref())?; + .execute(&connection)?; } if let Some(ref album_art_pattern) = new_config.album_art_pattern { diesel::update(misc_settings::table) .set(misc_settings::index_album_art_pattern.eq(album_art_pattern)) - .execute(connection.deref())?; + .execute(&connection)?; } if let Some(ref ydns) = new_config.ydns { @@ -234,28 +223,25 @@ where username.eq(ydns.username.clone()), password.eq(ydns.password.clone()), )) - .execute(connection.deref())?; + .execute(&connection)?; } if let Some(ref prefix_url) = new_config.prefix_url { diesel::update(misc_settings::table) .set(misc_settings::prefix_url.eq(prefix_url)) - .execute(connection.deref())?; + .execute(&connection)?; } Ok(()) } -pub fn read_preferences(db: &T, username: &str) -> Result -where - T: ConnectionSource, -{ +pub fn read_preferences(db: &DB, username: &str) -> Result { use self::users::dsl::*; - let connection = db.get_connection(); + let connection = db.connect()?; let (theme_base, theme_accent, read_lastfm_username) = users .select((web_theme_base, web_theme_accent, lastfm_username)) .filter(name.eq(username)) - .get_result(connection.deref())?; + .get_result(&connection)?; Ok(Preferences { web_theme_base: theme_base, web_theme_accent: theme_accent, @@ -263,33 +249,24 @@ where }) } -pub fn write_preferences(db: &T, username: &str, preferences: &Preferences) -> Result<()> -where - T: ConnectionSource, -{ +pub fn write_preferences(db: &DB, username: &str, preferences: &Preferences) -> Result<()> { use crate::db::users::dsl::*; - let connection = db.get_connection(); + let connection = db.connect()?; diesel::update(users.filter(name.eq(username))) .set(( web_theme_base.eq(&preferences.web_theme_base), web_theme_accent.eq(&preferences.web_theme_accent), )) - .execute(connection.deref())?; + .execute(&connection)?; Ok(()) } -pub fn get_auth_secret(db: &T) -> Result> -where - T: ConnectionSource, -{ +pub fn get_auth_secret(db: &DB) -> Result> { use self::misc_settings::dsl::*; - let connection = db.get_connection(); + let connection = db.connect()?; - match misc_settings - .select(auth_secret) - .get_result(connection.deref()) - { + match misc_settings.select(auth_secret).get_result(&connection) { Err(diesel::result::Error::NotFound) => bail!("Cannot find authentication secret"), Ok(secret) => Ok(secret), Err(e) => Err(e.into()), @@ -391,11 +368,11 @@ fn test_amend_preserve_password_hashes() { amend(&db, &initial_config).unwrap(); { - let connection = db.get_connection(); + let connection = db.connect().unwrap(); initial_hash = users .select(password_hash) .filter(name.eq("Teddy🐻")) - .get_result(connection.deref()) + .get_result(&connection) .unwrap(); } @@ -421,11 +398,11 @@ fn test_amend_preserve_password_hashes() { amend(&db, &new_config).unwrap(); { - let connection = db.get_connection(); + let connection = db.connect().unwrap(); new_hash = users .select(password_hash) .filter(name.eq("Teddy🐻")) - .get_result(connection.deref()) + .get_result(&connection) .unwrap(); } @@ -453,8 +430,8 @@ fn test_amend_ignore_blank_users() { }; amend(&db, &config).unwrap(); - let connection = db.get_connection(); - let user_count: i64 = users.count().get_result(connection.deref()).unwrap(); + let connection = db.connect().unwrap(); + let user_count: i64 = users.count().get_result(&connection).unwrap(); assert_eq!(user_count, 0); } @@ -473,8 +450,8 @@ fn test_amend_ignore_blank_users() { }; amend(&db, &config).unwrap(); - let connection = db.get_connection(); - let user_count: i64 = users.count().get_result(connection.deref()).unwrap(); + let connection = db.connect().unwrap(); + let user_count: i64 = users.count().get_result(&connection).unwrap(); assert_eq!(user_count, 0); } } @@ -500,8 +477,8 @@ fn test_toggle_admin() { amend(&db, &initial_config).unwrap(); { - let connection = db.get_connection(); - let is_admin: i32 = users.select(admin).get_result(connection.deref()).unwrap(); + let connection = db.connect().unwrap(); + let is_admin: i32 = users.select(admin).get_result(&connection).unwrap(); assert_eq!(is_admin, 1); } @@ -520,8 +497,8 @@ fn test_toggle_admin() { amend(&db, &new_config).unwrap(); { - let connection = db.get_connection(); - let is_admin: i32 = users.select(admin).get_result(connection.deref()).unwrap(); + let connection = db.connect().unwrap(); + let is_admin: i32 = users.select(admin).get_result(&connection).unwrap(); assert_eq!(is_admin, 0); } } diff --git a/src/db/mod.rs b/src/db/mod.rs index 17e5ea3..6f923af 100644 --- a/src/db/mod.rs +++ b/src/db/mod.rs @@ -1,11 +1,9 @@ use anyhow::*; -use core::ops::Deref; -use diesel::prelude::*; +use diesel::r2d2::{self, ConnectionManager, PooledConnection}; use diesel::sqlite::SqliteConnection; use diesel_migrations; use log::info; use std::path::Path; -use std::sync::{Arc, Mutex, MutexGuard}; mod schema; @@ -15,44 +13,34 @@ pub use self::schema::*; const DB_MIGRATIONS_PATH: &str = "migrations"; embed_migrations!("migrations"); -pub trait ConnectionSource { - fn get_connection(&self) -> MutexGuard<'_, SqliteConnection>; - fn get_connection_mutex(&self) -> Arc>; -} - +#[derive(Clone)] pub struct DB { - connection: Arc>, + pool: r2d2::Pool>, } impl DB { pub fn new(path: &Path) -> Result { info!("Database file path: {}", path.to_string_lossy()); - let connection = Arc::new(Mutex::new(SqliteConnection::establish( - &path.to_string_lossy(), - )?)); - let db = DB { - connection: connection.clone(), - }; - db.init()?; + let manager = ConnectionManager::::new(path.to_string_lossy()); + let pool = r2d2::Pool::builder() + .build(manager) + .expect("Failed to create pool."); // TODO handle error + + let db = DB { pool: pool }; + db.migrate_up()?; Ok(db) } - fn init(&self) -> Result<()> { - { - let connection = self.connection.lock().unwrap(); - connection.execute("PRAGMA synchronous = NORMAL")?; - } - self.migrate_up()?; - Ok(()) + pub fn connect(&self) -> Result>> { + self.pool.get().map_err(Error::new) } #[allow(dead_code)] fn migrate_down(&self) -> Result<()> { - let connection = self.connection.lock().unwrap(); - let connection = connection.deref(); + let connection = self.connect().unwrap(); loop { match diesel_migrations::revert_latest_migration_in_directory( - connection, + &connection, Path::new(DB_MIGRATIONS_PATH), ) { Ok(_) => (), @@ -66,23 +54,12 @@ impl DB { } fn migrate_up(&self) -> Result<()> { - let connection = self.connection.lock().unwrap(); - let connection = connection.deref(); - embedded_migrations::run(connection)?; + let connection = self.connect().unwrap(); + embedded_migrations::run(&connection)?; Ok(()) } } -impl ConnectionSource for DB { - fn get_connection(&self) -> MutexGuard<'_, SqliteConnection> { - self.connection.lock().unwrap() - } - - fn get_connection_mutex(&self) -> Arc> { - self.connection.clone() - } -} - #[cfg(test)] pub fn get_test_db(name: &str) -> DB { use crate::config; diff --git a/src/ddns.rs b/src/ddns.rs index c5f5e4e..8eda3bd 100644 --- a/src/ddns.rs +++ b/src/ddns.rs @@ -8,7 +8,7 @@ use std::thread; use std::time; use crate::db::ddns_config; -use crate::db::{ConnectionSource, DB}; +use crate::db::DB; #[derive(Clone, Debug, Deserialize, Insertable, PartialEq, Queryable, Serialize)] #[table_name = "ddns_config"] @@ -25,7 +25,7 @@ pub trait DDNSConfigSource { impl DDNSConfigSource for DB { fn get_ddns_config(&self) -> Result { use self::ddns_config::dsl::*; - let connection = self.get_connection(); + let connection = self.connect()?; Ok(ddns_config .select((host, username, password)) .get_result(connection.deref())?) @@ -34,10 +34,7 @@ impl DDNSConfigSource for DB { const DDNS_UPDATE_URL: &str = "https://ydns.io/api/v1/update/"; -fn update_my_ip(config_source: &T) -> Result<()> -where - T: DDNSConfigSource, -{ +fn update_my_ip(config_source: &DB) -> Result<()> { let config = config_source.get_ddns_config()?; if config.host.is_empty() || config.username.is_empty() { info!("Skipping DDNS update because credentials are missing"); @@ -59,10 +56,7 @@ where Ok(()) } -pub fn run(config_source: &T) -where - T: DDNSConfigSource, -{ +pub fn run(config_source: &DB) { loop { if let Err(e) = update_my_ip(config_source) { error!("Dynamic DNS update error: {:?}", e); diff --git a/src/index.rs b/src/index.rs index d41cc98..70b80a6 100644 --- a/src/index.rs +++ b/src/index.rs @@ -4,7 +4,6 @@ use diesel; use diesel::dsl::sql; use diesel::prelude::*; use diesel::sql_types; -use diesel::sqlite::SqliteConnection; #[cfg(feature = "profile-index")] use flame; use log::{error, info}; @@ -22,8 +21,7 @@ use std::time; use crate::config::MiscSettings; #[cfg(test)] use crate::db; -use crate::db::{directories, misc_settings, songs}; -use crate::db::{ConnectionSource, DB}; +use crate::db::{directories, misc_settings, songs, DB}; use crate::metadata; use crate::vfs::{VFSSource, VFS}; @@ -80,16 +78,14 @@ impl CommandSender { } } -pub fn init(db: Arc) -> Arc { +pub fn init(db: DB) -> Arc { let (index_sender, index_receiver) = channel(); let command_sender = Arc::new(CommandSender::new(index_sender)); let command_receiver = CommandReceiver::new(index_receiver); // Start update loop - let db_ref = db.clone(); std::thread::spawn(move || { - let db = db_ref.deref(); - update_loop(db, &command_receiver); + update_loop(&db, &command_receiver); }); command_sender @@ -162,19 +158,16 @@ struct NewDirectory { date_added: i32, } -struct IndexBuilder<'conn> { +struct IndexBuilder { new_songs: Vec, new_directories: Vec, - connection: &'conn Mutex, + db: DB, album_art_pattern: Regex, } -impl<'conn> IndexBuilder<'conn> { +impl IndexBuilder { #[cfg_attr(feature = "profile-index", flame)] - fn new( - connection: &Mutex, - album_art_pattern: Regex, - ) -> Result> { + fn new(db: DB, album_art_pattern: Regex) -> Result { let mut new_songs = Vec::new(); let mut new_directories = Vec::new(); new_songs.reserve_exact(INDEX_BUILDING_INSERT_BUFFER_SIZE); @@ -182,19 +175,18 @@ impl<'conn> IndexBuilder<'conn> { Ok(IndexBuilder { new_songs, new_directories, - connection, + db, album_art_pattern, }) } #[cfg_attr(feature = "profile-index", flame)] fn flush_songs(&mut self) -> Result<()> { - let connection = self.connection.lock().unwrap(); - let connection = connection.deref(); + let connection = self.db.connect()?; connection.transaction::<_, anyhow::Error, _>(|| { diesel::insert_into(songs::table) .values(&self.new_songs) - .execute(connection)?; + .execute(&*connection)?; // TODO https://github.com/diesel-rs/diesel/issues/1822 Ok(()) })?; self.new_songs.clear(); @@ -203,12 +195,11 @@ impl<'conn> IndexBuilder<'conn> { #[cfg_attr(feature = "profile-index", flame)] fn flush_directories(&mut self) -> Result<()> { - let connection = self.connection.lock().unwrap(); - let connection = connection.deref(); + let connection = self.db.connect()?; connection.transaction::<_, anyhow::Error, _>(|| { diesel::insert_into(directories::table) .values(&self.new_directories) - .execute(connection)?; + .execute(&*connection)?; // TODO https://github.com/diesel-rs/diesel/issues/1822 Ok(()) })?; self.new_directories.clear(); @@ -364,17 +355,14 @@ impl<'conn> IndexBuilder<'conn> { } #[cfg_attr(feature = "profile-index", flame)] -fn clean(db: &T) -> Result<()> -where - T: ConnectionSource + VFSSource, -{ +fn clean(db: &DB) -> Result<()> { let vfs = db.get_vfs()?; { let all_songs: Vec; { - let connection = db.get_connection(); - all_songs = songs::table.select(songs::path).load(connection.deref())?; + let connection = db.connect()?; + all_songs = songs::table.select(songs::path).load(&connection)?; } let missing_songs = all_songs @@ -386,7 +374,7 @@ where .collect::>(); { - let connection = db.get_connection(); + let connection = db.connect()?; for chunk in missing_songs[..].chunks(INDEX_BUILDING_CLEAN_BUFFER_SIZE) { diesel::delete(songs::table.filter(songs::path.eq_any(chunk))) .execute(connection.deref())?; @@ -397,7 +385,7 @@ where { let all_directories: Vec; { - let connection = db.get_connection(); + let connection = db.connect()?; all_directories = directories::table .select(directories::path) .load(connection.deref())?; @@ -412,7 +400,7 @@ where .collect::>(); { - let connection = db.get_connection(); + let connection = db.connect()?; for chunk in missing_directories[..].chunks(INDEX_BUILDING_CLEAN_BUFFER_SIZE) { diesel::delete(directories::table.filter(directories::path.eq_any(chunk))) .execute(connection.deref())?; @@ -424,22 +412,18 @@ where } #[cfg_attr(feature = "profile-index", flame)] -fn populate(db: &T) -> Result<()> -where - T: ConnectionSource + VFSSource, -{ +fn populate(db: &DB) -> Result<()> { let vfs = db.get_vfs()?; let mount_points = vfs.get_mount_points(); let album_art_pattern; { - let connection = db.get_connection(); - let settings: MiscSettings = misc_settings::table.get_result(connection.deref())?; + let connection = db.connect()?; + let settings: MiscSettings = misc_settings::table.get_result(&connection)?; album_art_pattern = Regex::new(&settings.index_album_art_pattern)?; } - let connection_mutex = db.get_connection_mutex(); - let mut builder = IndexBuilder::new(connection_mutex.deref(), album_art_pattern)?; + let mut builder = IndexBuilder::new(db.clone(), album_art_pattern)?; for target in mount_points.values() { builder.populate_directory(None, target.as_path())?; } @@ -448,10 +432,7 @@ where Ok(()) } -pub fn update(db: &T) -> Result<()> -where - T: ConnectionSource + VFSSource, -{ +pub fn update(db: &DB) -> Result<()> { let start = time::Instant::now(); info!("Beginning library index update"); clean(db)?; @@ -465,10 +446,7 @@ where Ok(()) } -fn update_loop(db: &T, command_buffer: &CommandReceiver) -where - T: ConnectionSource + VFSSource, -{ +fn update_loop(db: &DB, command_buffer: &CommandReceiver) { loop { // Wait for a command if command_buffer.receiver.recv().is_err() { @@ -492,10 +470,7 @@ where } } -pub fn self_trigger(db: &T, command_buffer: &Arc) -where - T: ConnectionSource, -{ +pub fn self_trigger(db: &DB, command_buffer: &Arc) { loop { { let command_buffer = command_buffer.deref(); @@ -504,19 +479,17 @@ where return; } } - let sleep_duration; - { - let connection = db.get_connection(); - let settings: Result = misc_settings::table - .get_result(connection.deref()) - .map_err(|e| e.into()); - if let Err(ref e) = settings { - error!("Could not retrieve index sleep duration: {}", e); - } - sleep_duration = settings - .map(|s| s.index_sleep_duration_seconds) - .unwrap_or(1800); - } + let sleep_duration = { + let connection = db.connect(); + connection + .and_then(|c| { + misc_settings::table + .get_result(&c) + .map_err(|e| Error::new(e)) + }) + .map(|s: MiscSettings| s.index_sleep_duration_seconds) + .unwrap_or(1800) // TODO log error + }; thread::sleep(time::Duration::from_secs(sleep_duration as u64)); } } @@ -551,14 +524,13 @@ fn virtualize_directory(vfs: &VFS, mut directory: Directory) -> Option(db: &T, virtual_path: P) -> Result> +pub fn browse

(db: &DB, virtual_path: P) -> Result> where - T: ConnectionSource + VFSSource, P: AsRef, { let mut output = Vec::new(); let vfs = db.get_vfs()?; - let connection = db.get_connection(); + let connection = db.connect()?; if virtual_path.as_ref().components().count() == 0 { // Browse top-level @@ -596,14 +568,13 @@ where Ok(output) } -pub fn flatten(db: &T, virtual_path: P) -> Result> +pub fn flatten

(db: &DB, virtual_path: P) -> Result> where - T: ConnectionSource + VFSSource, P: AsRef, { use self::songs::dsl::*; let vfs = db.get_vfs()?; - let connection = db.get_connection(); + let connection = db.connect()?; let real_songs: Vec = if virtual_path.as_ref().parent() != None { let real_path = vfs.virtual_to_real(virtual_path)?; @@ -622,13 +593,10 @@ where Ok(virtual_songs.collect::>()) } -pub fn get_random_albums(db: &T, count: i64) -> Result> -where - T: ConnectionSource + VFSSource, -{ +pub fn get_random_albums(db: &DB, count: i64) -> Result> { use self::directories::dsl::*; let vfs = db.get_vfs()?; - let connection = db.get_connection(); + let connection = db.connect()?; let real_directories = directories .filter(album.is_not_null()) .limit(count) @@ -640,13 +608,10 @@ where Ok(virtual_directories.collect::>()) } -pub fn get_recent_albums(db: &T, count: i64) -> Result> -where - T: ConnectionSource + VFSSource, -{ +pub fn get_recent_albums(db: &DB, count: i64) -> Result> { use self::directories::dsl::*; let vfs = db.get_vfs()?; - let connection = db.get_connection(); + let connection = db.connect()?; let real_directories: Vec = directories .filter(album.is_not_null()) .order(date_added.desc()) @@ -658,12 +623,9 @@ where Ok(virtual_directories.collect::>()) } -pub fn search(db: &T, query: &str) -> Result> -where - T: ConnectionSource + VFSSource, -{ +pub fn search(db: &DB, query: &str) -> Result> { let vfs = db.get_vfs()?; - let connection = db.get_connection(); + let connection = db.connect()?; let like_test = format!("%{}%", query); let mut output = Vec::new(); @@ -706,12 +668,9 @@ where Ok(output) } -pub fn get_song(db: &T, virtual_path: &Path) -> Result -where - T: ConnectionSource + VFSSource, -{ +pub fn get_song(db: &DB, virtual_path: &Path) -> Result { let vfs = db.get_vfs()?; - let connection = db.get_connection(); + let connection = db.connect()?; let real_path = vfs.virtual_to_real(virtual_path)?; let real_path_string = real_path.as_path().to_string_lossy(); @@ -732,9 +691,9 @@ fn test_populate() { update(&db).unwrap(); update(&db).unwrap(); // Check that subsequent updates don't run into conflicts - let connection = db.get_connection(); - let all_directories: Vec = directories::table.load(connection.deref()).unwrap(); - let all_songs: Vec = songs::table.load(connection.deref()).unwrap(); + let connection = db.connect().unwrap(); + let all_directories: Vec = directories::table.load(&connection).unwrap(); + let all_songs: Vec = songs::table.load(&connection).unwrap(); assert_eq!(all_directories.len(), 5); assert_eq!(all_songs.len(), 12); } @@ -756,7 +715,7 @@ fn test_metadata() { let db = db::get_test_db("metadata.sqlite"); update(&db).unwrap(); - let connection = db.get_connection(); + let connection = db.connect().unwrap(); let songs: Vec = songs::table .filter(songs::title.eq("シャーベット (Sherbet)")) .load(connection.deref()) diff --git a/src/lastfm.rs b/src/lastfm.rs index 843e8a6..5045f80 100644 --- a/src/lastfm.rs +++ b/src/lastfm.rs @@ -3,10 +3,9 @@ use rustfm_scrobble::{Scrobble, Scrobbler}; use serde::Deserialize; use std::path::Path; -use crate::db::ConnectionSource; +use crate::db::DB; use crate::index; use crate::user; -use crate::vfs::VFSSource; const LASTFM_API_KEY: &str = "02b96c939a2b451c31dfd67add1f696e"; const LASTFM_API_SECRET: &str = "0f25a80ceef4b470b5cb97d99d4b3420"; @@ -42,10 +41,7 @@ struct AuthResponse { pub session: AuthResponseSession, } -fn scrobble_from_path(db: &T, track: &Path) -> Result -where - T: ConnectionSource + VFSSource, -{ +fn scrobble_from_path(db: &DB, track: &Path) -> Result { let song = index::get_song(db, track)?; Ok(Scrobble::new( song.artist.unwrap_or_else(|| "".into()), @@ -54,27 +50,18 @@ where )) } -pub fn link(db: &T, username: &str, token: &str) -> Result<()> -where - T: ConnectionSource + VFSSource, -{ +pub fn link(db: &DB, username: &str, token: &str) -> Result<()> { let mut scrobbler = Scrobbler::new(LASTFM_API_KEY.into(), LASTFM_API_SECRET.into()); let auth_response = scrobbler.authenticate_with_token(token.to_string())?; user::lastfm_link(db, username, &auth_response.name, &auth_response.key) } -pub fn unlink(db: &T, username: &str) -> Result<()> -where - T: ConnectionSource + VFSSource, -{ +pub fn unlink(db: &DB, username: &str) -> Result<()> { user::lastfm_unlink(db, username) } -pub fn scrobble(db: &T, username: &str, track: &Path) -> Result<()> -where - T: ConnectionSource + VFSSource, -{ +pub fn scrobble(db: &DB, username: &str, track: &Path) -> Result<()> { let mut scrobbler = Scrobbler::new(LASTFM_API_KEY.into(), LASTFM_API_SECRET.into()); let scrobble = scrobble_from_path(db, track)?; let auth_token = user::get_lastfm_session_key(db, username)?; @@ -83,10 +70,7 @@ where Ok(()) } -pub fn now_playing(db: &T, username: &str, track: &Path) -> Result<()> -where - T: ConnectionSource + VFSSource, -{ +pub fn now_playing(db: &DB, username: &str, track: &Path) -> Result<()> { let mut scrobbler = Scrobbler::new(LASTFM_API_KEY.into(), LASTFM_API_SECRET.into()); let scrobble = scrobble_from_path(db, track)?; let auth_token = user::get_lastfm_session_key(db, username)?; diff --git a/src/main.rs b/src/main.rs index 88f1a62..ced7af9 100644 --- a/src/main.rs +++ b/src/main.rs @@ -21,12 +21,10 @@ use std::io::prelude::*; use unix_daemonize::{daemonize_redirect, ChdirMode}; use anyhow::*; -use core::ops::Deref; use getopts::Options; use log::info; use simplelog::{LevelFilter, SimpleLogger, TermLogger, TerminalMode}; use std::path::Path; -use std::sync::Arc; mod config; mod db; @@ -160,7 +158,7 @@ fn main() -> Result<()> { let db_path = db_path .map(|n| Path::new(n.as_str()).to_path_buf()) .unwrap_or(default_db_path); - let db = Arc::new(db::DB::new(&db_path)?); + let db = db::DB::new(&db_path)?; // Parse config info!("Parsing configuration"); @@ -169,10 +167,10 @@ fn main() -> Result<()> { if let Some(path) = config_file_path { let config = config::parse_toml_file(&path)?; info!("Applying configuration"); - config::amend(db.deref(), &config)?; + config::amend(&db, &config)?; } - let config = config::read(db.deref())?; - let auth_secret = config::get_auth_secret(db.deref())?; + let config = config::read(&db)?; + let auth_secret = config::get_auth_secret(&db)?; // Init index info!("Initializing index"); @@ -182,7 +180,7 @@ fn main() -> Result<()> { let db_auto_index = db.clone(); let command_sender_auto_index = command_sender.clone(); std::thread::spawn(move || { - index::self_trigger(db_auto_index.deref(), &command_sender_auto_index); + index::self_trigger(&db_auto_index, &command_sender_auto_index); }); // API mount target @@ -235,7 +233,7 @@ fn main() -> Result<()> { // Start DDNS updates let db_ddns = db.clone(); std::thread::spawn(move || { - ddns::run(db_ddns.deref()); + ddns::run(&db_ddns); }); // Send readiness notification diff --git a/src/playlist.rs b/src/playlist.rs index 5d66843..39dfc77 100644 --- a/src/playlist.rs +++ b/src/playlist.rs @@ -9,7 +9,7 @@ use std::path::Path; #[cfg(test)] use crate::db; -use crate::db::ConnectionSource; +use crate::db::DB; use crate::db::{playlist_songs, playlists, users}; use crate::index::{self, Song}; use crate::vfs::VFSSource; @@ -48,11 +48,8 @@ pub struct NewPlaylistSong { ordering: i32, } -pub fn list_playlists(owner: &str, db: &T) -> Result> -where - T: ConnectionSource + VFSSource, -{ - let connection = db.get_connection(); +pub fn list_playlists(owner: &str, db: &DB) -> Result> { + let connection = db.connect()?; let user: User; { @@ -72,17 +69,14 @@ where } } -pub fn save_playlist(playlist_name: &str, owner: &str, content: &[String], db: &T) -> Result<()> -where - T: ConnectionSource + VFSSource, -{ +pub fn save_playlist(playlist_name: &str, owner: &str, content: &[String], db: &DB) -> Result<()> { let user: User; let new_playlist: NewPlaylist; let playlist: Playlist; let vfs = db.get_vfs()?; { - let connection = db.get_connection(); + let connection = db.connect()?; // Find owner { @@ -90,7 +84,7 @@ where user = users .filter(name.eq(owner)) .select((id,)) - .get_result(connection.deref())?; + .get_result(&connection)?; } // Create playlist @@ -101,14 +95,14 @@ where diesel::insert_into(playlists::table) .values(&new_playlist) - .execute(connection.deref())?; + .execute(&connection)?; { use self::playlists::dsl::*; playlist = playlists .select((id, owner)) .filter(name.eq(playlist_name).and(owner.eq(user.id))) - .get_result(connection.deref())?; + .get_result(&connection)?; } } @@ -131,34 +125,29 @@ where } { - let connection = db.get_connection(); - connection - .deref() - .transaction::<_, diesel::result::Error, _>(|| { - // Delete old content (if any) - let old_songs = PlaylistSong::belonging_to(&playlist); - diesel::delete(old_songs).execute(connection.deref())?; + let connection = db.connect()?; + connection.transaction::<_, diesel::result::Error, _>(|| { + // Delete old content (if any) + let old_songs = PlaylistSong::belonging_to(&playlist); + diesel::delete(old_songs).execute(connection.deref())?; - // Insert content - diesel::insert_into(playlist_songs::table) - .values(&new_songs) - .execute(connection.deref())?; - Ok(()) - })?; + // Insert content + diesel::insert_into(playlist_songs::table) + .values(&new_songs) + .execute(&*connection)?; // TODO https://github.com/diesel-rs/diesel/issues/1822 + Ok(()) + })?; } Ok(()) } -pub fn read_playlist(playlist_name: &str, owner: &str, db: &T) -> Result> -where - T: ConnectionSource + VFSSource, -{ +pub fn read_playlist(playlist_name: &str, owner: &str, db: &DB) -> Result> { let vfs = db.get_vfs()?; let songs: Vec; { - let connection = db.get_connection(); + let connection = db.connect()?; let user: User; let playlist: Playlist; @@ -168,7 +157,7 @@ where user = users .filter(name.eq(owner)) .select((id,)) - .get_result(connection.deref())?; + .get_result(&connection)?; } // Find playlist @@ -177,7 +166,7 @@ where playlist = playlists .select((id, owner)) .filter(name.eq(playlist_name).and(owner.eq(user.id))) - .get_result(connection.deref())?; + .get_result(&connection)?; } // Select songs. Not using Diesel because we need to LEFT JOIN using a custom column @@ -191,7 +180,7 @@ where "#, ); let query = query.clone().bind::(playlist.id); - songs = query.get_results(connection.deref())?; + songs = query.get_results(&connection)?; } // Map real path to virtual paths @@ -203,11 +192,8 @@ where Ok(virtual_songs) } -pub fn delete_playlist(playlist_name: &str, owner: &str, db: &T) -> Result<()> -where - T: ConnectionSource + VFSSource, -{ - let connection = db.get_connection(); +pub fn delete_playlist(playlist_name: &str, owner: &str, db: &DB) -> Result<()> { + let connection = db.connect()?; let user: User; { @@ -215,7 +201,7 @@ where user = users .filter(name.eq(owner)) .select((id,)) - .first(connection.deref())?; + .first(&connection)?; } { diff --git a/src/service/actix/api.rs b/src/service/actix/api.rs index 34bd807..ef7e2e3 100644 --- a/src/service/actix/api.rs +++ b/src/service/actix/api.rs @@ -1,13 +1,56 @@ -use actix_web::{get, HttpResponse}; +use actix_http::ResponseBuilder; +use actix_web::{error, get, http::header, http::StatusCode, put, web, HttpResponse}; +use anyhow::*; +use crate::config::{self, Config, Preferences}; +use crate::db::DB; use crate::service::constants::*; use crate::service::dto; +use crate::service::error::APIError; +use crate::user; + +impl error::ResponseError for APIError { + fn error_response(&self) -> HttpResponse { + ResponseBuilder::new(self.status_code()) + .set_header(header::CONTENT_TYPE, "text/html; charset=utf-8") + .body(self.to_string()) + } + fn status_code(&self) -> StatusCode { + match *self { + APIError::IncorrectCredentials => StatusCode::UNAUTHORIZED, + APIError::Unspecified => StatusCode::INTERNAL_SERVER_ERROR, + } + } +} #[get("/version")] -pub async fn get_version() -> Result { +pub async fn get_version() -> Result { let current_version = dto::Version { major: CURRENT_MAJOR_VERSION, minor: CURRENT_MINOR_VERSION, }; Ok(HttpResponse::Ok().json(current_version)) } + +#[get("/initial_setup")] +pub async fn get_initial_setup(db: web::Data) -> Result { + let user_count = web::block(move || user::count(&db)) + .await + .map_err(|_| anyhow!("Could not count users"))?; + let initial_setup = dto::InitialSetup { + has_any_users: user_count > 0, + }; + Ok(HttpResponse::Ok().json(initial_setup)) +} + +#[put("/settings")] +pub async fn put_settings( + db: web::Data, + // _admin_rights: AdminRights, // TODO + config: web::Json, +) -> Result { + web::block(move || config::amend(&db, &config)) + .await + .map_err(|_| anyhow!("Could not amend config"))?; + Ok(HttpResponse::Ok().finish()) +} diff --git a/src/service/actix/mod.rs b/src/service/actix/mod.rs index f8e7365..e420973 100644 --- a/src/service/actix/mod.rs +++ b/src/service/actix/mod.rs @@ -2,6 +2,8 @@ use actix_files as fs; use actix_web::web; use std::path::Path; +use crate::db::DB; + pub mod server; mod api; @@ -14,9 +16,16 @@ fn configure_app( web_dir_path: &Path, swagger_url: &str, swagger_dir_path: &Path, + db: &DB, ) { // TODO logging - cfg.service(web::scope("/api").service(api::get_version)) + cfg.data(db.clone()) + .service( + web::scope("/api") + .service(api::get_initial_setup) + .service(api::get_version) + .service(api::put_settings), + ) .service(fs::Files::new(swagger_url, swagger_dir_path).index_file("index.html")) .service(fs::Files::new(web_url, web_dir_path).index_file("index.html")); } diff --git a/src/service/actix/server.rs b/src/service/actix/server.rs index 9b82631..2dcc699 100644 --- a/src/service/actix/server.rs +++ b/src/service/actix/server.rs @@ -15,7 +15,7 @@ pub async fn run( web_dir_path: PathBuf, swagger_url: String, swagger_dir_path: PathBuf, - db: Arc, + db: DB, command_sender: Arc, ) -> Result<()> { HttpServer::new(move || { @@ -26,6 +26,7 @@ pub async fn run( web_dir_path.as_path(), &swagger_url, swagger_dir_path.as_path(), + &db, ) }) }) diff --git a/src/service/actix/tests/api.rs b/src/service/actix/tests/api.rs index 6564f1f..e6b72c1 100644 --- a/src/service/actix/tests/api.rs +++ b/src/service/actix/tests/api.rs @@ -1,23 +1,114 @@ -use actix_web::body::Body::Bytes; +use actix_http::Request; use actix_web::dev::*; use actix_web::test::TestRequest; use actix_web::{test, App}; +use super::configure_test_app; +use crate::config; use crate::service::dto; +use crate::vfs; + +const TEST_USERNAME: &str = "test_user"; +const TEST_PASSWORD: &str = "test_password"; +const TEST_MOUNT_NAME: &str = "collection"; +const TEST_MOUNT_SOURCE: &str = "test/collection"; + +trait BodyTest { + fn as_u8(&self) -> &[u8]; +} + +impl BodyTest for ResponseBody { + fn as_u8(&self) -> &[u8] { + match self { + ResponseBody::Body(ref b) => match b { + Body::Bytes(ref by) => by.as_ref(), + _ => panic!(), + }, + ResponseBody::Other(ref b) => match b { + Body::Bytes(ref by) => by.as_ref(), + _ => panic!(), + }, + } + } +} + +fn initial_setup() -> Request { + let configuration = config::Config { + album_art_pattern: None, + prefix_url: None, + reindex_every_n_seconds: None, + ydns: None, + users: Some(vec![config::ConfigUser { + name: TEST_USERNAME.into(), + password: TEST_PASSWORD.into(), + admin: true, + }]), + mount_dirs: Some(vec![vfs::MountPoint { + name: TEST_MOUNT_NAME.into(), + source: TEST_MOUNT_SOURCE.into(), + }]), + }; + + TestRequest::put() + .uri("/api/settings") + .set_json(&configuration) + .to_request() +} #[actix_rt::test] async fn test_version() { - let app = App::new().configure(super::configure_test_app); + let app = App::new().configure(|cfg| configure_test_app(cfg, "test_version")); let mut service = test::init_service(app).await; let req = TestRequest::get().uri("/api/version").to_request(); let resp = service.call(req).await.unwrap(); assert!(resp.status().is_success()); - let body = match resp.response().body().as_ref() { - Some(Bytes(bytes)) => bytes, - _ => panic!("Response error"), - }; - + let body = resp.response().body().as_u8(); let response_json: dto::Version = serde_json::from_slice(body).unwrap(); assert_eq!(response_json, dto::Version { major: 4, minor: 0 }); } + +#[actix_rt::test] +async fn test_initial_setup() { + let app = App::new().configure(|cfg| configure_test_app(cfg, "test_initial_setup")); + let mut service = test::init_service(app).await; + + { + let req = TestRequest::get().uri("/api/initial_setup").to_request(); + let resp = service.call(req).await.unwrap(); + assert!(resp.status().is_success()); + + let body = resp.response().body().as_u8(); + let response_json: dto::InitialSetup = serde_json::from_slice(body).unwrap(); + + assert_eq!( + response_json, + dto::InitialSetup { + has_any_users: false + } + ); + } + + assert!(service + .call(initial_setup()) + .await + .unwrap() + .status() + .is_success()); + + { + let req = TestRequest::get().uri("/api/initial_setup").to_request(); + let resp = service.call(req).await.unwrap(); + assert!(resp.status().is_success()); + + let body = resp.response().body().as_u8(); + let response_json: dto::InitialSetup = serde_json::from_slice(body).unwrap(); + + assert_eq!( + response_json, + dto::InitialSetup { + has_any_users: true + } + ); + } +} diff --git a/src/service/actix/tests/mod.rs b/src/service/actix/tests/mod.rs index ea62291..897a9ea 100644 --- a/src/service/actix/tests/mod.rs +++ b/src/service/actix/tests/mod.rs @@ -1,10 +1,13 @@ +use std::fs; use std::path::PathBuf; +use crate::db::DB; + mod api; mod swagger; mod web; -fn configure_test_app(cfg: &mut actix_web::web::ServiceConfig) { +fn configure_test_app(cfg: &mut actix_web::web::ServiceConfig, db_name: &str) { let web_url = "/"; let web_dir_path = PathBuf::from("web"); @@ -12,11 +15,20 @@ fn configure_test_app(cfg: &mut actix_web::web::ServiceConfig) { let mut swagger_dir_path = PathBuf::from("docs"); swagger_dir_path.push("swagger"); + let mut db_path = PathBuf::new(); + db_path.push("test"); + db_path.push(format!("{}.sqlite", db_name)); + if db_path.exists() { + fs::remove_file(&db_path).unwrap(); + } + let db = DB::new(&db_path).unwrap(); + super::configure_app( cfg, web_url, web_dir_path.as_path(), swagger_url, swagger_dir_path.as_path(), + &db, ); } diff --git a/src/service/actix/tests/swagger.rs b/src/service/actix/tests/swagger.rs index f717ae8..73f5ecd 100644 --- a/src/service/actix/tests/swagger.rs +++ b/src/service/actix/tests/swagger.rs @@ -2,9 +2,11 @@ use actix_web::dev::Service; use actix_web::test::TestRequest; use actix_web::{test, App}; +use super::configure_test_app; + #[actix_rt::test] -async fn test_index() { - let app = App::new().configure(super::configure_test_app); +async fn test_swagger_index() { + let app = App::new().configure(|cfg| configure_test_app(cfg, "test_swagger_index")); let mut service = test::init_service(app).await; let req = TestRequest::get().uri("/swagger").to_request(); let resp = service.call(req).await.unwrap(); @@ -12,8 +14,9 @@ async fn test_index() { } #[actix_rt::test] -async fn test_index_with_trailing_slash() { - let app = App::new().configure(super::configure_test_app); +async fn test_swagger_index_with_trailing_slash() { + let app = App::new() + .configure(|cfg| configure_test_app(cfg, "test_swagger_index_with_trailing_slash")); let mut service = test::init_service(app).await; let req = TestRequest::get().uri("/swagger/").to_request(); let resp = service.call(req).await.unwrap(); diff --git a/src/service/actix/tests/web.rs b/src/service/actix/tests/web.rs index 898e9d1..70b7a15 100644 --- a/src/service/actix/tests/web.rs +++ b/src/service/actix/tests/web.rs @@ -2,9 +2,11 @@ use actix_web::dev::Service; use actix_web::test::TestRequest; use actix_web::{test, App}; +use super::configure_test_app; + #[actix_rt::test] async fn test_index() { - let app = App::new().configure(super::configure_test_app); + let app = App::new().configure(|cfg| configure_test_app(cfg, "test_index")); let mut service = test::init_service(app).await; let req = TestRequest::get().uri("/").to_request(); let resp = service.call(req).await.unwrap(); diff --git a/src/service/dto.rs b/src/service/dto.rs index 51973f1..fb31e5c 100644 --- a/src/service/dto.rs +++ b/src/service/dto.rs @@ -5,3 +5,8 @@ pub struct Version { pub major: i32, pub minor: i32, } + +#[derive(PartialEq, Debug, Serialize, Deserialize)] +pub struct InitialSetup { + pub has_any_users: bool, +} diff --git a/src/service/error.rs b/src/service/error.rs new file mode 100644 index 0000000..0ff3340 --- /dev/null +++ b/src/service/error.rs @@ -0,0 +1,15 @@ +use thiserror::Error; + +#[derive(Error, Debug)] +pub enum APIError { + #[error("Incorrect Credentials")] + IncorrectCredentials, + #[error("Unspecified")] + Unspecified, +} + +impl From for APIError { + fn from(_: anyhow::Error) -> Self { + APIError::Unspecified + } +} diff --git a/src/service/mod.rs b/src/service/mod.rs index fe22d69..6be9bb4 100644 --- a/src/service/mod.rs +++ b/src/service/mod.rs @@ -1,5 +1,6 @@ mod constants; mod dto; +mod error; #[cfg(feature = "service-actix")] mod actix; diff --git a/src/service/rocket/api.rs b/src/service/rocket/api.rs index 538f02f..a61e088 100644 --- a/src/service/rocket/api.rs +++ b/src/service/rocket/api.rs @@ -11,7 +11,6 @@ use std::path::PathBuf; use std::str; use std::str::FromStr; use std::sync::Arc; -use thiserror::Error; use time::Duration; use super::serve; @@ -22,6 +21,7 @@ use crate::lastfm; use crate::playlist; use crate::service::constants::*; use crate::service::dto; +use crate::service::error::APIError; use crate::thumbnails; use crate::user; use crate::utils; @@ -57,14 +57,6 @@ pub fn get_routes() -> Vec { ] } -#[derive(Error, Debug)] -enum APIError { - #[error("Incorrect Credentials")] - IncorrectCredentials, - #[error("Unspecified")] - Unspecified, -} - impl<'r> rocket::response::Responder<'r> for APIError { fn respond_to(self, _: &rocket::request::Request<'_>) -> rocket::response::Result<'r> { let status = match self { @@ -75,12 +67,6 @@ impl<'r> rocket::response::Responder<'r> for APIError { } } -impl From for APIError { - fn from(_: anyhow::Error) -> Self { - APIError::Unspecified - } -} - struct Auth { username: String, } @@ -118,7 +104,7 @@ impl<'a, 'r> FromRequest<'a, 'r> for Auth { fn from_request(request: &'a Request<'r>) -> request::Outcome { let mut cookies = request.guard::>().unwrap(); - let db = match request.guard::>>() { + let db = match request.guard::>() { Outcome::Success(d) => d, _ => return Outcome::Failure((Status::InternalServerError, ())), }; @@ -165,16 +151,16 @@ impl<'a, 'r> FromRequest<'a, 'r> for AdminRights { type Error = (); fn from_request(request: &'a Request<'r>) -> request::Outcome { - let db = request.guard::>>()?; + let db = request.guard::>()?; - match user::count::(&db) { + match user::count(&db) { Err(_) => return Outcome::Failure((Status::InternalServerError, ())), Ok(0) => return Outcome::Success(AdminRights {}), _ => (), }; let auth = request.guard::()?; - match user::is_admin::(&db, &auth.username) { + match user::is_admin(&db, &auth.username) { Err(_) => Outcome::Failure((Status::InternalServerError, ())), Ok(true) => Outcome::Success(AdminRights {}), Ok(false) => Outcome::Failure((Status::Forbidden, ())), @@ -212,48 +198,35 @@ fn version() -> Json { Json(current_version) } -#[derive(PartialEq, Debug, Serialize, Deserialize)] -pub struct InitialSetup { - pub has_any_users: bool, -} - #[get("/initial_setup")] -fn initial_setup(db: State<'_, Arc>) -> Result> { - let initial_setup = InitialSetup { - has_any_users: user::count::(&db)? > 0, +fn initial_setup(db: State<'_, DB>) -> Result> { + let initial_setup = dto::InitialSetup { + has_any_users: user::count(&db)? > 0, }; Ok(Json(initial_setup)) } #[get("/settings")] -fn get_settings(db: State<'_, Arc>, _admin_rights: AdminRights) -> Result> { - let config = config::read::(&db)?; +fn get_settings(db: State<'_, DB>, _admin_rights: AdminRights) -> Result> { + let config = config::read(&db)?; Ok(Json(config)) } #[put("/settings", data = "")] -fn put_settings( - db: State<'_, Arc>, - _admin_rights: AdminRights, - config: Json, -) -> Result<()> { - config::amend::(&db, &config)?; +fn put_settings(db: State<'_, DB>, _admin_rights: AdminRights, config: Json) -> Result<()> { + config::amend(&db, &config)?; Ok(()) } #[get("/preferences")] -fn get_preferences(db: State<'_, Arc>, auth: Auth) -> Result> { - let preferences = config::read_preferences::(&db, &auth.username)?; +fn get_preferences(db: State<'_, DB>, auth: Auth) -> Result> { + let preferences = config::read_preferences(&db, &auth.username)?; Ok(Json(preferences)) } #[put("/preferences", data = "")] -fn put_preferences( - db: State<'_, Arc>, - auth: Auth, - preferences: Json, -) -> Result<()> { - config::write_preferences::(&db, &auth.username, &preferences)?; +fn put_preferences(db: State<'_, DB>, auth: Auth, preferences: Json) -> Result<()> { + config::write_preferences(&db, &auth.username, &preferences)?; Ok(()) } @@ -279,27 +252,27 @@ struct AuthOutput { #[post("/auth", data = "")] fn auth( - db: State<'_, Arc>, + db: State<'_, DB>, credentials: Json, mut cookies: Cookies<'_>, ) -> std::result::Result<(), APIError> { - if !user::auth::(&db, &credentials.username, &credentials.password)? { + if !user::auth(&db, &credentials.username, &credentials.password)? { return Err(APIError::IncorrectCredentials); } - let is_admin = user::is_admin::(&db, &credentials.username)?; + let is_admin = user::is_admin(&db, &credentials.username)?; add_session_cookies(&mut cookies, &credentials.username, is_admin); Ok(()) } #[get("/browse")] -fn browse_root(db: State<'_, Arc>, _auth: Auth) -> Result>> { +fn browse_root(db: State<'_, DB>, _auth: Auth) -> Result>> { let result = index::browse(db.deref().deref(), &PathBuf::new())?; Ok(Json(result)) } #[get("/browse/")] fn browse( - db: State<'_, Arc>, + db: State<'_, DB>, _auth: Auth, path: VFSPathBuf, ) -> Result>> { @@ -308,42 +281,38 @@ fn browse( } #[get("/flatten")] -fn flatten_root(db: State<'_, Arc>, _auth: Auth) -> Result>> { +fn flatten_root(db: State<'_, DB>, _auth: Auth) -> Result>> { let result = index::flatten(db.deref().deref(), &PathBuf::new())?; Ok(Json(result)) } #[get("/flatten/")] -fn flatten( - db: State<'_, Arc>, - _auth: Auth, - path: VFSPathBuf, -) -> Result>> { +fn flatten(db: State<'_, DB>, _auth: Auth, path: VFSPathBuf) -> Result>> { let result = index::flatten(db.deref().deref(), &path.into() as &PathBuf)?; Ok(Json(result)) } #[get("/random")] -fn random(db: State<'_, Arc>, _auth: Auth) -> Result>> { +fn random(db: State<'_, DB>, _auth: Auth) -> Result>> { let result = index::get_random_albums(db.deref().deref(), 20)?; Ok(Json(result)) } #[get("/recent")] -fn recent(db: State<'_, Arc>, _auth: Auth) -> Result>> { +fn recent(db: State<'_, DB>, _auth: Auth) -> Result>> { let result = index::get_recent_albums(db.deref().deref(), 20)?; Ok(Json(result)) } #[get("/search")] -fn search_root(db: State<'_, Arc>, _auth: Auth) -> Result>> { +fn search_root(db: State<'_, DB>, _auth: Auth) -> Result>> { let result = index::search(db.deref().deref(), "")?; Ok(Json(result)) } #[get("/search/")] fn search( - db: State<'_, Arc>, + db: State<'_, DB>, _auth: Auth, query: String, ) -> Result>> { @@ -352,12 +321,7 @@ fn search( } #[get("/serve/")] -fn serve( - db: State<'_, Arc>, - _auth: Auth, - path: VFSPathBuf, -) -> Result> { - let db: &DB = db.deref().deref(); +fn serve(db: State<'_, DB>, _auth: Auth, path: VFSPathBuf) -> Result> { let vfs = db.get_vfs()?; let real_path = vfs.virtual_to_real(&path.into() as &PathBuf)?; @@ -377,7 +341,7 @@ pub struct ListPlaylistsEntry { } #[get("/playlists")] -fn list_playlists(db: State<'_, Arc>, auth: Auth) -> Result>> { +fn list_playlists(db: State<'_, DB>, auth: Auth) -> Result>> { let playlist_names = playlist::list_playlists(&auth.username, db.deref().deref())?; let playlists: Vec = playlist_names .into_iter() @@ -394,7 +358,7 @@ pub struct SavePlaylistInput { #[put("/playlist/", data = "")] fn save_playlist( - db: State<'_, Arc>, + db: State<'_, DB>, auth: Auth, name: String, playlist: Json, @@ -404,23 +368,19 @@ fn save_playlist( } #[get("/playlist/")] -fn read_playlist( - db: State<'_, Arc>, - auth: Auth, - name: String, -) -> Result>> { +fn read_playlist(db: State<'_, DB>, auth: Auth, name: String) -> Result>> { let songs = playlist::read_playlist(&name, &auth.username, db.deref().deref())?; Ok(Json(songs)) } #[delete("/playlist/")] -fn delete_playlist(db: State<'_, Arc>, auth: Auth, name: String) -> Result<()> { +fn delete_playlist(db: State<'_, DB>, auth: Auth, name: String) -> Result<()> { playlist::delete_playlist(&name, &auth.username, db.deref().deref())?; Ok(()) } #[put("/lastfm/now_playing/")] -fn lastfm_now_playing(db: State<'_, Arc>, auth: Auth, path: VFSPathBuf) -> Result<()> { +fn lastfm_now_playing(db: State<'_, DB>, auth: Auth, path: VFSPathBuf) -> Result<()> { if user::is_lastfm_linked(db.deref().deref(), &auth.username) { lastfm::now_playing(db.deref().deref(), &auth.username, &path.into() as &PathBuf)?; } @@ -428,7 +388,7 @@ fn lastfm_now_playing(db: State<'_, Arc>, auth: Auth, path: VFSPathBuf) -> R } #[post("/lastfm/scrobble/")] -fn lastfm_scrobble(db: State<'_, Arc>, auth: Auth, path: VFSPathBuf) -> Result<()> { +fn lastfm_scrobble(db: State<'_, DB>, auth: Auth, path: VFSPathBuf) -> Result<()> { if user::is_lastfm_linked(db.deref().deref(), &auth.username) { lastfm::scrobble(db.deref().deref(), &auth.username, &path.into() as &PathBuf)?; } @@ -437,7 +397,7 @@ fn lastfm_scrobble(db: State<'_, Arc>, auth: Auth, path: VFSPathBuf) -> Resu #[get("/lastfm/link?&")] fn lastfm_link( - db: State<'_, Arc>, + db: State<'_, DB>, auth: Auth, token: String, content: String, @@ -457,7 +417,7 @@ fn lastfm_link( } #[delete("/lastfm/link")] -fn lastfm_unlink(db: State<'_, Arc>, auth: Auth) -> Result<()> { +fn lastfm_unlink(db: State<'_, DB>, auth: Auth) -> Result<()> { lastfm::unlink(db.deref().deref(), &auth.username)?; Ok(()) } diff --git a/src/service/rocket/api_tests.rs b/src/service/rocket/api_tests.rs index bd3dcac..52914c5 100644 --- a/src/service/rocket/api_tests.rs +++ b/src/service/rocket/api_tests.rs @@ -71,10 +71,10 @@ fn initial_setup() { let mut response = client.get("/api/initial_setup").dispatch(); assert_eq!(response.status(), Status::Ok); let response_body = response.body_string().unwrap(); - let response_json: api::InitialSetup = serde_json::from_str(&response_body).unwrap(); + let response_json: dto::InitialSetup = serde_json::from_str(&response_body).unwrap(); assert_eq!( response_json, - api::InitialSetup { + dto::InitialSetup { has_any_users: false } ); @@ -86,10 +86,10 @@ fn initial_setup() { let mut response = client.get("/api/initial_setup").dispatch(); assert_eq!(response.status(), Status::Ok); let response_body = response.body_string().unwrap(); - let response_json: api::InitialSetup = serde_json::from_str(&response_body).unwrap(); + let response_json: dto::InitialSetup = serde_json::from_str(&response_body).unwrap(); assert_eq!( response_json, - api::InitialSetup { + dto::InitialSetup { has_any_users: true } ); diff --git a/src/service/rocket/server.rs b/src/service/rocket/server.rs index 05dbda7..84fca9d 100644 --- a/src/service/rocket/server.rs +++ b/src/service/rocket/server.rs @@ -17,7 +17,7 @@ pub fn get_server( web_dir_path: &PathBuf, swagger_url: &str, swagger_dir_path: &PathBuf, - db: Arc, + db: DB, command_sender: Arc, ) -> Result { let mut config = rocket::Config::build(Environment::Production) @@ -56,7 +56,7 @@ pub fn run( web_dir_path: PathBuf, swagger_url: String, swagger_dir_path: PathBuf, - db: Arc, + db: DB, command_sender: Arc, ) -> Result<()> { let server = get_server( diff --git a/src/service/rocket/test.rs b/src/service/rocket/test.rs index 2705822..a0072f8 100644 --- a/src/service/rocket/test.rs +++ b/src/service/rocket/test.rs @@ -12,12 +12,12 @@ use crate::index; pub struct TestEnvironment { pub client: Client, command_sender: Arc, - db: Arc, + db: DB, } impl TestEnvironment { pub fn update_index(&self) { - index::update(self.db.deref()).unwrap(); + index::update(&self.db).unwrap(); } } @@ -34,8 +34,7 @@ pub fn get_test_environment(db_name: &str) -> TestEnvironment { if db_path.exists() { fs::remove_file(&db_path).unwrap(); } - - let db = Arc::new(DB::new(&db_path).unwrap()); + let db = DB::new(&db_path).unwrap(); let web_dir_path = PathBuf::from("web"); let mut swagger_dir_path = PathBuf::from("docs"); diff --git a/src/user.rs b/src/user.rs index a4055f7..e0b4dbe 100644 --- a/src/user.rs +++ b/src/user.rs @@ -1,10 +1,9 @@ use anyhow::*; -use core::ops::Deref; use diesel; use diesel::prelude::*; use crate::db::users; -use crate::db::ConnectionSource; +use crate::db::DB; #[derive(Debug, Insertable, Queryable)] #[table_name = "users"] @@ -38,16 +37,13 @@ fn verify_password(password_hash: &str, attempted_password: &str) -> bool { pbkdf2::pbkdf2_check(attempted_password, password_hash).is_ok() } -pub fn auth(db: &T, username: &str, password: &str) -> Result -where - T: ConnectionSource, -{ +pub fn auth(db: &DB, username: &str, password: &str) -> Result { use crate::db::users::dsl::*; - let connection = db.get_connection(); + let connection = db.connect()?; match users .select(password_hash) .filter(name.eq(username)) - .get_result(connection.deref()) + .get_result(&connection) { Err(diesel::result::Error::NotFound) => Ok(false), Ok(hash) => { @@ -58,88 +54,67 @@ where } } -pub fn count(db: &T) -> Result -where - T: ConnectionSource, -{ +pub fn count(db: &DB) -> Result { use crate::db::users::dsl::*; - let connection = db.get_connection(); - let count = users.count().get_result(connection.deref())?; + let connection = db.connect()?; + let count = users.count().get_result(&connection)?; Ok(count) } -pub fn exists(db: &T, username: &str) -> Result -where - T: ConnectionSource, -{ +pub fn exists(db: &DB, username: &str) -> Result { use crate::db::users::dsl::*; - let connection = db.get_connection(); + let connection = db.connect()?; let results: Vec = users .select(name) .filter(name.eq(username)) - .get_results(connection.deref())?; + .get_results(&connection)?; Ok(results.len() > 0) } -pub fn is_admin(db: &T, username: &str) -> Result -where - T: ConnectionSource, -{ +pub fn is_admin(db: &DB, username: &str) -> Result { use crate::db::users::dsl::*; - let connection = db.get_connection(); + let connection = db.connect()?; let is_admin: i32 = users .filter(name.eq(username)) .select(admin) - .get_result(connection.deref())?; + .get_result(&connection)?; Ok(is_admin != 0) } -pub fn lastfm_link(db: &T, username: &str, lastfm_login: &str, session_key: &str) -> Result<()> -where - T: ConnectionSource, -{ +pub fn lastfm_link(db: &DB, username: &str, lastfm_login: &str, session_key: &str) -> Result<()> { use crate::db::users::dsl::*; - let connection = db.get_connection(); + let connection = db.connect()?; diesel::update(users.filter(name.eq(username))) .set(( lastfm_username.eq(lastfm_login), lastfm_session_key.eq(session_key), )) - .execute(connection.deref())?; + .execute(&connection)?; Ok(()) } -pub fn get_lastfm_session_key(db: &T, username: &str) -> Result -where - T: ConnectionSource, -{ +pub fn get_lastfm_session_key(db: &DB, username: &str) -> Result { use crate::db::users::dsl::*; - let connection = db.get_connection(); + let connection = db.connect()?; let token = users .filter(name.eq(username)) .select(lastfm_session_key) - .get_result(connection.deref())?; + .get_result(&connection)?; match token { Some(t) => Ok(t), _ => Err(anyhow!("Missing LastFM credentials")), } } -pub fn is_lastfm_linked(db: &T, username: &str) -> bool -where - T: ConnectionSource, -{ +pub fn is_lastfm_linked(db: &DB, username: &str) -> bool { get_lastfm_session_key(db, username).is_ok() } -pub fn lastfm_unlink(db: &T, username: &str) -> Result<()> -where - T: ConnectionSource, -{ +pub fn lastfm_unlink(db: &DB, username: &str) -> Result<()> { use crate::db::users::dsl::*; - let connection = db.get_connection(); + let connection = db.connect()?; diesel::update(users.filter(name.eq(username))) .set((lastfm_session_key.eq(""), lastfm_username.eq(""))) - .execute(connection.deref())?; + .execute(&connection)?; Ok(()) } diff --git a/src/vfs.rs b/src/vfs.rs index 44f8584..d5b9280 100644 --- a/src/vfs.rs +++ b/src/vfs.rs @@ -7,7 +7,7 @@ use std::path::Path; use std::path::PathBuf; use crate::db::mount_points; -use crate::db::{ConnectionSource, DB}; +use crate::db::DB; pub trait VFSSource { fn get_vfs(&self) -> Result; @@ -17,7 +17,7 @@ impl VFSSource for DB { fn get_vfs(&self) -> Result { use self::mount_points::dsl::*; let mut vfs = VFS::new(); - let connection = self.get_connection(); + let connection = self.connect()?; let points: Vec = mount_points .select((source, name)) .get_results(connection.deref())?;