it0/packages/shared/database/src/tenant-context.middleware.ts

136 lines
4.4 KiB
TypeScript

import { Injectable, NestMiddleware, Optional } from '@nestjs/common';
import { DataSource } from 'typeorm';
import { TenantContextService, TenantInfo } from '@it0/common';
interface TenantRow {
id: string;
name: string;
plan: 'free' | 'pro' | 'enterprise';
max_servers: number;
max_users: number;
max_standing_orders: number;
max_agent_tokens_per_month: number;
}
interface CacheEntry {
info: TenantInfo;
expiresAt: number;
}
/** Default quota limits per plan (used when tenant row is missing) */
const DEFAULT_QUOTAS: Record<string, Pick<TenantInfo, 'maxServers' | 'maxUsers' | 'maxStandingOrders' | 'maxAgentTokensPerMonth'>> = {
free: { maxServers: 5, maxUsers: 3, maxStandingOrders: 10, maxAgentTokensPerMonth: 100_000 },
pro: { maxServers: 50, maxUsers: 20, maxStandingOrders: 100, maxAgentTokensPerMonth: 1_000_000 },
enterprise: { maxServers: -1, maxUsers: -1, maxStandingOrders: -1, maxAgentTokensPerMonth: 10_000_000 },
};
const CACHE_TTL_MS = 5 * 60 * 1000; // 5 minutes
@Injectable()
export class TenantContextMiddleware implements NestMiddleware {
private readonly cache = new Map<string, CacheEntry>();
constructor(@Optional() private readonly dataSource: DataSource) {}
use(req: any, res: any, next: () => void) {
let tenantId = req.headers?.['x-tenant-id'] as string;
// Decode JWT to populate req.user for RolesGuard
const authHeader = req.headers?.['authorization'] as string;
if (authHeader?.startsWith('Bearer ')) {
try {
const token = authHeader.slice(7);
const payload = JSON.parse(
Buffer.from(token.split('.')[1], 'base64').toString(),
);
req.user = {
id: payload.sub,
email: payload.email,
tenantId: payload.tenantId,
roles: payload.roles || [],
};
// Fall back to JWT tenantId if header is missing
if (!tenantId && payload.tenantId) {
tenantId = payload.tenantId;
}
} catch {
// Ignore decode errors - JWT validation is handled by Kong
}
}
if (!tenantId) {
return next();
}
// Serve from cache when available (synchronous fast path)
const cached = this.cache.get(tenantId);
if (cached && cached.expiresAt > Date.now()) {
TenantContextService.run(cached.info, () => next());
return;
}
// Load from DB, then run the request in tenant context
this.loadTenantInfo(tenantId)
.then((tenantInfo) => {
TenantContextService.run(tenantInfo, () => next());
})
.catch(() => {
const fallback = this.buildFallback(tenantId, 'free');
TenantContextService.run(fallback, () => next());
});
}
private async loadTenantInfo(tenantId: string): Promise<TenantInfo> {
const schemaName = `it0_t_${tenantId}`;
if (!this.dataSource) {
return this.buildFallback(tenantId, 'free');
}
try {
const rows: TenantRow[] = await this.dataSource.query(
`SELECT id, name, plan, max_servers, max_users, max_standing_orders, max_agent_tokens_per_month
FROM public.tenants WHERE id = $1 LIMIT 1`,
[tenantId],
);
const row = rows[0];
const plan = (row?.plan ?? 'free') as 'free' | 'pro' | 'enterprise';
const defaults = DEFAULT_QUOTAS[plan] ?? DEFAULT_QUOTAS['free'];
const tenantInfo: TenantInfo = {
tenantId,
tenantName: row?.name ?? tenantId,
plan,
schemaName,
maxServers: row?.max_servers ?? defaults.maxServers,
maxUsers: row?.max_users ?? defaults.maxUsers,
maxStandingOrders: row?.max_standing_orders ?? defaults.maxStandingOrders,
maxAgentTokensPerMonth: row?.max_agent_tokens_per_month ?? defaults.maxAgentTokensPerMonth,
};
this.cache.set(tenantId, { info: tenantInfo, expiresAt: Date.now() + CACHE_TTL_MS });
return tenantInfo;
} catch {
return this.buildFallback(tenantId, 'free');
}
}
private buildFallback(tenantId: string, plan: 'free' | 'pro' | 'enterprise'): TenantInfo {
const defaults = DEFAULT_QUOTAS[plan] ?? DEFAULT_QUOTAS['free'];
return {
tenantId,
tenantName: tenantId,
plan,
schemaName: `it0_t_${tenantId}`,
...defaults,
};
}
/** Invalidate cached tenant info — call after plan upgrade/downgrade. */
invalidate(tenantId: string): void {
this.cache.delete(tenantId);
}
}