use crate::strategy_lb::EngineInfo; use serde::{Deserialize, Serialize}; use serde_json::Value; #[derive(Debug, Deserialize, Serialize)] #[serde(untagged)] pub enum SingleOrBatch { Single(T), Batch(Vec), } pub type InputIds = SingleOrBatch>; pub type InputText = SingleOrBatch; pub type BootstrapHost = SingleOrBatch; pub type BootstrapPort = SingleOrBatch>; pub type BootstrapRoom = SingleOrBatch; #[typetag::serde(tag = "type")] pub trait Bootstrap { fn is_stream(&self) -> bool; fn get_batch_size(&self) -> Result, actix_web::Error>; fn set_bootstrap_info( &mut self, bootstrap_host: BootstrapHost, bootstrap_port: BootstrapPort, bootstrap_room: BootstrapRoom, ); fn add_bootstrap_info(&mut self, prefill_info: &EngineInfo) -> Result<(), actix_web::Error> { let batch_size = self.get_batch_size()?; if let Some(batch_size) = batch_size { self.set_bootstrap_info( BootstrapHost::Batch(vec![prefill_info.get_hostname(); batch_size]), BootstrapPort::Batch(vec![prefill_info.bootstrap_port; batch_size]), BootstrapRoom::Batch((0..batch_size).map(|_| rand::random::()).collect()), ); } else { self.set_bootstrap_info( BootstrapHost::Single(prefill_info.get_hostname()), BootstrapPort::Single(prefill_info.bootstrap_port), BootstrapRoom::Single(rand::random::()), ); } Ok(()) } } #[derive(Debug, Deserialize, Serialize)] pub struct GenerateReqInput { pub text: Option, pub input_ids: Option, #[serde(default)] pub stream: bool, pub bootstrap_host: Option, pub bootstrap_port: Option, pub bootstrap_room: Option, #[serde(flatten)] pub other: Value, } impl GenerateReqInput { pub fn get_batch_size(&self) -> Result, actix_web::Error> { if self.text.is_some() && self.input_ids.is_some() { return Err(actix_web::error::ErrorBadRequest( "Both text and input_ids are present in the request".to_string(), )); } if let Some(InputText::Batch(texts)) = &self.text { return Ok(Some(texts.len())); } if let Some(InputIds::Batch(ids)) = &self.input_ids { return Ok(Some(ids.len())); } Ok(None) } } #[typetag::serde] impl Bootstrap for GenerateReqInput { fn is_stream(&self) -> bool { self.stream } fn get_batch_size(&self) -> Result, actix_web::Error> { self.get_batch_size() } fn set_bootstrap_info( &mut self, bootstrap_host: BootstrapHost, bootstrap_port: BootstrapPort, bootstrap_room: BootstrapRoom, ) { self.bootstrap_host = Some(bootstrap_host); self.bootstrap_port = Some(bootstrap_port); self.bootstrap_room = Some(bootstrap_room); } } #[derive(Debug, Deserialize, Serialize)] pub struct ChatReqInput { #[serde(default)] pub stream: bool, pub bootstrap_host: Option, pub bootstrap_port: Option, pub bootstrap_room: Option, #[serde(flatten)] pub other: Value, } #[typetag::serde] impl Bootstrap for ChatReqInput { fn is_stream(&self) -> bool { self.stream } fn get_batch_size(&self) -> Result, actix_web::Error> { Ok(None) } fn set_bootstrap_info( &mut self, bootstrap_host: BootstrapHost, bootstrap_port: BootstrapPort, bootstrap_room: BootstrapRoom, ) { self.bootstrap_host = Some(bootstrap_host); self.bootstrap_port = Some(bootstrap_port); self.bootstrap_room = Some(bootstrap_room); } }