Skip to content

Commit

Permalink
Add OIDC authentication endpoints
Browse files Browse the repository at this point in the history
  • Loading branch information
pka committed Jun 22, 2023
1 parent ce58bd1 commit 81a81ec
Show file tree
Hide file tree
Showing 12 changed files with 789 additions and 18 deletions.
551 changes: 543 additions & 8 deletions Cargo.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ num_cpus = "1.13.1"
once_cell = "1.12.0"
opentelemetry = { version = "0.18", default-features = false, features = ["trace", "metrics", "rt-tokio"] }
prometheus = { version = "0.13", default-features = false }
reqwest = { version = "0.11.11", default-features = false, features = ["rustls-tls"] }
rust-embed = { version = "5.6.0", features = ["compression"] }
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0.57"
Expand Down
4 changes: 4 additions & 0 deletions bbox-common/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ edition = "2018"
html = []

[dependencies]
actix-session = { version = "0.7", features = ["cookie-session"] }
actix-web = { workspace = true }
actix-web-opentelemetry = { version = "0.13", features = ["metrics-prometheus"] }
async-stream = { workspace = true }
Expand All @@ -23,16 +24,19 @@ mime_guess = "2.0.3"
minijinja = { workspace = true }
num_cpus = { workspace = true }
once_cell = "1.8.0"
openidconnect = "3.2.0"
opentelemetry = { workspace = true }
opentelemetry-jaeger = { version = "0.17", features = ["rt-tokio"] }
opentelemetry-prometheus = { version = "0.11" }
prometheus = { workspace = true }
reqwest = { workspace = true }
rust-embed = { workspace = true }
rustls = "0.20.8" # Same as actix-tls -> tokio-rustls
rustls-pemfile = "1.0.2"
serde = { workspace = true }
serde_json = { workspace = true }
serde_yaml = "0.8.24"
thiserror = { workspace = true }

[dev-dependencies]

Expand Down
6 changes: 6 additions & 0 deletions bbox-common/src/auth/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pub mod oidc;

pub struct Identity {
pub username: String,
pub groups: Vec<String>,
}
148 changes: 148 additions & 0 deletions bbox-common/src/auth/oidc.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
use super::Identity;
use log::{debug, info};
use openidconnect::{
core::{CoreClient, CoreErrorResponseType, CoreProviderMetadata, CoreResponseType},
reqwest::async_http_client,
AuthenticationFlow, AuthorizationCode, ClaimsVerificationError, ClientId, ClientSecret,
CsrfToken, IssuerUrl, Nonce, OAuth2TokenResponse, RedirectUrl, RequestTokenError, Scope,
StandardErrorResponse,
};
use serde::Deserialize;
use serde_json::Value;

