feat: scaffold relay client auth workspace
This commit is contained in:
17
auth-api/Cargo.toml
Normal file
17
auth-api/Cargo.toml
Normal file
@@ -0,0 +1,17 @@
|
||||
[package]
|
||||
name = "auth-api"
|
||||
version = "0.1.0"
|
||||
edition = "2024"
|
||||
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
axum.workspace = true
|
||||
chrono.workspace = true
|
||||
jsonwebtoken.workspace = true
|
||||
redis.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
tokio.workspace = true
|
||||
tracing.workspace = true
|
||||
tracing-subscriber.workspace = true
|
||||
fastrand.workspace = true
|
||||
250
auth-api/src/main.rs
Normal file
250
auth-api/src/main.rs
Normal file
@@ -0,0 +1,250 @@
|
||||
use std::{net::SocketAddr, sync::Arc};
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use axum::{
|
||||
Json, Router,
|
||||
extract::State,
|
||||
http::StatusCode,
|
||||
response::IntoResponse,
|
||||
routing::{get, post},
|
||||
};
|
||||
use chrono::{Duration, Utc};
|
||||
use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, Validation, decode, encode};
|
||||
use redis::AsyncCommands;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tracing::{info, warn};
|
||||
|
||||
#[derive(Clone)]
|
||||
struct AppState {
|
||||
jwt_secret: Arc<String>,
|
||||
redis: Option<redis::aio::ConnectionManager>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
struct Claims {
|
||||
sub: String,
|
||||
tier: String,
|
||||
max_tunnels: u32,
|
||||
exp: usize,
|
||||
iat: usize,
|
||||
jti: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct DevTokenRequest {
|
||||
user_id: String,
|
||||
#[serde(default = "default_tier")]
|
||||
tier: String,
|
||||
#[serde(default = "default_max_tunnels")]
|
||||
max_tunnels: u32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct TokenResponse {
|
||||
token: String,
|
||||
expires_at: i64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ValidateRequest {
|
||||
token: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct ValidateResponse {
|
||||
valid: bool,
|
||||
user_id: Option<String>,
|
||||
tier: Option<String>,
|
||||
max_tunnels: Option<u32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct StripeWebhookEvent {
|
||||
event_type: String,
|
||||
user_id: String,
|
||||
tier: String,
|
||||
#[serde(default = "default_max_tunnels")]
|
||||
max_tunnels: u32,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
tracing_subscriber::fmt()
|
||||
.with_env_filter(
|
||||
tracing_subscriber::EnvFilter::try_from_default_env()
|
||||
.unwrap_or_else(|_| "auth_api=info".into()),
|
||||
)
|
||||
.init();
|
||||
|
||||
let bind = std::env::var("AUTH_BIND").unwrap_or_else(|_| "0.0.0.0:8080".into());
|
||||
let jwt_secret = std::env::var("JWT_SECRET").unwrap_or_else(|_| "dev-secret-change-me".into());
|
||||
let redis = if let Ok(url) = std::env::var("REDIS_URL") {
|
||||
let client = redis::Client::open(url.clone()).context("open redis client")?;
|
||||
match redis::aio::ConnectionManager::new(client).await {
|
||||
Ok(conn) => {
|
||||
info!("auth-api connected to redis");
|
||||
Some(conn)
|
||||
}
|
||||
Err(e) => {
|
||||
warn!(error = %e, "auth-api redis unavailable; continuing without cache");
|
||||
None
|
||||
}
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let state = AppState {
|
||||
jwt_secret: Arc::new(jwt_secret),
|
||||
redis,
|
||||
};
|
||||
|
||||
let app = Router::new()
|
||||
.route("/healthz", get(healthz))
|
||||
.route("/v1/token/dev", post(issue_dev_token))
|
||||
.route("/v1/token/validate", post(validate_token))
|
||||
.route("/v1/stripe/webhook", post(stripe_webhook))
|
||||
.with_state(state);
|
||||
|
||||
let listener = tokio::net::TcpListener::bind(&bind)
|
||||
.await
|
||||
.with_context(|| format!("bind {bind}"))?;
|
||||
let local_addr: SocketAddr = listener.local_addr()?;
|
||||
info!(addr = %local_addr, "auth-api listening");
|
||||
axum::serve(listener, app).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn healthz() -> &'static str {
|
||||
"ok"
|
||||
}
|
||||
|
||||
async fn issue_dev_token(
|
||||
State(state): State<AppState>,
|
||||
Json(req): Json<DevTokenRequest>,
|
||||
) -> Result<Json<TokenResponse>, ApiError> {
|
||||
let now = Utc::now();
|
||||
let exp = now + Duration::hours(24);
|
||||
let claims = Claims {
|
||||
sub: req.user_id,
|
||||
tier: req.tier,
|
||||
max_tunnels: req.max_tunnels,
|
||||
exp: exp.timestamp() as usize,
|
||||
iat: now.timestamp() as usize,
|
||||
jti: format!("jti-{}", fastrand::u64(..)),
|
||||
};
|
||||
|
||||
let token = encode(
|
||||
&Header::default(),
|
||||
&claims,
|
||||
&EncodingKey::from_secret(state.jwt_secret.as_bytes()),
|
||||
)
|
||||
.map_err(ApiError::internal)?;
|
||||
|
||||
if let Some(mut redis) = state.redis.clone() {
|
||||
let key = format!("auth:jwt:jti:{}", claims.jti);
|
||||
let ttl = (claims.exp as i64 - now.timestamp()).max(1);
|
||||
let payload = serde_json::json!({
|
||||
"user_id": claims.sub,
|
||||
"plan_tier": claims.tier,
|
||||
"max_tunnels": claims.max_tunnels
|
||||
})
|
||||
.to_string();
|
||||
let _: () = redis
|
||||
.set_ex(key, payload, ttl as u64)
|
||||
.await
|
||||
.map_err(ApiError::internal)?;
|
||||
}
|
||||
|
||||
Ok(Json(TokenResponse {
|
||||
token,
|
||||
expires_at: exp.timestamp(),
|
||||
}))
|
||||
}
|
||||
|
||||
async fn validate_token(
|
||||
State(state): State<AppState>,
|
||||
Json(req): Json<ValidateRequest>,
|
||||
) -> Result<Json<ValidateResponse>, ApiError> {
|
||||
let decoded = decode::<Claims>(
|
||||
&req.token,
|
||||
&DecodingKey::from_secret(state.jwt_secret.as_bytes()),
|
||||
&Validation::new(Algorithm::HS256),
|
||||
);
|
||||
|
||||
match decoded {
|
||||
Ok(tok) => {
|
||||
let c = tok.claims;
|
||||
if let Some(mut redis) = state.redis.clone() {
|
||||
let key = format!("plan:user:{}", c.sub);
|
||||
let payload = serde_json::json!({
|
||||
"tier": c.tier,
|
||||
"max_tunnels": c.max_tunnels,
|
||||
"source": "auth-api"
|
||||
})
|
||||
.to_string();
|
||||
let _: () = redis.set_ex(key, payload, 300).await.map_err(ApiError::internal)?;
|
||||
}
|
||||
Ok(Json(ValidateResponse {
|
||||
valid: true,
|
||||
user_id: Some(c.sub),
|
||||
tier: Some(c.tier),
|
||||
max_tunnels: Some(c.max_tunnels),
|
||||
}))
|
||||
}
|
||||
Err(_) => Ok(Json(ValidateResponse {
|
||||
valid: false,
|
||||
user_id: None,
|
||||
tier: None,
|
||||
max_tunnels: None,
|
||||
})),
|
||||
}
|
||||
}
|
||||
|
||||
async fn stripe_webhook(
|
||||
State(state): State<AppState>,
|
||||
Json(event): Json<StripeWebhookEvent>,
|
||||
) -> Result<impl IntoResponse, ApiError> {
|
||||
if let Some(mut redis) = state.redis.clone() {
|
||||
let key = format!("plan:user:{}", event.user_id);
|
||||
let payload = serde_json::json!({
|
||||
"tier": event.tier,
|
||||
"max_tunnels": event.max_tunnels,
|
||||
"source": "stripe_webhook",
|
||||
"last_event_type": event.event_type,
|
||||
"updated_at": Utc::now().timestamp(),
|
||||
})
|
||||
.to_string();
|
||||
let _: () = redis.set_ex(key, payload, 300).await.map_err(ApiError::internal)?;
|
||||
}
|
||||
Ok(StatusCode::NO_CONTENT)
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct ApiError {
|
||||
status: StatusCode,
|
||||
message: String,
|
||||
}
|
||||
|
||||
impl ApiError {
|
||||
fn internal<E: std::fmt::Display>(e: E) -> Self {
|
||||
Self {
|
||||
status: StatusCode::INTERNAL_SERVER_ERROR,
|
||||
message: e.to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoResponse for ApiError {
|
||||
fn into_response(self) -> axum::response::Response {
|
||||
(self.status, self.message).into_response()
|
||||
}
|
||||
}
|
||||
|
||||
fn default_tier() -> String {
|
||||
"free".to_string()
|
||||
}
|
||||
|
||||
fn default_max_tunnels() -> u32 {
|
||||
1
|
||||
}
|
||||
Reference in New Issue
Block a user