Add oidc support

This commit is contained in:
Simon Goller 2024-06-04 20:07:58 +02:00
parent ed609cf06c
commit a868ceb0cd
7 changed files with 1111 additions and 25 deletions

2
.gitignore vendored
View file

@ -1,2 +1,4 @@
localdb.sqlite3*
target
.idea
result

1028
Cargo.lock generated

File diff suppressed because it is too large Load diff

View file

@ -30,3 +30,6 @@ version = "0.3.36"
[dependencies.time-macros]
version = "0.2.18"
[dependencies.dotenvy]
version = "0.15"

View file

@ -129,6 +129,7 @@ async fn create_dev_admin_user(pool: Arc<SqlitePool>) {
#[tokio::main]
async fn main() {
dotenvy::dotenv().ok();
let pool = Arc::new(
SqlitePool::connect("sqlite:./localdb.sqlite3")
.await

View file

@ -1,10 +1,18 @@
{ pkgs ? import <nixpkgs> {} }:
{ pkgs ? import <nixpkgs> {}, features ? [] }:
let
rustPlatform = pkgs.rustPlatform;
specificPkgs = import (pkgs.fetchFromGitHub {
owner = "NixOS";
repo = "nixpkgs";
rev = "57610d2f8f0937f39dbd72251e9614b1561942d8";
sha256 = "sha256-yZKhxVIKd2lsbOqYd5iDoUIwsRZFqE87smE2Vzf6Ck0=";
}) {};
rustPlatform = specificPkgs.rustPlatform;
in
rustPlatform.buildRustPackage {
pname = "shifty-service";
version = "0.1";
src = ./.;
cargoHash = "sha256-bgtX30TGRlBjCZ8qbqNgovsZrZqJ9kEGlv/qv6T5uZA=";
buildFeatures = features;
cargoHash = "sha256-sTKupn3HMBf3lumCu1RUkzutc+RUNpuqEyGR2BMxAso=";
}

View file

@ -4,6 +4,10 @@ version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[features]
#default = ["oidc"]
default = []
oidc = ["dep:axum-oidc"]
[dependencies]
axum = "0.7.5"
@ -29,3 +33,13 @@ features = ["derive", "std", "alloc", "rc"]
[dependencies.thiserror]
version = "1.0"
[dependencies.tower]
version = "0.4.4"
[dependencies.tower-sessions]
version = "0.12"
[dependencies.axum-oidc]
version = "0.4"
optional = true

View file

@ -5,9 +5,14 @@ mod permission;
mod sales_person;
mod slot;
use axum::{body::Body, response::Response, Router};
use axum::http::Uri;
use axum::response::IntoResponse;
use axum::{body::Body, error_handling::HandleErrorLayer, response::Response, Router};
use service::ServiceError;
use thiserror::Error;
use time::Duration;
use tower::ServiceBuilder;
use tower_sessions::{cookie::SameSite, Expiry, MemoryStore, SessionManagerLayer};
use uuid::Uuid;
// TODO: In prod, it must be a different type than in dev mode.
@ -151,6 +156,31 @@ pub trait RestStateDef: Clone + Send + Sync + 'static {
fn booking_service(&self) -> Arc<Self::BookingService>;
}
pub struct OidcConfig {
pub app_url: String,
pub issuer: String,
pub client_id: String,
pub client_secret: Option<String>,
}
pub fn oidc_config() -> OidcConfig {
let app_url = std::env::var("APP_URL").expect("APP_URL env variable");
let issuer = std::env::var("ISSUER").expect("ISSUER env variable");
let client_id = std::env::var("CLIENT_ID").expect("CLIENT_ID env variable");
let client_secret = std::env::var("CLIENT_SECRET").ok();
OidcConfig {
app_url: app_url.into(),
issuer: issuer.into(),
client_id: client_id.into(),
client_secret: client_secret.unwrap_or_default().into(),
}
}
pub fn bind_address() -> Arc<str> {
std::env::var("SERVER_ADDRESS")
.unwrap_or("127.0.0.1:3000".into())
.into()
}
pub async fn start_server<RestState: RestStateDef>(rest_state: RestState) {
let app = Router::new()
.nest("/permission", permission::generate_route())
@ -158,7 +188,47 @@ pub async fn start_server<RestState: RestStateDef>(rest_state: RestState) {
.nest("/sales-person", sales_person::generate_route())
.nest("/booking", booking::generate_route())
.with_state(rest_state);
let listener = tokio::net::TcpListener::bind("127.0.0.1:3000")
#[cfg(feature = "oidc")]
let app = {
use axum_oidc::error::MiddlewareError;
use axum_oidc::{EmptyAdditionalClaims, OidcAuthLayer, OidcLoginLayer};
let oidc_config = oidc_config();
let session_store = MemoryStore::default();
let session_layer = SessionManagerLayer::new(session_store)
.with_secure(false)
.with_same_site(SameSite::Lax)
.with_expiry(Expiry::OnInactivity(Duration::seconds(120)));
let oidc_login_service = ServiceBuilder::new()
.layer(HandleErrorLayer::new(|e: MiddlewareError| async {
e.into_response()
}))
.layer(OidcLoginLayer::<EmptyAdditionalClaims>::new());
let oidc_auth_service = ServiceBuilder::new()
.layer(HandleErrorLayer::new(|e: MiddlewareError| async {
e.into_response()
}))
.layer(
OidcAuthLayer::<EmptyAdditionalClaims>::discover_client(
Uri::from_maybe_shared(oidc_config.app_url).expect("valid APP_URL"),
oidc_config.issuer,
oidc_config.client_id,
oidc_config.client_secret,
vec![],
)
.await
.unwrap(),
);
app.layer(oidc_login_service)
.layer(oidc_auth_service)
.layer(session_layer)
};
let listener = tokio::net::TcpListener::bind(bind_address().as_ref())
.await
.expect("Could not bind server");
axum::serve(listener, app)