sglang.0.4.8.post1/sglang/sgl-pdlb/src/strategy_lb.rs

183 lines
5.1 KiB
Rust

use rand::Rng;
use serde_json::json;
#[derive(Debug, Clone)]
pub enum EngineType {
Prefill,
Decode,
}
#[derive(Debug, Clone)]
pub struct EngineInfo {
pub engine_type: EngineType,
pub url: String,
pub bootstrap_port: Option<u16>,
}
impl EngineInfo {
pub fn new_prefill(url: String, bootstrap_port: Option<u16>) -> Self {
EngineInfo {
engine_type: EngineType::Prefill,
url,
bootstrap_port,
}
}
pub fn new_decode(url: String) -> Self {
EngineInfo {
engine_type: EngineType::Decode,
url,
bootstrap_port: None,
}
}
pub fn api_path(&self, api_path: &str) -> String {
if api_path.starts_with("/") {
format!("{}{}", self.url, api_path)
} else {
format!("{}/{}", self.url, api_path)
}
}
pub fn to_string(&self) -> String {
format!("({:?}@{})", self.engine_type, self.url)
}
pub fn get_hostname(&self) -> String {
let url = self
.url
.trim_start_matches("http://")
.trim_start_matches("https://");
url.split(':').next().unwrap().to_string()
}
}
pub struct EngineLoad {
pub engine_info: EngineInfo,
pub load: isize,
}
impl EngineLoad {
pub fn from_json(engine_info: &EngineInfo, json: &serde_json::Value) -> Self {
let load = match json.get("load") {
Some(load) => load.as_i64().unwrap_or(-1) as isize,
None => -1,
};
EngineLoad {
engine_info: engine_info.clone(),
load,
}
}
pub fn to_json(&self) -> serde_json::Value {
json!({
"engine": self.engine_info.to_string(),
"load": self.load,
})
}
pub fn to_string(&self) -> String {
format!("{}: {}", self.engine_info.to_string(), self.load)
}
}
#[derive(Debug, Clone)]
pub enum LBPolicy {
Random,
PowerOfTwo,
}
#[derive(Debug, Clone)]
pub struct StrategyLB {
pub policy: LBPolicy,
pub prefill_servers: Vec<EngineInfo>,
pub decode_servers: Vec<EngineInfo>,
}
impl StrategyLB {
pub fn new(
policy: LBPolicy,
prefill_servers: Vec<EngineInfo>,
decode_servers: Vec<EngineInfo>,
) -> Self {
StrategyLB {
policy,
prefill_servers,
decode_servers,
}
}
pub fn get_one_server(&self) -> EngineInfo {
assert!(!self.prefill_servers.is_empty());
assert!(!self.decode_servers.is_empty());
self.prefill_servers[0].clone()
}
pub fn get_all_servers(&self) -> Vec<EngineInfo> {
let mut all_servers = Vec::new();
all_servers.extend(self.prefill_servers.clone());
all_servers.extend(self.decode_servers.clone());
all_servers
}
pub async fn select_pair(&self, client: &reqwest::Client) -> (EngineInfo, EngineInfo) {
match self.policy {
LBPolicy::Random => self.select_pd_pair_random(),
LBPolicy::PowerOfTwo => self.select_pd_pair_po2(client).await,
}
}
fn select_pd_pair_random(&self) -> (EngineInfo, EngineInfo) {
let mut rng = rand::rng();
let prefill_index = rng.random_range(0..self.prefill_servers.len());
let decode_index = rng.random_range(0..self.decode_servers.len());
(
self.prefill_servers[prefill_index].clone(),
self.decode_servers[decode_index].clone(),
)
}
async fn get_load_from_engine(
&self,
client: &reqwest::Client,
engine_info: &EngineInfo,
) -> Option<isize> {
let url = engine_info.api_path("/get_load");
let response = client.get(url).send().await.unwrap();
match response.status() {
reqwest::StatusCode::OK => {
let data = response.json::<serde_json::Value>().await.unwrap();
Some(data["load"].as_i64().unwrap() as isize)
}
_ => None,
}
}
async fn select_pd_pair_po2(&self, client: &reqwest::Client) -> (EngineInfo, EngineInfo) {
let mut rng = rand::rng();
let prefill1 =
self.prefill_servers[rng.random_range(0..self.prefill_servers.len())].clone();
let prefill2 =
self.prefill_servers[rng.random_range(0..self.prefill_servers.len())].clone();
let decode1 = self.decode_servers[rng.random_range(0..self.decode_servers.len())].clone();
let decode2 = self.decode_servers[rng.random_range(0..self.decode_servers.len())].clone();
let prefill1_load = self.get_load_from_engine(client, &prefill1).await;
let prefill2_load = self.get_load_from_engine(client, &prefill2).await;
let decode1_load = self.get_load_from_engine(client, &decode1).await;
let decode2_load = self.get_load_from_engine(client, &decode2).await;
(
if prefill1_load < prefill2_load {
prefill1
} else {
prefill2
},
if decode1_load < decode2_load {
decode1
} else {
decode2
},
)
}
}