1474 lines
49 KiB
Rust
1474 lines
49 KiB
Rust
use std::{
|
|
collections::HashMap,
|
|
net::SocketAddr,
|
|
sync::Arc,
|
|
time::{Duration, Instant},
|
|
};
|
|
|
|
use anyhow::{Context, Result};
|
|
use common::{
|
|
codec::{read_frame, write_frame},
|
|
minecraft::read_handshake_hostname_and_bytes,
|
|
protocol::{
|
|
ClientFrame, Heartbeat, IncomingTcp, RegisterAccepted, RegisterRequest, RelayForwardPrelude,
|
|
R2rFrame, R2rStreamClosed, R2rStreamData, ServerFrame, StreamClosed, StreamData,
|
|
},
|
|
};
|
|
use redis::AsyncCommands;
|
|
use serde::Deserialize;
|
|
use metrics_exporter_prometheus::PrometheusBuilder;
|
|
use tokio::{
|
|
io::{AsyncReadExt, AsyncWriteExt},
|
|
net::{TcpListener, TcpStream},
|
|
sync::{Mutex, Notify, RwLock, mpsc},
|
|
time::{MissedTickBehavior, interval, timeout},
|
|
};
|
|
use tracing::{debug, info, warn};
|
|
use uuid::Uuid;
|
|
|
|
#[derive(Clone)]
|
|
struct RelayConfig {
|
|
instance_id: String,
|
|
region: String,
|
|
control_bind: String,
|
|
player_bind: String,
|
|
r2r_bind: String,
|
|
r2r_advertise_addr: String,
|
|
domain: String,
|
|
heartbeat_timeout: Duration,
|
|
registry_ttl_secs: u64,
|
|
r2r_connect_timeout: Duration,
|
|
}
|
|
|
|
impl RelayConfig {
|
|
fn from_env() -> Self {
|
|
let control_bind = std::env::var("RELAY_CONTROL_BIND")
|
|
.unwrap_or_else(|_| "0.0.0.0:7000".to_string());
|
|
let player_bind =
|
|
std::env::var("RELAY_PLAYER_BIND").unwrap_or_else(|_| "0.0.0.0:25565".to_string());
|
|
let r2r_bind = std::env::var("RELAY_R2R_BIND").unwrap_or_else(|_| "0.0.0.0:7001".to_string());
|
|
let r2r_advertise_addr = std::env::var("RELAY_R2R_ADVERTISE_ADDR")
|
|
.unwrap_or_else(|_| guess_advertise_addr(&r2r_bind));
|
|
|
|
Self {
|
|
instance_id: std::env::var("RELAY_INSTANCE_ID")
|
|
.unwrap_or_else(|_| format!("relay-{}", Uuid::new_v4())),
|
|
region: std::env::var("RELAY_REGION").unwrap_or_else(|_| "eu".to_string()),
|
|
control_bind,
|
|
player_bind,
|
|
r2r_bind,
|
|
r2r_advertise_addr,
|
|
domain: std::env::var("RELAY_BASE_DOMAIN").unwrap_or_else(|_| "dvv.one".to_string()),
|
|
heartbeat_timeout: Duration::from_secs(
|
|
std::env::var("RELAY_HEARTBEAT_TIMEOUT_SECS")
|
|
.ok()
|
|
.and_then(|v| v.parse().ok())
|
|
.unwrap_or(30),
|
|
),
|
|
registry_ttl_secs: std::env::var("RELAY_REGISTRY_TTL_SECS")
|
|
.ok()
|
|
.and_then(|v| v.parse().ok())
|
|
.unwrap_or(20),
|
|
r2r_connect_timeout: Duration::from_secs(
|
|
std::env::var("RELAY_R2R_CONNECT_TIMEOUT_SECS")
|
|
.ok()
|
|
.and_then(|v| v.parse().ok())
|
|
.unwrap_or(3),
|
|
),
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Clone)]
|
|
struct SessionHandle {
|
|
session_id: String,
|
|
tx: mpsc::Sender<ServerFrame>,
|
|
stream_sinks: Arc<RwLock<HashMap<String, mpsc::Sender<Vec<u8>>>>>,
|
|
last_heartbeat: Instant,
|
|
}
|
|
|
|
struct RelayState {
|
|
by_fqdn: HashMap<String, SessionHandle>,
|
|
by_session: HashMap<String, String>,
|
|
}
|
|
|
|
impl RelayState {
|
|
fn new() -> Self {
|
|
Self {
|
|
by_fqdn: HashMap::new(),
|
|
by_session: HashMap::new(),
|
|
}
|
|
}
|
|
|
|
fn session_count(&self) -> usize {
|
|
self.by_session.len()
|
|
}
|
|
}
|
|
|
|
type SharedState = Arc<RwLock<RelayState>>;
|
|
|
|
#[derive(Clone)]
|
|
struct R2rManager {
|
|
outbound: Arc<Mutex<HashMap<String, mpsc::Sender<R2rFrame>>>>,
|
|
ingress_stream_sinks: Arc<RwLock<HashMap<String, mpsc::Sender<Vec<u8>>>>>,
|
|
}
|
|
|
|
impl R2rManager {
|
|
fn new() -> Self {
|
|
Self {
|
|
outbound: Arc::new(Mutex::new(HashMap::new())),
|
|
ingress_stream_sinks: Arc::new(RwLock::new(HashMap::new())),
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Clone)]
|
|
struct RelayGuards {
|
|
player_ip: Arc<Mutex<HashMap<String, BucketState>>>,
|
|
reg_ip: Arc<Mutex<HashMap<String, BucketState>>>,
|
|
session_ingress: Arc<Mutex<HashMap<String, BucketState>>>,
|
|
session_egress: Arc<Mutex<HashMap<String, BucketState>>>,
|
|
player_ip_rate: f64,
|
|
player_ip_burst: f64,
|
|
reg_ip_rate: f64,
|
|
reg_ip_burst: f64,
|
|
session_bw_rate_bytes: f64,
|
|
session_bw_burst_bytes: f64,
|
|
redis: Option<redis::aio::ConnectionManager>,
|
|
player_global_window_secs: u64,
|
|
player_global_limit: i64,
|
|
reg_global_window_secs: u64,
|
|
reg_global_limit: i64,
|
|
}
|
|
|
|
#[derive(Debug, Clone)]
|
|
struct BucketState {
|
|
tokens: f64,
|
|
last_refill: Instant,
|
|
capacity: f64,
|
|
rate_per_sec: f64,
|
|
}
|
|
|
|
#[derive(Clone, Copy)]
|
|
enum SessionDir {
|
|
IngressFromClient,
|
|
EgressToClient,
|
|
}
|
|
|
|
impl BucketState {
|
|
fn new(capacity: f64, rate_per_sec: f64) -> Self {
|
|
Self {
|
|
tokens: capacity,
|
|
last_refill: Instant::now(),
|
|
capacity,
|
|
rate_per_sec,
|
|
}
|
|
}
|
|
|
|
fn reserve_delay(&mut self, amount: usize) -> Duration {
|
|
let now = Instant::now();
|
|
let elapsed = now.saturating_duration_since(self.last_refill).as_secs_f64();
|
|
if elapsed > 0.0 {
|
|
self.tokens = (self.tokens + elapsed * self.rate_per_sec).min(self.capacity);
|
|
self.last_refill = now;
|
|
}
|
|
|
|
let amount = amount as f64;
|
|
if self.tokens >= amount {
|
|
self.tokens -= amount;
|
|
return Duration::ZERO;
|
|
}
|
|
if self.rate_per_sec <= 0.0 {
|
|
return Duration::from_secs(1);
|
|
}
|
|
|
|
let deficit = amount - self.tokens;
|
|
let wait = Duration::from_secs_f64(deficit / self.rate_per_sec);
|
|
self.tokens = 0.0;
|
|
self.last_refill = now + wait;
|
|
wait
|
|
}
|
|
}
|
|
|
|
impl RelayGuards {
|
|
async fn from_env() -> Self {
|
|
let player_ip_rate = std::env::var("RELAY_PLAYER_CONNECTS_PER_SEC")
|
|
.ok()
|
|
.and_then(|v| v.parse().ok())
|
|
.unwrap_or(2.0);
|
|
let player_ip_burst = std::env::var("RELAY_PLAYER_CONNECTS_BURST")
|
|
.ok()
|
|
.and_then(|v| v.parse().ok())
|
|
.unwrap_or(20.0);
|
|
|
|
let reg_per_min = std::env::var("RELAY_REG_ATTEMPTS_PER_MIN")
|
|
.ok()
|
|
.and_then(|v| v.parse().ok())
|
|
.unwrap_or(10.0);
|
|
let reg_ip_rate = reg_per_min / 60.0;
|
|
let reg_ip_burst = std::env::var("RELAY_REG_ATTEMPTS_BURST")
|
|
.ok()
|
|
.and_then(|v| v.parse().ok())
|
|
.unwrap_or(10.0);
|
|
|
|
let session_bw_kbps = std::env::var("RELAY_SESSION_BW_KBPS")
|
|
.ok()
|
|
.and_then(|v| v.parse().ok())
|
|
.unwrap_or(8192.0);
|
|
let session_bw_rate_bytes = session_bw_kbps * 1024.0 / 8.0;
|
|
let session_bw_burst_bytes = std::env::var("RELAY_SESSION_BW_BURST_KB")
|
|
.ok()
|
|
.and_then(|v| v.parse().ok())
|
|
.unwrap_or(512.0)
|
|
* 1024.0;
|
|
let redis = match std::env::var("REDIS_URL") {
|
|
Ok(url) => match redis::Client::open(url) {
|
|
Ok(client) => redis::aio::ConnectionManager::new(client).await.ok(),
|
|
Err(_) => None,
|
|
},
|
|
Err(_) => None,
|
|
};
|
|
let player_global_window_secs = std::env::var("RELAY_PLAYER_GLOBAL_WINDOW_SECS")
|
|
.ok()
|
|
.and_then(|v| v.parse().ok())
|
|
.unwrap_or(10);
|
|
let player_global_limit = std::env::var("RELAY_PLAYER_GLOBAL_LIMIT")
|
|
.ok()
|
|
.and_then(|v| v.parse().ok())
|
|
.unwrap_or(50);
|
|
let reg_global_window_secs = std::env::var("RELAY_REG_GLOBAL_WINDOW_SECS")
|
|
.ok()
|
|
.and_then(|v| v.parse().ok())
|
|
.unwrap_or(60);
|
|
let reg_global_limit = std::env::var("RELAY_REG_GLOBAL_LIMIT")
|
|
.ok()
|
|
.and_then(|v| v.parse().ok())
|
|
.unwrap_or(10);
|
|
|
|
Self {
|
|
player_ip: Arc::new(Mutex::new(HashMap::new())),
|
|
reg_ip: Arc::new(Mutex::new(HashMap::new())),
|
|
session_ingress: Arc::new(Mutex::new(HashMap::new())),
|
|
session_egress: Arc::new(Mutex::new(HashMap::new())),
|
|
player_ip_rate,
|
|
player_ip_burst,
|
|
reg_ip_rate,
|
|
reg_ip_burst,
|
|
session_bw_rate_bytes,
|
|
session_bw_burst_bytes,
|
|
redis,
|
|
player_global_window_secs,
|
|
player_global_limit,
|
|
reg_global_window_secs,
|
|
reg_global_limit,
|
|
}
|
|
}
|
|
|
|
async fn allow_player_ip(&self, ip: &str) -> bool {
|
|
self.allow_ip(ip, true).await
|
|
}
|
|
|
|
async fn allow_registration_ip(&self, ip: &str) -> bool {
|
|
self.allow_ip(ip, false).await
|
|
}
|
|
|
|
async fn allow_ip(&self, ip: &str, player: bool) -> bool {
|
|
let (map, rate, burst) = if player {
|
|
(&self.player_ip, self.player_ip_rate, self.player_ip_burst)
|
|
} else {
|
|
(&self.reg_ip, self.reg_ip_rate, self.reg_ip_burst)
|
|
};
|
|
let mut guard = map.lock().await;
|
|
let bucket = guard
|
|
.entry(ip.to_string())
|
|
.or_insert_with(|| BucketState::new(burst, rate));
|
|
let local_ok = bucket.reserve_delay(1).is_zero();
|
|
drop(guard);
|
|
if !local_ok {
|
|
return false;
|
|
}
|
|
|
|
let (window_secs, limit, scope) = if player {
|
|
(self.player_global_window_secs, self.player_global_limit, "mc")
|
|
} else {
|
|
(self.reg_global_window_secs, self.reg_global_limit, "reg")
|
|
};
|
|
self.redis_allow_ip_window(ip, scope, window_secs, limit).await
|
|
}
|
|
|
|
async fn throttle_session_bytes(&self, session_id: &str, dir: SessionDir, bytes: usize) {
|
|
let delay = {
|
|
let map = match dir {
|
|
SessionDir::IngressFromClient => &self.session_ingress,
|
|
SessionDir::EgressToClient => &self.session_egress,
|
|
};
|
|
let mut guard = map.lock().await;
|
|
let bucket = guard.entry(session_id.to_string()).or_insert_with(|| {
|
|
BucketState::new(self.session_bw_burst_bytes, self.session_bw_rate_bytes)
|
|
});
|
|
bucket.reserve_delay(bytes)
|
|
};
|
|
if !delay.is_zero() {
|
|
tokio::time::sleep(delay).await;
|
|
}
|
|
}
|
|
|
|
async fn remove_session(&self, session_id: &str) {
|
|
self.session_ingress.lock().await.remove(session_id);
|
|
self.session_egress.lock().await.remove(session_id);
|
|
}
|
|
|
|
async fn redis_allow_ip_window(
|
|
&self,
|
|
ip: &str,
|
|
scope: &str,
|
|
window_secs: u64,
|
|
limit: i64,
|
|
) -> bool {
|
|
let Some(mut conn) = self.redis.clone() else {
|
|
return true;
|
|
};
|
|
let key = format!("ratelimit:ip:{ip}:{scope}");
|
|
let res: redis::RedisResult<i64> = async {
|
|
let count: i64 = conn.incr(&key, 1).await?;
|
|
if count == 1 {
|
|
let _: bool = conn.expire(&key, window_secs as i64).await?;
|
|
}
|
|
Ok(count)
|
|
}
|
|
.await;
|
|
match res {
|
|
Ok(count) => count <= limit,
|
|
Err(_) => true,
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Clone, Deserialize)]
|
|
struct TunnelRouteRecord {
|
|
instance_id: String,
|
|
session_id: String,
|
|
#[serde(rename = "user_id")]
|
|
_user_id: Option<String>,
|
|
#[serde(rename = "fqdn")]
|
|
_fqdn: Option<String>,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Deserialize)]
|
|
struct RelayInstanceRecord {
|
|
#[serde(rename = "instance_id")]
|
|
_instance_id: String,
|
|
#[serde(rename = "region")]
|
|
_region: Option<String>,
|
|
#[serde(rename = "status")]
|
|
_status: Option<String>,
|
|
r2r_addr: Option<String>,
|
|
}
|
|
|
|
#[derive(Clone)]
|
|
struct RedisRegistry {
|
|
conn: Option<redis::aio::ConnectionManager>,
|
|
instance_id: String,
|
|
region: String,
|
|
control_addr: String,
|
|
player_addr: String,
|
|
r2r_addr: String,
|
|
ttl_secs: u64,
|
|
}
|
|
|
|
impl RedisRegistry {
|
|
async fn from_env(cfg: &RelayConfig) -> Self {
|
|
let conn = match std::env::var("REDIS_URL") {
|
|
Ok(url) => match redis::Client::open(url.clone()) {
|
|
Ok(client) => match redis::aio::ConnectionManager::new(client).await {
|
|
Ok(cm) => {
|
|
info!("connected to redis");
|
|
Some(cm)
|
|
}
|
|
Err(e) => {
|
|
warn!(error = %e, "redis connection manager failed; continuing without redis");
|
|
None
|
|
}
|
|
},
|
|
Err(e) => {
|
|
warn!(error = %e, "invalid REDIS_URL; continuing without redis");
|
|
None
|
|
}
|
|
},
|
|
Err(_) => None,
|
|
};
|
|
|
|
Self {
|
|
conn,
|
|
instance_id: cfg.instance_id.clone(),
|
|
region: cfg.region.clone(),
|
|
control_addr: cfg.control_bind.clone(),
|
|
player_addr: cfg.player_bind.clone(),
|
|
r2r_addr: cfg.r2r_advertise_addr.clone(),
|
|
ttl_secs: cfg.registry_ttl_secs,
|
|
}
|
|
}
|
|
|
|
async fn register_instance(&self) {
|
|
let Some(mut conn) = self.conn.clone() else { return; };
|
|
let key = format!("relay:instance:{}", self.instance_id);
|
|
let payload = serde_json::json!({
|
|
"instance_id": self.instance_id,
|
|
"region": self.region,
|
|
"status": "active",
|
|
"control_addr": self.control_addr,
|
|
"player_addr": self.player_addr,
|
|
"r2r_addr": self.r2r_addr,
|
|
"started_at": chrono::Utc::now().timestamp(),
|
|
}).to_string();
|
|
|
|
let res: redis::RedisResult<()> = async {
|
|
let _: usize = conn.sadd("relay:instances", &self.instance_id).await?;
|
|
let _: () = conn.set_ex(&key, payload, self.ttl_secs).await?;
|
|
let _: () = conn.set_ex(format!("relay:heartbeat:{}", self.instance_id), "1", self.ttl_secs).await?;
|
|
Ok(())
|
|
}.await;
|
|
if let Err(e) = res {
|
|
warn!(error = %e, "failed to register instance in redis");
|
|
}
|
|
}
|
|
|
|
async fn heartbeat_instance(&self, tunnel_count: usize) {
|
|
let Some(mut conn) = self.conn.clone() else { return; };
|
|
let load_key = format!("relay:load:{}", self.region);
|
|
let inst_key = format!("relay:instance:{}", self.instance_id);
|
|
let payload = serde_json::json!({
|
|
"instance_id": self.instance_id,
|
|
"region": self.region,
|
|
"status": "active",
|
|
"control_addr": self.control_addr,
|
|
"player_addr": self.player_addr,
|
|
"r2r_addr": self.r2r_addr,
|
|
"tunnel_count": tunnel_count,
|
|
"updated_at": chrono::Utc::now().timestamp(),
|
|
}).to_string();
|
|
let res: redis::RedisResult<()> = async {
|
|
let _: () = conn.set_ex(format!("relay:heartbeat:{}", self.instance_id), "1", self.ttl_secs).await?;
|
|
let _: () = conn.set_ex(inst_key, payload, self.ttl_secs).await?;
|
|
let _: () = conn.zadd(load_key, &self.instance_id, tunnel_count as f64).await?;
|
|
Ok(())
|
|
}.await;
|
|
if let Err(e) = res {
|
|
warn!(error = %e, "redis instance heartbeat failed");
|
|
}
|
|
}
|
|
|
|
async fn set_draining(&self) {
|
|
let Some(mut conn) = self.conn.clone() else { return; };
|
|
let inst_key = format!("relay:instance:{}", self.instance_id);
|
|
let payload = serde_json::json!({
|
|
"instance_id": self.instance_id,
|
|
"region": self.region,
|
|
"status": "draining",
|
|
"control_addr": self.control_addr,
|
|
"player_addr": self.player_addr,
|
|
"r2r_addr": self.r2r_addr,
|
|
"updated_at": chrono::Utc::now().timestamp(),
|
|
}).to_string();
|
|
let _: redis::RedisResult<()> = async {
|
|
let _: () = conn.set_ex(inst_key, payload, self.ttl_secs).await?;
|
|
let _: () = conn.zadd(format!("relay:load:{}", self.region), &self.instance_id, 1e12f64).await?;
|
|
Ok(())
|
|
}.await;
|
|
}
|
|
|
|
async fn register_tunnel(&self, fqdn: &str, session_id: &str, user_id: &str) {
|
|
let Some(mut conn) = self.conn.clone() else { return; };
|
|
let payload = serde_json::json!({
|
|
"instance_id": self.instance_id,
|
|
"session_id": session_id,
|
|
"user_id": user_id,
|
|
"region": self.region,
|
|
"fqdn": fqdn,
|
|
}).to_string();
|
|
let _: redis::RedisResult<()> = async {
|
|
let _: () = conn.set_ex(format!("tunnel:sub:{fqdn}"), &payload, self.ttl_secs).await?;
|
|
let _: () = conn.set_ex(format!("tunnel:session:{session_id}"), payload, self.ttl_secs).await?;
|
|
Ok(())
|
|
}.await;
|
|
}
|
|
|
|
async fn refresh_tunnel_session(&self, fqdn: &str, session_id: &str, user_id: &str) {
|
|
self.register_tunnel(fqdn, session_id, user_id).await;
|
|
}
|
|
|
|
async fn remove_tunnel(&self, fqdn: &str, session_id: &str) {
|
|
let Some(mut conn) = self.conn.clone() else { return; };
|
|
let _: redis::RedisResult<()> = async {
|
|
let _: usize = conn.del(format!("tunnel:sub:{fqdn}")).await?;
|
|
let _: usize = conn.del(format!("tunnel:session:{session_id}")).await?;
|
|
Ok(())
|
|
}.await;
|
|
}
|
|
|
|
async fn lookup_tunnel(&self, fqdn: &str) -> Option<TunnelRouteRecord> {
|
|
let Some(mut conn) = self.conn.clone() else { return None; };
|
|
let raw: Option<String> = conn.get(format!("tunnel:sub:{fqdn}")).await.ok()?;
|
|
serde_json::from_str(&raw?).ok()
|
|
}
|
|
|
|
async fn lookup_instance(&self, instance_id: &str) -> Option<RelayInstanceRecord> {
|
|
let Some(mut conn) = self.conn.clone() else { return None; };
|
|
let raw: Option<String> = conn.get(format!("relay:instance:{instance_id}")).await.ok()?;
|
|
serde_json::from_str(&raw?).ok()
|
|
}
|
|
}
|
|
|
|
#[tokio::main]
|
|
async fn main() -> Result<()> {
|
|
init_metrics()?;
|
|
tracing_subscriber::fmt()
|
|
.with_env_filter(
|
|
tracing_subscriber::EnvFilter::try_from_default_env()
|
|
.unwrap_or_else(|_| "relay=info".into()),
|
|
)
|
|
.init();
|
|
|
|
let cfg = RelayConfig::from_env();
|
|
let registry = RedisRegistry::from_env(&cfg).await;
|
|
let guards = Arc::new(RelayGuards::from_env().await);
|
|
let r2r = Arc::new(R2rManager::new());
|
|
registry.register_instance().await;
|
|
|
|
let control_listener = TcpListener::bind(&cfg.control_bind)
|
|
.await
|
|
.with_context(|| format!("bind control {}", cfg.control_bind))?;
|
|
let player_listener = TcpListener::bind(&cfg.player_bind)
|
|
.await
|
|
.with_context(|| format!("bind player {}", cfg.player_bind))?;
|
|
let r2r_listener = TcpListener::bind(&cfg.r2r_bind)
|
|
.await
|
|
.with_context(|| format!("bind r2r {}", cfg.r2r_bind))?;
|
|
|
|
info!(instance_id = %cfg.instance_id, region = %cfg.region, control = %cfg.control_bind, player = %cfg.player_bind, r2r = %cfg.r2r_bind, r2r_advertise = %cfg.r2r_advertise_addr, "relay started");
|
|
metrics::gauge!("relay_drain_state").set(0.0);
|
|
|
|
let shutdown = Arc::new(Notify::new());
|
|
let state: SharedState = Arc::new(RwLock::new(RelayState::new()));
|
|
|
|
let heartbeat_task = tokio::spawn(run_registry_heartbeat(state.clone(), registry.clone(), shutdown.clone()));
|
|
let control_task = tokio::spawn(run_control_accept_loop(
|
|
control_listener,
|
|
cfg.clone(),
|
|
state.clone(),
|
|
registry.clone(),
|
|
guards.clone(),
|
|
shutdown.clone(),
|
|
));
|
|
let player_task = tokio::spawn(run_player_accept_loop(
|
|
player_listener,
|
|
cfg.clone(),
|
|
state.clone(),
|
|
registry.clone(),
|
|
guards.clone(),
|
|
r2r.clone(),
|
|
shutdown.clone(),
|
|
));
|
|
let r2r_task = tokio::spawn(run_r2r_accept_loop(
|
|
r2r_listener,
|
|
cfg.clone(),
|
|
state.clone(),
|
|
guards.clone(),
|
|
r2r.clone(),
|
|
shutdown.clone(),
|
|
));
|
|
|
|
tokio::pin!(heartbeat_task);
|
|
tokio::pin!(control_task);
|
|
tokio::pin!(player_task);
|
|
tokio::pin!(r2r_task);
|
|
|
|
tokio::select! {
|
|
_ = tokio::signal::ctrl_c() => info!("shutdown signal received"),
|
|
res = &mut control_task => warn!("control accept loop ended: {:?}", res),
|
|
res = &mut player_task => warn!("player accept loop ended: {:?}", res),
|
|
res = &mut r2r_task => warn!("r2r accept loop ended: {:?}", res),
|
|
res = &mut heartbeat_task => warn!("registry heartbeat task ended: {:?}", res),
|
|
}
|
|
|
|
registry.set_draining().await;
|
|
metrics::gauge!("relay_drain_state").set(1.0);
|
|
shutdown.notify_waiters();
|
|
info!("draining relay");
|
|
tokio::time::sleep(Duration::from_secs(1)).await;
|
|
Ok(())
|
|
}
|
|
|
|
async fn run_registry_heartbeat(state: SharedState, registry: RedisRegistry, shutdown: Arc<Notify>) {
|
|
let mut ticker = interval(Duration::from_secs(5));
|
|
ticker.set_missed_tick_behavior(MissedTickBehavior::Skip);
|
|
loop {
|
|
tokio::select! {
|
|
_ = shutdown.notified() => break,
|
|
_ = ticker.tick() => {
|
|
let count = state.read().await.session_count();
|
|
metrics::gauge!("relay_active_tunnels").set(count as f64);
|
|
registry.heartbeat_instance(count).await;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
async fn run_control_accept_loop(
|
|
listener: TcpListener,
|
|
cfg: RelayConfig,
|
|
state: SharedState,
|
|
registry: RedisRegistry,
|
|
guards: Arc<RelayGuards>,
|
|
shutdown: Arc<Notify>,
|
|
) -> Result<()> {
|
|
loop {
|
|
tokio::select! {
|
|
_ = shutdown.notified() => break,
|
|
res = listener.accept() => {
|
|
let (stream, addr) = match res {
|
|
Ok(v) => v,
|
|
Err(e) => { warn!(error = %e, "control accept failed"); continue; }
|
|
};
|
|
metrics::counter!("relay_control_accepts_total").increment(1);
|
|
let cfg = cfg.clone();
|
|
let state = state.clone();
|
|
let registry = registry.clone();
|
|
let guards = guards.clone();
|
|
tokio::spawn(async move {
|
|
if let Err(e) = handle_control_conn(stream, addr, cfg, state, registry, guards).await {
|
|
warn!(peer = %addr, error = %e, "control connection ended with error");
|
|
}
|
|
});
|
|
}
|
|
}
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
async fn run_player_accept_loop(
|
|
listener: TcpListener,
|
|
cfg: RelayConfig,
|
|
state: SharedState,
|
|
registry: RedisRegistry,
|
|
guards: Arc<RelayGuards>,
|
|
r2r: Arc<R2rManager>,
|
|
shutdown: Arc<Notify>,
|
|
) -> Result<()> {
|
|
loop {
|
|
tokio::select! {
|
|
_ = shutdown.notified() => break,
|
|
res = listener.accept() => {
|
|
let (stream, addr) = match res {
|
|
Ok(v) => v,
|
|
Err(e) => { warn!(error = %e, "player accept failed"); continue; }
|
|
};
|
|
metrics::counter!("relay_player_accepts_total").increment(1);
|
|
let cfg = cfg.clone();
|
|
let state = state.clone();
|
|
let registry = registry.clone();
|
|
let guards = guards.clone();
|
|
let r2r = r2r.clone();
|
|
tokio::spawn(async move {
|
|
if let Err(e) = handle_player_conn(stream, addr, cfg, state, registry, guards, r2r).await {
|
|
debug!(peer = %addr, error = %e, "player connection closed");
|
|
}
|
|
});
|
|
}
|
|
}
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
async fn run_r2r_accept_loop(
|
|
listener: TcpListener,
|
|
cfg: RelayConfig,
|
|
state: SharedState,
|
|
guards: Arc<RelayGuards>,
|
|
r2r: Arc<R2rManager>,
|
|
shutdown: Arc<Notify>,
|
|
) -> Result<()> {
|
|
loop {
|
|
tokio::select! {
|
|
_ = shutdown.notified() => break,
|
|
res = listener.accept() => {
|
|
let (stream, addr) = match res {
|
|
Ok(v) => v,
|
|
Err(e) => { warn!(error = %e, "r2r accept failed"); continue; }
|
|
};
|
|
metrics::counter!("relay_r2r_accepts_total").increment(1);
|
|
let cfg = cfg.clone();
|
|
let state = state.clone();
|
|
let guards = guards.clone();
|
|
let r2r = r2r.clone();
|
|
tokio::spawn(async move {
|
|
if let Err(e) = handle_r2r_conn(stream, addr, cfg, state, guards, r2r).await {
|
|
warn!(peer = %addr, error = %e, "r2r connection ended with error");
|
|
}
|
|
});
|
|
}
|
|
}
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
#[tracing::instrument(skip(stream, state, registry, guards, cfg), fields(peer = %addr))]
|
|
async fn handle_control_conn(
|
|
stream: TcpStream,
|
|
addr: SocketAddr,
|
|
cfg: RelayConfig,
|
|
state: SharedState,
|
|
registry: RedisRegistry,
|
|
guards: Arc<RelayGuards>,
|
|
) -> Result<()> {
|
|
if !guards.allow_registration_ip(&addr.ip().to_string()).await {
|
|
metrics::counter!("relay_rate_limited_total", "scope" => "registration_ip").increment(1);
|
|
anyhow::bail!("registration rate limited for {}", addr.ip());
|
|
}
|
|
|
|
let (mut reader, mut writer) = stream.into_split();
|
|
let first: ClientFrame = read_frame(&mut reader).await.context("read initial frame")?;
|
|
let register = match first {
|
|
ClientFrame::Register(req) => req,
|
|
_ => {
|
|
write_frame(&mut writer, &ServerFrame::RegisterRejected { reason: "expected Register frame".to_string() }).await.ok();
|
|
anyhow::bail!("expected register frame");
|
|
}
|
|
};
|
|
if !token_looks_valid(®ister.token) {
|
|
write_frame(&mut writer, &ServerFrame::RegisterRejected { reason: "invalid token".to_string() }).await.ok();
|
|
anyhow::bail!("invalid token");
|
|
}
|
|
|
|
let (tx, mut rx) = mpsc::channel::<ServerFrame>(512);
|
|
let session_id = Uuid::new_v4().to_string();
|
|
let fqdn = assign_fqdn(&cfg, ®ister);
|
|
let user_id = fake_user_id_from_token(®ister.token);
|
|
let stream_sinks = Arc::new(RwLock::new(HashMap::<String, mpsc::Sender<Vec<u8>>>::new()));
|
|
|
|
{
|
|
let mut guard = state.write().await;
|
|
guard.by_session.insert(session_id.clone(), fqdn.clone());
|
|
guard.by_fqdn.insert(
|
|
fqdn.clone(),
|
|
SessionHandle {
|
|
session_id: session_id.clone(),
|
|
tx: tx.clone(),
|
|
stream_sinks: stream_sinks.clone(),
|
|
last_heartbeat: Instant::now(),
|
|
},
|
|
);
|
|
}
|
|
registry.register_tunnel(&fqdn, &session_id, &user_id).await;
|
|
|
|
write_frame(&mut writer, &ServerFrame::RegisterAccepted(RegisterAccepted {
|
|
session_id: session_id.clone(),
|
|
fqdn: fqdn.clone(),
|
|
heartbeat_interval_secs: 5,
|
|
owner_instance_id: cfg.instance_id.clone(),
|
|
})).await?;
|
|
info!(peer = %addr, user_id = %user_id, fqdn = %fqdn, session_id = %session_id, "client registered");
|
|
metrics::counter!("relay_tunnel_registrations_total").increment(1);
|
|
|
|
let write_task = tokio::spawn(async move {
|
|
while let Some(frame) = rx.recv().await {
|
|
write_frame(&mut writer, &frame).await?;
|
|
}
|
|
Ok::<(), anyhow::Error>(())
|
|
});
|
|
|
|
let read_result = control_read_loop(
|
|
&mut reader,
|
|
&state,
|
|
®istry,
|
|
&guards,
|
|
&session_id,
|
|
&fqdn,
|
|
&user_id,
|
|
cfg.heartbeat_timeout,
|
|
).await;
|
|
|
|
if let Err(e) = &read_result {
|
|
warn!(session_id = %session_id, error = %e, "control read loop error");
|
|
}
|
|
|
|
{
|
|
let mut guard = state.write().await;
|
|
if let Some(fqdn) = guard.by_session.remove(&session_id) {
|
|
guard.by_fqdn.remove(&fqdn);
|
|
}
|
|
}
|
|
registry.remove_tunnel(&fqdn, &session_id).await;
|
|
guards.remove_session(&session_id).await;
|
|
write_task.abort();
|
|
info!(session_id = %session_id, "client session removed");
|
|
read_result
|
|
}
|
|
|
|
async fn control_read_loop(
|
|
reader: &mut tokio::net::tcp::OwnedReadHalf,
|
|
state: &SharedState,
|
|
registry: &RedisRegistry,
|
|
guards: &RelayGuards,
|
|
session_id: &str,
|
|
fqdn: &str,
|
|
user_id: &str,
|
|
heartbeat_timeout: Duration,
|
|
) -> Result<()> {
|
|
loop {
|
|
let frame: ClientFrame = timeout(heartbeat_timeout, read_frame(reader)).await.context("heartbeat timeout")??;
|
|
match frame {
|
|
ClientFrame::Heartbeat(Heartbeat { session_id: hb_id, .. }) => {
|
|
if hb_id != session_id {
|
|
anyhow::bail!("heartbeat session mismatch");
|
|
}
|
|
let mut guard = state.write().await;
|
|
if let Some(route_fqdn) = guard.by_session.get(session_id).cloned()
|
|
&& let Some(handle) = guard.by_fqdn.get_mut(&route_fqdn)
|
|
{
|
|
handle.last_heartbeat = Instant::now();
|
|
}
|
|
drop(guard);
|
|
registry.refresh_tunnel_session(fqdn, session_id, user_id).await;
|
|
}
|
|
ClientFrame::StreamData(StreamData { stream_id, data }) => {
|
|
guards
|
|
.throttle_session_bytes(session_id, SessionDir::IngressFromClient, data.len())
|
|
.await;
|
|
let sink = lookup_stream_sink(state, session_id, &stream_id).await;
|
|
if let Some(tx) = sink {
|
|
if tx.send(data).await.is_err() {
|
|
remove_stream_sink(state, session_id, &stream_id).await;
|
|
}
|
|
}
|
|
}
|
|
ClientFrame::StreamClosed(StreamClosed { stream_id, .. }) => {
|
|
remove_stream_sink(state, session_id, &stream_id).await;
|
|
}
|
|
ClientFrame::Pong => {}
|
|
ClientFrame::Register(_) => anyhow::bail!("unexpected Register frame after registration"),
|
|
}
|
|
}
|
|
}
|
|
|
|
#[tracing::instrument(skip(stream, cfg, state, registry, guards, r2r), fields(peer = %addr))]
|
|
async fn handle_player_conn(
|
|
mut stream: TcpStream,
|
|
addr: SocketAddr,
|
|
cfg: RelayConfig,
|
|
state: SharedState,
|
|
registry: RedisRegistry,
|
|
guards: Arc<RelayGuards>,
|
|
r2r: Arc<R2rManager>,
|
|
) -> Result<()> {
|
|
if !guards.allow_player_ip(&addr.ip().to_string()).await {
|
|
metrics::counter!("relay_rate_limited_total", "scope" => "player_ip").increment(1);
|
|
debug!(peer = %addr, "player connect rate limited");
|
|
return Ok(());
|
|
}
|
|
|
|
let (hostname, initial_data) = read_handshake_hostname_and_bytes(&mut stream)
|
|
.await
|
|
.context("parse minecraft handshake")?;
|
|
|
|
if let Some(session) = local_session_for_hostname(&state, &hostname).await {
|
|
return attach_player_socket_to_session(
|
|
stream,
|
|
session,
|
|
hostname,
|
|
addr.to_string(),
|
|
initial_data,
|
|
None,
|
|
"direct",
|
|
guards,
|
|
)
|
|
.await;
|
|
}
|
|
|
|
if let Some(route) = registry.lookup_tunnel(&hostname).await {
|
|
if route.instance_id == cfg.instance_id {
|
|
debug!(peer = %addr, hostname = %hostname, session_id = %route.session_id, "route points to self but local session missing");
|
|
return Ok(());
|
|
}
|
|
return proxy_player_to_owner(stream, addr, hostname, initial_data, route, cfg, registry, guards, r2r).await;
|
|
}
|
|
|
|
debug!(peer = %addr, hostname = %hostname, "no tunnel for hostname");
|
|
Ok(())
|
|
}
|
|
|
|
#[tracing::instrument(skip(stream, cfg, state, guards, r2r), fields(peer = %addr))]
|
|
async fn handle_r2r_conn(
|
|
stream: TcpStream,
|
|
addr: SocketAddr,
|
|
cfg: RelayConfig,
|
|
state: SharedState,
|
|
guards: Arc<RelayGuards>,
|
|
r2r: Arc<R2rManager>,
|
|
) -> Result<()> {
|
|
handle_r2r_multiplex_conn(stream, addr, cfg, state, guards, r2r)
|
|
.await
|
|
.with_context(|| format!("r2r multiplex failed from {addr}"))
|
|
}
|
|
|
|
#[tracing::instrument(skip(player_stream, route, cfg, registry, guards, r2r), fields(peer = %player_addr, hostname = %hostname))]
|
|
async fn proxy_player_to_owner(
|
|
player_stream: TcpStream,
|
|
player_addr: SocketAddr,
|
|
hostname: String,
|
|
initial_data: Vec<u8>,
|
|
route: TunnelRouteRecord,
|
|
cfg: RelayConfig,
|
|
registry: RedisRegistry,
|
|
guards: Arc<RelayGuards>,
|
|
r2r: Arc<R2rManager>,
|
|
) -> Result<()> {
|
|
let redis_lookup_started = Instant::now();
|
|
let owner = registry
|
|
.lookup_instance(&route.instance_id)
|
|
.await
|
|
.with_context(|| format!("owner instance {} not found in redis", route.instance_id))?;
|
|
metrics::histogram!("relay_redis_lookup_latency_ms")
|
|
.record(redis_lookup_started.elapsed().as_secs_f64() * 1000.0);
|
|
let r2r_addr = owner
|
|
.r2r_addr
|
|
.clone()
|
|
.with_context(|| format!("owner {} missing r2r_addr", route.instance_id))?;
|
|
|
|
let r2r_connect_started = Instant::now();
|
|
metrics::histogram!("relay_r2r_connect_latency_ms")
|
|
.record(r2r_connect_started.elapsed().as_secs_f64() * 1000.0);
|
|
|
|
let prelude = RelayForwardPrelude {
|
|
version: 1,
|
|
session_id: route.session_id.clone(),
|
|
fqdn: hostname.clone(),
|
|
stream_id: Uuid::new_v4().to_string(),
|
|
peer_addr: player_addr.to_string(),
|
|
origin_instance_id: cfg.instance_id.clone(),
|
|
hop_count: 1,
|
|
initial_data,
|
|
};
|
|
proxy_player_to_owner_pooled(
|
|
player_stream,
|
|
player_addr,
|
|
hostname,
|
|
route.instance_id,
|
|
r2r_addr,
|
|
prelude,
|
|
route.session_id,
|
|
cfg,
|
|
guards,
|
|
r2r,
|
|
)
|
|
.await
|
|
}
|
|
|
|
async fn proxy_player_to_owner_pooled(
|
|
player_stream: TcpStream,
|
|
player_addr: SocketAddr,
|
|
hostname: String,
|
|
owner_instance_id: String,
|
|
owner_r2r_addr: String,
|
|
prelude: RelayForwardPrelude,
|
|
session_id: String,
|
|
cfg: RelayConfig,
|
|
guards: Arc<RelayGuards>,
|
|
r2r: Arc<R2rManager>,
|
|
) -> Result<()> {
|
|
let stream_id = prelude.stream_id.clone();
|
|
let sender = get_or_connect_r2r_pool(
|
|
owner_instance_id.clone(),
|
|
owner_r2r_addr,
|
|
cfg,
|
|
guards,
|
|
r2r.clone(),
|
|
)
|
|
.await?;
|
|
|
|
let (player_read, player_write) = player_stream.into_split();
|
|
let (to_player_tx, to_player_rx) = mpsc::channel::<Vec<u8>>(128);
|
|
r2r.ingress_stream_sinks
|
|
.write()
|
|
.await
|
|
.insert(stream_id.clone(), to_player_tx);
|
|
|
|
sender
|
|
.send(R2rFrame::Open(prelude))
|
|
.await
|
|
.context("send r2r open")?;
|
|
|
|
let tx = sender.clone();
|
|
let sid = session_id.clone();
|
|
let stid = stream_id.clone();
|
|
let sinks = r2r.ingress_stream_sinks.clone();
|
|
tokio::spawn(async move {
|
|
if let Err(e) =
|
|
run_ingress_player_reader_to_r2r(player_read, tx.clone(), sid.clone(), stid.clone()).await
|
|
{
|
|
debug!(stream_id = %stid, error = %e, "ingress player->r2r reader ended");
|
|
}
|
|
let _ = tx
|
|
.send(R2rFrame::Close(R2rStreamClosed {
|
|
session_id: sid,
|
|
stream_id: stid.clone(),
|
|
reason: Some("ingress_player_reader_closed".into()),
|
|
}))
|
|
.await;
|
|
let _ = sinks.write().await.remove(&stid);
|
|
});
|
|
|
|
let stid = stream_id.clone();
|
|
let sinks = r2r.ingress_stream_sinks.clone();
|
|
tokio::spawn(async move {
|
|
if let Err(e) = run_ingress_player_writer(player_write, to_player_rx).await {
|
|
debug!(stream_id = %stid, error = %e, "ingress r2r->player writer ended");
|
|
}
|
|
let _ = sinks.write().await.remove(&stid);
|
|
});
|
|
|
|
metrics::counter!("relay_r2r_forwards_total").increment(1);
|
|
info!(peer = %player_addr, hostname = %hostname, owner = %owner_instance_id, stream_id = %stream_id, "proxied player via pooled r2r channel");
|
|
Ok(())
|
|
}
|
|
|
|
async fn get_or_connect_r2r_pool(
|
|
owner_instance_id: String,
|
|
owner_r2r_addr: String,
|
|
cfg: RelayConfig,
|
|
guards: Arc<RelayGuards>,
|
|
r2r: Arc<R2rManager>,
|
|
) -> Result<mpsc::Sender<R2rFrame>> {
|
|
if let Some(existing) = r2r.outbound.lock().await.get(&owner_instance_id).cloned() {
|
|
return Ok(existing);
|
|
}
|
|
|
|
let connect_started = Instant::now();
|
|
let stream = timeout(cfg.r2r_connect_timeout, TcpStream::connect(&owner_r2r_addr))
|
|
.await
|
|
.context("r2r connect timeout")??;
|
|
metrics::histogram!("relay_r2r_connect_latency_ms")
|
|
.record(connect_started.elapsed().as_secs_f64() * 1000.0);
|
|
|
|
let (mut reader, mut writer) = stream.into_split();
|
|
let (tx, mut rx) = mpsc::channel::<R2rFrame>(2048);
|
|
|
|
let mut pools = r2r.outbound.lock().await;
|
|
if let Some(existing) = pools.get(&owner_instance_id).cloned() {
|
|
return Ok(existing);
|
|
}
|
|
pools.insert(owner_instance_id.clone(), tx.clone());
|
|
drop(pools);
|
|
|
|
let owner_for_reader = owner_instance_id.clone();
|
|
let r2r_for_reader = r2r.clone();
|
|
let guards_for_reader = guards.clone();
|
|
tokio::spawn(async move {
|
|
loop {
|
|
match read_frame::<_, R2rFrame>(&mut reader).await {
|
|
Ok(frame) => {
|
|
if let Err(e) = handle_r2r_inbound_frame(frame, &r2r_for_reader, &guards_for_reader).await {
|
|
debug!(owner = %owner_for_reader, error = %e, "r2r pooled inbound frame error");
|
|
break;
|
|
}
|
|
}
|
|
Err(e) => {
|
|
debug!(owner = %owner_for_reader, error = %e, "r2r pooled reader ended");
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
r2r_for_reader.outbound.lock().await.remove(&owner_for_reader);
|
|
});
|
|
|
|
tokio::spawn(async move {
|
|
while let Some(frame) = rx.recv().await {
|
|
if let Err(e) = write_frame(&mut writer, &frame).await {
|
|
debug!(error = %e, "r2r pooled writer ended");
|
|
break;
|
|
}
|
|
}
|
|
});
|
|
|
|
Ok(tx)
|
|
}
|
|
|
|
async fn handle_r2r_multiplex_conn(
|
|
stream: TcpStream,
|
|
_addr: SocketAddr,
|
|
_cfg: RelayConfig,
|
|
state: SharedState,
|
|
guards: Arc<RelayGuards>,
|
|
_r2r: Arc<R2rManager>,
|
|
) -> Result<()> {
|
|
let (mut reader, mut writer) = stream.into_split();
|
|
let (tx, mut rx) = mpsc::channel::<R2rFrame>(2048);
|
|
|
|
let _writer_task = tokio::spawn(async move {
|
|
while let Some(frame) = rx.recv().await {
|
|
write_frame(&mut writer, &frame).await?;
|
|
}
|
|
Ok::<(), anyhow::Error>(())
|
|
});
|
|
|
|
loop {
|
|
let frame: R2rFrame = read_frame(&mut reader).await?;
|
|
match frame {
|
|
R2rFrame::Open(prelude) => {
|
|
if prelude.version != 1 || prelude.hop_count > 1 {
|
|
continue;
|
|
}
|
|
if let Some(session) = local_session_for_session_id(&state, &prelude.session_id).await {
|
|
attach_virtual_r2r_stream_to_session(session, prelude, tx.clone()).await?;
|
|
} else {
|
|
let _ = tx.send(R2rFrame::Close(R2rStreamClosed {
|
|
session_id: prelude.session_id,
|
|
stream_id: prelude.stream_id,
|
|
reason: Some("owner_session_not_found".into()),
|
|
})).await;
|
|
}
|
|
}
|
|
R2rFrame::Data(data) => {
|
|
guards
|
|
.throttle_session_bytes(&data.session_id, SessionDir::EgressToClient, data.data.len())
|
|
.await;
|
|
if let Some(session) = local_session_for_session_id(&state, &data.session_id).await {
|
|
let _ = session
|
|
.tx
|
|
.send(ServerFrame::StreamData(StreamData { stream_id: data.stream_id, data: data.data }))
|
|
.await;
|
|
}
|
|
}
|
|
R2rFrame::Close(close) => {
|
|
if let Some(session) = local_session_for_session_id(&state, &close.session_id).await {
|
|
let _ = session
|
|
.tx
|
|
.send(ServerFrame::StreamClosed(StreamClosed { stream_id: close.stream_id.clone(), reason: close.reason.clone() }))
|
|
.await;
|
|
remove_stream_sink(&state, &close.session_id, &close.stream_id).await;
|
|
} else {
|
|
remove_stream_sink(&state, &close.session_id, &close.stream_id).await;
|
|
}
|
|
}
|
|
R2rFrame::Ping => {
|
|
let _ = tx.send(R2rFrame::Pong).await;
|
|
}
|
|
R2rFrame::Pong => {}
|
|
}
|
|
}
|
|
}
|
|
|
|
async fn attach_virtual_r2r_stream_to_session(
|
|
session: SessionHandle,
|
|
prelude: RelayForwardPrelude,
|
|
r2r_tx: mpsc::Sender<R2rFrame>,
|
|
) -> Result<()> {
|
|
let stream_id = prelude.stream_id.clone();
|
|
let session_id = session.session_id.clone();
|
|
let (to_r2r_tx, mut to_r2r_rx) = mpsc::channel::<Vec<u8>>(128);
|
|
session
|
|
.stream_sinks
|
|
.write()
|
|
.await
|
|
.insert(stream_id.clone(), to_r2r_tx);
|
|
|
|
session
|
|
.tx
|
|
.send(ServerFrame::IncomingTcp(IncomingTcp {
|
|
stream_id: stream_id.clone(),
|
|
session_id: session_id.clone(),
|
|
peer_addr: prelude.peer_addr.clone(),
|
|
hostname: prelude.fqdn.clone(),
|
|
initial_data: prelude.initial_data.clone(),
|
|
}))
|
|
.await
|
|
.context("send virtual r2r IncomingTcp to client")?;
|
|
|
|
tokio::spawn(async move {
|
|
while let Some(chunk) = to_r2r_rx.recv().await {
|
|
let _ = r2r_tx
|
|
.send(R2rFrame::Data(R2rStreamData {
|
|
session_id: session_id.clone(),
|
|
stream_id: stream_id.clone(),
|
|
data: chunk,
|
|
}))
|
|
.await;
|
|
}
|
|
let _ = r2r_tx
|
|
.send(R2rFrame::Close(R2rStreamClosed {
|
|
session_id,
|
|
stream_id,
|
|
reason: Some("owner_sink_closed".into()),
|
|
}))
|
|
.await;
|
|
});
|
|
|
|
Ok(())
|
|
}
|
|
|
|
async fn handle_r2r_inbound_frame(
|
|
frame: R2rFrame,
|
|
r2r: &R2rManager,
|
|
_guards: &RelayGuards,
|
|
) -> Result<()> {
|
|
match frame {
|
|
R2rFrame::Data(data) => {
|
|
if let Some(tx) = r2r.ingress_stream_sinks.read().await.get(&data.stream_id).cloned() {
|
|
let _ = tx.send(data.data).await;
|
|
}
|
|
}
|
|
R2rFrame::Close(close) => {
|
|
r2r.ingress_stream_sinks.write().await.remove(&close.stream_id);
|
|
}
|
|
R2rFrame::Ping | R2rFrame::Pong | R2rFrame::Open(_) => {}
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
async fn run_ingress_player_reader_to_r2r(
|
|
mut reader: tokio::net::tcp::OwnedReadHalf,
|
|
tx: mpsc::Sender<R2rFrame>,
|
|
session_id: String,
|
|
stream_id: String,
|
|
) -> Result<()> {
|
|
let mut buf = vec![0u8; 16 * 1024];
|
|
loop {
|
|
let n = reader.read(&mut buf).await?;
|
|
if n == 0 {
|
|
break;
|
|
}
|
|
tx.send(R2rFrame::Data(R2rStreamData {
|
|
session_id: session_id.clone(),
|
|
stream_id: stream_id.clone(),
|
|
data: buf[..n].to_vec(),
|
|
}))
|
|
.await
|
|
.context("send ingress data to r2r")?;
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
async fn run_ingress_player_writer(
|
|
mut writer: tokio::net::tcp::OwnedWriteHalf,
|
|
mut rx: mpsc::Receiver<Vec<u8>>,
|
|
) -> Result<()> {
|
|
while let Some(chunk) = rx.recv().await {
|
|
writer.write_all(&chunk).await?;
|
|
}
|
|
let _ = writer.shutdown().await;
|
|
Ok(())
|
|
}
|
|
|
|
async fn local_session_for_hostname(state: &SharedState, hostname: &str) -> Option<SessionHandle> {
|
|
state.read().await.by_fqdn.get(hostname).cloned()
|
|
}
|
|
|
|
async fn local_session_for_session_id(state: &SharedState, session_id: &str) -> Option<SessionHandle> {
|
|
let guard = state.read().await;
|
|
let fqdn = guard.by_session.get(session_id)?.clone();
|
|
guard.by_fqdn.get(&fqdn).cloned()
|
|
}
|
|
|
|
async fn attach_player_socket_to_session(
|
|
stream: TcpStream,
|
|
session: SessionHandle,
|
|
hostname: String,
|
|
peer_addr: String,
|
|
initial_data: Vec<u8>,
|
|
stream_id_override: Option<String>,
|
|
source: &'static str,
|
|
guards: Arc<RelayGuards>,
|
|
) -> Result<()> {
|
|
let stream_id = stream_id_override.unwrap_or_else(|| Uuid::new_v4().to_string());
|
|
let (player_read, player_write) = stream.into_split();
|
|
let (to_player_tx, to_player_rx) = mpsc::channel::<Vec<u8>>(128);
|
|
session.stream_sinks.write().await.insert(stream_id.clone(), to_player_tx);
|
|
|
|
session
|
|
.tx
|
|
.send(ServerFrame::IncomingTcp(IncomingTcp {
|
|
stream_id: stream_id.clone(),
|
|
session_id: session.session_id.clone(),
|
|
peer_addr: peer_addr.clone(),
|
|
hostname: hostname.clone(),
|
|
initial_data,
|
|
}))
|
|
.await
|
|
.context("send IncomingTcp to client")?;
|
|
|
|
let tx_control = session.tx.clone();
|
|
let stream_id_clone = stream_id.clone();
|
|
let sinks = session.stream_sinks.clone();
|
|
tokio::spawn(async move {
|
|
if let Err(e) = run_player_writer(player_write, to_player_rx).await {
|
|
debug!(stream_id = %stream_id_clone, error = %e, "player writer ended");
|
|
}
|
|
let _ = tx_control
|
|
.send(ServerFrame::StreamClosed(StreamClosed {
|
|
stream_id: stream_id_clone.clone(),
|
|
reason: Some("player_writer_closed".into()),
|
|
}))
|
|
.await;
|
|
let _ = remove_stream_sink_by_store(sinks, &stream_id_clone).await;
|
|
});
|
|
|
|
let tx_control = session.tx.clone();
|
|
let stream_id_clone = stream_id.clone();
|
|
let session_id_clone = session.session_id.clone();
|
|
let sinks = session.stream_sinks.clone();
|
|
let guards_clone = guards.clone();
|
|
tokio::spawn(async move {
|
|
if let Err(e) = run_player_reader(
|
|
player_read,
|
|
tx_control.clone(),
|
|
stream_id_clone.clone(),
|
|
session_id_clone,
|
|
guards_clone,
|
|
)
|
|
.await
|
|
{
|
|
debug!(stream_id = %stream_id_clone, error = %e, "player reader ended");
|
|
}
|
|
let _ = tx_control
|
|
.send(ServerFrame::StreamClosed(StreamClosed {
|
|
stream_id: stream_id_clone.clone(),
|
|
reason: Some("player_reader_closed".into()),
|
|
}))
|
|
.await;
|
|
let _ = remove_stream_sink_by_store(sinks, &stream_id_clone).await;
|
|
});
|
|
|
|
info!(peer = %peer_addr, hostname = %hostname, session_id = %session.session_id, stream_id = %stream_id, source, "player proxied via client stream");
|
|
metrics::gauge!("relay_active_player_conns").increment(1.0);
|
|
Ok(())
|
|
}
|
|
|
|
async fn run_player_reader(
|
|
mut reader: tokio::net::tcp::OwnedReadHalf,
|
|
tx_control: mpsc::Sender<ServerFrame>,
|
|
stream_id: String,
|
|
session_id: String,
|
|
guards: Arc<RelayGuards>,
|
|
) -> Result<()> {
|
|
let mut buf = vec![0u8; 16 * 1024];
|
|
loop {
|
|
let n = reader.read(&mut buf).await?;
|
|
if n == 0 {
|
|
break;
|
|
}
|
|
guards
|
|
.throttle_session_bytes(&session_id, SessionDir::EgressToClient, n)
|
|
.await;
|
|
tx_control
|
|
.send(ServerFrame::StreamData(StreamData {
|
|
stream_id: stream_id.clone(),
|
|
data: buf[..n].to_vec(),
|
|
}))
|
|
.await
|
|
.context("send stream data to client")?;
|
|
metrics::counter!("relay_bytes_out_total").increment(n as u64);
|
|
}
|
|
metrics::gauge!("relay_active_player_conns").decrement(1.0);
|
|
Ok(())
|
|
}
|
|
|
|
async fn run_player_writer(
|
|
mut writer: tokio::net::tcp::OwnedWriteHalf,
|
|
mut rx: mpsc::Receiver<Vec<u8>>,
|
|
) -> Result<()> {
|
|
while let Some(chunk) = rx.recv().await {
|
|
writer.write_all(&chunk).await?;
|
|
metrics::counter!("relay_bytes_in_total").increment(chunk.len() as u64);
|
|
}
|
|
let _ = writer.shutdown().await;
|
|
Ok(())
|
|
}
|
|
|
|
async fn lookup_stream_sink(
|
|
state: &SharedState,
|
|
session_id: &str,
|
|
stream_id: &str,
|
|
) -> Option<mpsc::Sender<Vec<u8>>> {
|
|
let store = {
|
|
let guard = state.read().await;
|
|
let fqdn = guard.by_session.get(session_id)?.clone();
|
|
guard.by_fqdn.get(&fqdn)?.stream_sinks.clone()
|
|
};
|
|
store.read().await.get(stream_id).cloned()
|
|
}
|
|
|
|
async fn remove_stream_sink(state: &SharedState, session_id: &str, stream_id: &str) {
|
|
let store = {
|
|
let guard = state.read().await;
|
|
let Some(fqdn) = guard.by_session.get(session_id).cloned() else { return; };
|
|
let Some(handle) = guard.by_fqdn.get(&fqdn) else { return; };
|
|
handle.stream_sinks.clone()
|
|
};
|
|
let _ = remove_stream_sink_by_store(store, stream_id).await;
|
|
}
|
|
|
|
async fn remove_stream_sink_by_store(
|
|
store: Arc<RwLock<HashMap<String, mpsc::Sender<Vec<u8>>>>>,
|
|
stream_id: &str,
|
|
) -> Option<mpsc::Sender<Vec<u8>>> {
|
|
store.write().await.remove(stream_id)
|
|
}
|
|
|
|
fn token_looks_valid(token: &str) -> bool {
|
|
!token.trim().is_empty()
|
|
}
|
|
|
|
fn fake_user_id_from_token(token: &str) -> String {
|
|
let suffix: String = token.chars().rev().take(6).collect();
|
|
format!("user-{}", suffix.chars().rev().collect::<String>())
|
|
}
|
|
|
|
fn assign_fqdn(cfg: &RelayConfig, req: &RegisterRequest) -> String {
|
|
let label = req
|
|
.requested_subdomain
|
|
.as_ref()
|
|
.filter(|s| !s.trim().is_empty())
|
|
.cloned()
|
|
.unwrap_or_else(random_label);
|
|
format!("{}.{}.{}", sanitize_label(&label), cfg.region, cfg.domain)
|
|
}
|
|
|
|
fn random_label() -> String {
|
|
const ADJ: &[&str] = &["sleepy", "swift", "brave", "quiet", "mossy"];
|
|
const NOUN: &[&str] = &["creeper", "ghast", "axolotl", "wolf", "beacon"];
|
|
format!(
|
|
"{}-{}-{}",
|
|
ADJ[fastrand::usize(..ADJ.len())],
|
|
NOUN[fastrand::usize(..NOUN.len())],
|
|
fastrand::u16(..9999)
|
|
)
|
|
}
|
|
|
|
fn sanitize_label(input: &str) -> String {
|
|
input
|
|
.chars()
|
|
.filter(|c| c.is_ascii_alphanumeric() || *c == '-')
|
|
.collect::<String>()
|
|
.trim_matches('-')
|
|
.to_ascii_lowercase()
|
|
}
|
|
|
|
fn guess_advertise_addr(bind: &str) -> String {
|
|
if let Some((_host, port)) = bind.rsplit_once(':') {
|
|
format!("127.0.0.1:{port}")
|
|
} else {
|
|
"127.0.0.1:7001".to_string()
|
|
}
|
|
}
|
|
|
|
fn init_metrics() -> Result<()> {
|
|
if let Ok(bind) = std::env::var("RELAY_METRICS_BIND") {
|
|
let addr: std::net::SocketAddr = bind.parse().context("parse RELAY_METRICS_BIND")?;
|
|
PrometheusBuilder::new()
|
|
.with_http_listener(addr)
|
|
.install()
|
|
.context("install prometheus exporter")?;
|
|
}
|
|
Ok(())
|
|
}
|