This commit is contained in:
Daan Vanoverloop 2022-08-29 13:53:39 +02:00
parent b96420ec1d
commit 04ca51c927
Signed by: Danacus
GPG Key ID: F2272B50E129FC5C
8 changed files with 294 additions and 143 deletions

1
.gitignore vendored
View File

@ -1 +1,2 @@
/target
api_keys.json

2
Cargo.lock generated
View File

@ -815,11 +815,13 @@ version = "0.1.0"
dependencies = [
"dashmap",
"futures",
"lazy_static",
"okapi",
"rocket",
"rocket_db_pools",
"rocket_okapi",
"schemars",
"serde_json",
"sqlx",
"thiserror",
"uuid",

View File

@ -11,8 +11,10 @@ dashmap = "5.3.4"
thiserror = "1.0"
schemars = "0.8.10"
okapi = { version = "0.7.0-rc.1" }
rocket_okapi = { version = "0.8.0-rc.2", features = ["swagger", "rocket_db_pools"] }
rocket_okapi = { version = "0.8.0-rc.2", features = ["swagger", "rocket_db_pools", "rapidoc"] }
futures = "0.3"
lazy_static = "1.4"
serde_json = "1"
[dependencies.sqlx]
version = "*"

Binary file not shown.

Binary file not shown.

1
src/api/mod.rs Normal file
View File

@ -0,0 +1 @@
pub mod user;

172
src/api/user.rs Normal file
View File

@ -0,0 +1,172 @@
use crate::{ApiKey, Db, Ordering, PartyError};
use okapi::openapi3::OpenApi;
use rocket::{
fairing::AdHoc,
http::Status,
response::status::Created,
serde::{json::Json, Deserialize, Serialize},
};
use rocket_db_pools::{sqlx, Connection};
use rocket_okapi::{
openapi, openapi_get_routes, openapi_get_routes_spec, settings::OpenApiSettings, JsonSchema,
};
use sqlx::FromRow;
/// # User
///
/// A user that represents a person participating in the LAN party
#[derive(Clone, Debug, FromForm, Serialize, Deserialize, JsonSchema, FromRow)]
#[serde(crate = "rocket::serde")]
pub struct User {
/// Name of the user
name: String,
/// Score of the user
#[serde(default)]
score: i64,
/// Unique identifier of the user
#[serde(default)]
id: i64,
}
/// # Create new user with the give name
///
/// Returns the created user
#[openapi(tag = "User")]
#[post("/", data = "<name>")]
pub async fn add_user(
_api_key: ApiKey,
mut db: Connection<Db>,
name: Json<&str>,
) -> Result<Created<Json<User>>, PartyError> {
let result = sqlx::query!("INSERT INTO users (name) VALUES (?)", *name)
.execute(&mut *db)
.await?;
let user = User {
id: result.last_insert_rowid(),
score: 0,
name: name.to_string(),
};
Ok(Created::new("/").body(Json(user)))
}
/// # Delete user by id
#[openapi(tag = "User")]
#[delete("/<id>")]
pub async fn delete_user(
_api_key: ApiKey,
mut db: Connection<Db>,
id: i64,
) -> Result<Status, PartyError> {
sqlx::query!("DELETE FROM users where (id = ?)", id)
.execute(&mut *db)
.await?;
Ok(Status::Ok)
}
/// # Get user by id
///
/// Returns a single user by id
#[openapi(tag = "User")]
#[get("/<id>")]
pub async fn get_user(
_api_key: ApiKey,
mut db: Connection<Db>,
id: i64,
) -> Result<Json<User>, PartyError> {
let user = sqlx::query_as!(User, "SELECT id, name, score FROM users WHERE (id = ?)", id)
.fetch_one(&mut *db)
.await?;
Ok(Json(user))
}
/// # UserSort
///
/// Field used to sort users
#[derive(Clone, Debug, FromFormField, JsonSchema)]
#[serde(rename_all = "snake_case")]
pub enum UserSort {
#[field(value = "score")]
Score,
#[field(value = "name")]
Name,
#[field(value = "id")]
Id,
}
impl ToString for UserSort {
fn to_string(&self) -> String {
match self {
Self::Score => "score",
Self::Name => "name",
Self::Id => "id",
}
.into()
}
}
/// # Get all users
///
/// Returns an array of all users sorted by the given sort field and ordering
#[openapi(tag = "User")]
#[get("/?<sort>&<order>")]
pub async fn get_all_users(
_api_key: ApiKey,
mut db: Connection<Db>,
sort: Option<UserSort>,
order: Option<Ordering>,
) -> Result<Json<Vec<User>>, PartyError> {
let users = sqlx::query_as::<_, User>(&format!(
"SELECT id, name, score FROM users ORDER BY {} {}",
sort.unwrap_or(UserSort::Id).to_string(),
order.unwrap_or(Ordering::Asc).to_string()
))
.fetch_all(&mut *db)
.await?;
Ok(Json(users))
}
/// # Set score of user by id
#[openapi(tag = "User")]
#[post("/<id>/score", data = "<score>")]
pub async fn set_score(
_api_key: ApiKey,
mut db: Connection<Db>,
id: i64,
score: Json<i64>,
) -> Result<Status, PartyError> {
sqlx::query!("UPDATE users SET score = ? WHERE id = ?", *score, id)
.execute(&mut *db)
.await?;
Ok(Status::Ok)
}
/// # Get score of user by id
///
/// Returns the score of a single user by id
#[openapi(tag = "User")]
#[get("/<id>/score")]
pub async fn get_score(
_api_key: ApiKey,
mut db: Connection<Db>,
id: i64,
) -> Result<Json<i64>, PartyError> {
let score = sqlx::query_scalar!("SELECT score FROM users WHERE id = ?", id)
.fetch_one(&mut *db)
.await?;
Ok(Json(score))
}
pub fn get_routes_and_docs(settings: &OpenApiSettings) -> (Vec<rocket::Route>, OpenApi) {
openapi_get_routes_spec![
settings: add_user,
get_user,
get_all_users,
delete_user,
set_score,
get_score
]
}

View File

@ -1,31 +1,45 @@
use futures::{prelude::*, stream::TryStreamExt};
mod api;
use lazy_static::lazy_static;
use okapi::openapi3::{Object, Responses, SecurityRequirement, SecurityScheme, SecuritySchemeData};
use rocket::{
fairing,
fairing::AdHoc,
http::Status,
response,
response::{status::Created, Responder},
serde::{json::Json, Deserialize, Serialize},
Build, Request, Rocket,
};
use rocket_db_pools::{
sqlx::{self},
Connection, Database,
fairing, fairing::AdHoc, http::Status, request, request::FromRequest, response,
response::Responder, serde, Build, Request, Rocket,
};
use rocket_db_pools::{sqlx, Database};
use rocket_okapi::{
openapi, openapi_get_routes,
gen::OpenApiGenerator,
mount_endpoints_and_merged_docs, openapi, openapi_get_routes,
rapidoc::*,
request::{OpenApiFromRequest, RequestHeaderInput},
response::OpenApiResponderInner,
settings::{OpenApiSettings, UrlObject},
swagger_ui::{make_swagger_ui, SwaggerUIConfig},
JsonSchema,
};
use sqlx::{Acquire, Connection as SqlxConnection, FromRow, Sqlite};
use std::collections::HashMap;
#[macro_use]
extern crate rocket;
use schemars::JsonSchema;
use thiserror::Error;
/*
const API_KEYS: [&'static str; 3] = [
"7de10bf6-278d-11ed-ad60-a8a15919d1b3",
"89eb06e0-278d-11ed-9b29-a8a15919d1b3",
"8a35ba14-278d-11ed-a200-a8a15919d1b3",
];
*/
lazy_static! {
static ref API_KEYS: Vec<String> = {
serde_json::from_str(
&std::fs::read_to_string("api_keys.json").expect("api_keys.json does not exist"),
)
.expect("api_keys.json is not valid")
};
}
#[derive(Error, Debug)]
pub enum PartyError {
#[error("user `{0}` does not exist")]
@ -56,11 +70,7 @@ impl OpenApiResponderInner for PartyError {
impl<'r, 'o: 'r> Responder<'r, 'o> for PartyError {
fn respond_to(self, req: &'r Request<'_>) -> response::Result<'o> {
// log `self` to your favored error tracker, e.g.
// sentry::capture_error(&self);
match self {
// in our simplistic example, we're happy to respond with the default 500 responder in all cases
Self::UserNotFound(_) => Status::NotFound,
_ => Status::InternalServerError,
}
@ -68,73 +78,12 @@ impl<'r, 'o: 'r> Responder<'r, 'o> for PartyError {
}
}
#[derive(Clone, Debug, FromForm, Serialize, Deserialize, JsonSchema, FromRow)]
#[serde(crate = "rocket::serde")]
struct User {
name: String,
#[serde(default)]
score: i64,
#[serde(default)]
id: i64,
}
#[derive(Database)]
#[database("party")]
struct Db(sqlx::SqlitePool);
#[openapi]
#[get("/")]
fn index() -> String {
format!("Hello, world!")
}
#[openapi]
#[post("/user", data = "<user>")]
async fn add_user(
mut db: Connection<Db>,
mut user: Json<User>,
) -> Result<Created<Json<User>>, PartyError> {
let result = sqlx::query!("INSERT INTO users (name) VALUES (?)", user.name)
.execute(&mut *db)
.await?;
user.id = result.last_insert_rowid();
Ok(Created::new("/").body(user))
}
#[openapi]
#[delete("/user/<id>")]
async fn delete_user(mut db: Connection<Db>, id: i64) -> Result<Status, PartyError> {
sqlx::query!("DELETE FROM users where (id = ?)", id)
.execute(&mut *db)
.await?;
Ok(Status::Ok)
}
#[openapi]
#[get("/user/<id>")]
async fn get_user(mut db: Connection<Db>, id: i64) -> Result<Json<User>, PartyError> {
let user = sqlx::query_as!(User, "SELECT id, name, score FROM users WHERE (id = ?)", id)
.fetch_one(&mut *db)
.await?;
Ok(Json(user))
}
/// # Ordering
///
/// Ordering of data in an array, ascending or descending
#[derive(Clone, Debug, FromFormField, JsonSchema)]
#[serde(rename_all = "snake_case")]
enum UserSort {
#[field(value = "score")]
Score,
#[field(value = "name")]
Name,
#[field(value = "id")]
Id,
}
#[derive(Clone, Debug, FromFormField, JsonSchema)]
#[serde(rename_all = "snake_case")]
enum Ordering {
pub enum Ordering {
#[field(value = "desc")]
Desc,
#[field(value = "asc")]
@ -151,55 +100,66 @@ impl ToString for Ordering {
}
}
impl ToString for UserSort {
fn to_string(&self) -> String {
match self {
Self::Score => "score",
Self::Name => "name",
Self::Id => "id",
#[derive(Database)]
#[database("party")]
pub struct Db(sqlx::SqlitePool);
#[derive(Clone, Copy, Debug)]
pub struct ApiKey;
#[rocket::async_trait]
impl<'r> FromRequest<'r> for ApiKey {
type Error = ApiKey;
async fn from_request(req: &'r Request<'_>) -> request::Outcome<Self, Self::Error> {
if req
.headers()
.get("X-API-Key")
.any(|k| API_KEYS.contains(&String::from(k)))
{
request::Outcome::Success(ApiKey)
} else {
request::Outcome::Failure((Status::Unauthorized, ApiKey))
}
.into()
}
}
#[openapi]
#[get("/user?<sort>&<order>")]
async fn get_all_users(
mut db: Connection<Db>,
sort: Option<UserSort>,
order: Option<Ordering>,
) -> Result<Json<Vec<User>>, PartyError> {
let users = sqlx::query_as::<_, User>(&format!(
"SELECT id, name, score FROM users ORDER BY {} {}",
sort.unwrap_or(UserSort::Id).to_string(),
order.unwrap_or(Ordering::Asc).to_string()
impl<'a> OpenApiFromRequest<'a> for ApiKey {
fn from_request_input(
_gen: &mut OpenApiGenerator,
_name: String,
_required: bool,
) -> rocket_okapi::Result<RequestHeaderInput> {
// Setup global requirement for Security scheme
let security_scheme = SecurityScheme {
description: Some("Requires an API key to access".to_owned()),
// Setup data requirements.
// This can be part of the `header`, `query` or `cookie`.
// In this case the header `x-api-key: mykey` needs to be set.
data: SecuritySchemeData::ApiKey {
name: "x-api-key".to_owned(),
location: "header".to_owned(),
},
extensions: Object::default(),
};
// Add the requirement for this route/endpoint
// This can change between routes.
let mut security_req = SecurityRequirement::new();
// Each security requirement needs to be met before access is allowed.
security_req.insert("ApiKeyAuth".to_owned(), Vec::new());
// These vvvvvvv-----^^^^^^^^^^ values need to match exactly!
Ok(RequestHeaderInput::Security(
"ApiKeyAuth".to_owned(),
security_scheme,
security_req,
))
.fetch_all(&mut *db)
.await?;
Ok(Json(users))
}
}
#[openapi]
#[post("/user/<id>/score", data = "<score>")]
async fn set_score(
mut db: Connection<Db>,
id: i64,
score: Json<i64>,
) -> Result<Status, PartyError> {
sqlx::query!("UPDATE users SET score = ? WHERE id = ?", *score, id)
.execute(&mut *db)
.await?;
Ok(Status::Ok)
}
#[openapi]
#[get("/user/<id>/score")]
async fn get_score(mut db: Connection<Db>, id: i64) -> Result<Json<i64>, PartyError> {
let score = sqlx::query_scalar!("SELECT score FROM users WHERE id = ?", id)
.fetch_one(&mut *db)
.await?;
Ok(Json(score))
#[get("/")]
fn index() -> String {
format!("Hello, world!")
}
fn get_docs() -> SwaggerUIConfig {
@ -227,20 +187,33 @@ async fn run_migrations(rocket: Rocket<Build>) -> fairing::Result {
#[launch]
fn rocket() -> _ {
rocket::build()
println!("{:#?}", API_KEYS.len());
let mut building_rocket = rocket::build()
.attach(Db::init())
.attach(AdHoc::try_on_ignite("SQLx Migrations", run_migrations))
.mount("/", openapi_get_routes![index])
.mount(
"/api",
openapi_get_routes![
add_user,
get_user,
get_all_users,
delete_user,
set_score,
get_score
],
)
.mount("/", routes![index])
.mount("/swagger", make_swagger_ui(&get_docs()))
.mount(
"/rapidoc/",
make_rapidoc(&RapiDocConfig {
general: GeneralConfig {
spec_urls: vec![UrlObject::new("General", "../api/openapi.json")],
..Default::default()
},
hide_show: HideShowConfig {
allow_spec_url_load: false,
allow_spec_file_load: false,
..Default::default()
},
..Default::default()
}),
);
let openapi_settings = OpenApiSettings::default();
mount_endpoints_and_merged_docs! {
building_rocket, "/api".to_owned(), openapi_settings,
"/user" => api::user::get_routes_and_docs(&openapi_settings),
};
building_rocket
}