615 lines
18 KiB
Rust
Executable File
615 lines
18 KiB
Rust
Executable File
// Mock worker for testing - these functions are used by integration tests
|
|
#![allow(dead_code)]
|
|
|
|
use axum::{
|
|
extract::{Json, State},
|
|
http::StatusCode,
|
|
response::sse::{Event, KeepAlive},
|
|
response::{IntoResponse, Response, Sse},
|
|
routing::{get, post},
|
|
Router,
|
|
};
|
|
use futures_util::stream::{self, StreamExt};
|
|
use serde_json::json;
|
|
use std::convert::Infallible;
|
|
use std::sync::Arc;
|
|
use std::time::{SystemTime, UNIX_EPOCH};
|
|
use tokio::sync::RwLock;
|
|
use uuid::Uuid;
|
|
|
|
/// Configuration for mock worker behavior
|
|
#[derive(Clone)]
|
|
pub struct MockWorkerConfig {
|
|
pub port: u16,
|
|
pub worker_type: WorkerType,
|
|
pub health_status: HealthStatus,
|
|
pub response_delay_ms: u64,
|
|
pub fail_rate: f32,
|
|
}
|
|
|
|
#[derive(Clone, Debug)]
|
|
pub enum WorkerType {
|
|
Regular,
|
|
Prefill,
|
|
Decode,
|
|
}
|
|
|
|
#[derive(Clone, Debug)]
|
|
pub enum HealthStatus {
|
|
Healthy,
|
|
Unhealthy,
|
|
Degraded,
|
|
}
|
|
|
|
/// Mock worker server for testing
|
|
pub struct MockWorker {
|
|
config: Arc<RwLock<MockWorkerConfig>>,
|
|
shutdown_handle: Option<tokio::task::JoinHandle<()>>,
|
|
shutdown_tx: Option<tokio::sync::oneshot::Sender<()>>,
|
|
}
|
|
|
|
impl MockWorker {
|
|
pub fn new(config: MockWorkerConfig) -> Self {
|
|
Self {
|
|
config: Arc::new(RwLock::new(config)),
|
|
shutdown_handle: None,
|
|
shutdown_tx: None,
|
|
}
|
|
}
|
|
|
|
/// Start the mock worker server
|
|
pub async fn start(&mut self) -> Result<String, Box<dyn std::error::Error>> {
|
|
let config = self.config.clone();
|
|
let port = config.read().await.port;
|
|
|
|
// If port is 0, find an available port
|
|
let port = if port == 0 {
|
|
let listener = std::net::TcpListener::bind("127.0.0.1:0")?;
|
|
let port = listener.local_addr()?.port();
|
|
drop(listener);
|
|
config.write().await.port = port;
|
|
port
|
|
} else {
|
|
port
|
|
};
|
|
|
|
let app = Router::new()
|
|
.route("/health", get(health_handler))
|
|
.route("/health_generate", get(health_generate_handler))
|
|
.route("/get_server_info", get(server_info_handler))
|
|
.route("/get_model_info", get(model_info_handler))
|
|
.route("/generate", post(generate_handler))
|
|
.route("/v1/chat/completions", post(chat_completions_handler))
|
|
.route("/v1/completions", post(completions_handler))
|
|
.route("/flush_cache", post(flush_cache_handler))
|
|
.route("/v1/models", get(v1_models_handler))
|
|
.with_state(config);
|
|
|
|
let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel::<()>();
|
|
self.shutdown_tx = Some(shutdown_tx);
|
|
|
|
// Spawn the server in a separate task
|
|
let handle = tokio::spawn(async move {
|
|
let listener = match tokio::net::TcpListener::bind(("127.0.0.1", port)).await {
|
|
Ok(l) => l,
|
|
Err(e) => {
|
|
eprintln!("Failed to bind to port {}: {}", port, e);
|
|
return;
|
|
}
|
|
};
|
|
|
|
let server = axum::serve(listener, app).with_graceful_shutdown(async move {
|
|
let _ = shutdown_rx.await;
|
|
});
|
|
|
|
if let Err(e) = server.await {
|
|
eprintln!("Server error: {}", e);
|
|
}
|
|
});
|
|
|
|
self.shutdown_handle = Some(handle);
|
|
|
|
// Wait for the server to start
|
|
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
|
|
|
|
let url = format!("http://127.0.0.1:{}", port);
|
|
Ok(url)
|
|
}
|
|
|
|
/// Stop the mock worker server
|
|
pub async fn stop(&mut self) {
|
|
if let Some(shutdown_tx) = self.shutdown_tx.take() {
|
|
let _ = shutdown_tx.send(());
|
|
}
|
|
|
|
if let Some(handle) = self.shutdown_handle.take() {
|
|
// Wait for the server to shut down
|
|
let _ = tokio::time::timeout(tokio::time::Duration::from_secs(5), handle).await;
|
|
}
|
|
}
|
|
}
|
|
|
|
impl Drop for MockWorker {
|
|
fn drop(&mut self) {
|
|
// Clean shutdown when dropped
|
|
if let Some(shutdown_tx) = self.shutdown_tx.take() {
|
|
let _ = shutdown_tx.send(());
|
|
}
|
|
}
|
|
}
|
|
|
|
// Handler implementations
|
|
|
|
/// Check if request should fail based on configured fail_rate
|
|
async fn should_fail(config: &MockWorkerConfig) -> bool {
|
|
rand::random::<f32>() < config.fail_rate
|
|
}
|
|
|
|
async fn health_handler(State(config): State<Arc<RwLock<MockWorkerConfig>>>) -> Response {
|
|
let config = config.read().await;
|
|
|
|
match config.health_status {
|
|
HealthStatus::Healthy => Json(json!({
|
|
"status": "healthy",
|
|
"timestamp": SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs(),
|
|
"worker_type": format!("{:?}", config.worker_type),
|
|
}))
|
|
.into_response(),
|
|
HealthStatus::Unhealthy => (
|
|
StatusCode::SERVICE_UNAVAILABLE,
|
|
Json(json!({
|
|
"status": "unhealthy",
|
|
"error": "Worker is not responding"
|
|
})),
|
|
)
|
|
.into_response(),
|
|
HealthStatus::Degraded => Json(json!({
|
|
"status": "degraded",
|
|
"warning": "High load detected"
|
|
}))
|
|
.into_response(),
|
|
}
|
|
}
|
|
|
|
async fn health_generate_handler(State(config): State<Arc<RwLock<MockWorkerConfig>>>) -> Response {
|
|
let config = config.read().await;
|
|
|
|
if should_fail(&config).await {
|
|
return (
|
|
StatusCode::INTERNAL_SERVER_ERROR,
|
|
Json(json!({
|
|
"error": "Random failure for testing"
|
|
})),
|
|
)
|
|
.into_response();
|
|
}
|
|
|
|
if matches!(config.health_status, HealthStatus::Healthy) {
|
|
Json(json!({
|
|
"status": "ok",
|
|
"queue_length": 0,
|
|
"processing_time_ms": config.response_delay_ms
|
|
}))
|
|
.into_response()
|
|
} else {
|
|
(
|
|
StatusCode::SERVICE_UNAVAILABLE,
|
|
Json(json!({
|
|
"error": "Generation service unavailable"
|
|
})),
|
|
)
|
|
.into_response()
|
|
}
|
|
}
|
|
|
|
async fn server_info_handler(State(config): State<Arc<RwLock<MockWorkerConfig>>>) -> Response {
|
|
let config = config.read().await;
|
|
|
|
if should_fail(&config).await {
|
|
return (
|
|
StatusCode::INTERNAL_SERVER_ERROR,
|
|
Json(json!({
|
|
"error": "Random failure for testing"
|
|
})),
|
|
)
|
|
.into_response();
|
|
}
|
|
|
|
Json(json!({
|
|
"model_path": "mock-model-path",
|
|
"tokenizer_path": "mock-tokenizer-path",
|
|
"port": config.port,
|
|
"host": "127.0.0.1",
|
|
"max_num_batched_tokens": 32768,
|
|
"max_prefill_tokens": 16384,
|
|
"mem_fraction_static": 0.88,
|
|
"tp_size": 1,
|
|
"dp_size": 1,
|
|
"stream_interval": 8,
|
|
"dtype": "float16",
|
|
"device": "cuda",
|
|
"enable_flashinfer": true,
|
|
"enable_p2p_check": true,
|
|
"context_length": 32768,
|
|
"chat_template": null,
|
|
"disable_radix_cache": false,
|
|
"enable_torch_compile": false,
|
|
"trust_remote_code": false,
|
|
"show_time_cost": false,
|
|
"waiting_queue_size": 0,
|
|
"running_queue_size": 0,
|
|
"req_to_token_ratio": 1.2,
|
|
"min_running_requests": 0,
|
|
"max_running_requests": 2048,
|
|
"max_req_num": 8192,
|
|
"max_batch_tokens": 32768,
|
|
"schedule_policy": "lpm",
|
|
"schedule_conservativeness": 1.0,
|
|
"version": "0.3.0",
|
|
"internal_states": [{
|
|
"waiting_queue_size": 0,
|
|
"running_queue_size": 0
|
|
}]
|
|
}))
|
|
.into_response()
|
|
}
|
|
|
|
async fn model_info_handler(State(config): State<Arc<RwLock<MockWorkerConfig>>>) -> Response {
|
|
let config = config.read().await;
|
|
|
|
if should_fail(&config).await {
|
|
return (
|
|
StatusCode::INTERNAL_SERVER_ERROR,
|
|
Json(json!({
|
|
"error": "Random failure for testing"
|
|
})),
|
|
)
|
|
.into_response();
|
|
}
|
|
|
|
Json(json!({
|
|
"model_path": "mock-model-path",
|
|
"tokenizer_path": "mock-tokenizer-path",
|
|
"is_generation": true,
|
|
"preferred_sampling_params": {
|
|
"temperature": 0.7,
|
|
"top_p": 0.9,
|
|
"top_k": 40,
|
|
"max_tokens": 2048
|
|
}
|
|
}))
|
|
.into_response()
|
|
}
|
|
|
|
async fn generate_handler(
|
|
State(config): State<Arc<RwLock<MockWorkerConfig>>>,
|
|
Json(payload): Json<serde_json::Value>,
|
|
) -> Response {
|
|
let config = config.read().await;
|
|
|
|
if should_fail(&config).await {
|
|
return (
|
|
StatusCode::INTERNAL_SERVER_ERROR,
|
|
Json(json!({
|
|
"error": "Random failure for testing"
|
|
})),
|
|
)
|
|
.into_response();
|
|
}
|
|
|
|
if config.response_delay_ms > 0 {
|
|
tokio::time::sleep(tokio::time::Duration::from_millis(config.response_delay_ms)).await;
|
|
}
|
|
|
|
let is_stream = payload
|
|
.get("stream")
|
|
.and_then(|v| v.as_bool())
|
|
.unwrap_or(false);
|
|
|
|
if is_stream {
|
|
let stream_delay = config.response_delay_ms;
|
|
|
|
// Check if it's a batch request
|
|
let is_batch = payload.get("text").and_then(|t| t.as_array()).is_some();
|
|
|
|
let batch_size = if is_batch {
|
|
payload
|
|
.get("text")
|
|
.and_then(|t| t.as_array())
|
|
.map(|arr| arr.len())
|
|
.unwrap_or(1)
|
|
} else {
|
|
1
|
|
};
|
|
|
|
let mut events = Vec::new();
|
|
|
|
// Generate events for each item in batch
|
|
for i in 0..batch_size {
|
|
let timestamp_start = SystemTime::now()
|
|
.duration_since(UNIX_EPOCH)
|
|
.unwrap()
|
|
.as_secs_f64();
|
|
|
|
let data = json!({
|
|
"text": format!("Mock response {}", i + 1),
|
|
"meta_info": {
|
|
"prompt_tokens": 10,
|
|
"completion_tokens": 5,
|
|
"completion_tokens_wo_jump_forward": 5,
|
|
"input_token_logprobs": null,
|
|
"output_token_logprobs": null,
|
|
"first_token_latency": stream_delay as f64 / 1000.0,
|
|
"time_to_first_token": stream_delay as f64 / 1000.0,
|
|
"time_per_output_token": 0.01,
|
|
"end_time": timestamp_start + (stream_delay as f64 / 1000.0),
|
|
"start_time": timestamp_start,
|
|
"finish_reason": {
|
|
"type": "stop",
|
|
"reason": "length"
|
|
}
|
|
},
|
|
"stage": "mid"
|
|
});
|
|
|
|
events.push(Ok::<_, Infallible>(Event::default().data(data.to_string())));
|
|
}
|
|
|
|
// Add [DONE] event
|
|
events.push(Ok(Event::default().data("[DONE]")));
|
|
|
|
let stream = stream::iter(events);
|
|
|
|
Sse::new(stream)
|
|
.keep_alive(KeepAlive::default())
|
|
.into_response()
|
|
} else {
|
|
Json(json!({
|
|
"text": "This is a mock response.",
|
|
"meta_info": {
|
|
"prompt_tokens": 10,
|
|
"completion_tokens": 5,
|
|
"completion_tokens_wo_jump_forward": 5,
|
|
"input_token_logprobs": null,
|
|
"output_token_logprobs": null,
|
|
"first_token_latency": config.response_delay_ms as f64 / 1000.0,
|
|
"time_to_first_token": config.response_delay_ms as f64 / 1000.0,
|
|
"time_per_output_token": 0.01,
|
|
"finish_reason": {
|
|
"type": "stop",
|
|
"reason": "length"
|
|
}
|
|
}
|
|
}))
|
|
.into_response()
|
|
}
|
|
}
|
|
|
|
async fn chat_completions_handler(
|
|
State(config): State<Arc<RwLock<MockWorkerConfig>>>,
|
|
Json(payload): Json<serde_json::Value>,
|
|
) -> Response {
|
|
let config = config.read().await;
|
|
|
|
if should_fail(&config).await {
|
|
return (
|
|
StatusCode::INTERNAL_SERVER_ERROR,
|
|
Json(json!({
|
|
"error": {
|
|
"message": "Random failure for testing",
|
|
"type": "internal_error",
|
|
"code": "internal_error"
|
|
}
|
|
})),
|
|
)
|
|
.into_response();
|
|
}
|
|
|
|
if config.response_delay_ms > 0 {
|
|
tokio::time::sleep(tokio::time::Duration::from_millis(config.response_delay_ms)).await;
|
|
}
|
|
|
|
let is_stream = payload
|
|
.get("stream")
|
|
.and_then(|v| v.as_bool())
|
|
.unwrap_or(false);
|
|
|
|
let timestamp = SystemTime::now()
|
|
.duration_since(UNIX_EPOCH)
|
|
.unwrap()
|
|
.as_secs();
|
|
|
|
if is_stream {
|
|
let request_id = format!("chatcmpl-{}", Uuid::new_v4());
|
|
|
|
let stream = stream::once(async move {
|
|
let chunk = json!({
|
|
"id": request_id,
|
|
"object": "chat.completion.chunk",
|
|
"created": timestamp,
|
|
"model": "mock-model",
|
|
"choices": [{
|
|
"index": 0,
|
|
"delta": {
|
|
"content": "This is a mock chat response."
|
|
},
|
|
"finish_reason": null
|
|
}]
|
|
});
|
|
|
|
Ok::<_, Infallible>(Event::default().data(chunk.to_string()))
|
|
})
|
|
.chain(stream::once(async { Ok(Event::default().data("[DONE]")) }));
|
|
|
|
Sse::new(stream)
|
|
.keep_alive(KeepAlive::default())
|
|
.into_response()
|
|
} else {
|
|
Json(json!({
|
|
"id": format!("chatcmpl-{}", Uuid::new_v4()),
|
|
"object": "chat.completion",
|
|
"created": timestamp,
|
|
"model": "mock-model",
|
|
"choices": [{
|
|
"index": 0,
|
|
"message": {
|
|
"role": "assistant",
|
|
"content": "This is a mock chat response."
|
|
},
|
|
"finish_reason": "stop"
|
|
}],
|
|
"usage": {
|
|
"prompt_tokens": 10,
|
|
"completion_tokens": 5,
|
|
"total_tokens": 15
|
|
}
|
|
}))
|
|
.into_response()
|
|
}
|
|
}
|
|
|
|
async fn completions_handler(
|
|
State(config): State<Arc<RwLock<MockWorkerConfig>>>,
|
|
Json(payload): Json<serde_json::Value>,
|
|
) -> Response {
|
|
let config = config.read().await;
|
|
|
|
if should_fail(&config).await {
|
|
return (
|
|
StatusCode::INTERNAL_SERVER_ERROR,
|
|
Json(json!({
|
|
"error": {
|
|
"message": "Random failure for testing",
|
|
"type": "internal_error",
|
|
"code": "internal_error"
|
|
}
|
|
})),
|
|
)
|
|
.into_response();
|
|
}
|
|
|
|
if config.response_delay_ms > 0 {
|
|
tokio::time::sleep(tokio::time::Duration::from_millis(config.response_delay_ms)).await;
|
|
}
|
|
|
|
let is_stream = payload
|
|
.get("stream")
|
|
.and_then(|v| v.as_bool())
|
|
.unwrap_or(false);
|
|
|
|
let timestamp = SystemTime::now()
|
|
.duration_since(UNIX_EPOCH)
|
|
.unwrap()
|
|
.as_secs();
|
|
|
|
if is_stream {
|
|
let request_id = format!("cmpl-{}", Uuid::new_v4());
|
|
|
|
let stream = stream::once(async move {
|
|
let chunk = json!({
|
|
"id": request_id,
|
|
"object": "text_completion",
|
|
"created": timestamp,
|
|
"model": "mock-model",
|
|
"choices": [{
|
|
"text": "This is a mock completion.",
|
|
"index": 0,
|
|
"logprobs": null,
|
|
"finish_reason": null
|
|
}]
|
|
});
|
|
|
|
Ok::<_, Infallible>(Event::default().data(chunk.to_string()))
|
|
})
|
|
.chain(stream::once(async { Ok(Event::default().data("[DONE]")) }));
|
|
|
|
Sse::new(stream)
|
|
.keep_alive(KeepAlive::default())
|
|
.into_response()
|
|
} else {
|
|
Json(json!({
|
|
"id": format!("cmpl-{}", Uuid::new_v4()),
|
|
"object": "text_completion",
|
|
"created": timestamp,
|
|
"model": "mock-model",
|
|
"choices": [{
|
|
"text": "This is a mock completion.",
|
|
"index": 0,
|
|
"logprobs": null,
|
|
"finish_reason": "stop"
|
|
}],
|
|
"usage": {
|
|
"prompt_tokens": 10,
|
|
"completion_tokens": 5,
|
|
"total_tokens": 15
|
|
}
|
|
}))
|
|
.into_response()
|
|
}
|
|
}
|
|
|
|
async fn flush_cache_handler(State(config): State<Arc<RwLock<MockWorkerConfig>>>) -> Response {
|
|
let config = config.read().await;
|
|
|
|
if should_fail(&config).await {
|
|
return (
|
|
StatusCode::INTERNAL_SERVER_ERROR,
|
|
Json(json!({
|
|
"error": "Random failure for testing"
|
|
})),
|
|
)
|
|
.into_response();
|
|
}
|
|
|
|
Json(json!({
|
|
"message": "Cache flushed successfully"
|
|
}))
|
|
.into_response()
|
|
}
|
|
|
|
async fn v1_models_handler(State(config): State<Arc<RwLock<MockWorkerConfig>>>) -> Response {
|
|
let config = config.read().await;
|
|
|
|
if should_fail(&config).await {
|
|
return (
|
|
StatusCode::INTERNAL_SERVER_ERROR,
|
|
Json(json!({
|
|
"error": {
|
|
"message": "Random failure for testing",
|
|
"type": "internal_error",
|
|
"code": "internal_error"
|
|
}
|
|
})),
|
|
)
|
|
.into_response();
|
|
}
|
|
|
|
let timestamp = SystemTime::now()
|
|
.duration_since(UNIX_EPOCH)
|
|
.unwrap()
|
|
.as_secs();
|
|
|
|
Json(json!({
|
|
"object": "list",
|
|
"data": [{
|
|
"id": "mock-model",
|
|
"object": "model",
|
|
"created": timestamp,
|
|
"owned_by": "organization-owner"
|
|
}]
|
|
}))
|
|
.into_response()
|
|
}
|
|
|
|
impl Default for MockWorkerConfig {
|
|
fn default() -> Self {
|
|
Self {
|
|
port: 0,
|
|
worker_type: WorkerType::Regular,
|
|
health_status: HealthStatus::Healthy,
|
|
response_delay_ms: 0,
|
|
fail_rate: 0.0,
|
|
}
|
|
}
|
|
}
|