iconsulting/packages/services/conversation-service/src/infrastructure/claude/token-usage.service.ts

304 lines
8.4 KiB
TypeScript

import { Injectable } from '@nestjs/common';
import { InjectRepository } from '@nestjs/typeorm';
import { Repository, Between, MoreThanOrEqual } from 'typeorm';
import { TokenUsageORM } from '../database/postgres/entities/token-usage.orm';
/**
* Claude API 定价 (截至 2024年)
* claude-sonnet-4-20250514:
* - Input: $3/MTok
* - Output: $15/MTok
* - Cache write: $3.75/MTok
* - Cache read: $0.30/MTok
*/
interface PricingTier {
input: number;
output: number;
cacheWrite: number;
cacheRead: number;
}
const PRICING: Record<string, PricingTier> = {
'claude-sonnet-4-20250514': {
input: 3 / 1_000_000,
output: 15 / 1_000_000,
cacheWrite: 3.75 / 1_000_000,
cacheRead: 0.30 / 1_000_000,
},
default: {
input: 3 / 1_000_000,
output: 15 / 1_000_000,
cacheWrite: 3.75 / 1_000_000,
cacheRead: 0.30 / 1_000_000,
},
};
export interface TokenUsageInput {
userId?: string;
conversationId: string;
messageId?: string;
model: string;
inputTokens: number;
outputTokens: number;
cacheCreationTokens?: number;
cacheReadTokens?: number;
intentType?: string;
toolCalls?: number;
responseLength?: number;
latencyMs?: number;
}
export interface UsageStats {
totalRequests: number;
totalInputTokens: number;
totalOutputTokens: number;
totalCacheReadTokens: number;
totalCacheCreationTokens: number;
totalTokens: number;
totalCost: number;
avgInputTokens: number;
avgOutputTokens: number;
avgLatencyMs: number;
cacheHitRate: number;
}
@Injectable()
export class TokenUsageService {
constructor(
@InjectRepository(TokenUsageORM)
private tokenUsageRepository: Repository<TokenUsageORM>,
) {}
/**
* 计算估算成本
*/
private calculateCost(
model: string,
inputTokens: number,
outputTokens: number,
cacheCreationTokens: number,
cacheReadTokens: number,
): number {
const pricing = PRICING[model] || PRICING.default;
// 缓存命中的 tokens 不计入普通输入
const regularInputTokens = inputTokens - cacheReadTokens;
return (
regularInputTokens * pricing.input +
outputTokens * pricing.output +
cacheCreationTokens * pricing.cacheWrite +
cacheReadTokens * pricing.cacheRead
);
}
/**
* 记录一次 API 调用的 token 使用量
*/
async recordUsage(input: TokenUsageInput): Promise<TokenUsageORM> {
const cacheCreationTokens = input.cacheCreationTokens || 0;
const cacheReadTokens = input.cacheReadTokens || 0;
const totalTokens = input.inputTokens + input.outputTokens;
const estimatedCost = this.calculateCost(
input.model,
input.inputTokens,
input.outputTokens,
cacheCreationTokens,
cacheReadTokens,
);
const entity = this.tokenUsageRepository.create({
userId: input.userId || null,
conversationId: input.conversationId,
messageId: input.messageId || null,
model: input.model,
inputTokens: input.inputTokens,
outputTokens: input.outputTokens,
cacheCreationTokens,
cacheReadTokens,
totalTokens,
estimatedCost,
intentType: input.intentType || null,
toolCalls: input.toolCalls || 0,
responseLength: input.responseLength || 0,
latencyMs: input.latencyMs || 0,
});
const saved = await this.tokenUsageRepository.save(entity);
console.log(
`[TokenUsage] Recorded: in=${input.inputTokens}, out=${input.outputTokens}, ` +
`cache_read=${cacheReadTokens}, cost=$${estimatedCost.toFixed(6)}, ` +
`intent=${input.intentType}, latency=${input.latencyMs}ms`
);
return saved;
}
/**
* 获取用户的 token 使用统计
*/
async getUserStats(userId: string, days: number = 30): Promise<UsageStats> {
const since = new Date();
since.setDate(since.getDate() - days);
const records = await this.tokenUsageRepository.find({
where: {
userId,
createdAt: MoreThanOrEqual(since),
},
});
return this.calculateStats(records);
}
/**
* 获取对话的 token 使用统计
*/
async getConversationStats(conversationId: string): Promise<UsageStats> {
const records = await this.tokenUsageRepository.find({
where: { conversationId },
});
return this.calculateStats(records);
}
/**
* 获取全局统计 (管理员用)
*/
async getGlobalStats(days: number = 30): Promise<UsageStats & { uniqueUsers: number }> {
const since = new Date();
since.setDate(since.getDate() - days);
const records = await this.tokenUsageRepository.find({
where: {
createdAt: MoreThanOrEqual(since),
},
});
const stats = this.calculateStats(records);
const uniqueUsers = new Set(records.filter(r => r.userId).map(r => r.userId)).size;
return { ...stats, uniqueUsers };
}
/**
* 获取日统计明细
*/
async getDailyStats(days: number = 7): Promise<Array<{
date: string;
requests: number;
totalTokens: number;
totalCost: number;
}>> {
const since = new Date();
since.setDate(since.getDate() - days);
const records = await this.tokenUsageRepository.find({
where: {
createdAt: MoreThanOrEqual(since),
},
order: { createdAt: 'ASC' },
});
// 按日期分组
const byDate = new Map<string, TokenUsageORM[]>();
for (const record of records) {
const date = record.createdAt.toISOString().split('T')[0];
if (!byDate.has(date)) {
byDate.set(date, []);
}
byDate.get(date)!.push(record);
}
return Array.from(byDate.entries()).map(([date, dayRecords]) => ({
date,
requests: dayRecords.length,
totalTokens: dayRecords.reduce((sum, r) => sum + r.totalTokens, 0),
totalCost: dayRecords.reduce((sum, r) => sum + Number(r.estimatedCost), 0),
}));
}
/**
* 获取用户排行榜 (按 token 消耗)
*/
async getTopUsers(days: number = 30, limit: number = 10): Promise<Array<{
userId: string;
totalTokens: number;
totalCost: number;
requestCount: number;
}>> {
const since = new Date();
since.setDate(since.getDate() - days);
const result = await this.tokenUsageRepository
.createQueryBuilder('usage')
.select('usage.user_id', 'userId')
.addSelect('SUM(usage.total_tokens)', 'totalTokens')
.addSelect('SUM(usage.estimated_cost)', 'totalCost')
.addSelect('COUNT(*)', 'requestCount')
.where('usage.created_at >= :since', { since })
.andWhere('usage.user_id IS NOT NULL')
.groupBy('usage.user_id')
.orderBy('SUM(usage.total_tokens)', 'DESC')
.limit(limit)
.getRawMany();
return result.map(r => ({
userId: r.userId,
totalTokens: parseInt(r.totalTokens) || 0,
totalCost: parseFloat(r.totalCost) || 0,
requestCount: parseInt(r.requestCount) || 0,
}));
}
/**
* 计算统计数据
*/
private calculateStats(records: TokenUsageORM[]): UsageStats {
if (records.length === 0) {
return {
totalRequests: 0,
totalInputTokens: 0,
totalOutputTokens: 0,
totalCacheReadTokens: 0,
totalCacheCreationTokens: 0,
totalTokens: 0,
totalCost: 0,
avgInputTokens: 0,
avgOutputTokens: 0,
avgLatencyMs: 0,
cacheHitRate: 0,
};
}
const totalInputTokens = records.reduce((sum, r) => sum + r.inputTokens, 0);
const totalOutputTokens = records.reduce((sum, r) => sum + r.outputTokens, 0);
const totalCacheReadTokens = records.reduce((sum, r) => sum + r.cacheReadTokens, 0);
const totalCacheCreationTokens = records.reduce((sum, r) => sum + r.cacheCreationTokens, 0);
const totalTokens = records.reduce((sum, r) => sum + r.totalTokens, 0);
const totalCost = records.reduce((sum, r) => sum + Number(r.estimatedCost), 0);
const totalLatency = records.reduce((sum, r) => sum + r.latencyMs, 0);
// 缓存命中率 = 缓存读取的 tokens / 总输入 tokens
const cacheHitRate = totalInputTokens > 0
? (totalCacheReadTokens / totalInputTokens) * 100
: 0;
return {
totalRequests: records.length,
totalInputTokens,
totalOutputTokens,
totalCacheReadTokens,
totalCacheCreationTokens,
totalTokens,
totalCost,
avgInputTokens: Math.round(totalInputTokens / records.length),
avgOutputTokens: Math.round(totalOutputTokens / records.length),
avgLatencyMs: Math.round(totalLatency / records.length),
cacheHitRate: Math.round(cacheHitRate * 100) / 100,
};
}
}