Compare commits
6 Commits
09205f8db2
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
230a9212fe | ||
|
|
28918880da | ||
|
|
37090d80b0 | ||
|
|
a45a9b0392 | ||
|
|
4ce94a5b17 | ||
|
|
fe8376dd6d |
901
Cargo.lock
generated
901
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -19,3 +19,6 @@ axum = "0.8"
|
||||
redis = { version = "0.32", features = ["tokio-comp", "connection-manager"] }
|
||||
jsonwebtoken = "10"
|
||||
chrono = { version = "0.4", features = ["serde", "clock"] }
|
||||
metrics = "0.24"
|
||||
metrics-exporter-prometheus = "0.17"
|
||||
tokio-postgres = { version = "0.7", features = ["with-chrono-0_4", "with-serde_json-1"] }
|
||||
|
||||
@@ -15,3 +15,7 @@ tokio.workspace = true
|
||||
tracing.workspace = true
|
||||
tracing-subscriber.workspace = true
|
||||
fastrand.workspace = true
|
||||
metrics.workspace = true
|
||||
metrics-exporter-prometheus.workspace = true
|
||||
tokio-postgres.workspace = true
|
||||
uuid.workspace = true
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use std::{net::SocketAddr, sync::Arc};
|
||||
use std::{net::SocketAddr, sync::Arc, time::Instant};
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use axum::{
|
||||
@@ -10,14 +10,19 @@ use axum::{
|
||||
};
|
||||
use chrono::{Duration, Utc};
|
||||
use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, Validation, decode, encode};
|
||||
use metrics_exporter_prometheus::{PrometheusBuilder, PrometheusHandle};
|
||||
use redis::AsyncCommands;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio_postgres::{Client as PgClient, NoTls};
|
||||
use tracing::{info, warn};
|
||||
use uuid::Uuid;
|
||||
|
||||
#[derive(Clone)]
|
||||
struct AppState {
|
||||
jwt_secret: Arc<String>,
|
||||
redis: Option<redis::aio::ConnectionManager>,
|
||||
pg: Option<Arc<PgClient>>,
|
||||
metrics: PrometheusHandle,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
@@ -58,12 +63,9 @@ struct ValidateResponse {
|
||||
max_tunnels: Option<u32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct StripeWebhookEvent {
|
||||
event_type: String,
|
||||
user_id: String,
|
||||
#[derive(Debug, Clone)]
|
||||
struct PlanEntitlement {
|
||||
tier: String,
|
||||
#[serde(default = "default_max_tunnels")]
|
||||
max_tunnels: u32,
|
||||
}
|
||||
|
||||
@@ -78,6 +80,10 @@ async fn main() -> Result<()> {
|
||||
|
||||
let bind = std::env::var("AUTH_BIND").unwrap_or_else(|_| "0.0.0.0:8080".into());
|
||||
let jwt_secret = std::env::var("JWT_SECRET").unwrap_or_else(|_| "dev-secret-change-me".into());
|
||||
let metrics = PrometheusBuilder::new()
|
||||
.install_recorder()
|
||||
.context("install prometheus recorder")?;
|
||||
|
||||
let redis = if let Ok(url) = std::env::var("REDIS_URL") {
|
||||
let client = redis::Client::open(url.clone()).context("open redis client")?;
|
||||
match redis::aio::ConnectionManager::new(client).await {
|
||||
@@ -94,16 +100,20 @@ async fn main() -> Result<()> {
|
||||
None
|
||||
};
|
||||
|
||||
let pg = connect_postgres().await?;
|
||||
|
||||
let state = AppState {
|
||||
jwt_secret: Arc::new(jwt_secret),
|
||||
redis,
|
||||
pg,
|
||||
metrics,
|
||||
};
|
||||
|
||||
let app = Router::new()
|
||||
.route("/healthz", get(healthz))
|
||||
.route("/metrics", get(metrics_endpoint))
|
||||
.route("/v1/token/dev", post(issue_dev_token))
|
||||
.route("/v1/token/validate", post(validate_token))
|
||||
.route("/v1/stripe/webhook", post(stripe_webhook))
|
||||
.with_state(state);
|
||||
|
||||
let listener = tokio::net::TcpListener::bind(&bind)
|
||||
@@ -115,14 +125,36 @@ async fn main() -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn connect_postgres() -> Result<Option<Arc<PgClient>>> {
|
||||
let Some(url) = std::env::var("DATABASE_URL").ok() else {
|
||||
return Ok(None);
|
||||
};
|
||||
let (client, conn) = tokio_postgres::connect(&url, NoTls)
|
||||
.await
|
||||
.context("connect postgres")?;
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = conn.await {
|
||||
warn!(error = %e, "postgres connection task ended");
|
||||
}
|
||||
});
|
||||
info!("auth-api connected to postgres");
|
||||
Ok(Some(Arc::new(client)))
|
||||
}
|
||||
|
||||
async fn healthz() -> &'static str {
|
||||
"ok"
|
||||
}
|
||||
|
||||
async fn metrics_endpoint(State(state): State<AppState>) -> impl IntoResponse {
|
||||
state.metrics.render()
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(state))]
|
||||
async fn issue_dev_token(
|
||||
State(state): State<AppState>,
|
||||
Json(req): Json<DevTokenRequest>,
|
||||
) -> Result<Json<TokenResponse>, ApiError> {
|
||||
let started = Instant::now();
|
||||
let now = Utc::now();
|
||||
let exp = now + Duration::hours(24);
|
||||
let claims = Claims {
|
||||
@@ -141,9 +173,68 @@ async fn issue_dev_token(
|
||||
)
|
||||
.map_err(ApiError::internal)?;
|
||||
|
||||
sync_jwt_cache(&state, &claims).await?;
|
||||
metrics::counter!("auth_jwt_issued_total").increment(1);
|
||||
metrics::histogram!("auth_issue_token_latency_ms")
|
||||
.record(started.elapsed().as_secs_f64() * 1000.0);
|
||||
|
||||
Ok(Json(TokenResponse {
|
||||
token,
|
||||
expires_at: exp.timestamp(),
|
||||
}))
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(state, req))]
|
||||
async fn validate_token(
|
||||
State(state): State<AppState>,
|
||||
Json(req): Json<ValidateRequest>,
|
||||
) -> Result<Json<ValidateResponse>, ApiError> {
|
||||
let started = Instant::now();
|
||||
let decoded = decode::<Claims>(
|
||||
&req.token,
|
||||
&DecodingKey::from_secret(state.jwt_secret.as_bytes()),
|
||||
&Validation::new(Algorithm::HS256),
|
||||
);
|
||||
|
||||
let response = match decoded {
|
||||
Ok(tok) => {
|
||||
let claims = tok.claims;
|
||||
let ent = match load_plan_for_user(&state, &claims.sub).await? {
|
||||
Some(db_plan) => db_plan,
|
||||
None => PlanEntitlement {
|
||||
tier: claims.tier.clone(),
|
||||
max_tunnels: claims.max_tunnels,
|
||||
},
|
||||
};
|
||||
sync_plan_cache(&state, &claims.sub, &ent, "auth-validate").await?;
|
||||
metrics::counter!("auth_token_validate_total", "result" => "valid").increment(1);
|
||||
ValidateResponse {
|
||||
valid: true,
|
||||
user_id: Some(claims.sub),
|
||||
tier: Some(ent.tier),
|
||||
max_tunnels: Some(ent.max_tunnels),
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
metrics::counter!("auth_token_validate_total", "result" => "invalid").increment(1);
|
||||
ValidateResponse {
|
||||
valid: false,
|
||||
user_id: None,
|
||||
tier: None,
|
||||
max_tunnels: None,
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
metrics::histogram!("auth_validate_token_latency_ms")
|
||||
.record(started.elapsed().as_secs_f64() * 1000.0);
|
||||
Ok(Json(response))
|
||||
}
|
||||
|
||||
async fn sync_jwt_cache(state: &AppState, claims: &Claims) -> Result<(), ApiError> {
|
||||
if let Some(mut redis) = state.redis.clone() {
|
||||
let key = format!("auth:jwt:jti:{}", claims.jti);
|
||||
let ttl = (claims.exp as i64 - now.timestamp()).max(1);
|
||||
let ttl = (claims.exp as i64 - Utc::now().timestamp()).max(1);
|
||||
let payload = serde_json::json!({
|
||||
"user_id": claims.sub,
|
||||
"plan_tier": claims.tier,
|
||||
@@ -155,69 +246,53 @@ async fn issue_dev_token(
|
||||
.await
|
||||
.map_err(ApiError::internal)?;
|
||||
}
|
||||
|
||||
Ok(Json(TokenResponse {
|
||||
token,
|
||||
expires_at: exp.timestamp(),
|
||||
}))
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn validate_token(
|
||||
State(state): State<AppState>,
|
||||
Json(req): Json<ValidateRequest>,
|
||||
) -> Result<Json<ValidateResponse>, ApiError> {
|
||||
let decoded = decode::<Claims>(
|
||||
&req.token,
|
||||
&DecodingKey::from_secret(state.jwt_secret.as_bytes()),
|
||||
&Validation::new(Algorithm::HS256),
|
||||
);
|
||||
|
||||
match decoded {
|
||||
Ok(tok) => {
|
||||
let c = tok.claims;
|
||||
if let Some(mut redis) = state.redis.clone() {
|
||||
let key = format!("plan:user:{}", c.sub);
|
||||
let payload = serde_json::json!({
|
||||
"tier": c.tier,
|
||||
"max_tunnels": c.max_tunnels,
|
||||
"source": "auth-api"
|
||||
})
|
||||
.to_string();
|
||||
let _: () = redis.set_ex(key, payload, 300).await.map_err(ApiError::internal)?;
|
||||
}
|
||||
Ok(Json(ValidateResponse {
|
||||
valid: true,
|
||||
user_id: Some(c.sub),
|
||||
tier: Some(c.tier),
|
||||
max_tunnels: Some(c.max_tunnels),
|
||||
}))
|
||||
}
|
||||
Err(_) => Ok(Json(ValidateResponse {
|
||||
valid: false,
|
||||
user_id: None,
|
||||
tier: None,
|
||||
max_tunnels: None,
|
||||
})),
|
||||
}
|
||||
}
|
||||
|
||||
async fn stripe_webhook(
|
||||
State(state): State<AppState>,
|
||||
Json(event): Json<StripeWebhookEvent>,
|
||||
) -> Result<impl IntoResponse, ApiError> {
|
||||
async fn sync_plan_cache(
|
||||
state: &AppState,
|
||||
user_id: &str,
|
||||
ent: &PlanEntitlement,
|
||||
source: &str,
|
||||
) -> Result<(), ApiError> {
|
||||
if let Some(mut redis) = state.redis.clone() {
|
||||
let key = format!("plan:user:{}", event.user_id);
|
||||
let key = format!("plan:user:{user_id}");
|
||||
let payload = serde_json::json!({
|
||||
"tier": event.tier,
|
||||
"max_tunnels": event.max_tunnels,
|
||||
"source": "stripe_webhook",
|
||||
"last_event_type": event.event_type,
|
||||
"updated_at": Utc::now().timestamp(),
|
||||
"tier": ent.tier,
|
||||
"max_tunnels": ent.max_tunnels,
|
||||
"source": source,
|
||||
"updated_at": Utc::now().timestamp()
|
||||
})
|
||||
.to_string();
|
||||
let _: () = redis.set_ex(key, payload, 300).await.map_err(ApiError::internal)?;
|
||||
}
|
||||
Ok(StatusCode::NO_CONTENT)
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn load_plan_for_user(state: &AppState, user_id: &str) -> Result<Option<PlanEntitlement>, ApiError> {
|
||||
let Some(pg) = &state.pg else {
|
||||
return Ok(None);
|
||||
};
|
||||
let _ = Uuid::parse_str(user_id).map_err(ApiError::bad_request)?;
|
||||
let row = pg
|
||||
.query_opt(
|
||||
r#"
|
||||
select p.id as plan_id, p.max_tunnels
|
||||
from subscriptions s
|
||||
join plans p on p.id = s.plan_id
|
||||
where s.user_id = $1::uuid and s.status in ('active', 'trialing')
|
||||
order by s.updated_at desc
|
||||
limit 1
|
||||
"#,
|
||||
&[&user_id],
|
||||
)
|
||||
.await
|
||||
.map_err(ApiError::internal)?;
|
||||
|
||||
Ok(row.map(|r| PlanEntitlement {
|
||||
tier: r.get::<_, String>("plan_id"),
|
||||
max_tunnels: r.get::<_, i32>("max_tunnels") as u32,
|
||||
}))
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
@@ -233,6 +308,14 @@ impl ApiError {
|
||||
message: e.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
fn bad_request<E: std::fmt::Display>(e: E) -> Self {
|
||||
Self {
|
||||
status: StatusCode::BAD_REQUEST,
|
||||
message: e.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
impl IntoResponse for ApiError {
|
||||
|
||||
@@ -57,6 +57,30 @@ pub struct RelayForwardPrelude {
|
||||
pub initial_data: Vec<u8>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct R2rStreamData {
|
||||
pub session_id: String,
|
||||
pub stream_id: String,
|
||||
pub data: Vec<u8>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct R2rStreamClosed {
|
||||
pub session_id: String,
|
||||
pub stream_id: String,
|
||||
pub reason: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(tag = "type", content = "data")]
|
||||
pub enum R2rFrame {
|
||||
Open(RelayForwardPrelude),
|
||||
Data(R2rStreamData),
|
||||
Close(R2rStreamClosed),
|
||||
Ping,
|
||||
Pong,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(tag = "type", content = "data")]
|
||||
pub enum ClientFrame {
|
||||
|
||||
@@ -35,6 +35,9 @@ create table if not exists subscriptions (
|
||||
);
|
||||
|
||||
create index if not exists subscriptions_user_id_idx on subscriptions(user_id);
|
||||
create unique index if not exists subscriptions_provider_subscription_id_uidx
|
||||
on subscriptions(provider_subscription_id)
|
||||
where provider_subscription_id is not null;
|
||||
|
||||
create table if not exists tunnels (
|
||||
id uuid primary key default gen_random_uuid(),
|
||||
|
||||
@@ -14,4 +14,6 @@ redis.workspace = true
|
||||
serde_json.workspace = true
|
||||
chrono.workspace = true
|
||||
serde.workspace = true
|
||||
metrics.workspace = true
|
||||
metrics-exporter-prometheus.workspace = true
|
||||
common = { path = "../common" }
|
||||
|
||||
@@ -11,13 +11,14 @@ use common::{
|
||||
minecraft::read_handshake_hostname_and_bytes,
|
||||
protocol::{
|
||||
ClientFrame, Heartbeat, IncomingTcp, RegisterAccepted, RegisterRequest, RelayForwardPrelude,
|
||||
ServerFrame, StreamClosed, StreamData,
|
||||
R2rFrame, R2rStreamClosed, R2rStreamData, ServerFrame, StreamClosed, StreamData,
|
||||
},
|
||||
};
|
||||
use redis::AsyncCommands;
|
||||
use serde::Deserialize;
|
||||
use metrics_exporter_prometheus::PrometheusBuilder;
|
||||
use tokio::{
|
||||
io::{AsyncReadExt, AsyncWriteExt, copy_bidirectional},
|
||||
io::{AsyncReadExt, AsyncWriteExt},
|
||||
net::{TcpListener, TcpStream},
|
||||
sync::{Mutex, Notify, RwLock, mpsc},
|
||||
time::{MissedTickBehavior, interval, timeout},
|
||||
@@ -106,6 +107,21 @@ impl RelayState {
|
||||
|
||||
type SharedState = Arc<RwLock<RelayState>>;
|
||||
|
||||
#[derive(Clone)]
|
||||
struct R2rManager {
|
||||
outbound: Arc<Mutex<HashMap<String, mpsc::Sender<R2rFrame>>>>,
|
||||
ingress_stream_sinks: Arc<RwLock<HashMap<String, mpsc::Sender<Vec<u8>>>>>,
|
||||
}
|
||||
|
||||
impl R2rManager {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
outbound: Arc::new(Mutex::new(HashMap::new())),
|
||||
ingress_stream_sinks: Arc::new(RwLock::new(HashMap::new())),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct RelayGuards {
|
||||
player_ip: Arc<Mutex<HashMap<String, BucketState>>>,
|
||||
@@ -118,6 +134,11 @@ struct RelayGuards {
|
||||
reg_ip_burst: f64,
|
||||
session_bw_rate_bytes: f64,
|
||||
session_bw_burst_bytes: f64,
|
||||
redis: Option<redis::aio::ConnectionManager>,
|
||||
player_global_window_secs: u64,
|
||||
player_global_limit: i64,
|
||||
reg_global_window_secs: u64,
|
||||
reg_global_limit: i64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
@@ -170,7 +191,7 @@ impl BucketState {
|
||||
}
|
||||
|
||||
impl RelayGuards {
|
||||
fn from_env() -> Self {
|
||||
async fn from_env() -> Self {
|
||||
let player_ip_rate = std::env::var("RELAY_PLAYER_CONNECTS_PER_SEC")
|
||||
.ok()
|
||||
.and_then(|v| v.parse().ok())
|
||||
@@ -200,6 +221,29 @@ impl RelayGuards {
|
||||
.and_then(|v| v.parse().ok())
|
||||
.unwrap_or(512.0)
|
||||
* 1024.0;
|
||||
let redis = match std::env::var("REDIS_URL") {
|
||||
Ok(url) => match redis::Client::open(url) {
|
||||
Ok(client) => redis::aio::ConnectionManager::new(client).await.ok(),
|
||||
Err(_) => None,
|
||||
},
|
||||
Err(_) => None,
|
||||
};
|
||||
let player_global_window_secs = std::env::var("RELAY_PLAYER_GLOBAL_WINDOW_SECS")
|
||||
.ok()
|
||||
.and_then(|v| v.parse().ok())
|
||||
.unwrap_or(10);
|
||||
let player_global_limit = std::env::var("RELAY_PLAYER_GLOBAL_LIMIT")
|
||||
.ok()
|
||||
.and_then(|v| v.parse().ok())
|
||||
.unwrap_or(50);
|
||||
let reg_global_window_secs = std::env::var("RELAY_REG_GLOBAL_WINDOW_SECS")
|
||||
.ok()
|
||||
.and_then(|v| v.parse().ok())
|
||||
.unwrap_or(60);
|
||||
let reg_global_limit = std::env::var("RELAY_REG_GLOBAL_LIMIT")
|
||||
.ok()
|
||||
.and_then(|v| v.parse().ok())
|
||||
.unwrap_or(10);
|
||||
|
||||
Self {
|
||||
player_ip: Arc::new(Mutex::new(HashMap::new())),
|
||||
@@ -212,6 +256,11 @@ impl RelayGuards {
|
||||
reg_ip_burst,
|
||||
session_bw_rate_bytes,
|
||||
session_bw_burst_bytes,
|
||||
redis,
|
||||
player_global_window_secs,
|
||||
player_global_limit,
|
||||
reg_global_window_secs,
|
||||
reg_global_limit,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -233,7 +282,18 @@ impl RelayGuards {
|
||||
let bucket = guard
|
||||
.entry(ip.to_string())
|
||||
.or_insert_with(|| BucketState::new(burst, rate));
|
||||
bucket.reserve_delay(1).is_zero()
|
||||
let local_ok = bucket.reserve_delay(1).is_zero();
|
||||
drop(guard);
|
||||
if !local_ok {
|
||||
return false;
|
||||
}
|
||||
|
||||
let (window_secs, limit, scope) = if player {
|
||||
(self.player_global_window_secs, self.player_global_limit, "mc")
|
||||
} else {
|
||||
(self.reg_global_window_secs, self.reg_global_limit, "reg")
|
||||
};
|
||||
self.redis_allow_ip_window(ip, scope, window_secs, limit).await
|
||||
}
|
||||
|
||||
async fn throttle_session_bytes(&self, session_id: &str, dir: SessionDir, bytes: usize) {
|
||||
@@ -257,6 +317,31 @@ impl RelayGuards {
|
||||
self.session_ingress.lock().await.remove(session_id);
|
||||
self.session_egress.lock().await.remove(session_id);
|
||||
}
|
||||
|
||||
async fn redis_allow_ip_window(
|
||||
&self,
|
||||
ip: &str,
|
||||
scope: &str,
|
||||
window_secs: u64,
|
||||
limit: i64,
|
||||
) -> bool {
|
||||
let Some(mut conn) = self.redis.clone() else {
|
||||
return true;
|
||||
};
|
||||
let key = format!("ratelimit:ip:{ip}:{scope}");
|
||||
let res: redis::RedisResult<i64> = async {
|
||||
let count: i64 = conn.incr(&key, 1).await?;
|
||||
if count == 1 {
|
||||
let _: bool = conn.expire(&key, window_secs as i64).await?;
|
||||
}
|
||||
Ok(count)
|
||||
}
|
||||
.await;
|
||||
match res {
|
||||
Ok(count) => count <= limit,
|
||||
Err(_) => true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
@@ -275,7 +360,8 @@ struct RelayInstanceRecord {
|
||||
_instance_id: String,
|
||||
#[serde(rename = "region")]
|
||||
_region: Option<String>,
|
||||
status: Option<String>,
|
||||
#[serde(rename = "status")]
|
||||
_status: Option<String>,
|
||||
r2r_addr: Option<String>,
|
||||
}
|
||||
|
||||
@@ -435,6 +521,7 @@ impl RedisRegistry {
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
init_metrics()?;
|
||||
tracing_subscriber::fmt()
|
||||
.with_env_filter(
|
||||
tracing_subscriber::EnvFilter::try_from_default_env()
|
||||
@@ -444,7 +531,8 @@ async fn main() -> Result<()> {
|
||||
|
||||
let cfg = RelayConfig::from_env();
|
||||
let registry = RedisRegistry::from_env(&cfg).await;
|
||||
let guards = Arc::new(RelayGuards::from_env());
|
||||
let guards = Arc::new(RelayGuards::from_env().await);
|
||||
let r2r = Arc::new(R2rManager::new());
|
||||
registry.register_instance().await;
|
||||
|
||||
let control_listener = TcpListener::bind(&cfg.control_bind)
|
||||
@@ -458,6 +546,7 @@ async fn main() -> Result<()> {
|
||||
.with_context(|| format!("bind r2r {}", cfg.r2r_bind))?;
|
||||
|
||||
info!(instance_id = %cfg.instance_id, region = %cfg.region, control = %cfg.control_bind, player = %cfg.player_bind, r2r = %cfg.r2r_bind, r2r_advertise = %cfg.r2r_advertise_addr, "relay started");
|
||||
metrics::gauge!("relay_drain_state").set(0.0);
|
||||
|
||||
let shutdown = Arc::new(Notify::new());
|
||||
let state: SharedState = Arc::new(RwLock::new(RelayState::new()));
|
||||
@@ -477,6 +566,7 @@ async fn main() -> Result<()> {
|
||||
state.clone(),
|
||||
registry.clone(),
|
||||
guards.clone(),
|
||||
r2r.clone(),
|
||||
shutdown.clone(),
|
||||
));
|
||||
let r2r_task = tokio::spawn(run_r2r_accept_loop(
|
||||
@@ -484,6 +574,7 @@ async fn main() -> Result<()> {
|
||||
cfg.clone(),
|
||||
state.clone(),
|
||||
guards.clone(),
|
||||
r2r.clone(),
|
||||
shutdown.clone(),
|
||||
));
|
||||
|
||||
@@ -501,6 +592,7 @@ async fn main() -> Result<()> {
|
||||
}
|
||||
|
||||
registry.set_draining().await;
|
||||
metrics::gauge!("relay_drain_state").set(1.0);
|
||||
shutdown.notify_waiters();
|
||||
info!("draining relay");
|
||||
tokio::time::sleep(Duration::from_secs(1)).await;
|
||||
@@ -515,6 +607,7 @@ async fn run_registry_heartbeat(state: SharedState, registry: RedisRegistry, shu
|
||||
_ = shutdown.notified() => break,
|
||||
_ = ticker.tick() => {
|
||||
let count = state.read().await.session_count();
|
||||
metrics::gauge!("relay_active_tunnels").set(count as f64);
|
||||
registry.heartbeat_instance(count).await;
|
||||
}
|
||||
}
|
||||
@@ -537,6 +630,7 @@ async fn run_control_accept_loop(
|
||||
Ok(v) => v,
|
||||
Err(e) => { warn!(error = %e, "control accept failed"); continue; }
|
||||
};
|
||||
metrics::counter!("relay_control_accepts_total").increment(1);
|
||||
let cfg = cfg.clone();
|
||||
let state = state.clone();
|
||||
let registry = registry.clone();
|
||||
@@ -558,6 +652,7 @@ async fn run_player_accept_loop(
|
||||
state: SharedState,
|
||||
registry: RedisRegistry,
|
||||
guards: Arc<RelayGuards>,
|
||||
r2r: Arc<R2rManager>,
|
||||
shutdown: Arc<Notify>,
|
||||
) -> Result<()> {
|
||||
loop {
|
||||
@@ -568,12 +663,14 @@ async fn run_player_accept_loop(
|
||||
Ok(v) => v,
|
||||
Err(e) => { warn!(error = %e, "player accept failed"); continue; }
|
||||
};
|
||||
metrics::counter!("relay_player_accepts_total").increment(1);
|
||||
let cfg = cfg.clone();
|
||||
let state = state.clone();
|
||||
let registry = registry.clone();
|
||||
let guards = guards.clone();
|
||||
let r2r = r2r.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = handle_player_conn(stream, addr, cfg, state, registry, guards).await {
|
||||
if let Err(e) = handle_player_conn(stream, addr, cfg, state, registry, guards, r2r).await {
|
||||
debug!(peer = %addr, error = %e, "player connection closed");
|
||||
}
|
||||
});
|
||||
@@ -588,6 +685,7 @@ async fn run_r2r_accept_loop(
|
||||
cfg: RelayConfig,
|
||||
state: SharedState,
|
||||
guards: Arc<RelayGuards>,
|
||||
r2r: Arc<R2rManager>,
|
||||
shutdown: Arc<Notify>,
|
||||
) -> Result<()> {
|
||||
loop {
|
||||
@@ -598,11 +696,13 @@ async fn run_r2r_accept_loop(
|
||||
Ok(v) => v,
|
||||
Err(e) => { warn!(error = %e, "r2r accept failed"); continue; }
|
||||
};
|
||||
metrics::counter!("relay_r2r_accepts_total").increment(1);
|
||||
let cfg = cfg.clone();
|
||||
let state = state.clone();
|
||||
let guards = guards.clone();
|
||||
let r2r = r2r.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = handle_r2r_conn(stream, addr, cfg, state, guards).await {
|
||||
if let Err(e) = handle_r2r_conn(stream, addr, cfg, state, guards, r2r).await {
|
||||
warn!(peer = %addr, error = %e, "r2r connection ended with error");
|
||||
}
|
||||
});
|
||||
@@ -612,6 +712,7 @@ async fn run_r2r_accept_loop(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(stream, state, registry, guards, cfg), fields(peer = %addr))]
|
||||
async fn handle_control_conn(
|
||||
stream: TcpStream,
|
||||
addr: SocketAddr,
|
||||
@@ -621,6 +722,7 @@ async fn handle_control_conn(
|
||||
guards: Arc<RelayGuards>,
|
||||
) -> Result<()> {
|
||||
if !guards.allow_registration_ip(&addr.ip().to_string()).await {
|
||||
metrics::counter!("relay_rate_limited_total", "scope" => "registration_ip").increment(1);
|
||||
anyhow::bail!("registration rate limited for {}", addr.ip());
|
||||
}
|
||||
|
||||
@@ -666,6 +768,7 @@ async fn handle_control_conn(
|
||||
owner_instance_id: cfg.instance_id.clone(),
|
||||
})).await?;
|
||||
info!(peer = %addr, user_id = %user_id, fqdn = %fqdn, session_id = %session_id, "client registered");
|
||||
metrics::counter!("relay_tunnel_registrations_total").increment(1);
|
||||
|
||||
let write_task = tokio::spawn(async move {
|
||||
while let Some(frame) = rx.recv().await {
|
||||
@@ -748,6 +851,7 @@ async fn control_read_loop(
|
||||
}
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(stream, cfg, state, registry, guards, r2r), fields(peer = %addr))]
|
||||
async fn handle_player_conn(
|
||||
mut stream: TcpStream,
|
||||
addr: SocketAddr,
|
||||
@@ -755,8 +859,10 @@ async fn handle_player_conn(
|
||||
state: SharedState,
|
||||
registry: RedisRegistry,
|
||||
guards: Arc<RelayGuards>,
|
||||
r2r: Arc<R2rManager>,
|
||||
) -> Result<()> {
|
||||
if !guards.allow_player_ip(&addr.ip().to_string()).await {
|
||||
metrics::counter!("relay_rate_limited_total", "scope" => "player_ip").increment(1);
|
||||
debug!(peer = %addr, "player connect rate limited");
|
||||
return Ok(());
|
||||
}
|
||||
@@ -784,83 +890,372 @@ async fn handle_player_conn(
|
||||
debug!(peer = %addr, hostname = %hostname, session_id = %route.session_id, "route points to self but local session missing");
|
||||
return Ok(());
|
||||
}
|
||||
return proxy_player_to_owner(stream, addr, hostname, initial_data, route, cfg, registry).await;
|
||||
return proxy_player_to_owner(stream, addr, hostname, initial_data, route, cfg, registry, guards, r2r).await;
|
||||
}
|
||||
|
||||
debug!(peer = %addr, hostname = %hostname, "no tunnel for hostname");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(stream, cfg, state, guards, r2r), fields(peer = %addr))]
|
||||
async fn handle_r2r_conn(
|
||||
mut stream: TcpStream,
|
||||
stream: TcpStream,
|
||||
addr: SocketAddr,
|
||||
_cfg: RelayConfig,
|
||||
cfg: RelayConfig,
|
||||
state: SharedState,
|
||||
guards: Arc<RelayGuards>,
|
||||
r2r: Arc<R2rManager>,
|
||||
) -> Result<()> {
|
||||
let prelude: RelayForwardPrelude = read_frame(&mut stream).await.context("read r2r prelude")?;
|
||||
if prelude.version != 1 {
|
||||
anyhow::bail!("unsupported r2r prelude version {}", prelude.version);
|
||||
}
|
||||
if prelude.hop_count > 1 {
|
||||
anyhow::bail!("invalid hop_count {}", prelude.hop_count);
|
||||
}
|
||||
|
||||
let session = local_session_for_session_id(&state, &prelude.session_id).await;
|
||||
let Some(session) = session else {
|
||||
anyhow::bail!("owner session not found for {}", prelude.session_id);
|
||||
};
|
||||
|
||||
attach_player_socket_to_session(
|
||||
stream,
|
||||
session,
|
||||
prelude.fqdn.clone(),
|
||||
prelude.peer_addr,
|
||||
prelude.initial_data,
|
||||
Some(prelude.stream_id),
|
||||
"r2r",
|
||||
guards,
|
||||
)
|
||||
.await
|
||||
.with_context(|| format!("r2r attach failed from {addr}"))
|
||||
handle_r2r_multiplex_conn(stream, addr, cfg, state, guards, r2r)
|
||||
.await
|
||||
.with_context(|| format!("r2r multiplex failed from {addr}"))
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(player_stream, route, cfg, registry, guards, r2r), fields(peer = %player_addr, hostname = %hostname))]
|
||||
async fn proxy_player_to_owner(
|
||||
mut player_stream: TcpStream,
|
||||
player_stream: TcpStream,
|
||||
player_addr: SocketAddr,
|
||||
hostname: String,
|
||||
initial_data: Vec<u8>,
|
||||
route: TunnelRouteRecord,
|
||||
cfg: RelayConfig,
|
||||
registry: RedisRegistry,
|
||||
guards: Arc<RelayGuards>,
|
||||
r2r: Arc<R2rManager>,
|
||||
) -> Result<()> {
|
||||
let redis_lookup_started = Instant::now();
|
||||
let owner = registry
|
||||
.lookup_instance(&route.instance_id)
|
||||
.await
|
||||
.with_context(|| format!("owner instance {} not found in redis", route.instance_id))?;
|
||||
metrics::histogram!("relay_redis_lookup_latency_ms")
|
||||
.record(redis_lookup_started.elapsed().as_secs_f64() * 1000.0);
|
||||
let r2r_addr = owner
|
||||
.r2r_addr
|
||||
.clone()
|
||||
.with_context(|| format!("owner {} missing r2r_addr", route.instance_id))?;
|
||||
|
||||
let mut owner_stream = timeout(cfg.r2r_connect_timeout, TcpStream::connect(&r2r_addr))
|
||||
.await
|
||||
.context("r2r connect timeout")??;
|
||||
let r2r_connect_started = Instant::now();
|
||||
metrics::histogram!("relay_r2r_connect_latency_ms")
|
||||
.record(r2r_connect_started.elapsed().as_secs_f64() * 1000.0);
|
||||
|
||||
let prelude = RelayForwardPrelude {
|
||||
version: 1,
|
||||
session_id: route.session_id,
|
||||
session_id: route.session_id.clone(),
|
||||
fqdn: hostname.clone(),
|
||||
stream_id: Uuid::new_v4().to_string(),
|
||||
peer_addr: player_addr.to_string(),
|
||||
origin_instance_id: cfg.instance_id,
|
||||
origin_instance_id: cfg.instance_id.clone(),
|
||||
hop_count: 1,
|
||||
initial_data,
|
||||
};
|
||||
write_frame(&mut owner_stream, &prelude).await?;
|
||||
proxy_player_to_owner_pooled(
|
||||
player_stream,
|
||||
player_addr,
|
||||
hostname,
|
||||
route.instance_id,
|
||||
r2r_addr,
|
||||
prelude,
|
||||
route.session_id,
|
||||
cfg,
|
||||
guards,
|
||||
r2r,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
let _ = copy_bidirectional(&mut player_stream, &mut owner_stream).await?;
|
||||
info!(peer = %player_addr, hostname = %hostname, owner = %route.instance_id, "proxied player connection to owner relay");
|
||||
async fn proxy_player_to_owner_pooled(
|
||||
player_stream: TcpStream,
|
||||
player_addr: SocketAddr,
|
||||
hostname: String,
|
||||
owner_instance_id: String,
|
||||
owner_r2r_addr: String,
|
||||
prelude: RelayForwardPrelude,
|
||||
session_id: String,
|
||||
cfg: RelayConfig,
|
||||
guards: Arc<RelayGuards>,
|
||||
r2r: Arc<R2rManager>,
|
||||
) -> Result<()> {
|
||||
let stream_id = prelude.stream_id.clone();
|
||||
let sender = get_or_connect_r2r_pool(
|
||||
owner_instance_id.clone(),
|
||||
owner_r2r_addr,
|
||||
cfg,
|
||||
guards,
|
||||
r2r.clone(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
let (player_read, player_write) = player_stream.into_split();
|
||||
let (to_player_tx, to_player_rx) = mpsc::channel::<Vec<u8>>(128);
|
||||
r2r.ingress_stream_sinks
|
||||
.write()
|
||||
.await
|
||||
.insert(stream_id.clone(), to_player_tx);
|
||||
|
||||
sender
|
||||
.send(R2rFrame::Open(prelude))
|
||||
.await
|
||||
.context("send r2r open")?;
|
||||
|
||||
let tx = sender.clone();
|
||||
let sid = session_id.clone();
|
||||
let stid = stream_id.clone();
|
||||
let sinks = r2r.ingress_stream_sinks.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) =
|
||||
run_ingress_player_reader_to_r2r(player_read, tx.clone(), sid.clone(), stid.clone()).await
|
||||
{
|
||||
debug!(stream_id = %stid, error = %e, "ingress player->r2r reader ended");
|
||||
}
|
||||
let _ = tx
|
||||
.send(R2rFrame::Close(R2rStreamClosed {
|
||||
session_id: sid,
|
||||
stream_id: stid.clone(),
|
||||
reason: Some("ingress_player_reader_closed".into()),
|
||||
}))
|
||||
.await;
|
||||
let _ = sinks.write().await.remove(&stid);
|
||||
});
|
||||
|
||||
let stid = stream_id.clone();
|
||||
let sinks = r2r.ingress_stream_sinks.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = run_ingress_player_writer(player_write, to_player_rx).await {
|
||||
debug!(stream_id = %stid, error = %e, "ingress r2r->player writer ended");
|
||||
}
|
||||
let _ = sinks.write().await.remove(&stid);
|
||||
});
|
||||
|
||||
metrics::counter!("relay_r2r_forwards_total").increment(1);
|
||||
info!(peer = %player_addr, hostname = %hostname, owner = %owner_instance_id, stream_id = %stream_id, "proxied player via pooled r2r channel");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn get_or_connect_r2r_pool(
|
||||
owner_instance_id: String,
|
||||
owner_r2r_addr: String,
|
||||
cfg: RelayConfig,
|
||||
guards: Arc<RelayGuards>,
|
||||
r2r: Arc<R2rManager>,
|
||||
) -> Result<mpsc::Sender<R2rFrame>> {
|
||||
if let Some(existing) = r2r.outbound.lock().await.get(&owner_instance_id).cloned() {
|
||||
return Ok(existing);
|
||||
}
|
||||
|
||||
let connect_started = Instant::now();
|
||||
let stream = timeout(cfg.r2r_connect_timeout, TcpStream::connect(&owner_r2r_addr))
|
||||
.await
|
||||
.context("r2r connect timeout")??;
|
||||
metrics::histogram!("relay_r2r_connect_latency_ms")
|
||||
.record(connect_started.elapsed().as_secs_f64() * 1000.0);
|
||||
|
||||
let (mut reader, mut writer) = stream.into_split();
|
||||
let (tx, mut rx) = mpsc::channel::<R2rFrame>(2048);
|
||||
|
||||
let mut pools = r2r.outbound.lock().await;
|
||||
if let Some(existing) = pools.get(&owner_instance_id).cloned() {
|
||||
return Ok(existing);
|
||||
}
|
||||
pools.insert(owner_instance_id.clone(), tx.clone());
|
||||
drop(pools);
|
||||
|
||||
let owner_for_reader = owner_instance_id.clone();
|
||||
let r2r_for_reader = r2r.clone();
|
||||
let guards_for_reader = guards.clone();
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
match read_frame::<_, R2rFrame>(&mut reader).await {
|
||||
Ok(frame) => {
|
||||
if let Err(e) = handle_r2r_inbound_frame(frame, &r2r_for_reader, &guards_for_reader).await {
|
||||
debug!(owner = %owner_for_reader, error = %e, "r2r pooled inbound frame error");
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
debug!(owner = %owner_for_reader, error = %e, "r2r pooled reader ended");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
r2r_for_reader.outbound.lock().await.remove(&owner_for_reader);
|
||||
});
|
||||
|
||||
tokio::spawn(async move {
|
||||
while let Some(frame) = rx.recv().await {
|
||||
if let Err(e) = write_frame(&mut writer, &frame).await {
|
||||
debug!(error = %e, "r2r pooled writer ended");
|
||||
break;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Ok(tx)
|
||||
}
|
||||
|
||||
async fn handle_r2r_multiplex_conn(
|
||||
stream: TcpStream,
|
||||
_addr: SocketAddr,
|
||||
_cfg: RelayConfig,
|
||||
state: SharedState,
|
||||
guards: Arc<RelayGuards>,
|
||||
_r2r: Arc<R2rManager>,
|
||||
) -> Result<()> {
|
||||
let (mut reader, mut writer) = stream.into_split();
|
||||
let (tx, mut rx) = mpsc::channel::<R2rFrame>(2048);
|
||||
|
||||
let _writer_task = tokio::spawn(async move {
|
||||
while let Some(frame) = rx.recv().await {
|
||||
write_frame(&mut writer, &frame).await?;
|
||||
}
|
||||
Ok::<(), anyhow::Error>(())
|
||||
});
|
||||
|
||||
loop {
|
||||
let frame: R2rFrame = read_frame(&mut reader).await?;
|
||||
match frame {
|
||||
R2rFrame::Open(prelude) => {
|
||||
if prelude.version != 1 || prelude.hop_count > 1 {
|
||||
continue;
|
||||
}
|
||||
if let Some(session) = local_session_for_session_id(&state, &prelude.session_id).await {
|
||||
attach_virtual_r2r_stream_to_session(session, prelude, tx.clone()).await?;
|
||||
} else {
|
||||
let _ = tx.send(R2rFrame::Close(R2rStreamClosed {
|
||||
session_id: prelude.session_id,
|
||||
stream_id: prelude.stream_id,
|
||||
reason: Some("owner_session_not_found".into()),
|
||||
})).await;
|
||||
}
|
||||
}
|
||||
R2rFrame::Data(data) => {
|
||||
guards
|
||||
.throttle_session_bytes(&data.session_id, SessionDir::EgressToClient, data.data.len())
|
||||
.await;
|
||||
if let Some(session) = local_session_for_session_id(&state, &data.session_id).await {
|
||||
let _ = session
|
||||
.tx
|
||||
.send(ServerFrame::StreamData(StreamData { stream_id: data.stream_id, data: data.data }))
|
||||
.await;
|
||||
}
|
||||
}
|
||||
R2rFrame::Close(close) => {
|
||||
if let Some(session) = local_session_for_session_id(&state, &close.session_id).await {
|
||||
let _ = session
|
||||
.tx
|
||||
.send(ServerFrame::StreamClosed(StreamClosed { stream_id: close.stream_id.clone(), reason: close.reason.clone() }))
|
||||
.await;
|
||||
remove_stream_sink(&state, &close.session_id, &close.stream_id).await;
|
||||
} else {
|
||||
remove_stream_sink(&state, &close.session_id, &close.stream_id).await;
|
||||
}
|
||||
}
|
||||
R2rFrame::Ping => {
|
||||
let _ = tx.send(R2rFrame::Pong).await;
|
||||
}
|
||||
R2rFrame::Pong => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn attach_virtual_r2r_stream_to_session(
|
||||
session: SessionHandle,
|
||||
prelude: RelayForwardPrelude,
|
||||
r2r_tx: mpsc::Sender<R2rFrame>,
|
||||
) -> Result<()> {
|
||||
let stream_id = prelude.stream_id.clone();
|
||||
let session_id = session.session_id.clone();
|
||||
let (to_r2r_tx, mut to_r2r_rx) = mpsc::channel::<Vec<u8>>(128);
|
||||
session
|
||||
.stream_sinks
|
||||
.write()
|
||||
.await
|
||||
.insert(stream_id.clone(), to_r2r_tx);
|
||||
|
||||
session
|
||||
.tx
|
||||
.send(ServerFrame::IncomingTcp(IncomingTcp {
|
||||
stream_id: stream_id.clone(),
|
||||
session_id: session_id.clone(),
|
||||
peer_addr: prelude.peer_addr.clone(),
|
||||
hostname: prelude.fqdn.clone(),
|
||||
initial_data: prelude.initial_data.clone(),
|
||||
}))
|
||||
.await
|
||||
.context("send virtual r2r IncomingTcp to client")?;
|
||||
|
||||
tokio::spawn(async move {
|
||||
while let Some(chunk) = to_r2r_rx.recv().await {
|
||||
let _ = r2r_tx
|
||||
.send(R2rFrame::Data(R2rStreamData {
|
||||
session_id: session_id.clone(),
|
||||
stream_id: stream_id.clone(),
|
||||
data: chunk,
|
||||
}))
|
||||
.await;
|
||||
}
|
||||
let _ = r2r_tx
|
||||
.send(R2rFrame::Close(R2rStreamClosed {
|
||||
session_id,
|
||||
stream_id,
|
||||
reason: Some("owner_sink_closed".into()),
|
||||
}))
|
||||
.await;
|
||||
});
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn handle_r2r_inbound_frame(
|
||||
frame: R2rFrame,
|
||||
r2r: &R2rManager,
|
||||
_guards: &RelayGuards,
|
||||
) -> Result<()> {
|
||||
match frame {
|
||||
R2rFrame::Data(data) => {
|
||||
if let Some(tx) = r2r.ingress_stream_sinks.read().await.get(&data.stream_id).cloned() {
|
||||
let _ = tx.send(data.data).await;
|
||||
}
|
||||
}
|
||||
R2rFrame::Close(close) => {
|
||||
r2r.ingress_stream_sinks.write().await.remove(&close.stream_id);
|
||||
}
|
||||
R2rFrame::Ping | R2rFrame::Pong | R2rFrame::Open(_) => {}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn run_ingress_player_reader_to_r2r(
|
||||
mut reader: tokio::net::tcp::OwnedReadHalf,
|
||||
tx: mpsc::Sender<R2rFrame>,
|
||||
session_id: String,
|
||||
stream_id: String,
|
||||
) -> Result<()> {
|
||||
let mut buf = vec![0u8; 16 * 1024];
|
||||
loop {
|
||||
let n = reader.read(&mut buf).await?;
|
||||
if n == 0 {
|
||||
break;
|
||||
}
|
||||
tx.send(R2rFrame::Data(R2rStreamData {
|
||||
session_id: session_id.clone(),
|
||||
stream_id: stream_id.clone(),
|
||||
data: buf[..n].to_vec(),
|
||||
}))
|
||||
.await
|
||||
.context("send ingress data to r2r")?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn run_ingress_player_writer(
|
||||
mut writer: tokio::net::tcp::OwnedWriteHalf,
|
||||
mut rx: mpsc::Receiver<Vec<u8>>,
|
||||
) -> Result<()> {
|
||||
while let Some(chunk) = rx.recv().await {
|
||||
writer.write_all(&chunk).await?;
|
||||
}
|
||||
let _ = writer.shutdown().await;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -944,6 +1339,7 @@ async fn attach_player_socket_to_session(
|
||||
});
|
||||
|
||||
info!(peer = %peer_addr, hostname = %hostname, session_id = %session.session_id, stream_id = %stream_id, source, "player proxied via client stream");
|
||||
metrics::gauge!("relay_active_player_conns").increment(1.0);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -970,7 +1366,9 @@ async fn run_player_reader(
|
||||
}))
|
||||
.await
|
||||
.context("send stream data to client")?;
|
||||
metrics::counter!("relay_bytes_out_total").increment(n as u64);
|
||||
}
|
||||
metrics::gauge!("relay_active_player_conns").decrement(1.0);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -980,6 +1378,7 @@ async fn run_player_writer(
|
||||
) -> Result<()> {
|
||||
while let Some(chunk) = rx.recv().await {
|
||||
writer.write_all(&chunk).await?;
|
||||
metrics::counter!("relay_bytes_in_total").increment(chunk.len() as u64);
|
||||
}
|
||||
let _ = writer.shutdown().await;
|
||||
Ok(())
|
||||
@@ -1061,3 +1460,14 @@ fn guess_advertise_addr(bind: &str) -> String {
|
||||
"127.0.0.1:7001".to_string()
|
||||
}
|
||||
}
|
||||
|
||||
fn init_metrics() -> Result<()> {
|
||||
if let Ok(bind) = std::env::var("RELAY_METRICS_BIND") {
|
||||
let addr: std::net::SocketAddr = bind.parse().context("parse RELAY_METRICS_BIND")?;
|
||||
PrometheusBuilder::new()
|
||||
.with_http_listener(addr)
|
||||
.install()
|
||||
.context("install prometheus exporter")?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user