mod common; use common::mock_worker::{HealthStatus, MockWorker, MockWorkerConfig, WorkerType}; use reqwest::Client; use serde_json::json; use sglang_router_rs::config::{ CircuitBreakerConfig, ConnectionMode, PolicyConfig, RetryConfig, RouterConfig, RoutingMode, }; use sglang_router_rs::routers::{RouterFactory, RouterTrait}; use std::sync::Arc; /// Test context that manages mock workers struct TestContext { workers: Vec, router: Arc, } impl TestContext { async fn new(worker_configs: Vec) -> Self { let mut config = RouterConfig { mode: RoutingMode::Regular { worker_urls: vec![], }, policy: PolicyConfig::Random, host: "127.0.0.1".to_string(), port: 3003, max_payload_size: 256 * 1024 * 1024, request_timeout_secs: 600, worker_startup_timeout_secs: 1, worker_startup_check_interval_secs: 1, dp_aware: false, api_key: None, discovery: None, metrics: None, log_dir: None, log_level: None, request_id_headers: None, max_concurrent_requests: 64, queue_size: 0, queue_timeout_secs: 60, rate_limit_tokens_per_second: None, cors_allowed_origins: vec![], retry: RetryConfig::default(), circuit_breaker: CircuitBreakerConfig::default(), disable_retries: false, disable_circuit_breaker: false, health_check: sglang_router_rs::config::HealthCheckConfig::default(), enable_igw: false, connection_mode: ConnectionMode::Http, model_path: None, tokenizer_path: None, }; let mut workers = Vec::new(); let mut worker_urls = Vec::new(); for worker_config in worker_configs { let mut worker = MockWorker::new(worker_config); let url = worker.start().await.unwrap(); worker_urls.push(url); workers.push(worker); } if !workers.is_empty() { tokio::time::sleep(tokio::time::Duration::from_millis(200)).await; } config.mode = RoutingMode::Regular { worker_urls }; let app_context = common::create_test_context(config); let router = RouterFactory::create_router(&app_context).await.unwrap(); let router = Arc::from(router); if !workers.is_empty() { tokio::time::sleep(tokio::time::Duration::from_millis(500)).await; } Self { workers, router } } async fn shutdown(mut self) { // Small delay to ensure any pending operations complete tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; for worker in &mut self.workers { worker.stop().await; } // Another small delay to ensure cleanup completes tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; } async fn make_request( &self, endpoint: &str, body: serde_json::Value, ) -> Result { let client = Client::new(); // Get any worker URL for testing let worker_urls = self.router.get_worker_urls(); if worker_urls.is_empty() { return Err("No available workers".to_string()); } let worker_url = &worker_urls[0]; let response = client .post(format!("{}{}", worker_url, endpoint)) .json(&body) .send() .await .map_err(|e| format!("Request failed: {}", e))?; if !response.status().is_success() { return Err(format!("Request failed with status: {}", response.status())); } response .json::() .await .map_err(|e| format!("Failed to parse response: {}", e)) } } #[cfg(test)] mod request_format_tests { use super::*; #[tokio::test] async fn test_generate_request_formats() { let ctx = TestContext::new(vec![MockWorkerConfig { port: 19001, worker_type: WorkerType::Regular, health_status: HealthStatus::Healthy, response_delay_ms: 0, fail_rate: 0.0, }]) .await; // Test 1: Basic text request let payload = json!({ "text": "Hello, world!", "stream": false }); let result = ctx.make_request("/generate", payload).await; assert!(result.is_ok()); // Test 2: Request with sampling parameters let payload = json!({ "text": "Tell me a story", "sampling_params": { "temperature": 0.7, "max_new_tokens": 100, "top_p": 0.9 }, "stream": false }); let result = ctx.make_request("/generate", payload).await; assert!(result.is_ok()); // Test 3: Request with input_ids let payload = json!({ "input_ids": [1, 2, 3, 4, 5], "sampling_params": { "temperature": 0.0, "max_new_tokens": 50 }, "stream": false }); let result = ctx.make_request("/generate", payload).await; assert!(result.is_ok()); ctx.shutdown().await; } #[tokio::test] async fn test_v1_chat_completions_formats() { let ctx = TestContext::new(vec![MockWorkerConfig { port: 19002, worker_type: WorkerType::Regular, health_status: HealthStatus::Healthy, response_delay_ms: 0, fail_rate: 0.0, }]) .await; // Test 1: Basic chat completion let payload = json!({ "model": "test-model", "messages": [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Hello!"} ], "stream": false }); let result = ctx.make_request("/v1/chat/completions", payload).await; assert!(result.is_ok()); let response = result.unwrap(); assert!(response.get("choices").is_some()); assert!(response.get("id").is_some()); assert_eq!( response.get("object").and_then(|v| v.as_str()), Some("chat.completion") ); // Test 2: Chat completion with parameters let payload = json!({ "model": "test-model", "messages": [ {"role": "user", "content": "Tell me a joke"} ], "temperature": 0.8, "max_tokens": 150, "top_p": 0.95, "stream": false }); let result = ctx.make_request("/v1/chat/completions", payload).await; assert!(result.is_ok()); ctx.shutdown().await; } #[tokio::test] async fn test_v1_completions_formats() { let ctx = TestContext::new(vec![MockWorkerConfig { port: 19003, worker_type: WorkerType::Regular, health_status: HealthStatus::Healthy, response_delay_ms: 0, fail_rate: 0.0, }]) .await; // Test 1: Basic completion let payload = json!({ "model": "test-model", "prompt": "Once upon a time", "max_tokens": 50, "stream": false }); let result = ctx.make_request("/v1/completions", payload).await; assert!(result.is_ok()); let response = result.unwrap(); assert!(response.get("choices").is_some()); assert_eq!( response.get("object").and_then(|v| v.as_str()), Some("text_completion") ); // Test 2: Completion with array prompt let payload = json!({ "model": "test-model", "prompt": ["First prompt", "Second prompt"], "temperature": 0.5, "stream": false }); let result = ctx.make_request("/v1/completions", payload).await; assert!(result.is_ok()); // Test 3: Completion with logprobs let payload = json!({ "model": "test-model", "prompt": "The capital of France is", "max_tokens": 10, "logprobs": 5, "stream": false }); let result = ctx.make_request("/v1/completions", payload).await; assert!(result.is_ok()); ctx.shutdown().await; } #[tokio::test] async fn test_batch_requests() { let ctx = TestContext::new(vec![MockWorkerConfig { port: 19004, worker_type: WorkerType::Regular, health_status: HealthStatus::Healthy, response_delay_ms: 0, fail_rate: 0.0, }]) .await; // Test batch text generation let payload = json!({ "text": ["First text", "Second text", "Third text"], "sampling_params": { "temperature": 0.7, "max_new_tokens": 50 }, "stream": false }); let result = ctx.make_request("/generate", payload).await; assert!(result.is_ok()); // Test batch with input_ids let payload = json!({ "input_ids": [[1, 2, 3], [4, 5, 6], [7, 8, 9]], "stream": false }); let result = ctx.make_request("/generate", payload).await; assert!(result.is_ok()); ctx.shutdown().await; } #[tokio::test] async fn test_special_parameters() { let ctx = TestContext::new(vec![MockWorkerConfig { port: 19005, worker_type: WorkerType::Regular, health_status: HealthStatus::Healthy, response_delay_ms: 0, fail_rate: 0.0, }]) .await; // Test with return_logprob let payload = json!({ "text": "Test", "return_logprob": true, "stream": false }); let result = ctx.make_request("/generate", payload).await; assert!(result.is_ok()); // Test with json_schema let payload = json!({ "text": "Generate JSON", "sampling_params": { "temperature": 0.0, "json_schema": "$$ANY$$" }, "stream": false }); let result = ctx.make_request("/generate", payload).await; assert!(result.is_ok()); // Test with ignore_eos let payload = json!({ "text": "Continue forever", "sampling_params": { "temperature": 0.7, "max_new_tokens": 100, "ignore_eos": true }, "stream": false }); let result = ctx.make_request("/generate", payload).await; assert!(result.is_ok()); ctx.shutdown().await; } #[tokio::test] async fn test_error_handling() { let ctx = TestContext::new(vec![MockWorkerConfig { port: 19006, worker_type: WorkerType::Regular, health_status: HealthStatus::Healthy, response_delay_ms: 0, fail_rate: 0.0, }]) .await; // Test with empty body - should still work with mock worker let payload = json!({}); let result = ctx.make_request("/generate", payload).await; // Mock worker accepts empty body assert!(result.is_ok()); ctx.shutdown().await; } }