#[derive(thiserror::Error, Debug)]
pub enum AuthError {
#[error(transparent)]
OidcRequestTokenError(
#[from]
RequestTokenError<
openidconnect::reqwest::Error<reqwest::Error>,
StandardErrorResponse<CoreErrorResponseType>,
>,
),
#[error(transparent)]
OidcClaimsVerificationError(#[from] ClaimsVerificationError),
#[error("Server did not return an ID token")]
OpenidIdTokenError,
}

#[derive(Deserialize, Default, Clone, Debug)]
#[serde(default, deny_unknown_fields)]
pub struct OidcAuthCfg {
pub client_id: String,
pub client_secret: String,
pub issuer_url: String,
pub redirect_uri: Option<String>,
pub scopes: Option<String>,
pub username_claim: Option<String>,
pub groupinfo_claim: Option<String>,
}

#[derive(Clone, Debug)]
pub struct OidcClient {
client: CoreClient,
pub authorize_url: String,
nonce: Nonce,
username_claim: Option<String>,
groupinfo_claim: String,
}

impl OidcClient {
pub async fn from_config(cfg: &OidcAuthCfg) -> Self {
info!(
"Fetching {}/.well-known/openid-configuration",
&cfg.issuer_url
);
let provider_metadata = CoreProviderMetadata::discover_async(
IssuerUrl::new(cfg.issuer_url.clone()).expect("Invalid issuer URL"),
async_http_client,
)
.await
.expect("Failed to discover OpenID Provider");

// Set up the config for the OAuth2 process.
let redirect_uri = cfg
.redirect_uri
.clone()
.unwrap_or("http://127.0.0.1:8080/auth".to_string());
let client = CoreClient::from_provider_metadata(
provider_metadata,
ClientId::new(cfg.client_id.clone()),
Some(ClientSecret::new(cfg.client_secret.clone())),
)
.set_redirect_uri(RedirectUrl::new(redirect_uri).expect("Invalid redirect URL"));

// Generate the authorization URL to which we'll redirect the user.
let mut auth_client = client.authorize_url(
AuthenticationFlow::<CoreResponseType>::AuthorizationCode,
CsrfToken::new_random,
Nonce::new_random,
);
let scopes = cfg.scopes.clone().unwrap_or("email profile".to_string());
for scope in scopes.split(' ') {
auth_client = auth_client.add_scope(Scope::new(scope.to_string()));
}
let (authorize_url, _csrf_state, nonce) = auth_client.url();
let groupinfo_claim = cfg.groupinfo_claim.clone().unwrap_or("group".to_string());
OidcClient {
client,
authorize_url: authorize_url.to_string(),
nonce,
username_claim: cfg.username_claim.clone(),
groupinfo_claim,
}
}
}

#[derive(Deserialize, Debug)]
pub struct AuthRequest {
pub code: String,
// pub state: String,
// pub scope: String,
}

impl AuthRequest {
pub async fn auth(&self, oidc: &OidcClient) -> Result<Identity, AuthError> {
// let state = CsrfToken::new(self.state.clone());
let code = AuthorizationCode::new(self.code.clone());
// Exchange the code with a token.
let token_response = oidc
.client
.exchange_code(code)
.request_async(async_http_client)
.await?;
debug!("IdP returned scopes: {:?}", token_response.scopes());

let id_token_verifier = oidc.client.id_token_verifier();
let id_token_claims = token_response
.extra_fields()
.id_token()
.ok_or(AuthError::OpenidIdTokenError)?
.claims(&id_token_verifier, &oidc.nonce)?;

// Convert back to raw JSON to simplify extracting configurable claims
let userinfo = serde_json::to_value(id_token_claims).unwrap();
info!("userinfo: {userinfo:#?}");

let username = if let Some(claim) = &oidc.username_claim {
userinfo[claim].as_str()
} else {
userinfo
.get("preferred_username")
.or(userinfo.get("upn"))
.or(userinfo.get("email"))
.and_then(|v| v.as_str())
}
.unwrap_or("")
.to_string();
let groups = match &userinfo[&oidc.groupinfo_claim] {
Value::String(s) => vec![s.as_str().to_string()],
Value::Array(arr) => arr
.iter()
.filter_map(|v| v.as_str().map(str::to_string))
.collect(),
_ => Vec::new(),
};
Ok(Identity { username, groups })
}
}
13 changes: 13 additions & 0 deletions bbox-common/src/config.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use crate::auth::oidc::OidcAuthCfg;
use actix_web::HttpRequest;
use core::fmt::Display;
use figment::providers::{Env, Format, Toml};
Expand Down Expand Up @@ -95,6 +96,18 @@ impl WebserverCfg {
}
}

#[derive(Deserialize, Default, Clone, Debug)]
#[serde(default, deny_unknown_fields)]
pub struct AuthCfg {
pub oidc: Option<OidcAuthCfg>,
}

impl AuthCfg {
pub fn from_config() -> Self {
from_config_opt_or_exit("auth").unwrap_or_default()
}
}

// -- Metrics --

