fix(ws): set tenant context for WebSocket message handling

WebSocket gateway was missing AsyncLocalStorage tenant context setup,
causing 'Tenant context not set' error on every message. Now extracts
tenantId from handshake and wraps handleMessage in tenantContext.runAsync().

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
hailin 2026-02-06 08:49:56 -08:00
parent 4ac1fc4f88
commit d1907636fe
1 changed files with 120 additions and 86 deletions

View File

@ -1,3 +1,4 @@
import { Inject } from '@nestjs/common';
import { import {
WebSocketGateway, WebSocketGateway,
WebSocketServer, WebSocketServer,
@ -8,6 +9,12 @@ import {
MessageBody, MessageBody,
} from '@nestjs/websockets'; } from '@nestjs/websockets';
import { Server, Socket } from 'socket.io'; import { Server, Socket } from 'socket.io';
import {
TenantContextService,
TENANT_FINDER,
DEFAULT_TENANT_ID,
} from '@iconsulting/shared';
import type { ITenantFinder } from '@iconsulting/shared';
import { ConversationService, FileAttachment } from '../../application/services/conversation.service'; import { ConversationService, FileAttachment } from '../../application/services/conversation.service';
interface SendMessagePayload { interface SendMessagePayload {
@ -16,6 +23,11 @@ interface SendMessagePayload {
attachments?: FileAttachment[]; attachments?: FileAttachment[];
} }
interface ConnectionInfo {
userId: string;
tenantId: string;
}
@WebSocketGateway({ @WebSocketGateway({
cors: { cors: {
origin: process.env.CORS_ORIGINS?.split(',') || ['http://localhost:5173'], origin: process.env.CORS_ORIGINS?.split(',') || ['http://localhost:5173'],
@ -30,10 +42,14 @@ export class ConversationGateway
@WebSocketServer() @WebSocketServer()
server: Server; server: Server;
// Map socket ID to user ID // Map socket ID to connection info (userId + tenantId)
private connections = new Map<string, string>(); private connections = new Map<string, ConnectionInfo>();
constructor(private conversationService: ConversationService) {} constructor(
private conversationService: ConversationService,
private readonly tenantContext: TenantContextService,
@Inject(TENANT_FINDER) private readonly tenantFinder: ITenantFinder,
) {}
async handleConnection(client: Socket) { async handleConnection(client: Socket) {
// Extract user ID from query or headers // Extract user ID from query or headers
@ -48,16 +64,22 @@ export class ConversationGateway
return; return;
} }
this.connections.set(client.id, userId); // Extract tenant ID from query or headers, fall back to default
console.log(`Client ${client.id} connected as user ${userId}`); const tenantId =
(client.handshake.query.tenantId as string) ||
(client.handshake.headers['x-tenant-id'] as string) ||
DEFAULT_TENANT_ID;
this.connections.set(client.id, { userId, tenantId });
console.log(`Client ${client.id} connected as user ${userId} (tenant: ${tenantId})`);
client.emit('connected', { userId, socketId: client.id }); client.emit('connected', { userId, socketId: client.id });
} }
handleDisconnect(client: Socket) { handleDisconnect(client: Socket) {
const userId = this.connections.get(client.id); const conn = this.connections.get(client.id);
this.connections.delete(client.id); this.connections.delete(client.id);
console.log(`Client ${client.id} (user ${userId}) disconnected`); console.log(`Client ${client.id} (user ${conn?.userId}) disconnected`);
} }
@SubscribeMessage('message') @SubscribeMessage('message')
@ -65,9 +87,9 @@ export class ConversationGateway
@ConnectedSocket() client: Socket, @ConnectedSocket() client: Socket,
@MessageBody() payload: SendMessagePayload, @MessageBody() payload: SendMessagePayload,
) { ) {
const userId = this.connections.get(client.id); const conn = this.connections.get(client.id);
if (!userId) { if (!conn) {
client.emit('error', { message: 'Not authenticated' }); client.emit('error', { message: 'Not authenticated' });
return; return;
} }
@ -79,6 +101,17 @@ export class ConversationGateway
return; return;
} }
// Resolve tenant context for this connection
const tenant = await this.tenantFinder.findById(conn.tenantId);
if (!tenant) {
client.emit('error', { message: 'Invalid tenant', conversationId });
return;
}
// Run the message handler inside tenant context (AsyncLocalStorage)
await this.tenantContext.runAsync(tenant, async () => {
this.tenantContext.setCurrentUserId(conn.userId);
try { try {
// Generate unique message ID for this response // Generate unique message ID for this response
const messageId = `msg_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`; const messageId = `msg_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`;
@ -91,7 +124,7 @@ export class ConversationGateway
// Stream the response // Stream the response
for await (const chunk of this.conversationService.sendMessage({ for await (const chunk of this.conversationService.sendMessage({
conversationId, conversationId,
userId, userId: conn.userId,
content, content,
attachments, attachments,
})) { })) {
@ -157,6 +190,7 @@ export class ConversationGateway
conversationId, conversationId,
}); });
} }
}, conn.userId);
} }
@SubscribeMessage('typing_start') @SubscribeMessage('typing_start')