chore: update Rust dependencies.

This commit is contained in:
L
2026-02-23 23:26:57 +00:00
parent e7ef7fdf70
commit 09205f8db2
2 changed files with 261 additions and 121 deletions

1
Cargo.lock generated
View File

@@ -921,6 +921,7 @@ dependencies = [
"common", "common",
"fastrand", "fastrand",
"redis", "redis",
"serde",
"serde_json", "serde_json",
"tokio", "tokio",
"tracing", "tracing",

View File

@@ -19,7 +19,7 @@ use serde::Deserialize;
use tokio::{ use tokio::{
io::{AsyncReadExt, AsyncWriteExt, copy_bidirectional}, io::{AsyncReadExt, AsyncWriteExt, copy_bidirectional},
net::{TcpListener, TcpStream}, net::{TcpListener, TcpStream},
sync::{Notify, RwLock, mpsc}, sync::{Mutex, Notify, RwLock, mpsc},
time::{MissedTickBehavior, interval, timeout}, time::{MissedTickBehavior, interval, timeout},
}; };
use tracing::{debug, info, warn}; use tracing::{debug, info, warn};
@@ -48,6 +48,7 @@ impl RelayConfig {
let r2r_bind = std::env::var("RELAY_R2R_BIND").unwrap_or_else(|_| "0.0.0.0:7001".to_string()); let r2r_bind = std::env::var("RELAY_R2R_BIND").unwrap_or_else(|_| "0.0.0.0:7001".to_string());
let r2r_advertise_addr = std::env::var("RELAY_R2R_ADVERTISE_ADDR") let r2r_advertise_addr = std::env::var("RELAY_R2R_ADVERTISE_ADDR")
.unwrap_or_else(|_| guess_advertise_addr(&r2r_bind)); .unwrap_or_else(|_| guess_advertise_addr(&r2r_bind));
Self { Self {
instance_id: std::env::var("RELAY_INSTANCE_ID") instance_id: std::env::var("RELAY_INSTANCE_ID")
.unwrap_or_else(|_| format!("relay-{}", Uuid::new_v4())), .unwrap_or_else(|_| format!("relay-{}", Uuid::new_v4())),
@@ -105,6 +106,159 @@ impl RelayState {
type SharedState = Arc<RwLock<RelayState>>; type SharedState = Arc<RwLock<RelayState>>;
#[derive(Clone)]
struct RelayGuards {
player_ip: Arc<Mutex<HashMap<String, BucketState>>>,
reg_ip: Arc<Mutex<HashMap<String, BucketState>>>,
session_ingress: Arc<Mutex<HashMap<String, BucketState>>>,
session_egress: Arc<Mutex<HashMap<String, BucketState>>>,
player_ip_rate: f64,
player_ip_burst: f64,
reg_ip_rate: f64,
reg_ip_burst: f64,
session_bw_rate_bytes: f64,
session_bw_burst_bytes: f64,
}
#[derive(Debug, Clone)]
struct BucketState {
tokens: f64,
last_refill: Instant,
capacity: f64,
rate_per_sec: f64,
}
#[derive(Clone, Copy)]
enum SessionDir {
IngressFromClient,
EgressToClient,
}
impl BucketState {
fn new(capacity: f64, rate_per_sec: f64) -> Self {
Self {
tokens: capacity,
last_refill: Instant::now(),
capacity,
rate_per_sec,
}
}
fn reserve_delay(&mut self, amount: usize) -> Duration {
let now = Instant::now();
let elapsed = now.saturating_duration_since(self.last_refill).as_secs_f64();
if elapsed > 0.0 {
self.tokens = (self.tokens + elapsed * self.rate_per_sec).min(self.capacity);
self.last_refill = now;
}
let amount = amount as f64;
if self.tokens >= amount {
self.tokens -= amount;
return Duration::ZERO;
}
if self.rate_per_sec <= 0.0 {
return Duration::from_secs(1);
}
let deficit = amount - self.tokens;
let wait = Duration::from_secs_f64(deficit / self.rate_per_sec);
self.tokens = 0.0;
self.last_refill = now + wait;
wait
}
}
impl RelayGuards {
fn from_env() -> Self {
let player_ip_rate = std::env::var("RELAY_PLAYER_CONNECTS_PER_SEC")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(2.0);
let player_ip_burst = std::env::var("RELAY_PLAYER_CONNECTS_BURST")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(20.0);
let reg_per_min = std::env::var("RELAY_REG_ATTEMPTS_PER_MIN")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(10.0);
let reg_ip_rate = reg_per_min / 60.0;
let reg_ip_burst = std::env::var("RELAY_REG_ATTEMPTS_BURST")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(10.0);
let session_bw_kbps = std::env::var("RELAY_SESSION_BW_KBPS")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(8192.0);
let session_bw_rate_bytes = session_bw_kbps * 1024.0 / 8.0;
let session_bw_burst_bytes = std::env::var("RELAY_SESSION_BW_BURST_KB")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(512.0)
* 1024.0;
Self {
player_ip: Arc::new(Mutex::new(HashMap::new())),
reg_ip: Arc::new(Mutex::new(HashMap::new())),
session_ingress: Arc::new(Mutex::new(HashMap::new())),
session_egress: Arc::new(Mutex::new(HashMap::new())),
player_ip_rate,
player_ip_burst,
reg_ip_rate,
reg_ip_burst,
session_bw_rate_bytes,
session_bw_burst_bytes,
}
}
async fn allow_player_ip(&self, ip: &str) -> bool {
self.allow_ip(ip, true).await
}
async fn allow_registration_ip(&self, ip: &str) -> bool {
self.allow_ip(ip, false).await
}
async fn allow_ip(&self, ip: &str, player: bool) -> bool {
let (map, rate, burst) = if player {
(&self.player_ip, self.player_ip_rate, self.player_ip_burst)
} else {
(&self.reg_ip, self.reg_ip_rate, self.reg_ip_burst)
};
let mut guard = map.lock().await;
let bucket = guard
.entry(ip.to_string())
.or_insert_with(|| BucketState::new(burst, rate));
bucket.reserve_delay(1).is_zero()
}
async fn throttle_session_bytes(&self, session_id: &str, dir: SessionDir, bytes: usize) {
let delay = {
let map = match dir {
SessionDir::IngressFromClient => &self.session_ingress,
SessionDir::EgressToClient => &self.session_egress,
};
let mut guard = map.lock().await;
let bucket = guard.entry(session_id.to_string()).or_insert_with(|| {
BucketState::new(self.session_bw_burst_bytes, self.session_bw_rate_bytes)
});
bucket.reserve_delay(bytes)
};
if !delay.is_zero() {
tokio::time::sleep(delay).await;
}
}
async fn remove_session(&self, session_id: &str) {
self.session_ingress.lock().await.remove(session_id);
self.session_egress.lock().await.remove(session_id);
}
}
#[derive(Debug, Clone, Deserialize)] #[derive(Debug, Clone, Deserialize)]
struct TunnelRouteRecord { struct TunnelRouteRecord {
instance_id: String, instance_id: String,
@@ -170,9 +324,7 @@ impl RedisRegistry {
} }
async fn register_instance(&self) { async fn register_instance(&self) {
let Some(mut conn) = self.conn.clone() else { let Some(mut conn) = self.conn.clone() else { return; };
return;
};
let key = format!("relay:instance:{}", self.instance_id); let key = format!("relay:instance:{}", self.instance_id);
let payload = serde_json::json!({ let payload = serde_json::json!({
"instance_id": self.instance_id, "instance_id": self.instance_id,
@@ -182,30 +334,23 @@ impl RedisRegistry {
"player_addr": self.player_addr, "player_addr": self.player_addr,
"r2r_addr": self.r2r_addr, "r2r_addr": self.r2r_addr,
"started_at": chrono::Utc::now().timestamp(), "started_at": chrono::Utc::now().timestamp(),
}) }).to_string();
.to_string();
let res: redis::RedisResult<()> = async { let res: redis::RedisResult<()> = async {
let _: usize = conn.sadd("relay:instances", &self.instance_id).await?; let _: usize = conn.sadd("relay:instances", &self.instance_id).await?;
let _: () = conn.set_ex(&key, payload, self.ttl_secs).await?; let _: () = conn.set_ex(&key, payload, self.ttl_secs).await?;
let hb_key = format!("relay:heartbeat:{}", self.instance_id); let _: () = conn.set_ex(format!("relay:heartbeat:{}", self.instance_id), "1", self.ttl_secs).await?;
let _: () = conn.set_ex(hb_key, "1", self.ttl_secs).await?;
Ok(()) Ok(())
} }.await;
.await;
if let Err(e) = res { if let Err(e) = res {
warn!(error = %e, "failed to register instance in redis"); warn!(error = %e, "failed to register instance in redis");
} }
} }
async fn heartbeat_instance(&self, tunnel_count: usize) { async fn heartbeat_instance(&self, tunnel_count: usize) {
let Some(mut conn) = self.conn.clone() else { let Some(mut conn) = self.conn.clone() else { return; };
return;
};
let hb_key = format!("relay:heartbeat:{}", self.instance_id);
let load_key = format!("relay:load:{}", self.region); let load_key = format!("relay:load:{}", self.region);
let score = tunnel_count as f64; let inst_key = format!("relay:instance:{}", self.instance_id);
let key = format!("relay:instance:{}", self.instance_id);
let payload = serde_json::json!({ let payload = serde_json::json!({
"instance_id": self.instance_id, "instance_id": self.instance_id,
"region": self.region, "region": self.region,
@@ -215,25 +360,21 @@ impl RedisRegistry {
"r2r_addr": self.r2r_addr, "r2r_addr": self.r2r_addr,
"tunnel_count": tunnel_count, "tunnel_count": tunnel_count,
"updated_at": chrono::Utc::now().timestamp(), "updated_at": chrono::Utc::now().timestamp(),
}) }).to_string();
.to_string();
let res: redis::RedisResult<()> = async { let res: redis::RedisResult<()> = async {
let _: () = conn.set_ex(hb_key, "1", self.ttl_secs).await?; let _: () = conn.set_ex(format!("relay:heartbeat:{}", self.instance_id), "1", self.ttl_secs).await?;
let _: () = conn.set_ex(key, payload, self.ttl_secs).await?; let _: () = conn.set_ex(inst_key, payload, self.ttl_secs).await?;
let _: () = conn.zadd(load_key, &self.instance_id, score).await?; let _: () = conn.zadd(load_key, &self.instance_id, tunnel_count as f64).await?;
Ok(()) Ok(())
} }.await;
.await;
if let Err(e) = res { if let Err(e) = res {
warn!(error = %e, "redis instance heartbeat failed"); warn!(error = %e, "redis instance heartbeat failed");
} }
} }
async fn set_draining(&self) { async fn set_draining(&self) {
let Some(mut conn) = self.conn.clone() else { let Some(mut conn) = self.conn.clone() else { return; };
return; let inst_key = format!("relay:instance:{}", self.instance_id);
};
let key = format!("relay:instance:{}", self.instance_id);
let payload = serde_json::json!({ let payload = serde_json::json!({
"instance_id": self.instance_id, "instance_id": self.instance_id,
"region": self.region, "region": self.region,
@@ -242,37 +383,28 @@ impl RedisRegistry {
"player_addr": self.player_addr, "player_addr": self.player_addr,
"r2r_addr": self.r2r_addr, "r2r_addr": self.r2r_addr,
"updated_at": chrono::Utc::now().timestamp(), "updated_at": chrono::Utc::now().timestamp(),
}) }).to_string();
.to_string();
let _: redis::RedisResult<()> = async { let _: redis::RedisResult<()> = async {
let _: () = conn.set_ex(key, payload, self.ttl_secs).await?; let _: () = conn.set_ex(inst_key, payload, self.ttl_secs).await?;
let load_key = format!("relay:load:{}", self.region); let _: () = conn.zadd(format!("relay:load:{}", self.region), &self.instance_id, 1e12f64).await?;
let _: () = conn.zadd(load_key, &self.instance_id, 1e12f64).await?;
Ok(()) Ok(())
} }.await;
.await;
} }
async fn register_tunnel(&self, fqdn: &str, session_id: &str, user_id: &str) { async fn register_tunnel(&self, fqdn: &str, session_id: &str, user_id: &str) {
let Some(mut conn) = self.conn.clone() else { let Some(mut conn) = self.conn.clone() else { return; };
return;
};
let key = format!("tunnel:sub:{fqdn}");
let session_key = format!("tunnel:session:{session_id}");
let payload = serde_json::json!({ let payload = serde_json::json!({
"instance_id": self.instance_id, "instance_id": self.instance_id,
"session_id": session_id, "session_id": session_id,
"user_id": user_id, "user_id": user_id,
"region": self.region, "region": self.region,
"fqdn": fqdn, "fqdn": fqdn,
}) }).to_string();
.to_string();
let _: redis::RedisResult<()> = async { let _: redis::RedisResult<()> = async {
let _: () = conn.set_ex(key, &payload, self.ttl_secs).await?; let _: () = conn.set_ex(format!("tunnel:sub:{fqdn}"), &payload, self.ttl_secs).await?;
let _: () = conn.set_ex(session_key, payload, self.ttl_secs).await?; let _: () = conn.set_ex(format!("tunnel:session:{session_id}"), payload, self.ttl_secs).await?;
Ok(()) Ok(())
} }.await;
.await;
} }
async fn refresh_tunnel_session(&self, fqdn: &str, session_id: &str, user_id: &str) { async fn refresh_tunnel_session(&self, fqdn: &str, session_id: &str, user_id: &str) {
@@ -280,35 +412,24 @@ impl RedisRegistry {
} }
async fn remove_tunnel(&self, fqdn: &str, session_id: &str) { async fn remove_tunnel(&self, fqdn: &str, session_id: &str) {
let Some(mut conn) = self.conn.clone() else { let Some(mut conn) = self.conn.clone() else { return; };
return;
};
let _: redis::RedisResult<()> = async { let _: redis::RedisResult<()> = async {
let _: usize = conn.del(format!("tunnel:sub:{fqdn}")).await?; let _: usize = conn.del(format!("tunnel:sub:{fqdn}")).await?;
let _: usize = conn.del(format!("tunnel:session:{session_id}")).await?; let _: usize = conn.del(format!("tunnel:session:{session_id}")).await?;
Ok(()) Ok(())
} }.await;
.await;
} }
async fn lookup_tunnel(&self, fqdn: &str) -> Option<TunnelRouteRecord> { async fn lookup_tunnel(&self, fqdn: &str) -> Option<TunnelRouteRecord> {
let Some(mut conn) = self.conn.clone() else { let Some(mut conn) = self.conn.clone() else { return None; };
return None; let raw: Option<String> = conn.get(format!("tunnel:sub:{fqdn}")).await.ok()?;
}; serde_json::from_str(&raw?).ok()
let key = format!("tunnel:sub:{fqdn}");
let raw: Option<String> = conn.get(key).await.ok()?;
let raw = raw?;
serde_json::from_str(&raw).ok()
} }
async fn lookup_instance(&self, instance_id: &str) -> Option<RelayInstanceRecord> { async fn lookup_instance(&self, instance_id: &str) -> Option<RelayInstanceRecord> {
let Some(mut conn) = self.conn.clone() else { let Some(mut conn) = self.conn.clone() else { return None; };
return None; let raw: Option<String> = conn.get(format!("relay:instance:{instance_id}")).await.ok()?;
}; serde_json::from_str(&raw?).ok()
let key = format!("relay:instance:{instance_id}");
let raw: Option<String> = conn.get(key).await.ok()?;
let raw = raw?;
serde_json::from_str(&raw).ok()
} }
} }
@@ -323,6 +444,7 @@ async fn main() -> Result<()> {
let cfg = RelayConfig::from_env(); let cfg = RelayConfig::from_env();
let registry = RedisRegistry::from_env(&cfg).await; let registry = RedisRegistry::from_env(&cfg).await;
let guards = Arc::new(RelayGuards::from_env());
registry.register_instance().await; registry.register_instance().await;
let control_listener = TcpListener::bind(&cfg.control_bind) let control_listener = TcpListener::bind(&cfg.control_bind)
@@ -335,15 +457,7 @@ async fn main() -> Result<()> {
.await .await
.with_context(|| format!("bind r2r {}", cfg.r2r_bind))?; .with_context(|| format!("bind r2r {}", cfg.r2r_bind))?;
info!( info!(instance_id = %cfg.instance_id, region = %cfg.region, control = %cfg.control_bind, player = %cfg.player_bind, r2r = %cfg.r2r_bind, r2r_advertise = %cfg.r2r_advertise_addr, "relay started");
instance_id = %cfg.instance_id,
region = %cfg.region,
control = %cfg.control_bind,
player = %cfg.player_bind,
r2r = %cfg.r2r_bind,
r2r_advertise = %cfg.r2r_advertise_addr,
"relay started"
);
let shutdown = Arc::new(Notify::new()); let shutdown = Arc::new(Notify::new());
let state: SharedState = Arc::new(RwLock::new(RelayState::new())); let state: SharedState = Arc::new(RwLock::new(RelayState::new()));
@@ -354,6 +468,7 @@ async fn main() -> Result<()> {
cfg.clone(), cfg.clone(),
state.clone(), state.clone(),
registry.clone(), registry.clone(),
guards.clone(),
shutdown.clone(), shutdown.clone(),
)); ));
let player_task = tokio::spawn(run_player_accept_loop( let player_task = tokio::spawn(run_player_accept_loop(
@@ -361,12 +476,14 @@ async fn main() -> Result<()> {
cfg.clone(), cfg.clone(),
state.clone(), state.clone(),
registry.clone(), registry.clone(),
guards.clone(),
shutdown.clone(), shutdown.clone(),
)); ));
let r2r_task = tokio::spawn(run_r2r_accept_loop( let r2r_task = tokio::spawn(run_r2r_accept_loop(
r2r_listener, r2r_listener,
cfg.clone(), cfg.clone(),
state.clone(), state.clone(),
guards.clone(),
shutdown.clone(), shutdown.clone(),
)); ));
@@ -409,6 +526,7 @@ async fn run_control_accept_loop(
cfg: RelayConfig, cfg: RelayConfig,
state: SharedState, state: SharedState,
registry: RedisRegistry, registry: RedisRegistry,
guards: Arc<RelayGuards>,
shutdown: Arc<Notify>, shutdown: Arc<Notify>,
) -> Result<()> { ) -> Result<()> {
loop { loop {
@@ -422,8 +540,9 @@ async fn run_control_accept_loop(
let cfg = cfg.clone(); let cfg = cfg.clone();
let state = state.clone(); let state = state.clone();
let registry = registry.clone(); let registry = registry.clone();
let guards = guards.clone();
tokio::spawn(async move { tokio::spawn(async move {
if let Err(e) = handle_control_conn(stream, addr, cfg, state, registry).await { if let Err(e) = handle_control_conn(stream, addr, cfg, state, registry, guards).await {
warn!(peer = %addr, error = %e, "control connection ended with error"); warn!(peer = %addr, error = %e, "control connection ended with error");
} }
}); });
@@ -438,6 +557,7 @@ async fn run_player_accept_loop(
cfg: RelayConfig, cfg: RelayConfig,
state: SharedState, state: SharedState,
registry: RedisRegistry, registry: RedisRegistry,
guards: Arc<RelayGuards>,
shutdown: Arc<Notify>, shutdown: Arc<Notify>,
) -> Result<()> { ) -> Result<()> {
loop { loop {
@@ -448,11 +568,12 @@ async fn run_player_accept_loop(
Ok(v) => v, Ok(v) => v,
Err(e) => { warn!(error = %e, "player accept failed"); continue; } Err(e) => { warn!(error = %e, "player accept failed"); continue; }
}; };
let cfg = cfg.clone();
let state = state.clone(); let state = state.clone();
let registry = registry.clone(); let registry = registry.clone();
let cfg = cfg.clone(); let guards = guards.clone();
tokio::spawn(async move { tokio::spawn(async move {
if let Err(e) = handle_player_conn(stream, addr, cfg, state, registry).await { if let Err(e) = handle_player_conn(stream, addr, cfg, state, registry, guards).await {
debug!(peer = %addr, error = %e, "player connection closed"); debug!(peer = %addr, error = %e, "player connection closed");
} }
}); });
@@ -466,6 +587,7 @@ async fn run_r2r_accept_loop(
listener: TcpListener, listener: TcpListener,
cfg: RelayConfig, cfg: RelayConfig,
state: SharedState, state: SharedState,
guards: Arc<RelayGuards>,
shutdown: Arc<Notify>, shutdown: Arc<Notify>,
) -> Result<()> { ) -> Result<()> {
loop { loop {
@@ -476,10 +598,11 @@ async fn run_r2r_accept_loop(
Ok(v) => v, Ok(v) => v,
Err(e) => { warn!(error = %e, "r2r accept failed"); continue; } Err(e) => { warn!(error = %e, "r2r accept failed"); continue; }
}; };
let state = state.clone();
let cfg = cfg.clone(); let cfg = cfg.clone();
let state = state.clone();
let guards = guards.clone();
tokio::spawn(async move { tokio::spawn(async move {
if let Err(e) = handle_r2r_conn(stream, addr, cfg, state).await { if let Err(e) = handle_r2r_conn(stream, addr, cfg, state, guards).await {
warn!(peer = %addr, error = %e, "r2r connection ended with error"); warn!(peer = %addr, error = %e, "r2r connection ended with error");
} }
}); });
@@ -495,25 +618,23 @@ async fn handle_control_conn(
cfg: RelayConfig, cfg: RelayConfig,
state: SharedState, state: SharedState,
registry: RedisRegistry, registry: RedisRegistry,
guards: Arc<RelayGuards>,
) -> Result<()> { ) -> Result<()> {
if !guards.allow_registration_ip(&addr.ip().to_string()).await {
anyhow::bail!("registration rate limited for {}", addr.ip());
}
let (mut reader, mut writer) = stream.into_split(); let (mut reader, mut writer) = stream.into_split();
let first: ClientFrame = read_frame(&mut reader).await.context("read initial frame")?; let first: ClientFrame = read_frame(&mut reader).await.context("read initial frame")?;
let register = match first { let register = match first {
ClientFrame::Register(req) => req, ClientFrame::Register(req) => req,
_ => { _ => {
write_frame( write_frame(&mut writer, &ServerFrame::RegisterRejected { reason: "expected Register frame".to_string() }).await.ok();
&mut writer,
&ServerFrame::RegisterRejected { reason: "expected Register frame".to_string() },
).await.ok();
anyhow::bail!("expected register frame"); anyhow::bail!("expected register frame");
} }
}; };
if !token_looks_valid(&register.token) { if !token_looks_valid(&register.token) {
write_frame( write_frame(&mut writer, &ServerFrame::RegisterRejected { reason: "invalid token".to_string() }).await.ok();
&mut writer,
&ServerFrame::RegisterRejected { reason: "invalid token".to_string() },
).await.ok();
anyhow::bail!("invalid token"); anyhow::bail!("invalid token");
} }
@@ -525,27 +646,25 @@ async fn handle_control_conn(
{ {
let mut guard = state.write().await; let mut guard = state.write().await;
let handle = SessionHandle {
session_id: session_id.clone(),
tx: tx.clone(),
stream_sinks: stream_sinks.clone(),
last_heartbeat: Instant::now(),
};
guard.by_session.insert(session_id.clone(), fqdn.clone()); guard.by_session.insert(session_id.clone(), fqdn.clone());
guard.by_fqdn.insert(fqdn.clone(), handle); guard.by_fqdn.insert(
fqdn.clone(),
SessionHandle {
session_id: session_id.clone(),
tx: tx.clone(),
stream_sinks: stream_sinks.clone(),
last_heartbeat: Instant::now(),
},
);
} }
registry.register_tunnel(&fqdn, &session_id, &user_id).await; registry.register_tunnel(&fqdn, &session_id, &user_id).await;
write_frame( write_frame(&mut writer, &ServerFrame::RegisterAccepted(RegisterAccepted {
&mut writer, session_id: session_id.clone(),
&ServerFrame::RegisterAccepted(RegisterAccepted { fqdn: fqdn.clone(),
session_id: session_id.clone(), heartbeat_interval_secs: 5,
fqdn: fqdn.clone(), owner_instance_id: cfg.instance_id.clone(),
heartbeat_interval_secs: 5, })).await?;
owner_instance_id: cfg.instance_id.clone(),
}),
)
.await?;
info!(peer = %addr, user_id = %user_id, fqdn = %fqdn, session_id = %session_id, "client registered"); info!(peer = %addr, user_id = %user_id, fqdn = %fqdn, session_id = %session_id, "client registered");
let write_task = tokio::spawn(async move { let write_task = tokio::spawn(async move {
@@ -559,12 +678,12 @@ async fn handle_control_conn(
&mut reader, &mut reader,
&state, &state,
&registry, &registry,
&guards,
&session_id, &session_id,
&fqdn, &fqdn,
&user_id, &user_id,
cfg.heartbeat_timeout, cfg.heartbeat_timeout,
) ).await;
.await;
if let Err(e) = &read_result { if let Err(e) = &read_result {
warn!(session_id = %session_id, error = %e, "control read loop error"); warn!(session_id = %session_id, error = %e, "control read loop error");
@@ -577,6 +696,7 @@ async fn handle_control_conn(
} }
} }
registry.remove_tunnel(&fqdn, &session_id).await; registry.remove_tunnel(&fqdn, &session_id).await;
guards.remove_session(&session_id).await;
write_task.abort(); write_task.abort();
info!(session_id = %session_id, "client session removed"); info!(session_id = %session_id, "client session removed");
read_result read_result
@@ -586,15 +706,14 @@ async fn control_read_loop(
reader: &mut tokio::net::tcp::OwnedReadHalf, reader: &mut tokio::net::tcp::OwnedReadHalf,
state: &SharedState, state: &SharedState,
registry: &RedisRegistry, registry: &RedisRegistry,
guards: &RelayGuards,
session_id: &str, session_id: &str,
fqdn: &str, fqdn: &str,
user_id: &str, user_id: &str,
heartbeat_timeout: Duration, heartbeat_timeout: Duration,
) -> Result<()> { ) -> Result<()> {
loop { loop {
let frame: ClientFrame = timeout(heartbeat_timeout, read_frame(reader)) let frame: ClientFrame = timeout(heartbeat_timeout, read_frame(reader)).await.context("heartbeat timeout")??;
.await
.context("heartbeat timeout")??;
match frame { match frame {
ClientFrame::Heartbeat(Heartbeat { session_id: hb_id, .. }) => { ClientFrame::Heartbeat(Heartbeat { session_id: hb_id, .. }) => {
if hb_id != session_id { if hb_id != session_id {
@@ -610,6 +729,9 @@ async fn control_read_loop(
registry.refresh_tunnel_session(fqdn, session_id, user_id).await; registry.refresh_tunnel_session(fqdn, session_id, user_id).await;
} }
ClientFrame::StreamData(StreamData { stream_id, data }) => { ClientFrame::StreamData(StreamData { stream_id, data }) => {
guards
.throttle_session_bytes(session_id, SessionDir::IngressFromClient, data.len())
.await;
let sink = lookup_stream_sink(state, session_id, &stream_id).await; let sink = lookup_stream_sink(state, session_id, &stream_id).await;
if let Some(tx) = sink { if let Some(tx) = sink {
if tx.send(data).await.is_err() { if tx.send(data).await.is_err() {
@@ -632,7 +754,13 @@ async fn handle_player_conn(
cfg: RelayConfig, cfg: RelayConfig,
state: SharedState, state: SharedState,
registry: RedisRegistry, registry: RedisRegistry,
guards: Arc<RelayGuards>,
) -> Result<()> { ) -> Result<()> {
if !guards.allow_player_ip(&addr.ip().to_string()).await {
debug!(peer = %addr, "player connect rate limited");
return Ok(());
}
let (hostname, initial_data) = read_handshake_hostname_and_bytes(&mut stream) let (hostname, initial_data) = read_handshake_hostname_and_bytes(&mut stream)
.await .await
.context("parse minecraft handshake")?; .context("parse minecraft handshake")?;
@@ -646,6 +774,7 @@ async fn handle_player_conn(
initial_data, initial_data,
None, None,
"direct", "direct",
guards,
) )
.await; .await;
} }
@@ -667,6 +796,7 @@ async fn handle_r2r_conn(
addr: SocketAddr, addr: SocketAddr,
_cfg: RelayConfig, _cfg: RelayConfig,
state: SharedState, state: SharedState,
guards: Arc<RelayGuards>,
) -> Result<()> { ) -> Result<()> {
let prelude: RelayForwardPrelude = read_frame(&mut stream).await.context("read r2r prelude")?; let prelude: RelayForwardPrelude = read_frame(&mut stream).await.context("read r2r prelude")?;
if prelude.version != 1 { if prelude.version != 1 {
@@ -689,6 +819,7 @@ async fn handle_r2r_conn(
prelude.initial_data, prelude.initial_data,
Some(prelude.stream_id), Some(prelude.stream_id),
"r2r", "r2r",
guards,
) )
.await .await
.with_context(|| format!("r2r attach failed from {addr}")) .with_context(|| format!("r2r attach failed from {addr}"))
@@ -707,9 +838,6 @@ async fn proxy_player_to_owner(
.lookup_instance(&route.instance_id) .lookup_instance(&route.instance_id)
.await .await
.with_context(|| format!("owner instance {} not found in redis", route.instance_id))?; .with_context(|| format!("owner instance {} not found in redis", route.instance_id))?;
if owner.status.as_deref() == Some("draining") {
debug!(owner = %route.instance_id, hostname = %hostname, "owner draining; attempting forward anyway");
}
let r2r_addr = owner let r2r_addr = owner
.r2r_addr .r2r_addr
.clone() .clone()
@@ -725,7 +853,7 @@ async fn proxy_player_to_owner(
fqdn: hostname.clone(), fqdn: hostname.clone(),
stream_id: Uuid::new_v4().to_string(), stream_id: Uuid::new_v4().to_string(),
peer_addr: player_addr.to_string(), peer_addr: player_addr.to_string(),
origin_instance_id: cfg.instance_id.clone(), origin_instance_id: cfg.instance_id,
hop_count: 1, hop_count: 1,
initial_data, initial_data,
}; };
@@ -737,8 +865,7 @@ async fn proxy_player_to_owner(
} }
async fn local_session_for_hostname(state: &SharedState, hostname: &str) -> Option<SessionHandle> { async fn local_session_for_hostname(state: &SharedState, hostname: &str) -> Option<SessionHandle> {
let guard = state.read().await; state.read().await.by_fqdn.get(hostname).cloned()
guard.by_fqdn.get(hostname).cloned()
} }
async fn local_session_for_session_id(state: &SharedState, session_id: &str) -> Option<SessionHandle> { async fn local_session_for_session_id(state: &SharedState, session_id: &str) -> Option<SessionHandle> {
@@ -755,15 +882,12 @@ async fn attach_player_socket_to_session(
initial_data: Vec<u8>, initial_data: Vec<u8>,
stream_id_override: Option<String>, stream_id_override: Option<String>,
source: &'static str, source: &'static str,
guards: Arc<RelayGuards>,
) -> Result<()> { ) -> Result<()> {
let stream_id = stream_id_override.unwrap_or_else(|| Uuid::new_v4().to_string()); let stream_id = stream_id_override.unwrap_or_else(|| Uuid::new_v4().to_string());
let (player_read, player_write) = stream.into_split(); let (player_read, player_write) = stream.into_split();
let (to_player_tx, to_player_rx) = mpsc::channel::<Vec<u8>>(128); let (to_player_tx, to_player_rx) = mpsc::channel::<Vec<u8>>(128);
session session.stream_sinks.write().await.insert(stream_id.clone(), to_player_tx);
.stream_sinks
.write()
.await
.insert(stream_id.clone(), to_player_tx);
session session
.tx .tx
@@ -795,9 +919,19 @@ async fn attach_player_socket_to_session(
let tx_control = session.tx.clone(); let tx_control = session.tx.clone();
let stream_id_clone = stream_id.clone(); let stream_id_clone = stream_id.clone();
let session_id_clone = session.session_id.clone();
let sinks = session.stream_sinks.clone(); let sinks = session.stream_sinks.clone();
let guards_clone = guards.clone();
tokio::spawn(async move { tokio::spawn(async move {
if let Err(e) = run_player_reader(player_read, tx_control.clone(), stream_id_clone.clone()).await { if let Err(e) = run_player_reader(
player_read,
tx_control.clone(),
stream_id_clone.clone(),
session_id_clone,
guards_clone,
)
.await
{
debug!(stream_id = %stream_id_clone, error = %e, "player reader ended"); debug!(stream_id = %stream_id_clone, error = %e, "player reader ended");
} }
let _ = tx_control let _ = tx_control
@@ -817,6 +951,8 @@ async fn run_player_reader(
mut reader: tokio::net::tcp::OwnedReadHalf, mut reader: tokio::net::tcp::OwnedReadHalf,
tx_control: mpsc::Sender<ServerFrame>, tx_control: mpsc::Sender<ServerFrame>,
stream_id: String, stream_id: String,
session_id: String,
guards: Arc<RelayGuards>,
) -> Result<()> { ) -> Result<()> {
let mut buf = vec![0u8; 16 * 1024]; let mut buf = vec![0u8; 16 * 1024];
loop { loop {
@@ -824,6 +960,9 @@ async fn run_player_reader(
if n == 0 { if n == 0 {
break; break;
} }
guards
.throttle_session_bytes(&session_id, SessionDir::EgressToClient, n)
.await;
tx_control tx_control
.send(ServerFrame::StreamData(StreamData { .send(ServerFrame::StreamData(StreamData {
stream_id: stream_id.clone(), stream_id: stream_id.clone(),