diff --git a/app/Cargo.toml b/app/Cargo.toml index 16abd04..86415f5 100644 --- a/app/Cargo.toml +++ b/app/Cargo.toml @@ -5,6 +5,12 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html +[features] +default = ["mock_auth"] +#default = ["oidc"] +oidc = [] +mock_auth = [] + [dependencies.rest] path = "../rest" diff --git a/app/src/main.rs b/app/src/main.rs index 58e0371..5d62c71 100644 --- a/app/src/main.rs +++ b/app/src/main.rs @@ -2,8 +2,12 @@ use std::sync::Arc; use sqlx::SqlitePool; +#[cfg(feature = "mock_auth")] +type UserService = service_impl::UserServiceDev; +#[cfg(feature = "oidc")] +type UserService = service_impl::UserServiceImpl; type PermissionService = - service_impl::PermissionServiceImpl; + service_impl::PermissionServiceImpl; type ClockService = service_impl::clock::ClockServiceImpl; type UuidService = service_impl::uuid_service::UuidServiceImpl; type SlotService = service_impl::slot::SlotServiceImpl< @@ -66,7 +70,10 @@ impl RestStateImpl { // TODO: Implement a proper authentication service when used in produciton. Maybe // use differnet implementations on debug then on release. Or control it via a // feature. + #[cfg(feature = "mock_auth")] let user_service = service_impl::UserServiceDev; + #[cfg(feature = "oidc")] + let user_service = service_impl::UserServiceImpl; let permission_service = Arc::new(service_impl::PermissionServiceImpl::new( permission_dao.into(), user_service.into(), diff --git a/rest/Cargo.toml b/rest/Cargo.toml index 2a8d11f..fe018a5 100644 --- a/rest/Cargo.toml +++ b/rest/Cargo.toml @@ -5,9 +5,10 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [features] +default = ["mock_auth"] #default = ["oidc"] -default = [] oidc = ["dep:axum-oidc"] +mock_auth = [] [dependencies] axum = "0.7.5" diff --git a/rest/src/booking.rs b/rest/src/booking.rs index 3273bc3..eeff622 100644 --- a/rest/src/booking.rs +++ b/rest/src/booking.rs @@ -4,12 +4,12 @@ use axum::body::Body; use axum::extract::Path; use axum::routing::{delete, get, post}; use axum::{extract::State, response::Response}; -use axum::{Json, Router}; +use axum::{Extension, Json, Router}; use serde::{Deserialize, Serialize}; use time::PrimitiveDateTime; use uuid::Uuid; -use crate::{error_handler, RestStateDef}; +use crate::{error_handler, Context, RestStateDef}; use service::booking::{Booking, BookingService}; #[derive(Serialize, Deserialize, Clone, Debug)] @@ -65,12 +65,15 @@ pub fn generate_route() -> Router { .route("/:id", delete(delete_booking::)) } -pub async fn get_all_bookings(rest_state: State) -> Response { +pub async fn get_all_bookings( + rest_state: State, + Extension(context): Extension, +) -> Response { error_handler( (async { let bookings: Arc<[BookingTO]> = rest_state .booking_service() - .get_all(().into()) + .get_all(context.into()) .await? .iter() .map(BookingTO::from) @@ -86,13 +89,14 @@ pub async fn get_all_bookings(rest_state: State( rest_state: State, + Extension(context): Extension, Path(booking_id): Path, ) -> Response { error_handler( (async { let booking = rest_state .booking_service() - .get(booking_id, ().into()) + .get(booking_id, context.into()) .await?; Ok(Response::builder() .status(200) @@ -107,13 +111,14 @@ pub async fn get_booking( pub async fn create_booking( rest_state: State, + Extension(context): Extension, Json(booking): Json, ) -> Response { error_handler( (async { let booking = rest_state .booking_service() - .create(&Booking::from(&booking), ().into()) + .create(&Booking::from(&booking), context.into()) .await?; Ok(Response::builder() .status(200) @@ -128,13 +133,14 @@ pub async fn create_booking( pub async fn delete_booking( rest_state: State, + Extension(context): Extension, Path(booking_id): Path, ) -> Response { error_handler( (async { rest_state .booking_service() - .delete(booking_id, ().into()) + .delete(booking_id, context.into()) .await?; Ok(Response::builder().status(200).body(Body::empty()).unwrap()) }) diff --git a/rest/src/lib.rs b/rest/src/lib.rs index 09dd96b..f9e8709 100644 --- a/rest/src/lib.rs +++ b/rest/src/lib.rs @@ -5,7 +5,9 @@ mod permission; mod sales_person; mod slot; +use axum::extract::Request; use axum::http::Uri; +use axum::middleware::{self, Next}; use axum::response::{IntoResponse, Redirect}; use axum::routing::get; use axum::{body::Body, error_handling::HandleErrorLayer, response::Response, Router}; @@ -19,7 +21,34 @@ use tower_sessions::{cookie::SameSite, Expiry, MemoryStore, SessionManagerLayer} use uuid::Uuid; // TODO: In prod, it must be a different type than in dev mode. +#[cfg(feature = "mock_auth")] type Context = (); +#[cfg(feature = "oidc")] +type Context = Option>; + +#[cfg(feature = "oidc")] +pub async fn context_extractor( + claims: Option>, + mut request: Request, + next: Next, +) -> Response { + let context: Context = if let Some(oidc_claims) = claims { + let username = oidc_claims + .preferred_username() + .map(|s| s.as_str().to_string()) + .unwrap_or_else(|| "NoUsername".to_string()); + Some(username.into()) + } else { + None + }; + request.extensions_mut().insert(context); + next.run(request).await +} +#[cfg(feature = "mock_auth")] +pub async fn context_extractor(mut request: Request, next: Next) -> Response { + request.extensions_mut().insert(()); + next.run(request).await +} pub struct RoString(Arc, bool); impl http_body::Body for RoString { @@ -77,6 +106,9 @@ fn error_handler(result: Result) -> Response { Err(RestError::ServiceError(service::ServiceError::Forbidden)) => { Response::builder().status(403).body(Body::empty()).unwrap() } + Err(RestError::ServiceError(service::ServiceError::Unauthorized)) => { + Response::builder().status(401).body(Body::empty()).unwrap() + } Err(RestError::ServiceError(service::ServiceError::DatabaseQueryError(e))) => { Response::builder() .status(500) @@ -245,7 +277,8 @@ pub async fn start_server(rest_state: RestState) { .nest("/slot", slot::generate_route()) .nest("/sales-person", sales_person::generate_route()) .nest("/booking", booking::generate_route()) - .with_state(rest_state); + .with_state(rest_state) + .layer(middleware::from_fn(context_extractor)); #[cfg(feature = "oidc")] let app = { diff --git a/rest/src/permission.rs b/rest/src/permission.rs index befc3f6..6a73f78 100644 --- a/rest/src/permission.rs +++ b/rest/src/permission.rs @@ -6,10 +6,10 @@ use axum::{ extract::State, response::Response, routing::{delete, get, post}, - Json, Router, + Extension, Json, Router, }; -use crate::{error_handler, RestStateDef}; +use crate::{error_handler, Context, RestStateDef}; use service::PermissionService; #[derive(Debug, Serialize, Deserialize)] @@ -80,6 +80,7 @@ pub fn generate_route() -> Router { pub async fn add_user( rest_state: State, + Extension(context): Extension, Json(user): Json, ) -> Response { println!("Adding user: {:?}", user); @@ -87,7 +88,7 @@ pub async fn add_user( (async { rest_state .permission_service() - .create_user(user.name.as_str(), ().into()) + .create_user(user.name.as_str(), context.into()) .await?; Ok(Response::builder() .status(201) @@ -100,6 +101,7 @@ pub async fn add_user( pub async fn remove_user( rest_state: State, + Extension(context): Extension, Json(user): Json, ) -> Response { println!("Removing user: {:?}", user); @@ -107,7 +109,7 @@ pub async fn remove_user( (async { rest_state .permission_service() - .delete_user(&user, ().into()) + .delete_user(&user, context.into()) .await?; Ok(Response::builder() .status(200) @@ -120,13 +122,14 @@ pub async fn remove_user( pub async fn add_role( rest_state: State, + Extension(context): Extension, Json(role): Json, ) -> Response { error_handler( (async { rest_state .permission_service() - .create_role(role.name.as_str(), ().into()) + .create_role(role.name.as_str(), context.into()) .await?; Ok(Response::builder() .status(200) @@ -139,13 +142,14 @@ pub async fn add_role( pub async fn delete_role( rest_state: State, + Extension(context): Extension, Json(role): Json, ) -> Response { error_handler( (async { rest_state .permission_service() - .delete_role(role.as_str(), ().into()) + .delete_role(role.as_str(), context.into()) .await?; Ok(Response::builder() .status(200) @@ -158,13 +162,18 @@ pub async fn delete_role( pub async fn add_user_role( rest_state: State, + Extension(context): Extension, Json(user_role): Json, ) -> Response { error_handler( (async { rest_state .permission_service() - .add_user_role(user_role.user.as_str(), user_role.role.as_str(), ().into()) + .add_user_role( + user_role.user.as_str(), + user_role.role.as_str(), + context.into(), + ) .await?; Ok(Response::builder() .status(201) @@ -177,13 +186,18 @@ pub async fn add_user_role( pub async fn remove_user_role( rest_state: State, + Extension(context): Extension, Json(user_role): Json, ) -> Response { error_handler( (async { rest_state .permission_service() - .delete_user_role(user_role.user.as_str(), user_role.role.as_str(), ().into()) + .delete_user_role( + user_role.user.as_str(), + user_role.role.as_str(), + context.into(), + ) .await?; Ok(Response::builder() .status(200) @@ -196,6 +210,7 @@ pub async fn remove_user_role( pub async fn add_role_privilege( rest_state: State, + Extension(context): Extension, Json(role_privilege): Json, ) -> Response { error_handler( @@ -205,7 +220,7 @@ pub async fn add_role_privilege( .add_role_privilege( role_privilege.role.as_str(), role_privilege.privilege.as_str(), - ().into(), + context.into(), ) .await?; Ok(Response::builder() @@ -219,6 +234,7 @@ pub async fn add_role_privilege( pub async fn remove_role_privilege( rest_state: State, + Extension(context): Extension, Json(role_privilege): Json, ) -> Response { error_handler( @@ -228,7 +244,7 @@ pub async fn remove_role_privilege( .delete_role_privilege( role_privilege.role.as_str(), role_privilege.privilege.as_str(), - ().into(), + context.into(), ) .await?; Ok(Response::builder() @@ -240,12 +256,15 @@ pub async fn remove_role_privilege( ) } -pub async fn get_all_users(rest_state: State) -> Response { +pub async fn get_all_users( + rest_state: State, + Extension(context): Extension, +) -> Response { error_handler( (async { let users: Arc<[UserTO]> = rest_state .permission_service() - .get_all_users(().into()) + .get_all_users(context.into()) .await? .iter() .map(UserTO::from) @@ -259,12 +278,15 @@ pub async fn get_all_users(rest_state: State ) } -pub async fn get_all_roles(rest_state: State) -> Response { +pub async fn get_all_roles( + rest_state: State, + Extension(context): Extension, +) -> Response { error_handler( (async { let roles: Arc<[RoleTO]> = rest_state .permission_service() - .get_all_roles(().into()) + .get_all_roles(context.into()) .await? .iter() .map(RoleTO::from) @@ -278,12 +300,15 @@ pub async fn get_all_roles(rest_state: State ) } -pub async fn get_all_privileges(rest_state: State) -> Response { +pub async fn get_all_privileges( + rest_state: State, + Extension(context): Extension, +) -> Response { error_handler( (async { let privileges: Arc<[PrivilegeTO]> = rest_state .permission_service() - .get_all_privileges(().into()) + .get_all_privileges(context.into()) .await? .iter() .map(PrivilegeTO::from) diff --git a/rest/src/sales_person.rs b/rest/src/sales_person.rs index 8c46167..512bc7a 100644 --- a/rest/src/sales_person.rs +++ b/rest/src/sales_person.rs @@ -4,13 +4,13 @@ use axum::body::Body; use axum::extract::Path; use axum::routing::{delete, get, post, put}; use axum::{extract::State, response::Response}; -use axum::{Json, Router}; +use axum::{Extension, Json, Router}; use serde::{Deserialize, Serialize}; use service::sales_person::SalesPerson; use service::sales_person::SalesPersonService; use uuid::Uuid; -use crate::{error_handler, RestError, RestStateDef}; +use crate::{error_handler, Context, RestError, RestStateDef}; #[derive(Serialize, Deserialize, Clone, Debug)] pub struct SalesPersonTO { @@ -62,12 +62,13 @@ pub fn generate_route() -> Router { pub async fn get_all_sales_persons( rest_state: State, + Extension(context): Extension, ) -> Response { error_handler( (async { let sales_persons: Arc<[SalesPersonTO]> = rest_state .sales_person_service() - .get_all(().into()) + .get_all(context.into()) .await? .iter() .map(SalesPersonTO::from) @@ -83,6 +84,7 @@ pub async fn get_all_sales_persons( pub async fn get_sales_person( rest_state: State, + Extension(context): Extension, Path(sales_person_id): Path, ) -> Response { error_handler( @@ -90,7 +92,7 @@ pub async fn get_sales_person( let sales_person = SalesPersonTO::from( &rest_state .sales_person_service() - .get(sales_person_id, ().into()) + .get(sales_person_id, context.into()) .await?, ); Ok(Response::builder() @@ -104,6 +106,7 @@ pub async fn get_sales_person( pub async fn create_sales_person( rest_state: State, + Extension(context): Extension, Json(sales_person): Json, ) -> Response { error_handler( @@ -111,7 +114,7 @@ pub async fn create_sales_person( let sales_person = SalesPersonTO::from( &rest_state .sales_person_service() - .create(&(&sales_person).into(), ().into()) + .create(&(&sales_person).into(), context.into()) .await?, ); Ok(Response::builder() @@ -125,6 +128,7 @@ pub async fn create_sales_person( pub async fn update_sales_person( rest_state: State, + Extension(context): Extension, Path(sales_person_id): Path, Json(sales_person): Json, ) -> Response { @@ -135,7 +139,7 @@ pub async fn update_sales_person( } rest_state .sales_person_service() - .update(&(&sales_person).into(), ().into()) + .update(&(&sales_person).into(), context.into()) .await?; Ok(Response::builder() .status(200) @@ -148,13 +152,14 @@ pub async fn update_sales_person( pub async fn delete_sales_person( rest_state: State, + Extension(context): Extension, Path(sales_person_id): Path, ) -> Response { error_handler( (async { rest_state .sales_person_service() - .delete(sales_person_id, ().into()) + .delete(sales_person_id, context.into()) .await?; Ok(Response::builder().status(204).body(Body::empty()).unwrap()) }) @@ -164,13 +169,14 @@ pub async fn delete_sales_person( pub async fn get_sales_person_user( rest_state: State, + Extension(context): Extension, Path(sales_person_id): Path, ) -> Response { error_handler( (async { let user = rest_state .sales_person_service() - .get_assigned_user(sales_person_id, ().into()) + .get_assigned_user(sales_person_id, context.into()) .await?; Ok(Response::builder() .status(200) @@ -183,6 +189,7 @@ pub async fn get_sales_person_user( pub async fn set_sales_person_user( rest_state: State, + Extension(context): Extension, Path(sales_person_id): Path, Json(user): Json>, ) -> Response { @@ -190,7 +197,7 @@ pub async fn set_sales_person_user( (async { rest_state .sales_person_service() - .set_user(sales_person_id, user.into(), ().into()) + .set_user(sales_person_id, user.into(), context.into()) .await?; Ok(Response::builder().status(204).body(Body::empty()).unwrap()) }) @@ -200,13 +207,14 @@ pub async fn set_sales_person_user( pub async fn delete_sales_person_user( rest_state: State, + Extension(context): Extension, Path(sales_person_id): Path, ) -> Response { error_handler( (async { rest_state .sales_person_service() - .set_user(sales_person_id, None, ().into()) + .set_user(sales_person_id, None, context.into()) .await?; Ok(Response::builder().status(204).body(Body::empty()).unwrap()) }) diff --git a/rest/src/slot.rs b/rest/src/slot.rs index 73f27f8..2601ca6 100644 --- a/rest/src/slot.rs +++ b/rest/src/slot.rs @@ -5,13 +5,13 @@ use axum::{ extract::{Path, State}, response::Response, routing::{get, post, put}, - Json, Router, + Extension, Json, Router, }; use serde::{Deserialize, Serialize}; use service::slot::SlotService; use uuid::Uuid; -use crate::{error_handler, RestError, RestStateDef}; +use crate::{error_handler, Context, RestError, RestStateDef}; #[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)] pub enum DayOfWeek { @@ -102,12 +102,15 @@ pub fn generate_route() -> Router { .route("/:id", put(update_slot::)) } -pub async fn get_all_slots(rest_state: State) -> Response { +pub async fn get_all_slots( + rest_state: State, + Extension(context): Extension, +) -> Response { error_handler( (async { let slots: Arc<[SlotTO]> = rest_state .slot_service() - .get_slots(().into()) + .get_slots(context.into()) .await? .iter() .map(SlotTO::from) @@ -123,6 +126,7 @@ pub async fn get_all_slots(rest_state: State pub async fn get_slot( rest_state: State, + Extension(context): Extension, Path(slot_id): Path, ) -> Response { error_handler( @@ -130,7 +134,7 @@ pub async fn get_slot( let slot = SlotTO::from( &rest_state .slot_service() - .get_slot(&slot_id, ().into()) + .get_slot(&slot_id, context.into()) .await?, ); Ok(Response::builder() @@ -144,6 +148,7 @@ pub async fn get_slot( pub async fn create_slot( rest_state: State, + Extension(context): Extension, Json(slot): Json, ) -> Response { error_handler( @@ -151,7 +156,7 @@ pub async fn create_slot( let slot = SlotTO::from( &rest_state .slot_service() - .create_slot(&(&slot).into(), ().into()) + .create_slot(&(&slot).into(), context.into()) .await?, ); Ok(Response::builder() @@ -165,6 +170,7 @@ pub async fn create_slot( pub async fn update_slot( rest_state: State, + Extension(context): Extension, Path(slot_id): Path, Json(slot): Json, ) -> Response { @@ -175,7 +181,7 @@ pub async fn update_slot( } rest_state .slot_service() - .update_slot(&(&slot).into(), ().into()) + .update_slot(&(&slot).into(), context.into()) .await?; Ok(Response::builder() .status(200) diff --git a/service/src/lib.rs b/service/src/lib.rs index d891c23..cbcca98 100644 --- a/service/src/lib.rs +++ b/service/src/lib.rs @@ -31,6 +31,9 @@ pub enum ServiceError { #[error("Database query error: {0}")] DatabaseQueryError(#[from] dao::DaoError), + #[error("Unauthorized")] + Unauthorized, + #[error("Forbidden")] Forbidden, diff --git a/service_impl/src/lib.rs b/service_impl/src/lib.rs index 2c031a7..8e17e09 100644 --- a/service_impl/src/lib.rs +++ b/service_impl/src/lib.rs @@ -25,3 +25,19 @@ impl service::user_service::UserService for UserServiceDev { Ok("DEVUSER".into()) } } + +pub struct UserServiceImpl; + +#[async_trait] +impl service::user_service::UserService for UserServiceImpl { + type Context = Option>; + + async fn current_user( + &self, + context: Self::Context, + ) -> Result, service::ServiceError> { + context + .ok_or_else(|| service::ServiceError::Unauthorized) + .map(|user| user.clone()) + } +}