239 lines
7.0 KiB
Rust
239 lines
7.0 KiB
Rust
//! Mock servers for testing
|
|
|
|
#![allow(dead_code)]
|
|
|
|
use axum::{
|
|
body::Body,
|
|
extract::{Request, State},
|
|
http::{HeaderValue, StatusCode},
|
|
response::sse::{Event, KeepAlive},
|
|
response::{IntoResponse, Response, Sse},
|
|
routing::post,
|
|
Json, Router,
|
|
};
|
|
use futures_util::stream::{self, StreamExt};
|
|
use serde_json::json;
|
|
use std::net::SocketAddr;
|
|
use std::sync::Arc;
|
|
use tokio::net::TcpListener;
|
|
|
|
/// Mock OpenAI API server for testing
|
|
pub struct MockOpenAIServer {
|
|
addr: SocketAddr,
|
|
_handle: tokio::task::JoinHandle<()>,
|
|
}
|
|
|
|
#[derive(Clone)]
|
|
struct MockServerState {
|
|
require_auth: bool,
|
|
expected_auth: Option<String>,
|
|
}
|
|
|
|
impl MockOpenAIServer {
|
|
/// Create and start a new mock OpenAI server
|
|
pub async fn new() -> Self {
|
|
Self::new_with_auth(None).await
|
|
}
|
|
|
|
/// Create and start a new mock OpenAI server with optional auth requirement
|
|
pub async fn new_with_auth(expected_auth: Option<String>) -> Self {
|
|
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
|
let addr = listener.local_addr().unwrap();
|
|
|
|
let state = Arc::new(MockServerState {
|
|
require_auth: expected_auth.is_some(),
|
|
expected_auth,
|
|
});
|
|
|
|
let app = Router::new()
|
|
.route("/v1/chat/completions", post(mock_chat_completions))
|
|
.route("/v1/completions", post(mock_completions))
|
|
.route("/v1/models", post(mock_models).get(mock_models))
|
|
.with_state(state);
|
|
|
|
let handle = tokio::spawn(async move {
|
|
axum::serve(listener, app).await.unwrap();
|
|
});
|
|
|
|
// Give the server a moment to start
|
|
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
|
|
|
|
Self {
|
|
addr,
|
|
_handle: handle,
|
|
}
|
|
}
|
|
|
|
/// Get the base URL for this mock server
|
|
pub fn base_url(&self) -> String {
|
|
format!("http://{}", self.addr)
|
|
}
|
|
}
|
|
|
|
/// Mock chat completions endpoint
|
|
async fn mock_chat_completions(req: Request<Body>) -> Response {
|
|
let (_, body) = req.into_parts();
|
|
let body_bytes = match axum::body::to_bytes(body, usize::MAX).await {
|
|
Ok(bytes) => bytes,
|
|
Err(_) => return StatusCode::BAD_REQUEST.into_response(),
|
|
};
|
|
|
|
let request: serde_json::Value = match serde_json::from_slice(&body_bytes) {
|
|
Ok(req) => req,
|
|
Err(_) => return StatusCode::BAD_REQUEST.into_response(),
|
|
};
|
|
|
|
// Extract model from request or use default (owned String to satisfy 'static in stream)
|
|
let model: String = request
|
|
.get("model")
|
|
.and_then(|v| v.as_str())
|
|
.unwrap_or("gpt-3.5-turbo")
|
|
.to_string();
|
|
|
|
// If stream requested, return SSE
|
|
let is_stream = request
|
|
.get("stream")
|
|
.and_then(|v| v.as_bool())
|
|
.unwrap_or(false);
|
|
|
|
if is_stream {
|
|
let created = 1677652288u64;
|
|
// Single chunk then [DONE]
|
|
let model_chunk = model.clone();
|
|
let event_stream = stream::once(async move {
|
|
let chunk = json!({
|
|
"id": "chatcmpl-123456789",
|
|
"object": "chat.completion.chunk",
|
|
"created": created,
|
|
"model": model_chunk,
|
|
"choices": [{
|
|
"index": 0,
|
|
"delta": {
|
|
"content": "Hello!"
|
|
},
|
|
"finish_reason": null
|
|
}]
|
|
});
|
|
Ok::<_, std::convert::Infallible>(Event::default().data(chunk.to_string()))
|
|
})
|
|
.chain(stream::once(async { Ok(Event::default().data("[DONE]")) }));
|
|
|
|
Sse::new(event_stream)
|
|
.keep_alive(KeepAlive::default())
|
|
.into_response()
|
|
} else {
|
|
// Create a mock non-streaming response
|
|
let response = json!({
|
|
"id": "chatcmpl-123456789",
|
|
"object": "chat.completion",
|
|
"created": 1677652288,
|
|
"model": model,
|
|
"choices": [{
|
|
"index": 0,
|
|
"message": {
|
|
"role": "assistant",
|
|
"content": "Hello! I'm a mock OpenAI assistant. How can I help you today?"
|
|
},
|
|
"finish_reason": "stop"
|
|
}],
|
|
"usage": {
|
|
"prompt_tokens": 9,
|
|
"completion_tokens": 12,
|
|
"total_tokens": 21
|
|
}
|
|
});
|
|
|
|
Json(response).into_response()
|
|
}
|
|
}
|
|
|
|
/// Mock completions endpoint (legacy)
|
|
async fn mock_completions(req: Request<Body>) -> Response {
|
|
let (_, body) = req.into_parts();
|
|
let body_bytes = match axum::body::to_bytes(body, usize::MAX).await {
|
|
Ok(bytes) => bytes,
|
|
Err(_) => return StatusCode::BAD_REQUEST.into_response(),
|
|
};
|
|
|
|
let request: serde_json::Value = match serde_json::from_slice(&body_bytes) {
|
|
Ok(req) => req,
|
|
Err(_) => return StatusCode::BAD_REQUEST.into_response(),
|
|
};
|
|
|
|
let model = request["model"].as_str().unwrap_or("text-davinci-003");
|
|
|
|
let response = json!({
|
|
"id": "cmpl-123456789",
|
|
"object": "text_completion",
|
|
"created": 1677652288,
|
|
"model": model,
|
|
"choices": [{
|
|
"text": " This is a mock completion response.",
|
|
"index": 0,
|
|
"logprobs": null,
|
|
"finish_reason": "stop"
|
|
}],
|
|
"usage": {
|
|
"prompt_tokens": 5,
|
|
"completion_tokens": 7,
|
|
"total_tokens": 12
|
|
}
|
|
});
|
|
|
|
Json(response).into_response()
|
|
}
|
|
|
|
/// Mock models endpoint
|
|
async fn mock_models(State(state): State<Arc<MockServerState>>, req: Request<Body>) -> Response {
|
|
// Optionally enforce Authorization header
|
|
if state.require_auth {
|
|
let auth = req
|
|
.headers()
|
|
.get("authorization")
|
|
.or_else(|| req.headers().get("Authorization"))
|
|
.and_then(|v| v.to_str().ok())
|
|
.map(|s| s.to_string());
|
|
let auth_ok = match (&state.expected_auth, auth) {
|
|
(Some(expected), Some(got)) => &got == expected,
|
|
(None, Some(_)) => true,
|
|
_ => false,
|
|
};
|
|
if !auth_ok {
|
|
let mut response = Response::new(Body::from(
|
|
json!({
|
|
"error": {
|
|
"message": "Unauthorized",
|
|
"type": "invalid_request_error"
|
|
}
|
|
})
|
|
.to_string(),
|
|
));
|
|
*response.status_mut() = StatusCode::UNAUTHORIZED;
|
|
response
|
|
.headers_mut()
|
|
.insert("WWW-Authenticate", HeaderValue::from_static("Bearer"));
|
|
return response;
|
|
}
|
|
}
|
|
|
|
let response = json!({
|
|
"object": "list",
|
|
"data": [
|
|
{
|
|
"id": "gpt-4",
|
|
"object": "model",
|
|
"created": 1677610602,
|
|
"owned_by": "openai"
|
|
},
|
|
{
|
|
"id": "gpt-3.5-turbo",
|
|
"object": "model",
|
|
"created": 1677610602,
|
|
"owned_by": "openai"
|
|
}
|
|
]
|
|
});
|
|
|
|
Json(response).into_response()
|
|
}
|