feat(mpc-service): 实现混合传输模式 (WebSocket + HTTP轮询)

- 优先尝试 WebSocket 连接 (5秒超时)
- WebSocket 失败自动降级到 HTTP 轮询
- HTTP 轮询间隔 100ms,总超时 5分钟
- 新增 getTransportMode() 方法查看当前传输模式
- 修复 message-router 404 导致的 socket hang up 问题

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
Developer 2025-12-04 00:00:49 -08:00
parent a701f55342
commit 178a5c9f8b
1 changed files with 324 additions and 100 deletions

View File

@ -1,13 +1,13 @@
/** /**
* MPC Message Router Client * MPC Message Router Client
* *
* WebSocket client for real-time message exchange between MPC parties. * Hybrid client for message exchange between MPC parties.
* Strategy: Try WebSocket first, fallback to HTTP polling if WebSocket fails.
*/ */
import { Injectable, Logger, OnModuleInit, OnModuleDestroy } from '@nestjs/common'; import { Injectable, Logger, OnModuleInit, OnModuleDestroy } from '@nestjs/common';
import { ConfigService } from '@nestjs/config'; import { ConfigService } from '@nestjs/config';
import * as WebSocket from 'ws'; import WebSocket from 'ws';
import { EventEmitter } from 'events';
export interface MPCMessage { export interface MPCMessage {
fromParty: string; fromParty: string;
@ -29,197 +29,421 @@ export interface MessageStream {
close(): void; close(): void;
} }
type TransportMode = 'websocket' | 'http';
interface ConnectionState {
sessionId: string;
partyId: string;
mode: TransportMode;
closed: boolean;
ws?: WebSocket;
lastPollTime?: number;
}
@Injectable() @Injectable()
export class MPCMessageRouterClient implements OnModuleInit, OnModuleDestroy { export class MPCMessageRouterClient implements OnModuleInit, OnModuleDestroy {
private readonly logger = new Logger(MPCMessageRouterClient.name); private readonly logger = new Logger(MPCMessageRouterClient.name);
private wsUrl: string; private wsUrl: string;
private connections: Map<string, WebSocket> = new Map(); private httpUrl: string;
private connections: Map<string, ConnectionState> = new Map();
// Configuration
private readonly WS_CONNECT_TIMEOUT_MS = 5000; // 5 seconds to establish WebSocket
private readonly POLL_INTERVAL_MS = 100; // Poll every 100ms (HTTP fallback)
private readonly POLL_TIMEOUT_MS = 300000; // 5 minute total timeout
private readonly REQUEST_TIMEOUT_MS = 10000; // 10 second per-request timeout
constructor(private readonly configService: ConfigService) {} constructor(private readonly configService: ConfigService) {}
onModuleInit() { onModuleInit() {
this.wsUrl = this.configService.get<string>('MPC_MESSAGE_ROUTER_WS_URL') || ''; this.wsUrl = this.configService.get<string>('MPC_MESSAGE_ROUTER_WS_URL') || '';
this.httpUrl = this.wsUrl.replace('ws://', 'http://').replace('wss://', 'https://');
if (!this.wsUrl) { if (!this.wsUrl) {
this.logger.warn('MPC_MESSAGE_ROUTER_WS_URL not configured'); this.logger.warn('MPC_MESSAGE_ROUTER_WS_URL not configured');
} else {
this.logger.log(`Message router configured - WS: ${this.wsUrl}, HTTP: ${this.httpUrl}`);
} }
} }
onModuleDestroy() { onModuleDestroy() {
// Close all WebSocket connections for (const [key, state] of this.connections) {
for (const [key, ws] of this.connections) { this.logger.debug(`Closing connection: ${key}`);
this.logger.debug(`Closing WebSocket connection: ${key}`); state.closed = true;
ws.close(); if (state.ws) {
state.ws.close();
}
} }
this.connections.clear(); this.connections.clear();
} }
/** /**
* Subscribe to messages for a session/party * Subscribe to messages - tries WebSocket first, falls back to HTTP polling
*/ */
async subscribeMessages(sessionId: string, partyId: string): Promise<MessageStream> { async subscribeMessages(sessionId: string, partyId: string): Promise<MessageStream> {
const connectionKey = `${sessionId}:${partyId}`; const connectionKey = `${sessionId}:${partyId}`;
this.logger.log(`Subscribing to messages: ${connectionKey}`); this.logger.log(`Subscribing to messages: ${connectionKey}`);
// Try WebSocket first
try {
const stream = await this.tryWebSocketSubscribe(sessionId, partyId);
this.logger.log(`WebSocket connection established for ${connectionKey}`);
return stream;
} catch (wsError) {
this.logger.warn(`WebSocket failed for ${connectionKey}, falling back to HTTP polling: ${wsError}`);
}
// Fallback to HTTP polling
return this.httpPollingSubscribe(sessionId, partyId);
}
/**
* Try to establish WebSocket connection with timeout
*/
private async tryWebSocketSubscribe(sessionId: string, partyId: string): Promise<MessageStream> {
const connectionKey = `${sessionId}:${partyId}`;
const url = `${this.wsUrl}/sessions/${sessionId}/messages?party_id=${partyId}`; const url = `${this.wsUrl}/sessions/${sessionId}/messages?party_id=${partyId}`;
const ws = new WebSocket(url);
this.connections.set(connectionKey, ws); return new Promise((resolve, reject) => {
const ws = new WebSocket(url);
const connectTimeout = setTimeout(() => {
ws.close();
reject(new Error('WebSocket connection timeout'));
}, this.WS_CONNECT_TIMEOUT_MS);
const state: ConnectionState = {
sessionId,
partyId,
mode: 'websocket',
closed: false,
ws,
};
const messageQueue: MPCMessage[] = [];
const waiters: Array<{
resolve: (value: { value: MPCMessage; done: false } | { done: true; value: undefined }) => void;
reject: (error: Error) => void;
}> = [];
let error: Error | null = null;
ws.on('open', () => {
clearTimeout(connectTimeout);
this.connections.set(connectionKey, state);
this.logger.debug(`WebSocket connected: ${connectionKey}`);
resolve({
next: () => {
return new Promise((res, rej) => {
if (error) {
rej(error);
return;
}
if (messageQueue.length > 0) {
res({ value: messageQueue.shift()!, done: false });
return;
}
if (state.closed) {
res({ done: true, value: undefined });
return;
}
waiters.push({ resolve: res, reject: rej });
});
},
close: () => {
state.closed = true;
ws.close();
this.connections.delete(connectionKey);
this.logger.debug(`WebSocket closed: ${connectionKey}`);
},
});
});
ws.on('message', (data: Buffer) => {
try {
const parsed = JSON.parse(data.toString());
const message: MPCMessage = {
fromParty: parsed.from_party,
toParties: parsed.to_parties,
roundNumber: parsed.round_number,
payload: Buffer.from(parsed.payload, 'base64'),
};
if (waiters.length > 0) {
const waiter = waiters.shift()!;
waiter.resolve({ value: message, done: false });
} else {
messageQueue.push(message);
}
} catch (err) {
this.logger.error('Failed to parse WebSocket message', err);
}
});
ws.on('error', (err) => {
clearTimeout(connectTimeout);
this.logger.error(`WebSocket error: ${connectionKey}`, err);
error = err instanceof Error ? err : new Error(String(err));
// If not yet connected, reject the promise
if (!this.connections.has(connectionKey)) {
reject(error);
return;
}
// Reject all waiting consumers
while (waiters.length > 0) {
const waiter = waiters.shift()!;
waiter.reject(error);
}
});
ws.on('close', () => {
clearTimeout(connectTimeout);
this.logger.debug(`WebSocket closed: ${connectionKey}`);
state.closed = true;
this.connections.delete(connectionKey);
// If not yet connected, reject
if (!this.connections.has(connectionKey) && !error) {
reject(new Error('WebSocket closed before connection established'));
return;
}
// Resolve all waiting consumers with done
while (waiters.length > 0) {
const waiter = waiters.shift()!;
waiter.resolve({ done: true, value: undefined });
}
});
});
}
/**
* HTTP polling fallback
*/
private async httpPollingSubscribe(sessionId: string, partyId: string): Promise<MessageStream> {
const connectionKey = `${sessionId}:${partyId}`;
this.logger.log(`Starting HTTP polling for ${connectionKey}`);
const state: ConnectionState = {
sessionId,
partyId,
mode: 'http',
closed: false,
lastPollTime: 0,
};
this.connections.set(connectionKey, state);
const messageQueue: MPCMessage[] = []; const messageQueue: MPCMessage[] = [];
const waiters: Array<{
resolve: (value: { value: MPCMessage; done: false } | { done: true; value: undefined }) => void;
reject: (error: Error) => void;
}> = [];
let closed = false;
let error: Error | null = null; let error: Error | null = null;
const startTime = Date.now();
ws.on('open', () => { // Background polling
this.logger.debug(`WebSocket connected: ${connectionKey}`); const pollLoop = async () => {
}); while (!state.closed) {
if (Date.now() - startTime > this.POLL_TIMEOUT_MS) {
ws.on('message', (data: Buffer) => { this.logger.warn(`HTTP polling timeout for ${connectionKey}`);
try { state.closed = true;
const parsed = JSON.parse(data.toString()); break;
const message: MPCMessage = {
fromParty: parsed.from_party,
toParties: parsed.to_parties,
roundNumber: parsed.round_number,
payload: Buffer.from(parsed.payload, 'base64'),
};
// If there's a waiting consumer, deliver immediately
if (waiters.length > 0) {
const waiter = waiters.shift()!;
waiter.resolve({ value: message, done: false });
} else {
// Otherwise queue the message
messageQueue.push(message);
} }
} catch (err) {
this.logger.error('Failed to parse message', err);
}
});
ws.on('error', (err) => { try {
this.logger.error(`WebSocket error: ${connectionKey}`, err); const messages = await this.fetchPendingMessages(sessionId, partyId);
for (const msg of messages) {
messageQueue.push(msg);
}
} catch (err) {
this.logger.debug(`HTTP poll error for ${connectionKey}: ${err}`);
}
if (!state.closed) {
await this.sleep(this.POLL_INTERVAL_MS);
}
}
};
pollLoop().catch((err) => {
this.logger.error(`HTTP polling failed for ${connectionKey}`, err);
error = err instanceof Error ? err : new Error(String(err)); error = err instanceof Error ? err : new Error(String(err));
// Reject all waiting consumers
while (waiters.length > 0) {
const waiter = waiters.shift()!;
waiter.reject(error);
}
});
ws.on('close', () => {
this.logger.debug(`WebSocket closed: ${connectionKey}`);
closed = true;
this.connections.delete(connectionKey);
// Resolve all waiting consumers with done
while (waiters.length > 0) {
const waiter = waiters.shift()!;
waiter.resolve({ done: true, value: undefined });
}
}); });
return { return {
next: () => { next: async () => {
return new Promise((resolve, reject) => { const waitStart = Date.now();
if (error) { const maxWait = 30000;
reject(error);
return; while (Date.now() - waitStart < maxWait) {
} if (error) throw error;
if (messageQueue.length > 0) { if (messageQueue.length > 0) {
resolve({ value: messageQueue.shift()!, done: false }); return { value: messageQueue.shift()!, done: false as const };
return;
} }
if (closed) { if (state.closed && messageQueue.length === 0) {
resolve({ done: true, value: undefined }); return { done: true as const, value: undefined };
return;
} }
// Wait for next message await this.sleep(50);
waiters.push({ resolve, reject }); }
});
if (state.closed) {
return { done: true as const, value: undefined };
}
// No messages received in time, but not closed - keep waiting
return { done: true as const, value: undefined };
}, },
close: () => { close: () => {
if (!closed) { state.closed = true;
ws.close(); this.connections.delete(connectionKey);
} this.logger.debug(`HTTP polling stopped: ${connectionKey}`);
}, },
}; };
} }
/** /**
* Send a message to other parties * Fetch pending messages via HTTP
*/
private async fetchPendingMessages(sessionId: string, partyId: string): Promise<MPCMessage[]> {
const url = `${this.httpUrl}/api/v1/messages/pending?session_id=${sessionId}&party_id=${partyId}`;
const controller = new AbortController();
const timeoutId = setTimeout(() => controller.abort(), this.REQUEST_TIMEOUT_MS);
try {
const response = await fetch(url, {
method: 'GET',
headers: { 'Accept': 'application/json' },
signal: controller.signal,
});
if (!response.ok) {
throw new Error(`HTTP ${response.status}: ${response.statusText}`);
}
const data = await response.json();
const messages: MPCMessage[] = [];
if (data.messages && Array.isArray(data.messages)) {
for (const msg of data.messages) {
messages.push({
fromParty: msg.from_party,
toParties: msg.to_parties,
roundNumber: msg.round_number,
payload: Buffer.from(msg.payload, 'base64'),
});
}
}
return messages;
} finally {
clearTimeout(timeoutId);
}
}
/**
* Send message - uses WebSocket if connected, otherwise HTTP
*/ */
async sendMessage(request: SendMessageRequest): Promise<void> { async sendMessage(request: SendMessageRequest): Promise<void> {
const connectionKey = `${request.sessionId}:${request.fromParty}`; const connectionKey = `${request.sessionId}:${request.fromParty}`;
const ws = this.connections.get(connectionKey); const state = this.connections.get(connectionKey);
if (ws && ws.readyState === WebSocket.OPEN) { // Try WebSocket if available
// Send via WebSocket if connected if (state?.mode === 'websocket' && state.ws?.readyState === WebSocket.OPEN) {
const message = JSON.stringify({ const message = JSON.stringify({
from_party: request.fromParty, from_party: request.fromParty,
to_parties: request.toParties, to_parties: request.toParties,
round_number: request.roundNumber, round_number: request.roundNumber,
payload: request.payload.toString('base64'), payload: request.payload.toString('base64'),
}); });
ws.send(message); state.ws.send(message);
} else { this.logger.debug(`Message sent via WebSocket: session=${request.sessionId}, round=${request.roundNumber}`);
// Fallback to HTTP POST return;
await this.sendMessageViaHttp(request);
} }
// Fallback to HTTP
await this.sendMessageViaHttp(request);
} }
/**
* Send message via HTTP POST
*/
private async sendMessageViaHttp(request: SendMessageRequest): Promise<void> { private async sendMessageViaHttp(request: SendMessageRequest): Promise<void> {
const httpUrl = this.wsUrl.replace('ws://', 'http://').replace('wss://', 'https://'); const url = `${this.httpUrl}/api/v1/messages/route`;
const controller = new AbortController();
const timeoutId = setTimeout(() => controller.abort(), this.REQUEST_TIMEOUT_MS);
try { try {
const response = await fetch(`${httpUrl}/sessions/${request.sessionId}/messages`, { const response = await fetch(url, {
method: 'POST', method: 'POST',
headers: { headers: { 'Content-Type': 'application/json' },
'Content-Type': 'application/json',
},
body: JSON.stringify({ body: JSON.stringify({
session_id: request.sessionId,
from_party: request.fromParty, from_party: request.fromParty,
to_parties: request.toParties, to_parties: request.toParties,
round_number: request.roundNumber, round_number: request.roundNumber,
payload: request.payload.toString('base64'), payload: request.payload.toString('base64'),
}), }),
signal: controller.signal,
}); });
if (!response.ok) { if (!response.ok) {
throw new Error(`HTTP ${response.status}: ${response.statusText}`); const errorText = await response.text();
throw new Error(`HTTP ${response.status}: ${errorText}`);
} }
this.logger.debug(`Message sent via HTTP: session=${request.sessionId}, round=${request.roundNumber}`);
} catch (err) { } catch (err) {
this.logger.error('Failed to send message via HTTP', err); this.logger.error('Failed to send message via HTTP', err);
throw err; throw err;
} finally {
clearTimeout(timeoutId);
} }
} }
/** /**
* Check if connected to a session * Check connection status
*/ */
isConnected(sessionId: string, partyId: string): boolean { isConnected(sessionId: string, partyId: string): boolean {
const connectionKey = `${sessionId}:${partyId}`; const connectionKey = `${sessionId}:${partyId}`;
const ws = this.connections.get(connectionKey); const state = this.connections.get(connectionKey);
return ws !== undefined && ws.readyState === WebSocket.OPEN; if (!state || state.closed) return false;
if (state.mode === 'websocket') {
return state.ws?.readyState === WebSocket.OPEN;
}
return true; // HTTP polling is always "connected" while active
} }
/** /**
* Disconnect from a session * Get current transport mode
*/
getTransportMode(sessionId: string, partyId: string): TransportMode | null {
const connectionKey = `${sessionId}:${partyId}`;
const state = this.connections.get(connectionKey);
return state?.mode ?? null;
}
/**
* Disconnect
*/ */
disconnect(sessionId: string, partyId: string): void { disconnect(sessionId: string, partyId: string): void {
const connectionKey = `${sessionId}:${partyId}`; const connectionKey = `${sessionId}:${partyId}`;
const ws = this.connections.get(connectionKey); const state = this.connections.get(connectionKey);
if (ws) { if (state) {
ws.close(); state.closed = true;
if (state.ws) {
state.ws.close();
}
this.connections.delete(connectionKey); this.connections.delete(connectionKey);
this.logger.debug(`Disconnected from: ${connectionKey}`); this.logger.debug(`Disconnected: ${connectionKey} (mode: ${state.mode})`);
} }
} }
private sleep(ms: number): Promise<void> {
return new Promise(resolve => setTimeout(resolve, ms));
}
} }