refactor: remove stripe webhook flow for free product
This commit is contained in:
9
Cargo.lock
generated
9
Cargo.lock
generated
@@ -72,15 +72,12 @@ dependencies = [
|
||||
"axum",
|
||||
"chrono",
|
||||
"fastrand",
|
||||
"hex",
|
||||
"hmac",
|
||||
"jsonwebtoken",
|
||||
"metrics",
|
||||
"metrics-exporter-prometheus",
|
||||
"redis",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"sha2",
|
||||
"tokio",
|
||||
"tokio-postgres",
|
||||
"tracing",
|
||||
@@ -577,12 +574,6 @@ version = "0.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
|
||||
|
||||
[[package]]
|
||||
name = "hex"
|
||||
version = "0.4.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70"
|
||||
|
||||
[[package]]
|
||||
name = "hmac"
|
||||
version = "0.12.1"
|
||||
|
||||
@@ -22,6 +22,3 @@ chrono = { version = "0.4", features = ["serde", "clock"] }
|
||||
metrics = "0.24"
|
||||
metrics-exporter-prometheus = "0.17"
|
||||
tokio-postgres = { version = "0.7", features = ["with-chrono-0_4", "with-serde_json-1"] }
|
||||
hmac = "0.12"
|
||||
sha2 = "0.10"
|
||||
hex = "0.4"
|
||||
|
||||
@@ -18,7 +18,4 @@ fastrand.workspace = true
|
||||
metrics.workspace = true
|
||||
metrics-exporter-prometheus.workspace = true
|
||||
tokio-postgres.workspace = true
|
||||
hmac.workspace = true
|
||||
sha2.workspace = true
|
||||
hex.workspace = true
|
||||
uuid.workspace = true
|
||||
|
||||
@@ -4,27 +4,22 @@ use anyhow::{Context, Result};
|
||||
use axum::{
|
||||
Json, Router,
|
||||
extract::State,
|
||||
http::{HeaderMap, StatusCode},
|
||||
http::StatusCode,
|
||||
response::IntoResponse,
|
||||
routing::{get, post},
|
||||
};
|
||||
use chrono::{Duration, TimeZone, Utc};
|
||||
use hmac::{Hmac, Mac};
|
||||
use chrono::{Duration, Utc};
|
||||
use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, Validation, decode, encode};
|
||||
use metrics_exporter_prometheus::{PrometheusBuilder, PrometheusHandle};
|
||||
use redis::AsyncCommands;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sha2::Sha256;
|
||||
use tokio_postgres::{Client as PgClient, NoTls};
|
||||
use tracing::{info, warn};
|
||||
use uuid::Uuid;
|
||||
|
||||
type HmacSha256 = Hmac<Sha256>;
|
||||
|
||||
#[derive(Clone)]
|
||||
struct AppState {
|
||||
jwt_secret: Arc<String>,
|
||||
stripe_webhook_secret: Option<Arc<String>>,
|
||||
redis: Option<redis::aio::ConnectionManager>,
|
||||
pg: Option<Arc<PgClient>>,
|
||||
metrics: PrometheusHandle,
|
||||
@@ -68,20 +63,6 @@ struct ValidateResponse {
|
||||
max_tunnels: Option<u32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct StripeWebhookEvent {
|
||||
event_type: String,
|
||||
user_id: String,
|
||||
tier: String,
|
||||
#[serde(default = "default_max_tunnels")]
|
||||
max_tunnels: u32,
|
||||
#[serde(default = "default_subscription_status")]
|
||||
status: String,
|
||||
provider_customer_id: Option<String>,
|
||||
provider_subscription_id: Option<String>,
|
||||
current_period_end: Option<i64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct PlanEntitlement {
|
||||
tier: String,
|
||||
@@ -99,7 +80,6 @@ async fn main() -> Result<()> {
|
||||
|
||||
let bind = std::env::var("AUTH_BIND").unwrap_or_else(|_| "0.0.0.0:8080".into());
|
||||
let jwt_secret = std::env::var("JWT_SECRET").unwrap_or_else(|_| "dev-secret-change-me".into());
|
||||
let stripe_webhook_secret = std::env::var("STRIPE_WEBHOOK_SECRET").ok().map(Arc::new);
|
||||
let metrics = PrometheusBuilder::new()
|
||||
.install_recorder()
|
||||
.context("install prometheus recorder")?;
|
||||
@@ -124,7 +104,6 @@ async fn main() -> Result<()> {
|
||||
|
||||
let state = AppState {
|
||||
jwt_secret: Arc::new(jwt_secret),
|
||||
stripe_webhook_secret,
|
||||
redis,
|
||||
pg,
|
||||
metrics,
|
||||
@@ -135,7 +114,6 @@ async fn main() -> Result<()> {
|
||||
.route("/metrics", get(metrics_endpoint))
|
||||
.route("/v1/token/dev", post(issue_dev_token))
|
||||
.route("/v1/token/validate", post(validate_token))
|
||||
.route("/v1/stripe/webhook", post(stripe_webhook))
|
||||
.with_state(state);
|
||||
|
||||
let listener = tokio::net::TcpListener::bind(&bind)
|
||||
@@ -253,33 +231,6 @@ async fn validate_token(
|
||||
Ok(Json(response))
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(state, headers, body))]
|
||||
async fn stripe_webhook(
|
||||
State(state): State<AppState>,
|
||||
headers: HeaderMap,
|
||||
body: String,
|
||||
) -> Result<impl IntoResponse, ApiError> {
|
||||
let started = Instant::now();
|
||||
verify_stripe_signature(state.stripe_webhook_secret.as_deref().map(|s| s.as_str()), &headers, &body)?;
|
||||
|
||||
let event: StripeWebhookEvent = serde_json::from_str(&body).map_err(ApiError::bad_request)?;
|
||||
let ent = PlanEntitlement {
|
||||
tier: event.tier.clone(),
|
||||
max_tunnels: event.max_tunnels,
|
||||
};
|
||||
|
||||
if let Some(pg) = &state.pg {
|
||||
upsert_subscription_from_webhook(pg, &event).await?;
|
||||
}
|
||||
sync_plan_cache(&state, &event.user_id, &ent, "stripe_webhook").await?;
|
||||
|
||||
metrics::counter!("stripe_webhook_events_total", "event_type" => event.event_type.clone())
|
||||
.increment(1);
|
||||
metrics::histogram!("stripe_webhook_latency_ms")
|
||||
.record(started.elapsed().as_secs_f64() * 1000.0);
|
||||
Ok(StatusCode::NO_CONTENT)
|
||||
}
|
||||
|
||||
async fn sync_jwt_cache(state: &AppState, claims: &Claims) -> Result<(), ApiError> {
|
||||
if let Some(mut redis) = state.redis.clone() {
|
||||
let key = format!("auth:jwt:jti:{}", claims.jti);
|
||||
@@ -344,101 +295,6 @@ async fn load_plan_for_user(state: &AppState, user_id: &str) -> Result<Option<Pl
|
||||
}))
|
||||
}
|
||||
|
||||
async fn upsert_subscription_from_webhook(pg: &PgClient, event: &StripeWebhookEvent) -> Result<(), ApiError> {
|
||||
let _ = Uuid::parse_str(&event.user_id).map_err(ApiError::bad_request)?;
|
||||
|
||||
let period_end = event
|
||||
.current_period_end
|
||||
.and_then(|ts| Utc.timestamp_opt(ts, 0).single())
|
||||
.map(|dt| dt.naive_utc());
|
||||
|
||||
let provider_sub_id = event
|
||||
.provider_subscription_id
|
||||
.clone()
|
||||
.unwrap_or_else(|| format!("sub-dev-{}", fastrand::u64(..)));
|
||||
|
||||
pg.execute(
|
||||
r#"
|
||||
insert into subscriptions (
|
||||
user_id, plan_id, provider, provider_customer_id, provider_subscription_id, status,
|
||||
current_period_end, updated_at
|
||||
)
|
||||
values ($1::uuid, $2, 'stripe', $3, $4, $5, $6, now())
|
||||
on conflict (provider_subscription_id)
|
||||
do update set
|
||||
user_id = excluded.user_id,
|
||||
plan_id = excluded.plan_id,
|
||||
provider_customer_id = excluded.provider_customer_id,
|
||||
status = excluded.status,
|
||||
current_period_end = excluded.current_period_end,
|
||||
updated_at = now()
|
||||
"#,
|
||||
&[
|
||||
&event.user_id,
|
||||
&event.tier,
|
||||
&event.provider_customer_id,
|
||||
&provider_sub_id,
|
||||
&event.status,
|
||||
&period_end,
|
||||
],
|
||||
)
|
||||
.await
|
||||
.map_err(ApiError::internal)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn verify_stripe_signature(
|
||||
secret: Option<&str>,
|
||||
headers: &HeaderMap,
|
||||
body: &str,
|
||||
) -> Result<(), ApiError> {
|
||||
let Some(secret) = secret else {
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
let header = headers
|
||||
.get("stripe-signature")
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.ok_or_else(|| ApiError::unauthorized("missing stripe-signature header"))?;
|
||||
|
||||
let mut timestamp: Option<i64> = None;
|
||||
let mut sigs: Vec<&str> = Vec::new();
|
||||
for part in header.split(',') {
|
||||
let mut kv = part.trim().splitn(2, '=');
|
||||
let k = kv.next().unwrap_or_default();
|
||||
let v = kv.next().unwrap_or_default();
|
||||
match k {
|
||||
"t" => timestamp = v.parse::<i64>().ok(),
|
||||
"v1" => sigs.push(v),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
let ts = timestamp.ok_or_else(|| ApiError::unauthorized("invalid stripe signature timestamp"))?;
|
||||
let now = Utc::now().timestamp();
|
||||
if (now - ts).abs() > 300 {
|
||||
return Err(ApiError::unauthorized("stale stripe signature"));
|
||||
}
|
||||
|
||||
let signed_payload = format!("{ts}.{body}");
|
||||
let mut mac = HmacSha256::new_from_slice(secret.as_bytes()).map_err(ApiError::internal)?;
|
||||
mac.update(signed_payload.as_bytes());
|
||||
let expected = mac.finalize().into_bytes();
|
||||
|
||||
for sig in sigs {
|
||||
let Ok(bytes) = hex::decode(sig) else { continue; };
|
||||
let mut verify_mac = HmacSha256::new_from_slice(secret.as_bytes()).map_err(ApiError::internal)?;
|
||||
verify_mac.update(signed_payload.as_bytes());
|
||||
if verify_mac.verify_slice(&bytes).is_ok() {
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
|
||||
let _ = expected;
|
||||
Err(ApiError::unauthorized("stripe signature verification failed"))
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct ApiError {
|
||||
status: StatusCode,
|
||||
@@ -460,12 +316,6 @@ impl ApiError {
|
||||
}
|
||||
}
|
||||
|
||||
fn unauthorized<E: Into<String>>(msg: E) -> Self {
|
||||
Self {
|
||||
status: StatusCode::UNAUTHORIZED,
|
||||
message: msg.into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoResponse for ApiError {
|
||||
@@ -481,7 +331,3 @@ fn default_tier() -> String {
|
||||
fn default_max_tunnels() -> u32 {
|
||||
1
|
||||
}
|
||||
|
||||
fn default_subscription_status() -> String {
|
||||
"active".to_string()
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user