From 0751408450f4ad56c4096de78cf4a62223f168c8 Mon Sep 17 00:00:00 2001 From: Daan Vanoverloop Date: Mon, 29 Aug 2022 17:17:31 +0200 Subject: [PATCH] What am I even doing? --- .gitignore | 1 + db/migrations/1_create-events-table.sql | 6 + party.db | Bin 20480 -> 24576 bytes src/api/auth.rs | 72 ++++++++++ src/api/event.rs | 181 ++++++++++++++++++++++++ src/api/mod.rs | 51 ++++++- src/api/user.rs | 38 ++--- src/api/util.rs | 64 +++++++++ src/main.rs | 168 ++-------------------- 9 files changed, 398 insertions(+), 183 deletions(-) create mode 100644 db/migrations/1_create-events-table.sql create mode 100644 src/api/auth.rs create mode 100644 src/api/event.rs create mode 100644 src/api/util.rs diff --git a/.gitignore b/.gitignore index 4ef3e1e..70ba123 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ /target api_keys.json +*.sqlite-* diff --git a/db/migrations/1_create-events-table.sql b/db/migrations/1_create-events-table.sql new file mode 100644 index 0000000..25ed11f --- /dev/null +++ b/db/migrations/1_create-events-table.sql @@ -0,0 +1,6 @@ +CREATE TABLE IF NOT EXISTS events ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name VARCHAR NOT NULL, + event_type VARCHAR NOT NULL, + event_id INTEGER NOT NULL +); diff --git a/party.db b/party.db index fb4a01ce79308d85e7d23ea5f0c2c89c44b23ae0..6ce82fba5422be251f5e14eabec82740b0fa4464 100644 GIT binary patch delta 329 zcmZozz}Rqrae}lU3j+fK8xX?)^F$qEaTW%>Xd_8WOEL3bqsM;2yt}saa90|O_t&Fu}T+gxsid9k**g_VKDx!UU1hyJWFBI^$y)ta+A zJk7Rl$Khk0KYp#hvE%yT4S#R*PwDBLt=e&(E3R-B{~I%n4dF42>>S!iHk1~p78UEJ oB!NwzykFiG=%Vio{D1hrZx&Q|&Cklh%*n{K`Hwz}!XgI&0GL>4S^xk5 delta 98 zcmZoTz}T>Wae}lUGXnzyD-go~(?lI(QDz3cXd_ = { + serde_json::from_str( + &std::fs::read_to_string("api_keys.json").expect("api_keys.json does not exist"), + ) + .expect("api_keys.json is not valid") + }; +} + +#[derive(Clone, Copy, Debug)] +pub struct ApiKey; + +#[rocket::async_trait] +impl<'r> FromRequest<'r> for ApiKey { + type Error = ApiKey; + + async fn from_request(req: &'r Request<'_>) -> request::Outcome { + if req + .headers() + .get("X-API-Key") + .any(|k| API_KEYS.contains(&String::from(k))) + { + request::Outcome::Success(ApiKey) + } else { + request::Outcome::Failure((Status::Unauthorized, ApiKey)) + } + } +} + +impl<'a> OpenApiFromRequest<'a> for ApiKey { + fn from_request_input( + _gen: &mut OpenApiGenerator, + _name: String, + _required: bool, + ) -> rocket_okapi::Result { + // Setup global requirement for Security scheme + let security_scheme = SecurityScheme { + description: Some("Requires an API key to access".to_owned()), + // Setup data requirements. + // This can be part of the `header`, `query` or `cookie`. + // In this case the header `x-api-key: mykey` needs to be set. + data: SecuritySchemeData::ApiKey { + name: "x-api-key".to_owned(), + location: "header".to_owned(), + }, + extensions: Object::default(), + }; + // Add the requirement for this route/endpoint + // This can change between routes. + let mut security_req = SecurityRequirement::new(); + // Each security requirement needs to be met before access is allowed. + security_req.insert("ApiKeyAuth".to_owned(), Vec::new()); + // These vvvvvvv-----^^^^^^^^^^ values need to match exactly! + Ok(RequestHeaderInput::Security( + "ApiKeyAuth".to_owned(), + security_scheme, + security_req, + )) + } +} diff --git a/src/api/event.rs b/src/api/event.rs new file mode 100644 index 0000000..4e52d70 --- /dev/null +++ b/src/api/event.rs @@ -0,0 +1,181 @@ +use std::collections::HashMap; + +use sqlx::FromRow; + +use super::{prelude::*, util::PartyError}; + +api_routes!(); + +#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] +#[serde(crate = "rocket::serde")] +pub struct EventOutcome { + points: HashMap, +} + +#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, FromRow)] +#[serde(crate = "rocket::serde")] +pub struct FreeForAllGame {} + +#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, FromRow)] +#[serde(crate = "rocket::serde")] +pub struct TeamGame {} + +#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, FromRow)] +#[serde(crate = "rocket::serde")] +pub struct Test {} + +// # Event +// +// An event in which participants can win or lose points +#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] +#[serde(crate = "rocket::serde")] +pub enum Event { + FreeForAllGame(FreeForAllGame), + TeamGame(TeamGame), + Test(Test), +} + +pub struct EventRecord { + id: i64, + name: String, + event_type: String, + event_id: i64, +} + +macro_rules! dispatch { + ($event:ident) => { + dispatch!($event, + free_for_all_game => FreeForAllGame, + team_game => TeamGame, + test => Test, + ) + }; + ($event:ident, $($event_type:ident => $event_struct:ident),* $(,)?) => { + match $event.event_type.as_str() { + $(stringify!($event_type) => dispatch_run!($event_type, $event_struct),)* + _ => return Err(PartyError::Unknown("invalid event type".into())), + } + }; +} + +macro_rules! reverse_dispatch { + ($event:ident) => { + reverse_dispatch!($event, + FreeForAllGame => free_for_all_game, + TeamGame => team_game, + Test => test, + ) + }; + ($event:ident, $($event_struct:ident => $event_type:ident),* $(,)?) => { + match $event { + $(Event::$event_struct(e) => reverse_dispatch_run!($event_type, $event_struct, e),)* + } + }; +} + +impl EventRecord { + pub async fn get(db: &mut Connection, id: i64) -> Result { + Ok(sqlx::query_as!( + EventRecord, + "SELECT id, name, event_type, event_id FROM events WHERE id = ?", + id + ) + .fetch_one(&mut **db) + .await?) + } + + pub async fn remove(&self, db: &mut Connection) -> Result<(), PartyError> { + macro_rules! dispatch_run { + ($event_type:ident, $event_struct:ident) => {{ + sqlx::query(&format!( + "DELETE FROM events_{} WHERE id = {}", + stringify!($event_type), + self.event_id + )) + .fetch_one(&mut **db) + .await?; + }}; + } + + dispatch!(self); + + sqlx::query!("DELETE FROM events WHERE id = ?", self.id) + .execute(&mut **db) + .await?; + + Ok(()) + } +} + +impl Event { + pub async fn register( + &self, + db: &mut Connection, + name: String, + ) -> Result { + let event_id = match self { + Self::FreeForAllGame(e) => { + unimplemented!() + /* + sqlx::query!("INSERT INTO events_free_for_all_game () VALUES ()") + .execute(&mut **db) + .await? + .last_insert_rowid() + */ + } + Self::TeamGame(e) => { + unimplemented!() + /* + sqlx::query!("INSERT INTO events_team_game () VALUES ()") + .execute(&mut **db) + .await? + .last_insert_rowid() + */ + } + Self::Test(e) => sqlx::query!("INSERT INTO events_test () VALUES ()") + .execute(&mut **db) + .await? + .last_insert_rowid(), + }; + + macro_rules! reverse_dispatch_run { + ($event_type:ident, $event_struct:ident, $inner:ident) => { + sqlx::query!( + "INSERT INTO events (name, event_type, event_id) VALUES (?, ?, ?)", + stringify!($event_type), + name, + event_id + ) + .execute(&mut **db) + .await? + }; + } + + let id = reverse_dispatch!(self).last_insert_rowid(); + Ok(EventRecord::get(db, id).await?) + } + + pub async fn get(db: &mut Connection, record: EventRecord) -> Result { + macro_rules! dispatch_run { + ($event_type:ident, $event_struct:ident) => { + Event::$event_struct( + sqlx::query_as::<_, $event_struct>(&format!( + "SELECT id, name, event_type, event_id FROM events_{} WHERE id = {}", + stringify!($event_type), + record.event_id + )) + .fetch_one(&mut **db) + .await?, + ) + }; + } + + Ok(dispatch!(record)) + } +} + +#[openapi(tag = "Event")] +#[post("//stop")] +pub fn stop_event(id: i64) -> Result, PartyError> { + todo!() +} diff --git a/src/api/mod.rs b/src/api/mod.rs index 22d12a3..7c99554 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -1 +1,50 @@ -pub mod user; +mod auth; +pub mod util; + +pub use auth::ApiKey; +use rocket::{Build, Rocket}; +use rocket_okapi::{mount_endpoints_and_merged_docs, settings::OpenApiSettings}; + +mod prelude { + pub use super::{util, ApiKey}; + pub use crate::{api_routes, Db}; + pub use rocket::{ + http::Status, + response::status, + serde::{json::Json, Deserialize, Serialize}, + }; + pub use rocket_db_pools::{sqlx, Connection}; + pub use rocket_okapi::{openapi, JsonSchema}; +} + +#[macro_export] +macro_rules! api_routes { + ($($route:ident),* $(,)?) => { + pub fn get_routes_and_docs( + settings: &rocket_okapi::settings::OpenApiSettings + ) -> (Vec, okapi::openapi3::OpenApi) { + rocket_okapi::openapi_get_routes_spec![ + settings: $($route,)* + ] + } + }; +} + +macro_rules! mount_endpoints { + ($($endpoint:ident),* $(,)?) => { + $(pub mod $endpoint;)* + + pub fn mount_endpoints( + mut building_rocket: Rocket, + openapi_settings: &OpenApiSettings, + ) -> Rocket { + mount_endpoints_and_merged_docs! { + building_rocket, "/api".to_owned(), openapi_settings, + $(stringify!("/", $endpoint) => $endpoint::get_routes_and_docs(&openapi_settings),)* + }; + building_rocket + } + }; +} + +mount_endpoints!(user, event); diff --git a/src/api/user.rs b/src/api/user.rs index d806d37..fff8140 100644 --- a/src/api/user.rs +++ b/src/api/user.rs @@ -1,16 +1,15 @@ -use crate::{ApiKey, Db, Ordering, PartyError}; -use okapi::openapi3::OpenApi; -use rocket::{ - fairing::AdHoc, - http::Status, - response::status::Created, - serde::{json::Json, Deserialize, Serialize}, -}; -use rocket_db_pools::{sqlx, Connection}; -use rocket_okapi::{ - openapi, openapi_get_routes, openapi_get_routes_spec, settings::OpenApiSettings, JsonSchema, -}; +use super::prelude::*; use sqlx::FromRow; +use util::{Ordering, PartyError}; + +api_routes!( + add_user, + get_user, + get_all_users, + delete_user, + set_score, + get_score +); /// # User /// @@ -37,7 +36,7 @@ pub async fn add_user( _api_key: ApiKey, mut db: Connection, name: Json<&str>, -) -> Result>, PartyError> { +) -> Result>, PartyError> { let result = sqlx::query!("INSERT INTO users (name) VALUES (?)", *name) .execute(&mut *db) .await?; @@ -48,7 +47,7 @@ pub async fn add_user( name: name.to_string(), }; - Ok(Created::new("/").body(Json(user))) + Ok(status::Created::new("/").body(Json(user))) } /// # Delete user by id @@ -159,14 +158,3 @@ pub async fn get_score( .await?; Ok(Json(score)) } - -pub fn get_routes_and_docs(settings: &OpenApiSettings) -> (Vec, OpenApi) { - openapi_get_routes_spec![ - settings: add_user, - get_user, - get_all_users, - delete_user, - set_score, - get_score - ] -} diff --git a/src/api/util.rs b/src/api/util.rs new file mode 100644 index 0000000..51eb57f --- /dev/null +++ b/src/api/util.rs @@ -0,0 +1,64 @@ +use rocket::{http::Status, response, response::Responder, Request}; +use rocket_okapi::response::OpenApiResponderInner; +use schemars::JsonSchema; +use thiserror::Error; + +/// # Ordering +/// +/// Ordering of data in an array, ascending or descending +#[derive(Clone, Debug, FromFormField, JsonSchema)] +#[serde(rename_all = "snake_case")] +pub enum Ordering { + #[field(value = "desc")] + Desc, + #[field(value = "asc")] + Asc, +} + +impl ToString for Ordering { + fn to_string(&self) -> String { + match self { + Self::Desc => "DESC", + Self::Asc => "ASC", + } + .into() + } +} + +#[derive(Error, Debug)] +pub enum PartyError { + #[error("user `{0}` does not exist")] + UserNotFound(i64), + #[error("unknown error: {0}")] + Unknown(String), + #[error("invalid parameter: {0}")] + InvalidParameter(String), + #[error("uuid error {source:?}")] + UuidError { + #[from] + source: uuid::Error, + }, + #[error("sqlx error {source:?}")] + SqlxError { + #[from] + source: sqlx::Error, + }, +} + +impl OpenApiResponderInner for PartyError { + fn responses( + _gen: &mut rocket_okapi::gen::OpenApiGenerator, + ) -> rocket_okapi::Result { + Ok(okapi::openapi3::Responses::default()) + } +} + +impl<'r, 'o: 'r> Responder<'r, 'o> for PartyError { + fn respond_to(self, req: &'r Request<'_>) -> response::Result<'o> { + match self { + Self::UserNotFound(_) => Status::NotFound, + _ => Status::InternalServerError, + } + .respond_to(req) + } +} diff --git a/src/main.rs b/src/main.rs index 7b90cac..f4a835f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,18 +1,10 @@ mod api; -use lazy_static::lazy_static; -use okapi::openapi3::{Object, Responses, SecurityRequirement, SecurityScheme, SecuritySchemeData}; -use rocket::{ - fairing, fairing::AdHoc, http::Status, request, request::FromRequest, response, - response::Responder, serde, Build, Request, Rocket, -}; +use rocket::{fairing, fairing::AdHoc, Build, Rocket}; use rocket_db_pools::{sqlx, Database}; use rocket_okapi::{ - gen::OpenApiGenerator, - mount_endpoints_and_merged_docs, openapi, openapi_get_routes, + mount_endpoints_and_merged_docs, openapi, rapidoc::*, - request::{OpenApiFromRequest, RequestHeaderInput}, - response::OpenApiResponderInner, settings::{OpenApiSettings, UrlObject}, swagger_ui::{make_swagger_ui, SwaggerUIConfig}, }; @@ -20,155 +12,16 @@ use rocket_okapi::{ #[macro_use] extern crate rocket; -use schemars::JsonSchema; -use thiserror::Error; - -/* -const API_KEYS: [&'static str; 3] = [ - "7de10bf6-278d-11ed-ad60-a8a15919d1b3", - "89eb06e0-278d-11ed-9b29-a8a15919d1b3", - "8a35ba14-278d-11ed-a200-a8a15919d1b3", -]; -*/ - -lazy_static! { - static ref API_KEYS: Vec = { - serde_json::from_str( - &std::fs::read_to_string("api_keys.json").expect("api_keys.json does not exist"), - ) - .expect("api_keys.json is not valid") - }; -} - -#[derive(Error, Debug)] -pub enum PartyError { - #[error("user `{0}` does not exist")] - UserNotFound(i64), - #[error("unknown error: {0}")] - Unknown(String), - #[error("invalid parameter: {0}")] - InvalidParameter(String), - #[error("uuid error {source:?}")] - UuidError { - #[from] - source: uuid::Error, - }, - #[error("sqlx error {source:?}")] - SqlxError { - #[from] - source: sqlx::Error, - }, -} - -impl OpenApiResponderInner for PartyError { - fn responses( - _gen: &mut rocket_okapi::gen::OpenApiGenerator, - ) -> rocket_okapi::Result { - Ok(okapi::openapi3::Responses::default()) - } -} - -impl<'r, 'o: 'r> Responder<'r, 'o> for PartyError { - fn respond_to(self, req: &'r Request<'_>) -> response::Result<'o> { - match self { - Self::UserNotFound(_) => Status::NotFound, - _ => Status::InternalServerError, - } - .respond_to(req) - } -} - -/// # Ordering -/// -/// Ordering of data in an array, ascending or descending -#[derive(Clone, Debug, FromFormField, JsonSchema)] -#[serde(rename_all = "snake_case")] -pub enum Ordering { - #[field(value = "desc")] - Desc, - #[field(value = "asc")] - Asc, -} - -impl ToString for Ordering { - fn to_string(&self) -> String { - match self { - Self::Desc => "DESC", - Self::Asc => "ASC", - } - .into() - } -} - #[derive(Database)] #[database("party")] pub struct Db(sqlx::SqlitePool); -#[derive(Clone, Copy, Debug)] -pub struct ApiKey; - -#[rocket::async_trait] -impl<'r> FromRequest<'r> for ApiKey { - type Error = ApiKey; - - async fn from_request(req: &'r Request<'_>) -> request::Outcome { - if req - .headers() - .get("X-API-Key") - .any(|k| API_KEYS.contains(&String::from(k))) - { - request::Outcome::Success(ApiKey) - } else { - request::Outcome::Failure((Status::Unauthorized, ApiKey)) - } - } -} - -impl<'a> OpenApiFromRequest<'a> for ApiKey { - fn from_request_input( - _gen: &mut OpenApiGenerator, - _name: String, - _required: bool, - ) -> rocket_okapi::Result { - // Setup global requirement for Security scheme - let security_scheme = SecurityScheme { - description: Some("Requires an API key to access".to_owned()), - // Setup data requirements. - // This can be part of the `header`, `query` or `cookie`. - // In this case the header `x-api-key: mykey` needs to be set. - data: SecuritySchemeData::ApiKey { - name: "x-api-key".to_owned(), - location: "header".to_owned(), - }, - extensions: Object::default(), - }; - // Add the requirement for this route/endpoint - // This can change between routes. - let mut security_req = SecurityRequirement::new(); - // Each security requirement needs to be met before access is allowed. - security_req.insert("ApiKeyAuth".to_owned(), Vec::new()); - // These vvvvvvv-----^^^^^^^^^^ values need to match exactly! - Ok(RequestHeaderInput::Security( - "ApiKeyAuth".to_owned(), - security_scheme, - security_req, - )) - } -} - #[openapi] #[get("/")] fn index() -> String { format!("Hello, world!") } -fn get_docs() -> SwaggerUIConfig { - SwaggerUIConfig { - url: "../api/openapi.json".to_owned(), - ..Default::default() - } -} - async fn run_migrations(rocket: Rocket) -> fairing::Result { match Db::fetch(&rocket) { Some(db) => match sqlx::migrate!("db/migrations").run(&**db).await { @@ -187,12 +40,17 @@ async fn run_migrations(rocket: Rocket) -> fairing::Result { #[launch] fn rocket() -> _ { - println!("{:#?}", API_KEYS.len()); - let mut building_rocket = rocket::build() + let building_rocket = rocket::build() .attach(Db::init()) .attach(AdHoc::try_on_ignite("SQLx Migrations", run_migrations)) .mount("/", routes![index]) - .mount("/swagger", make_swagger_ui(&get_docs())) + .mount( + "/swagger", + make_swagger_ui(&SwaggerUIConfig { + url: "../api/openapi.json".to_owned(), + ..Default::default() + }), + ) .mount( "/rapidoc/", make_rapidoc(&RapiDocConfig { @@ -210,10 +68,6 @@ fn rocket() -> _ { ); let openapi_settings = OpenApiSettings::default(); - mount_endpoints_and_merged_docs! { - building_rocket, "/api".to_owned(), openapi_settings, - "/user" => api::user::get_routes_and_docs(&openapi_settings), - }; - building_rocket + api::mount_endpoints(building_rocket, &openapi_settings) }