refactor: remove stripe webhook flow for free product

This commit is contained in:
L
2026-02-24 08:26:38 +00:00
parent 28918880da
commit 230a9212fe
4 changed files with 2 additions and 171 deletions

View File

@@ -18,7 +18,4 @@ fastrand.workspace = true
metrics.workspace = true
metrics-exporter-prometheus.workspace = true
tokio-postgres.workspace = true
hmac.workspace = true
sha2.workspace = true
hex.workspace = true
uuid.workspace = true

View File

@@ -4,27 +4,22 @@ use anyhow::{Context, Result};
use axum::{
Json, Router,
extract::State,
http::{HeaderMap, StatusCode},
http::StatusCode,
response::IntoResponse,
routing::{get, post},
};
use chrono::{Duration, TimeZone, Utc};
use hmac::{Hmac, Mac};
use chrono::{Duration, Utc};
use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, Validation, decode, encode};
use metrics_exporter_prometheus::{PrometheusBuilder, PrometheusHandle};
use redis::AsyncCommands;
use serde::{Deserialize, Serialize};
use sha2::Sha256;
use tokio_postgres::{Client as PgClient, NoTls};
use tracing::{info, warn};
use uuid::Uuid;
type HmacSha256 = Hmac<Sha256>;
#[derive(Clone)]
struct AppState {
jwt_secret: Arc<String>,
stripe_webhook_secret: Option<Arc<String>>,
redis: Option<redis::aio::ConnectionManager>,
pg: Option<Arc<PgClient>>,
metrics: PrometheusHandle,
@@ -68,20 +63,6 @@ struct ValidateResponse {
max_tunnels: Option<u32>,
}
#[derive(Debug, Deserialize)]
struct StripeWebhookEvent {
event_type: String,
user_id: String,
tier: String,
#[serde(default = "default_max_tunnels")]
max_tunnels: u32,
#[serde(default = "default_subscription_status")]
status: String,
provider_customer_id: Option<String>,
provider_subscription_id: Option<String>,
current_period_end: Option<i64>,
}
#[derive(Debug, Clone)]
struct PlanEntitlement {
tier: String,
@@ -99,7 +80,6 @@ async fn main() -> Result<()> {
let bind = std::env::var("AUTH_BIND").unwrap_or_else(|_| "0.0.0.0:8080".into());
let jwt_secret = std::env::var("JWT_SECRET").unwrap_or_else(|_| "dev-secret-change-me".into());
let stripe_webhook_secret = std::env::var("STRIPE_WEBHOOK_SECRET").ok().map(Arc::new);
let metrics = PrometheusBuilder::new()
.install_recorder()
.context("install prometheus recorder")?;
@@ -124,7 +104,6 @@ async fn main() -> Result<()> {
let state = AppState {
jwt_secret: Arc::new(jwt_secret),
stripe_webhook_secret,
redis,
pg,
metrics,
@@ -135,7 +114,6 @@ async fn main() -> Result<()> {
.route("/metrics", get(metrics_endpoint))
.route("/v1/token/dev", post(issue_dev_token))
.route("/v1/token/validate", post(validate_token))
.route("/v1/stripe/webhook", post(stripe_webhook))
.with_state(state);
let listener = tokio::net::TcpListener::bind(&bind)
@@ -253,33 +231,6 @@ async fn validate_token(
Ok(Json(response))
}
#[tracing::instrument(skip(state, headers, body))]
async fn stripe_webhook(
State(state): State<AppState>,
headers: HeaderMap,
body: String,
) -> Result<impl IntoResponse, ApiError> {
let started = Instant::now();
verify_stripe_signature(state.stripe_webhook_secret.as_deref().map(|s| s.as_str()), &headers, &body)?;
let event: StripeWebhookEvent = serde_json::from_str(&body).map_err(ApiError::bad_request)?;
let ent = PlanEntitlement {
tier: event.tier.clone(),
max_tunnels: event.max_tunnels,
};
if let Some(pg) = &state.pg {
upsert_subscription_from_webhook(pg, &event).await?;
}
sync_plan_cache(&state, &event.user_id, &ent, "stripe_webhook").await?;
metrics::counter!("stripe_webhook_events_total", "event_type" => event.event_type.clone())
.increment(1);
metrics::histogram!("stripe_webhook_latency_ms")
.record(started.elapsed().as_secs_f64() * 1000.0);
Ok(StatusCode::NO_CONTENT)
}
async fn sync_jwt_cache(state: &AppState, claims: &Claims) -> Result<(), ApiError> {
if let Some(mut redis) = state.redis.clone() {
let key = format!("auth:jwt:jti:{}", claims.jti);
@@ -344,101 +295,6 @@ async fn load_plan_for_user(state: &AppState, user_id: &str) -> Result<Option<Pl
}))
}
async fn upsert_subscription_from_webhook(pg: &PgClient, event: &StripeWebhookEvent) -> Result<(), ApiError> {
let _ = Uuid::parse_str(&event.user_id).map_err(ApiError::bad_request)?;
let period_end = event
.current_period_end
.and_then(|ts| Utc.timestamp_opt(ts, 0).single())
.map(|dt| dt.naive_utc());
let provider_sub_id = event
.provider_subscription_id
.clone()
.unwrap_or_else(|| format!("sub-dev-{}", fastrand::u64(..)));
pg.execute(
r#"
insert into subscriptions (
user_id, plan_id, provider, provider_customer_id, provider_subscription_id, status,
current_period_end, updated_at
)
values ($1::uuid, $2, 'stripe', $3, $4, $5, $6, now())
on conflict (provider_subscription_id)
do update set
user_id = excluded.user_id,
plan_id = excluded.plan_id,
provider_customer_id = excluded.provider_customer_id,
status = excluded.status,
current_period_end = excluded.current_period_end,
updated_at = now()
"#,
&[
&event.user_id,
&event.tier,
&event.provider_customer_id,
&provider_sub_id,
&event.status,
&period_end,
],
)
.await
.map_err(ApiError::internal)?;
Ok(())
}
fn verify_stripe_signature(
secret: Option<&str>,
headers: &HeaderMap,
body: &str,
) -> Result<(), ApiError> {
let Some(secret) = secret else {
return Ok(());
};
let header = headers
.get("stripe-signature")
.and_then(|h| h.to_str().ok())
.ok_or_else(|| ApiError::unauthorized("missing stripe-signature header"))?;
let mut timestamp: Option<i64> = None;
let mut sigs: Vec<&str> = Vec::new();
for part in header.split(',') {
let mut kv = part.trim().splitn(2, '=');
let k = kv.next().unwrap_or_default();
let v = kv.next().unwrap_or_default();
match k {
"t" => timestamp = v.parse::<i64>().ok(),
"v1" => sigs.push(v),
_ => {}
}
}
let ts = timestamp.ok_or_else(|| ApiError::unauthorized("invalid stripe signature timestamp"))?;
let now = Utc::now().timestamp();
if (now - ts).abs() > 300 {
return Err(ApiError::unauthorized("stale stripe signature"));
}
let signed_payload = format!("{ts}.{body}");
let mut mac = HmacSha256::new_from_slice(secret.as_bytes()).map_err(ApiError::internal)?;
mac.update(signed_payload.as_bytes());
let expected = mac.finalize().into_bytes();
for sig in sigs {
let Ok(bytes) = hex::decode(sig) else { continue; };
let mut verify_mac = HmacSha256::new_from_slice(secret.as_bytes()).map_err(ApiError::internal)?;
verify_mac.update(signed_payload.as_bytes());
if verify_mac.verify_slice(&bytes).is_ok() {
return Ok(());
}
}
let _ = expected;
Err(ApiError::unauthorized("stripe signature verification failed"))
}
#[derive(Debug)]
struct ApiError {
status: StatusCode,
@@ -460,12 +316,6 @@ impl ApiError {
}
}
fn unauthorized<E: Into<String>>(msg: E) -> Self {
Self {
status: StatusCode::UNAUTHORIZED,
message: msg.into(),
}
}
}
impl IntoResponse for ApiError {
@@ -481,7 +331,3 @@ fn default_tier() -> String {
fn default_max_tunnels() -> u32 {
1
}
fn default_subscription_status() -> String {
"active".to_string()
}