Version: 0.9.80.dev.260506
后端: 1. LLM 独立服务与统一计费出口落地:新增 `cmd/llm`、`client/llm` 与 `services/llm/rpc`,补齐 BillingContext、CreditBalanceGuard、价格规则解析、stream usage 归集与 `credit.charge.requested` outbox 发布,active-scheduler / agent / course / memory / gateway fallback 全部改走 llm zrpc,不再各自本地初始化模型。 2. TokenStore 收口为 Credit 权威账本:新增 credit account / ledger / product / order / price-rule / reward-rule 能力与 Redis 快照缓存,扩展 tokenstore rpc/client 支撑余额快照、消耗看板、商品、订单、流水、价格规则和奖励规则,并接入 LLM charge 事件消费完成 Credit 扣费落账。 3. 计费旧链路下线与网关切口切换:`/token-store` 语义整体切到 `/credit-store`,agent chat 移除旧 TokenQuotaGuard,userauth 的 CheckTokenQuota / AdjustTokenUsage 改为废弃,聊天历史落库不再同步旧 token 额度账本,course 图片解析请求补 user_id 进入新计费口径。 前端: 4. 计划广场从 mock 数据切到真实接口:新增 forum api/types,首页支持真实列表、标签、搜索、防抖、点赞、导入和发布计划,详情页补齐帖子详情、评论树、回复和删除评论链路,同时补上“至少一个标签”的前后端约束与默认标签兜底。 5. 商店页切到 Credit 体系并重做展示:顶部改为余额 + Credit/Token 消耗看板,支持 24h/7d/30d/all 周期切换;套餐区展示原价与当前价;历史区改为当前用户 Credit 流水并支持查看更多,整体视觉和交互同步收口。 仓库: 6. 配置与本地启动体系补齐 llm / outbox 编排:`config.example.yaml` 增加 llm rpc 和统一 outbox service 配置,`dev-common.ps1` 把 llm 纳入多服务依赖并自动建 Kafka topic,`docker-compose.yml` 同步初始化 agent/task/memory/active-scheduler/notification/taskclass-forum/llm/token-store 全量 outbox topic。
This commit is contained in:
@@ -46,8 +46,9 @@ type ArkResponsesResult struct {
|
||||
|
||||
// ArkResponsesClient 是 Ark SDK Responses 的统一模型出口。
|
||||
type ArkResponsesClient struct {
|
||||
model string
|
||||
client *arkruntime.Client
|
||||
model string
|
||||
client *arkruntime.Client
|
||||
generateText func(ctx context.Context, messages []ArkResponsesMessage, options ArkResponsesOptions) (*ArkResponsesResult, error)
|
||||
}
|
||||
|
||||
// NewArkResponsesClient 创建 Ark SDK Responses 客户端。
|
||||
@@ -71,8 +72,28 @@ func NewArkResponsesClient(apiKey string, baseURL string, model string) *ArkResp
|
||||
}
|
||||
}
|
||||
|
||||
// NewArkResponsesClientWithFunc 使用外部注入的 GenerateText 能力构造客户端。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. 供 llm zrpc remote client 和测试替身复用;
|
||||
// 2. 这里只负责挂接统一函数签名,不负责远端连接初始化;
|
||||
// 3. model 仅作为兼容字段保留,真正调用行为以 generateText 为准。
|
||||
func NewArkResponsesClientWithFunc(model string, generateText func(ctx context.Context, messages []ArkResponsesMessage, options ArkResponsesOptions) (*ArkResponsesResult, error)) *ArkResponsesClient {
|
||||
if generateText == nil {
|
||||
return nil
|
||||
}
|
||||
return &ArkResponsesClient{
|
||||
model: strings.TrimSpace(model),
|
||||
generateText: generateText,
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateText 执行一次非流式 Responses 调用并提取文本。
|
||||
func (c *ArkResponsesClient) GenerateText(ctx context.Context, messages []ArkResponsesMessage, options ArkResponsesOptions) (*ArkResponsesResult, error) {
|
||||
if c != nil && c.generateText != nil {
|
||||
return c.generateText(ctx, messages, options)
|
||||
}
|
||||
|
||||
req, err := c.buildRequest(messages, options)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
76
backend/services/llm/billing.go
Normal file
76
backend/services/llm/billing.go
Normal file
@@ -0,0 +1,76 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type billingContextKey struct{}
|
||||
|
||||
// BillingContext 描述一次 LLM 调用必需的计费上下文。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. 只承载计费、审计、幂等所需的调用上下文;
|
||||
// 2. 不承载 Temperature / MaxTokens 这类模型行为参数;
|
||||
// 3. 不混入 prompt 文本,避免把业务输入复制成第二份协议。
|
||||
type BillingContext struct {
|
||||
UserID uint64 `json:"user_id"`
|
||||
EventID string `json:"event_id"`
|
||||
Scene string `json:"scene"`
|
||||
RequestID string `json:"request_id"`
|
||||
ConversationID string `json:"conversation_id"`
|
||||
ModelAlias string `json:"model_alias"`
|
||||
SkipCharge bool `json:"skip_charge"`
|
||||
}
|
||||
|
||||
// Normalize 返回去空格后的 BillingContext 副本。
|
||||
func (c BillingContext) Normalize() BillingContext {
|
||||
c.EventID = strings.TrimSpace(c.EventID)
|
||||
c.Scene = strings.TrimSpace(c.Scene)
|
||||
c.RequestID = strings.TrimSpace(c.RequestID)
|
||||
c.ConversationID = strings.TrimSpace(c.ConversationID)
|
||||
c.ModelAlias = strings.TrimSpace(c.ModelAlias)
|
||||
return c
|
||||
}
|
||||
|
||||
// IsZero 判断是否完全没有注入计费上下文。
|
||||
func (c BillingContext) IsZero() bool {
|
||||
return c.UserID == 0 &&
|
||||
strings.TrimSpace(c.EventID) == "" &&
|
||||
strings.TrimSpace(c.Scene) == "" &&
|
||||
strings.TrimSpace(c.RequestID) == "" &&
|
||||
strings.TrimSpace(c.ConversationID) == "" &&
|
||||
strings.TrimSpace(c.ModelAlias) == "" &&
|
||||
!c.SkipCharge
|
||||
}
|
||||
|
||||
// WithBillingContext 把计费上下文挂进调用 ctx。
|
||||
//
|
||||
// 设计说明:
|
||||
// 1. 这次优先保持 GenerateText / GenerateJSON / Stream 原有签名基本不变;
|
||||
// 2. 计费必填信息不再塞进 GenerateOptions.Metadata,而是走强语义 ctx;
|
||||
// 3. 后续若统一切为显式 request struct,可继续复用本结构体,不改业务语义。
|
||||
func WithBillingContext(ctx context.Context, billing BillingContext) context.Context {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
billing = billing.Normalize()
|
||||
return context.WithValue(ctx, billingContextKey{}, billing)
|
||||
}
|
||||
|
||||
// BillingContextFromContext 读取调用上下文中的计费信息。
|
||||
func BillingContextFromContext(ctx context.Context) (BillingContext, bool) {
|
||||
if ctx == nil {
|
||||
return BillingContext{}, false
|
||||
}
|
||||
value := ctx.Value(billingContextKey{})
|
||||
billing, ok := value.(BillingContext)
|
||||
if !ok {
|
||||
return BillingContext{}, false
|
||||
}
|
||||
billing = billing.Normalize()
|
||||
if billing.IsZero() {
|
||||
return BillingContext{}, false
|
||||
}
|
||||
return billing, true
|
||||
}
|
||||
68
backend/services/llm/billing_compat.go
Normal file
68
backend/services/llm/billing_compat.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"crypto/sha1"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
llmcontracts "github.com/LoveLosita/smartflow/backend/shared/contracts/llm"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
)
|
||||
|
||||
// EnsureTextBillingIdentity 负责在老调用点未显式提供 event_id 时兜底补稳定事件号。
|
||||
//
|
||||
// 兼容策略:
|
||||
// 1. 只在 user_id、request_id 已具备且 event_id 为空时触发,避免覆盖显式幂等键;
|
||||
// 2. stage 优先从 GenerateOptions.Metadata["stage"] 读取,兼容 agent 现有大量调用点;
|
||||
// 3. 输入摘要使用 messages 的稳定哈希,确保同一 request_id 下不同阶段/不同输入不会串账。
|
||||
func EnsureTextBillingIdentity(billing BillingContext, options llmcontracts.GenerateOptions, messages []*schema.Message) BillingContext {
|
||||
return ensureBillingIdentity(billing, readStageFromMetadata(options.Metadata), messages)
|
||||
}
|
||||
|
||||
// EnsureResponsesBillingIdentity 负责给 Responses 调用补稳定事件号。
|
||||
func EnsureResponsesBillingIdentity(billing BillingContext, messages []llmcontracts.ResponsesMessage) BillingContext {
|
||||
return ensureBillingIdentity(billing, "", messages)
|
||||
}
|
||||
|
||||
func ensureBillingIdentity(billing BillingContext, stage string, payload any) BillingContext {
|
||||
billing = billing.Normalize()
|
||||
if billing.UserID == 0 || strings.TrimSpace(billing.RequestID) == "" || strings.TrimSpace(billing.EventID) != "" {
|
||||
return billing
|
||||
}
|
||||
|
||||
stage = strings.TrimSpace(stage)
|
||||
billing.EventID = buildStableBillingEventID(billing.RequestID, stage, hashPayload(payload))
|
||||
return billing
|
||||
}
|
||||
|
||||
func readStageFromMetadata(metadata map[string]any) string {
|
||||
if len(metadata) == 0 {
|
||||
return ""
|
||||
}
|
||||
raw, ok := metadata["stage"]
|
||||
if !ok || raw == nil {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(fmt.Sprint(raw))
|
||||
}
|
||||
|
||||
func hashPayload(payload any) string {
|
||||
raw, err := json.Marshal(payload)
|
||||
if err != nil || len(raw) == 0 {
|
||||
return ""
|
||||
}
|
||||
sum := sha1.Sum(raw)
|
||||
return hex.EncodeToString(sum[:8])
|
||||
}
|
||||
|
||||
func buildStableBillingEventID(requestID, stage, payloadDigest string) string {
|
||||
requestID = strings.TrimSpace(requestID)
|
||||
stage = strings.TrimSpace(stage)
|
||||
payloadDigest = strings.TrimSpace(payloadDigest)
|
||||
|
||||
base := requestID + "|" + stage + "|" + payloadDigest
|
||||
sum := sha1.Sum([]byte(base))
|
||||
return requestID + ":" + hex.EncodeToString(sum[:8])
|
||||
}
|
||||
107
backend/services/llm/dao/cache.go
Normal file
107
backend/services/llm/dao/cache.go
Normal file
@@ -0,0 +1,107 @@
|
||||
package dao
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/go-redis/redis/v8"
|
||||
)
|
||||
|
||||
const defaultCreditSnapshotTTL = 10 * time.Minute
|
||||
|
||||
// CreditBalanceSnapshot 是 LLM 准入守卫读取的余额快照。
|
||||
type CreditBalanceSnapshot struct {
|
||||
AvailableCredit int64 `json:"balance"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
// CacheDAO 只承载 LLM 服务私有的 Redis Key 读写。
|
||||
type CacheDAO struct {
|
||||
client *redis.Client
|
||||
}
|
||||
|
||||
func NewCacheDAO(client *redis.Client) *CacheDAO {
|
||||
return &CacheDAO{client: client}
|
||||
}
|
||||
|
||||
func userCreditBalanceSnapshotKey(userID uint64) string {
|
||||
return fmt.Sprintf("smartflow:credit_balance_snapshot:%d", userID)
|
||||
}
|
||||
|
||||
func userCreditBlockedKey(userID uint64) string {
|
||||
return fmt.Sprintf("smartflow:credit_blocked:%d", userID)
|
||||
}
|
||||
|
||||
func (d *CacheDAO) GetUserCreditBalanceSnapshot(ctx context.Context, userID uint64) (*CreditBalanceSnapshot, bool, error) {
|
||||
if d == nil || d.client == nil || userID == 0 {
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
value, err := d.client.Get(ctx, userCreditBalanceSnapshotKey(userID)).Result()
|
||||
if errors.Is(err, redis.Nil) {
|
||||
return nil, false, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
var snapshot CreditBalanceSnapshot
|
||||
if err = json.Unmarshal([]byte(value), &snapshot); err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
return &snapshot, true, nil
|
||||
}
|
||||
|
||||
func (d *CacheDAO) SetUserCreditBalanceSnapshot(ctx context.Context, userID uint64, snapshot CreditBalanceSnapshot, ttl time.Duration) error {
|
||||
if d == nil || d.client == nil || userID == 0 {
|
||||
return nil
|
||||
}
|
||||
if ttl <= 0 {
|
||||
ttl = defaultCreditSnapshotTTL
|
||||
}
|
||||
|
||||
raw, err := json.Marshal(snapshot)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return d.client.Set(ctx, userCreditBalanceSnapshotKey(userID), raw, ttl).Err()
|
||||
}
|
||||
|
||||
func (d *CacheDAO) DeleteUserCreditBalanceSnapshot(ctx context.Context, userID uint64) error {
|
||||
if d == nil || d.client == nil || userID == 0 {
|
||||
return nil
|
||||
}
|
||||
return d.client.Del(ctx, userCreditBalanceSnapshotKey(userID)).Err()
|
||||
}
|
||||
|
||||
func (d *CacheDAO) IsUserCreditBlocked(ctx context.Context, userID uint64) (bool, error) {
|
||||
if d == nil || d.client == nil || userID == 0 {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
value, err := d.client.Get(ctx, userCreditBlockedKey(userID)).Result()
|
||||
if errors.Is(err, redis.Nil) {
|
||||
return false, nil
|
||||
}
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return value == "1", nil
|
||||
}
|
||||
|
||||
func (d *CacheDAO) SetUserCreditBlocked(ctx context.Context, userID uint64, ttl time.Duration) error {
|
||||
if d == nil || d.client == nil || userID == 0 {
|
||||
return nil
|
||||
}
|
||||
return d.client.Set(ctx, userCreditBlockedKey(userID), "1", ttl).Err()
|
||||
}
|
||||
|
||||
func (d *CacheDAO) DeleteUserCreditBlocked(ctx context.Context, userID uint64) error {
|
||||
if d == nil || d.client == nil || userID == 0 {
|
||||
return nil
|
||||
}
|
||||
return d.client.Del(ctx, userCreditBlockedKey(userID)).Err()
|
||||
}
|
||||
42
backend/services/llm/dao/connect.go
Normal file
42
backend/services/llm/dao/connect.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package dao
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/LoveLosita/smartflow/backend/services/runtime/model"
|
||||
mysqlinfra "github.com/LoveLosita/smartflow/backend/shared/infra/mysql"
|
||||
outboxinfra "github.com/LoveLosita/smartflow/backend/shared/infra/outbox"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// OpenDBFromConfig 负责打开 LLM 独立服务需要的数据库连接。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. 只初始化通用 MySQL 连接并补齐 LLM 自己的 outbox 表;
|
||||
// 2. 不负责启动 Kafka relay,也不负责装配 Redis/模型客户端;
|
||||
// 3. 当前阶段不额外声明业务私表,避免和主代理后续 Credit 表迁移交叉。
|
||||
func OpenDBFromConfig() (*gorm.DB, error) {
|
||||
db, err := mysqlinfra.OpenDBFromConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = autoMigrateLLMOutboxTable(db); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return db, nil
|
||||
}
|
||||
|
||||
func autoMigrateLLMOutboxTable(db *gorm.DB) error {
|
||||
if db == nil {
|
||||
return fmt.Errorf("llm database is not initialized")
|
||||
}
|
||||
|
||||
cfg, ok := outboxinfra.ResolveServiceConfig(outboxinfra.ServiceLLM)
|
||||
if !ok {
|
||||
return fmt.Errorf("resolve llm outbox config failed")
|
||||
}
|
||||
if err := db.Table(cfg.TableName).AutoMigrate(&model.AgentOutboxMessage{}); err != nil {
|
||||
return fmt.Errorf("auto migrate llm outbox table failed for %s (%s): %w", cfg.Name, cfg.TableName, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
53
backend/services/llm/dao/pricing.go
Normal file
53
backend/services/llm/dao/pricing.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package dao
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const creditPriceRuleStatusActive = "active"
|
||||
|
||||
type CreditPriceRule struct {
|
||||
ID uint64 `gorm:"column:id"`
|
||||
Scene string `gorm:"column:scene"`
|
||||
ProviderName string `gorm:"column:provider_name"`
|
||||
ModelName string `gorm:"column:model_name"`
|
||||
InputPriceMicros int64 `gorm:"column:input_price_micros"`
|
||||
OutputPriceMicros int64 `gorm:"column:output_price_micros"`
|
||||
CachedPriceMicros int64 `gorm:"column:cached_price_micros"`
|
||||
ReasoningPriceMicros int64 `gorm:"column:reasoning_price_micros"`
|
||||
CreditPerYuan int64 `gorm:"column:credit_per_yuan"`
|
||||
Status string `gorm:"column:status"`
|
||||
Priority int `gorm:"column:priority"`
|
||||
Description string `gorm:"column:description"`
|
||||
}
|
||||
|
||||
func (CreditPriceRule) TableName() string {
|
||||
return "credit_price_rules"
|
||||
}
|
||||
|
||||
type PriceRuleDAO struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewPriceRuleDAO(db *gorm.DB) *PriceRuleDAO {
|
||||
return &PriceRuleDAO{db: db}
|
||||
}
|
||||
|
||||
func (d *PriceRuleDAO) ListActiveRules(ctx context.Context) ([]CreditPriceRule, error) {
|
||||
if d == nil || d.db == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var rules []CreditPriceRule
|
||||
err := d.db.WithContext(ctx).
|
||||
Model(&CreditPriceRule{}).
|
||||
Where("status = ?", creditPriceRuleStatusActive).
|
||||
Order("priority DESC, id ASC").
|
||||
Find(&rules).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return rules, nil
|
||||
}
|
||||
152
backend/services/llm/guard.go
Normal file
152
backend/services/llm/guard.go
Normal file
@@ -0,0 +1,152 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
llmdao "github.com/LoveLosita/smartflow/backend/services/llm/dao"
|
||||
creditcontracts "github.com/LoveLosita/smartflow/backend/shared/contracts/creditstore"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrRuntimeServiceNotReady = errors.New("llm runtime service dependency not initialized")
|
||||
ErrUnsupportedModelAlias = errors.New("llm model alias is unsupported")
|
||||
ErrCreditBalanceBlocked = errors.New("credit balance is insufficient")
|
||||
)
|
||||
|
||||
const (
|
||||
defaultCreditBlockedTTL = 5 * time.Minute
|
||||
defaultCreditSnapshotTimeout = time.Second
|
||||
)
|
||||
|
||||
type CreditBalanceSnapshotProvider interface {
|
||||
GetCreditBalanceSnapshot(ctx context.Context, userID uint64) (*creditcontracts.CreditBalanceSnapshot, error)
|
||||
}
|
||||
|
||||
// CreditBalanceGuard 负责在真正发起 LLM 调用前做一次轻量余额准入。
|
||||
type CreditBalanceGuard struct {
|
||||
cacheDAO *llmdao.CacheDAO
|
||||
snapshotProvider CreditBalanceSnapshotProvider
|
||||
blockTTL time.Duration
|
||||
snapshotTimeout time.Duration
|
||||
}
|
||||
|
||||
type CreditBalanceGuardOptions struct {
|
||||
CacheDAO *llmdao.CacheDAO
|
||||
SnapshotProvider CreditBalanceSnapshotProvider
|
||||
BlockTTL time.Duration
|
||||
SnapshotTimeout time.Duration
|
||||
}
|
||||
|
||||
func NewCreditBalanceGuard(opts CreditBalanceGuardOptions) *CreditBalanceGuard {
|
||||
blockTTL := opts.BlockTTL
|
||||
if blockTTL <= 0 {
|
||||
blockTTL = defaultCreditBlockedTTL
|
||||
}
|
||||
snapshotTimeout := opts.SnapshotTimeout
|
||||
if snapshotTimeout <= 0 {
|
||||
snapshotTimeout = defaultCreditSnapshotTimeout
|
||||
}
|
||||
return &CreditBalanceGuard{
|
||||
cacheDAO: opts.CacheDAO,
|
||||
snapshotProvider: opts.SnapshotProvider,
|
||||
blockTTL: blockTTL,
|
||||
snapshotTimeout: snapshotTimeout,
|
||||
}
|
||||
}
|
||||
|
||||
// Guard 只做 Redis 快照级别的 fail-open 准入检查。
|
||||
//
|
||||
// 设计说明:
|
||||
// 1. 先查 blocked key,命中则直接拒绝,避免每次都回源余额快照;
|
||||
// 2. 再查余额快照;若快照明确余额 <= 0,则写 blocked key 并拒绝;
|
||||
// 3. Redis 读失败或快照缺失时保持放行,避免基础设施抖动直接阻断全部 LLM 调用。
|
||||
func (g *CreditBalanceGuard) Guard(ctx context.Context, billing BillingContext) error {
|
||||
if g == nil || g.cacheDAO == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
billing = billing.Normalize()
|
||||
if billing.UserID == 0 || billing.SkipCharge {
|
||||
return nil
|
||||
}
|
||||
|
||||
blocked, err := g.cacheDAO.IsUserCreditBlocked(ctx, billing.UserID)
|
||||
if err != nil {
|
||||
log.Printf("llm credit guard read blocked key failed: user_id=%d err=%v", billing.UserID, err)
|
||||
return nil
|
||||
}
|
||||
if blocked {
|
||||
return ErrCreditBalanceBlocked
|
||||
}
|
||||
|
||||
snapshot, found, err := g.cacheDAO.GetUserCreditBalanceSnapshot(ctx, billing.UserID)
|
||||
if err != nil {
|
||||
log.Printf("llm credit guard read balance snapshot failed: user_id=%d err=%v", billing.UserID, err)
|
||||
return nil
|
||||
}
|
||||
if !found || snapshot == nil {
|
||||
snapshot, err = g.fetchSnapshot(ctx, billing.UserID)
|
||||
if err != nil {
|
||||
log.Printf("llm credit guard fetch balance snapshot failed: user_id=%d err=%v", billing.UserID, err)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
if snapshot == nil {
|
||||
return nil
|
||||
}
|
||||
if snapshot.AvailableCredit > 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err = g.cacheDAO.SetUserCreditBlocked(ctx, billing.UserID, g.blockTTL); err != nil {
|
||||
log.Printf("llm credit guard set blocked key failed: user_id=%d err=%v", billing.UserID, err)
|
||||
}
|
||||
return ErrCreditBalanceBlocked
|
||||
}
|
||||
|
||||
func (g *CreditBalanceGuard) fetchSnapshot(ctx context.Context, userID uint64) (*llmdao.CreditBalanceSnapshot, error) {
|
||||
if g == nil || g.snapshotProvider == nil || userID == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
fetchCtx := ctx
|
||||
if fetchCtx == nil {
|
||||
fetchCtx = context.Background()
|
||||
}
|
||||
if g.snapshotTimeout > 0 {
|
||||
var cancel context.CancelFunc
|
||||
fetchCtx, cancel = context.WithTimeout(context.WithoutCancel(fetchCtx), g.snapshotTimeout)
|
||||
defer cancel()
|
||||
}
|
||||
|
||||
snapshotView, err := g.snapshotProvider.GetCreditBalanceSnapshot(fetchCtx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if snapshotView == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
snapshot := &llmdao.CreditBalanceSnapshot{
|
||||
AvailableCredit: snapshotView.Balance,
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
if err = g.cacheDAO.SetUserCreditBalanceSnapshot(fetchCtx, userID, *snapshot, 0); err != nil {
|
||||
log.Printf("llm credit guard backfill balance snapshot failed: user_id=%d err=%v", userID, err)
|
||||
}
|
||||
|
||||
if snapshotView.IsBlocked || snapshotView.Balance <= 0 {
|
||||
if err = g.cacheDAO.SetUserCreditBlocked(fetchCtx, userID, g.blockTTL); err != nil {
|
||||
log.Printf("llm credit guard backfill blocked key failed: user_id=%d err=%v", userID, err)
|
||||
}
|
||||
return snapshot, nil
|
||||
}
|
||||
|
||||
if err = g.cacheDAO.DeleteUserCreditBlocked(fetchCtx, userID); err != nil {
|
||||
log.Printf("llm credit guard clear blocked key failed: user_id=%d err=%v", userID, err)
|
||||
}
|
||||
return snapshot, nil
|
||||
}
|
||||
211
backend/services/llm/outbox.go
Normal file
211
backend/services/llm/outbox.go
Normal file
@@ -0,0 +1,211 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
llmcontracts "github.com/LoveLosita/smartflow/backend/shared/contracts/llm"
|
||||
sharedevents "github.com/LoveLosita/smartflow/backend/shared/events"
|
||||
outboxinfra "github.com/LoveLosita/smartflow/backend/shared/infra/outbox"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultOutboxMaxRetry = 20
|
||||
defaultBillingPersistWindow = 2 * time.Second
|
||||
)
|
||||
|
||||
// ChargeRecorder 负责把一次已完成的 LLM usage 写入 LLM 自己的 outbox。
|
||||
type ChargeRecorder struct {
|
||||
publisher *outboxinfra.RepositoryPublisher
|
||||
providerName string
|
||||
pricing UsagePricingResolver
|
||||
}
|
||||
|
||||
type ChargeRecorderOptions struct {
|
||||
Repo *outboxinfra.Repository
|
||||
MaxRetry int
|
||||
ProviderName string
|
||||
Pricing UsagePricingResolver
|
||||
}
|
||||
|
||||
func NewChargeRecorder(opts ChargeRecorderOptions) (*ChargeRecorder, error) {
|
||||
if err := RegisterCreditChargeRoute(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
providerName := strings.TrimSpace(opts.ProviderName)
|
||||
if providerName == "" {
|
||||
providerName = llmcontracts.ProviderNameArk
|
||||
}
|
||||
|
||||
if opts.Repo == nil {
|
||||
return &ChargeRecorder{providerName: providerName}, nil
|
||||
}
|
||||
|
||||
maxRetry := opts.MaxRetry
|
||||
if maxRetry <= 0 {
|
||||
maxRetry = defaultOutboxMaxRetry
|
||||
}
|
||||
return &ChargeRecorder{
|
||||
// 1. 当前 outbox infra 仍是“由归属服务自己 dispatch + consume 自己的 outbox”模型。
|
||||
// 2. 因此这里必须让 Repository 按事件归属把 credit 事件写进 token-store 的 outbox,
|
||||
// 不能再强绑到 llm 自己的 route,否则消息只会停在 published 而无人消费。
|
||||
publisher: outboxinfra.NewRepositoryPublisher(opts.Repo, maxRetry),
|
||||
providerName: providerName,
|
||||
pricing: opts.Pricing,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func RegisterCreditChargeRoute() error {
|
||||
return outboxinfra.RegisterEventService(sharedevents.CreditChargeRequestedEventType, outboxinfra.ServiceTokenStore)
|
||||
}
|
||||
|
||||
func (r *ChargeRecorder) RecordTextUsage(ctx context.Context, billing BillingContext, modelAlias, modelName, defaultScene string, usage *schema.TokenUsage) error {
|
||||
if usage == nil {
|
||||
return nil
|
||||
}
|
||||
return r.publish(ctx, billing, publishUsageInput{
|
||||
ModelAlias: modelAlias,
|
||||
ModelName: modelName,
|
||||
DefaultScene: defaultScene,
|
||||
InputTokens: int64(usage.PromptTokens),
|
||||
OutputTokens: int64(usage.CompletionTokens),
|
||||
CachedTokens: int64(usage.PromptTokenDetails.CachedTokens),
|
||||
ReasoningTokens: int64(usage.CompletionTokensDetails.ReasoningTokens),
|
||||
TotalTokens: int64(usage.TotalTokens),
|
||||
})
|
||||
}
|
||||
|
||||
func (r *ChargeRecorder) RecordResponsesUsage(ctx context.Context, billing BillingContext, modelAlias, modelName, defaultScene string, usage *ArkResponsesUsage) error {
|
||||
if usage == nil {
|
||||
return nil
|
||||
}
|
||||
return r.publish(ctx, billing, publishUsageInput{
|
||||
ModelAlias: modelAlias,
|
||||
ModelName: modelName,
|
||||
DefaultScene: defaultScene,
|
||||
InputTokens: usage.InputTokens,
|
||||
OutputTokens: usage.OutputTokens,
|
||||
TotalTokens: usage.TotalTokens,
|
||||
})
|
||||
}
|
||||
|
||||
type publishUsageInput struct {
|
||||
ModelAlias string
|
||||
ModelName string
|
||||
DefaultScene string
|
||||
InputTokens int64
|
||||
OutputTokens int64
|
||||
CachedTokens int64
|
||||
ReasoningTokens int64
|
||||
TotalTokens int64
|
||||
}
|
||||
|
||||
func (r *ChargeRecorder) publish(ctx context.Context, billing BillingContext, input publishUsageInput) error {
|
||||
if r == nil || r.publisher == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
billing = billing.Normalize()
|
||||
if billing.UserID == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
eventID := firstNonEmptyString(strings.TrimSpace(billing.EventID), uuid.NewString())
|
||||
requestID := firstNonEmptyString(strings.TrimSpace(billing.RequestID), eventID)
|
||||
scene := firstNonEmptyString(strings.TrimSpace(billing.Scene), strings.TrimSpace(input.DefaultScene))
|
||||
modelAlias := firstNonEmptyString(strings.TrimSpace(billing.ModelAlias), strings.TrimSpace(input.ModelAlias))
|
||||
modelName := firstNonEmptyString(strings.TrimSpace(input.ModelName), modelAlias)
|
||||
totalTokens := input.TotalTokens
|
||||
if totalTokens <= 0 {
|
||||
totalTokens = input.InputTokens + input.OutputTokens
|
||||
}
|
||||
|
||||
payload := sharedevents.CreditChargeRequestedPayload{
|
||||
EventID: eventID,
|
||||
UserID: billing.UserID,
|
||||
Scene: scene,
|
||||
RequestID: requestID,
|
||||
ConversationID: strings.TrimSpace(billing.ConversationID),
|
||||
ModelAlias: modelAlias,
|
||||
ProviderName: r.providerName,
|
||||
ModelName: modelName,
|
||||
InputTokens: input.InputTokens,
|
||||
OutputTokens: input.OutputTokens,
|
||||
CachedTokens: input.CachedTokens,
|
||||
ReasoningTokens: input.ReasoningTokens,
|
||||
TotalTokens: totalTokens,
|
||||
RMBCostMicros: 0,
|
||||
CreditCost: 0,
|
||||
TriggeredAt: time.Now(),
|
||||
SkipCharge: billing.SkipCharge,
|
||||
}
|
||||
if !billing.SkipCharge {
|
||||
quote, err := r.resolvePriceQuote(ctx, payload)
|
||||
if err != nil {
|
||||
log.Printf("llm price quote resolve failed: event_id=%s user_id=%d err=%v", payload.EventID, payload.UserID, err)
|
||||
} else {
|
||||
payload.RMBCostMicros = quote.RMBCostMicros
|
||||
payload.CreditCost = quote.CreditCost
|
||||
}
|
||||
}
|
||||
if err := payload.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
recordCtx, cancel := detachedBillingContext(ctx)
|
||||
defer cancel()
|
||||
return r.publisher.Publish(recordCtx, outboxinfra.PublishRequest{
|
||||
EventID: payload.EventID,
|
||||
EventType: sharedevents.CreditChargeRequestedEventType,
|
||||
EventVersion: sharedevents.CreditChargeEventVersion,
|
||||
MessageKey: payload.MessageKey(),
|
||||
AggregateID: payload.AggregateID(),
|
||||
Payload: payload,
|
||||
})
|
||||
}
|
||||
|
||||
func (r *ChargeRecorder) resolvePriceQuote(ctx context.Context, payload sharedevents.CreditChargeRequestedPayload) (UsagePriceQuote, error) {
|
||||
if r == nil || r.pricing == nil {
|
||||
return UsagePriceQuote{}, nil
|
||||
}
|
||||
|
||||
return r.pricing.Resolve(ctx, UsagePricingInput{
|
||||
Scene: payload.Scene,
|
||||
ProviderName: payload.ProviderName,
|
||||
ModelName: payload.ModelName,
|
||||
InputTokens: payload.InputTokens,
|
||||
OutputTokens: payload.OutputTokens,
|
||||
CachedTokens: payload.CachedTokens,
|
||||
ReasoningTokens: payload.ReasoningTokens,
|
||||
})
|
||||
}
|
||||
|
||||
func detachedBillingContext(ctx context.Context) (context.Context, context.CancelFunc) {
|
||||
base := context.Background()
|
||||
if ctx != nil {
|
||||
base = context.WithoutCancel(ctx)
|
||||
}
|
||||
return context.WithTimeout(base, defaultBillingPersistWindow)
|
||||
}
|
||||
|
||||
func logChargeRecordError(scene string, err error) {
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
log.Printf("llm charge record failed: scene=%s err=%v", strings.TrimSpace(scene), err)
|
||||
}
|
||||
|
||||
func firstNonEmptyString(values ...string) string {
|
||||
for _, value := range values {
|
||||
trimmed := strings.TrimSpace(value)
|
||||
if trimmed != "" {
|
||||
return trimmed
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
207
backend/services/llm/pricing.go
Normal file
207
backend/services/llm/pricing.go
Normal file
@@ -0,0 +1,207 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
llmdao "github.com/LoveLosita/smartflow/backend/services/llm/dao"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultPriceRuleCacheTTL = time.Minute
|
||||
tokenPriceScalePer1K = int64(1000)
|
||||
rmbMicrosPerYuan = int64(1_000_000)
|
||||
)
|
||||
|
||||
type UsagePricingInput struct {
|
||||
Scene string
|
||||
ProviderName string
|
||||
ModelName string
|
||||
InputTokens int64
|
||||
OutputTokens int64
|
||||
CachedTokens int64
|
||||
ReasoningTokens int64
|
||||
}
|
||||
|
||||
type UsagePriceQuote struct {
|
||||
RuleID uint64
|
||||
RMBCostMicros int64
|
||||
CreditCost int64
|
||||
MatchedScene string
|
||||
MatchedProvider string
|
||||
MatchedModel string
|
||||
}
|
||||
|
||||
type UsagePricingResolver interface {
|
||||
Resolve(ctx context.Context, input UsagePricingInput) (UsagePriceQuote, error)
|
||||
}
|
||||
|
||||
type CreditPriceResolverOptions struct {
|
||||
DAO *llmdao.PriceRuleDAO
|
||||
CacheTTL time.Duration
|
||||
}
|
||||
|
||||
type CreditPriceResolver struct {
|
||||
dao *llmdao.PriceRuleDAO
|
||||
cacheTTL time.Duration
|
||||
|
||||
mu sync.RWMutex
|
||||
cachedAt time.Time
|
||||
cachedSet []llmdao.CreditPriceRule
|
||||
}
|
||||
|
||||
func NewCreditPriceResolver(opts CreditPriceResolverOptions) *CreditPriceResolver {
|
||||
cacheTTL := opts.CacheTTL
|
||||
if cacheTTL <= 0 {
|
||||
cacheTTL = defaultPriceRuleCacheTTL
|
||||
}
|
||||
return &CreditPriceResolver{
|
||||
dao: opts.DAO,
|
||||
cacheTTL: cacheTTL,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *CreditPriceResolver) Resolve(ctx context.Context, input UsagePricingInput) (UsagePriceQuote, error) {
|
||||
if r == nil || r.dao == nil {
|
||||
return UsagePriceQuote{}, nil
|
||||
}
|
||||
|
||||
rules, err := r.loadRules(ctx)
|
||||
if err != nil {
|
||||
return UsagePriceQuote{}, err
|
||||
}
|
||||
if len(rules) == 0 {
|
||||
return UsagePriceQuote{}, nil
|
||||
}
|
||||
|
||||
scene := strings.TrimSpace(input.Scene)
|
||||
providerName := strings.TrimSpace(input.ProviderName)
|
||||
modelName := strings.TrimSpace(input.ModelName)
|
||||
|
||||
for _, rule := range rules {
|
||||
if !matchesPriceRuleField(rule.Scene, scene) {
|
||||
continue
|
||||
}
|
||||
if !matchesPriceRuleField(rule.ProviderName, providerName) {
|
||||
continue
|
||||
}
|
||||
if !matchesPriceRuleField(rule.ModelName, modelName) {
|
||||
continue
|
||||
}
|
||||
return quoteUsagePrice(rule, input), nil
|
||||
}
|
||||
|
||||
return UsagePriceQuote{}, nil
|
||||
}
|
||||
|
||||
func (r *CreditPriceResolver) loadRules(ctx context.Context) ([]llmdao.CreditPriceRule, error) {
|
||||
now := time.Now()
|
||||
|
||||
r.mu.RLock()
|
||||
if len(r.cachedSet) > 0 && now.Sub(r.cachedAt) < r.cacheTTL {
|
||||
rules := clonePriceRules(r.cachedSet)
|
||||
r.mu.RUnlock()
|
||||
return rules, nil
|
||||
}
|
||||
r.mu.RUnlock()
|
||||
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
if len(r.cachedSet) > 0 && now.Sub(r.cachedAt) < r.cacheTTL {
|
||||
return clonePriceRules(r.cachedSet), nil
|
||||
}
|
||||
|
||||
rules, err := r.dao.ListActiveRules(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
r.cachedSet = clonePriceRules(rules)
|
||||
r.cachedAt = now
|
||||
return clonePriceRules(r.cachedSet), nil
|
||||
}
|
||||
|
||||
func clonePriceRules(input []llmdao.CreditPriceRule) []llmdao.CreditPriceRule {
|
||||
if len(input) == 0 {
|
||||
return nil
|
||||
}
|
||||
output := make([]llmdao.CreditPriceRule, len(input))
|
||||
copy(output, input)
|
||||
return output
|
||||
}
|
||||
|
||||
func matchesPriceRuleField(ruleValue string, actual string) bool {
|
||||
ruleValue = strings.TrimSpace(ruleValue)
|
||||
actual = strings.TrimSpace(actual)
|
||||
|
||||
if ruleValue == "" || ruleValue == "*" {
|
||||
return true
|
||||
}
|
||||
return strings.EqualFold(ruleValue, actual)
|
||||
}
|
||||
|
||||
func quoteUsagePrice(rule llmdao.CreditPriceRule, input UsagePricingInput) UsagePriceQuote {
|
||||
inputTokens := maxInt64(input.InputTokens, 0)
|
||||
outputTokens := maxInt64(input.OutputTokens, 0)
|
||||
cachedTokens := clampInt64(input.CachedTokens, 0, inputTokens)
|
||||
reasoningTokens := clampInt64(input.ReasoningTokens, 0, outputTokens)
|
||||
|
||||
nonCachedInputTokens := inputTokens - cachedTokens
|
||||
nonReasoningOutputTokens := outputTokens - reasoningTokens
|
||||
|
||||
cachedPriceMicros := rule.CachedPriceMicros
|
||||
if cachedPriceMicros <= 0 {
|
||||
cachedPriceMicros = rule.InputPriceMicros
|
||||
}
|
||||
reasoningPriceMicros := rule.ReasoningPriceMicros
|
||||
if reasoningPriceMicros <= 0 {
|
||||
reasoningPriceMicros = rule.OutputPriceMicros
|
||||
}
|
||||
|
||||
totalMicrosScaled := nonCachedInputTokens*maxInt64(rule.InputPriceMicros, 0) +
|
||||
cachedTokens*maxInt64(cachedPriceMicros, 0) +
|
||||
nonReasoningOutputTokens*maxInt64(rule.OutputPriceMicros, 0) +
|
||||
reasoningTokens*maxInt64(reasoningPriceMicros, 0)
|
||||
|
||||
rmbCostMicros := ceilDivInt64(totalMicrosScaled, tokenPriceScalePer1K)
|
||||
creditCost := int64(0)
|
||||
if rmbCostMicros > 0 && rule.CreditPerYuan > 0 {
|
||||
creditCost = ceilDivInt64(rmbCostMicros*rule.CreditPerYuan, rmbMicrosPerYuan)
|
||||
}
|
||||
|
||||
return UsagePriceQuote{
|
||||
RuleID: rule.ID,
|
||||
RMBCostMicros: rmbCostMicros,
|
||||
CreditCost: creditCost,
|
||||
MatchedScene: strings.TrimSpace(rule.Scene),
|
||||
MatchedProvider: strings.TrimSpace(rule.ProviderName),
|
||||
MatchedModel: strings.TrimSpace(rule.ModelName),
|
||||
}
|
||||
}
|
||||
|
||||
func ceilDivInt64(numerator int64, denominator int64) int64 {
|
||||
if numerator <= 0 || denominator <= 0 {
|
||||
return 0
|
||||
}
|
||||
return (numerator + denominator - 1) / denominator
|
||||
}
|
||||
|
||||
func clampInt64(value int64, minValue int64, maxValue int64) int64 {
|
||||
if value < minValue {
|
||||
return minValue
|
||||
}
|
||||
if value > maxValue {
|
||||
return maxValue
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
func maxInt64(value int64, minValue int64) int64 {
|
||||
if value < minValue {
|
||||
return minValue
|
||||
}
|
||||
return value
|
||||
}
|
||||
71
backend/services/llm/rpc/errors.go
Normal file
71
backend/services/llm/rpc/errors.go
Normal file
@@ -0,0 +1,71 @@
|
||||
package rpc
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"log"
|
||||
"strings"
|
||||
|
||||
llmservice "github.com/LoveLosita/smartflow/backend/services/llm"
|
||||
"github.com/LoveLosita/smartflow/backend/shared/respond"
|
||||
"google.golang.org/genproto/googleapis/rpc/errdetails"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
const llmErrorDomain = "smartflow.llm"
|
||||
|
||||
func grpcErrorFromServiceError(err error) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var resp respond.Response
|
||||
if errors.As(err, &resp) {
|
||||
return grpcErrorFromResponse(resp)
|
||||
}
|
||||
|
||||
switch {
|
||||
case errors.Is(err, llmservice.ErrUnsupportedModelAlias):
|
||||
return status.Error(codes.InvalidArgument, err.Error())
|
||||
case errors.Is(err, llmservice.ErrCreditBalanceBlocked):
|
||||
return status.Error(codes.ResourceExhausted, err.Error())
|
||||
case errors.Is(err, llmservice.ErrRuntimeServiceNotReady):
|
||||
return status.Error(codes.FailedPrecondition, err.Error())
|
||||
}
|
||||
|
||||
log.Printf("llm rpc internal error: %v", err)
|
||||
return status.Error(codes.Internal, "llm service internal error")
|
||||
}
|
||||
|
||||
func grpcErrorFromResponse(resp respond.Response) error {
|
||||
code := grpcCodeFromRespondStatus(resp.Status)
|
||||
message := strings.TrimSpace(resp.Info)
|
||||
if message == "" {
|
||||
message = strings.TrimSpace(resp.Status)
|
||||
}
|
||||
|
||||
st := status.New(code, message)
|
||||
detail := &errdetails.ErrorInfo{
|
||||
Domain: llmErrorDomain,
|
||||
Reason: resp.Status,
|
||||
Metadata: map[string]string{
|
||||
"info": resp.Info,
|
||||
},
|
||||
}
|
||||
withDetails, err := st.WithDetails(detail)
|
||||
if err != nil {
|
||||
return st.Err()
|
||||
}
|
||||
return withDetails.Err()
|
||||
}
|
||||
|
||||
func grpcCodeFromRespondStatus(statusValue string) codes.Code {
|
||||
switch strings.TrimSpace(statusValue) {
|
||||
case respond.MissingParam.Status, respond.WrongParamType.Status:
|
||||
return codes.InvalidArgument
|
||||
}
|
||||
if strings.HasPrefix(strings.TrimSpace(statusValue), "5") {
|
||||
return codes.Internal
|
||||
}
|
||||
return codes.InvalidArgument
|
||||
}
|
||||
122
backend/services/llm/rpc/handler.go
Normal file
122
backend/services/llm/rpc/handler.go
Normal file
@@ -0,0 +1,122 @@
|
||||
package rpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
|
||||
llmservice "github.com/LoveLosita/smartflow/backend/services/llm"
|
||||
llmcontracts "github.com/LoveLosita/smartflow/backend/shared/contracts/llm"
|
||||
"github.com/LoveLosita/smartflow/backend/shared/respond"
|
||||
)
|
||||
|
||||
type Handler struct {
|
||||
UnimplementedLLMServer
|
||||
svc *llmservice.RuntimeService
|
||||
}
|
||||
|
||||
func NewHandler(svc *llmservice.RuntimeService) *Handler {
|
||||
return &Handler{svc: svc}
|
||||
}
|
||||
|
||||
func (h *Handler) Ping(ctx context.Context, req *llmcontracts.PingRequest) (*llmcontracts.PingResponse, error) {
|
||||
if err := h.ensureReady(req); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &llmcontracts.PingResponse{}, nil
|
||||
}
|
||||
|
||||
func (h *Handler) GenerateText(ctx context.Context, req *llmcontracts.TextRequest) (*llmcontracts.TextResponse, error) {
|
||||
if err := h.ensureReady(req); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result, err := h.svc.GenerateText(ctx, *req)
|
||||
if err != nil {
|
||||
return nil, grpcErrorFromServiceError(err)
|
||||
}
|
||||
return &llmcontracts.TextResponse{Result: llmserviceToContractTextResult(result)}, nil
|
||||
}
|
||||
|
||||
func (h *Handler) StreamText(req *llmcontracts.StreamTextRequest, stream LLM_StreamTextServer) error {
|
||||
if err := h.ensureReady(req); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
reader, err := h.svc.StreamText(stream.Context(), *req)
|
||||
if err != nil {
|
||||
return grpcErrorFromServiceError(err)
|
||||
}
|
||||
if reader == nil {
|
||||
return grpcErrorFromServiceError(llmservice.ErrRuntimeServiceNotReady)
|
||||
}
|
||||
defer reader.Close()
|
||||
|
||||
for {
|
||||
message, recvErr := reader.Recv()
|
||||
if recvErr != nil {
|
||||
if errors.Is(recvErr, io.EOF) {
|
||||
return nil
|
||||
}
|
||||
return grpcErrorFromServiceError(recvErr)
|
||||
}
|
||||
if message == nil {
|
||||
continue
|
||||
}
|
||||
if err = stream.Send(&llmcontracts.StreamChunk{Message: message}); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) GenerateResponsesText(ctx context.Context, req *llmcontracts.ResponsesRequest) (*llmcontracts.ResponsesResponse, error) {
|
||||
if err := h.ensureReady(req); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result, err := h.svc.GenerateResponsesText(ctx, *req)
|
||||
if err != nil {
|
||||
return nil, grpcErrorFromServiceError(err)
|
||||
}
|
||||
return &llmcontracts.ResponsesResponse{Result: llmserviceToContractResponsesResult(result)}, nil
|
||||
}
|
||||
|
||||
func (h *Handler) ensureReady(req any) error {
|
||||
if h == nil || h.svc == nil {
|
||||
return grpcErrorFromServiceError(llmservice.ErrRuntimeServiceNotReady)
|
||||
}
|
||||
if req == nil {
|
||||
return grpcErrorFromServiceError(respond.MissingParam)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func llmserviceToContractTextResult(result *llmservice.TextResult) *llmcontracts.TextResult {
|
||||
if result == nil {
|
||||
return nil
|
||||
}
|
||||
return &llmcontracts.TextResult{
|
||||
Text: result.Text,
|
||||
Usage: llmservice.CloneUsage(result.Usage),
|
||||
FinishReason: result.FinishReason,
|
||||
}
|
||||
}
|
||||
|
||||
func llmserviceToContractResponsesResult(result *llmservice.ArkResponsesResult) *llmcontracts.ResponsesResult {
|
||||
if result == nil {
|
||||
return nil
|
||||
}
|
||||
output := &llmcontracts.ResponsesResult{
|
||||
Text: result.Text,
|
||||
Status: result.Status,
|
||||
IncompleteReason: result.IncompleteReason,
|
||||
ErrorCode: result.ErrorCode,
|
||||
ErrorMessage: result.ErrorMessage,
|
||||
}
|
||||
if result.Usage != nil {
|
||||
output.Usage = &llmcontracts.ResponsesUsage{
|
||||
InputTokens: result.Usage.InputTokens,
|
||||
OutputTokens: result.Usage.OutputTokens,
|
||||
TotalTokens: result.Usage.TotalTokens,
|
||||
}
|
||||
}
|
||||
return output
|
||||
}
|
||||
38
backend/services/llm/rpc/json_codec.go
Normal file
38
backend/services/llm/rpc/json_codec.go
Normal file
@@ -0,0 +1,38 @@
|
||||
package rpc
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/encoding"
|
||||
)
|
||||
|
||||
const jsonCodecName = "smartflow-json"
|
||||
|
||||
type jsonCodec struct{}
|
||||
|
||||
func init() {
|
||||
encoding.RegisterCodec(jsonCodec{})
|
||||
}
|
||||
|
||||
func (jsonCodec) Marshal(v any) ([]byte, error) {
|
||||
return json.Marshal(v)
|
||||
}
|
||||
|
||||
func (jsonCodec) Unmarshal(data []byte, v any) error {
|
||||
return json.Unmarshal(data, v)
|
||||
}
|
||||
|
||||
func (jsonCodec) Name() string {
|
||||
return jsonCodecName
|
||||
}
|
||||
|
||||
// JSONCodecDialOption 负责让 zrpc client 按 JSON 编解码本服务请求体。
|
||||
func JSONCodecDialOption() grpc.DialOption {
|
||||
return grpc.WithDefaultCallOptions(grpc.ForceCodec(jsonCodec{}))
|
||||
}
|
||||
|
||||
// JSONCodecServerOption 负责让 zrpc server 按 JSON 编解码本服务请求体。
|
||||
func JSONCodecServerOption() grpc.ServerOption {
|
||||
return grpc.ForceServerCodec(jsonCodec{})
|
||||
}
|
||||
19
backend/services/llm/rpc/llm.proto
Normal file
19
backend/services/llm/rpc/llm.proto
Normal file
@@ -0,0 +1,19 @@
|
||||
syntax = "proto3";
|
||||
|
||||
package smartflow.llm;
|
||||
|
||||
service LLM {
|
||||
rpc Ping (PingRequest) returns (PingResponse);
|
||||
rpc GenerateText (TextRequest) returns (TextResponse);
|
||||
rpc StreamText (StreamTextRequest) returns (stream StreamChunk);
|
||||
rpc GenerateResponsesText (ResponsesRequest) returns (ResponsesResponse);
|
||||
}
|
||||
|
||||
message PingRequest {}
|
||||
message PingResponse {}
|
||||
message TextRequest {}
|
||||
message TextResponse {}
|
||||
message StreamTextRequest {}
|
||||
message StreamChunk {}
|
||||
message ResponsesRequest {}
|
||||
message ResponsesResponse {}
|
||||
55
backend/services/llm/rpc/server.go
Normal file
55
backend/services/llm/rpc/server.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package rpc
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
llmservice "github.com/LoveLosita/smartflow/backend/services/llm"
|
||||
"github.com/zeromicro/go-zero/core/service"
|
||||
"github.com/zeromicro/go-zero/zrpc"
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultListenOn = "0.0.0.0:9096"
|
||||
defaultTimeout = 0
|
||||
)
|
||||
|
||||
type ServerOptions struct {
|
||||
ListenOn string
|
||||
Timeout time.Duration
|
||||
Service *llmservice.RuntimeService
|
||||
}
|
||||
|
||||
// NewServer 负责创建 LLM 独立进程的最小 zrpc server。
|
||||
func NewServer(opts ServerOptions) (*zrpc.RpcServer, string, error) {
|
||||
if opts.Service == nil {
|
||||
return nil, "", errors.New("llm runtime service dependency not initialized")
|
||||
}
|
||||
|
||||
listenOn := strings.TrimSpace(opts.ListenOn)
|
||||
if listenOn == "" {
|
||||
listenOn = defaultListenOn
|
||||
}
|
||||
timeout := opts.Timeout
|
||||
if timeout < 0 {
|
||||
timeout = defaultTimeout
|
||||
}
|
||||
|
||||
server, err := zrpc.NewServer(zrpc.RpcServerConf{
|
||||
ServiceConf: service.ServiceConf{
|
||||
Name: "llm.rpc",
|
||||
Mode: service.DevMode,
|
||||
},
|
||||
ListenOn: listenOn,
|
||||
Timeout: int64(timeout / time.Millisecond),
|
||||
}, func(grpcServer *grpc.Server) {
|
||||
RegisterLLMServer(grpcServer, NewHandler(opts.Service))
|
||||
})
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
server.AddOptions(JSONCodecServerOption())
|
||||
return server, listenOn, nil
|
||||
}
|
||||
195
backend/services/llm/rpc/transport.go
Normal file
195
backend/services/llm/rpc/transport.go
Normal file
@@ -0,0 +1,195 @@
|
||||
package rpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
|
||||
llmcontracts "github.com/LoveLosita/smartflow/backend/shared/contracts/llm"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
const (
|
||||
LLM_Ping_FullMethodName = "/smartflow.llm.LLM/Ping"
|
||||
LLM_GenerateText_FullMethodName = "/smartflow.llm.LLM/GenerateText"
|
||||
LLM_StreamText_FullMethodName = "/smartflow.llm.LLM/StreamText"
|
||||
LLM_GenerateResponsesText_FullMethodName = "/smartflow.llm.LLM/GenerateResponsesText"
|
||||
)
|
||||
|
||||
type LLMClient interface {
|
||||
Ping(ctx context.Context, in *llmcontracts.PingRequest, opts ...grpc.CallOption) (*llmcontracts.PingResponse, error)
|
||||
GenerateText(ctx context.Context, in *llmcontracts.TextRequest, opts ...grpc.CallOption) (*llmcontracts.TextResponse, error)
|
||||
StreamText(ctx context.Context, in *llmcontracts.StreamTextRequest, opts ...grpc.CallOption) (LLM_StreamTextClient, error)
|
||||
GenerateResponsesText(ctx context.Context, in *llmcontracts.ResponsesRequest, opts ...grpc.CallOption) (*llmcontracts.ResponsesResponse, error)
|
||||
}
|
||||
|
||||
type llmClient struct {
|
||||
cc grpc.ClientConnInterface
|
||||
}
|
||||
|
||||
func NewLLMClient(cc grpc.ClientConnInterface) LLMClient {
|
||||
return &llmClient{cc: cc}
|
||||
}
|
||||
|
||||
func (c *llmClient) Ping(ctx context.Context, in *llmcontracts.PingRequest, opts ...grpc.CallOption) (*llmcontracts.PingResponse, error) {
|
||||
out := new(llmcontracts.PingResponse)
|
||||
err := c.cc.Invoke(ctx, LLM_Ping_FullMethodName, in, out, opts...)
|
||||
return out, err
|
||||
}
|
||||
|
||||
func (c *llmClient) GenerateText(ctx context.Context, in *llmcontracts.TextRequest, opts ...grpc.CallOption) (*llmcontracts.TextResponse, error) {
|
||||
out := new(llmcontracts.TextResponse)
|
||||
err := c.cc.Invoke(ctx, LLM_GenerateText_FullMethodName, in, out, opts...)
|
||||
return out, err
|
||||
}
|
||||
|
||||
func (c *llmClient) StreamText(ctx context.Context, in *llmcontracts.StreamTextRequest, opts ...grpc.CallOption) (LLM_StreamTextClient, error) {
|
||||
stream, err := c.cc.NewStream(ctx, &LLM_ServiceDesc.Streams[0], LLM_StreamText_FullMethodName, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
client := &llmStreamTextClient{ClientStream: stream}
|
||||
if err = client.SendMsg(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = client.CloseSend(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func (c *llmClient) GenerateResponsesText(ctx context.Context, in *llmcontracts.ResponsesRequest, opts ...grpc.CallOption) (*llmcontracts.ResponsesResponse, error) {
|
||||
out := new(llmcontracts.ResponsesResponse)
|
||||
err := c.cc.Invoke(ctx, LLM_GenerateResponsesText_FullMethodName, in, out, opts...)
|
||||
return out, err
|
||||
}
|
||||
|
||||
type LLM_StreamTextClient interface {
|
||||
Recv() (*llmcontracts.StreamChunk, error)
|
||||
grpc.ClientStream
|
||||
}
|
||||
|
||||
type llmStreamTextClient struct {
|
||||
grpc.ClientStream
|
||||
}
|
||||
|
||||
func (x *llmStreamTextClient) Recv() (*llmcontracts.StreamChunk, error) {
|
||||
m := new(llmcontracts.StreamChunk)
|
||||
if err := x.ClientStream.RecvMsg(m); err != nil {
|
||||
if err == io.EOF {
|
||||
return nil, err
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
type LLMServer interface {
|
||||
Ping(context.Context, *llmcontracts.PingRequest) (*llmcontracts.PingResponse, error)
|
||||
GenerateText(context.Context, *llmcontracts.TextRequest) (*llmcontracts.TextResponse, error)
|
||||
StreamText(*llmcontracts.StreamTextRequest, LLM_StreamTextServer) error
|
||||
GenerateResponsesText(context.Context, *llmcontracts.ResponsesRequest) (*llmcontracts.ResponsesResponse, error)
|
||||
}
|
||||
|
||||
type UnimplementedLLMServer struct{}
|
||||
|
||||
func (UnimplementedLLMServer) Ping(context.Context, *llmcontracts.PingRequest) (*llmcontracts.PingResponse, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method Ping not implemented")
|
||||
}
|
||||
|
||||
func (UnimplementedLLMServer) GenerateText(context.Context, *llmcontracts.TextRequest) (*llmcontracts.TextResponse, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GenerateText not implemented")
|
||||
}
|
||||
|
||||
func (UnimplementedLLMServer) StreamText(*llmcontracts.StreamTextRequest, LLM_StreamTextServer) error {
|
||||
return status.Errorf(codes.Unimplemented, "method StreamText not implemented")
|
||||
}
|
||||
|
||||
func (UnimplementedLLMServer) GenerateResponsesText(context.Context, *llmcontracts.ResponsesRequest) (*llmcontracts.ResponsesResponse, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GenerateResponsesText not implemented")
|
||||
}
|
||||
|
||||
func RegisterLLMServer(s grpc.ServiceRegistrar, srv LLMServer) {
|
||||
s.RegisterService(&LLM_ServiceDesc, srv)
|
||||
}
|
||||
|
||||
type LLM_StreamTextServer interface {
|
||||
Send(*llmcontracts.StreamChunk) error
|
||||
grpc.ServerStream
|
||||
}
|
||||
|
||||
type llmStreamTextServer struct {
|
||||
grpc.ServerStream
|
||||
}
|
||||
|
||||
func (x *llmStreamTextServer) Send(m *llmcontracts.StreamChunk) error {
|
||||
return x.ServerStream.SendMsg(m)
|
||||
}
|
||||
|
||||
func _LLM_Ping_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(llmcontracts.PingRequest)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(LLMServer).Ping(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{Server: srv, FullMethod: LLM_Ping_FullMethodName}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(LLMServer).Ping(ctx, req.(*llmcontracts.PingRequest))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
func _LLM_GenerateText_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(llmcontracts.TextRequest)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(LLMServer).GenerateText(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{Server: srv, FullMethod: LLM_GenerateText_FullMethodName}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(LLMServer).GenerateText(ctx, req.(*llmcontracts.TextRequest))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
func _LLM_GenerateResponsesText_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(llmcontracts.ResponsesRequest)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(LLMServer).GenerateResponsesText(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{Server: srv, FullMethod: LLM_GenerateResponsesText_FullMethodName}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(LLMServer).GenerateResponsesText(ctx, req.(*llmcontracts.ResponsesRequest))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
func _LLM_StreamText_Handler(srv interface{}, stream grpc.ServerStream) error {
|
||||
m := new(llmcontracts.StreamTextRequest)
|
||||
if err := stream.RecvMsg(m); err != nil {
|
||||
return err
|
||||
}
|
||||
return srv.(LLMServer).StreamText(m, &llmStreamTextServer{ServerStream: stream})
|
||||
}
|
||||
|
||||
var LLM_ServiceDesc = grpc.ServiceDesc{
|
||||
ServiceName: "smartflow.llm.LLM",
|
||||
HandlerType: (*LLMServer)(nil),
|
||||
Methods: []grpc.MethodDesc{
|
||||
{MethodName: "Ping", Handler: _LLM_Ping_Handler},
|
||||
{MethodName: "GenerateText", Handler: _LLM_GenerateText_Handler},
|
||||
{MethodName: "GenerateResponsesText", Handler: _LLM_GenerateResponsesText_Handler},
|
||||
},
|
||||
Streams: []grpc.StreamDesc{
|
||||
{StreamName: "StreamText", Handler: _LLM_StreamText_Handler, ServerStreams: true},
|
||||
},
|
||||
Metadata: "services/llm/rpc/llm.proto",
|
||||
}
|
||||
315
backend/services/llm/runtime_service.go
Normal file
315
backend/services/llm/runtime_service.go
Normal file
@@ -0,0 +1,315 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
llmdao "github.com/LoveLosita/smartflow/backend/services/llm/dao"
|
||||
llmcontracts "github.com/LoveLosita/smartflow/backend/shared/contracts/llm"
|
||||
outboxinfra "github.com/LoveLosita/smartflow/backend/shared/infra/outbox"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
)
|
||||
|
||||
// RuntimeService 是独立 LLM 进程对外暴露的业务门面。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. 负责模型别名选择、BillingContext 注入、准入守卫与 outbox 写入;
|
||||
// 2. 不负责 prompt 编排,调用方仍然直接传入 messages;
|
||||
// 3. 不负责价格换算细则,本轮先把 usage 事件稳定写入 outbox,价格字段留给后续主代理接线。
|
||||
type RuntimeService struct {
|
||||
legacy *Service
|
||||
textClients map[string]*Client
|
||||
textModelNames map[string]string
|
||||
responsesClient *ArkResponsesClient
|
||||
responsesModel string
|
||||
balanceGuard *CreditBalanceGuard
|
||||
chargeRecorder *ChargeRecorder
|
||||
defaultProvider string
|
||||
}
|
||||
|
||||
type RuntimeServiceOptions struct {
|
||||
LegacyService *Service
|
||||
CacheDAO *llmdao.CacheDAO
|
||||
PriceRuleDAO *llmdao.PriceRuleDAO
|
||||
SnapshotProvider CreditBalanceSnapshotProvider
|
||||
OutboxRepo *outboxinfra.Repository
|
||||
OutboxMaxRetry int
|
||||
ProviderName string
|
||||
LiteModelName string
|
||||
ProModelName string
|
||||
MaxModelName string
|
||||
CourseVisionModel string
|
||||
}
|
||||
|
||||
func NewRuntimeService(opts RuntimeServiceOptions) (*RuntimeService, error) {
|
||||
if opts.LegacyService == nil {
|
||||
return nil, ErrRuntimeServiceNotReady
|
||||
}
|
||||
|
||||
chargeRecorder, err := NewChargeRecorder(ChargeRecorderOptions{
|
||||
Repo: opts.OutboxRepo,
|
||||
MaxRetry: opts.OutboxMaxRetry,
|
||||
ProviderName: opts.ProviderName,
|
||||
Pricing: NewCreditPriceResolver(CreditPriceResolverOptions{DAO: opts.PriceRuleDAO}),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &RuntimeService{
|
||||
legacy: opts.LegacyService,
|
||||
textClients: map[string]*Client{
|
||||
llmcontracts.ModelAliasLite: opts.LegacyService.LiteClient(),
|
||||
llmcontracts.ModelAliasPro: opts.LegacyService.ProClient(),
|
||||
llmcontracts.ModelAliasMax: opts.LegacyService.MaxClient(),
|
||||
},
|
||||
textModelNames: map[string]string{
|
||||
llmcontracts.ModelAliasLite: strings.TrimSpace(opts.LiteModelName),
|
||||
llmcontracts.ModelAliasPro: strings.TrimSpace(opts.ProModelName),
|
||||
llmcontracts.ModelAliasMax: strings.TrimSpace(opts.MaxModelName),
|
||||
},
|
||||
responsesClient: opts.LegacyService.CourseImageResponsesClient(),
|
||||
responsesModel: strings.TrimSpace(opts.CourseVisionModel),
|
||||
balanceGuard: NewCreditBalanceGuard(CreditBalanceGuardOptions{
|
||||
CacheDAO: opts.CacheDAO,
|
||||
SnapshotProvider: opts.SnapshotProvider,
|
||||
}),
|
||||
chargeRecorder: chargeRecorder,
|
||||
defaultProvider: firstNonEmptyString(strings.TrimSpace(opts.ProviderName), llmcontracts.ProviderNameArk),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *RuntimeService) LegacyService() *Service {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
return s.legacy
|
||||
}
|
||||
|
||||
// GenerateText 负责处理一次非流式文本调用。
|
||||
func (s *RuntimeService) GenerateText(ctx context.Context, req llmcontracts.TextRequest) (*TextResult, error) {
|
||||
client, alias, modelName, err := s.resolveTextClient(req.ModelAlias)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 1. 先把跨进程 billing 副本还原回 ctx,保持业务侧调用面不改签名。
|
||||
// 2. 再做一次 Redis 快照级准入守卫;守卫失败直接短路,不继续发起模型调用。
|
||||
// 3. 模型成功后同步写 LLM outbox;写失败只打日志,避免因为记账侧抖动反向打挂主链路。
|
||||
ctx, billing := applyRequestBillingContext(ctx, req.Billing, alias)
|
||||
billing = EnsureTextBillingIdentity(billing, req.Options, req.Messages)
|
||||
if !billing.IsZero() {
|
||||
ctx = WithBillingContext(ctx, billing)
|
||||
}
|
||||
if err = s.balanceGuard.Guard(ctx, billing); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result, err := client.GenerateText(ctx, req.Messages, toServiceGenerateOptions(req.Options))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
logChargeRecordError("llm.text.generate", s.chargeRecorder.RecordTextUsage(ctx, billing, alias, modelName, "llm.text.generate", result.Usage))
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// StreamText 负责处理一次流式文本调用。
|
||||
func (s *RuntimeService) StreamText(ctx context.Context, req llmcontracts.StreamTextRequest) (StreamReader, error) {
|
||||
client, alias, modelName, err := s.resolveTextClient(req.ModelAlias)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ctx, billing := applyRequestBillingContext(ctx, req.Billing, alias)
|
||||
billing = EnsureTextBillingIdentity(billing, req.Options, req.Messages)
|
||||
if !billing.IsZero() {
|
||||
ctx = WithBillingContext(ctx, billing)
|
||||
}
|
||||
if err = s.balanceGuard.Guard(ctx, billing); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
reader, err := client.Stream(ctx, req.Messages, toServiceGenerateOptions(req.Options))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewUsageAccountingStreamReader(reader, func(usage *schema.TokenUsage) {
|
||||
logChargeRecordError("llm.text.stream", s.chargeRecorder.RecordTextUsage(ctx, billing, alias, modelName, "llm.text.stream", usage))
|
||||
}), nil
|
||||
}
|
||||
|
||||
// GenerateResponsesText 负责处理课程图片解析使用的 Responses 文本调用。
|
||||
func (s *RuntimeService) GenerateResponsesText(ctx context.Context, req llmcontracts.ResponsesRequest) (*ArkResponsesResult, error) {
|
||||
client, alias, modelName, err := s.resolveResponsesClient(req.ModelAlias)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ctx, billing := applyRequestBillingContext(ctx, req.Billing, alias)
|
||||
billing = EnsureResponsesBillingIdentity(billing, req.Messages)
|
||||
if !billing.IsZero() {
|
||||
ctx = WithBillingContext(ctx, billing)
|
||||
}
|
||||
if err = s.balanceGuard.Guard(ctx, billing); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result, err := client.GenerateText(ctx, toServiceResponsesMessages(req.Messages), toServiceResponsesOptions(req.Options))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
logChargeRecordError("llm.responses.generate", s.chargeRecorder.RecordResponsesUsage(ctx, billing, alias, modelName, "llm.responses.generate", result.Usage))
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (s *RuntimeService) resolveTextClient(modelAlias string) (*Client, string, string, error) {
|
||||
if s == nil {
|
||||
return nil, "", "", ErrRuntimeServiceNotReady
|
||||
}
|
||||
|
||||
alias := llmcontracts.NormalizeModelAlias(modelAlias)
|
||||
client, ok := s.textClients[alias]
|
||||
if !ok {
|
||||
return nil, alias, "", ErrUnsupportedModelAlias
|
||||
}
|
||||
if client == nil {
|
||||
return nil, alias, "", ErrRuntimeServiceNotReady
|
||||
}
|
||||
return client, alias, firstNonEmptyString(s.textModelNames[alias], alias), nil
|
||||
}
|
||||
|
||||
func (s *RuntimeService) resolveResponsesClient(modelAlias string) (*ArkResponsesClient, string, string, error) {
|
||||
if s == nil || s.responsesClient == nil {
|
||||
return nil, "", "", ErrRuntimeServiceNotReady
|
||||
}
|
||||
|
||||
alias := strings.TrimSpace(modelAlias)
|
||||
if alias == "" {
|
||||
alias = llmcontracts.ModelAliasCourseImageResponses
|
||||
}
|
||||
if alias != llmcontracts.ModelAliasCourseImageResponses {
|
||||
return nil, alias, "", ErrUnsupportedModelAlias
|
||||
}
|
||||
return s.responsesClient, alias, firstNonEmptyString(s.responsesModel, alias), nil
|
||||
}
|
||||
|
||||
func applyRequestBillingContext(ctx context.Context, input *llmcontracts.BillingContext, modelAlias string) (context.Context, BillingContext) {
|
||||
billing := BillingContext{}
|
||||
if input != nil {
|
||||
billing = BillingContext{
|
||||
UserID: input.UserID,
|
||||
EventID: input.EventID,
|
||||
Scene: input.Scene,
|
||||
RequestID: input.RequestID,
|
||||
ConversationID: input.ConversationID,
|
||||
ModelAlias: input.ModelAlias,
|
||||
SkipCharge: input.SkipCharge,
|
||||
}
|
||||
}
|
||||
if strings.TrimSpace(billing.ModelAlias) == "" {
|
||||
billing.ModelAlias = strings.TrimSpace(modelAlias)
|
||||
}
|
||||
if billing.IsZero() {
|
||||
return ctx, billing
|
||||
}
|
||||
return WithBillingContext(ctx, billing), billing
|
||||
}
|
||||
|
||||
func toServiceGenerateOptions(input llmcontracts.GenerateOptions) GenerateOptions {
|
||||
return GenerateOptions{
|
||||
Temperature: input.Temperature,
|
||||
MaxTokens: input.MaxTokens,
|
||||
Thinking: ThinkingMode(strings.TrimSpace(input.Thinking)),
|
||||
Metadata: input.Metadata,
|
||||
}
|
||||
}
|
||||
|
||||
func toServiceResponsesMessages(input []llmcontracts.ResponsesMessage) []ArkResponsesMessage {
|
||||
if len(input) == 0 {
|
||||
return nil
|
||||
}
|
||||
output := make([]ArkResponsesMessage, 0, len(input))
|
||||
for _, item := range input {
|
||||
output = append(output, ArkResponsesMessage{
|
||||
Role: item.Role,
|
||||
Text: item.Text,
|
||||
ImageURL: item.ImageURL,
|
||||
ImageDetail: item.ImageDetail,
|
||||
})
|
||||
}
|
||||
return output
|
||||
}
|
||||
|
||||
func toServiceResponsesOptions(input llmcontracts.ResponsesOptions) ArkResponsesOptions {
|
||||
return ArkResponsesOptions{
|
||||
Model: input.Model,
|
||||
Temperature: input.Temperature,
|
||||
MaxOutputTokens: input.MaxOutputTokens,
|
||||
Thinking: ThinkingMode(strings.TrimSpace(input.Thinking)),
|
||||
TextFormat: input.TextFormat,
|
||||
}
|
||||
}
|
||||
|
||||
func toContractTextResult(result *TextResult) *llmcontracts.TextResult {
|
||||
if result == nil {
|
||||
return nil
|
||||
}
|
||||
return &llmcontracts.TextResult{
|
||||
Text: result.Text,
|
||||
Usage: CloneUsage(result.Usage),
|
||||
FinishReason: result.FinishReason,
|
||||
}
|
||||
}
|
||||
|
||||
func toContractResponsesResult(result *ArkResponsesResult) *llmcontracts.ResponsesResult {
|
||||
if result == nil {
|
||||
return nil
|
||||
}
|
||||
output := &llmcontracts.ResponsesResult{
|
||||
Text: result.Text,
|
||||
Status: result.Status,
|
||||
IncompleteReason: result.IncompleteReason,
|
||||
ErrorCode: result.ErrorCode,
|
||||
ErrorMessage: result.ErrorMessage,
|
||||
}
|
||||
if result.Usage != nil {
|
||||
output.Usage = &llmcontracts.ResponsesUsage{
|
||||
InputTokens: result.Usage.InputTokens,
|
||||
OutputTokens: result.Usage.OutputTokens,
|
||||
TotalTokens: result.Usage.TotalTokens,
|
||||
}
|
||||
}
|
||||
return output
|
||||
}
|
||||
|
||||
func toServiceTextResult(result *llmcontracts.TextResult) *TextResult {
|
||||
if result == nil {
|
||||
return nil
|
||||
}
|
||||
return &TextResult{
|
||||
Text: result.Text,
|
||||
Usage: CloneUsage(result.Usage),
|
||||
FinishReason: result.FinishReason,
|
||||
}
|
||||
}
|
||||
|
||||
func toServiceResponsesResult(result *llmcontracts.ResponsesResult) *ArkResponsesResult {
|
||||
if result == nil {
|
||||
return nil
|
||||
}
|
||||
output := &ArkResponsesResult{
|
||||
Text: result.Text,
|
||||
Status: result.Status,
|
||||
IncompleteReason: result.IncompleteReason,
|
||||
ErrorCode: result.ErrorCode,
|
||||
ErrorMessage: result.ErrorMessage,
|
||||
}
|
||||
if result.Usage != nil {
|
||||
output.Usage = &ArkResponsesUsage{
|
||||
InputTokens: result.Usage.InputTokens,
|
||||
OutputTokens: result.Usage.OutputTokens,
|
||||
TotalTokens: result.Usage.TotalTokens,
|
||||
}
|
||||
}
|
||||
return output
|
||||
}
|
||||
@@ -35,6 +35,19 @@ type AgentModelClients struct {
|
||||
Summary *Client
|
||||
}
|
||||
|
||||
// StaticClients 用于在不依赖 AIHub 的情况下直接注入已构造好的客户端。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. 只负责把已经准备好的 client 聚合成 Service;
|
||||
// 2. 不负责选择 provider,也不负责初始化远端 RPC 连接;
|
||||
// 3. 供独立 llm zrpc client、测试替身和迁移期桥接入口复用。
|
||||
type StaticClients struct {
|
||||
Lite *Client
|
||||
Pro *Client
|
||||
Max *Client
|
||||
CourseImageResponses *ArkResponsesClient
|
||||
}
|
||||
|
||||
// New 构造 llm-service。
|
||||
// 1. 不返回 error,是为了让上层继续按 nil 客户端做逐步降级。
|
||||
// 2. 只要 AIHub 已初始化,就把其中的 ChatModel 收敛成统一 Client。
|
||||
@@ -62,6 +75,16 @@ func New(opts Options) *Service {
|
||||
return svc
|
||||
}
|
||||
|
||||
// NewWithClients 使用外部注入的现成客户端构造 Service。
|
||||
func NewWithClients(clients StaticClients) *Service {
|
||||
return &Service{
|
||||
liteClient: clients.Lite,
|
||||
proClient: clients.Pro,
|
||||
maxClient: clients.Max,
|
||||
courseImageResponsesClient: clients.CourseImageResponses,
|
||||
}
|
||||
}
|
||||
|
||||
// LiteClient 返回低成本短输出模型客户端。
|
||||
func (s *Service) LiteClient() *Client {
|
||||
if s == nil {
|
||||
|
||||
61
backend/services/llm/stream_accounting.go
Normal file
61
backend/services/llm/stream_accounting.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/cloudwego/eino/schema"
|
||||
)
|
||||
|
||||
// usageAccountingStreamReader 负责在流式读取结束时统一回收 usage。
|
||||
type usageAccountingStreamReader struct {
|
||||
source StreamReader
|
||||
onDone func(usage *schema.TokenUsage)
|
||||
|
||||
once sync.Once
|
||||
usage *schema.TokenUsage
|
||||
}
|
||||
|
||||
func NewUsageAccountingStreamReader(source StreamReader, onDone func(usage *schema.TokenUsage)) StreamReader {
|
||||
if source == nil {
|
||||
return nil
|
||||
}
|
||||
return &usageAccountingStreamReader{
|
||||
source: source,
|
||||
onDone: onDone,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *usageAccountingStreamReader) Recv() (*schema.Message, error) {
|
||||
if r == nil || r.source == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
msg, err := r.source.Recv()
|
||||
if msg != nil && msg.ResponseMeta != nil {
|
||||
r.usage = MergeUsage(r.usage, msg.ResponseMeta.Usage)
|
||||
}
|
||||
if err != nil {
|
||||
r.finish()
|
||||
}
|
||||
return msg, err
|
||||
}
|
||||
|
||||
func (r *usageAccountingStreamReader) Close() error {
|
||||
if r == nil || r.source == nil {
|
||||
return nil
|
||||
}
|
||||
err := r.source.Close()
|
||||
r.finish()
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *usageAccountingStreamReader) finish() {
|
||||
if r == nil {
|
||||
return
|
||||
}
|
||||
r.once.Do(func() {
|
||||
if r.onDone != nil {
|
||||
r.onDone(CloneUsage(r.usage))
|
||||
}
|
||||
})
|
||||
}
|
||||
Reference in New Issue
Block a user