sglang.0.4.8.post1/sglang/sgl-pdlb/src/lb_state.rs

176 lines
6.0 KiB
Rust

use crate::io_struct::Bootstrap;
use crate::strategy_lb::{EngineInfo, EngineLoad, EngineType, LBPolicy, StrategyLB};
use actix_web::HttpResponse;
use bytes::Bytes;
use futures::{Stream, StreamExt, future::join_all};
use reqwest::{Method, StatusCode};
use std::pin::Pin;
pub enum ProxyResponseBody {
Full(Bytes),
Stream(Pin<Box<dyn Stream<Item = Result<Bytes, actix_web::Error>> + Send>>),
}
pub struct ProxyResponse {
pub status: StatusCode,
pub body: ProxyResponseBody,
}
impl ProxyResponse {
pub fn to_json(&self) -> Result<serde_json::Value, actix_web::Error> {
match &self.body {
ProxyResponseBody::Full(body) => Ok(serde_json::from_slice(&body)?),
ProxyResponseBody::Stream(_) => Err(actix_web::error::ErrorBadRequest(
"Stream response is not supported",
)),
}
}
}
impl Into<Result<HttpResponse, actix_web::Error>> for ProxyResponse {
fn into(self) -> Result<HttpResponse, actix_web::Error> {
let status = actix_web::http::StatusCode::from_u16(self.status.as_u16()).map_err(|e| {
actix_web::error::ErrorBadGateway(format!("Invalid status code: {}", e))
})?;
match self.body {
ProxyResponseBody::Full(body) => Ok(HttpResponse::Ok().status(status).body(body)),
ProxyResponseBody::Stream(body) => Ok(HttpResponse::Ok()
.status(status)
.content_type("application/octet-stream")
.streaming(body)),
}
}
}
#[derive(Debug, Clone)]
pub struct LBConfig {
pub host: String,
pub port: u16,
pub policy: String,
pub prefill_infos: Vec<(String, Option<u16>)>,
pub decode_infos: Vec<String>,
pub log_interval: u64,
pub timeout: u64,
}
#[derive(Debug, Clone)]
pub struct LBState {
pub strategy_lb: StrategyLB,
pub client: reqwest::Client,
pub log_interval: u64,
}
impl LBState {
pub fn new(lb_config: LBConfig) -> anyhow::Result<Self> {
let client = reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(lb_config.timeout))
.build()?;
let policy = match lb_config.policy.as_str() {
"random" => LBPolicy::Random,
"po2" => LBPolicy::PowerOfTwo,
_ => anyhow::bail!("Invalid policy"),
};
let prefill_servers = lb_config
.prefill_infos
.into_iter()
.map(|(url, port)| EngineInfo::new_prefill(url, port))
.collect();
let decode_servers = lb_config
.decode_infos
.into_iter()
.map(|url| EngineInfo::new_decode(url))
.collect();
let lb = StrategyLB::new(policy, prefill_servers, decode_servers);
Ok(Self {
strategy_lb: lb,
client,
log_interval: lb_config.log_interval,
})
}
pub async fn route_one(
&self,
engine_info: &EngineInfo,
method: Method,
api_path: &str,
request: Option<&serde_json::Value>,
stream: bool,
) -> Result<ProxyResponse, actix_web::Error> {
let url = engine_info.api_path(api_path);
let request = request.unwrap_or(&serde_json::Value::Null);
let task = self.client.request(method, url).json(request).send();
let resp = task.await.map_err(actix_web::error::ErrorBadGateway)?;
// FIXME: handle error status code (map status code to error)
let status = resp.status();
let body = if stream {
let resp_stream = resp.bytes_stream().map(|r| {
r.map_err(actix_web::error::ErrorBadGateway)
.map(Bytes::from)
});
ProxyResponseBody::Stream(Box::pin(resp_stream))
} else {
let body = resp
.bytes()
.await
.map_err(actix_web::error::ErrorBadGateway)?;
ProxyResponseBody::Full(body)
};
Ok(ProxyResponse { status, body })
}
pub async fn route_collect(
&self,
engines: &Vec<EngineInfo>,
method: Method,
api_path: &str,
request: Option<&serde_json::Value>,
) -> Result<Vec<ProxyResponse>, actix_web::Error> {
let tasks = engines
.iter()
.map(|engine| self.route_one(engine, method.clone(), api_path, request, false));
let responses = join_all(tasks).await;
responses
.into_iter()
.map(|r| r.map_err(actix_web::error::ErrorBadGateway))
.collect()
}
pub async fn generate(
&self,
api_path: &str,
mut req: Box<dyn Bootstrap>,
) -> Result<HttpResponse, actix_web::Error> {
let (prefill, decode) = self.strategy_lb.select_pair(&self.client).await;
let stream = req.is_stream();
req.add_bootstrap_info(&prefill)?;
let json = serde_json::to_value(req)?;
let prefill_task = self.route_one(&prefill, Method::POST, api_path, Some(&json), false);
let decode_task = self.route_one(&decode, Method::POST, api_path, Some(&json), stream);
let (_, decode_response) = tokio::join!(prefill_task, decode_task);
decode_response?.into()
}
pub async fn get_engine_loads(
&self,
) -> Result<(Vec<EngineLoad>, Vec<EngineLoad>), actix_web::Error> {
let servers = self.strategy_lb.get_all_servers();
let responses = self
.route_collect(&servers, Method::GET, "/get_load", None)
.await?;
let loads = responses
.into_iter()
.enumerate()
.map(|(i, r)| Ok(EngineLoad::from_json(&servers[i], &r.to_json()?)))
.collect::<Result<Vec<EngineLoad>, actix_web::Error>>()?;
let mut prefill_loads = Vec::new();
let mut decode_loads = Vec::new();
for load in loads {
match load.engine_info.engine_type {
EngineType::Prefill => prefill_loads.push(load),
EngineType::Decode => decode_loads.push(load),
}
}
Ok((prefill_loads, decode_loads))
}
}