fix(ws): set tenant context for WebSocket message handling

WebSocket gateway was missing AsyncLocalStorage tenant context setup,
causing 'Tenant context not set' error on every message. Now extracts
tenantId from handshake and wraps handleMessage in tenantContext.runAsync().

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
hailin 2026-02-06 08:49:56 -08:00
parent 4ac1fc4f88
commit d1907636fe
1 changed files with 120 additions and 86 deletions

View File

@ -1,3 +1,4 @@
import { Inject } from '@nestjs/common';
import {
WebSocketGateway,
WebSocketServer,
@ -8,6 +9,12 @@ import {
MessageBody,
} from '@nestjs/websockets';
import { Server, Socket } from 'socket.io';
import {
TenantContextService,
TENANT_FINDER,
DEFAULT_TENANT_ID,
} from '@iconsulting/shared';
import type { ITenantFinder } from '@iconsulting/shared';
import { ConversationService, FileAttachment } from '../../application/services/conversation.service';
interface SendMessagePayload {
@ -16,6 +23,11 @@ interface SendMessagePayload {
attachments?: FileAttachment[];
}
interface ConnectionInfo {
userId: string;
tenantId: string;
}
@WebSocketGateway({
cors: {
origin: process.env.CORS_ORIGINS?.split(',') || ['http://localhost:5173'],
@ -30,10 +42,14 @@ export class ConversationGateway
@WebSocketServer()
server: Server;
// Map socket ID to user ID
private connections = new Map<string, string>();
// Map socket ID to connection info (userId + tenantId)
private connections = new Map<string, ConnectionInfo>();
constructor(private conversationService: ConversationService) {}
constructor(
private conversationService: ConversationService,
private readonly tenantContext: TenantContextService,
@Inject(TENANT_FINDER) private readonly tenantFinder: ITenantFinder,
) {}
async handleConnection(client: Socket) {
// Extract user ID from query or headers
@ -48,16 +64,22 @@ export class ConversationGateway
return;
}
this.connections.set(client.id, userId);
console.log(`Client ${client.id} connected as user ${userId}`);
// Extract tenant ID from query or headers, fall back to default
const tenantId =
(client.handshake.query.tenantId as string) ||
(client.handshake.headers['x-tenant-id'] as string) ||
DEFAULT_TENANT_ID;
this.connections.set(client.id, { userId, tenantId });
console.log(`Client ${client.id} connected as user ${userId} (tenant: ${tenantId})`);
client.emit('connected', { userId, socketId: client.id });
}
handleDisconnect(client: Socket) {
const userId = this.connections.get(client.id);
const conn = this.connections.get(client.id);
this.connections.delete(client.id);
console.log(`Client ${client.id} (user ${userId}) disconnected`);
console.log(`Client ${client.id} (user ${conn?.userId}) disconnected`);
}
@SubscribeMessage('message')
@ -65,9 +87,9 @@ export class ConversationGateway
@ConnectedSocket() client: Socket,
@MessageBody() payload: SendMessagePayload,
) {
const userId = this.connections.get(client.id);
const conn = this.connections.get(client.id);
if (!userId) {
if (!conn) {
client.emit('error', { message: 'Not authenticated' });
return;
}
@ -79,84 +101,96 @@ export class ConversationGateway
return;
}
try {
// Generate unique message ID for this response
const messageId = `msg_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`;
// Emit stream start
client.emit('stream_start', { messageId, conversationId });
let chunkIndex = 0;
// Stream the response
for await (const chunk of this.conversationService.sendMessage({
conversationId,
userId,
content,
attachments,
})) {
if (chunk.type === 'text' && chunk.content) {
client.emit('stream_chunk', {
messageId,
conversationId,
content: chunk.content,
index: chunkIndex++,
});
} else if (chunk.type === 'tool_use') {
client.emit('tool_call', {
messageId,
conversationId,
tool: chunk.toolName,
input: chunk.toolInput,
});
} else if (chunk.type === 'tool_result') {
client.emit('tool_result', {
messageId,
conversationId,
tool: chunk.toolName,
result: chunk.toolResult,
});
} else if (chunk.type === 'agent_start') {
client.emit('agent_start', {
messageId,
conversationId,
agentType: chunk.agentType,
agentName: chunk.agentName,
description: chunk.description,
});
} else if (chunk.type === 'agent_complete') {
client.emit('agent_complete', {
messageId,
conversationId,
agentType: chunk.agentType,
agentName: chunk.agentName,
durationMs: chunk.durationMs,
success: chunk.success,
});
} else if (chunk.type === 'coordinator_thinking') {
client.emit('coordinator_thinking', {
messageId,
conversationId,
phase: chunk.phase,
message: chunk.message,
});
} else if (chunk.type === 'end') {
client.emit('stream_end', {
messageId,
conversationId,
isComplete: true,
inputTokens: chunk.inputTokens,
outputTokens: chunk.outputTokens,
});
}
}
} catch (error) {
console.error('Error processing message:', error);
client.emit('error', {
message: error instanceof Error ? error.message : 'Failed to process message',
conversationId,
});
// Resolve tenant context for this connection
const tenant = await this.tenantFinder.findById(conn.tenantId);
if (!tenant) {
client.emit('error', { message: 'Invalid tenant', conversationId });
return;
}
// Run the message handler inside tenant context (AsyncLocalStorage)
await this.tenantContext.runAsync(tenant, async () => {
this.tenantContext.setCurrentUserId(conn.userId);
try {
// Generate unique message ID for this response
const messageId = `msg_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`;
// Emit stream start
client.emit('stream_start', { messageId, conversationId });
let chunkIndex = 0;
// Stream the response
for await (const chunk of this.conversationService.sendMessage({
conversationId,
userId: conn.userId,
content,
attachments,
})) {
if (chunk.type === 'text' && chunk.content) {
client.emit('stream_chunk', {
messageId,
conversationId,
content: chunk.content,
index: chunkIndex++,
});
} else if (chunk.type === 'tool_use') {
client.emit('tool_call', {
messageId,
conversationId,
tool: chunk.toolName,
input: chunk.toolInput,
});
} else if (chunk.type === 'tool_result') {
client.emit('tool_result', {
messageId,
conversationId,
tool: chunk.toolName,
result: chunk.toolResult,
});
} else if (chunk.type === 'agent_start') {
client.emit('agent_start', {
messageId,
conversationId,
agentType: chunk.agentType,
agentName: chunk.agentName,
description: chunk.description,
});
} else if (chunk.type === 'agent_complete') {
client.emit('agent_complete', {
messageId,
conversationId,
agentType: chunk.agentType,
agentName: chunk.agentName,
durationMs: chunk.durationMs,
success: chunk.success,
});
} else if (chunk.type === 'coordinator_thinking') {
client.emit('coordinator_thinking', {
messageId,
conversationId,
phase: chunk.phase,
message: chunk.message,
});
} else if (chunk.type === 'end') {
client.emit('stream_end', {
messageId,
conversationId,
isComplete: true,
inputTokens: chunk.inputTokens,
outputTokens: chunk.outputTokens,
});
}
}
} catch (error) {
console.error('Error processing message:', error);
client.emit('error', {
message: error instanceof Error ? error.message : 'Failed to process message',
conversationId,
});
}
}, conn.userId);
}
@SubscribeMessage('typing_start')