488 lines
14 KiB
Rust
488 lines
14 KiB
Rust
use std::{net::SocketAddr, sync::Arc, time::Instant};
|
|
|
|
use anyhow::{Context, Result};
|
|
use axum::{
|
|
Json, Router,
|
|
extract::State,
|
|
http::{HeaderMap, StatusCode},
|
|
response::IntoResponse,
|
|
routing::{get, post},
|
|
};
|
|
use chrono::{Duration, TimeZone, Utc};
|
|
use hmac::{Hmac, Mac};
|
|
use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, Validation, decode, encode};
|
|
use metrics_exporter_prometheus::{PrometheusBuilder, PrometheusHandle};
|
|
use redis::AsyncCommands;
|
|
use serde::{Deserialize, Serialize};
|
|
use sha2::Sha256;
|
|
use tokio_postgres::{Client as PgClient, NoTls};
|
|
use tracing::{info, warn};
|
|
use uuid::Uuid;
|
|
|
|
type HmacSha256 = Hmac<Sha256>;
|
|
|
|
#[derive(Clone)]
|
|
struct AppState {
|
|
jwt_secret: Arc<String>,
|
|
stripe_webhook_secret: Option<Arc<String>>,
|
|
redis: Option<redis::aio::ConnectionManager>,
|
|
pg: Option<Arc<PgClient>>,
|
|
metrics: PrometheusHandle,
|
|
}
|
|
|
|
#[derive(Debug, Serialize, Deserialize, Clone)]
|
|
struct Claims {
|
|
sub: String,
|
|
tier: String,
|
|
max_tunnels: u32,
|
|
exp: usize,
|
|
iat: usize,
|
|
jti: String,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
struct DevTokenRequest {
|
|
user_id: String,
|
|
#[serde(default = "default_tier")]
|
|
tier: String,
|
|
#[serde(default = "default_max_tunnels")]
|
|
max_tunnels: u32,
|
|
}
|
|
|
|
#[derive(Debug, Serialize)]
|
|
struct TokenResponse {
|
|
token: String,
|
|
expires_at: i64,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
struct ValidateRequest {
|
|
token: String,
|
|
}
|
|
|
|
#[derive(Debug, Serialize)]
|
|
struct ValidateResponse {
|
|
valid: bool,
|
|
user_id: Option<String>,
|
|
tier: Option<String>,
|
|
max_tunnels: Option<u32>,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
struct StripeWebhookEvent {
|
|
event_type: String,
|
|
user_id: String,
|
|
tier: String,
|
|
#[serde(default = "default_max_tunnels")]
|
|
max_tunnels: u32,
|
|
#[serde(default = "default_subscription_status")]
|
|
status: String,
|
|
provider_customer_id: Option<String>,
|
|
provider_subscription_id: Option<String>,
|
|
current_period_end: Option<i64>,
|
|
}
|
|
|
|
#[derive(Debug, Clone)]
|
|
struct PlanEntitlement {
|
|
tier: String,
|
|
max_tunnels: u32,
|
|
}
|
|
|
|
#[tokio::main]
|
|
async fn main() -> Result<()> {
|
|
tracing_subscriber::fmt()
|
|
.with_env_filter(
|
|
tracing_subscriber::EnvFilter::try_from_default_env()
|
|
.unwrap_or_else(|_| "auth_api=info".into()),
|
|
)
|
|
.init();
|
|
|
|
let bind = std::env::var("AUTH_BIND").unwrap_or_else(|_| "0.0.0.0:8080".into());
|
|
let jwt_secret = std::env::var("JWT_SECRET").unwrap_or_else(|_| "dev-secret-change-me".into());
|
|
let stripe_webhook_secret = std::env::var("STRIPE_WEBHOOK_SECRET").ok().map(Arc::new);
|
|
let metrics = PrometheusBuilder::new()
|
|
.install_recorder()
|
|
.context("install prometheus recorder")?;
|
|
|
|
let redis = if let Ok(url) = std::env::var("REDIS_URL") {
|
|
let client = redis::Client::open(url.clone()).context("open redis client")?;
|
|
match redis::aio::ConnectionManager::new(client).await {
|
|
Ok(conn) => {
|
|
info!("auth-api connected to redis");
|
|
Some(conn)
|
|
}
|
|
Err(e) => {
|
|
warn!(error = %e, "auth-api redis unavailable; continuing without cache");
|
|
None
|
|
}
|
|
}
|
|
} else {
|
|
None
|
|
};
|
|
|
|
let pg = connect_postgres().await?;
|
|
|
|
let state = AppState {
|
|
jwt_secret: Arc::new(jwt_secret),
|
|
stripe_webhook_secret,
|
|
redis,
|
|
pg,
|
|
metrics,
|
|
};
|
|
|
|
let app = Router::new()
|
|
.route("/healthz", get(healthz))
|
|
.route("/metrics", get(metrics_endpoint))
|
|
.route("/v1/token/dev", post(issue_dev_token))
|
|
.route("/v1/token/validate", post(validate_token))
|
|
.route("/v1/stripe/webhook", post(stripe_webhook))
|
|
.with_state(state);
|
|
|
|
let listener = tokio::net::TcpListener::bind(&bind)
|
|
.await
|
|
.with_context(|| format!("bind {bind}"))?;
|
|
let local_addr: SocketAddr = listener.local_addr()?;
|
|
info!(addr = %local_addr, "auth-api listening");
|
|
axum::serve(listener, app).await?;
|
|
Ok(())
|
|
}
|
|
|
|
async fn connect_postgres() -> Result<Option<Arc<PgClient>>> {
|
|
let Some(url) = std::env::var("DATABASE_URL").ok() else {
|
|
return Ok(None);
|
|
};
|
|
let (client, conn) = tokio_postgres::connect(&url, NoTls)
|
|
.await
|
|
.context("connect postgres")?;
|
|
tokio::spawn(async move {
|
|
if let Err(e) = conn.await {
|
|
warn!(error = %e, "postgres connection task ended");
|
|
}
|
|
});
|
|
info!("auth-api connected to postgres");
|
|
Ok(Some(Arc::new(client)))
|
|
}
|
|
|
|
async fn healthz() -> &'static str {
|
|
"ok"
|
|
}
|
|
|
|
async fn metrics_endpoint(State(state): State<AppState>) -> impl IntoResponse {
|
|
state.metrics.render()
|
|
}
|
|
|
|
#[tracing::instrument(skip(state))]
|
|
async fn issue_dev_token(
|
|
State(state): State<AppState>,
|
|
Json(req): Json<DevTokenRequest>,
|
|
) -> Result<Json<TokenResponse>, ApiError> {
|
|
let started = Instant::now();
|
|
let now = Utc::now();
|
|
let exp = now + Duration::hours(24);
|
|
let claims = Claims {
|
|
sub: req.user_id,
|
|
tier: req.tier,
|
|
max_tunnels: req.max_tunnels,
|
|
exp: exp.timestamp() as usize,
|
|
iat: now.timestamp() as usize,
|
|
jti: format!("jti-{}", fastrand::u64(..)),
|
|
};
|
|
|
|
let token = encode(
|
|
&Header::default(),
|
|
&claims,
|
|
&EncodingKey::from_secret(state.jwt_secret.as_bytes()),
|
|
)
|
|
.map_err(ApiError::internal)?;
|
|
|
|
sync_jwt_cache(&state, &claims).await?;
|
|
metrics::counter!("auth_jwt_issued_total").increment(1);
|
|
metrics::histogram!("auth_issue_token_latency_ms")
|
|
.record(started.elapsed().as_secs_f64() * 1000.0);
|
|
|
|
Ok(Json(TokenResponse {
|
|
token,
|
|
expires_at: exp.timestamp(),
|
|
}))
|
|
}
|
|
|
|
#[tracing::instrument(skip(state, req))]
|
|
async fn validate_token(
|
|
State(state): State<AppState>,
|
|
Json(req): Json<ValidateRequest>,
|
|
) -> Result<Json<ValidateResponse>, ApiError> {
|
|
let started = Instant::now();
|
|
let decoded = decode::<Claims>(
|
|
&req.token,
|
|
&DecodingKey::from_secret(state.jwt_secret.as_bytes()),
|
|
&Validation::new(Algorithm::HS256),
|
|
);
|
|
|
|
let response = match decoded {
|
|
Ok(tok) => {
|
|
let claims = tok.claims;
|
|
let ent = match load_plan_for_user(&state, &claims.sub).await? {
|
|
Some(db_plan) => db_plan,
|
|
None => PlanEntitlement {
|
|
tier: claims.tier.clone(),
|
|
max_tunnels: claims.max_tunnels,
|
|
},
|
|
};
|
|
sync_plan_cache(&state, &claims.sub, &ent, "auth-validate").await?;
|
|
metrics::counter!("auth_token_validate_total", "result" => "valid").increment(1);
|
|
ValidateResponse {
|
|
valid: true,
|
|
user_id: Some(claims.sub),
|
|
tier: Some(ent.tier),
|
|
max_tunnels: Some(ent.max_tunnels),
|
|
}
|
|
}
|
|
Err(_) => {
|
|
metrics::counter!("auth_token_validate_total", "result" => "invalid").increment(1);
|
|
ValidateResponse {
|
|
valid: false,
|
|
user_id: None,
|
|
tier: None,
|
|
max_tunnels: None,
|
|
}
|
|
}
|
|
};
|
|
|
|
metrics::histogram!("auth_validate_token_latency_ms")
|
|
.record(started.elapsed().as_secs_f64() * 1000.0);
|
|
Ok(Json(response))
|
|
}
|
|
|
|
#[tracing::instrument(skip(state, headers, body))]
|
|
async fn stripe_webhook(
|
|
State(state): State<AppState>,
|
|
headers: HeaderMap,
|
|
body: String,
|
|
) -> Result<impl IntoResponse, ApiError> {
|
|
let started = Instant::now();
|
|
verify_stripe_signature(state.stripe_webhook_secret.as_deref().map(|s| s.as_str()), &headers, &body)?;
|
|
|
|
let event: StripeWebhookEvent = serde_json::from_str(&body).map_err(ApiError::bad_request)?;
|
|
let ent = PlanEntitlement {
|
|
tier: event.tier.clone(),
|
|
max_tunnels: event.max_tunnels,
|
|
};
|
|
|
|
if let Some(pg) = &state.pg {
|
|
upsert_subscription_from_webhook(pg, &event).await?;
|
|
}
|
|
sync_plan_cache(&state, &event.user_id, &ent, "stripe_webhook").await?;
|
|
|
|
metrics::counter!("stripe_webhook_events_total", "event_type" => event.event_type.clone())
|
|
.increment(1);
|
|
metrics::histogram!("stripe_webhook_latency_ms")
|
|
.record(started.elapsed().as_secs_f64() * 1000.0);
|
|
Ok(StatusCode::NO_CONTENT)
|
|
}
|
|
|
|
async fn sync_jwt_cache(state: &AppState, claims: &Claims) -> Result<(), ApiError> {
|
|
if let Some(mut redis) = state.redis.clone() {
|
|
let key = format!("auth:jwt:jti:{}", claims.jti);
|
|
let ttl = (claims.exp as i64 - Utc::now().timestamp()).max(1);
|
|
let payload = serde_json::json!({
|
|
"user_id": claims.sub,
|
|
"plan_tier": claims.tier,
|
|
"max_tunnels": claims.max_tunnels
|
|
})
|
|
.to_string();
|
|
let _: () = redis
|
|
.set_ex(key, payload, ttl as u64)
|
|
.await
|
|
.map_err(ApiError::internal)?;
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
async fn sync_plan_cache(
|
|
state: &AppState,
|
|
user_id: &str,
|
|
ent: &PlanEntitlement,
|
|
source: &str,
|
|
) -> Result<(), ApiError> {
|
|
if let Some(mut redis) = state.redis.clone() {
|
|
let key = format!("plan:user:{user_id}");
|
|
let payload = serde_json::json!({
|
|
"tier": ent.tier,
|
|
"max_tunnels": ent.max_tunnels,
|
|
"source": source,
|
|
"updated_at": Utc::now().timestamp()
|
|
})
|
|
.to_string();
|
|
let _: () = redis.set_ex(key, payload, 300).await.map_err(ApiError::internal)?;
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
async fn load_plan_for_user(state: &AppState, user_id: &str) -> Result<Option<PlanEntitlement>, ApiError> {
|
|
let Some(pg) = &state.pg else {
|
|
return Ok(None);
|
|
};
|
|
let _ = Uuid::parse_str(user_id).map_err(ApiError::bad_request)?;
|
|
let row = pg
|
|
.query_opt(
|
|
r#"
|
|
select p.id as plan_id, p.max_tunnels
|
|
from subscriptions s
|
|
join plans p on p.id = s.plan_id
|
|
where s.user_id = $1::uuid and s.status in ('active', 'trialing')
|
|
order by s.updated_at desc
|
|
limit 1
|
|
"#,
|
|
&[&user_id],
|
|
)
|
|
.await
|
|
.map_err(ApiError::internal)?;
|
|
|
|
Ok(row.map(|r| PlanEntitlement {
|
|
tier: r.get::<_, String>("plan_id"),
|
|
max_tunnels: r.get::<_, i32>("max_tunnels") as u32,
|
|
}))
|
|
}
|
|
|
|
async fn upsert_subscription_from_webhook(pg: &PgClient, event: &StripeWebhookEvent) -> Result<(), ApiError> {
|
|
let _ = Uuid::parse_str(&event.user_id).map_err(ApiError::bad_request)?;
|
|
|
|
let period_end = event
|
|
.current_period_end
|
|
.and_then(|ts| Utc.timestamp_opt(ts, 0).single())
|
|
.map(|dt| dt.naive_utc());
|
|
|
|
let provider_sub_id = event
|
|
.provider_subscription_id
|
|
.clone()
|
|
.unwrap_or_else(|| format!("sub-dev-{}", fastrand::u64(..)));
|
|
|
|
pg.execute(
|
|
r#"
|
|
insert into subscriptions (
|
|
user_id, plan_id, provider, provider_customer_id, provider_subscription_id, status,
|
|
current_period_end, updated_at
|
|
)
|
|
values ($1::uuid, $2, 'stripe', $3, $4, $5, $6, now())
|
|
on conflict (provider_subscription_id)
|
|
do update set
|
|
user_id = excluded.user_id,
|
|
plan_id = excluded.plan_id,
|
|
provider_customer_id = excluded.provider_customer_id,
|
|
status = excluded.status,
|
|
current_period_end = excluded.current_period_end,
|
|
updated_at = now()
|
|
"#,
|
|
&[
|
|
&event.user_id,
|
|
&event.tier,
|
|
&event.provider_customer_id,
|
|
&provider_sub_id,
|
|
&event.status,
|
|
&period_end,
|
|
],
|
|
)
|
|
.await
|
|
.map_err(ApiError::internal)?;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
fn verify_stripe_signature(
|
|
secret: Option<&str>,
|
|
headers: &HeaderMap,
|
|
body: &str,
|
|
) -> Result<(), ApiError> {
|
|
let Some(secret) = secret else {
|
|
return Ok(());
|
|
};
|
|
|
|
let header = headers
|
|
.get("stripe-signature")
|
|
.and_then(|h| h.to_str().ok())
|
|
.ok_or_else(|| ApiError::unauthorized("missing stripe-signature header"))?;
|
|
|
|
let mut timestamp: Option<i64> = None;
|
|
let mut sigs: Vec<&str> = Vec::new();
|
|
for part in header.split(',') {
|
|
let mut kv = part.trim().splitn(2, '=');
|
|
let k = kv.next().unwrap_or_default();
|
|
let v = kv.next().unwrap_or_default();
|
|
match k {
|
|
"t" => timestamp = v.parse::<i64>().ok(),
|
|
"v1" => sigs.push(v),
|
|
_ => {}
|
|
}
|
|
}
|
|
|
|
let ts = timestamp.ok_or_else(|| ApiError::unauthorized("invalid stripe signature timestamp"))?;
|
|
let now = Utc::now().timestamp();
|
|
if (now - ts).abs() > 300 {
|
|
return Err(ApiError::unauthorized("stale stripe signature"));
|
|
}
|
|
|
|
let signed_payload = format!("{ts}.{body}");
|
|
let mut mac = HmacSha256::new_from_slice(secret.as_bytes()).map_err(ApiError::internal)?;
|
|
mac.update(signed_payload.as_bytes());
|
|
let expected = mac.finalize().into_bytes();
|
|
|
|
for sig in sigs {
|
|
let Ok(bytes) = hex::decode(sig) else { continue; };
|
|
let mut verify_mac = HmacSha256::new_from_slice(secret.as_bytes()).map_err(ApiError::internal)?;
|
|
verify_mac.update(signed_payload.as_bytes());
|
|
if verify_mac.verify_slice(&bytes).is_ok() {
|
|
return Ok(());
|
|
}
|
|
}
|
|
|
|
let _ = expected;
|
|
Err(ApiError::unauthorized("stripe signature verification failed"))
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
struct ApiError {
|
|
status: StatusCode,
|
|
message: String,
|
|
}
|
|
|
|
impl ApiError {
|
|
fn internal<E: std::fmt::Display>(e: E) -> Self {
|
|
Self {
|
|
status: StatusCode::INTERNAL_SERVER_ERROR,
|
|
message: e.to_string(),
|
|
}
|
|
}
|
|
|
|
fn bad_request<E: std::fmt::Display>(e: E) -> Self {
|
|
Self {
|
|
status: StatusCode::BAD_REQUEST,
|
|
message: e.to_string(),
|
|
}
|
|
}
|
|
|
|
fn unauthorized<E: Into<String>>(msg: E) -> Self {
|
|
Self {
|
|
status: StatusCode::UNAUTHORIZED,
|
|
message: msg.into(),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl IntoResponse for ApiError {
|
|
fn into_response(self) -> axum::response::Response {
|
|
(self.status, self.message).into_response()
|
|
}
|
|
}
|
|
|
|
fn default_tier() -> String {
|
|
"free".to_string()
|
|
}
|
|
|
|
fn default_max_tunnels() -> u32 {
|
|
1
|
|
}
|
|
|
|
fn default_subscription_status() -> String {
|
|
"active".to_string()
|
|
}
|