use std::{net::SocketAddr, sync::Arc, time::Instant}; use anyhow::{Context, Result}; use axum::{ Json, Router, extract::State, http::{HeaderMap, StatusCode}, response::IntoResponse, routing::{get, post}, }; use chrono::{Duration, TimeZone, Utc}; use hmac::{Hmac, Mac}; use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, Validation, decode, encode}; use metrics_exporter_prometheus::{PrometheusBuilder, PrometheusHandle}; use redis::AsyncCommands; use serde::{Deserialize, Serialize}; use sha2::Sha256; use tokio_postgres::{Client as PgClient, NoTls}; use tracing::{info, warn}; use uuid::Uuid; type HmacSha256 = Hmac; #[derive(Clone)] struct AppState { jwt_secret: Arc, stripe_webhook_secret: Option>, redis: Option, pg: Option>, metrics: PrometheusHandle, } #[derive(Debug, Serialize, Deserialize, Clone)] struct Claims { sub: String, tier: String, max_tunnels: u32, exp: usize, iat: usize, jti: String, } #[derive(Debug, Deserialize)] struct DevTokenRequest { user_id: String, #[serde(default = "default_tier")] tier: String, #[serde(default = "default_max_tunnels")] max_tunnels: u32, } #[derive(Debug, Serialize)] struct TokenResponse { token: String, expires_at: i64, } #[derive(Debug, Deserialize)] struct ValidateRequest { token: String, } #[derive(Debug, Serialize)] struct ValidateResponse { valid: bool, user_id: Option, tier: Option, max_tunnels: Option, } #[derive(Debug, Deserialize)] struct StripeWebhookEvent { event_type: String, user_id: String, tier: String, #[serde(default = "default_max_tunnels")] max_tunnels: u32, #[serde(default = "default_subscription_status")] status: String, provider_customer_id: Option, provider_subscription_id: Option, current_period_end: Option, } #[derive(Debug, Clone)] struct PlanEntitlement { tier: String, max_tunnels: u32, } #[tokio::main] async fn main() -> Result<()> { tracing_subscriber::fmt() .with_env_filter( tracing_subscriber::EnvFilter::try_from_default_env() .unwrap_or_else(|_| "auth_api=info".into()), ) .init(); let bind = std::env::var("AUTH_BIND").unwrap_or_else(|_| "0.0.0.0:8080".into()); let jwt_secret = std::env::var("JWT_SECRET").unwrap_or_else(|_| "dev-secret-change-me".into()); let stripe_webhook_secret = std::env::var("STRIPE_WEBHOOK_SECRET").ok().map(Arc::new); let metrics = PrometheusBuilder::new() .install_recorder() .context("install prometheus recorder")?; let redis = if let Ok(url) = std::env::var("REDIS_URL") { let client = redis::Client::open(url.clone()).context("open redis client")?; match redis::aio::ConnectionManager::new(client).await { Ok(conn) => { info!("auth-api connected to redis"); Some(conn) } Err(e) => { warn!(error = %e, "auth-api redis unavailable; continuing without cache"); None } } } else { None }; let pg = connect_postgres().await?; let state = AppState { jwt_secret: Arc::new(jwt_secret), stripe_webhook_secret, redis, pg, metrics, }; let app = Router::new() .route("/healthz", get(healthz)) .route("/metrics", get(metrics_endpoint)) .route("/v1/token/dev", post(issue_dev_token)) .route("/v1/token/validate", post(validate_token)) .route("/v1/stripe/webhook", post(stripe_webhook)) .with_state(state); let listener = tokio::net::TcpListener::bind(&bind) .await .with_context(|| format!("bind {bind}"))?; let local_addr: SocketAddr = listener.local_addr()?; info!(addr = %local_addr, "auth-api listening"); axum::serve(listener, app).await?; Ok(()) } async fn connect_postgres() -> Result>> { let Some(url) = std::env::var("DATABASE_URL").ok() else { return Ok(None); }; let (client, conn) = tokio_postgres::connect(&url, NoTls) .await .context("connect postgres")?; tokio::spawn(async move { if let Err(e) = conn.await { warn!(error = %e, "postgres connection task ended"); } }); info!("auth-api connected to postgres"); Ok(Some(Arc::new(client))) } async fn healthz() -> &'static str { "ok" } async fn metrics_endpoint(State(state): State) -> impl IntoResponse { state.metrics.render() } #[tracing::instrument(skip(state))] async fn issue_dev_token( State(state): State, Json(req): Json, ) -> Result, ApiError> { let started = Instant::now(); let now = Utc::now(); let exp = now + Duration::hours(24); let claims = Claims { sub: req.user_id, tier: req.tier, max_tunnels: req.max_tunnels, exp: exp.timestamp() as usize, iat: now.timestamp() as usize, jti: format!("jti-{}", fastrand::u64(..)), }; let token = encode( &Header::default(), &claims, &EncodingKey::from_secret(state.jwt_secret.as_bytes()), ) .map_err(ApiError::internal)?; sync_jwt_cache(&state, &claims).await?; metrics::counter!("auth_jwt_issued_total").increment(1); metrics::histogram!("auth_issue_token_latency_ms") .record(started.elapsed().as_secs_f64() * 1000.0); Ok(Json(TokenResponse { token, expires_at: exp.timestamp(), })) } #[tracing::instrument(skip(state, req))] async fn validate_token( State(state): State, Json(req): Json, ) -> Result, ApiError> { let started = Instant::now(); let decoded = decode::( &req.token, &DecodingKey::from_secret(state.jwt_secret.as_bytes()), &Validation::new(Algorithm::HS256), ); let response = match decoded { Ok(tok) => { let claims = tok.claims; let ent = match load_plan_for_user(&state, &claims.sub).await? { Some(db_plan) => db_plan, None => PlanEntitlement { tier: claims.tier.clone(), max_tunnels: claims.max_tunnels, }, }; sync_plan_cache(&state, &claims.sub, &ent, "auth-validate").await?; metrics::counter!("auth_token_validate_total", "result" => "valid").increment(1); ValidateResponse { valid: true, user_id: Some(claims.sub), tier: Some(ent.tier), max_tunnels: Some(ent.max_tunnels), } } Err(_) => { metrics::counter!("auth_token_validate_total", "result" => "invalid").increment(1); ValidateResponse { valid: false, user_id: None, tier: None, max_tunnels: None, } } }; metrics::histogram!("auth_validate_token_latency_ms") .record(started.elapsed().as_secs_f64() * 1000.0); Ok(Json(response)) } #[tracing::instrument(skip(state, headers, body))] async fn stripe_webhook( State(state): State, headers: HeaderMap, body: String, ) -> Result { let started = Instant::now(); verify_stripe_signature(state.stripe_webhook_secret.as_deref().map(|s| s.as_str()), &headers, &body)?; let event: StripeWebhookEvent = serde_json::from_str(&body).map_err(ApiError::bad_request)?; let ent = PlanEntitlement { tier: event.tier.clone(), max_tunnels: event.max_tunnels, }; if let Some(pg) = &state.pg { upsert_subscription_from_webhook(pg, &event).await?; } sync_plan_cache(&state, &event.user_id, &ent, "stripe_webhook").await?; metrics::counter!("stripe_webhook_events_total", "event_type" => event.event_type.clone()) .increment(1); metrics::histogram!("stripe_webhook_latency_ms") .record(started.elapsed().as_secs_f64() * 1000.0); Ok(StatusCode::NO_CONTENT) } async fn sync_jwt_cache(state: &AppState, claims: &Claims) -> Result<(), ApiError> { if let Some(mut redis) = state.redis.clone() { let key = format!("auth:jwt:jti:{}", claims.jti); let ttl = (claims.exp as i64 - Utc::now().timestamp()).max(1); let payload = serde_json::json!({ "user_id": claims.sub, "plan_tier": claims.tier, "max_tunnels": claims.max_tunnels }) .to_string(); let _: () = redis .set_ex(key, payload, ttl as u64) .await .map_err(ApiError::internal)?; } Ok(()) } async fn sync_plan_cache( state: &AppState, user_id: &str, ent: &PlanEntitlement, source: &str, ) -> Result<(), ApiError> { if let Some(mut redis) = state.redis.clone() { let key = format!("plan:user:{user_id}"); let payload = serde_json::json!({ "tier": ent.tier, "max_tunnels": ent.max_tunnels, "source": source, "updated_at": Utc::now().timestamp() }) .to_string(); let _: () = redis.set_ex(key, payload, 300).await.map_err(ApiError::internal)?; } Ok(()) } async fn load_plan_for_user(state: &AppState, user_id: &str) -> Result, ApiError> { let Some(pg) = &state.pg else { return Ok(None); }; let _ = Uuid::parse_str(user_id).map_err(ApiError::bad_request)?; let row = pg .query_opt( r#" select p.id as plan_id, p.max_tunnels from subscriptions s join plans p on p.id = s.plan_id where s.user_id = $1::uuid and s.status in ('active', 'trialing') order by s.updated_at desc limit 1 "#, &[&user_id], ) .await .map_err(ApiError::internal)?; Ok(row.map(|r| PlanEntitlement { tier: r.get::<_, String>("plan_id"), max_tunnels: r.get::<_, i32>("max_tunnels") as u32, })) } async fn upsert_subscription_from_webhook(pg: &PgClient, event: &StripeWebhookEvent) -> Result<(), ApiError> { let _ = Uuid::parse_str(&event.user_id).map_err(ApiError::bad_request)?; let period_end = event .current_period_end .and_then(|ts| Utc.timestamp_opt(ts, 0).single()) .map(|dt| dt.naive_utc()); let provider_sub_id = event .provider_subscription_id .clone() .unwrap_or_else(|| format!("sub-dev-{}", fastrand::u64(..))); pg.execute( r#" insert into subscriptions ( user_id, plan_id, provider, provider_customer_id, provider_subscription_id, status, current_period_end, updated_at ) values ($1::uuid, $2, 'stripe', $3, $4, $5, $6, now()) on conflict (provider_subscription_id) do update set user_id = excluded.user_id, plan_id = excluded.plan_id, provider_customer_id = excluded.provider_customer_id, status = excluded.status, current_period_end = excluded.current_period_end, updated_at = now() "#, &[ &event.user_id, &event.tier, &event.provider_customer_id, &provider_sub_id, &event.status, &period_end, ], ) .await .map_err(ApiError::internal)?; Ok(()) } fn verify_stripe_signature( secret: Option<&str>, headers: &HeaderMap, body: &str, ) -> Result<(), ApiError> { let Some(secret) = secret else { return Ok(()); }; let header = headers .get("stripe-signature") .and_then(|h| h.to_str().ok()) .ok_or_else(|| ApiError::unauthorized("missing stripe-signature header"))?; let mut timestamp: Option = None; let mut sigs: Vec<&str> = Vec::new(); for part in header.split(',') { let mut kv = part.trim().splitn(2, '='); let k = kv.next().unwrap_or_default(); let v = kv.next().unwrap_or_default(); match k { "t" => timestamp = v.parse::().ok(), "v1" => sigs.push(v), _ => {} } } let ts = timestamp.ok_or_else(|| ApiError::unauthorized("invalid stripe signature timestamp"))?; let now = Utc::now().timestamp(); if (now - ts).abs() > 300 { return Err(ApiError::unauthorized("stale stripe signature")); } let signed_payload = format!("{ts}.{body}"); let mut mac = HmacSha256::new_from_slice(secret.as_bytes()).map_err(ApiError::internal)?; mac.update(signed_payload.as_bytes()); let expected = mac.finalize().into_bytes(); for sig in sigs { let Ok(bytes) = hex::decode(sig) else { continue; }; let mut verify_mac = HmacSha256::new_from_slice(secret.as_bytes()).map_err(ApiError::internal)?; verify_mac.update(signed_payload.as_bytes()); if verify_mac.verify_slice(&bytes).is_ok() { return Ok(()); } } let _ = expected; Err(ApiError::unauthorized("stripe signature verification failed")) } #[derive(Debug)] struct ApiError { status: StatusCode, message: String, } impl ApiError { fn internal(e: E) -> Self { Self { status: StatusCode::INTERNAL_SERVER_ERROR, message: e.to_string(), } } fn bad_request(e: E) -> Self { Self { status: StatusCode::BAD_REQUEST, message: e.to_string(), } } fn unauthorized>(msg: E) -> Self { Self { status: StatusCode::UNAUTHORIZED, message: msg.into(), } } } impl IntoResponse for ApiError { fn into_response(self) -> axum::response::Response { (self.status, self.message).into_response() } } fn default_tier() -> String { "free".to_string() } fn default_max_tunnels() -> u32 { 1 } fn default_subscription_status() -> String { "active".to_string() }