Compare commits

..

6 Commits

8 changed files with 1525 additions and 121 deletions

901
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -19,3 +19,6 @@ axum = "0.8"
redis = { version = "0.32", features = ["tokio-comp", "connection-manager"] }
jsonwebtoken = "10"
chrono = { version = "0.4", features = ["serde", "clock"] }
metrics = "0.24"
metrics-exporter-prometheus = "0.17"
tokio-postgres = { version = "0.7", features = ["with-chrono-0_4", "with-serde_json-1"] }

View File

@@ -15,3 +15,7 @@ tokio.workspace = true
tracing.workspace = true
tracing-subscriber.workspace = true
fastrand.workspace = true
metrics.workspace = true
metrics-exporter-prometheus.workspace = true
tokio-postgres.workspace = true
uuid.workspace = true

View File

@@ -1,4 +1,4 @@
use std::{net::SocketAddr, sync::Arc};
use std::{net::SocketAddr, sync::Arc, time::Instant};
use anyhow::{Context, Result};
use axum::{
@@ -10,14 +10,19 @@ use axum::{
};
use chrono::{Duration, Utc};
use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, Validation, decode, encode};
use metrics_exporter_prometheus::{PrometheusBuilder, PrometheusHandle};
use redis::AsyncCommands;
use serde::{Deserialize, Serialize};
use tokio_postgres::{Client as PgClient, NoTls};
use tracing::{info, warn};
use uuid::Uuid;
#[derive(Clone)]
struct AppState {
jwt_secret: Arc<String>,
redis: Option<redis::aio::ConnectionManager>,
pg: Option<Arc<PgClient>>,
metrics: PrometheusHandle,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
@@ -58,12 +63,9 @@ struct ValidateResponse {
max_tunnels: Option<u32>,
}
#[derive(Debug, Deserialize)]
struct StripeWebhookEvent {
event_type: String,
user_id: String,
#[derive(Debug, Clone)]
struct PlanEntitlement {
tier: String,
#[serde(default = "default_max_tunnels")]
max_tunnels: u32,
}
@@ -78,6 +80,10 @@ async fn main() -> Result<()> {
let bind = std::env::var("AUTH_BIND").unwrap_or_else(|_| "0.0.0.0:8080".into());
let jwt_secret = std::env::var("JWT_SECRET").unwrap_or_else(|_| "dev-secret-change-me".into());
let metrics = PrometheusBuilder::new()
.install_recorder()
.context("install prometheus recorder")?;
let redis = if let Ok(url) = std::env::var("REDIS_URL") {
let client = redis::Client::open(url.clone()).context("open redis client")?;
match redis::aio::ConnectionManager::new(client).await {
@@ -94,16 +100,20 @@ async fn main() -> Result<()> {
None
};
let pg = connect_postgres().await?;
let state = AppState {
jwt_secret: Arc::new(jwt_secret),
redis,
pg,
metrics,
};
let app = Router::new()
.route("/healthz", get(healthz))
.route("/metrics", get(metrics_endpoint))
.route("/v1/token/dev", post(issue_dev_token))
.route("/v1/token/validate", post(validate_token))
.route("/v1/stripe/webhook", post(stripe_webhook))
.with_state(state);
let listener = tokio::net::TcpListener::bind(&bind)
@@ -115,14 +125,36 @@ async fn main() -> Result<()> {
Ok(())
}
async fn connect_postgres() -> Result<Option<Arc<PgClient>>> {
let Some(url) = std::env::var("DATABASE_URL").ok() else {
return Ok(None);
};
let (client, conn) = tokio_postgres::connect(&url, NoTls)
.await
.context("connect postgres")?;
tokio::spawn(async move {
if let Err(e) = conn.await {
warn!(error = %e, "postgres connection task ended");
}
});
info!("auth-api connected to postgres");
Ok(Some(Arc::new(client)))
}
async fn healthz() -> &'static str {
"ok"
}
async fn metrics_endpoint(State(state): State<AppState>) -> impl IntoResponse {
state.metrics.render()
}
#[tracing::instrument(skip(state))]
async fn issue_dev_token(
State(state): State<AppState>,
Json(req): Json<DevTokenRequest>,
) -> Result<Json<TokenResponse>, ApiError> {
let started = Instant::now();
let now = Utc::now();
let exp = now + Duration::hours(24);
let claims = Claims {
@@ -141,9 +173,68 @@ async fn issue_dev_token(
)
.map_err(ApiError::internal)?;
sync_jwt_cache(&state, &claims).await?;
metrics::counter!("auth_jwt_issued_total").increment(1);
metrics::histogram!("auth_issue_token_latency_ms")
.record(started.elapsed().as_secs_f64() * 1000.0);
Ok(Json(TokenResponse {
token,
expires_at: exp.timestamp(),
}))
}
#[tracing::instrument(skip(state, req))]
async fn validate_token(
State(state): State<AppState>,
Json(req): Json<ValidateRequest>,
) -> Result<Json<ValidateResponse>, ApiError> {
let started = Instant::now();
let decoded = decode::<Claims>(
&req.token,
&DecodingKey::from_secret(state.jwt_secret.as_bytes()),
&Validation::new(Algorithm::HS256),
);
let response = match decoded {
Ok(tok) => {
let claims = tok.claims;
let ent = match load_plan_for_user(&state, &claims.sub).await? {
Some(db_plan) => db_plan,
None => PlanEntitlement {
tier: claims.tier.clone(),
max_tunnels: claims.max_tunnels,
},
};
sync_plan_cache(&state, &claims.sub, &ent, "auth-validate").await?;
metrics::counter!("auth_token_validate_total", "result" => "valid").increment(1);
ValidateResponse {
valid: true,
user_id: Some(claims.sub),
tier: Some(ent.tier),
max_tunnels: Some(ent.max_tunnels),
}
}
Err(_) => {
metrics::counter!("auth_token_validate_total", "result" => "invalid").increment(1);
ValidateResponse {
valid: false,
user_id: None,
tier: None,
max_tunnels: None,
}
}
};
metrics::histogram!("auth_validate_token_latency_ms")
.record(started.elapsed().as_secs_f64() * 1000.0);
Ok(Json(response))
}
async fn sync_jwt_cache(state: &AppState, claims: &Claims) -> Result<(), ApiError> {
if let Some(mut redis) = state.redis.clone() {
let key = format!("auth:jwt:jti:{}", claims.jti);
let ttl = (claims.exp as i64 - now.timestamp()).max(1);
let ttl = (claims.exp as i64 - Utc::now().timestamp()).max(1);
let payload = serde_json::json!({
"user_id": claims.sub,
"plan_tier": claims.tier,
@@ -155,69 +246,53 @@ async fn issue_dev_token(
.await
.map_err(ApiError::internal)?;
}
Ok(Json(TokenResponse {
token,
expires_at: exp.timestamp(),
}))
Ok(())
}
async fn validate_token(
State(state): State<AppState>,
Json(req): Json<ValidateRequest>,
) -> Result<Json<ValidateResponse>, ApiError> {
let decoded = decode::<Claims>(
&req.token,
&DecodingKey::from_secret(state.jwt_secret.as_bytes()),
&Validation::new(Algorithm::HS256),
);
match decoded {
Ok(tok) => {
let c = tok.claims;
if let Some(mut redis) = state.redis.clone() {
let key = format!("plan:user:{}", c.sub);
let payload = serde_json::json!({
"tier": c.tier,
"max_tunnels": c.max_tunnels,
"source": "auth-api"
})
.to_string();
let _: () = redis.set_ex(key, payload, 300).await.map_err(ApiError::internal)?;
}
Ok(Json(ValidateResponse {
valid: true,
user_id: Some(c.sub),
tier: Some(c.tier),
max_tunnels: Some(c.max_tunnels),
}))
}
Err(_) => Ok(Json(ValidateResponse {
valid: false,
user_id: None,
tier: None,
max_tunnels: None,
})),
}
}
async fn stripe_webhook(
State(state): State<AppState>,
Json(event): Json<StripeWebhookEvent>,
) -> Result<impl IntoResponse, ApiError> {
async fn sync_plan_cache(
state: &AppState,
user_id: &str,
ent: &PlanEntitlement,
source: &str,
) -> Result<(), ApiError> {
if let Some(mut redis) = state.redis.clone() {
let key = format!("plan:user:{}", event.user_id);
let key = format!("plan:user:{user_id}");
let payload = serde_json::json!({
"tier": event.tier,
"max_tunnels": event.max_tunnels,
"source": "stripe_webhook",
"last_event_type": event.event_type,
"updated_at": Utc::now().timestamp(),
"tier": ent.tier,
"max_tunnels": ent.max_tunnels,
"source": source,
"updated_at": Utc::now().timestamp()
})
.to_string();
let _: () = redis.set_ex(key, payload, 300).await.map_err(ApiError::internal)?;
}
Ok(StatusCode::NO_CONTENT)
Ok(())
}
async fn load_plan_for_user(state: &AppState, user_id: &str) -> Result<Option<PlanEntitlement>, ApiError> {
let Some(pg) = &state.pg else {
return Ok(None);
};
let _ = Uuid::parse_str(user_id).map_err(ApiError::bad_request)?;
let row = pg
.query_opt(
r#"
select p.id as plan_id, p.max_tunnels
from subscriptions s
join plans p on p.id = s.plan_id
where s.user_id = $1::uuid and s.status in ('active', 'trialing')
order by s.updated_at desc
limit 1
"#,
&[&user_id],
)
.await
.map_err(ApiError::internal)?;
Ok(row.map(|r| PlanEntitlement {
tier: r.get::<_, String>("plan_id"),
max_tunnels: r.get::<_, i32>("max_tunnels") as u32,
}))
}
#[derive(Debug)]
@@ -233,6 +308,14 @@ impl ApiError {
message: e.to_string(),
}
}
fn bad_request<E: std::fmt::Display>(e: E) -> Self {
Self {
status: StatusCode::BAD_REQUEST,
message: e.to_string(),
}
}
}
impl IntoResponse for ApiError {

View File

@@ -57,6 +57,30 @@ pub struct RelayForwardPrelude {
pub initial_data: Vec<u8>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct R2rStreamData {
pub session_id: String,
pub stream_id: String,
pub data: Vec<u8>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct R2rStreamClosed {
pub session_id: String,
pub stream_id: String,
pub reason: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", content = "data")]
pub enum R2rFrame {
Open(RelayForwardPrelude),
Data(R2rStreamData),
Close(R2rStreamClosed),
Ping,
Pong,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", content = "data")]
pub enum ClientFrame {

View File

@@ -35,6 +35,9 @@ create table if not exists subscriptions (
);
create index if not exists subscriptions_user_id_idx on subscriptions(user_id);
create unique index if not exists subscriptions_provider_subscription_id_uidx
on subscriptions(provider_subscription_id)
where provider_subscription_id is not null;
create table if not exists tunnels (
id uuid primary key default gen_random_uuid(),

View File

@@ -14,4 +14,6 @@ redis.workspace = true
serde_json.workspace = true
chrono.workspace = true
serde.workspace = true
metrics.workspace = true
metrics-exporter-prometheus.workspace = true
common = { path = "../common" }

View File

@@ -11,13 +11,14 @@ use common::{
minecraft::read_handshake_hostname_and_bytes,
protocol::{
ClientFrame, Heartbeat, IncomingTcp, RegisterAccepted, RegisterRequest, RelayForwardPrelude,
ServerFrame, StreamClosed, StreamData,
R2rFrame, R2rStreamClosed, R2rStreamData, ServerFrame, StreamClosed, StreamData,
},
};
use redis::AsyncCommands;
use serde::Deserialize;
use metrics_exporter_prometheus::PrometheusBuilder;
use tokio::{
io::{AsyncReadExt, AsyncWriteExt, copy_bidirectional},
io::{AsyncReadExt, AsyncWriteExt},
net::{TcpListener, TcpStream},
sync::{Mutex, Notify, RwLock, mpsc},
time::{MissedTickBehavior, interval, timeout},
@@ -106,6 +107,21 @@ impl RelayState {
type SharedState = Arc<RwLock<RelayState>>;
#[derive(Clone)]
struct R2rManager {
outbound: Arc<Mutex<HashMap<String, mpsc::Sender<R2rFrame>>>>,
ingress_stream_sinks: Arc<RwLock<HashMap<String, mpsc::Sender<Vec<u8>>>>>,
}
impl R2rManager {
fn new() -> Self {
Self {
outbound: Arc::new(Mutex::new(HashMap::new())),
ingress_stream_sinks: Arc::new(RwLock::new(HashMap::new())),
}
}
}
#[derive(Clone)]
struct RelayGuards {
player_ip: Arc<Mutex<HashMap<String, BucketState>>>,
@@ -118,6 +134,11 @@ struct RelayGuards {
reg_ip_burst: f64,
session_bw_rate_bytes: f64,
session_bw_burst_bytes: f64,
redis: Option<redis::aio::ConnectionManager>,
player_global_window_secs: u64,
player_global_limit: i64,
reg_global_window_secs: u64,
reg_global_limit: i64,
}
#[derive(Debug, Clone)]
@@ -170,7 +191,7 @@ impl BucketState {
}
impl RelayGuards {
fn from_env() -> Self {
async fn from_env() -> Self {
let player_ip_rate = std::env::var("RELAY_PLAYER_CONNECTS_PER_SEC")
.ok()
.and_then(|v| v.parse().ok())
@@ -200,6 +221,29 @@ impl RelayGuards {
.and_then(|v| v.parse().ok())
.unwrap_or(512.0)
* 1024.0;
let redis = match std::env::var("REDIS_URL") {
Ok(url) => match redis::Client::open(url) {
Ok(client) => redis::aio::ConnectionManager::new(client).await.ok(),
Err(_) => None,
},
Err(_) => None,
};
let player_global_window_secs = std::env::var("RELAY_PLAYER_GLOBAL_WINDOW_SECS")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(10);
let player_global_limit = std::env::var("RELAY_PLAYER_GLOBAL_LIMIT")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(50);
let reg_global_window_secs = std::env::var("RELAY_REG_GLOBAL_WINDOW_SECS")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(60);
let reg_global_limit = std::env::var("RELAY_REG_GLOBAL_LIMIT")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(10);
Self {
player_ip: Arc::new(Mutex::new(HashMap::new())),
@@ -212,6 +256,11 @@ impl RelayGuards {
reg_ip_burst,
session_bw_rate_bytes,
session_bw_burst_bytes,
redis,
player_global_window_secs,
player_global_limit,
reg_global_window_secs,
reg_global_limit,
}
}
@@ -233,7 +282,18 @@ impl RelayGuards {
let bucket = guard
.entry(ip.to_string())
.or_insert_with(|| BucketState::new(burst, rate));
bucket.reserve_delay(1).is_zero()
let local_ok = bucket.reserve_delay(1).is_zero();
drop(guard);
if !local_ok {
return false;
}
let (window_secs, limit, scope) = if player {
(self.player_global_window_secs, self.player_global_limit, "mc")
} else {
(self.reg_global_window_secs, self.reg_global_limit, "reg")
};
self.redis_allow_ip_window(ip, scope, window_secs, limit).await
}
async fn throttle_session_bytes(&self, session_id: &str, dir: SessionDir, bytes: usize) {
@@ -257,6 +317,31 @@ impl RelayGuards {
self.session_ingress.lock().await.remove(session_id);
self.session_egress.lock().await.remove(session_id);
}
async fn redis_allow_ip_window(
&self,
ip: &str,
scope: &str,
window_secs: u64,
limit: i64,
) -> bool {
let Some(mut conn) = self.redis.clone() else {
return true;
};
let key = format!("ratelimit:ip:{ip}:{scope}");
let res: redis::RedisResult<i64> = async {
let count: i64 = conn.incr(&key, 1).await?;
if count == 1 {
let _: bool = conn.expire(&key, window_secs as i64).await?;
}
Ok(count)
}
.await;
match res {
Ok(count) => count <= limit,
Err(_) => true,
}
}
}
#[derive(Debug, Clone, Deserialize)]
@@ -275,7 +360,8 @@ struct RelayInstanceRecord {
_instance_id: String,
#[serde(rename = "region")]
_region: Option<String>,
status: Option<String>,
#[serde(rename = "status")]
_status: Option<String>,
r2r_addr: Option<String>,
}
@@ -435,6 +521,7 @@ impl RedisRegistry {
#[tokio::main]
async fn main() -> Result<()> {
init_metrics()?;
tracing_subscriber::fmt()
.with_env_filter(
tracing_subscriber::EnvFilter::try_from_default_env()
@@ -444,7 +531,8 @@ async fn main() -> Result<()> {
let cfg = RelayConfig::from_env();
let registry = RedisRegistry::from_env(&cfg).await;
let guards = Arc::new(RelayGuards::from_env());
let guards = Arc::new(RelayGuards::from_env().await);
let r2r = Arc::new(R2rManager::new());
registry.register_instance().await;
let control_listener = TcpListener::bind(&cfg.control_bind)
@@ -458,6 +546,7 @@ async fn main() -> Result<()> {
.with_context(|| format!("bind r2r {}", cfg.r2r_bind))?;
info!(instance_id = %cfg.instance_id, region = %cfg.region, control = %cfg.control_bind, player = %cfg.player_bind, r2r = %cfg.r2r_bind, r2r_advertise = %cfg.r2r_advertise_addr, "relay started");
metrics::gauge!("relay_drain_state").set(0.0);
let shutdown = Arc::new(Notify::new());
let state: SharedState = Arc::new(RwLock::new(RelayState::new()));
@@ -477,6 +566,7 @@ async fn main() -> Result<()> {
state.clone(),
registry.clone(),
guards.clone(),
r2r.clone(),
shutdown.clone(),
));
let r2r_task = tokio::spawn(run_r2r_accept_loop(
@@ -484,6 +574,7 @@ async fn main() -> Result<()> {
cfg.clone(),
state.clone(),
guards.clone(),
r2r.clone(),
shutdown.clone(),
));
@@ -501,6 +592,7 @@ async fn main() -> Result<()> {
}
registry.set_draining().await;
metrics::gauge!("relay_drain_state").set(1.0);
shutdown.notify_waiters();
info!("draining relay");
tokio::time::sleep(Duration::from_secs(1)).await;
@@ -515,6 +607,7 @@ async fn run_registry_heartbeat(state: SharedState, registry: RedisRegistry, shu
_ = shutdown.notified() => break,
_ = ticker.tick() => {
let count = state.read().await.session_count();
metrics::gauge!("relay_active_tunnels").set(count as f64);
registry.heartbeat_instance(count).await;
}
}
@@ -537,6 +630,7 @@ async fn run_control_accept_loop(
Ok(v) => v,
Err(e) => { warn!(error = %e, "control accept failed"); continue; }
};
metrics::counter!("relay_control_accepts_total").increment(1);
let cfg = cfg.clone();
let state = state.clone();
let registry = registry.clone();
@@ -558,6 +652,7 @@ async fn run_player_accept_loop(
state: SharedState,
registry: RedisRegistry,
guards: Arc<RelayGuards>,
r2r: Arc<R2rManager>,
shutdown: Arc<Notify>,
) -> Result<()> {
loop {
@@ -568,12 +663,14 @@ async fn run_player_accept_loop(
Ok(v) => v,
Err(e) => { warn!(error = %e, "player accept failed"); continue; }
};
metrics::counter!("relay_player_accepts_total").increment(1);
let cfg = cfg.clone();
let state = state.clone();
let registry = registry.clone();
let guards = guards.clone();
let r2r = r2r.clone();
tokio::spawn(async move {
if let Err(e) = handle_player_conn(stream, addr, cfg, state, registry, guards).await {
if let Err(e) = handle_player_conn(stream, addr, cfg, state, registry, guards, r2r).await {
debug!(peer = %addr, error = %e, "player connection closed");
}
});
@@ -588,6 +685,7 @@ async fn run_r2r_accept_loop(
cfg: RelayConfig,
state: SharedState,
guards: Arc<RelayGuards>,
r2r: Arc<R2rManager>,
shutdown: Arc<Notify>,
) -> Result<()> {
loop {
@@ -598,11 +696,13 @@ async fn run_r2r_accept_loop(
Ok(v) => v,
Err(e) => { warn!(error = %e, "r2r accept failed"); continue; }
};
metrics::counter!("relay_r2r_accepts_total").increment(1);
let cfg = cfg.clone();
let state = state.clone();
let guards = guards.clone();
let r2r = r2r.clone();
tokio::spawn(async move {
if let Err(e) = handle_r2r_conn(stream, addr, cfg, state, guards).await {
if let Err(e) = handle_r2r_conn(stream, addr, cfg, state, guards, r2r).await {
warn!(peer = %addr, error = %e, "r2r connection ended with error");
}
});
@@ -612,6 +712,7 @@ async fn run_r2r_accept_loop(
Ok(())
}
#[tracing::instrument(skip(stream, state, registry, guards, cfg), fields(peer = %addr))]
async fn handle_control_conn(
stream: TcpStream,
addr: SocketAddr,
@@ -621,6 +722,7 @@ async fn handle_control_conn(
guards: Arc<RelayGuards>,
) -> Result<()> {
if !guards.allow_registration_ip(&addr.ip().to_string()).await {
metrics::counter!("relay_rate_limited_total", "scope" => "registration_ip").increment(1);
anyhow::bail!("registration rate limited for {}", addr.ip());
}
@@ -666,6 +768,7 @@ async fn handle_control_conn(
owner_instance_id: cfg.instance_id.clone(),
})).await?;
info!(peer = %addr, user_id = %user_id, fqdn = %fqdn, session_id = %session_id, "client registered");
metrics::counter!("relay_tunnel_registrations_total").increment(1);
let write_task = tokio::spawn(async move {
while let Some(frame) = rx.recv().await {
@@ -748,6 +851,7 @@ async fn control_read_loop(
}
}
#[tracing::instrument(skip(stream, cfg, state, registry, guards, r2r), fields(peer = %addr))]
async fn handle_player_conn(
mut stream: TcpStream,
addr: SocketAddr,
@@ -755,8 +859,10 @@ async fn handle_player_conn(
state: SharedState,
registry: RedisRegistry,
guards: Arc<RelayGuards>,
r2r: Arc<R2rManager>,
) -> Result<()> {
if !guards.allow_player_ip(&addr.ip().to_string()).await {
metrics::counter!("relay_rate_limited_total", "scope" => "player_ip").increment(1);
debug!(peer = %addr, "player connect rate limited");
return Ok(());
}
@@ -784,83 +890,372 @@ async fn handle_player_conn(
debug!(peer = %addr, hostname = %hostname, session_id = %route.session_id, "route points to self but local session missing");
return Ok(());
}
return proxy_player_to_owner(stream, addr, hostname, initial_data, route, cfg, registry).await;
return proxy_player_to_owner(stream, addr, hostname, initial_data, route, cfg, registry, guards, r2r).await;
}
debug!(peer = %addr, hostname = %hostname, "no tunnel for hostname");
Ok(())
}
#[tracing::instrument(skip(stream, cfg, state, guards, r2r), fields(peer = %addr))]
async fn handle_r2r_conn(
mut stream: TcpStream,
stream: TcpStream,
addr: SocketAddr,
_cfg: RelayConfig,
cfg: RelayConfig,
state: SharedState,
guards: Arc<RelayGuards>,
r2r: Arc<R2rManager>,
) -> Result<()> {
let prelude: RelayForwardPrelude = read_frame(&mut stream).await.context("read r2r prelude")?;
if prelude.version != 1 {
anyhow::bail!("unsupported r2r prelude version {}", prelude.version);
}
if prelude.hop_count > 1 {
anyhow::bail!("invalid hop_count {}", prelude.hop_count);
}
let session = local_session_for_session_id(&state, &prelude.session_id).await;
let Some(session) = session else {
anyhow::bail!("owner session not found for {}", prelude.session_id);
};
attach_player_socket_to_session(
stream,
session,
prelude.fqdn.clone(),
prelude.peer_addr,
prelude.initial_data,
Some(prelude.stream_id),
"r2r",
guards,
)
.await
.with_context(|| format!("r2r attach failed from {addr}"))
handle_r2r_multiplex_conn(stream, addr, cfg, state, guards, r2r)
.await
.with_context(|| format!("r2r multiplex failed from {addr}"))
}
#[tracing::instrument(skip(player_stream, route, cfg, registry, guards, r2r), fields(peer = %player_addr, hostname = %hostname))]
async fn proxy_player_to_owner(
mut player_stream: TcpStream,
player_stream: TcpStream,
player_addr: SocketAddr,
hostname: String,
initial_data: Vec<u8>,
route: TunnelRouteRecord,
cfg: RelayConfig,
registry: RedisRegistry,
guards: Arc<RelayGuards>,
r2r: Arc<R2rManager>,
) -> Result<()> {
let redis_lookup_started = Instant::now();
let owner = registry
.lookup_instance(&route.instance_id)
.await
.with_context(|| format!("owner instance {} not found in redis", route.instance_id))?;
metrics::histogram!("relay_redis_lookup_latency_ms")
.record(redis_lookup_started.elapsed().as_secs_f64() * 1000.0);
let r2r_addr = owner
.r2r_addr
.clone()
.with_context(|| format!("owner {} missing r2r_addr", route.instance_id))?;
let mut owner_stream = timeout(cfg.r2r_connect_timeout, TcpStream::connect(&r2r_addr))
.await
.context("r2r connect timeout")??;
let r2r_connect_started = Instant::now();
metrics::histogram!("relay_r2r_connect_latency_ms")
.record(r2r_connect_started.elapsed().as_secs_f64() * 1000.0);
let prelude = RelayForwardPrelude {
version: 1,
session_id: route.session_id,
session_id: route.session_id.clone(),
fqdn: hostname.clone(),
stream_id: Uuid::new_v4().to_string(),
peer_addr: player_addr.to_string(),
origin_instance_id: cfg.instance_id,
origin_instance_id: cfg.instance_id.clone(),
hop_count: 1,
initial_data,
};
write_frame(&mut owner_stream, &prelude).await?;
proxy_player_to_owner_pooled(
player_stream,
player_addr,
hostname,
route.instance_id,
r2r_addr,
prelude,
route.session_id,
cfg,
guards,
r2r,
)
.await
}
let _ = copy_bidirectional(&mut player_stream, &mut owner_stream).await?;
info!(peer = %player_addr, hostname = %hostname, owner = %route.instance_id, "proxied player connection to owner relay");
async fn proxy_player_to_owner_pooled(
player_stream: TcpStream,
player_addr: SocketAddr,
hostname: String,
owner_instance_id: String,
owner_r2r_addr: String,
prelude: RelayForwardPrelude,
session_id: String,
cfg: RelayConfig,
guards: Arc<RelayGuards>,
r2r: Arc<R2rManager>,
) -> Result<()> {
let stream_id = prelude.stream_id.clone();
let sender = get_or_connect_r2r_pool(
owner_instance_id.clone(),
owner_r2r_addr,
cfg,
guards,
r2r.clone(),
)
.await?;
let (player_read, player_write) = player_stream.into_split();
let (to_player_tx, to_player_rx) = mpsc::channel::<Vec<u8>>(128);
r2r.ingress_stream_sinks
.write()
.await
.insert(stream_id.clone(), to_player_tx);
sender
.send(R2rFrame::Open(prelude))
.await
.context("send r2r open")?;
let tx = sender.clone();
let sid = session_id.clone();
let stid = stream_id.clone();
let sinks = r2r.ingress_stream_sinks.clone();
tokio::spawn(async move {
if let Err(e) =
run_ingress_player_reader_to_r2r(player_read, tx.clone(), sid.clone(), stid.clone()).await
{
debug!(stream_id = %stid, error = %e, "ingress player->r2r reader ended");
}
let _ = tx
.send(R2rFrame::Close(R2rStreamClosed {
session_id: sid,
stream_id: stid.clone(),
reason: Some("ingress_player_reader_closed".into()),
}))
.await;
let _ = sinks.write().await.remove(&stid);
});
let stid = stream_id.clone();
let sinks = r2r.ingress_stream_sinks.clone();
tokio::spawn(async move {
if let Err(e) = run_ingress_player_writer(player_write, to_player_rx).await {
debug!(stream_id = %stid, error = %e, "ingress r2r->player writer ended");
}
let _ = sinks.write().await.remove(&stid);
});
metrics::counter!("relay_r2r_forwards_total").increment(1);
info!(peer = %player_addr, hostname = %hostname, owner = %owner_instance_id, stream_id = %stream_id, "proxied player via pooled r2r channel");
Ok(())
}
async fn get_or_connect_r2r_pool(
owner_instance_id: String,
owner_r2r_addr: String,
cfg: RelayConfig,
guards: Arc<RelayGuards>,
r2r: Arc<R2rManager>,
) -> Result<mpsc::Sender<R2rFrame>> {
if let Some(existing) = r2r.outbound.lock().await.get(&owner_instance_id).cloned() {
return Ok(existing);
}
let connect_started = Instant::now();
let stream = timeout(cfg.r2r_connect_timeout, TcpStream::connect(&owner_r2r_addr))
.await
.context("r2r connect timeout")??;
metrics::histogram!("relay_r2r_connect_latency_ms")
.record(connect_started.elapsed().as_secs_f64() * 1000.0);
let (mut reader, mut writer) = stream.into_split();
let (tx, mut rx) = mpsc::channel::<R2rFrame>(2048);
let mut pools = r2r.outbound.lock().await;
if let Some(existing) = pools.get(&owner_instance_id).cloned() {
return Ok(existing);
}
pools.insert(owner_instance_id.clone(), tx.clone());
drop(pools);
let owner_for_reader = owner_instance_id.clone();
let r2r_for_reader = r2r.clone();
let guards_for_reader = guards.clone();
tokio::spawn(async move {
loop {
match read_frame::<_, R2rFrame>(&mut reader).await {
Ok(frame) => {
if let Err(e) = handle_r2r_inbound_frame(frame, &r2r_for_reader, &guards_for_reader).await {
debug!(owner = %owner_for_reader, error = %e, "r2r pooled inbound frame error");
break;
}
}
Err(e) => {
debug!(owner = %owner_for_reader, error = %e, "r2r pooled reader ended");
break;
}
}
}
r2r_for_reader.outbound.lock().await.remove(&owner_for_reader);
});
tokio::spawn(async move {
while let Some(frame) = rx.recv().await {
if let Err(e) = write_frame(&mut writer, &frame).await {
debug!(error = %e, "r2r pooled writer ended");
break;
}
}
});
Ok(tx)
}
async fn handle_r2r_multiplex_conn(
stream: TcpStream,
_addr: SocketAddr,
_cfg: RelayConfig,
state: SharedState,
guards: Arc<RelayGuards>,
_r2r: Arc<R2rManager>,
) -> Result<()> {
let (mut reader, mut writer) = stream.into_split();
let (tx, mut rx) = mpsc::channel::<R2rFrame>(2048);
let _writer_task = tokio::spawn(async move {
while let Some(frame) = rx.recv().await {
write_frame(&mut writer, &frame).await?;
}
Ok::<(), anyhow::Error>(())
});
loop {
let frame: R2rFrame = read_frame(&mut reader).await?;
match frame {
R2rFrame::Open(prelude) => {
if prelude.version != 1 || prelude.hop_count > 1 {
continue;
}
if let Some(session) = local_session_for_session_id(&state, &prelude.session_id).await {
attach_virtual_r2r_stream_to_session(session, prelude, tx.clone()).await?;
} else {
let _ = tx.send(R2rFrame::Close(R2rStreamClosed {
session_id: prelude.session_id,
stream_id: prelude.stream_id,
reason: Some("owner_session_not_found".into()),
})).await;
}
}
R2rFrame::Data(data) => {
guards
.throttle_session_bytes(&data.session_id, SessionDir::EgressToClient, data.data.len())
.await;
if let Some(session) = local_session_for_session_id(&state, &data.session_id).await {
let _ = session
.tx
.send(ServerFrame::StreamData(StreamData { stream_id: data.stream_id, data: data.data }))
.await;
}
}
R2rFrame::Close(close) => {
if let Some(session) = local_session_for_session_id(&state, &close.session_id).await {
let _ = session
.tx
.send(ServerFrame::StreamClosed(StreamClosed { stream_id: close.stream_id.clone(), reason: close.reason.clone() }))
.await;
remove_stream_sink(&state, &close.session_id, &close.stream_id).await;
} else {
remove_stream_sink(&state, &close.session_id, &close.stream_id).await;
}
}
R2rFrame::Ping => {
let _ = tx.send(R2rFrame::Pong).await;
}
R2rFrame::Pong => {}
}
}
}
async fn attach_virtual_r2r_stream_to_session(
session: SessionHandle,
prelude: RelayForwardPrelude,
r2r_tx: mpsc::Sender<R2rFrame>,
) -> Result<()> {
let stream_id = prelude.stream_id.clone();
let session_id = session.session_id.clone();
let (to_r2r_tx, mut to_r2r_rx) = mpsc::channel::<Vec<u8>>(128);
session
.stream_sinks
.write()
.await
.insert(stream_id.clone(), to_r2r_tx);
session
.tx
.send(ServerFrame::IncomingTcp(IncomingTcp {
stream_id: stream_id.clone(),
session_id: session_id.clone(),
peer_addr: prelude.peer_addr.clone(),
hostname: prelude.fqdn.clone(),
initial_data: prelude.initial_data.clone(),
}))
.await
.context("send virtual r2r IncomingTcp to client")?;
tokio::spawn(async move {
while let Some(chunk) = to_r2r_rx.recv().await {
let _ = r2r_tx
.send(R2rFrame::Data(R2rStreamData {
session_id: session_id.clone(),
stream_id: stream_id.clone(),
data: chunk,
}))
.await;
}
let _ = r2r_tx
.send(R2rFrame::Close(R2rStreamClosed {
session_id,
stream_id,
reason: Some("owner_sink_closed".into()),
}))
.await;
});
Ok(())
}
async fn handle_r2r_inbound_frame(
frame: R2rFrame,
r2r: &R2rManager,
_guards: &RelayGuards,
) -> Result<()> {
match frame {
R2rFrame::Data(data) => {
if let Some(tx) = r2r.ingress_stream_sinks.read().await.get(&data.stream_id).cloned() {
let _ = tx.send(data.data).await;
}
}
R2rFrame::Close(close) => {
r2r.ingress_stream_sinks.write().await.remove(&close.stream_id);
}
R2rFrame::Ping | R2rFrame::Pong | R2rFrame::Open(_) => {}
}
Ok(())
}
async fn run_ingress_player_reader_to_r2r(
mut reader: tokio::net::tcp::OwnedReadHalf,
tx: mpsc::Sender<R2rFrame>,
session_id: String,
stream_id: String,
) -> Result<()> {
let mut buf = vec![0u8; 16 * 1024];
loop {
let n = reader.read(&mut buf).await?;
if n == 0 {
break;
}
tx.send(R2rFrame::Data(R2rStreamData {
session_id: session_id.clone(),
stream_id: stream_id.clone(),
data: buf[..n].to_vec(),
}))
.await
.context("send ingress data to r2r")?;
}
Ok(())
}
async fn run_ingress_player_writer(
mut writer: tokio::net::tcp::OwnedWriteHalf,
mut rx: mpsc::Receiver<Vec<u8>>,
) -> Result<()> {
while let Some(chunk) = rx.recv().await {
writer.write_all(&chunk).await?;
}
let _ = writer.shutdown().await;
Ok(())
}
@@ -944,6 +1339,7 @@ async fn attach_player_socket_to_session(
});
info!(peer = %peer_addr, hostname = %hostname, session_id = %session.session_id, stream_id = %stream_id, source, "player proxied via client stream");
metrics::gauge!("relay_active_player_conns").increment(1.0);
Ok(())
}
@@ -970,7 +1366,9 @@ async fn run_player_reader(
}))
.await
.context("send stream data to client")?;
metrics::counter!("relay_bytes_out_total").increment(n as u64);
}
metrics::gauge!("relay_active_player_conns").decrement(1.0);
Ok(())
}
@@ -980,6 +1378,7 @@ async fn run_player_writer(
) -> Result<()> {
while let Some(chunk) = rx.recv().await {
writer.write_all(&chunk).await?;
metrics::counter!("relay_bytes_in_total").increment(chunk.len() as u64);
}
let _ = writer.shutdown().await;
Ok(())
@@ -1061,3 +1460,14 @@ fn guess_advertise_addr(bind: &str) -> String {
"127.0.0.1:7001".to_string()
}
}
fn init_metrics() -> Result<()> {
if let Ok(bind) = std::env::var("RELAY_METRICS_BIND") {
let addr: std::net::SocketAddr = bind.parse().context("parse RELAY_METRICS_BIND")?;
PrometheusBuilder::new()
.with_http_listener(addr)
.install()
.context("install prometheus exporter")?;
}
Ok(())
}