diff --git a/Cargo.lock b/Cargo.lock index e6f6360..f60da91 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -921,6 +921,7 @@ dependencies = [ "common", "fastrand", "redis", + "serde", "serde_json", "tokio", "tracing", diff --git a/relay/src/main.rs b/relay/src/main.rs index 0476173..f0e4c97 100644 --- a/relay/src/main.rs +++ b/relay/src/main.rs @@ -19,7 +19,7 @@ use serde::Deserialize; use tokio::{ io::{AsyncReadExt, AsyncWriteExt, copy_bidirectional}, net::{TcpListener, TcpStream}, - sync::{Notify, RwLock, mpsc}, + sync::{Mutex, Notify, RwLock, mpsc}, time::{MissedTickBehavior, interval, timeout}, }; use tracing::{debug, info, warn}; @@ -48,6 +48,7 @@ impl RelayConfig { let r2r_bind = std::env::var("RELAY_R2R_BIND").unwrap_or_else(|_| "0.0.0.0:7001".to_string()); let r2r_advertise_addr = std::env::var("RELAY_R2R_ADVERTISE_ADDR") .unwrap_or_else(|_| guess_advertise_addr(&r2r_bind)); + Self { instance_id: std::env::var("RELAY_INSTANCE_ID") .unwrap_or_else(|_| format!("relay-{}", Uuid::new_v4())), @@ -105,6 +106,159 @@ impl RelayState { type SharedState = Arc>; +#[derive(Clone)] +struct RelayGuards { + player_ip: Arc>>, + reg_ip: Arc>>, + session_ingress: Arc>>, + session_egress: Arc>>, + player_ip_rate: f64, + player_ip_burst: f64, + reg_ip_rate: f64, + reg_ip_burst: f64, + session_bw_rate_bytes: f64, + session_bw_burst_bytes: f64, +} + +#[derive(Debug, Clone)] +struct BucketState { + tokens: f64, + last_refill: Instant, + capacity: f64, + rate_per_sec: f64, +} + +#[derive(Clone, Copy)] +enum SessionDir { + IngressFromClient, + EgressToClient, +} + +impl BucketState { + fn new(capacity: f64, rate_per_sec: f64) -> Self { + Self { + tokens: capacity, + last_refill: Instant::now(), + capacity, + rate_per_sec, + } + } + + fn reserve_delay(&mut self, amount: usize) -> Duration { + let now = Instant::now(); + let elapsed = now.saturating_duration_since(self.last_refill).as_secs_f64(); + if elapsed > 0.0 { + self.tokens = (self.tokens + elapsed * self.rate_per_sec).min(self.capacity); + self.last_refill = now; + } + + let amount = amount as f64; + if self.tokens >= amount { + self.tokens -= amount; + return Duration::ZERO; + } + if self.rate_per_sec <= 0.0 { + return Duration::from_secs(1); + } + + let deficit = amount - self.tokens; + let wait = Duration::from_secs_f64(deficit / self.rate_per_sec); + self.tokens = 0.0; + self.last_refill = now + wait; + wait + } +} + +impl RelayGuards { + fn from_env() -> Self { + let player_ip_rate = std::env::var("RELAY_PLAYER_CONNECTS_PER_SEC") + .ok() + .and_then(|v| v.parse().ok()) + .unwrap_or(2.0); + let player_ip_burst = std::env::var("RELAY_PLAYER_CONNECTS_BURST") + .ok() + .and_then(|v| v.parse().ok()) + .unwrap_or(20.0); + + let reg_per_min = std::env::var("RELAY_REG_ATTEMPTS_PER_MIN") + .ok() + .and_then(|v| v.parse().ok()) + .unwrap_or(10.0); + let reg_ip_rate = reg_per_min / 60.0; + let reg_ip_burst = std::env::var("RELAY_REG_ATTEMPTS_BURST") + .ok() + .and_then(|v| v.parse().ok()) + .unwrap_or(10.0); + + let session_bw_kbps = std::env::var("RELAY_SESSION_BW_KBPS") + .ok() + .and_then(|v| v.parse().ok()) + .unwrap_or(8192.0); + let session_bw_rate_bytes = session_bw_kbps * 1024.0 / 8.0; + let session_bw_burst_bytes = std::env::var("RELAY_SESSION_BW_BURST_KB") + .ok() + .and_then(|v| v.parse().ok()) + .unwrap_or(512.0) + * 1024.0; + + Self { + player_ip: Arc::new(Mutex::new(HashMap::new())), + reg_ip: Arc::new(Mutex::new(HashMap::new())), + session_ingress: Arc::new(Mutex::new(HashMap::new())), + session_egress: Arc::new(Mutex::new(HashMap::new())), + player_ip_rate, + player_ip_burst, + reg_ip_rate, + reg_ip_burst, + session_bw_rate_bytes, + session_bw_burst_bytes, + } + } + + async fn allow_player_ip(&self, ip: &str) -> bool { + self.allow_ip(ip, true).await + } + + async fn allow_registration_ip(&self, ip: &str) -> bool { + self.allow_ip(ip, false).await + } + + async fn allow_ip(&self, ip: &str, player: bool) -> bool { + let (map, rate, burst) = if player { + (&self.player_ip, self.player_ip_rate, self.player_ip_burst) + } else { + (&self.reg_ip, self.reg_ip_rate, self.reg_ip_burst) + }; + let mut guard = map.lock().await; + let bucket = guard + .entry(ip.to_string()) + .or_insert_with(|| BucketState::new(burst, rate)); + bucket.reserve_delay(1).is_zero() + } + + async fn throttle_session_bytes(&self, session_id: &str, dir: SessionDir, bytes: usize) { + let delay = { + let map = match dir { + SessionDir::IngressFromClient => &self.session_ingress, + SessionDir::EgressToClient => &self.session_egress, + }; + let mut guard = map.lock().await; + let bucket = guard.entry(session_id.to_string()).or_insert_with(|| { + BucketState::new(self.session_bw_burst_bytes, self.session_bw_rate_bytes) + }); + bucket.reserve_delay(bytes) + }; + if !delay.is_zero() { + tokio::time::sleep(delay).await; + } + } + + async fn remove_session(&self, session_id: &str) { + self.session_ingress.lock().await.remove(session_id); + self.session_egress.lock().await.remove(session_id); + } +} + #[derive(Debug, Clone, Deserialize)] struct TunnelRouteRecord { instance_id: String, @@ -170,9 +324,7 @@ impl RedisRegistry { } async fn register_instance(&self) { - let Some(mut conn) = self.conn.clone() else { - return; - }; + let Some(mut conn) = self.conn.clone() else { return; }; let key = format!("relay:instance:{}", self.instance_id); let payload = serde_json::json!({ "instance_id": self.instance_id, @@ -182,30 +334,23 @@ impl RedisRegistry { "player_addr": self.player_addr, "r2r_addr": self.r2r_addr, "started_at": chrono::Utc::now().timestamp(), - }) - .to_string(); + }).to_string(); let res: redis::RedisResult<()> = async { let _: usize = conn.sadd("relay:instances", &self.instance_id).await?; let _: () = conn.set_ex(&key, payload, self.ttl_secs).await?; - let hb_key = format!("relay:heartbeat:{}", self.instance_id); - let _: () = conn.set_ex(hb_key, "1", self.ttl_secs).await?; + let _: () = conn.set_ex(format!("relay:heartbeat:{}", self.instance_id), "1", self.ttl_secs).await?; Ok(()) - } - .await; + }.await; if let Err(e) = res { warn!(error = %e, "failed to register instance in redis"); } } async fn heartbeat_instance(&self, tunnel_count: usize) { - let Some(mut conn) = self.conn.clone() else { - return; - }; - let hb_key = format!("relay:heartbeat:{}", self.instance_id); + let Some(mut conn) = self.conn.clone() else { return; }; let load_key = format!("relay:load:{}", self.region); - let score = tunnel_count as f64; - let key = format!("relay:instance:{}", self.instance_id); + let inst_key = format!("relay:instance:{}", self.instance_id); let payload = serde_json::json!({ "instance_id": self.instance_id, "region": self.region, @@ -215,25 +360,21 @@ impl RedisRegistry { "r2r_addr": self.r2r_addr, "tunnel_count": tunnel_count, "updated_at": chrono::Utc::now().timestamp(), - }) - .to_string(); + }).to_string(); let res: redis::RedisResult<()> = async { - let _: () = conn.set_ex(hb_key, "1", self.ttl_secs).await?; - let _: () = conn.set_ex(key, payload, self.ttl_secs).await?; - let _: () = conn.zadd(load_key, &self.instance_id, score).await?; + let _: () = conn.set_ex(format!("relay:heartbeat:{}", self.instance_id), "1", self.ttl_secs).await?; + let _: () = conn.set_ex(inst_key, payload, self.ttl_secs).await?; + let _: () = conn.zadd(load_key, &self.instance_id, tunnel_count as f64).await?; Ok(()) - } - .await; + }.await; if let Err(e) = res { warn!(error = %e, "redis instance heartbeat failed"); } } async fn set_draining(&self) { - let Some(mut conn) = self.conn.clone() else { - return; - }; - let key = format!("relay:instance:{}", self.instance_id); + let Some(mut conn) = self.conn.clone() else { return; }; + let inst_key = format!("relay:instance:{}", self.instance_id); let payload = serde_json::json!({ "instance_id": self.instance_id, "region": self.region, @@ -242,37 +383,28 @@ impl RedisRegistry { "player_addr": self.player_addr, "r2r_addr": self.r2r_addr, "updated_at": chrono::Utc::now().timestamp(), - }) - .to_string(); + }).to_string(); let _: redis::RedisResult<()> = async { - let _: () = conn.set_ex(key, payload, self.ttl_secs).await?; - let load_key = format!("relay:load:{}", self.region); - let _: () = conn.zadd(load_key, &self.instance_id, 1e12f64).await?; + let _: () = conn.set_ex(inst_key, payload, self.ttl_secs).await?; + let _: () = conn.zadd(format!("relay:load:{}", self.region), &self.instance_id, 1e12f64).await?; Ok(()) - } - .await; + }.await; } async fn register_tunnel(&self, fqdn: &str, session_id: &str, user_id: &str) { - let Some(mut conn) = self.conn.clone() else { - return; - }; - let key = format!("tunnel:sub:{fqdn}"); - let session_key = format!("tunnel:session:{session_id}"); + let Some(mut conn) = self.conn.clone() else { return; }; let payload = serde_json::json!({ "instance_id": self.instance_id, "session_id": session_id, "user_id": user_id, "region": self.region, "fqdn": fqdn, - }) - .to_string(); + }).to_string(); let _: redis::RedisResult<()> = async { - let _: () = conn.set_ex(key, &payload, self.ttl_secs).await?; - let _: () = conn.set_ex(session_key, payload, self.ttl_secs).await?; + let _: () = conn.set_ex(format!("tunnel:sub:{fqdn}"), &payload, self.ttl_secs).await?; + let _: () = conn.set_ex(format!("tunnel:session:{session_id}"), payload, self.ttl_secs).await?; Ok(()) - } - .await; + }.await; } async fn refresh_tunnel_session(&self, fqdn: &str, session_id: &str, user_id: &str) { @@ -280,35 +412,24 @@ impl RedisRegistry { } async fn remove_tunnel(&self, fqdn: &str, session_id: &str) { - let Some(mut conn) = self.conn.clone() else { - return; - }; + let Some(mut conn) = self.conn.clone() else { return; }; let _: redis::RedisResult<()> = async { let _: usize = conn.del(format!("tunnel:sub:{fqdn}")).await?; let _: usize = conn.del(format!("tunnel:session:{session_id}")).await?; Ok(()) - } - .await; + }.await; } async fn lookup_tunnel(&self, fqdn: &str) -> Option { - let Some(mut conn) = self.conn.clone() else { - return None; - }; - let key = format!("tunnel:sub:{fqdn}"); - let raw: Option = conn.get(key).await.ok()?; - let raw = raw?; - serde_json::from_str(&raw).ok() + let Some(mut conn) = self.conn.clone() else { return None; }; + let raw: Option = conn.get(format!("tunnel:sub:{fqdn}")).await.ok()?; + serde_json::from_str(&raw?).ok() } async fn lookup_instance(&self, instance_id: &str) -> Option { - let Some(mut conn) = self.conn.clone() else { - return None; - }; - let key = format!("relay:instance:{instance_id}"); - let raw: Option = conn.get(key).await.ok()?; - let raw = raw?; - serde_json::from_str(&raw).ok() + let Some(mut conn) = self.conn.clone() else { return None; }; + let raw: Option = conn.get(format!("relay:instance:{instance_id}")).await.ok()?; + serde_json::from_str(&raw?).ok() } } @@ -323,6 +444,7 @@ async fn main() -> Result<()> { let cfg = RelayConfig::from_env(); let registry = RedisRegistry::from_env(&cfg).await; + let guards = Arc::new(RelayGuards::from_env()); registry.register_instance().await; let control_listener = TcpListener::bind(&cfg.control_bind) @@ -335,15 +457,7 @@ async fn main() -> Result<()> { .await .with_context(|| format!("bind r2r {}", cfg.r2r_bind))?; - info!( - instance_id = %cfg.instance_id, - region = %cfg.region, - control = %cfg.control_bind, - player = %cfg.player_bind, - r2r = %cfg.r2r_bind, - r2r_advertise = %cfg.r2r_advertise_addr, - "relay started" - ); + info!(instance_id = %cfg.instance_id, region = %cfg.region, control = %cfg.control_bind, player = %cfg.player_bind, r2r = %cfg.r2r_bind, r2r_advertise = %cfg.r2r_advertise_addr, "relay started"); let shutdown = Arc::new(Notify::new()); let state: SharedState = Arc::new(RwLock::new(RelayState::new())); @@ -354,6 +468,7 @@ async fn main() -> Result<()> { cfg.clone(), state.clone(), registry.clone(), + guards.clone(), shutdown.clone(), )); let player_task = tokio::spawn(run_player_accept_loop( @@ -361,12 +476,14 @@ async fn main() -> Result<()> { cfg.clone(), state.clone(), registry.clone(), + guards.clone(), shutdown.clone(), )); let r2r_task = tokio::spawn(run_r2r_accept_loop( r2r_listener, cfg.clone(), state.clone(), + guards.clone(), shutdown.clone(), )); @@ -409,6 +526,7 @@ async fn run_control_accept_loop( cfg: RelayConfig, state: SharedState, registry: RedisRegistry, + guards: Arc, shutdown: Arc, ) -> Result<()> { loop { @@ -422,8 +540,9 @@ async fn run_control_accept_loop( let cfg = cfg.clone(); let state = state.clone(); let registry = registry.clone(); + let guards = guards.clone(); tokio::spawn(async move { - if let Err(e) = handle_control_conn(stream, addr, cfg, state, registry).await { + if let Err(e) = handle_control_conn(stream, addr, cfg, state, registry, guards).await { warn!(peer = %addr, error = %e, "control connection ended with error"); } }); @@ -438,6 +557,7 @@ async fn run_player_accept_loop( cfg: RelayConfig, state: SharedState, registry: RedisRegistry, + guards: Arc, shutdown: Arc, ) -> Result<()> { loop { @@ -448,11 +568,12 @@ async fn run_player_accept_loop( Ok(v) => v, Err(e) => { warn!(error = %e, "player accept failed"); continue; } }; + let cfg = cfg.clone(); let state = state.clone(); let registry = registry.clone(); - let cfg = cfg.clone(); + let guards = guards.clone(); tokio::spawn(async move { - if let Err(e) = handle_player_conn(stream, addr, cfg, state, registry).await { + if let Err(e) = handle_player_conn(stream, addr, cfg, state, registry, guards).await { debug!(peer = %addr, error = %e, "player connection closed"); } }); @@ -466,6 +587,7 @@ async fn run_r2r_accept_loop( listener: TcpListener, cfg: RelayConfig, state: SharedState, + guards: Arc, shutdown: Arc, ) -> Result<()> { loop { @@ -476,10 +598,11 @@ async fn run_r2r_accept_loop( Ok(v) => v, Err(e) => { warn!(error = %e, "r2r accept failed"); continue; } }; - let state = state.clone(); let cfg = cfg.clone(); + let state = state.clone(); + let guards = guards.clone(); tokio::spawn(async move { - if let Err(e) = handle_r2r_conn(stream, addr, cfg, state).await { + if let Err(e) = handle_r2r_conn(stream, addr, cfg, state, guards).await { warn!(peer = %addr, error = %e, "r2r connection ended with error"); } }); @@ -495,25 +618,23 @@ async fn handle_control_conn( cfg: RelayConfig, state: SharedState, registry: RedisRegistry, + guards: Arc, ) -> Result<()> { + if !guards.allow_registration_ip(&addr.ip().to_string()).await { + anyhow::bail!("registration rate limited for {}", addr.ip()); + } + let (mut reader, mut writer) = stream.into_split(); let first: ClientFrame = read_frame(&mut reader).await.context("read initial frame")?; let register = match first { ClientFrame::Register(req) => req, _ => { - write_frame( - &mut writer, - &ServerFrame::RegisterRejected { reason: "expected Register frame".to_string() }, - ).await.ok(); + write_frame(&mut writer, &ServerFrame::RegisterRejected { reason: "expected Register frame".to_string() }).await.ok(); anyhow::bail!("expected register frame"); } }; - if !token_looks_valid(®ister.token) { - write_frame( - &mut writer, - &ServerFrame::RegisterRejected { reason: "invalid token".to_string() }, - ).await.ok(); + write_frame(&mut writer, &ServerFrame::RegisterRejected { reason: "invalid token".to_string() }).await.ok(); anyhow::bail!("invalid token"); } @@ -525,27 +646,25 @@ async fn handle_control_conn( { let mut guard = state.write().await; - let handle = SessionHandle { - session_id: session_id.clone(), - tx: tx.clone(), - stream_sinks: stream_sinks.clone(), - last_heartbeat: Instant::now(), - }; guard.by_session.insert(session_id.clone(), fqdn.clone()); - guard.by_fqdn.insert(fqdn.clone(), handle); + guard.by_fqdn.insert( + fqdn.clone(), + SessionHandle { + session_id: session_id.clone(), + tx: tx.clone(), + stream_sinks: stream_sinks.clone(), + last_heartbeat: Instant::now(), + }, + ); } registry.register_tunnel(&fqdn, &session_id, &user_id).await; - write_frame( - &mut writer, - &ServerFrame::RegisterAccepted(RegisterAccepted { - session_id: session_id.clone(), - fqdn: fqdn.clone(), - heartbeat_interval_secs: 5, - owner_instance_id: cfg.instance_id.clone(), - }), - ) - .await?; + write_frame(&mut writer, &ServerFrame::RegisterAccepted(RegisterAccepted { + session_id: session_id.clone(), + fqdn: fqdn.clone(), + heartbeat_interval_secs: 5, + owner_instance_id: cfg.instance_id.clone(), + })).await?; info!(peer = %addr, user_id = %user_id, fqdn = %fqdn, session_id = %session_id, "client registered"); let write_task = tokio::spawn(async move { @@ -559,12 +678,12 @@ async fn handle_control_conn( &mut reader, &state, ®istry, + &guards, &session_id, &fqdn, &user_id, cfg.heartbeat_timeout, - ) - .await; + ).await; if let Err(e) = &read_result { warn!(session_id = %session_id, error = %e, "control read loop error"); @@ -577,6 +696,7 @@ async fn handle_control_conn( } } registry.remove_tunnel(&fqdn, &session_id).await; + guards.remove_session(&session_id).await; write_task.abort(); info!(session_id = %session_id, "client session removed"); read_result @@ -586,15 +706,14 @@ async fn control_read_loop( reader: &mut tokio::net::tcp::OwnedReadHalf, state: &SharedState, registry: &RedisRegistry, + guards: &RelayGuards, session_id: &str, fqdn: &str, user_id: &str, heartbeat_timeout: Duration, ) -> Result<()> { loop { - let frame: ClientFrame = timeout(heartbeat_timeout, read_frame(reader)) - .await - .context("heartbeat timeout")??; + let frame: ClientFrame = timeout(heartbeat_timeout, read_frame(reader)).await.context("heartbeat timeout")??; match frame { ClientFrame::Heartbeat(Heartbeat { session_id: hb_id, .. }) => { if hb_id != session_id { @@ -610,6 +729,9 @@ async fn control_read_loop( registry.refresh_tunnel_session(fqdn, session_id, user_id).await; } ClientFrame::StreamData(StreamData { stream_id, data }) => { + guards + .throttle_session_bytes(session_id, SessionDir::IngressFromClient, data.len()) + .await; let sink = lookup_stream_sink(state, session_id, &stream_id).await; if let Some(tx) = sink { if tx.send(data).await.is_err() { @@ -632,7 +754,13 @@ async fn handle_player_conn( cfg: RelayConfig, state: SharedState, registry: RedisRegistry, + guards: Arc, ) -> Result<()> { + if !guards.allow_player_ip(&addr.ip().to_string()).await { + debug!(peer = %addr, "player connect rate limited"); + return Ok(()); + } + let (hostname, initial_data) = read_handshake_hostname_and_bytes(&mut stream) .await .context("parse minecraft handshake")?; @@ -646,6 +774,7 @@ async fn handle_player_conn( initial_data, None, "direct", + guards, ) .await; } @@ -667,6 +796,7 @@ async fn handle_r2r_conn( addr: SocketAddr, _cfg: RelayConfig, state: SharedState, + guards: Arc, ) -> Result<()> { let prelude: RelayForwardPrelude = read_frame(&mut stream).await.context("read r2r prelude")?; if prelude.version != 1 { @@ -689,6 +819,7 @@ async fn handle_r2r_conn( prelude.initial_data, Some(prelude.stream_id), "r2r", + guards, ) .await .with_context(|| format!("r2r attach failed from {addr}")) @@ -707,9 +838,6 @@ async fn proxy_player_to_owner( .lookup_instance(&route.instance_id) .await .with_context(|| format!("owner instance {} not found in redis", route.instance_id))?; - if owner.status.as_deref() == Some("draining") { - debug!(owner = %route.instance_id, hostname = %hostname, "owner draining; attempting forward anyway"); - } let r2r_addr = owner .r2r_addr .clone() @@ -725,7 +853,7 @@ async fn proxy_player_to_owner( fqdn: hostname.clone(), stream_id: Uuid::new_v4().to_string(), peer_addr: player_addr.to_string(), - origin_instance_id: cfg.instance_id.clone(), + origin_instance_id: cfg.instance_id, hop_count: 1, initial_data, }; @@ -737,8 +865,7 @@ async fn proxy_player_to_owner( } async fn local_session_for_hostname(state: &SharedState, hostname: &str) -> Option { - let guard = state.read().await; - guard.by_fqdn.get(hostname).cloned() + state.read().await.by_fqdn.get(hostname).cloned() } async fn local_session_for_session_id(state: &SharedState, session_id: &str) -> Option { @@ -755,15 +882,12 @@ async fn attach_player_socket_to_session( initial_data: Vec, stream_id_override: Option, source: &'static str, + guards: Arc, ) -> Result<()> { let stream_id = stream_id_override.unwrap_or_else(|| Uuid::new_v4().to_string()); let (player_read, player_write) = stream.into_split(); let (to_player_tx, to_player_rx) = mpsc::channel::>(128); - session - .stream_sinks - .write() - .await - .insert(stream_id.clone(), to_player_tx); + session.stream_sinks.write().await.insert(stream_id.clone(), to_player_tx); session .tx @@ -795,9 +919,19 @@ async fn attach_player_socket_to_session( let tx_control = session.tx.clone(); let stream_id_clone = stream_id.clone(); + let session_id_clone = session.session_id.clone(); let sinks = session.stream_sinks.clone(); + let guards_clone = guards.clone(); tokio::spawn(async move { - if let Err(e) = run_player_reader(player_read, tx_control.clone(), stream_id_clone.clone()).await { + if let Err(e) = run_player_reader( + player_read, + tx_control.clone(), + stream_id_clone.clone(), + session_id_clone, + guards_clone, + ) + .await + { debug!(stream_id = %stream_id_clone, error = %e, "player reader ended"); } let _ = tx_control @@ -817,6 +951,8 @@ async fn run_player_reader( mut reader: tokio::net::tcp::OwnedReadHalf, tx_control: mpsc::Sender, stream_id: String, + session_id: String, + guards: Arc, ) -> Result<()> { let mut buf = vec![0u8; 16 * 1024]; loop { @@ -824,6 +960,9 @@ async fn run_player_reader( if n == 0 { break; } + guards + .throttle_session_bytes(&session_id, SessionDir::EgressToClient, n) + .await; tx_control .send(ServerFrame::StreamData(StreamData { stream_id: stream_id.clone(),