package upstream import ( "context" "errors" "fmt" "io" "strings" "sync" "sync/atomic" "time" "chatappgateway/internal/config" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/connectivity" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/status" ) // Handle 暴露给调用方的单个下游节点句柄。 type Handle struct { Target string Conn *grpc.ClientConn } // Pool 管理一组下游节点的连接、健康状态和容错策略。 type Pool struct { name string timeout time.Duration maxAttempts int retryBackoff time.Duration failureThreshold int openTimeout time.Duration healthCacheTTL time.Duration counter atomic.Uint64 endpoints []*endpoint } type endpoint struct { handle Handle mu sync.Mutex consecutiveFailures int openUntil time.Time lastCheckedAt time.Time lastHealthy bool lastErr string } type snapshot struct { handle Handle fresh bool healthy bool open bool errString string } // New 为一组下游节点建立连接池。 func New(ctx context.Context, name string, cfg config.UpstreamConfig) (*Pool, error) { pool := &Pool{ name: name, timeout: cfg.Timeout, maxAttempts: cfg.Retry.MaxAttempts, retryBackoff: cfg.Retry.Backoff, failureThreshold: cfg.CircuitBreaker.FailureThreshold, openTimeout: cfg.CircuitBreaker.OpenTimeout, healthCacheTTL: cfg.HealthCache.TTL, endpoints: make([]*endpoint, 0, len(cfg.Targets)), } for _, target := range cfg.Targets { conn, err := grpc.DialContext( ctx, target, grpc.WithTransportCredentials(insecure.NewCredentials()), ) if err != nil { _ = pool.Close() return nil, fmt.Errorf("dial %s %s: %w", name, target, err) } pool.endpoints = append(pool.endpoints, &endpoint{ handle: Handle{ Target: target, Conn: conn, }, }) } return pool, nil } // Close 关闭全部 gRPC 连接。 func (p *Pool) Close() error { var firstErr error for _, endpoint := range p.endpoints { if endpoint.handle.Conn == nil { continue } if err := endpoint.handle.Conn.Close(); err != nil && firstErr == nil { firstErr = err } } return firstErr } // Ready 使用缓存和连接状态判断当前是否至少有一个健康节点可用。 func (p *Pool) Ready(ctx context.Context) error { now := time.Now() if cached, err := p.cachedReady(now); cached { return err } ordered := p.candidateOrder(now) if len(ordered) == 0 { return p.unavailableError(now) } var errs []string for _, endpoint := range ordered { if err := p.refreshEndpoint(ctx, endpoint); err == nil { return nil } else { errs = append(errs, fmt.Sprintf("%s: %v", endpoint.handle.Target, err)) } } return errors.New(strings.Join(errs, "; ")) } // Call 选择健康节点执行调用,并在可重试错误上做节点级重试。 func Call[T any](ctx context.Context, pool *Pool, invoke func(context.Context, Handle) (T, error)) (T, error) { var zero T var lastErr error tried := make(map[string]struct{}) for attempt := 1; attempt <= pool.maxAttempts; attempt++ { endpoint := pool.pickEndpoint(time.Now(), tried) if endpoint == nil { if lastErr != nil { return zero, lastErr } return zero, pool.unavailableError(time.Now()) } tried[endpoint.handle.Target] = struct{}{} callCtx, cancel := context.WithTimeout(ctx, pool.timeout) result, err := invoke(callCtx, endpoint.handle) cancel() if err == nil { endpoint.recordSuccess(time.Now()) return result, nil } lastErr = err if shouldRetry(err, ctx) { endpoint.recordFailure(time.Now(), err, pool.failureThreshold, pool.openTimeout) if attempt < pool.maxAttempts { if err := sleepWithContext(ctx, pool.retryBackoff); err != nil { return zero, err } continue } return zero, err } // 非重试型错误说明目标节点虽然返回失败,但链路是可达的,不应拉开熔断。 endpoint.recordSuccess(time.Now()) return zero, err } if lastErr != nil { return zero, lastErr } return zero, pool.unavailableError(time.Now()) } func (p *Pool) cachedReady(now time.Time) (bool, error) { allFresh := true var errs []string for _, endpoint := range p.endpoints { snap := endpoint.snapshot(now, p.healthCacheTTL) if snap.fresh && snap.healthy && !snap.open { return true, nil } if !snap.fresh { allFresh = false continue } if snap.errString != "" { errs = append(errs, fmt.Sprintf("%s: %s", snap.handle.Target, snap.errString)) } else { errs = append(errs, fmt.Sprintf("%s: not ready", snap.handle.Target)) } } if allFresh { return true, errors.New(strings.Join(errs, "; ")) } return false, nil } func (p *Pool) refreshEndpoint(ctx context.Context, endpoint *endpoint) error { checkCtx, cancel := context.WithTimeout(ctx, p.readyCheckTimeout()) defer cancel() if err := connectionReady(checkCtx, endpoint.handle.Conn); err != nil { endpoint.recordFailure(time.Now(), err, p.failureThreshold, p.openTimeout) return err } endpoint.recordSuccess(time.Now()) return nil } func (p *Pool) readyCheckTimeout() time.Duration { if p.timeout <= 2*time.Second { return p.timeout } return 2 * time.Second } func (p *Pool) pickEndpoint(now time.Time, tried map[string]struct{}) *endpoint { ordered := p.candidateOrder(now) for _, endpoint := range ordered { if _, ok := tried[endpoint.handle.Target]; !ok { return endpoint } } if len(ordered) == 0 { return nil } return ordered[0] } func (p *Pool) candidateOrder(now time.Time) []*endpoint { var healthy []*endpoint var unknown []*endpoint var unhealthy []*endpoint for _, endpoint := range p.endpoints { snap := endpoint.snapshot(now, p.healthCacheTTL) if snap.open { continue } switch { case snap.fresh && snap.healthy: healthy = append(healthy, endpoint) case !snap.fresh: unknown = append(unknown, endpoint) default: unhealthy = append(unhealthy, endpoint) } } seed := int(p.counter.Add(1) - 1) ordered := make([]*endpoint, 0, len(healthy)+len(unknown)+len(unhealthy)) ordered = append(ordered, rotate(healthy, seed)...) ordered = append(ordered, rotate(unknown, seed)...) ordered = append(ordered, rotate(unhealthy, seed)...) return ordered } func rotate(items []*endpoint, seed int) []*endpoint { if len(items) == 0 { return nil } offset := seed % len(items) return append(items[offset:], items[:offset]...) } func (p *Pool) unavailableError(now time.Time) error { var parts []string for _, endpoint := range p.endpoints { snap := endpoint.snapshot(now, p.healthCacheTTL) if snap.open { parts = append(parts, fmt.Sprintf("%s: circuit open", snap.handle.Target)) continue } if snap.errString != "" { parts = append(parts, fmt.Sprintf("%s: %s", snap.handle.Target, snap.errString)) continue } parts = append(parts, fmt.Sprintf("%s: unavailable", snap.handle.Target)) } return fmt.Errorf("%s has no available upstreams: %s", p.name, strings.Join(parts, "; ")) } func (e *endpoint) snapshot(now time.Time, ttl time.Duration) snapshot { e.mu.Lock() defer e.mu.Unlock() fresh := !e.lastCheckedAt.IsZero() && now.Sub(e.lastCheckedAt) <= ttl open := !e.openUntil.IsZero() && now.Before(e.openUntil) return snapshot{ handle: e.handle, fresh: fresh, healthy: e.lastHealthy, open: open, errString: e.lastErr, } } func (e *endpoint) recordSuccess(now time.Time) { e.mu.Lock() defer e.mu.Unlock() e.consecutiveFailures = 0 e.openUntil = time.Time{} e.lastCheckedAt = now e.lastHealthy = true e.lastErr = "" } func (e *endpoint) recordFailure(now time.Time, err error, failureThreshold int, openTimeout time.Duration) { e.mu.Lock() defer e.mu.Unlock() e.consecutiveFailures++ e.lastCheckedAt = now e.lastHealthy = false e.lastErr = err.Error() if e.consecutiveFailures >= failureThreshold { e.openUntil = now.Add(openTimeout) } } func connectionReady(ctx context.Context, conn *grpc.ClientConn) error { conn.Connect() for { state := conn.GetState() switch state { case connectivity.Ready: return nil case connectivity.Shutdown: return fmt.Errorf("connection shutdown") } if !conn.WaitForStateChange(ctx, state) { if err := ctx.Err(); err != nil { return err } return fmt.Errorf("state did not change") } } } func shouldRetry(err error, parent context.Context) bool { if err == nil { return false } if parent.Err() != nil { return false } if errors.Is(err, context.DeadlineExceeded) { return true } st, ok := status.FromError(err) if !ok { return true } switch st.Code() { case codes.DeadlineExceeded, codes.Unavailable, codes.ResourceExhausted, codes.Aborted, codes.Internal: return true default: return false } } func sleepWithContext(ctx context.Context, d time.Duration) error { if d <= 0 { return nil } timer := time.NewTimer(d) defer timer.Stop() select { case <-ctx.Done(): return ctx.Err() case <-timer.C: return nil } } var _ io.Closer = (*Pool)(nil)