import { Injectable } from '@nestjs/common'; import { InjectRepository } from '@nestjs/typeorm'; import { Repository, Between, MoreThanOrEqual } from 'typeorm'; import { TokenUsageORM } from '../database/postgres/entities/token-usage.orm'; /** * Claude API 定价 (截至 2024年) * claude-sonnet-4-20250514: * - Input: $3/MTok * - Output: $15/MTok * - Cache write: $3.75/MTok * - Cache read: $0.30/MTok */ interface PricingTier { input: number; output: number; cacheWrite: number; cacheRead: number; } const PRICING: Record = { 'claude-sonnet-4-20250514': { input: 3 / 1_000_000, output: 15 / 1_000_000, cacheWrite: 3.75 / 1_000_000, cacheRead: 0.30 / 1_000_000, }, default: { input: 3 / 1_000_000, output: 15 / 1_000_000, cacheWrite: 3.75 / 1_000_000, cacheRead: 0.30 / 1_000_000, }, }; export interface TokenUsageInput { userId?: string; conversationId: string; messageId?: string; model: string; inputTokens: number; outputTokens: number; cacheCreationTokens?: number; cacheReadTokens?: number; intentType?: string; toolCalls?: number; responseLength?: number; latencyMs?: number; } export interface UsageStats { totalRequests: number; totalInputTokens: number; totalOutputTokens: number; totalCacheReadTokens: number; totalCacheCreationTokens: number; totalTokens: number; totalCost: number; avgInputTokens: number; avgOutputTokens: number; avgLatencyMs: number; cacheHitRate: number; } @Injectable() export class TokenUsageService { constructor( @InjectRepository(TokenUsageORM) private tokenUsageRepository: Repository, ) {} /** * 计算估算成本 */ private calculateCost( model: string, inputTokens: number, outputTokens: number, cacheCreationTokens: number, cacheReadTokens: number, ): number { const pricing = PRICING[model] || PRICING.default; // 缓存命中的 tokens 不计入普通输入 const regularInputTokens = inputTokens - cacheReadTokens; return ( regularInputTokens * pricing.input + outputTokens * pricing.output + cacheCreationTokens * pricing.cacheWrite + cacheReadTokens * pricing.cacheRead ); } /** * 记录一次 API 调用的 token 使用量 */ async recordUsage(input: TokenUsageInput): Promise { const cacheCreationTokens = input.cacheCreationTokens || 0; const cacheReadTokens = input.cacheReadTokens || 0; const totalTokens = input.inputTokens + input.outputTokens; const estimatedCost = this.calculateCost( input.model, input.inputTokens, input.outputTokens, cacheCreationTokens, cacheReadTokens, ); const entity = this.tokenUsageRepository.create({ userId: input.userId || null, conversationId: input.conversationId, messageId: input.messageId || null, model: input.model, inputTokens: input.inputTokens, outputTokens: input.outputTokens, cacheCreationTokens, cacheReadTokens, totalTokens, estimatedCost, intentType: input.intentType || null, toolCalls: input.toolCalls || 0, responseLength: input.responseLength || 0, latencyMs: input.latencyMs || 0, }); const saved = await this.tokenUsageRepository.save(entity); console.log( `[TokenUsage] Recorded: in=${input.inputTokens}, out=${input.outputTokens}, ` + `cache_read=${cacheReadTokens}, cost=$${estimatedCost.toFixed(6)}, ` + `intent=${input.intentType}, latency=${input.latencyMs}ms` ); return saved; } /** * 获取用户的 token 使用统计 */ async getUserStats(userId: string, days: number = 30): Promise { const since = new Date(); since.setDate(since.getDate() - days); const records = await this.tokenUsageRepository.find({ where: { userId, createdAt: MoreThanOrEqual(since), }, }); return this.calculateStats(records); } /** * 获取对话的 token 使用统计 */ async getConversationStats(conversationId: string): Promise { const records = await this.tokenUsageRepository.find({ where: { conversationId }, }); return this.calculateStats(records); } /** * 获取全局统计 (管理员用) */ async getGlobalStats(days: number = 30): Promise { const since = new Date(); since.setDate(since.getDate() - days); const records = await this.tokenUsageRepository.find({ where: { createdAt: MoreThanOrEqual(since), }, }); const stats = this.calculateStats(records); const uniqueUsers = new Set(records.filter(r => r.userId).map(r => r.userId)).size; return { ...stats, uniqueUsers }; } /** * 获取日统计明细 */ async getDailyStats(days: number = 7): Promise> { const since = new Date(); since.setDate(since.getDate() - days); const records = await this.tokenUsageRepository.find({ where: { createdAt: MoreThanOrEqual(since), }, order: { createdAt: 'ASC' }, }); // 按日期分组 const byDate = new Map(); for (const record of records) { const date = record.createdAt.toISOString().split('T')[0]; if (!byDate.has(date)) { byDate.set(date, []); } byDate.get(date)!.push(record); } return Array.from(byDate.entries()).map(([date, dayRecords]) => ({ date, requests: dayRecords.length, totalTokens: dayRecords.reduce((sum, r) => sum + r.totalTokens, 0), totalCost: dayRecords.reduce((sum, r) => sum + Number(r.estimatedCost), 0), })); } /** * 获取用户排行榜 (按 token 消耗) */ async getTopUsers(days: number = 30, limit: number = 10): Promise> { const since = new Date(); since.setDate(since.getDate() - days); const result = await this.tokenUsageRepository .createQueryBuilder('usage') .select('usage.user_id', 'userId') .addSelect('SUM(usage.total_tokens)', 'totalTokens') .addSelect('SUM(usage.estimated_cost)', 'totalCost') .addSelect('COUNT(*)', 'requestCount') .where('usage.created_at >= :since', { since }) .andWhere('usage.user_id IS NOT NULL') .groupBy('usage.user_id') .orderBy('SUM(usage.total_tokens)', 'DESC') .limit(limit) .getRawMany(); return result.map(r => ({ userId: r.userId, totalTokens: parseInt(r.totalTokens) || 0, totalCost: parseFloat(r.totalCost) || 0, requestCount: parseInt(r.requestCount) || 0, })); } /** * 计算统计数据 */ private calculateStats(records: TokenUsageORM[]): UsageStats { if (records.length === 0) { return { totalRequests: 0, totalInputTokens: 0, totalOutputTokens: 0, totalCacheReadTokens: 0, totalCacheCreationTokens: 0, totalTokens: 0, totalCost: 0, avgInputTokens: 0, avgOutputTokens: 0, avgLatencyMs: 0, cacheHitRate: 0, }; } const totalInputTokens = records.reduce((sum, r) => sum + r.inputTokens, 0); const totalOutputTokens = records.reduce((sum, r) => sum + r.outputTokens, 0); const totalCacheReadTokens = records.reduce((sum, r) => sum + r.cacheReadTokens, 0); const totalCacheCreationTokens = records.reduce((sum, r) => sum + r.cacheCreationTokens, 0); const totalTokens = records.reduce((sum, r) => sum + r.totalTokens, 0); const totalCost = records.reduce((sum, r) => sum + Number(r.estimatedCost), 0); const totalLatency = records.reduce((sum, r) => sum + r.latencyMs, 0); // 缓存命中率 = 缓存读取的 tokens / 总输入 tokens const cacheHitRate = totalInputTokens > 0 ? (totalCacheReadTokens / totalInputTokens) * 100 : 0; return { totalRequests: records.length, totalInputTokens, totalOutputTokens, totalCacheReadTokens, totalCacheCreationTokens, totalTokens, totalCost, avgInputTokens: Math.round(totalInputTokens / records.length), avgOutputTokens: Math.round(totalOutputTokens / records.length), avgLatencyMs: Math.round(totalLatency / records.length), cacheHitRate: Math.round(cacheHitRate * 100) / 100, }; } }