diff --git a/Cargo.lock b/Cargo.lock index 35f8333..02e9a4e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -47,6 +47,17 @@ dependencies = [ "rustversion", ] +[[package]] +name = "async-trait" +version = "0.1.89" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9035ad2d096bed7955a320ee7e2230574d28fd3c3a0f186cbea1ff3c7eed5dbb" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "atomic-waker" version = "1.1.2" @@ -61,15 +72,20 @@ dependencies = [ "axum", "chrono", "fastrand", + "hex", + "hmac", "jsonwebtoken", "metrics", "metrics-exporter-prometheus", "redis", "serde", "serde_json", + "sha2", "tokio", + "tokio-postgres", "tracing", "tracing-subscriber", + "uuid", ] [[package]] @@ -173,12 +189,27 @@ version = "2.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "843867be96c8daad0d758b57df9392b6d8d271134fce549de6ce169ff98a92af" +[[package]] +name = "block-buffer" +version = "0.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71" +dependencies = [ + "generic-array", +] + [[package]] name = "bumpalo" version = "3.20.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5d20789868f4b01b2f2caec9f5c4e0213b41e3e5702a50157d699ae31ced2fcb" +[[package]] +name = "byteorder" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" + [[package]] name = "bytes" version = "1.11.1" @@ -278,6 +309,15 @@ version = "0.8.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" +[[package]] +name = "cpufeatures" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59ed5838eebb26a2bb2e58f6d5b5316989ae9d08bab10e0e6d103e656d1b0280" +dependencies = [ + "libc", +] + [[package]] name = "crossbeam-epoch" version = "0.9.18" @@ -293,6 +333,16 @@ version = "0.8.21" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" +[[package]] +name = "crypto-common" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "78c8292055d1c1df0cce5d180393dc8cce0abec0a7102adb6c7b1eef6016d60a" +dependencies = [ + "generic-array", + "typenum", +] + [[package]] name = "deranged" version = "0.5.8" @@ -302,6 +352,17 @@ dependencies = [ "powerfmt", ] +[[package]] +name = "digest" +version = "0.10.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" +dependencies = [ + "block-buffer", + "crypto-common", + "subtle", +] + [[package]] name = "displaydoc" version = "0.2.5" @@ -335,6 +396,12 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "fallible-iterator" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4443176a9f2c162692bd3d352d745ef9413eec5782a80d8fd6f8a1ac692a07f7" + [[package]] name = "fastrand" version = "2.3.0" @@ -387,6 +454,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "07bbe89c50d7a535e539b8c17bc0b49bdb77747034daa8087407d655f3f7cc1d" dependencies = [ "futures-core", + "futures-sink", ] [[package]] @@ -420,6 +488,16 @@ dependencies = [ "slab", ] +[[package]] +name = "generic-array" +version = "0.14.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a" +dependencies = [ + "typenum", + "version_check", +] + [[package]] name = "getrandom" version = "0.2.17" @@ -428,7 +506,7 @@ checksum = "ff2abc00be7fca6ebc474524697ae276ad847ad0a6b3faa4bcb027e9a4614ad0" dependencies = [ "cfg-if", "libc", - "wasi", + "wasi 0.11.1+wasi-snapshot-preview1", ] [[package]] @@ -499,6 +577,21 @@ version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" +[[package]] +name = "hex" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" + +[[package]] +name = "hmac" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c49c37c09c17a53d937dfbb742eb3a961d65a994e6bcdcf37e7399d0cc8ab5e" +dependencies = [ + "digest", +] + [[package]] name = "http" version = "1.4.0" @@ -814,6 +907,16 @@ version = "0.2.182" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6800badb6cb2082ffd7b6a67e6125bb39f18782f793520caee8cb8846be06112" +[[package]] +name = "libredox" +version = "0.1.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d0b95e02c851351f877147b7deea7b1afb1df71b63aa5f8270716e0c5720616" +dependencies = [ + "bitflags", + "libc", +] + [[package]] name = "litemap" version = "0.8.1" @@ -850,6 +953,16 @@ version = "0.8.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "47e1ffaa40ddd1f3ed91f717a33c8c0ee23fff369e3aa8772b9605cc1d22f4c3" +[[package]] +name = "md-5" +version = "0.10.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d89e7ee0cfbedfc4da3340218492196241d89eefb6dab27de5df917a6d2e78cf" +dependencies = [ + "cfg-if", + "digest", +] + [[package]] name = "memchr" version = "2.8.0" @@ -916,7 +1029,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a69bcab0ad47271a0234d9422b131806bf3968021e5dc9328caf2d4cd58557fc" dependencies = [ "libc", - "wasi", + "wasi 0.11.1+wasi-snapshot-preview1", "windows-sys 0.61.2", ] @@ -963,6 +1076,24 @@ dependencies = [ "autocfg", ] +[[package]] +name = "objc2-core-foundation" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2a180dd8642fa45cdb7dd721cd4c11b1cadd4929ce112ebd8b9f5803cc79d536" +dependencies = [ + "bitflags", +] + +[[package]] +name = "objc2-system-configuration" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7216bd11cbda54ccabcab84d523dc93b858ec75ecfb3a7d89513fa22464da396" +dependencies = [ + "objc2-core-foundation", +] + [[package]] name = "once_cell" version = "1.21.3" @@ -1014,6 +1145,25 @@ version = "2.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9b4f627cb1b25917193a259e49bdad08f671f8d9708acfd5fe0a8c1455d87220" +[[package]] +name = "phf" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c1562dc717473dbaa4c1f85a36410e03c047b2e7df7f45ee938fbef64ae7fadf" +dependencies = [ + "phf_shared", + "serde", +] + +[[package]] +name = "phf_shared" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e57fef6bc5981e38c2ce2d63bfa546861309f875b8a75f092d1d54ae2d64f266" +dependencies = [ + "siphasher", +] + [[package]] name = "pin-project-lite" version = "0.2.16" @@ -1032,6 +1182,38 @@ version = "1.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c33a9471896f1c69cecef8d20cbe2f7accd12527ce60845ff44c153bb2a21b49" +[[package]] +name = "postgres-protocol" +version = "0.6.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3ee9dd5fe15055d2b6806f4736aa0c9637217074e224bbec46d4041b91bb9491" +dependencies = [ + "base64", + "byteorder", + "bytes", + "fallible-iterator", + "hmac", + "md-5", + "memchr", + "rand", + "sha2", + "stringprep", +] + +[[package]] +name = "postgres-types" +version = "0.2.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "54b858f82211e84682fecd373f68e1ceae642d8d751a1ebd13f33de6257b3e20" +dependencies = [ + "bytes", + "chrono", + "fallible-iterator", + "postgres-protocol", + "serde_core", + "serde_json", +] + [[package]] name = "potential_utf" version = "0.1.4" @@ -1085,7 +1267,7 @@ dependencies = [ "libc", "once_cell", "raw-cpuid", - "wasi", + "wasi 0.11.1+wasi-snapshot-preview1", "web-sys", "winapi", ] @@ -1420,6 +1602,17 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bbfa15b3dddfee50a0fff136974b3e1bde555604ba463834a7eb7deb6417705d" +[[package]] +name = "sha2" +version = "0.10.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a7507d819769d01a365ab707794a4084392c824f54a7a6a7862f8c3d0892b283" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + [[package]] name = "sharded-slab" version = "0.1.7" @@ -1466,6 +1659,12 @@ dependencies = [ "time", ] +[[package]] +name = "siphasher" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b2aa850e253778c88a04c3d7323b043aeda9d3e30d5971937c1855769763678e" + [[package]] name = "sketches-ddsketch" version = "0.3.0" @@ -1500,6 +1699,17 @@ version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6ce2be8dc25455e1f91df71bfa12ad37d7af1092ae736f3a6cd0e37bc7810596" +[[package]] +name = "stringprep" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b4df3d392d81bd458a8a621b8bffbd2302a12ffe288a9d931670948749463b1" +dependencies = [ + "unicode-bidi", + "unicode-normalization", + "unicode-properties", +] + [[package]] name = "subtle" version = "2.6.1" @@ -1604,6 +1814,21 @@ dependencies = [ "zerovec", ] +[[package]] +name = "tinyvec" +version = "1.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfa5fdc3bce6191a1dbc8c02d5c8bffcf557bafa17c124c5264a458f1b0613fa" +dependencies = [ + "tinyvec_macros", +] + +[[package]] +name = "tinyvec_macros" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" + [[package]] name = "tokio" version = "1.49.0" @@ -1632,6 +1857,32 @@ dependencies = [ "syn", ] +[[package]] +name = "tokio-postgres" +version = "0.7.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dcea47c8f71744367793f16c2db1f11cb859d28f436bdb4ca9193eb1f787ee42" +dependencies = [ + "async-trait", + "byteorder", + "bytes", + "fallible-iterator", + "futures-channel", + "futures-util", + "log", + "parking_lot", + "percent-encoding", + "phf", + "pin-project-lite", + "postgres-protocol", + "postgres-types", + "rand", + "socket2", + "tokio", + "tokio-util", + "whoami", +] + [[package]] name = "tokio-rustls" version = "0.26.4" @@ -1751,12 +2002,39 @@ version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" +[[package]] +name = "typenum" +version = "1.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "562d481066bde0658276a35467c4af00bdc6ee726305698a55b86e61d7ad82bb" + +[[package]] +name = "unicode-bidi" +version = "0.3.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c1cb5db39152898a79168971543b1cb5020dff7fe43c8dc468b0885f5e29df5" + [[package]] name = "unicode-ident" version = "1.0.24" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e6e4313cd5fcd3dad5cafa179702e2b244f760991f45397d14d4ebf38247da75" +[[package]] +name = "unicode-normalization" +version = "0.1.25" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5fd4f6878c9cb28d874b009da9e8d183b5abc80117c40bbd187a1fde336be6e8" +dependencies = [ + "tinyvec", +] + +[[package]] +name = "unicode-properties" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7df058c713841ad818f1dc5d3fd88063241cc61f49f5fbea4b951e8cf5a8d71d" + [[package]] name = "unicode-xid" version = "0.2.6" @@ -1826,6 +2104,15 @@ version = "0.11.1+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ccf3ec651a847eb01de73ccad15eb7d99f80485de043efb2f370cd654f4ea44b" +[[package]] +name = "wasi" +version = "0.14.7+wasi-0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "883478de20367e224c0090af9cf5f9fa85bed63a95c1abf3afc5c083ebc06e8c" +dependencies = [ + "wasip2", +] + [[package]] name = "wasip2" version = "1.0.2+wasi-0.2.9" @@ -1844,6 +2131,15 @@ dependencies = [ "wit-bindgen", ] +[[package]] +name = "wasite" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "66fe902b4a6b8028a753d5424909b764ccf79b7a209eac9bf97e59cda9f71a42" +dependencies = [ + "wasi 0.14.7+wasi-0.2.4", +] + [[package]] name = "wasm-bindgen" version = "0.2.111" @@ -1933,6 +2229,19 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "whoami" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d6a5b12f9df4f978d2cfdb1bd3bac52433f44393342d7ee9c25f5a1c14c0f45d" +dependencies = [ + "libc", + "libredox", + "objc2-system-configuration", + "wasite", + "web-sys", +] + [[package]] name = "winapi" version = "0.3.9" diff --git a/Cargo.toml b/Cargo.toml index 9b4cb50..46f8b46 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,3 +21,7 @@ jsonwebtoken = "10" chrono = { version = "0.4", features = ["serde", "clock"] } metrics = "0.24" metrics-exporter-prometheus = "0.17" +tokio-postgres = { version = "0.7", features = ["with-chrono-0_4", "with-serde_json-1"] } +hmac = "0.12" +sha2 = "0.10" +hex = "0.4" diff --git a/auth-api/Cargo.toml b/auth-api/Cargo.toml index a732b6a..acb9f54 100644 --- a/auth-api/Cargo.toml +++ b/auth-api/Cargo.toml @@ -17,3 +17,8 @@ tracing-subscriber.workspace = true fastrand.workspace = true metrics.workspace = true metrics-exporter-prometheus.workspace = true +tokio-postgres.workspace = true +hmac.workspace = true +sha2.workspace = true +hex.workspace = true +uuid.workspace = true diff --git a/auth-api/src/main.rs b/auth-api/src/main.rs index e37639e..f6931a7 100644 --- a/auth-api/src/main.rs +++ b/auth-api/src/main.rs @@ -1,24 +1,32 @@ -use std::{net::SocketAddr, sync::Arc}; +use std::{net::SocketAddr, sync::Arc, time::Instant}; use anyhow::{Context, Result}; use axum::{ Json, Router, extract::State, - http::StatusCode, + http::{HeaderMap, StatusCode}, response::IntoResponse, routing::{get, post}, }; -use chrono::{Duration, Utc}; +use chrono::{Duration, TimeZone, Utc}; +use hmac::{Hmac, Mac}; use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, Validation, decode, encode}; use metrics_exporter_prometheus::{PrometheusBuilder, PrometheusHandle}; use redis::AsyncCommands; use serde::{Deserialize, Serialize}; +use sha2::Sha256; +use tokio_postgres::{Client as PgClient, NoTls}; use tracing::{info, warn}; +use uuid::Uuid; + +type HmacSha256 = Hmac; #[derive(Clone)] struct AppState { jwt_secret: Arc, + stripe_webhook_secret: Option>, redis: Option, + pg: Option>, metrics: PrometheusHandle, } @@ -67,6 +75,17 @@ struct StripeWebhookEvent { tier: String, #[serde(default = "default_max_tunnels")] max_tunnels: u32, + #[serde(default = "default_subscription_status")] + status: String, + provider_customer_id: Option, + provider_subscription_id: Option, + current_period_end: Option, +} + +#[derive(Debug, Clone)] +struct PlanEntitlement { + tier: String, + max_tunnels: u32, } #[tokio::main] @@ -80,9 +99,11 @@ async fn main() -> Result<()> { let bind = std::env::var("AUTH_BIND").unwrap_or_else(|_| "0.0.0.0:8080".into()); let jwt_secret = std::env::var("JWT_SECRET").unwrap_or_else(|_| "dev-secret-change-me".into()); + let stripe_webhook_secret = std::env::var("STRIPE_WEBHOOK_SECRET").ok().map(Arc::new); let metrics = PrometheusBuilder::new() .install_recorder() .context("install prometheus recorder")?; + let redis = if let Ok(url) = std::env::var("REDIS_URL") { let client = redis::Client::open(url.clone()).context("open redis client")?; match redis::aio::ConnectionManager::new(client).await { @@ -99,9 +120,13 @@ async fn main() -> Result<()> { None }; + let pg = connect_postgres().await?; + let state = AppState { jwt_secret: Arc::new(jwt_secret), + stripe_webhook_secret, redis, + pg, metrics, }; @@ -122,6 +147,22 @@ async fn main() -> Result<()> { Ok(()) } +async fn connect_postgres() -> Result>> { + let Some(url) = std::env::var("DATABASE_URL").ok() else { + return Ok(None); + }; + let (client, conn) = tokio_postgres::connect(&url, NoTls) + .await + .context("connect postgres")?; + tokio::spawn(async move { + if let Err(e) = conn.await { + warn!(error = %e, "postgres connection task ended"); + } + }); + info!("auth-api connected to postgres"); + Ok(Some(Arc::new(client))) +} + async fn healthz() -> &'static str { "ok" } @@ -135,7 +176,7 @@ async fn issue_dev_token( State(state): State, Json(req): Json, ) -> Result, ApiError> { - let started = std::time::Instant::now(); + let started = Instant::now(); let now = Utc::now(); let exp = now + Duration::hours(24); let claims = Claims { @@ -154,20 +195,7 @@ async fn issue_dev_token( ) .map_err(ApiError::internal)?; - if let Some(mut redis) = state.redis.clone() { - let key = format!("auth:jwt:jti:{}", claims.jti); - let ttl = (claims.exp as i64 - now.timestamp()).max(1); - let payload = serde_json::json!({ - "user_id": claims.sub, - "plan_tier": claims.tier, - "max_tunnels": claims.max_tunnels - }) - .to_string(); - let _: () = redis - .set_ex(key, payload, ttl as u64) - .await - .map_err(ApiError::internal)?; - } + sync_jwt_cache(&state, &claims).await?; metrics::counter!("auth_jwt_issued_total").increment(1); metrics::histogram!("auth_issue_token_latency_ms") .record(started.elapsed().as_secs_f64() * 1000.0); @@ -183,72 +211,232 @@ async fn validate_token( State(state): State, Json(req): Json, ) -> Result, ApiError> { - let started = std::time::Instant::now(); + let started = Instant::now(); let decoded = decode::( &req.token, &DecodingKey::from_secret(state.jwt_secret.as_bytes()), &Validation::new(Algorithm::HS256), ); - match decoded { + let response = match decoded { Ok(tok) => { - let c = tok.claims; - if let Some(mut redis) = state.redis.clone() { - let key = format!("plan:user:{}", c.sub); - let payload = serde_json::json!({ - "tier": c.tier, - "max_tunnels": c.max_tunnels, - "source": "auth-api" - }) - .to_string(); - let _: () = redis.set_ex(key, payload, 300).await.map_err(ApiError::internal)?; - } + let claims = tok.claims; + let ent = match load_plan_for_user(&state, &claims.sub).await? { + Some(db_plan) => db_plan, + None => PlanEntitlement { + tier: claims.tier.clone(), + max_tunnels: claims.max_tunnels, + }, + }; + sync_plan_cache(&state, &claims.sub, &ent, "auth-validate").await?; metrics::counter!("auth_token_validate_total", "result" => "valid").increment(1); - metrics::histogram!("auth_validate_token_latency_ms") - .record(started.elapsed().as_secs_f64() * 1000.0); - Ok(Json(ValidateResponse { + ValidateResponse { valid: true, - user_id: Some(c.sub), - tier: Some(c.tier), - max_tunnels: Some(c.max_tunnels), - })) + user_id: Some(claims.sub), + tier: Some(ent.tier), + max_tunnels: Some(ent.max_tunnels), + } } - Err(_) => Ok(Json(ValidateResponse { - valid: false, - user_id: None, - tier: None, - max_tunnels: None, - })).map(|resp| { + Err(_) => { metrics::counter!("auth_token_validate_total", "result" => "invalid").increment(1); - metrics::histogram!("auth_validate_token_latency_ms") - .record(started.elapsed().as_secs_f64() * 1000.0); - resp - }), - } + ValidateResponse { + valid: false, + user_id: None, + tier: None, + max_tunnels: None, + } + } + }; + + metrics::histogram!("auth_validate_token_latency_ms") + .record(started.elapsed().as_secs_f64() * 1000.0); + Ok(Json(response)) } -#[tracing::instrument(skip(state))] +#[tracing::instrument(skip(state, headers, body))] async fn stripe_webhook( State(state): State, - Json(event): Json, + headers: HeaderMap, + body: String, ) -> Result { - let started = std::time::Instant::now(); + let started = Instant::now(); + verify_stripe_signature(state.stripe_webhook_secret.as_deref().map(|s| s.as_str()), &headers, &body)?; + + let event: StripeWebhookEvent = serde_json::from_str(&body).map_err(ApiError::bad_request)?; + let ent = PlanEntitlement { + tier: event.tier.clone(), + max_tunnels: event.max_tunnels, + }; + + if let Some(pg) = &state.pg { + upsert_subscription_from_webhook(pg, &event).await?; + } + sync_plan_cache(&state, &event.user_id, &ent, "stripe_webhook").await?; + + metrics::counter!("stripe_webhook_events_total", "event_type" => event.event_type.clone()) + .increment(1); + metrics::histogram!("stripe_webhook_latency_ms") + .record(started.elapsed().as_secs_f64() * 1000.0); + Ok(StatusCode::NO_CONTENT) +} + +async fn sync_jwt_cache(state: &AppState, claims: &Claims) -> Result<(), ApiError> { if let Some(mut redis) = state.redis.clone() { - let key = format!("plan:user:{}", event.user_id); + let key = format!("auth:jwt:jti:{}", claims.jti); + let ttl = (claims.exp as i64 - Utc::now().timestamp()).max(1); let payload = serde_json::json!({ - "tier": event.tier, - "max_tunnels": event.max_tunnels, - "source": "stripe_webhook", - "last_event_type": event.event_type, - "updated_at": Utc::now().timestamp(), + "user_id": claims.sub, + "plan_tier": claims.tier, + "max_tunnels": claims.max_tunnels + }) + .to_string(); + let _: () = redis + .set_ex(key, payload, ttl as u64) + .await + .map_err(ApiError::internal)?; + } + Ok(()) +} + +async fn sync_plan_cache( + state: &AppState, + user_id: &str, + ent: &PlanEntitlement, + source: &str, +) -> Result<(), ApiError> { + if let Some(mut redis) = state.redis.clone() { + let key = format!("plan:user:{user_id}"); + let payload = serde_json::json!({ + "tier": ent.tier, + "max_tunnels": ent.max_tunnels, + "source": source, + "updated_at": Utc::now().timestamp() }) .to_string(); let _: () = redis.set_ex(key, payload, 300).await.map_err(ApiError::internal)?; } - metrics::counter!("stripe_webhook_events_total", "event_type" => event.event_type.clone()).increment(1); - metrics::histogram!("stripe_webhook_latency_ms") - .record(started.elapsed().as_secs_f64() * 1000.0); - Ok(StatusCode::NO_CONTENT) + Ok(()) +} + +async fn load_plan_for_user(state: &AppState, user_id: &str) -> Result, ApiError> { + let Some(pg) = &state.pg else { + return Ok(None); + }; + let _ = Uuid::parse_str(user_id).map_err(ApiError::bad_request)?; + let row = pg + .query_opt( + r#" + select p.id as plan_id, p.max_tunnels + from subscriptions s + join plans p on p.id = s.plan_id + where s.user_id = $1::uuid and s.status in ('active', 'trialing') + order by s.updated_at desc + limit 1 + "#, + &[&user_id], + ) + .await + .map_err(ApiError::internal)?; + + Ok(row.map(|r| PlanEntitlement { + tier: r.get::<_, String>("plan_id"), + max_tunnels: r.get::<_, i32>("max_tunnels") as u32, + })) +} + +async fn upsert_subscription_from_webhook(pg: &PgClient, event: &StripeWebhookEvent) -> Result<(), ApiError> { + let _ = Uuid::parse_str(&event.user_id).map_err(ApiError::bad_request)?; + + let period_end = event + .current_period_end + .and_then(|ts| Utc.timestamp_opt(ts, 0).single()) + .map(|dt| dt.naive_utc()); + + let provider_sub_id = event + .provider_subscription_id + .clone() + .unwrap_or_else(|| format!("sub-dev-{}", fastrand::u64(..))); + + pg.execute( + r#" + insert into subscriptions ( + user_id, plan_id, provider, provider_customer_id, provider_subscription_id, status, + current_period_end, updated_at + ) + values ($1::uuid, $2, 'stripe', $3, $4, $5, $6, now()) + on conflict (provider_subscription_id) + do update set + user_id = excluded.user_id, + plan_id = excluded.plan_id, + provider_customer_id = excluded.provider_customer_id, + status = excluded.status, + current_period_end = excluded.current_period_end, + updated_at = now() + "#, + &[ + &event.user_id, + &event.tier, + &event.provider_customer_id, + &provider_sub_id, + &event.status, + &period_end, + ], + ) + .await + .map_err(ApiError::internal)?; + + Ok(()) +} + +fn verify_stripe_signature( + secret: Option<&str>, + headers: &HeaderMap, + body: &str, +) -> Result<(), ApiError> { + let Some(secret) = secret else { + return Ok(()); + }; + + let header = headers + .get("stripe-signature") + .and_then(|h| h.to_str().ok()) + .ok_or_else(|| ApiError::unauthorized("missing stripe-signature header"))?; + + let mut timestamp: Option = None; + let mut sigs: Vec<&str> = Vec::new(); + for part in header.split(',') { + let mut kv = part.trim().splitn(2, '='); + let k = kv.next().unwrap_or_default(); + let v = kv.next().unwrap_or_default(); + match k { + "t" => timestamp = v.parse::().ok(), + "v1" => sigs.push(v), + _ => {} + } + } + + let ts = timestamp.ok_or_else(|| ApiError::unauthorized("invalid stripe signature timestamp"))?; + let now = Utc::now().timestamp(); + if (now - ts).abs() > 300 { + return Err(ApiError::unauthorized("stale stripe signature")); + } + + let signed_payload = format!("{ts}.{body}"); + let mut mac = HmacSha256::new_from_slice(secret.as_bytes()).map_err(ApiError::internal)?; + mac.update(signed_payload.as_bytes()); + let expected = mac.finalize().into_bytes(); + + for sig in sigs { + let Ok(bytes) = hex::decode(sig) else { continue; }; + let mut verify_mac = HmacSha256::new_from_slice(secret.as_bytes()).map_err(ApiError::internal)?; + verify_mac.update(signed_payload.as_bytes()); + if verify_mac.verify_slice(&bytes).is_ok() { + return Ok(()); + } + } + + let _ = expected; + Err(ApiError::unauthorized("stripe signature verification failed")) } #[derive(Debug)] @@ -264,6 +452,20 @@ impl ApiError { message: e.to_string(), } } + + fn bad_request(e: E) -> Self { + Self { + status: StatusCode::BAD_REQUEST, + message: e.to_string(), + } + } + + fn unauthorized>(msg: E) -> Self { + Self { + status: StatusCode::UNAUTHORIZED, + message: msg.into(), + } + } } impl IntoResponse for ApiError { @@ -279,3 +481,7 @@ fn default_tier() -> String { fn default_max_tunnels() -> u32 { 1 } + +fn default_subscription_status() -> String { + "active".to_string() +} diff --git a/db/schema.sql b/db/schema.sql index 13b4f79..8c00ec8 100644 --- a/db/schema.sql +++ b/db/schema.sql @@ -35,6 +35,9 @@ create table if not exists subscriptions ( ); create index if not exists subscriptions_user_id_idx on subscriptions(user_id); +create unique index if not exists subscriptions_provider_subscription_id_uidx + on subscriptions(provider_subscription_id) + where provider_subscription_id is not null; create table if not exists tunnels ( id uuid primary key default gen_random_uuid(),