diff --git a/.gitignore b/.gitignore index ea8c4bf..4ef3e1e 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ /target +api_keys.json diff --git a/Cargo.lock b/Cargo.lock index 77a67e9..7385107 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -815,11 +815,13 @@ version = "0.1.0" dependencies = [ "dashmap", "futures", + "lazy_static", "okapi", "rocket", "rocket_db_pools", "rocket_okapi", "schemars", + "serde_json", "sqlx", "thiserror", "uuid", diff --git a/Cargo.toml b/Cargo.toml index 131211d..03559c0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,8 +11,10 @@ dashmap = "5.3.4" thiserror = "1.0" schemars = "0.8.10" okapi = { version = "0.7.0-rc.1" } -rocket_okapi = { version = "0.8.0-rc.2", features = ["swagger", "rocket_db_pools"] } +rocket_okapi = { version = "0.8.0-rc.2", features = ["swagger", "rocket_db_pools", "rapidoc"] } futures = "0.3" +lazy_static = "1.4" +serde_json = "1" [dependencies.sqlx] version = "*" diff --git a/party.sqlite-shm b/party.sqlite-shm index f3295f4..8711cc1 100644 Binary files a/party.sqlite-shm and b/party.sqlite-shm differ diff --git a/party.sqlite-wal b/party.sqlite-wal index 6e724b9..18d349c 100644 Binary files a/party.sqlite-wal and b/party.sqlite-wal differ diff --git a/src/api/mod.rs b/src/api/mod.rs new file mode 100644 index 0000000..22d12a3 --- /dev/null +++ b/src/api/mod.rs @@ -0,0 +1 @@ +pub mod user; diff --git a/src/api/user.rs b/src/api/user.rs new file mode 100644 index 0000000..d806d37 --- /dev/null +++ b/src/api/user.rs @@ -0,0 +1,172 @@ +use crate::{ApiKey, Db, Ordering, PartyError}; +use okapi::openapi3::OpenApi; +use rocket::{ + fairing::AdHoc, + http::Status, + response::status::Created, + serde::{json::Json, Deserialize, Serialize}, +}; +use rocket_db_pools::{sqlx, Connection}; +use rocket_okapi::{ + openapi, openapi_get_routes, openapi_get_routes_spec, settings::OpenApiSettings, JsonSchema, +}; +use sqlx::FromRow; + +/// # User +/// +/// A user that represents a person participating in the LAN party +#[derive(Clone, Debug, FromForm, Serialize, Deserialize, JsonSchema, FromRow)] +#[serde(crate = "rocket::serde")] +pub struct User { + /// Name of the user + name: String, + /// Score of the user + #[serde(default)] + score: i64, + /// Unique identifier of the user + #[serde(default)] + id: i64, +} + +/// # Create new user with the give name +/// +/// Returns the created user +#[openapi(tag = "User")] +#[post("/", data = "")] +pub async fn add_user( + _api_key: ApiKey, + mut db: Connection, + name: Json<&str>, +) -> Result>, PartyError> { + let result = sqlx::query!("INSERT INTO users (name) VALUES (?)", *name) + .execute(&mut *db) + .await?; + + let user = User { + id: result.last_insert_rowid(), + score: 0, + name: name.to_string(), + }; + + Ok(Created::new("/").body(Json(user))) +} + +/// # Delete user by id +#[openapi(tag = "User")] +#[delete("/")] +pub async fn delete_user( + _api_key: ApiKey, + mut db: Connection, + id: i64, +) -> Result { + sqlx::query!("DELETE FROM users where (id = ?)", id) + .execute(&mut *db) + .await?; + + Ok(Status::Ok) +} + +/// # Get user by id +/// +/// Returns a single user by id +#[openapi(tag = "User")] +#[get("/")] +pub async fn get_user( + _api_key: ApiKey, + mut db: Connection, + id: i64, +) -> Result, PartyError> { + let user = sqlx::query_as!(User, "SELECT id, name, score FROM users WHERE (id = ?)", id) + .fetch_one(&mut *db) + .await?; + Ok(Json(user)) +} + +/// # UserSort +/// +/// Field used to sort users +#[derive(Clone, Debug, FromFormField, JsonSchema)] +#[serde(rename_all = "snake_case")] +pub enum UserSort { + #[field(value = "score")] + Score, + #[field(value = "name")] + Name, + #[field(value = "id")] + Id, +} + +impl ToString for UserSort { + fn to_string(&self) -> String { + match self { + Self::Score => "score", + Self::Name => "name", + Self::Id => "id", + } + .into() + } +} + +/// # Get all users +/// +/// Returns an array of all users sorted by the given sort field and ordering +#[openapi(tag = "User")] +#[get("/?&")] +pub async fn get_all_users( + _api_key: ApiKey, + mut db: Connection, + sort: Option, + order: Option, +) -> Result>, PartyError> { + let users = sqlx::query_as::<_, User>(&format!( + "SELECT id, name, score FROM users ORDER BY {} {}", + sort.unwrap_or(UserSort::Id).to_string(), + order.unwrap_or(Ordering::Asc).to_string() + )) + .fetch_all(&mut *db) + .await?; + + Ok(Json(users)) +} + +/// # Set score of user by id +#[openapi(tag = "User")] +#[post("//score", data = "")] +pub async fn set_score( + _api_key: ApiKey, + mut db: Connection, + id: i64, + score: Json, +) -> Result { + sqlx::query!("UPDATE users SET score = ? WHERE id = ?", *score, id) + .execute(&mut *db) + .await?; + Ok(Status::Ok) +} + +/// # Get score of user by id +/// +/// Returns the score of a single user by id +#[openapi(tag = "User")] +#[get("//score")] +pub async fn get_score( + _api_key: ApiKey, + mut db: Connection, + id: i64, +) -> Result, PartyError> { + let score = sqlx::query_scalar!("SELECT score FROM users WHERE id = ?", id) + .fetch_one(&mut *db) + .await?; + Ok(Json(score)) +} + +pub fn get_routes_and_docs(settings: &OpenApiSettings) -> (Vec, OpenApi) { + openapi_get_routes_spec![ + settings: add_user, + get_user, + get_all_users, + delete_user, + set_score, + get_score + ] +} diff --git a/src/main.rs b/src/main.rs index 98a763b..7b90cac 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,31 +1,45 @@ -use futures::{prelude::*, stream::TryStreamExt}; +mod api; + +use lazy_static::lazy_static; +use okapi::openapi3::{Object, Responses, SecurityRequirement, SecurityScheme, SecuritySchemeData}; use rocket::{ - fairing, - fairing::AdHoc, - http::Status, - response, - response::{status::Created, Responder}, - serde::{json::Json, Deserialize, Serialize}, - Build, Request, Rocket, -}; -use rocket_db_pools::{ - sqlx::{self}, - Connection, Database, + fairing, fairing::AdHoc, http::Status, request, request::FromRequest, response, + response::Responder, serde, Build, Request, Rocket, }; +use rocket_db_pools::{sqlx, Database}; use rocket_okapi::{ - openapi, openapi_get_routes, + gen::OpenApiGenerator, + mount_endpoints_and_merged_docs, openapi, openapi_get_routes, + rapidoc::*, + request::{OpenApiFromRequest, RequestHeaderInput}, response::OpenApiResponderInner, + settings::{OpenApiSettings, UrlObject}, swagger_ui::{make_swagger_ui, SwaggerUIConfig}, - JsonSchema, }; -use sqlx::{Acquire, Connection as SqlxConnection, FromRow, Sqlite}; -use std::collections::HashMap; #[macro_use] extern crate rocket; +use schemars::JsonSchema; use thiserror::Error; +/* +const API_KEYS: [&'static str; 3] = [ + "7de10bf6-278d-11ed-ad60-a8a15919d1b3", + "89eb06e0-278d-11ed-9b29-a8a15919d1b3", + "8a35ba14-278d-11ed-a200-a8a15919d1b3", +]; +*/ + +lazy_static! { + static ref API_KEYS: Vec = { + serde_json::from_str( + &std::fs::read_to_string("api_keys.json").expect("api_keys.json does not exist"), + ) + .expect("api_keys.json is not valid") + }; +} + #[derive(Error, Debug)] pub enum PartyError { #[error("user `{0}` does not exist")] @@ -56,11 +70,7 @@ impl OpenApiResponderInner for PartyError { impl<'r, 'o: 'r> Responder<'r, 'o> for PartyError { fn respond_to(self, req: &'r Request<'_>) -> response::Result<'o> { - // log `self` to your favored error tracker, e.g. - // sentry::capture_error(&self); - match self { - // in our simplistic example, we're happy to respond with the default 500 responder in all cases Self::UserNotFound(_) => Status::NotFound, _ => Status::InternalServerError, } @@ -68,73 +78,12 @@ impl<'r, 'o: 'r> Responder<'r, 'o> for PartyError { } } -#[derive(Clone, Debug, FromForm, Serialize, Deserialize, JsonSchema, FromRow)] -#[serde(crate = "rocket::serde")] -struct User { - name: String, - #[serde(default)] - score: i64, - #[serde(default)] - id: i64, -} - -#[derive(Database)] -#[database("party")] -struct Db(sqlx::SqlitePool); - -#[openapi] -#[get("/")] -fn index() -> String { - format!("Hello, world!") -} - -#[openapi] -#[post("/user", data = "")] -async fn add_user( - mut db: Connection, - mut user: Json, -) -> Result>, PartyError> { - let result = sqlx::query!("INSERT INTO users (name) VALUES (?)", user.name) - .execute(&mut *db) - .await?; - user.id = result.last_insert_rowid(); - - Ok(Created::new("/").body(user)) -} - -#[openapi] -#[delete("/user/")] -async fn delete_user(mut db: Connection, id: i64) -> Result { - sqlx::query!("DELETE FROM users where (id = ?)", id) - .execute(&mut *db) - .await?; - - Ok(Status::Ok) -} - -#[openapi] -#[get("/user/")] -async fn get_user(mut db: Connection, id: i64) -> Result, PartyError> { - let user = sqlx::query_as!(User, "SELECT id, name, score FROM users WHERE (id = ?)", id) - .fetch_one(&mut *db) - .await?; - Ok(Json(user)) -} - +/// # Ordering +/// +/// Ordering of data in an array, ascending or descending #[derive(Clone, Debug, FromFormField, JsonSchema)] #[serde(rename_all = "snake_case")] -enum UserSort { - #[field(value = "score")] - Score, - #[field(value = "name")] - Name, - #[field(value = "id")] - Id, -} - -#[derive(Clone, Debug, FromFormField, JsonSchema)] -#[serde(rename_all = "snake_case")] -enum Ordering { +pub enum Ordering { #[field(value = "desc")] Desc, #[field(value = "asc")] @@ -151,55 +100,66 @@ impl ToString for Ordering { } } -impl ToString for UserSort { - fn to_string(&self) -> String { - match self { - Self::Score => "score", - Self::Name => "name", - Self::Id => "id", +#[derive(Database)] +#[database("party")] +pub struct Db(sqlx::SqlitePool); + +#[derive(Clone, Copy, Debug)] +pub struct ApiKey; + +#[rocket::async_trait] +impl<'r> FromRequest<'r> for ApiKey { + type Error = ApiKey; + + async fn from_request(req: &'r Request<'_>) -> request::Outcome { + if req + .headers() + .get("X-API-Key") + .any(|k| API_KEYS.contains(&String::from(k))) + { + request::Outcome::Success(ApiKey) + } else { + request::Outcome::Failure((Status::Unauthorized, ApiKey)) } - .into() + } +} + +impl<'a> OpenApiFromRequest<'a> for ApiKey { + fn from_request_input( + _gen: &mut OpenApiGenerator, + _name: String, + _required: bool, + ) -> rocket_okapi::Result { + // Setup global requirement for Security scheme + let security_scheme = SecurityScheme { + description: Some("Requires an API key to access".to_owned()), + // Setup data requirements. + // This can be part of the `header`, `query` or `cookie`. + // In this case the header `x-api-key: mykey` needs to be set. + data: SecuritySchemeData::ApiKey { + name: "x-api-key".to_owned(), + location: "header".to_owned(), + }, + extensions: Object::default(), + }; + // Add the requirement for this route/endpoint + // This can change between routes. + let mut security_req = SecurityRequirement::new(); + // Each security requirement needs to be met before access is allowed. + security_req.insert("ApiKeyAuth".to_owned(), Vec::new()); + // These vvvvvvv-----^^^^^^^^^^ values need to match exactly! + Ok(RequestHeaderInput::Security( + "ApiKeyAuth".to_owned(), + security_scheme, + security_req, + )) } } #[openapi] -#[get("/user?&")] -async fn get_all_users( - mut db: Connection, - sort: Option, - order: Option, -) -> Result>, PartyError> { - let users = sqlx::query_as::<_, User>(&format!( - "SELECT id, name, score FROM users ORDER BY {} {}", - sort.unwrap_or(UserSort::Id).to_string(), - order.unwrap_or(Ordering::Asc).to_string() - )) - .fetch_all(&mut *db) - .await?; - - Ok(Json(users)) -} - -#[openapi] -#[post("/user//score", data = "")] -async fn set_score( - mut db: Connection, - id: i64, - score: Json, -) -> Result { - sqlx::query!("UPDATE users SET score = ? WHERE id = ?", *score, id) - .execute(&mut *db) - .await?; - Ok(Status::Ok) -} - -#[openapi] -#[get("/user//score")] -async fn get_score(mut db: Connection, id: i64) -> Result, PartyError> { - let score = sqlx::query_scalar!("SELECT score FROM users WHERE id = ?", id) - .fetch_one(&mut *db) - .await?; - Ok(Json(score)) +#[get("/")] +fn index() -> String { + format!("Hello, world!") } fn get_docs() -> SwaggerUIConfig { @@ -227,20 +187,33 @@ async fn run_migrations(rocket: Rocket) -> fairing::Result { #[launch] fn rocket() -> _ { - rocket::build() + println!("{:#?}", API_KEYS.len()); + let mut building_rocket = rocket::build() .attach(Db::init()) .attach(AdHoc::try_on_ignite("SQLx Migrations", run_migrations)) - .mount("/", openapi_get_routes![index]) - .mount( - "/api", - openapi_get_routes![ - add_user, - get_user, - get_all_users, - delete_user, - set_score, - get_score - ], - ) + .mount("/", routes![index]) .mount("/swagger", make_swagger_ui(&get_docs())) + .mount( + "/rapidoc/", + make_rapidoc(&RapiDocConfig { + general: GeneralConfig { + spec_urls: vec![UrlObject::new("General", "../api/openapi.json")], + ..Default::default() + }, + hide_show: HideShowConfig { + allow_spec_url_load: false, + allow_spec_file_load: false, + ..Default::default() + }, + ..Default::default() + }), + ); + + let openapi_settings = OpenApiSettings::default(); + mount_endpoints_and_merged_docs! { + building_rocket, "/api".to_owned(), openapi_settings, + "/user" => api::user::get_routes_and_docs(&openapi_settings), + }; + + building_rocket }