#[derive(Deserialize, Default, Debug)]
Expand Down
40 changes: 38 additions & 2 deletions bbox-common/src/endpoints.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
use crate::api::{OgcApiInventory, OpenApiDoc};
use crate::auth::oidc::{AuthRequest, OidcClient};
use crate::config::WebserverCfg;
use crate::ogcapi::*;
use crate::service::CoreService;
use actix_web::web::Bytes;
use actix_session::Session;
use actix_web::{
guard, guard::Guard, guard::GuardContext, http::header, web, HttpRequest, HttpResponse,
error::ErrorInternalServerError, guard, guard::Guard, guard::GuardContext, http::header,
http::StatusCode, web, web::Bytes, HttpRequest, HttpResponse, Responder,
};
use actix_web_opentelemetry::PrometheusMetricsHandler;
use async_stream::stream;
use futures_core::stream::Stream;
use log::info;
use std::collections::HashMap;
use std::convert::Infallible;
use std::io::Read;
Expand Down Expand Up @@ -150,6 +153,32 @@ async fn health() -> HttpResponse {
HttpResponse::Ok().body("OK")
}

async fn login(oidc: web::Data<OidcClient>) -> impl Responder {
web::Redirect::to(oidc.authorize_url.clone()).using_status_code(StatusCode::FOUND)
}

async fn auth(
session: Session,
oidc: web::Data<OidcClient>,
params: web::Query<AuthRequest>,
) -> actix_web::Result<impl Responder> {
let identity = params.auth(&oidc).await.map_err(ErrorInternalServerError)?;
info!(
"username: `{}` groups: {:?}",
identity.username, identity.groups
);

session.insert("username", identity.username).unwrap();
session.insert("groups", identity.groups).unwrap();

Ok(web::Redirect::to("/").using_status_code(StatusCode::FOUND))
}

async fn logout(session: Session) -> impl Responder {
session.clear();
web::Redirect::to("/").using_status_code(StatusCode::FOUND)
}

