sglang.0.4.8.post1/sglang/sgl-router/src/router.rs

1331 lines
51 KiB
Rust

use crate::pd_router::PDRouter;
use crate::pd_types::PDSelectionPolicy;
use crate::tree::Tree;
use ::metrics::{counter, gauge, histogram};
use actix_web::http::header::{HeaderValue, CONTENT_TYPE};
use actix_web::{HttpRequest, HttpResponse};
use futures_util::{StreamExt, TryStreamExt};
use std::collections::HashMap;
use std::fmt::Debug;
use std::sync::atomic::AtomicUsize;
use std::sync::{Arc, Mutex, RwLock};
use std::thread;
use std::time::Duration;
use std::time::Instant;
use tokio;
use tracing::{debug, error, info, warn};
pub fn copy_request_headers(req: &HttpRequest) -> Vec<(String, String)> {
req.headers()
.iter()
.filter_map(|(name, value)| {
value
.to_str()
.ok()
.map(|v| (name.to_string(), v.to_string()))
})
.collect()
}
#[derive(Debug)]
pub enum Router {
RoundRobin {
worker_urls: Arc<RwLock<Vec<String>>>,
current_index: AtomicUsize,
timeout_secs: u64,
interval_secs: u64,
},
Random {
worker_urls: Arc<RwLock<Vec<String>>>,
timeout_secs: u64,
interval_secs: u64,
},
PrefillDecode {
pd_router: Arc<PDRouter>,
},
CacheAware {
/*
Cache-Aware Load Balancing Router
This router combines two strategies to optimize both cache utilization and request distribution:
1. Cache-Aware Routing (Approximate Tree)
2. Load Balancing (Shortest Queue with Balance Thresholds)
The router dynamically switches between these strategies based on load conditions:
- Uses load balancing when the system is imbalanced
- Uses cache-aware routing when the system is balanced
A system is considered imbalanced if both conditions are met:
1. (max - min) > abs_threshold
2. max > rel_threshold * min
Strategy Details:
1. Cache-Aware Routing (Approximate Tree)
-------------------------------------------
This strategy maintains an approximate radix tree for each worker based on request history,
eliminating the need for direct cache state queries. The tree stores raw text characters
instead of token IDs to avoid tokenization overhead.
Process:
a. For each request, find the worker with the highest prefix match
b. If match rate > cache_threshold:
Route to the worker with highest match (likely has relevant data cached)
c. If match rate ≤ cache_threshold:
Route to the worker with smallest tree size (most available cache capacity)
d. Background maintenance:
Periodically evict least recently used leaf nodes to prevent memory overflow
2. Load Balancing (Shortest Queue)
-------------------------------------------
This strategy tracks pending request counts per worker and routes new requests
to the least busy worker when the system is detected to be imbalanced.
Configuration Parameters:
------------------------
1. cache_threshold: (float, 0.0 to 1.0)
Minimum prefix match ratio to use highest-match routing.
Below this threshold, routes to worker with most available cache space.
2. balance_abs_threshold: (integer)
Absolute difference threshold for load imbalance detection.
System is potentially imbalanced if (max_load - min_load) > abs_threshold
3. balance_rel_threshold: (float)
Relative ratio threshold for load imbalance detection.
System is potentially imbalanced if max_load > min_load * rel_threshold
Used in conjunction with abs_threshold to determine final imbalance state.
4. eviction_interval_secs: (integer)
Interval between LRU eviction cycles for the approximate trees.
5. max_tree_size: (integer)
Maximum nodes per tree. When exceeded, LRU leaf nodes are evicted
during the next eviction cycle.
*/
worker_urls: Arc<RwLock<Vec<String>>>,
tree: Arc<Mutex<Tree>>,
running_queue: Arc<Mutex<HashMap<String, usize>>>,
processed_queue: Arc<Mutex<HashMap<String, usize>>>,
cache_threshold: f32,
balance_abs_threshold: usize,
balance_rel_threshold: f32,
timeout_secs: u64,
interval_secs: u64,
_eviction_thread: Option<thread::JoinHandle<()>>,
},
}
#[derive(Debug, Clone)]
pub enum PolicyConfig {
RandomConfig {
timeout_secs: u64,
interval_secs: u64,
},
RoundRobinConfig {
timeout_secs: u64,
interval_secs: u64,
},
CacheAwareConfig {
cache_threshold: f32,
balance_abs_threshold: usize,
balance_rel_threshold: f32,
eviction_interval_secs: u64,
max_tree_size: usize,
timeout_secs: u64,
interval_secs: u64,
},
PrefillDecodeConfig {
selection_policy: PDSelectionPolicy,
prefill_urls: Vec<(String, Option<u16>)>, // (url, bootstrap_port)
decode_urls: Vec<String>,
timeout_secs: u64,
interval_secs: u64,
},
}
impl Router {
pub fn new(worker_urls: Vec<String>, policy_config: PolicyConfig) -> Result<Self, String> {
// Update active workers gauge
gauge!("sgl_router_active_workers").set(worker_urls.len() as f64);
// Get timeout and interval from policy config
let (timeout_secs, interval_secs) = match &policy_config {
PolicyConfig::RandomConfig {
timeout_secs,
interval_secs,
} => (*timeout_secs, *interval_secs),
PolicyConfig::RoundRobinConfig {
timeout_secs,
interval_secs,
} => (*timeout_secs, *interval_secs),
PolicyConfig::CacheAwareConfig {
timeout_secs,
interval_secs,
..
} => (*timeout_secs, *interval_secs),
PolicyConfig::PrefillDecodeConfig {
timeout_secs,
interval_secs,
..
} => (*timeout_secs, *interval_secs),
};
// For PrefillDecode, we need to handle workers differently
match &policy_config {
PolicyConfig::PrefillDecodeConfig { .. } => {
// PD mode doesn't use the worker_urls parameter
// We'll validate PD workers separately
}
_ => {
// Wait until all workers are healthy for regular modes
Self::wait_for_healthy_workers(&worker_urls, timeout_secs, interval_secs)?;
}
}
// Create router based on policy...
Ok(match policy_config {
PolicyConfig::RandomConfig {
timeout_secs,
interval_secs,
} => Router::Random {
worker_urls: Arc::new(RwLock::new(worker_urls)),
timeout_secs,
interval_secs,
},
PolicyConfig::RoundRobinConfig {
timeout_secs,
interval_secs,
} => Router::RoundRobin {
worker_urls: Arc::new(RwLock::new(worker_urls)),
current_index: std::sync::atomic::AtomicUsize::new(0),
timeout_secs,
interval_secs,
},
PolicyConfig::CacheAwareConfig {
cache_threshold,
balance_abs_threshold,
balance_rel_threshold,
eviction_interval_secs,
max_tree_size,
timeout_secs,
interval_secs,
} => {
let mut running_queue = HashMap::new();
for url in &worker_urls {
running_queue.insert(url.clone(), 0);
}
let mut processed_queue = HashMap::new();
for url in &worker_urls {
processed_queue.insert(url.clone(), 0);
}
let tree = Arc::new(Mutex::new(Tree::new()));
let running_queue = Arc::new(Mutex::new(running_queue));
let processed_queue = Arc::new(Mutex::new(processed_queue));
// Create background eviction thread
let tree_clone = Arc::clone(&tree);
let processed_queue_clone = Arc::clone(&processed_queue);
let running_queue_clone = Arc::clone(&running_queue);
let eviction_thread = thread::spawn(move || {
loop {
// Sleep for the specified interval
thread::sleep(Duration::from_secs(eviction_interval_secs));
let locked_tree_clone = tree_clone.lock().unwrap();
// Run eviction
locked_tree_clone.evict_tenant_by_size(max_tree_size);
// Print the process queue
let locked_processed_queue = processed_queue_clone.lock().unwrap();
info!("Processed Queue: {:?}", locked_processed_queue);
// Print the running queue
let locked_running_queue = running_queue_clone.lock().unwrap();
info!("Running Queue: {:?}", locked_running_queue);
}
});
for url in &worker_urls {
tree.lock().unwrap().insert("", url);
}
Router::CacheAware {
worker_urls: Arc::new(RwLock::new(worker_urls)),
tree,
running_queue,
processed_queue,
cache_threshold,
balance_abs_threshold,
balance_rel_threshold,
timeout_secs,
interval_secs,
_eviction_thread: Some(eviction_thread),
}
}
PolicyConfig::PrefillDecodeConfig {
selection_policy,
prefill_urls,
decode_urls,
timeout_secs,
interval_secs,
} => {
// Create PDRouter instance
let pd_router = PDRouter::new(
prefill_urls,
decode_urls,
selection_policy,
timeout_secs,
interval_secs,
)?;
Router::PrefillDecode {
pd_router: Arc::new(pd_router),
}
}
})
}
/// Get a reference to the worker URLs shared across threads
pub fn get_worker_urls(&self) -> Arc<RwLock<Vec<String>>> {
match self {
Router::RoundRobin { worker_urls, .. } => Arc::clone(worker_urls),
Router::Random { worker_urls, .. } => Arc::clone(worker_urls),
Router::CacheAware { worker_urls, .. } => Arc::clone(worker_urls),
Router::PrefillDecode { .. } => {
// For PD mode, return empty list since we manage workers differently
Arc::new(RwLock::new(Vec::new()))
}
}
}
pub fn wait_for_healthy_workers(
worker_urls: &[String],
timeout_secs: u64,
interval_secs: u64,
) -> Result<(), String> {
let start_time = std::time::Instant::now();
let sync_client = reqwest::blocking::Client::builder()
.timeout(Duration::from_secs(timeout_secs))
.build()
.map_err(|e| format!("Failed to create HTTP client: {}", e))?;
loop {
if start_time.elapsed() > Duration::from_secs(timeout_secs) {
error!(
"Timeout {}s waiting for workers {:?} to become healthy. Please set --router-worker-startup-timeout-secs (sglang_router.launch_server) or --worker-startup-timeout-secs (sglang_worker.router) to a larger value",
timeout_secs, worker_urls
);
return Err(format!(
"Timeout {}s waiting for workers {:?} to become healthy. Please set --router-worker-startup-timeout-secs (sglang_router.launch_server) or --worker-startup-timeout-secs (sglang_worker.router) to a larger value",
timeout_secs, worker_urls
));
}
let mut all_healthy = true;
let mut unhealthy_workers = Vec::new();
for url in worker_urls {
match sync_client.get(&format!("{}/health", url)).send() {
Ok(res) => {
if !res.status().is_success() {
let msg = format!(
"Worker heatlh check is pending with status {}",
res.status()
);
info!("{}", msg);
all_healthy = false;
unhealthy_workers.push((url, msg));
}
}
Err(_) => {
let msg = format!("Worker is not ready yet");
info!("{}", msg);
all_healthy = false;
unhealthy_workers.push((url, msg));
}
}
}
if all_healthy {
info!("All workers are healthy");
return Ok(());
} else {
info!("Initializing workers:");
for (url, reason) in &unhealthy_workers {
info!(" {} - {}", url, reason);
}
thread::sleep(Duration::from_secs(interval_secs));
}
}
}
fn select_first_worker(&self) -> Result<String, String> {
match self {
Router::RoundRobin { worker_urls, .. }
| Router::Random { worker_urls, .. }
| Router::CacheAware { worker_urls, .. } => {
if worker_urls.read().unwrap().is_empty() {
Err("No workers are available".to_string())
} else {
Ok(worker_urls.read().unwrap()[0].clone())
}
}
Router::PrefillDecode { .. } => {
// For PD mode, we don't need this method as routing is handled by PDRouter
Err("PrefillDecode mode doesn't use select_first_worker".to_string())
}
}
}
pub async fn send_request(
&self,
client: &reqwest::Client,
worker_url: &str,
route: &str,
req: &HttpRequest,
) -> HttpResponse {
let start = Instant::now();
let mut request_builder = client.get(format!("{}{}", worker_url, route));
// Copy all headers from original request except for /health because it does not need authorization
if route != "/health" {
for (name, value) in copy_request_headers(req) {
// Skip Content-Type and Content-Length as .json() sets them
if name.to_lowercase() != "content-type" && name.to_lowercase() != "content-length"
{
request_builder = request_builder.header(name, value);
}
}
}
let response = match request_builder.send().await {
Ok(res) => {
let status = actix_web::http::StatusCode::from_u16(res.status().as_u16())
.unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);
match res.bytes().await {
Ok(body) => HttpResponse::build(status).body(body.to_vec()),
Err(e) => HttpResponse::InternalServerError()
.body(format!("Failed to read response body: {}", e)),
}
}
Err(e) => HttpResponse::InternalServerError().body(format!(
"Failed to send request to worker {}: {}",
worker_url, e
)),
};
// Record request metrics
if route != "/health" {
let duration = start.elapsed();
counter!("sgl_router_requests_total", "route" => route.to_string()).increment(1);
histogram!("sgl_router_request_duration_seconds", "route" => route.to_string())
.record(duration.as_secs_f64());
if !response.status().is_success() {
counter!("sgl_router_request_errors_total", "route" => route.to_string())
.increment(1);
}
}
response
}
pub async fn route_to_first(
&self,
client: &reqwest::Client,
route: &str,
req: &HttpRequest,
) -> HttpResponse {
const MAX_REQUEST_RETRIES: u32 = 3;
const MAX_TOTAL_RETRIES: u32 = 6;
let mut total_retries = 0;
while total_retries < MAX_TOTAL_RETRIES {
match self.select_first_worker() {
Ok(worker_url) => {
let mut request_retries = 0;
// Try the same worker multiple times
while request_retries < MAX_REQUEST_RETRIES {
if total_retries >= 1 {
info!("Retrying request after {} failed attempts", total_retries);
}
let response = self.send_request(client, &worker_url, route, req).await;
if response.status().is_success() {
return response;
} else {
// if the worker is healthy, it means the request is bad, so return the error response
let health_response =
self.send_request(client, &worker_url, "/health", req).await;
if health_response.status().is_success() {
return response;
}
}
warn!(
"Request to {} failed (attempt {}/{})",
worker_url,
request_retries + 1,
MAX_REQUEST_RETRIES
);
request_retries += 1;
total_retries += 1;
if request_retries == MAX_REQUEST_RETRIES {
warn!("Removing failed worker: {}", worker_url);
self.remove_worker(&worker_url);
break;
}
}
}
Err(e) => return HttpResponse::InternalServerError().body(e),
}
}
HttpResponse::InternalServerError().body("All retry attempts failed")
}
pub async fn route_to_all(
&self,
client: &reqwest::Client,
route: &str,
req: &HttpRequest,
) -> HttpResponse {
// Get all worker URLs based on router type
let worker_urls = match self {
Router::PrefillDecode { .. } => {
// For PD mode, route_to_all is not supported directly
// It should be handled by PDRouter if needed
return HttpResponse::NotImplemented()
.body("route_to_all not implemented for PrefillDecode mode");
}
_ => self.get_worker_urls().read().unwrap().clone(),
};
// Send requests to all workers concurrently
let mut tasks = Vec::new();
for worker_url in &worker_urls {
let mut request_builder = client.post(format!("{}{}", worker_url, route));
// Copy headers from original request
for (name, value) in copy_request_headers(req) {
request_builder = request_builder.header(name, value);
}
tasks.push(request_builder.send());
}
// Wait for all responses
let results = futures_util::future::join_all(tasks).await;
// Check if all succeeded
let all_success = results.iter().all(|r| {
r.as_ref()
.map(|res| res.status().is_success())
.unwrap_or(false)
});
if all_success {
HttpResponse::Ok().body("Operation completed on all servers")
} else {
HttpResponse::InternalServerError().body("Operation failed on one or more servers")
}
}
pub async fn get_all_loads(
&self,
client: &reqwest::Client,
_req: &HttpRequest,
) -> HttpResponse {
// For PD mode, delegate to PDRouter
match self {
Router::PrefillDecode { pd_router } => {
return pd_router.get_loads(client).await;
}
_ => {
// For non-PD routers, handle normally
}
}
let urls = self.get_worker_urls().read().unwrap().clone();
let prefill_urls: Vec<String> = Vec::new();
let decode_urls = urls;
// Collect loads from all servers
let mut prefill_loads = Vec::new();
let mut decode_loads = Vec::new();
// Get prefill loads
for url in &prefill_urls {
let load = self.get_worker_load(client, url).await.unwrap_or(-1);
prefill_loads.push(serde_json::json!({
"engine": format!("(Prefill@{})", url),
"load": load as i64
}));
}
// Get decode loads
for url in &decode_urls {
let load = self.get_worker_load(client, url).await.unwrap_or(-1);
decode_loads.push(serde_json::json!({
"engine": format!("(Decode@{})", url),
"load": load as i64
}));
}
HttpResponse::Ok().json(serde_json::json!({
"prefill": prefill_loads,
"decode": decode_loads
}))
}
// New method to route typed requests directly
pub async fn route_typed_request<
T: crate::openai_api_types::GenerationRequest + serde::Serialize + Clone,
>(
&self,
client: &reqwest::Client,
req: &HttpRequest,
typed_req: &T,
route: &str,
) -> HttpResponse {
match self {
Router::PrefillDecode { .. } => HttpResponse::InternalServerError()
.body("PD routing should use specialized typed handlers"),
_ => {
// Handle retries like the original implementation
let start = Instant::now();
const MAX_REQUEST_RETRIES: u32 = 3;
const MAX_TOTAL_RETRIES: u32 = 6;
let mut total_retries = 0;
while total_retries < MAX_TOTAL_RETRIES {
// Extract routing text directly from typed request
let text = typed_req.extract_text_for_routing();
let is_stream = typed_req.is_stream();
// Select worker based on text
let worker_url = self.select_generate_worker_from_text(&text);
let mut request_retries = 0;
// Try the same worker multiple times
while request_retries < MAX_REQUEST_RETRIES {
if total_retries >= 1 {
info!("Retrying request after {} failed attempts", total_retries);
counter!("sgl_router_retries_total", "route" => route.to_string())
.increment(1);
}
// Send typed request directly
let response = self
.send_typed_request(
client,
req,
typed_req,
route,
&worker_url,
is_stream,
)
.await;
if response.status().is_success() {
let duration = start.elapsed();
histogram!("sgl_router_generate_duration_seconds", "route" => route.to_string())
.record(duration.as_secs_f64());
return response;
} else {
// if the worker is healthy, it means the request is bad, so return the error response
let health_response =
self.send_request(client, &worker_url, "/health", req).await;
if health_response.status().is_success() {
counter!("sgl_router_request_errors_total", "route" => route.to_string())
.increment(1);
return response;
}
}
warn!(
"Generate request to {} failed (attempt {}/{})",
worker_url,
request_retries + 1,
MAX_REQUEST_RETRIES
);
request_retries += 1;
total_retries += 1;
if request_retries == MAX_REQUEST_RETRIES {
warn!("Removing failed worker: {}", worker_url);
self.remove_worker(&worker_url);
break;
}
}
}
counter!("sgl_router_request_errors_total", "route" => route.to_string())
.increment(1);
HttpResponse::InternalServerError().body("All retry attempts failed")
}
}
}
// Helper method to select worker from text
fn select_generate_worker_from_text(&self, text: &str) -> String {
match self {
Router::RoundRobin {
worker_urls,
current_index,
..
} => {
let idx = current_index
.fetch_update(
std::sync::atomic::Ordering::SeqCst,
std::sync::atomic::Ordering::SeqCst,
|x| Some((x + 1) % worker_urls.read().unwrap().len()),
)
.unwrap();
worker_urls.read().unwrap()[idx].clone()
}
Router::Random { worker_urls, .. } => worker_urls.read().unwrap()
[rand::random::<usize>() % worker_urls.read().unwrap().len()]
.clone(),
Router::CacheAware {
worker_urls,
tree,
running_queue,
processed_queue,
cache_threshold,
balance_abs_threshold,
balance_rel_threshold,
..
} => {
let tree = tree.lock().unwrap();
let mut running_queue = running_queue.lock().unwrap();
// Get current load statistics
let max_load = *running_queue.values().max().unwrap_or(&0);
let min_load = *running_queue.values().min().unwrap_or(&0);
// Load is considered imbalanced if:
// 1. (max - min) > abs_threshold AND
// 2. max > rel_threshold * min
let is_imbalanced = max_load.saturating_sub(min_load) > *balance_abs_threshold
&& (max_load as f32) > (min_load as f32 * balance_rel_threshold);
let selected_url = if is_imbalanced {
// Log load balancing trigger and current queue state
info!(
"Load balancing triggered due to workload imbalance:\n\
Max load: {}, Min load: {}\n\
Current running queue: {:?}",
max_load, min_load, running_queue
);
counter!("sgl_router_load_balancing_events_total").increment(1);
gauge!("sgl_router_max_load").set(max_load as f64);
gauge!("sgl_router_min_load").set(min_load as f64);
// Use shortest queue routing when load is imbalanced
running_queue
.iter()
.min_by_key(|(_url, &count)| count)
.map(|(url, _)| url.clone())
.unwrap_or_else(|| worker_urls.read().unwrap()[0].clone())
} else {
// Use cache-aware routing when load is balanced
let (matched_text, matched_worker) = tree.prefix_match(&text);
let matched_rate =
matched_text.chars().count() as f32 / text.chars().count() as f32;
if matched_rate > *cache_threshold {
counter!("sgl_router_cache_hits_total").increment(1);
matched_worker.to_string()
} else {
counter!("sgl_router_cache_misses_total").increment(1);
tree.get_smallest_tenant()
}
};
// Update queues and tree
*running_queue.get_mut(&selected_url).unwrap() += 1;
*processed_queue
.lock()
.unwrap()
.get_mut(&selected_url)
.unwrap() += 1;
gauge!("sgl_router_running_requests", "worker" => selected_url.to_string())
.set(*running_queue.get(&selected_url).unwrap() as f64);
counter!("sgl_router_processed_requests_total", "worker" => selected_url.to_string()).increment(1);
tree.insert(&text, &selected_url);
selected_url
}
Router::PrefillDecode { .. } => {
// For PD mode, we don't use this method
return "PD_MODE_ERROR".to_string();
}
}
}
// Send typed request directly without conversion
async fn send_typed_request<T: serde::Serialize>(
&self,
client: &reqwest::Client,
req: &HttpRequest,
typed_req: &T,
route: &str,
worker_url: &str,
is_stream: bool,
) -> HttpResponse {
let start = Instant::now();
// Debug: Log what we're sending
if let Ok(json_str) = serde_json::to_string_pretty(typed_req) {
debug!("Sending request to {}: {}", route, json_str);
}
let mut request_builder = client
.post(format!("{}{}", worker_url, route))
.json(typed_req); // Use json() directly with typed request
// Copy all headers from original request
for (name, value) in copy_request_headers(req) {
// Skip Content-Type and Content-Length as .json() sets them
if name.to_lowercase() != "content-type" && name.to_lowercase() != "content-length" {
request_builder = request_builder.header(&name, &value);
}
}
let res = match request_builder.send().await {
Ok(res) => res,
Err(e) => {
error!("Failed to send request to {}: {}", worker_url, e);
return HttpResponse::InternalServerError().body(format!("Request failed: {}", e));
}
};
let status = actix_web::http::StatusCode::from_u16(res.status().as_u16())
.unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);
if !is_stream {
// For non-streaming requests, get response first
let response = match res.bytes().await {
Ok(body) => HttpResponse::build(status).body(body.to_vec()),
Err(e) => {
let error_msg = format!("Failed to get response body: {}", e);
HttpResponse::InternalServerError().body(error_msg)
}
};
// Then decrement running queue counter if using CacheAware
if let Router::CacheAware { running_queue, .. } = self {
if let Ok(mut queue) = running_queue.lock() {
if let Some(count) = queue.get_mut(worker_url) {
*count = count.saturating_sub(1);
}
}
}
// Record metrics
let duration = start.elapsed();
histogram!("sgl_router_generate_duration_seconds", "route" => route.to_string())
.record(duration.as_secs_f64());
counter!("sgl_router_requests_total", "route" => route.to_string()).increment(1);
response
} else if let Router::CacheAware { running_queue, .. } = self {
let running_queue = Arc::clone(running_queue);
let worker_url = worker_url.to_string();
HttpResponse::build(status)
.insert_header((CONTENT_TYPE, HeaderValue::from_static("text/event-stream")))
.streaming(
res.bytes_stream()
.map_err(|_| {
actix_web::error::ErrorInternalServerError("Failed to read stream")
})
.inspect(move |bytes| {
let bytes = bytes.as_ref().unwrap();
if bytes
.as_ref()
.windows(12)
.any(|window| window == b"data: [DONE]")
{
let mut locked_queue = running_queue.lock().unwrap();
let count = locked_queue.get_mut(&worker_url).unwrap();
*count = count.saturating_sub(1);
debug!("Streaming is done!!")
}
}),
)
} else {
HttpResponse::build(status)
.insert_header((CONTENT_TYPE, HeaderValue::from_static("text/event-stream")))
.streaming(res.bytes_stream().map_err(|_| {
actix_web::error::ErrorInternalServerError("Failed to read stream")
}))
}
}
pub async fn add_worker(&self, worker_url: &str) -> Result<String, String> {
let (timeout_secs, interval_secs) = match self {
Router::Random {
timeout_secs,
interval_secs,
..
} => (*timeout_secs, *interval_secs),
Router::RoundRobin {
timeout_secs,
interval_secs,
..
} => (*timeout_secs, *interval_secs),
Router::CacheAware {
timeout_secs,
interval_secs,
..
} => (*timeout_secs, *interval_secs),
Router::PrefillDecode { .. } => {
// For PD mode, we don't support adding workers via this method
return Err("Adding workers to PrefillDecode router not supported via add_worker. Use dedicated PD management methods.".to_string());
}
};
let start_time = std::time::Instant::now();
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(timeout_secs))
.build()
.map_err(|e| format!("Failed to create HTTP client: {}", e))?;
loop {
if start_time.elapsed() > Duration::from_secs(timeout_secs) {
error!(
"Timeout {}s waiting for worker {} to become healthy. Please set --router-worker-startup-timeout-secs (sglang_router.launch_server) or --worker-startup-timeout-secs (sglang_worker.router) to a larger value",
timeout_secs, worker_url
);
return Err(format!(
"Timeout {}s waiting for worker {} to become healthy. Please set --router-worker-startup-timeout-secs (sglang_router.launch_server) or --worker-startup-timeout-secs (sglang_worker.router) to a larger value",
timeout_secs, worker_url
));
}
match client.get(&format!("{}/health", worker_url)).send().await {
Ok(res) => {
if res.status().is_success() {
match self {
Router::RoundRobin { worker_urls, .. }
| Router::Random { worker_urls, .. }
| Router::CacheAware { worker_urls, .. } => {
info!("Worker {} health check passed", worker_url);
let mut urls = worker_urls.write().unwrap();
if urls.contains(&worker_url.to_string()) {
return Err(format!("Worker {} already exists", worker_url));
}
info!("Added worker: {}", worker_url);
urls.push(worker_url.to_string());
gauge!("sgl_router_active_workers").set(urls.len() as f64);
}
Router::PrefillDecode { .. } => {
return Err("Adding workers to PrefillDecode router not supported via add_worker. Use dedicated PD management methods.".to_string());
}
}
// If cache aware, initialize the queues for the new worker
if let Router::CacheAware {
running_queue,
processed_queue,
tree,
..
} = self
{
// Add worker to running queue with initial count of 0
running_queue
.lock()
.unwrap()
.insert(worker_url.to_string(), 0);
// Add worker to processed queue with initial count of 0
processed_queue
.lock()
.unwrap()
.insert(worker_url.to_string(), 0);
// Add worker to tree
tree.lock().unwrap().insert("", worker_url);
}
return Ok(format!("Successfully added worker: {}", worker_url));
} else {
info!(
"Worker {} health check is pending with status: {}.",
worker_url,
res.status()
);
// if the url does not have http or https prefix, warn users
if !worker_url.starts_with("http://") && !worker_url.starts_with("https://")
{
warn!("The worker url {} does not have http or https prefix. Please add the prefix to the url.", worker_url);
}
tokio::time::sleep(Duration::from_secs(interval_secs)).await;
continue;
}
}
Err(e) => {
info!(
"Worker {} health check is pending with error: {}",
worker_url, e
);
// if the url does not have http or https prefix, warn users
if !worker_url.starts_with("http://") && !worker_url.starts_with("https://") {
warn!("The worker url {} does not have http or https prefix. Please add the prefix to the url.", worker_url);
}
tokio::time::sleep(Duration::from_secs(interval_secs)).await;
continue;
}
}
}
}
pub fn remove_worker(&self, worker_url: &str) {
match self {
Router::RoundRobin { worker_urls, .. }
| Router::Random { worker_urls, .. }
| Router::CacheAware { worker_urls, .. } => {
let mut urls = worker_urls.write().unwrap();
if let Some(index) = urls.iter().position(|url| url == &worker_url) {
urls.remove(index);
info!("Removed worker: {}", worker_url);
gauge!("sgl_router_active_workers").set(urls.len() as f64);
} else {
warn!("Worker {} not found, skipping removal", worker_url);
return;
}
}
Router::PrefillDecode { .. } => {
warn!("Removing workers from PrefillDecode router not supported via remove_worker. Use dedicated PD management methods.");
return;
}
}
// if cache aware, remove the worker from the tree
if let Router::CacheAware {
tree,
running_queue,
processed_queue,
..
} = self
{
tree.lock().unwrap().remove_tenant(&worker_url);
running_queue
.lock()
.unwrap()
.remove(&worker_url.to_string());
processed_queue
.lock()
.unwrap()
.remove(&worker_url.to_string());
info!(
"Removed worker from tree and cleaned up queues: {}",
worker_url
);
}
}
/// Add a worker with PD mode support
pub async fn add_pd_worker(
&self,
worker_url: &str,
pod_type: crate::service_discovery::PodType,
bootstrap_port: Option<u16>,
) -> Result<String, String> {
match self {
Router::PrefillDecode { pd_router } => match pod_type {
crate::service_discovery::PodType::Prefill => pd_router
.add_prefill_server(worker_url.to_string(), bootstrap_port)
.await
.map_err(|e| e.to_string()),
crate::service_discovery::PodType::Decode => pd_router
.add_decode_server(worker_url.to_string())
.await
.map_err(|e| e.to_string()),
crate::service_discovery::PodType::Regular => {
Err("Regular pod type not supported in PD mode".to_string())
}
},
_ => Err("add_pd_worker only supported in PD mode".to_string()),
}
}
/// Remove a worker with PD mode support
pub async fn remove_pd_worker(
&self,
worker_url: &str,
pod_type: crate::service_discovery::PodType,
) -> Result<String, String> {
match self {
Router::PrefillDecode { pd_router } => match pod_type {
crate::service_discovery::PodType::Prefill => pd_router
.remove_prefill_server(worker_url)
.await
.map_err(|e| e.to_string()),
crate::service_discovery::PodType::Decode => pd_router
.remove_decode_server(worker_url)
.await
.map_err(|e| e.to_string()),
crate::service_discovery::PodType::Regular => {
Err("Regular pod type not supported in PD mode".to_string())
}
},
_ => Err("remove_pd_worker only supported in PD mode".to_string()),
}
}
async fn get_worker_load(&self, client: &reqwest::Client, worker_url: &str) -> Option<isize> {
match client.get(&format!("{}/get_load", worker_url)).send().await {
Ok(res) if res.status().is_success() => match res.bytes().await {
Ok(bytes) => match serde_json::from_slice::<serde_json::Value>(&bytes) {
Ok(data) => data
.get("load")
.and_then(|v| v.as_i64())
.map(|v| v as isize),
Err(e) => {
debug!("Failed to parse load response from {}: {}", worker_url, e);
None
}
},
Err(e) => {
debug!("Failed to read load response from {}: {}", worker_url, e);
None
}
},
Ok(res) => {
debug!(
"Worker {} returned non-success status: {}",
worker_url,
res.status()
);
None
}
Err(e) => {
debug!("Failed to get load from {}: {}", worker_url, e);
None
}
}
}
// PD-specific wrapper methods that delegate to PDRouter
pub async fn route_pd_health_generate(
&self,
_client: &reqwest::Client,
_req: &HttpRequest,
) -> HttpResponse {
match self {
Router::PrefillDecode { pd_router } => {
pd_router.health_generate(&pd_router.http_client).await
}
_ => HttpResponse::InternalServerError().body("Not in PrefillDecode mode"),
}
}
pub async fn route_pd_generate_typed(
&self,
_client: &reqwest::Client,
req: &HttpRequest,
typed_req: crate::pd_types::GenerateReqInput,
route: &str,
) -> HttpResponse {
match self {
Router::PrefillDecode { pd_router } => {
pd_router
.route_generate(&pd_router.http_client, req, typed_req, route)
.await
}
_ => HttpResponse::InternalServerError().body("Not in PrefillDecode mode"),
}
}
pub async fn route_pd_chat_typed(
&self,
_client: &reqwest::Client,
req: &HttpRequest,
typed_req: crate::pd_types::ChatReqInput,
route: &str,
) -> HttpResponse {
match self {
Router::PrefillDecode { pd_router } => {
pd_router
.route_chat(&pd_router.http_client, req, typed_req, route)
.await
}
_ => HttpResponse::InternalServerError().body("Not in PrefillDecode mode"),
}
}
pub async fn get_pd_server_info(
&self,
_client: &reqwest::Client,
_req: &HttpRequest,
) -> HttpResponse {
match self {
Router::PrefillDecode { pd_router } => {
pd_router.get_server_info(&pd_router.http_client).await
}
_ => HttpResponse::InternalServerError().body("Not in PrefillDecode mode"),
}
}
pub async fn get_pd_models(
&self,
_client: &reqwest::Client,
req: &HttpRequest,
) -> HttpResponse {
match self {
Router::PrefillDecode { pd_router } => {
pd_router.get_models(&pd_router.http_client, req).await
}
_ => HttpResponse::InternalServerError().body("Not in PrefillDecode mode"),
}
}
pub async fn route_pd_flush_cache(&self, _client: &reqwest::Client) -> HttpResponse {
match self {
Router::PrefillDecode { pd_router } => {
pd_router.flush_cache(&pd_router.http_client).await
}
_ => HttpResponse::InternalServerError().body("Not in PrefillDecode mode"),
}
}
pub async fn get_pd_model_info(
&self,
_client: &reqwest::Client,
req: &HttpRequest,
) -> HttpResponse {
match self {
Router::PrefillDecode { pd_router } => {
pd_router.get_model_info(&pd_router.http_client, req).await
}
_ => HttpResponse::InternalServerError().body("Not in PrefillDecode mode"),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::service_discovery::PodType;
fn create_test_regular_router() -> Router {
Router::Random {
worker_urls: Arc::new(RwLock::new(vec![
"http://worker1:8080".to_string(),
"http://worker2:8080".to_string(),
])),
timeout_secs: 5,
interval_secs: 1,
}
}
#[test]
fn test_router_get_worker_urls_regular() {
let router = create_test_regular_router();
let worker_urls = router.get_worker_urls();
let urls = worker_urls.read().unwrap();
assert_eq!(urls.len(), 2);
assert!(urls.contains(&"http://worker1:8080".to_string()));
assert!(urls.contains(&"http://worker2:8080".to_string()));
}
// #[test]
// fn test_router_get_worker_urls_pd_mode() {
// // For PD mode, get_worker_urls returns empty list
// // Note: PDRouter::new requires health checks which fail in tests
// // This test would need a mock server or different test setup
// }
#[tokio::test]
async fn test_add_pd_worker_with_regular_router() {
let router = create_test_regular_router();
let result = router
.add_pd_worker("http://new-worker:8080", PodType::Prefill, Some(8081))
.await;
assert!(result.is_err());
assert!(result
.unwrap_err()
.contains("add_pd_worker only supported in PD mode"));
}
#[tokio::test]
async fn test_remove_pd_worker_with_regular_router() {
let router = create_test_regular_router();
let result = router
.remove_pd_worker("http://worker:8080", PodType::Decode)
.await;
assert!(result.is_err());
assert!(result
.unwrap_err()
.contains("remove_pd_worker only supported in PD mode"));
}
// #[tokio::test]
// async fn test_add_pd_worker_with_pd_router_regular_type() {
// // Note: PDRouter::new requires health checks which fail in tests
// // This test would need a mock server or different test setup
// }
// #[tokio::test]
// async fn test_remove_pd_worker_with_pd_router_regular_type() {
// // Note: PDRouter::new requires health checks which fail in tests
// // This test would need a mock server or different test setup
// }
#[test]
fn test_select_first_worker_regular() {
let router = create_test_regular_router();
let result = router.select_first_worker();
assert!(result.is_ok());
assert_eq!(result.unwrap(), "http://worker1:8080");
}
// #[test]
// fn test_select_first_worker_pd_mode() {
// // Note: PDRouter::new requires health checks which fail in tests
// // This test would need a mock server or different test setup
// }
#[test]
fn test_wait_for_healthy_workers_empty_list() {
let result = Router::wait_for_healthy_workers(&[], 1, 1);
assert!(result.is_ok());
}
#[test]
fn test_wait_for_healthy_workers_invalid_urls() {
// This test will timeout quickly since the URLs are invalid
let result =
Router::wait_for_healthy_workers(&["http://nonexistent:8080".to_string()], 1, 1);
assert!(result.is_err());
assert!(result.unwrap_err().contains("Timeout"));
}
}