diff --git a/common/src/protocol.rs b/common/src/protocol.rs index b489abe..46595ce 100644 --- a/common/src/protocol.rs +++ b/common/src/protocol.rs @@ -57,6 +57,30 @@ pub struct RelayForwardPrelude { pub initial_data: Vec, } +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct R2rStreamData { + pub session_id: String, + pub stream_id: String, + pub data: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct R2rStreamClosed { + pub session_id: String, + pub stream_id: String, + pub reason: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type", content = "data")] +pub enum R2rFrame { + Open(RelayForwardPrelude), + Data(R2rStreamData), + Close(R2rStreamClosed), + Ping, + Pong, +} + #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(tag = "type", content = "data")] pub enum ClientFrame { diff --git a/relay/src/main.rs b/relay/src/main.rs index 6750e4b..46d86b7 100644 --- a/relay/src/main.rs +++ b/relay/src/main.rs @@ -11,14 +11,14 @@ use common::{ minecraft::read_handshake_hostname_and_bytes, protocol::{ ClientFrame, Heartbeat, IncomingTcp, RegisterAccepted, RegisterRequest, RelayForwardPrelude, - ServerFrame, StreamClosed, StreamData, + R2rFrame, R2rStreamClosed, R2rStreamData, ServerFrame, StreamClosed, StreamData, }, }; use redis::AsyncCommands; use serde::Deserialize; use metrics_exporter_prometheus::PrometheusBuilder; use tokio::{ - io::{AsyncReadExt, AsyncWriteExt, copy_bidirectional}, + io::{AsyncReadExt, AsyncWriteExt}, net::{TcpListener, TcpStream}, sync::{Mutex, Notify, RwLock, mpsc}, time::{MissedTickBehavior, interval, timeout}, @@ -107,6 +107,21 @@ impl RelayState { type SharedState = Arc>; +#[derive(Clone)] +struct R2rManager { + outbound: Arc>>>, + ingress_stream_sinks: Arc>>>>, +} + +impl R2rManager { + fn new() -> Self { + Self { + outbound: Arc::new(Mutex::new(HashMap::new())), + ingress_stream_sinks: Arc::new(RwLock::new(HashMap::new())), + } + } +} + #[derive(Clone)] struct RelayGuards { player_ip: Arc>>, @@ -448,6 +463,7 @@ async fn main() -> Result<()> { let cfg = RelayConfig::from_env(); let registry = RedisRegistry::from_env(&cfg).await; let guards = Arc::new(RelayGuards::from_env()); + let r2r = Arc::new(R2rManager::new()); registry.register_instance().await; let control_listener = TcpListener::bind(&cfg.control_bind) @@ -481,6 +497,7 @@ async fn main() -> Result<()> { state.clone(), registry.clone(), guards.clone(), + r2r.clone(), shutdown.clone(), )); let r2r_task = tokio::spawn(run_r2r_accept_loop( @@ -488,6 +505,7 @@ async fn main() -> Result<()> { cfg.clone(), state.clone(), guards.clone(), + r2r.clone(), shutdown.clone(), )); @@ -565,6 +583,7 @@ async fn run_player_accept_loop( state: SharedState, registry: RedisRegistry, guards: Arc, + r2r: Arc, shutdown: Arc, ) -> Result<()> { loop { @@ -580,8 +599,9 @@ async fn run_player_accept_loop( let state = state.clone(); let registry = registry.clone(); let guards = guards.clone(); + let r2r = r2r.clone(); tokio::spawn(async move { - if let Err(e) = handle_player_conn(stream, addr, cfg, state, registry, guards).await { + if let Err(e) = handle_player_conn(stream, addr, cfg, state, registry, guards, r2r).await { debug!(peer = %addr, error = %e, "player connection closed"); } }); @@ -596,6 +616,7 @@ async fn run_r2r_accept_loop( cfg: RelayConfig, state: SharedState, guards: Arc, + r2r: Arc, shutdown: Arc, ) -> Result<()> { loop { @@ -610,8 +631,9 @@ async fn run_r2r_accept_loop( let cfg = cfg.clone(); let state = state.clone(); let guards = guards.clone(); + let r2r = r2r.clone(); tokio::spawn(async move { - if let Err(e) = handle_r2r_conn(stream, addr, cfg, state, guards).await { + if let Err(e) = handle_r2r_conn(stream, addr, cfg, state, guards, r2r).await { warn!(peer = %addr, error = %e, "r2r connection ended with error"); } }); @@ -760,7 +782,7 @@ async fn control_read_loop( } } -#[tracing::instrument(skip(stream, cfg, state, registry, guards), fields(peer = %addr))] +#[tracing::instrument(skip(stream, cfg, state, registry, guards, r2r), fields(peer = %addr))] async fn handle_player_conn( mut stream: TcpStream, addr: SocketAddr, @@ -768,6 +790,7 @@ async fn handle_player_conn( state: SharedState, registry: RedisRegistry, guards: Arc, + r2r: Arc, ) -> Result<()> { if !guards.allow_player_ip(&addr.ip().to_string()).await { metrics::counter!("relay_rate_limited_total", "scope" => "player_ip").increment(1); @@ -798,57 +821,38 @@ async fn handle_player_conn( debug!(peer = %addr, hostname = %hostname, session_id = %route.session_id, "route points to self but local session missing"); return Ok(()); } - return proxy_player_to_owner(stream, addr, hostname, initial_data, route, cfg, registry).await; + return proxy_player_to_owner(stream, addr, hostname, initial_data, route, cfg, registry, guards, r2r).await; } debug!(peer = %addr, hostname = %hostname, "no tunnel for hostname"); Ok(()) } -#[tracing::instrument(skip(stream, _cfg, state, guards), fields(peer = %addr))] +#[tracing::instrument(skip(stream, cfg, state, guards, r2r), fields(peer = %addr))] async fn handle_r2r_conn( - mut stream: TcpStream, + stream: TcpStream, addr: SocketAddr, - _cfg: RelayConfig, + cfg: RelayConfig, state: SharedState, guards: Arc, + r2r: Arc, ) -> Result<()> { - let prelude: RelayForwardPrelude = read_frame(&mut stream).await.context("read r2r prelude")?; - if prelude.version != 1 { - anyhow::bail!("unsupported r2r prelude version {}", prelude.version); - } - if prelude.hop_count > 1 { - anyhow::bail!("invalid hop_count {}", prelude.hop_count); - } - - let session = local_session_for_session_id(&state, &prelude.session_id).await; - let Some(session) = session else { - anyhow::bail!("owner session not found for {}", prelude.session_id); - }; - - attach_player_socket_to_session( - stream, - session, - prelude.fqdn.clone(), - prelude.peer_addr, - prelude.initial_data, - Some(prelude.stream_id), - "r2r", - guards, - ) - .await - .with_context(|| format!("r2r attach failed from {addr}")) + handle_r2r_multiplex_conn(stream, addr, cfg, state, guards, r2r) + .await + .with_context(|| format!("r2r multiplex failed from {addr}")) } -#[tracing::instrument(skip(player_stream, route, cfg, registry), fields(peer = %player_addr, hostname = %hostname))] +#[tracing::instrument(skip(player_stream, route, cfg, registry, guards, r2r), fields(peer = %player_addr, hostname = %hostname))] async fn proxy_player_to_owner( - mut player_stream: TcpStream, + player_stream: TcpStream, player_addr: SocketAddr, hostname: String, initial_data: Vec, route: TunnelRouteRecord, cfg: RelayConfig, registry: RedisRegistry, + guards: Arc, + r2r: Arc, ) -> Result<()> { let redis_lookup_started = Instant::now(); let owner = registry @@ -863,27 +867,326 @@ async fn proxy_player_to_owner( .with_context(|| format!("owner {} missing r2r_addr", route.instance_id))?; let r2r_connect_started = Instant::now(); - let mut owner_stream = timeout(cfg.r2r_connect_timeout, TcpStream::connect(&r2r_addr)) - .await - .context("r2r connect timeout")??; metrics::histogram!("relay_r2r_connect_latency_ms") .record(r2r_connect_started.elapsed().as_secs_f64() * 1000.0); let prelude = RelayForwardPrelude { version: 1, - session_id: route.session_id, + session_id: route.session_id.clone(), fqdn: hostname.clone(), stream_id: Uuid::new_v4().to_string(), peer_addr: player_addr.to_string(), - origin_instance_id: cfg.instance_id, + origin_instance_id: cfg.instance_id.clone(), hop_count: 1, initial_data, }; - write_frame(&mut owner_stream, &prelude).await?; + proxy_player_to_owner_pooled( + player_stream, + player_addr, + hostname, + route.instance_id, + r2r_addr, + prelude, + route.session_id, + cfg, + guards, + r2r, + ) + .await +} + +async fn proxy_player_to_owner_pooled( + player_stream: TcpStream, + player_addr: SocketAddr, + hostname: String, + owner_instance_id: String, + owner_r2r_addr: String, + prelude: RelayForwardPrelude, + session_id: String, + cfg: RelayConfig, + guards: Arc, + r2r: Arc, +) -> Result<()> { + let stream_id = prelude.stream_id.clone(); + let sender = get_or_connect_r2r_pool( + owner_instance_id.clone(), + owner_r2r_addr, + cfg, + guards, + r2r.clone(), + ) + .await?; + + let (player_read, player_write) = player_stream.into_split(); + let (to_player_tx, to_player_rx) = mpsc::channel::>(128); + r2r.ingress_stream_sinks + .write() + .await + .insert(stream_id.clone(), to_player_tx); + + sender + .send(R2rFrame::Open(prelude)) + .await + .context("send r2r open")?; + + let tx = sender.clone(); + let sid = session_id.clone(); + let stid = stream_id.clone(); + let sinks = r2r.ingress_stream_sinks.clone(); + tokio::spawn(async move { + if let Err(e) = + run_ingress_player_reader_to_r2r(player_read, tx.clone(), sid.clone(), stid.clone()).await + { + debug!(stream_id = %stid, error = %e, "ingress player->r2r reader ended"); + } + let _ = tx + .send(R2rFrame::Close(R2rStreamClosed { + session_id: sid, + stream_id: stid.clone(), + reason: Some("ingress_player_reader_closed".into()), + })) + .await; + let _ = sinks.write().await.remove(&stid); + }); + + let stid = stream_id.clone(); + let sinks = r2r.ingress_stream_sinks.clone(); + tokio::spawn(async move { + if let Err(e) = run_ingress_player_writer(player_write, to_player_rx).await { + debug!(stream_id = %stid, error = %e, "ingress r2r->player writer ended"); + } + let _ = sinks.write().await.remove(&stid); + }); - let _ = copy_bidirectional(&mut player_stream, &mut owner_stream).await?; metrics::counter!("relay_r2r_forwards_total").increment(1); - info!(peer = %player_addr, hostname = %hostname, owner = %route.instance_id, "proxied player connection to owner relay"); + info!(peer = %player_addr, hostname = %hostname, owner = %owner_instance_id, stream_id = %stream_id, "proxied player via pooled r2r channel"); + Ok(()) +} + +async fn get_or_connect_r2r_pool( + owner_instance_id: String, + owner_r2r_addr: String, + cfg: RelayConfig, + guards: Arc, + r2r: Arc, +) -> Result> { + if let Some(existing) = r2r.outbound.lock().await.get(&owner_instance_id).cloned() { + return Ok(existing); + } + + let connect_started = Instant::now(); + let stream = timeout(cfg.r2r_connect_timeout, TcpStream::connect(&owner_r2r_addr)) + .await + .context("r2r connect timeout")??; + metrics::histogram!("relay_r2r_connect_latency_ms") + .record(connect_started.elapsed().as_secs_f64() * 1000.0); + + let (mut reader, mut writer) = stream.into_split(); + let (tx, mut rx) = mpsc::channel::(2048); + + let mut pools = r2r.outbound.lock().await; + if let Some(existing) = pools.get(&owner_instance_id).cloned() { + return Ok(existing); + } + pools.insert(owner_instance_id.clone(), tx.clone()); + drop(pools); + + let owner_for_reader = owner_instance_id.clone(); + let r2r_for_reader = r2r.clone(); + let guards_for_reader = guards.clone(); + tokio::spawn(async move { + loop { + match read_frame::<_, R2rFrame>(&mut reader).await { + Ok(frame) => { + if let Err(e) = handle_r2r_inbound_frame(frame, &r2r_for_reader, &guards_for_reader).await { + debug!(owner = %owner_for_reader, error = %e, "r2r pooled inbound frame error"); + break; + } + } + Err(e) => { + debug!(owner = %owner_for_reader, error = %e, "r2r pooled reader ended"); + break; + } + } + } + r2r_for_reader.outbound.lock().await.remove(&owner_for_reader); + }); + + tokio::spawn(async move { + while let Some(frame) = rx.recv().await { + if let Err(e) = write_frame(&mut writer, &frame).await { + debug!(error = %e, "r2r pooled writer ended"); + break; + } + } + }); + + Ok(tx) +} + +async fn handle_r2r_multiplex_conn( + stream: TcpStream, + _addr: SocketAddr, + _cfg: RelayConfig, + state: SharedState, + guards: Arc, + _r2r: Arc, +) -> Result<()> { + let (mut reader, mut writer) = stream.into_split(); + let (tx, mut rx) = mpsc::channel::(2048); + + let _writer_task = tokio::spawn(async move { + while let Some(frame) = rx.recv().await { + write_frame(&mut writer, &frame).await?; + } + Ok::<(), anyhow::Error>(()) + }); + + loop { + let frame: R2rFrame = read_frame(&mut reader).await?; + match frame { + R2rFrame::Open(prelude) => { + if prelude.version != 1 || prelude.hop_count > 1 { + continue; + } + if let Some(session) = local_session_for_session_id(&state, &prelude.session_id).await { + attach_virtual_r2r_stream_to_session(session, prelude, tx.clone()).await?; + } else { + let _ = tx.send(R2rFrame::Close(R2rStreamClosed { + session_id: prelude.session_id, + stream_id: prelude.stream_id, + reason: Some("owner_session_not_found".into()), + })).await; + } + } + R2rFrame::Data(data) => { + guards + .throttle_session_bytes(&data.session_id, SessionDir::EgressToClient, data.data.len()) + .await; + if let Some(session) = local_session_for_session_id(&state, &data.session_id).await { + let _ = session + .tx + .send(ServerFrame::StreamData(StreamData { stream_id: data.stream_id, data: data.data })) + .await; + } + } + R2rFrame::Close(close) => { + if let Some(session) = local_session_for_session_id(&state, &close.session_id).await { + let _ = session + .tx + .send(ServerFrame::StreamClosed(StreamClosed { stream_id: close.stream_id.clone(), reason: close.reason.clone() })) + .await; + remove_stream_sink(&state, &close.session_id, &close.stream_id).await; + } else { + remove_stream_sink(&state, &close.session_id, &close.stream_id).await; + } + } + R2rFrame::Ping => { + let _ = tx.send(R2rFrame::Pong).await; + } + R2rFrame::Pong => {} + } + } +} + +async fn attach_virtual_r2r_stream_to_session( + session: SessionHandle, + prelude: RelayForwardPrelude, + r2r_tx: mpsc::Sender, +) -> Result<()> { + let stream_id = prelude.stream_id.clone(); + let session_id = session.session_id.clone(); + let (to_r2r_tx, mut to_r2r_rx) = mpsc::channel::>(128); + session + .stream_sinks + .write() + .await + .insert(stream_id.clone(), to_r2r_tx); + + session + .tx + .send(ServerFrame::IncomingTcp(IncomingTcp { + stream_id: stream_id.clone(), + session_id: session_id.clone(), + peer_addr: prelude.peer_addr.clone(), + hostname: prelude.fqdn.clone(), + initial_data: prelude.initial_data.clone(), + })) + .await + .context("send virtual r2r IncomingTcp to client")?; + + tokio::spawn(async move { + while let Some(chunk) = to_r2r_rx.recv().await { + let _ = r2r_tx + .send(R2rFrame::Data(R2rStreamData { + session_id: session_id.clone(), + stream_id: stream_id.clone(), + data: chunk, + })) + .await; + } + let _ = r2r_tx + .send(R2rFrame::Close(R2rStreamClosed { + session_id, + stream_id, + reason: Some("owner_sink_closed".into()), + })) + .await; + }); + + Ok(()) +} + +async fn handle_r2r_inbound_frame( + frame: R2rFrame, + r2r: &R2rManager, + _guards: &RelayGuards, +) -> Result<()> { + match frame { + R2rFrame::Data(data) => { + if let Some(tx) = r2r.ingress_stream_sinks.read().await.get(&data.stream_id).cloned() { + let _ = tx.send(data.data).await; + } + } + R2rFrame::Close(close) => { + r2r.ingress_stream_sinks.write().await.remove(&close.stream_id); + } + R2rFrame::Ping | R2rFrame::Pong | R2rFrame::Open(_) => {} + } + Ok(()) +} + +async fn run_ingress_player_reader_to_r2r( + mut reader: tokio::net::tcp::OwnedReadHalf, + tx: mpsc::Sender, + session_id: String, + stream_id: String, +) -> Result<()> { + let mut buf = vec![0u8; 16 * 1024]; + loop { + let n = reader.read(&mut buf).await?; + if n == 0 { + break; + } + tx.send(R2rFrame::Data(R2rStreamData { + session_id: session_id.clone(), + stream_id: stream_id.clone(), + data: buf[..n].to_vec(), + })) + .await + .context("send ingress data to r2r")?; + } + Ok(()) +} + +async fn run_ingress_player_writer( + mut writer: tokio::net::tcp::OwnedWriteHalf, + mut rx: mpsc::Receiver>, +) -> Result<()> { + while let Some(chunk) = rx.recv().await { + writer.write_all(&chunk).await?; + } + let _ = writer.shutdown().await; Ok(()) }