diff --git a/backend/services/mpc-service/src/infrastructure/external/mpc-system/message-router-client.ts b/backend/services/mpc-service/src/infrastructure/external/mpc-system/message-router-client.ts index 28630a6b..74a336df 100644 --- a/backend/services/mpc-service/src/infrastructure/external/mpc-system/message-router-client.ts +++ b/backend/services/mpc-service/src/infrastructure/external/mpc-system/message-router-client.ts @@ -1,13 +1,13 @@ /** * MPC Message Router Client * - * WebSocket client for real-time message exchange between MPC parties. + * Hybrid client for message exchange between MPC parties. + * Strategy: Try WebSocket first, fallback to HTTP polling if WebSocket fails. */ import { Injectable, Logger, OnModuleInit, OnModuleDestroy } from '@nestjs/common'; import { ConfigService } from '@nestjs/config'; -import * as WebSocket from 'ws'; -import { EventEmitter } from 'events'; +import WebSocket from 'ws'; export interface MPCMessage { fromParty: string; @@ -29,197 +29,421 @@ export interface MessageStream { close(): void; } +type TransportMode = 'websocket' | 'http'; + +interface ConnectionState { + sessionId: string; + partyId: string; + mode: TransportMode; + closed: boolean; + ws?: WebSocket; + lastPollTime?: number; +} + @Injectable() export class MPCMessageRouterClient implements OnModuleInit, OnModuleDestroy { private readonly logger = new Logger(MPCMessageRouterClient.name); private wsUrl: string; - private connections: Map = new Map(); + private httpUrl: string; + private connections: Map = new Map(); + + // Configuration + private readonly WS_CONNECT_TIMEOUT_MS = 5000; // 5 seconds to establish WebSocket + private readonly POLL_INTERVAL_MS = 100; // Poll every 100ms (HTTP fallback) + private readonly POLL_TIMEOUT_MS = 300000; // 5 minute total timeout + private readonly REQUEST_TIMEOUT_MS = 10000; // 10 second per-request timeout constructor(private readonly configService: ConfigService) {} onModuleInit() { this.wsUrl = this.configService.get('MPC_MESSAGE_ROUTER_WS_URL') || ''; + this.httpUrl = this.wsUrl.replace('ws://', 'http://').replace('wss://', 'https://'); + if (!this.wsUrl) { this.logger.warn('MPC_MESSAGE_ROUTER_WS_URL not configured'); + } else { + this.logger.log(`Message router configured - WS: ${this.wsUrl}, HTTP: ${this.httpUrl}`); } } onModuleDestroy() { - // Close all WebSocket connections - for (const [key, ws] of this.connections) { - this.logger.debug(`Closing WebSocket connection: ${key}`); - ws.close(); + for (const [key, state] of this.connections) { + this.logger.debug(`Closing connection: ${key}`); + state.closed = true; + if (state.ws) { + state.ws.close(); + } } this.connections.clear(); } /** - * Subscribe to messages for a session/party + * Subscribe to messages - tries WebSocket first, falls back to HTTP polling */ async subscribeMessages(sessionId: string, partyId: string): Promise { const connectionKey = `${sessionId}:${partyId}`; this.logger.log(`Subscribing to messages: ${connectionKey}`); + // Try WebSocket first + try { + const stream = await this.tryWebSocketSubscribe(sessionId, partyId); + this.logger.log(`WebSocket connection established for ${connectionKey}`); + return stream; + } catch (wsError) { + this.logger.warn(`WebSocket failed for ${connectionKey}, falling back to HTTP polling: ${wsError}`); + } + + // Fallback to HTTP polling + return this.httpPollingSubscribe(sessionId, partyId); + } + + /** + * Try to establish WebSocket connection with timeout + */ + private async tryWebSocketSubscribe(sessionId: string, partyId: string): Promise { + const connectionKey = `${sessionId}:${partyId}`; const url = `${this.wsUrl}/sessions/${sessionId}/messages?party_id=${partyId}`; - const ws = new WebSocket(url); - this.connections.set(connectionKey, ws); + + return new Promise((resolve, reject) => { + const ws = new WebSocket(url); + const connectTimeout = setTimeout(() => { + ws.close(); + reject(new Error('WebSocket connection timeout')); + }, this.WS_CONNECT_TIMEOUT_MS); + + const state: ConnectionState = { + sessionId, + partyId, + mode: 'websocket', + closed: false, + ws, + }; + + const messageQueue: MPCMessage[] = []; + const waiters: Array<{ + resolve: (value: { value: MPCMessage; done: false } | { done: true; value: undefined }) => void; + reject: (error: Error) => void; + }> = []; + let error: Error | null = null; + + ws.on('open', () => { + clearTimeout(connectTimeout); + this.connections.set(connectionKey, state); + this.logger.debug(`WebSocket connected: ${connectionKey}`); + + resolve({ + next: () => { + return new Promise((res, rej) => { + if (error) { + rej(error); + return; + } + if (messageQueue.length > 0) { + res({ value: messageQueue.shift()!, done: false }); + return; + } + if (state.closed) { + res({ done: true, value: undefined }); + return; + } + waiters.push({ resolve: res, reject: rej }); + }); + }, + close: () => { + state.closed = true; + ws.close(); + this.connections.delete(connectionKey); + this.logger.debug(`WebSocket closed: ${connectionKey}`); + }, + }); + }); + + ws.on('message', (data: Buffer) => { + try { + const parsed = JSON.parse(data.toString()); + const message: MPCMessage = { + fromParty: parsed.from_party, + toParties: parsed.to_parties, + roundNumber: parsed.round_number, + payload: Buffer.from(parsed.payload, 'base64'), + }; + + if (waiters.length > 0) { + const waiter = waiters.shift()!; + waiter.resolve({ value: message, done: false }); + } else { + messageQueue.push(message); + } + } catch (err) { + this.logger.error('Failed to parse WebSocket message', err); + } + }); + + ws.on('error', (err) => { + clearTimeout(connectTimeout); + this.logger.error(`WebSocket error: ${connectionKey}`, err); + error = err instanceof Error ? err : new Error(String(err)); + + // If not yet connected, reject the promise + if (!this.connections.has(connectionKey)) { + reject(error); + return; + } + + // Reject all waiting consumers + while (waiters.length > 0) { + const waiter = waiters.shift()!; + waiter.reject(error); + } + }); + + ws.on('close', () => { + clearTimeout(connectTimeout); + this.logger.debug(`WebSocket closed: ${connectionKey}`); + state.closed = true; + this.connections.delete(connectionKey); + + // If not yet connected, reject + if (!this.connections.has(connectionKey) && !error) { + reject(new Error('WebSocket closed before connection established')); + return; + } + + // Resolve all waiting consumers with done + while (waiters.length > 0) { + const waiter = waiters.shift()!; + waiter.resolve({ done: true, value: undefined }); + } + }); + }); + } + + /** + * HTTP polling fallback + */ + private async httpPollingSubscribe(sessionId: string, partyId: string): Promise { + const connectionKey = `${sessionId}:${partyId}`; + this.logger.log(`Starting HTTP polling for ${connectionKey}`); + + const state: ConnectionState = { + sessionId, + partyId, + mode: 'http', + closed: false, + lastPollTime: 0, + }; + this.connections.set(connectionKey, state); const messageQueue: MPCMessage[] = []; - const waiters: Array<{ - resolve: (value: { value: MPCMessage; done: false } | { done: true; value: undefined }) => void; - reject: (error: Error) => void; - }> = []; - let closed = false; let error: Error | null = null; + const startTime = Date.now(); - ws.on('open', () => { - this.logger.debug(`WebSocket connected: ${connectionKey}`); - }); - - ws.on('message', (data: Buffer) => { - try { - const parsed = JSON.parse(data.toString()); - const message: MPCMessage = { - fromParty: parsed.from_party, - toParties: parsed.to_parties, - roundNumber: parsed.round_number, - payload: Buffer.from(parsed.payload, 'base64'), - }; - - // If there's a waiting consumer, deliver immediately - if (waiters.length > 0) { - const waiter = waiters.shift()!; - waiter.resolve({ value: message, done: false }); - } else { - // Otherwise queue the message - messageQueue.push(message); + // Background polling + const pollLoop = async () => { + while (!state.closed) { + if (Date.now() - startTime > this.POLL_TIMEOUT_MS) { + this.logger.warn(`HTTP polling timeout for ${connectionKey}`); + state.closed = true; + break; } - } catch (err) { - this.logger.error('Failed to parse message', err); - } - }); - ws.on('error', (err) => { - this.logger.error(`WebSocket error: ${connectionKey}`, err); + try { + const messages = await this.fetchPendingMessages(sessionId, partyId); + for (const msg of messages) { + messageQueue.push(msg); + } + } catch (err) { + this.logger.debug(`HTTP poll error for ${connectionKey}: ${err}`); + } + + if (!state.closed) { + await this.sleep(this.POLL_INTERVAL_MS); + } + } + }; + + pollLoop().catch((err) => { + this.logger.error(`HTTP polling failed for ${connectionKey}`, err); error = err instanceof Error ? err : new Error(String(err)); - - // Reject all waiting consumers - while (waiters.length > 0) { - const waiter = waiters.shift()!; - waiter.reject(error); - } - }); - - ws.on('close', () => { - this.logger.debug(`WebSocket closed: ${connectionKey}`); - closed = true; - this.connections.delete(connectionKey); - - // Resolve all waiting consumers with done - while (waiters.length > 0) { - const waiter = waiters.shift()!; - waiter.resolve({ done: true, value: undefined }); - } }); return { - next: () => { - return new Promise((resolve, reject) => { - if (error) { - reject(error); - return; - } + next: async () => { + const waitStart = Date.now(); + const maxWait = 30000; + + while (Date.now() - waitStart < maxWait) { + if (error) throw error; if (messageQueue.length > 0) { - resolve({ value: messageQueue.shift()!, done: false }); - return; + return { value: messageQueue.shift()!, done: false as const }; } - if (closed) { - resolve({ done: true, value: undefined }); - return; + if (state.closed && messageQueue.length === 0) { + return { done: true as const, value: undefined }; } - // Wait for next message - waiters.push({ resolve, reject }); - }); + await this.sleep(50); + } + + if (state.closed) { + return { done: true as const, value: undefined }; + } + + // No messages received in time, but not closed - keep waiting + return { done: true as const, value: undefined }; }, close: () => { - if (!closed) { - ws.close(); - } + state.closed = true; + this.connections.delete(connectionKey); + this.logger.debug(`HTTP polling stopped: ${connectionKey}`); }, }; } /** - * Send a message to other parties + * Fetch pending messages via HTTP + */ + private async fetchPendingMessages(sessionId: string, partyId: string): Promise { + const url = `${this.httpUrl}/api/v1/messages/pending?session_id=${sessionId}&party_id=${partyId}`; + + const controller = new AbortController(); + const timeoutId = setTimeout(() => controller.abort(), this.REQUEST_TIMEOUT_MS); + + try { + const response = await fetch(url, { + method: 'GET', + headers: { 'Accept': 'application/json' }, + signal: controller.signal, + }); + + if (!response.ok) { + throw new Error(`HTTP ${response.status}: ${response.statusText}`); + } + + const data = await response.json(); + const messages: MPCMessage[] = []; + + if (data.messages && Array.isArray(data.messages)) { + for (const msg of data.messages) { + messages.push({ + fromParty: msg.from_party, + toParties: msg.to_parties, + roundNumber: msg.round_number, + payload: Buffer.from(msg.payload, 'base64'), + }); + } + } + + return messages; + } finally { + clearTimeout(timeoutId); + } + } + + /** + * Send message - uses WebSocket if connected, otherwise HTTP */ async sendMessage(request: SendMessageRequest): Promise { const connectionKey = `${request.sessionId}:${request.fromParty}`; - const ws = this.connections.get(connectionKey); + const state = this.connections.get(connectionKey); - if (ws && ws.readyState === WebSocket.OPEN) { - // Send via WebSocket if connected + // Try WebSocket if available + if (state?.mode === 'websocket' && state.ws?.readyState === WebSocket.OPEN) { const message = JSON.stringify({ from_party: request.fromParty, to_parties: request.toParties, round_number: request.roundNumber, payload: request.payload.toString('base64'), }); - ws.send(message); - } else { - // Fallback to HTTP POST - await this.sendMessageViaHttp(request); + state.ws.send(message); + this.logger.debug(`Message sent via WebSocket: session=${request.sessionId}, round=${request.roundNumber}`); + return; } + + // Fallback to HTTP + await this.sendMessageViaHttp(request); } + /** + * Send message via HTTP POST + */ private async sendMessageViaHttp(request: SendMessageRequest): Promise { - const httpUrl = this.wsUrl.replace('ws://', 'http://').replace('wss://', 'https://'); + const url = `${this.httpUrl}/api/v1/messages/route`; + + const controller = new AbortController(); + const timeoutId = setTimeout(() => controller.abort(), this.REQUEST_TIMEOUT_MS); try { - const response = await fetch(`${httpUrl}/sessions/${request.sessionId}/messages`, { + const response = await fetch(url, { method: 'POST', - headers: { - 'Content-Type': 'application/json', - }, + headers: { 'Content-Type': 'application/json' }, body: JSON.stringify({ + session_id: request.sessionId, from_party: request.fromParty, to_parties: request.toParties, round_number: request.roundNumber, payload: request.payload.toString('base64'), }), + signal: controller.signal, }); if (!response.ok) { - throw new Error(`HTTP ${response.status}: ${response.statusText}`); + const errorText = await response.text(); + throw new Error(`HTTP ${response.status}: ${errorText}`); } + + this.logger.debug(`Message sent via HTTP: session=${request.sessionId}, round=${request.roundNumber}`); } catch (err) { this.logger.error('Failed to send message via HTTP', err); throw err; + } finally { + clearTimeout(timeoutId); } } /** - * Check if connected to a session + * Check connection status */ isConnected(sessionId: string, partyId: string): boolean { const connectionKey = `${sessionId}:${partyId}`; - const ws = this.connections.get(connectionKey); - return ws !== undefined && ws.readyState === WebSocket.OPEN; + const state = this.connections.get(connectionKey); + if (!state || state.closed) return false; + + if (state.mode === 'websocket') { + return state.ws?.readyState === WebSocket.OPEN; + } + + return true; // HTTP polling is always "connected" while active } /** - * Disconnect from a session + * Get current transport mode + */ + getTransportMode(sessionId: string, partyId: string): TransportMode | null { + const connectionKey = `${sessionId}:${partyId}`; + const state = this.connections.get(connectionKey); + return state?.mode ?? null; + } + + /** + * Disconnect */ disconnect(sessionId: string, partyId: string): void { const connectionKey = `${sessionId}:${partyId}`; - const ws = this.connections.get(connectionKey); + const state = this.connections.get(connectionKey); - if (ws) { - ws.close(); + if (state) { + state.closed = true; + if (state.ws) { + state.ws.close(); + } this.connections.delete(connectionKey); - this.logger.debug(`Disconnected from: ${connectionKey}`); + this.logger.debug(`Disconnected: ${connectionKey} (mode: ${state.mode})`); } } + + private sleep(ms: number): Promise { + return new Promise(resolve => setTimeout(resolve, ms)); + } }