impl CoreService {
pub(crate) fn register(&self, cfg: &mut web::ServiceConfig, _core: &CoreService) {
cfg.app_data(web::Data::new(self.web_config.clone()))
Expand Down Expand Up @@ -182,6 +211,13 @@ impl CoreService {
)
.service(web::resource("/health").to(health));

if let Some(oidc) = &self.oidc {
cfg.app_data(web::Data::new(oidc.clone()))
.service(web::resource("/login").route(web::get().to(login)))
.service(web::resource("/auth").route(web::get().to(auth)))
.service(web::resource("/logout").route(web::get().to(logout)));
}

if let Some(metrics) = &self.metrics {
let metrics_handler = PrometheusMetricsHandler::new(metrics.exporter.clone());
//TODO: path from MetricsCfg
Expand Down
3 changes: 1 addition & 2 deletions bbox-common/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
pub mod api;
pub mod auth;
pub mod cli;
pub mod config;
pub mod endpoints;
Expand All @@ -12,8 +13,6 @@ pub mod static_files;
pub mod templates;
pub mod tls;

// pub use utoipa::{path as api_path, Component as ApiComponent, OpenApi};

use std::env;
use std::path::Path;

Expand Down
30 changes: 26 additions & 4 deletions bbox-common/src/service.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
use crate::api::{OgcApiInventory, OpenApiDoc};
use crate::auth::oidc::OidcClient;
use crate::cli::{CommonCommands, GlobalArgs, NoArgs, NoCommands};
use crate::config::WebserverCfg;
use crate::config::{AuthCfg, WebserverCfg};
use crate::logger;
use crate::metrics::{init_metrics, Metrics};
use crate::ogcapi::{ApiLink, CoreCollection};
use crate::tls::load_rustls_config;
use actix_web::{middleware, web, App, HttpServer};
use actix_session::{config::PersistentSession, storage::CookieSessionStore, SessionMiddleware};
use actix_web::{
cookie::{time::Duration, Key},
middleware, web, App, HttpServer,
};
use actix_web_opentelemetry::{RequestMetrics, RequestMetricsBuilder, RequestTracing};
use async_trait::async_trait;
use clap::{ArgMatches, Args, Command, CommandFactory, FromArgMatches, Parser, Subcommand};
Expand Down Expand Up @@ -57,6 +62,7 @@ pub struct CoreService {
pub(crate) ogcapi: OgcApiInventory,
pub(crate) openapi: OpenApiDoc,
pub(crate) metrics: Option<Metrics>,
pub(crate) oidc: Option<OidcClient>,
}

impl Default for CoreService {
Expand All @@ -67,6 +73,7 @@ impl Default for CoreService {
ogcapi: OgcApiInventory::default(),
openapi: OpenApiDoc::new(),
metrics: None,
oidc: None,
}
}
}
Expand Down Expand Up @@ -157,7 +164,11 @@ impl OgcApiService for CoreService {
logger::init();

self.web_config = WebserverCfg::from_config();
let auth_cfg = AuthCfg::from_config();
self.metrics = init_metrics();
if let Some(cfg) = &auth_cfg.oidc {
self.oidc = Some(OidcClient::from_config(cfg).await);
}
}
fn landing_page_links(&self, _api_base: &str) -> Vec<ApiLink> {
vec![
Expand Down Expand Up @@ -225,15 +236,26 @@ pub async fn run_service<T: OgcApiService + Sync + 'static>() -> std::io::Result
return Ok(());
}

let secret_key = Key::generate();
let session_ttl = Duration::minutes(1);

let workers = core.workers();
let server_addr = core.server_addr().to_string();
let tls_config = core.tls_config();
let mut server = HttpServer::new(move || {
App::new()
.wrap(middleware::Logger::default())
.wrap(middleware::Compress::default())
.configure(|mut cfg| core.register_endpoints(&mut cfg, &core))
.configure(|mut cfg| service.register_endpoints(&mut cfg, &core))
.wrap(
SessionMiddleware::builder(CookieSessionStore::default(), secret_key.clone())
.cookie_name("bbox".to_owned())
.cookie_secure(false)
.session_lifecycle(PersistentSession::default().session_ttl(session_ttl))
.build(),
)
.wrap(middleware::Compress::default())
.wrap(middleware::NormalizePath::trim())
.wrap(middleware::Logger::default())
});
if let Some(tls_config) = tls_config {
server = server.bind_rustls(server_addr, tls_config)?;
Expand Down
2 changes: 1 addition & 1 deletion bbox-common/src/tls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ pub fn load_rustls_config(tls_cert: &str, tls_key: &str) -> rustls::ServerConfig

// load TLS key/cert files
let cert_file = &mut BufReader::new(File::open(&app_dir(tls_cert)).unwrap_or_else(error_exit));
let key_file = &mut BufReader::new(File::open(app_dir(tls_key)).unwrap_or_else(error_exit));
let key_file = &mut BufReader::new(File::open(&app_dir(tls_key)).unwrap_or_else(error_exit));

// convert files to key/cert objects
let cert_chain = certs(cert_file)
Expand Down
2 changes: 1 addition & 1 deletion bbox-tile-server/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ log = { workspace = true }
martin-mbtiles = { version = "0.2.1", default-features = false, features = ["rustls"], git = "https://github.com/pka/martin", branch = "sqlx-rustls" }
num_cpus = { workspace = true }
prometheus = { workspace = true }
reqwest = { version = "0.11.11", default-features = false, features = ["rustls-tls"] }
reqwest = { workspace = true }
rusoto_core = { version = "0.47.0", default-features = false, features = ["rustls"] }
rusoto_s3 = { version = "0.47.0", default-features = false, features = ["rustls"] }
serde = { workspace = true }
Expand Down
Loading

0 comments on commit 81a81ec

Please sign in to comment.