chore: update Rust dependencies.
This commit is contained in:
1
Cargo.lock
generated
1
Cargo.lock
generated
@@ -921,6 +921,7 @@ dependencies = [
|
||||
"common",
|
||||
"fastrand",
|
||||
"redis",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"tokio",
|
||||
"tracing",
|
||||
|
||||
@@ -19,7 +19,7 @@ use serde::Deserialize;
|
||||
use tokio::{
|
||||
io::{AsyncReadExt, AsyncWriteExt, copy_bidirectional},
|
||||
net::{TcpListener, TcpStream},
|
||||
sync::{Notify, RwLock, mpsc},
|
||||
sync::{Mutex, Notify, RwLock, mpsc},
|
||||
time::{MissedTickBehavior, interval, timeout},
|
||||
};
|
||||
use tracing::{debug, info, warn};
|
||||
@@ -48,6 +48,7 @@ impl RelayConfig {
|
||||
let r2r_bind = std::env::var("RELAY_R2R_BIND").unwrap_or_else(|_| "0.0.0.0:7001".to_string());
|
||||
let r2r_advertise_addr = std::env::var("RELAY_R2R_ADVERTISE_ADDR")
|
||||
.unwrap_or_else(|_| guess_advertise_addr(&r2r_bind));
|
||||
|
||||
Self {
|
||||
instance_id: std::env::var("RELAY_INSTANCE_ID")
|
||||
.unwrap_or_else(|_| format!("relay-{}", Uuid::new_v4())),
|
||||
@@ -105,6 +106,159 @@ impl RelayState {
|
||||
|
||||
type SharedState = Arc<RwLock<RelayState>>;
|
||||
|
||||
#[derive(Clone)]
|
||||
struct RelayGuards {
|
||||
player_ip: Arc<Mutex<HashMap<String, BucketState>>>,
|
||||
reg_ip: Arc<Mutex<HashMap<String, BucketState>>>,
|
||||
session_ingress: Arc<Mutex<HashMap<String, BucketState>>>,
|
||||
session_egress: Arc<Mutex<HashMap<String, BucketState>>>,
|
||||
player_ip_rate: f64,
|
||||
player_ip_burst: f64,
|
||||
reg_ip_rate: f64,
|
||||
reg_ip_burst: f64,
|
||||
session_bw_rate_bytes: f64,
|
||||
session_bw_burst_bytes: f64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct BucketState {
|
||||
tokens: f64,
|
||||
last_refill: Instant,
|
||||
capacity: f64,
|
||||
rate_per_sec: f64,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
enum SessionDir {
|
||||
IngressFromClient,
|
||||
EgressToClient,
|
||||
}
|
||||
|
||||
impl BucketState {
|
||||
fn new(capacity: f64, rate_per_sec: f64) -> Self {
|
||||
Self {
|
||||
tokens: capacity,
|
||||
last_refill: Instant::now(),
|
||||
capacity,
|
||||
rate_per_sec,
|
||||
}
|
||||
}
|
||||
|
||||
fn reserve_delay(&mut self, amount: usize) -> Duration {
|
||||
let now = Instant::now();
|
||||
let elapsed = now.saturating_duration_since(self.last_refill).as_secs_f64();
|
||||
if elapsed > 0.0 {
|
||||
self.tokens = (self.tokens + elapsed * self.rate_per_sec).min(self.capacity);
|
||||
self.last_refill = now;
|
||||
}
|
||||
|
||||
let amount = amount as f64;
|
||||
if self.tokens >= amount {
|
||||
self.tokens -= amount;
|
||||
return Duration::ZERO;
|
||||
}
|
||||
if self.rate_per_sec <= 0.0 {
|
||||
return Duration::from_secs(1);
|
||||
}
|
||||
|
||||
let deficit = amount - self.tokens;
|
||||
let wait = Duration::from_secs_f64(deficit / self.rate_per_sec);
|
||||
self.tokens = 0.0;
|
||||
self.last_refill = now + wait;
|
||||
wait
|
||||
}
|
||||
}
|
||||
|
||||
impl RelayGuards {
|
||||
fn from_env() -> Self {
|
||||
let player_ip_rate = std::env::var("RELAY_PLAYER_CONNECTS_PER_SEC")
|
||||
.ok()
|
||||
.and_then(|v| v.parse().ok())
|
||||
.unwrap_or(2.0);
|
||||
let player_ip_burst = std::env::var("RELAY_PLAYER_CONNECTS_BURST")
|
||||
.ok()
|
||||
.and_then(|v| v.parse().ok())
|
||||
.unwrap_or(20.0);
|
||||
|
||||
let reg_per_min = std::env::var("RELAY_REG_ATTEMPTS_PER_MIN")
|
||||
.ok()
|
||||
.and_then(|v| v.parse().ok())
|
||||
.unwrap_or(10.0);
|
||||
let reg_ip_rate = reg_per_min / 60.0;
|
||||
let reg_ip_burst = std::env::var("RELAY_REG_ATTEMPTS_BURST")
|
||||
.ok()
|
||||
.and_then(|v| v.parse().ok())
|
||||
.unwrap_or(10.0);
|
||||
|
||||
let session_bw_kbps = std::env::var("RELAY_SESSION_BW_KBPS")
|
||||
.ok()
|
||||
.and_then(|v| v.parse().ok())
|
||||
.unwrap_or(8192.0);
|
||||
let session_bw_rate_bytes = session_bw_kbps * 1024.0 / 8.0;
|
||||
let session_bw_burst_bytes = std::env::var("RELAY_SESSION_BW_BURST_KB")
|
||||
.ok()
|
||||
.and_then(|v| v.parse().ok())
|
||||
.unwrap_or(512.0)
|
||||
* 1024.0;
|
||||
|
||||
Self {
|
||||
player_ip: Arc::new(Mutex::new(HashMap::new())),
|
||||
reg_ip: Arc::new(Mutex::new(HashMap::new())),
|
||||
session_ingress: Arc::new(Mutex::new(HashMap::new())),
|
||||
session_egress: Arc::new(Mutex::new(HashMap::new())),
|
||||
player_ip_rate,
|
||||
player_ip_burst,
|
||||
reg_ip_rate,
|
||||
reg_ip_burst,
|
||||
session_bw_rate_bytes,
|
||||
session_bw_burst_bytes,
|
||||
}
|
||||
}
|
||||
|
||||
async fn allow_player_ip(&self, ip: &str) -> bool {
|
||||
self.allow_ip(ip, true).await
|
||||
}
|
||||
|
||||
async fn allow_registration_ip(&self, ip: &str) -> bool {
|
||||
self.allow_ip(ip, false).await
|
||||
}
|
||||
|
||||
async fn allow_ip(&self, ip: &str, player: bool) -> bool {
|
||||
let (map, rate, burst) = if player {
|
||||
(&self.player_ip, self.player_ip_rate, self.player_ip_burst)
|
||||
} else {
|
||||
(&self.reg_ip, self.reg_ip_rate, self.reg_ip_burst)
|
||||
};
|
||||
let mut guard = map.lock().await;
|
||||
let bucket = guard
|
||||
.entry(ip.to_string())
|
||||
.or_insert_with(|| BucketState::new(burst, rate));
|
||||
bucket.reserve_delay(1).is_zero()
|
||||
}
|
||||
|
||||
async fn throttle_session_bytes(&self, session_id: &str, dir: SessionDir, bytes: usize) {
|
||||
let delay = {
|
||||
let map = match dir {
|
||||
SessionDir::IngressFromClient => &self.session_ingress,
|
||||
SessionDir::EgressToClient => &self.session_egress,
|
||||
};
|
||||
let mut guard = map.lock().await;
|
||||
let bucket = guard.entry(session_id.to_string()).or_insert_with(|| {
|
||||
BucketState::new(self.session_bw_burst_bytes, self.session_bw_rate_bytes)
|
||||
});
|
||||
bucket.reserve_delay(bytes)
|
||||
};
|
||||
if !delay.is_zero() {
|
||||
tokio::time::sleep(delay).await;
|
||||
}
|
||||
}
|
||||
|
||||
async fn remove_session(&self, session_id: &str) {
|
||||
self.session_ingress.lock().await.remove(session_id);
|
||||
self.session_egress.lock().await.remove(session_id);
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
struct TunnelRouteRecord {
|
||||
instance_id: String,
|
||||
@@ -170,9 +324,7 @@ impl RedisRegistry {
|
||||
}
|
||||
|
||||
async fn register_instance(&self) {
|
||||
let Some(mut conn) = self.conn.clone() else {
|
||||
return;
|
||||
};
|
||||
let Some(mut conn) = self.conn.clone() else { return; };
|
||||
let key = format!("relay:instance:{}", self.instance_id);
|
||||
let payload = serde_json::json!({
|
||||
"instance_id": self.instance_id,
|
||||
@@ -182,30 +334,23 @@ impl RedisRegistry {
|
||||
"player_addr": self.player_addr,
|
||||
"r2r_addr": self.r2r_addr,
|
||||
"started_at": chrono::Utc::now().timestamp(),
|
||||
})
|
||||
.to_string();
|
||||
}).to_string();
|
||||
|
||||
let res: redis::RedisResult<()> = async {
|
||||
let _: usize = conn.sadd("relay:instances", &self.instance_id).await?;
|
||||
let _: () = conn.set_ex(&key, payload, self.ttl_secs).await?;
|
||||
let hb_key = format!("relay:heartbeat:{}", self.instance_id);
|
||||
let _: () = conn.set_ex(hb_key, "1", self.ttl_secs).await?;
|
||||
let _: () = conn.set_ex(format!("relay:heartbeat:{}", self.instance_id), "1", self.ttl_secs).await?;
|
||||
Ok(())
|
||||
}
|
||||
.await;
|
||||
}.await;
|
||||
if let Err(e) = res {
|
||||
warn!(error = %e, "failed to register instance in redis");
|
||||
}
|
||||
}
|
||||
|
||||
async fn heartbeat_instance(&self, tunnel_count: usize) {
|
||||
let Some(mut conn) = self.conn.clone() else {
|
||||
return;
|
||||
};
|
||||
let hb_key = format!("relay:heartbeat:{}", self.instance_id);
|
||||
let Some(mut conn) = self.conn.clone() else { return; };
|
||||
let load_key = format!("relay:load:{}", self.region);
|
||||
let score = tunnel_count as f64;
|
||||
let key = format!("relay:instance:{}", self.instance_id);
|
||||
let inst_key = format!("relay:instance:{}", self.instance_id);
|
||||
let payload = serde_json::json!({
|
||||
"instance_id": self.instance_id,
|
||||
"region": self.region,
|
||||
@@ -215,25 +360,21 @@ impl RedisRegistry {
|
||||
"r2r_addr": self.r2r_addr,
|
||||
"tunnel_count": tunnel_count,
|
||||
"updated_at": chrono::Utc::now().timestamp(),
|
||||
})
|
||||
.to_string();
|
||||
}).to_string();
|
||||
let res: redis::RedisResult<()> = async {
|
||||
let _: () = conn.set_ex(hb_key, "1", self.ttl_secs).await?;
|
||||
let _: () = conn.set_ex(key, payload, self.ttl_secs).await?;
|
||||
let _: () = conn.zadd(load_key, &self.instance_id, score).await?;
|
||||
let _: () = conn.set_ex(format!("relay:heartbeat:{}", self.instance_id), "1", self.ttl_secs).await?;
|
||||
let _: () = conn.set_ex(inst_key, payload, self.ttl_secs).await?;
|
||||
let _: () = conn.zadd(load_key, &self.instance_id, tunnel_count as f64).await?;
|
||||
Ok(())
|
||||
}
|
||||
.await;
|
||||
}.await;
|
||||
if let Err(e) = res {
|
||||
warn!(error = %e, "redis instance heartbeat failed");
|
||||
}
|
||||
}
|
||||
|
||||
async fn set_draining(&self) {
|
||||
let Some(mut conn) = self.conn.clone() else {
|
||||
return;
|
||||
};
|
||||
let key = format!("relay:instance:{}", self.instance_id);
|
||||
let Some(mut conn) = self.conn.clone() else { return; };
|
||||
let inst_key = format!("relay:instance:{}", self.instance_id);
|
||||
let payload = serde_json::json!({
|
||||
"instance_id": self.instance_id,
|
||||
"region": self.region,
|
||||
@@ -242,37 +383,28 @@ impl RedisRegistry {
|
||||
"player_addr": self.player_addr,
|
||||
"r2r_addr": self.r2r_addr,
|
||||
"updated_at": chrono::Utc::now().timestamp(),
|
||||
})
|
||||
.to_string();
|
||||
}).to_string();
|
||||
let _: redis::RedisResult<()> = async {
|
||||
let _: () = conn.set_ex(key, payload, self.ttl_secs).await?;
|
||||
let load_key = format!("relay:load:{}", self.region);
|
||||
let _: () = conn.zadd(load_key, &self.instance_id, 1e12f64).await?;
|
||||
let _: () = conn.set_ex(inst_key, payload, self.ttl_secs).await?;
|
||||
let _: () = conn.zadd(format!("relay:load:{}", self.region), &self.instance_id, 1e12f64).await?;
|
||||
Ok(())
|
||||
}
|
||||
.await;
|
||||
}.await;
|
||||
}
|
||||
|
||||
async fn register_tunnel(&self, fqdn: &str, session_id: &str, user_id: &str) {
|
||||
let Some(mut conn) = self.conn.clone() else {
|
||||
return;
|
||||
};
|
||||
let key = format!("tunnel:sub:{fqdn}");
|
||||
let session_key = format!("tunnel:session:{session_id}");
|
||||
let Some(mut conn) = self.conn.clone() else { return; };
|
||||
let payload = serde_json::json!({
|
||||
"instance_id": self.instance_id,
|
||||
"session_id": session_id,
|
||||
"user_id": user_id,
|
||||
"region": self.region,
|
||||
"fqdn": fqdn,
|
||||
})
|
||||
.to_string();
|
||||
}).to_string();
|
||||
let _: redis::RedisResult<()> = async {
|
||||
let _: () = conn.set_ex(key, &payload, self.ttl_secs).await?;
|
||||
let _: () = conn.set_ex(session_key, payload, self.ttl_secs).await?;
|
||||
let _: () = conn.set_ex(format!("tunnel:sub:{fqdn}"), &payload, self.ttl_secs).await?;
|
||||
let _: () = conn.set_ex(format!("tunnel:session:{session_id}"), payload, self.ttl_secs).await?;
|
||||
Ok(())
|
||||
}
|
||||
.await;
|
||||
}.await;
|
||||
}
|
||||
|
||||
async fn refresh_tunnel_session(&self, fqdn: &str, session_id: &str, user_id: &str) {
|
||||
@@ -280,35 +412,24 @@ impl RedisRegistry {
|
||||
}
|
||||
|
||||
async fn remove_tunnel(&self, fqdn: &str, session_id: &str) {
|
||||
let Some(mut conn) = self.conn.clone() else {
|
||||
return;
|
||||
};
|
||||
let Some(mut conn) = self.conn.clone() else { return; };
|
||||
let _: redis::RedisResult<()> = async {
|
||||
let _: usize = conn.del(format!("tunnel:sub:{fqdn}")).await?;
|
||||
let _: usize = conn.del(format!("tunnel:session:{session_id}")).await?;
|
||||
Ok(())
|
||||
}
|
||||
.await;
|
||||
}.await;
|
||||
}
|
||||
|
||||
async fn lookup_tunnel(&self, fqdn: &str) -> Option<TunnelRouteRecord> {
|
||||
let Some(mut conn) = self.conn.clone() else {
|
||||
return None;
|
||||
};
|
||||
let key = format!("tunnel:sub:{fqdn}");
|
||||
let raw: Option<String> = conn.get(key).await.ok()?;
|
||||
let raw = raw?;
|
||||
serde_json::from_str(&raw).ok()
|
||||
let Some(mut conn) = self.conn.clone() else { return None; };
|
||||
let raw: Option<String> = conn.get(format!("tunnel:sub:{fqdn}")).await.ok()?;
|
||||
serde_json::from_str(&raw?).ok()
|
||||
}
|
||||
|
||||
async fn lookup_instance(&self, instance_id: &str) -> Option<RelayInstanceRecord> {
|
||||
let Some(mut conn) = self.conn.clone() else {
|
||||
return None;
|
||||
};
|
||||
let key = format!("relay:instance:{instance_id}");
|
||||
let raw: Option<String> = conn.get(key).await.ok()?;
|
||||
let raw = raw?;
|
||||
serde_json::from_str(&raw).ok()
|
||||
let Some(mut conn) = self.conn.clone() else { return None; };
|
||||
let raw: Option<String> = conn.get(format!("relay:instance:{instance_id}")).await.ok()?;
|
||||
serde_json::from_str(&raw?).ok()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -323,6 +444,7 @@ async fn main() -> Result<()> {
|
||||
|
||||
let cfg = RelayConfig::from_env();
|
||||
let registry = RedisRegistry::from_env(&cfg).await;
|
||||
let guards = Arc::new(RelayGuards::from_env());
|
||||
registry.register_instance().await;
|
||||
|
||||
let control_listener = TcpListener::bind(&cfg.control_bind)
|
||||
@@ -335,15 +457,7 @@ async fn main() -> Result<()> {
|
||||
.await
|
||||
.with_context(|| format!("bind r2r {}", cfg.r2r_bind))?;
|
||||
|
||||
info!(
|
||||
instance_id = %cfg.instance_id,
|
||||
region = %cfg.region,
|
||||
control = %cfg.control_bind,
|
||||
player = %cfg.player_bind,
|
||||
r2r = %cfg.r2r_bind,
|
||||
r2r_advertise = %cfg.r2r_advertise_addr,
|
||||
"relay started"
|
||||
);
|
||||
info!(instance_id = %cfg.instance_id, region = %cfg.region, control = %cfg.control_bind, player = %cfg.player_bind, r2r = %cfg.r2r_bind, r2r_advertise = %cfg.r2r_advertise_addr, "relay started");
|
||||
|
||||
let shutdown = Arc::new(Notify::new());
|
||||
let state: SharedState = Arc::new(RwLock::new(RelayState::new()));
|
||||
@@ -354,6 +468,7 @@ async fn main() -> Result<()> {
|
||||
cfg.clone(),
|
||||
state.clone(),
|
||||
registry.clone(),
|
||||
guards.clone(),
|
||||
shutdown.clone(),
|
||||
));
|
||||
let player_task = tokio::spawn(run_player_accept_loop(
|
||||
@@ -361,12 +476,14 @@ async fn main() -> Result<()> {
|
||||
cfg.clone(),
|
||||
state.clone(),
|
||||
registry.clone(),
|
||||
guards.clone(),
|
||||
shutdown.clone(),
|
||||
));
|
||||
let r2r_task = tokio::spawn(run_r2r_accept_loop(
|
||||
r2r_listener,
|
||||
cfg.clone(),
|
||||
state.clone(),
|
||||
guards.clone(),
|
||||
shutdown.clone(),
|
||||
));
|
||||
|
||||
@@ -409,6 +526,7 @@ async fn run_control_accept_loop(
|
||||
cfg: RelayConfig,
|
||||
state: SharedState,
|
||||
registry: RedisRegistry,
|
||||
guards: Arc<RelayGuards>,
|
||||
shutdown: Arc<Notify>,
|
||||
) -> Result<()> {
|
||||
loop {
|
||||
@@ -422,8 +540,9 @@ async fn run_control_accept_loop(
|
||||
let cfg = cfg.clone();
|
||||
let state = state.clone();
|
||||
let registry = registry.clone();
|
||||
let guards = guards.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = handle_control_conn(stream, addr, cfg, state, registry).await {
|
||||
if let Err(e) = handle_control_conn(stream, addr, cfg, state, registry, guards).await {
|
||||
warn!(peer = %addr, error = %e, "control connection ended with error");
|
||||
}
|
||||
});
|
||||
@@ -438,6 +557,7 @@ async fn run_player_accept_loop(
|
||||
cfg: RelayConfig,
|
||||
state: SharedState,
|
||||
registry: RedisRegistry,
|
||||
guards: Arc<RelayGuards>,
|
||||
shutdown: Arc<Notify>,
|
||||
) -> Result<()> {
|
||||
loop {
|
||||
@@ -448,11 +568,12 @@ async fn run_player_accept_loop(
|
||||
Ok(v) => v,
|
||||
Err(e) => { warn!(error = %e, "player accept failed"); continue; }
|
||||
};
|
||||
let cfg = cfg.clone();
|
||||
let state = state.clone();
|
||||
let registry = registry.clone();
|
||||
let cfg = cfg.clone();
|
||||
let guards = guards.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = handle_player_conn(stream, addr, cfg, state, registry).await {
|
||||
if let Err(e) = handle_player_conn(stream, addr, cfg, state, registry, guards).await {
|
||||
debug!(peer = %addr, error = %e, "player connection closed");
|
||||
}
|
||||
});
|
||||
@@ -466,6 +587,7 @@ async fn run_r2r_accept_loop(
|
||||
listener: TcpListener,
|
||||
cfg: RelayConfig,
|
||||
state: SharedState,
|
||||
guards: Arc<RelayGuards>,
|
||||
shutdown: Arc<Notify>,
|
||||
) -> Result<()> {
|
||||
loop {
|
||||
@@ -476,10 +598,11 @@ async fn run_r2r_accept_loop(
|
||||
Ok(v) => v,
|
||||
Err(e) => { warn!(error = %e, "r2r accept failed"); continue; }
|
||||
};
|
||||
let state = state.clone();
|
||||
let cfg = cfg.clone();
|
||||
let state = state.clone();
|
||||
let guards = guards.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = handle_r2r_conn(stream, addr, cfg, state).await {
|
||||
if let Err(e) = handle_r2r_conn(stream, addr, cfg, state, guards).await {
|
||||
warn!(peer = %addr, error = %e, "r2r connection ended with error");
|
||||
}
|
||||
});
|
||||
@@ -495,25 +618,23 @@ async fn handle_control_conn(
|
||||
cfg: RelayConfig,
|
||||
state: SharedState,
|
||||
registry: RedisRegistry,
|
||||
guards: Arc<RelayGuards>,
|
||||
) -> Result<()> {
|
||||
if !guards.allow_registration_ip(&addr.ip().to_string()).await {
|
||||
anyhow::bail!("registration rate limited for {}", addr.ip());
|
||||
}
|
||||
|
||||
let (mut reader, mut writer) = stream.into_split();
|
||||
let first: ClientFrame = read_frame(&mut reader).await.context("read initial frame")?;
|
||||
let register = match first {
|
||||
ClientFrame::Register(req) => req,
|
||||
_ => {
|
||||
write_frame(
|
||||
&mut writer,
|
||||
&ServerFrame::RegisterRejected { reason: "expected Register frame".to_string() },
|
||||
).await.ok();
|
||||
write_frame(&mut writer, &ServerFrame::RegisterRejected { reason: "expected Register frame".to_string() }).await.ok();
|
||||
anyhow::bail!("expected register frame");
|
||||
}
|
||||
};
|
||||
|
||||
if !token_looks_valid(®ister.token) {
|
||||
write_frame(
|
||||
&mut writer,
|
||||
&ServerFrame::RegisterRejected { reason: "invalid token".to_string() },
|
||||
).await.ok();
|
||||
write_frame(&mut writer, &ServerFrame::RegisterRejected { reason: "invalid token".to_string() }).await.ok();
|
||||
anyhow::bail!("invalid token");
|
||||
}
|
||||
|
||||
@@ -525,27 +646,25 @@ async fn handle_control_conn(
|
||||
|
||||
{
|
||||
let mut guard = state.write().await;
|
||||
let handle = SessionHandle {
|
||||
session_id: session_id.clone(),
|
||||
tx: tx.clone(),
|
||||
stream_sinks: stream_sinks.clone(),
|
||||
last_heartbeat: Instant::now(),
|
||||
};
|
||||
guard.by_session.insert(session_id.clone(), fqdn.clone());
|
||||
guard.by_fqdn.insert(fqdn.clone(), handle);
|
||||
guard.by_fqdn.insert(
|
||||
fqdn.clone(),
|
||||
SessionHandle {
|
||||
session_id: session_id.clone(),
|
||||
tx: tx.clone(),
|
||||
stream_sinks: stream_sinks.clone(),
|
||||
last_heartbeat: Instant::now(),
|
||||
},
|
||||
);
|
||||
}
|
||||
registry.register_tunnel(&fqdn, &session_id, &user_id).await;
|
||||
|
||||
write_frame(
|
||||
&mut writer,
|
||||
&ServerFrame::RegisterAccepted(RegisterAccepted {
|
||||
session_id: session_id.clone(),
|
||||
fqdn: fqdn.clone(),
|
||||
heartbeat_interval_secs: 5,
|
||||
owner_instance_id: cfg.instance_id.clone(),
|
||||
}),
|
||||
)
|
||||
.await?;
|
||||
write_frame(&mut writer, &ServerFrame::RegisterAccepted(RegisterAccepted {
|
||||
session_id: session_id.clone(),
|
||||
fqdn: fqdn.clone(),
|
||||
heartbeat_interval_secs: 5,
|
||||
owner_instance_id: cfg.instance_id.clone(),
|
||||
})).await?;
|
||||
info!(peer = %addr, user_id = %user_id, fqdn = %fqdn, session_id = %session_id, "client registered");
|
||||
|
||||
let write_task = tokio::spawn(async move {
|
||||
@@ -559,12 +678,12 @@ async fn handle_control_conn(
|
||||
&mut reader,
|
||||
&state,
|
||||
®istry,
|
||||
&guards,
|
||||
&session_id,
|
||||
&fqdn,
|
||||
&user_id,
|
||||
cfg.heartbeat_timeout,
|
||||
)
|
||||
.await;
|
||||
).await;
|
||||
|
||||
if let Err(e) = &read_result {
|
||||
warn!(session_id = %session_id, error = %e, "control read loop error");
|
||||
@@ -577,6 +696,7 @@ async fn handle_control_conn(
|
||||
}
|
||||
}
|
||||
registry.remove_tunnel(&fqdn, &session_id).await;
|
||||
guards.remove_session(&session_id).await;
|
||||
write_task.abort();
|
||||
info!(session_id = %session_id, "client session removed");
|
||||
read_result
|
||||
@@ -586,15 +706,14 @@ async fn control_read_loop(
|
||||
reader: &mut tokio::net::tcp::OwnedReadHalf,
|
||||
state: &SharedState,
|
||||
registry: &RedisRegistry,
|
||||
guards: &RelayGuards,
|
||||
session_id: &str,
|
||||
fqdn: &str,
|
||||
user_id: &str,
|
||||
heartbeat_timeout: Duration,
|
||||
) -> Result<()> {
|
||||
loop {
|
||||
let frame: ClientFrame = timeout(heartbeat_timeout, read_frame(reader))
|
||||
.await
|
||||
.context("heartbeat timeout")??;
|
||||
let frame: ClientFrame = timeout(heartbeat_timeout, read_frame(reader)).await.context("heartbeat timeout")??;
|
||||
match frame {
|
||||
ClientFrame::Heartbeat(Heartbeat { session_id: hb_id, .. }) => {
|
||||
if hb_id != session_id {
|
||||
@@ -610,6 +729,9 @@ async fn control_read_loop(
|
||||
registry.refresh_tunnel_session(fqdn, session_id, user_id).await;
|
||||
}
|
||||
ClientFrame::StreamData(StreamData { stream_id, data }) => {
|
||||
guards
|
||||
.throttle_session_bytes(session_id, SessionDir::IngressFromClient, data.len())
|
||||
.await;
|
||||
let sink = lookup_stream_sink(state, session_id, &stream_id).await;
|
||||
if let Some(tx) = sink {
|
||||
if tx.send(data).await.is_err() {
|
||||
@@ -632,7 +754,13 @@ async fn handle_player_conn(
|
||||
cfg: RelayConfig,
|
||||
state: SharedState,
|
||||
registry: RedisRegistry,
|
||||
guards: Arc<RelayGuards>,
|
||||
) -> Result<()> {
|
||||
if !guards.allow_player_ip(&addr.ip().to_string()).await {
|
||||
debug!(peer = %addr, "player connect rate limited");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let (hostname, initial_data) = read_handshake_hostname_and_bytes(&mut stream)
|
||||
.await
|
||||
.context("parse minecraft handshake")?;
|
||||
@@ -646,6 +774,7 @@ async fn handle_player_conn(
|
||||
initial_data,
|
||||
None,
|
||||
"direct",
|
||||
guards,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
@@ -667,6 +796,7 @@ async fn handle_r2r_conn(
|
||||
addr: SocketAddr,
|
||||
_cfg: RelayConfig,
|
||||
state: SharedState,
|
||||
guards: Arc<RelayGuards>,
|
||||
) -> Result<()> {
|
||||
let prelude: RelayForwardPrelude = read_frame(&mut stream).await.context("read r2r prelude")?;
|
||||
if prelude.version != 1 {
|
||||
@@ -689,6 +819,7 @@ async fn handle_r2r_conn(
|
||||
prelude.initial_data,
|
||||
Some(prelude.stream_id),
|
||||
"r2r",
|
||||
guards,
|
||||
)
|
||||
.await
|
||||
.with_context(|| format!("r2r attach failed from {addr}"))
|
||||
@@ -707,9 +838,6 @@ async fn proxy_player_to_owner(
|
||||
.lookup_instance(&route.instance_id)
|
||||
.await
|
||||
.with_context(|| format!("owner instance {} not found in redis", route.instance_id))?;
|
||||
if owner.status.as_deref() == Some("draining") {
|
||||
debug!(owner = %route.instance_id, hostname = %hostname, "owner draining; attempting forward anyway");
|
||||
}
|
||||
let r2r_addr = owner
|
||||
.r2r_addr
|
||||
.clone()
|
||||
@@ -725,7 +853,7 @@ async fn proxy_player_to_owner(
|
||||
fqdn: hostname.clone(),
|
||||
stream_id: Uuid::new_v4().to_string(),
|
||||
peer_addr: player_addr.to_string(),
|
||||
origin_instance_id: cfg.instance_id.clone(),
|
||||
origin_instance_id: cfg.instance_id,
|
||||
hop_count: 1,
|
||||
initial_data,
|
||||
};
|
||||
@@ -737,8 +865,7 @@ async fn proxy_player_to_owner(
|
||||
}
|
||||
|
||||
async fn local_session_for_hostname(state: &SharedState, hostname: &str) -> Option<SessionHandle> {
|
||||
let guard = state.read().await;
|
||||
guard.by_fqdn.get(hostname).cloned()
|
||||
state.read().await.by_fqdn.get(hostname).cloned()
|
||||
}
|
||||
|
||||
async fn local_session_for_session_id(state: &SharedState, session_id: &str) -> Option<SessionHandle> {
|
||||
@@ -755,15 +882,12 @@ async fn attach_player_socket_to_session(
|
||||
initial_data: Vec<u8>,
|
||||
stream_id_override: Option<String>,
|
||||
source: &'static str,
|
||||
guards: Arc<RelayGuards>,
|
||||
) -> Result<()> {
|
||||
let stream_id = stream_id_override.unwrap_or_else(|| Uuid::new_v4().to_string());
|
||||
let (player_read, player_write) = stream.into_split();
|
||||
let (to_player_tx, to_player_rx) = mpsc::channel::<Vec<u8>>(128);
|
||||
session
|
||||
.stream_sinks
|
||||
.write()
|
||||
.await
|
||||
.insert(stream_id.clone(), to_player_tx);
|
||||
session.stream_sinks.write().await.insert(stream_id.clone(), to_player_tx);
|
||||
|
||||
session
|
||||
.tx
|
||||
@@ -795,9 +919,19 @@ async fn attach_player_socket_to_session(
|
||||
|
||||
let tx_control = session.tx.clone();
|
||||
let stream_id_clone = stream_id.clone();
|
||||
let session_id_clone = session.session_id.clone();
|
||||
let sinks = session.stream_sinks.clone();
|
||||
let guards_clone = guards.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = run_player_reader(player_read, tx_control.clone(), stream_id_clone.clone()).await {
|
||||
if let Err(e) = run_player_reader(
|
||||
player_read,
|
||||
tx_control.clone(),
|
||||
stream_id_clone.clone(),
|
||||
session_id_clone,
|
||||
guards_clone,
|
||||
)
|
||||
.await
|
||||
{
|
||||
debug!(stream_id = %stream_id_clone, error = %e, "player reader ended");
|
||||
}
|
||||
let _ = tx_control
|
||||
@@ -817,6 +951,8 @@ async fn run_player_reader(
|
||||
mut reader: tokio::net::tcp::OwnedReadHalf,
|
||||
tx_control: mpsc::Sender<ServerFrame>,
|
||||
stream_id: String,
|
||||
session_id: String,
|
||||
guards: Arc<RelayGuards>,
|
||||
) -> Result<()> {
|
||||
let mut buf = vec![0u8; 16 * 1024];
|
||||
loop {
|
||||
@@ -824,6 +960,9 @@ async fn run_player_reader(
|
||||
if n == 0 {
|
||||
break;
|
||||
}
|
||||
guards
|
||||
.throttle_session_bytes(&session_id, SessionDir::EgressToClient, n)
|
||||
.await;
|
||||
tx_control
|
||||
.send(ServerFrame::StreamData(StreamData {
|
||||
stream_id: stream_id.clone(),
|
||||
|
||||
Reference in New Issue
Block a user