mod api; use lazy_static::lazy_static; use okapi::openapi3::{Object, Responses, SecurityRequirement, SecurityScheme, SecuritySchemeData}; use rocket::{ fairing, fairing::AdHoc, http::Status, request, request::FromRequest, response, response::Responder, serde, Build, Request, Rocket, }; use rocket_db_pools::{sqlx, Database}; use rocket_okapi::{ gen::OpenApiGenerator, mount_endpoints_and_merged_docs, openapi, openapi_get_routes, rapidoc::*, request::{OpenApiFromRequest, RequestHeaderInput}, response::OpenApiResponderInner, settings::{OpenApiSettings, UrlObject}, swagger_ui::{make_swagger_ui, SwaggerUIConfig}, }; #[macro_use] extern crate rocket; use schemars::JsonSchema; use thiserror::Error; /* const API_KEYS: [&'static str; 3] = [ "7de10bf6-278d-11ed-ad60-a8a15919d1b3", "89eb06e0-278d-11ed-9b29-a8a15919d1b3", "8a35ba14-278d-11ed-a200-a8a15919d1b3", ]; */ lazy_static! { static ref API_KEYS: Vec = { serde_json::from_str( &std::fs::read_to_string("api_keys.json").expect("api_keys.json does not exist"), ) .expect("api_keys.json is not valid") }; } #[derive(Error, Debug)] pub enum PartyError { #[error("user `{0}` does not exist")] UserNotFound(i64), #[error("unknown error: {0}")] Unknown(String), #[error("invalid parameter: {0}")] InvalidParameter(String), #[error("uuid error {source:?}")] UuidError { #[from] source: uuid::Error, }, #[error("sqlx error {source:?}")] SqlxError { #[from] source: sqlx::Error, }, } impl OpenApiResponderInner for PartyError { fn responses( _gen: &mut rocket_okapi::gen::OpenApiGenerator, ) -> rocket_okapi::Result { Ok(okapi::openapi3::Responses::default()) } } impl<'r, 'o: 'r> Responder<'r, 'o> for PartyError { fn respond_to(self, req: &'r Request<'_>) -> response::Result<'o> { match self { Self::UserNotFound(_) => Status::NotFound, _ => Status::InternalServerError, } .respond_to(req) } } /// # Ordering /// /// Ordering of data in an array, ascending or descending #[derive(Clone, Debug, FromFormField, JsonSchema)] #[serde(rename_all = "snake_case")] pub enum Ordering { #[field(value = "desc")] Desc, #[field(value = "asc")] Asc, } impl ToString for Ordering { fn to_string(&self) -> String { match self { Self::Desc => "DESC", Self::Asc => "ASC", } .into() } } #[derive(Database)] #[database("party")] pub struct Db(sqlx::SqlitePool); #[derive(Clone, Copy, Debug)] pub struct ApiKey; #[rocket::async_trait] impl<'r> FromRequest<'r> for ApiKey { type Error = ApiKey; async fn from_request(req: &'r Request<'_>) -> request::Outcome { if req .headers() .get("X-API-Key") .any(|k| API_KEYS.contains(&String::from(k))) { request::Outcome::Success(ApiKey) } else { request::Outcome::Failure((Status::Unauthorized, ApiKey)) } } } impl<'a> OpenApiFromRequest<'a> for ApiKey { fn from_request_input( _gen: &mut OpenApiGenerator, _name: String, _required: bool, ) -> rocket_okapi::Result { // Setup global requirement for Security scheme let security_scheme = SecurityScheme { description: Some("Requires an API key to access".to_owned()), // Setup data requirements. // This can be part of the `header`, `query` or `cookie`. // In this case the header `x-api-key: mykey` needs to be set. data: SecuritySchemeData::ApiKey { name: "x-api-key".to_owned(), location: "header".to_owned(), }, extensions: Object::default(), }; // Add the requirement for this route/endpoint // This can change between routes. let mut security_req = SecurityRequirement::new(); // Each security requirement needs to be met before access is allowed. security_req.insert("ApiKeyAuth".to_owned(), Vec::new()); // These vvvvvvv-----^^^^^^^^^^ values need to match exactly! Ok(RequestHeaderInput::Security( "ApiKeyAuth".to_owned(), security_scheme, security_req, )) } } #[openapi] #[get("/")] fn index() -> String { format!("Hello, world!") } fn get_docs() -> SwaggerUIConfig { SwaggerUIConfig { url: "../api/openapi.json".to_owned(), ..Default::default() } } async fn run_migrations(rocket: Rocket) -> fairing::Result { match Db::fetch(&rocket) { Some(db) => match sqlx::migrate!("db/migrations").run(&**db).await { Ok(_) => { println!("Migrations completed"); Ok(rocket) } Err(e) => { error!("Failed to initialize SQLx database: {}", e); Err(rocket) } }, None => Err(rocket), } } #[launch] fn rocket() -> _ { println!("{:#?}", API_KEYS.len()); let mut building_rocket = rocket::build() .attach(Db::init()) .attach(AdHoc::try_on_ignite("SQLx Migrations", run_migrations)) .mount("/", routes![index]) .mount("/swagger", make_swagger_ui(&get_docs())) .mount( "/rapidoc/", make_rapidoc(&RapiDocConfig { general: GeneralConfig { spec_urls: vec![UrlObject::new("General", "../api/openapi.json")], ..Default::default() }, hide_show: HideShowConfig { allow_spec_url_load: false, allow_spec_file_load: false, ..Default::default() }, ..Default::default() }), ); let openapi_settings = OpenApiSettings::default(); mount_endpoints_and_merged_docs! { building_rocket, "/api".to_owned(), openapi_settings, "/user" => api::user::get_routes_and_docs(&openapi_settings), }; building_rocket }