use futures::{prelude::*, stream::TryStreamExt}; use rocket::{ fairing, fairing::AdHoc, http::Status, response, response::{status::Created, Responder}, serde::{json::Json, Deserialize, Serialize}, Build, Request, Rocket, }; use rocket_db_pools::{ sqlx::{self}, Connection, Database, }; use rocket_okapi::{ openapi, openapi_get_routes, response::OpenApiResponderInner, swagger_ui::{make_swagger_ui, SwaggerUIConfig}, JsonSchema, }; use sqlx::{Acquire, Connection as SqlxConnection, FromRow, Sqlite}; use std::collections::HashMap; #[macro_use] extern crate rocket; use thiserror::Error; #[derive(Error, Debug)] pub enum PartyError { #[error("user `{0}` does not exist")] UserNotFound(i64), #[error("unknown error: {0}")] Unknown(String), #[error("invalid parameter: {0}")] InvalidParameter(String), #[error("uuid error {source:?}")] UuidError { #[from] source: uuid::Error, }, #[error("sqlx error {source:?}")] SqlxError { #[from] source: sqlx::Error, }, } impl OpenApiResponderInner for PartyError { fn responses( _gen: &mut rocket_okapi::gen::OpenApiGenerator, ) -> rocket_okapi::Result { Ok(okapi::openapi3::Responses::default()) } } impl<'r, 'o: 'r> Responder<'r, 'o> for PartyError { fn respond_to(self, req: &'r Request<'_>) -> response::Result<'o> { // log `self` to your favored error tracker, e.g. // sentry::capture_error(&self); match self { // in our simplistic example, we're happy to respond with the default 500 responder in all cases Self::UserNotFound(_) => Status::NotFound, _ => Status::InternalServerError, } .respond_to(req) } } #[derive(Clone, Debug, FromForm, Serialize, Deserialize, JsonSchema, FromRow)] #[serde(crate = "rocket::serde")] struct User { name: String, #[serde(default)] score: i64, #[serde(default)] id: i64, } #[derive(Database)] #[database("party")] struct Db(sqlx::SqlitePool); #[openapi] #[get("/")] fn index() -> String { format!("Hello, world!") } #[openapi] #[post("/user", data = "")] async fn add_user( mut db: Connection, mut user: Json, ) -> Result>, PartyError> { let result = sqlx::query!("INSERT INTO users (name) VALUES (?)", user.name) .execute(&mut *db) .await?; user.id = result.last_insert_rowid(); Ok(Created::new("/").body(user)) } #[openapi] #[delete("/user/")] async fn delete_user(mut db: Connection, id: i64) -> Result { sqlx::query!("DELETE FROM users where (id = ?)", id) .execute(&mut *db) .await?; Ok(Status::Ok) } #[openapi] #[get("/user/")] async fn get_user(mut db: Connection, id: i64) -> Result, PartyError> { let user = sqlx::query_as!(User, "SELECT id, name, score FROM users WHERE (id = ?)", id) .fetch_one(&mut *db) .await?; Ok(Json(user)) } #[derive(Clone, Debug, FromFormField, JsonSchema)] #[serde(rename_all = "snake_case")] enum UserSort { #[field(value = "score")] Score, #[field(value = "name")] Name, #[field(value = "id")] Id, } #[derive(Clone, Debug, FromFormField, JsonSchema)] #[serde(rename_all = "snake_case")] enum Ordering { #[field(value = "desc")] Desc, #[field(value = "asc")] Asc, } impl ToString for Ordering { fn to_string(&self) -> String { match self { Self::Desc => "DESC", Self::Asc => "ASC", } .into() } } impl ToString for UserSort { fn to_string(&self) -> String { match self { Self::Score => "score", Self::Name => "name", Self::Id => "id", } .into() } } #[openapi] #[get("/user?&")] async fn get_all_users( mut db: Connection, sort: Option, order: Option, ) -> Result>, PartyError> { let users = sqlx::query_as::<_, User>(&format!( "SELECT id, name, score FROM users ORDER BY {} {}", sort.unwrap_or(UserSort::Id).to_string(), order.unwrap_or(Ordering::Asc).to_string() )) .fetch_all(&mut *db) .await?; Ok(Json(users)) } #[openapi] #[post("/user//score", data = "")] async fn set_score( mut db: Connection, id: i64, score: Json, ) -> Result { sqlx::query!("UPDATE users SET score = ? WHERE id = ?", *score, id) .execute(&mut *db) .await?; Ok(Status::Ok) } #[openapi] #[get("/user//score")] async fn get_score(mut db: Connection, id: i64) -> Result, PartyError> { let score = sqlx::query_scalar!("SELECT score FROM users WHERE id = ?", id) .fetch_one(&mut *db) .await?; Ok(Json(score)) } fn get_docs() -> SwaggerUIConfig { SwaggerUIConfig { url: "../api/openapi.json".to_owned(), ..Default::default() } } async fn run_migrations(rocket: Rocket) -> fairing::Result { match Db::fetch(&rocket) { Some(db) => match sqlx::migrate!("db/migrations").run(&**db).await { Ok(_) => { println!("Migrations completed"); Ok(rocket) } Err(e) => { error!("Failed to initialize SQLx database: {}", e); Err(rocket) } }, None => Err(rocket), } } #[launch] fn rocket() -> _ { rocket::build() .attach(Db::init()) .attach(AdHoc::try_on_ignite("SQLx Migrations", run_migrations)) .mount("/", openapi_get_routes![index]) .mount( "/api", openapi_get_routes![ add_user, get_user, get_all_users, delete_user, set_score, get_score ], ) .mount("/swagger", make_swagger_ui(&get_docs())) }