2026-04-06 17:08:51 +08:00

393 lines
9.0 KiB
Go

package upstream
import (
"context"
"errors"
"fmt"
"io"
"strings"
"sync"
"sync/atomic"
"time"
"chatappgateway/internal/config"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/status"
)
// Handle 暴露给调用方的单个下游节点句柄。
type Handle struct {
Target string
Conn *grpc.ClientConn
}
// Pool 管理一组下游节点的连接、健康状态和容错策略。
type Pool struct {
name string
timeout time.Duration
maxAttempts int
retryBackoff time.Duration
failureThreshold int
openTimeout time.Duration
healthCacheTTL time.Duration
counter atomic.Uint64
endpoints []*endpoint
}
type endpoint struct {
handle Handle
mu sync.Mutex
consecutiveFailures int
openUntil time.Time
lastCheckedAt time.Time
lastHealthy bool
lastErr string
}
type snapshot struct {
handle Handle
fresh bool
healthy bool
open bool
errString string
}
// New 为一组下游节点建立连接池。
func New(ctx context.Context, name string, cfg config.UpstreamConfig) (*Pool, error) {
pool := &Pool{
name: name,
timeout: cfg.Timeout,
maxAttempts: cfg.Retry.MaxAttempts,
retryBackoff: cfg.Retry.Backoff,
failureThreshold: cfg.CircuitBreaker.FailureThreshold,
openTimeout: cfg.CircuitBreaker.OpenTimeout,
healthCacheTTL: cfg.HealthCache.TTL,
endpoints: make([]*endpoint, 0, len(cfg.Targets)),
}
for _, target := range cfg.Targets {
conn, err := grpc.DialContext(
ctx,
target,
grpc.WithTransportCredentials(insecure.NewCredentials()),
)
if err != nil {
_ = pool.Close()
return nil, fmt.Errorf("dial %s %s: %w", name, target, err)
}
pool.endpoints = append(pool.endpoints, &endpoint{
handle: Handle{
Target: target,
Conn: conn,
},
})
}
return pool, nil
}
// Close 关闭全部 gRPC 连接。
func (p *Pool) Close() error {
var firstErr error
for _, endpoint := range p.endpoints {
if endpoint.handle.Conn == nil {
continue
}
if err := endpoint.handle.Conn.Close(); err != nil && firstErr == nil {
firstErr = err
}
}
return firstErr
}
// Ready 使用缓存和连接状态判断当前是否至少有一个健康节点可用。
func (p *Pool) Ready(ctx context.Context) error {
now := time.Now()
if cached, err := p.cachedReady(now); cached {
return err
}
ordered := p.candidateOrder(now)
if len(ordered) == 0 {
return p.unavailableError(now)
}
var errs []string
for _, endpoint := range ordered {
if err := p.refreshEndpoint(ctx, endpoint); err == nil {
return nil
} else {
errs = append(errs, fmt.Sprintf("%s: %v", endpoint.handle.Target, err))
}
}
return errors.New(strings.Join(errs, "; "))
}
// Call 选择健康节点执行调用,并在可重试错误上做节点级重试。
func Call[T any](ctx context.Context, pool *Pool, invoke func(context.Context, Handle) (T, error)) (T, error) {
var zero T
var lastErr error
tried := make(map[string]struct{})
for attempt := 1; attempt <= pool.maxAttempts; attempt++ {
endpoint := pool.pickEndpoint(time.Now(), tried)
if endpoint == nil {
if lastErr != nil {
return zero, lastErr
}
return zero, pool.unavailableError(time.Now())
}
tried[endpoint.handle.Target] = struct{}{}
callCtx, cancel := context.WithTimeout(ctx, pool.timeout)
result, err := invoke(callCtx, endpoint.handle)
cancel()
if err == nil {
endpoint.recordSuccess(time.Now())
return result, nil
}
lastErr = err
if shouldRetry(err, ctx) {
endpoint.recordFailure(time.Now(), err, pool.failureThreshold, pool.openTimeout)
if attempt < pool.maxAttempts {
if err := sleepWithContext(ctx, pool.retryBackoff); err != nil {
return zero, err
}
continue
}
return zero, err
}
// 非重试型错误说明目标节点虽然返回失败,但链路是可达的,不应拉开熔断。
endpoint.recordSuccess(time.Now())
return zero, err
}
if lastErr != nil {
return zero, lastErr
}
return zero, pool.unavailableError(time.Now())
}
func (p *Pool) cachedReady(now time.Time) (bool, error) {
allFresh := true
var errs []string
for _, endpoint := range p.endpoints {
snap := endpoint.snapshot(now, p.healthCacheTTL)
if snap.fresh && snap.healthy && !snap.open {
return true, nil
}
if !snap.fresh {
allFresh = false
continue
}
if snap.errString != "" {
errs = append(errs, fmt.Sprintf("%s: %s", snap.handle.Target, snap.errString))
} else {
errs = append(errs, fmt.Sprintf("%s: not ready", snap.handle.Target))
}
}
if allFresh {
return true, errors.New(strings.Join(errs, "; "))
}
return false, nil
}
func (p *Pool) refreshEndpoint(ctx context.Context, endpoint *endpoint) error {
checkCtx, cancel := context.WithTimeout(ctx, p.readyCheckTimeout())
defer cancel()
if err := connectionReady(checkCtx, endpoint.handle.Conn); err != nil {
endpoint.recordFailure(time.Now(), err, p.failureThreshold, p.openTimeout)
return err
}
endpoint.recordSuccess(time.Now())
return nil
}
func (p *Pool) readyCheckTimeout() time.Duration {
if p.timeout <= 2*time.Second {
return p.timeout
}
return 2 * time.Second
}
func (p *Pool) pickEndpoint(now time.Time, tried map[string]struct{}) *endpoint {
ordered := p.candidateOrder(now)
for _, endpoint := range ordered {
if _, ok := tried[endpoint.handle.Target]; !ok {
return endpoint
}
}
if len(ordered) == 0 {
return nil
}
return ordered[0]
}
func (p *Pool) candidateOrder(now time.Time) []*endpoint {
var healthy []*endpoint
var unknown []*endpoint
var unhealthy []*endpoint
for _, endpoint := range p.endpoints {
snap := endpoint.snapshot(now, p.healthCacheTTL)
if snap.open {
continue
}
switch {
case snap.fresh && snap.healthy:
healthy = append(healthy, endpoint)
case !snap.fresh:
unknown = append(unknown, endpoint)
default:
unhealthy = append(unhealthy, endpoint)
}
}
seed := int(p.counter.Add(1) - 1)
ordered := make([]*endpoint, 0, len(healthy)+len(unknown)+len(unhealthy))
ordered = append(ordered, rotate(healthy, seed)...)
ordered = append(ordered, rotate(unknown, seed)...)
ordered = append(ordered, rotate(unhealthy, seed)...)
return ordered
}
func rotate(items []*endpoint, seed int) []*endpoint {
if len(items) == 0 {
return nil
}
offset := seed % len(items)
return append(items[offset:], items[:offset]...)
}
func (p *Pool) unavailableError(now time.Time) error {
var parts []string
for _, endpoint := range p.endpoints {
snap := endpoint.snapshot(now, p.healthCacheTTL)
if snap.open {
parts = append(parts, fmt.Sprintf("%s: circuit open", snap.handle.Target))
continue
}
if snap.errString != "" {
parts = append(parts, fmt.Sprintf("%s: %s", snap.handle.Target, snap.errString))
continue
}
parts = append(parts, fmt.Sprintf("%s: unavailable", snap.handle.Target))
}
return fmt.Errorf("%s has no available upstreams: %s", p.name, strings.Join(parts, "; "))
}
func (e *endpoint) snapshot(now time.Time, ttl time.Duration) snapshot {
e.mu.Lock()
defer e.mu.Unlock()
fresh := !e.lastCheckedAt.IsZero() && now.Sub(e.lastCheckedAt) <= ttl
open := !e.openUntil.IsZero() && now.Before(e.openUntil)
return snapshot{
handle: e.handle,
fresh: fresh,
healthy: e.lastHealthy,
open: open,
errString: e.lastErr,
}
}
func (e *endpoint) recordSuccess(now time.Time) {
e.mu.Lock()
defer e.mu.Unlock()
e.consecutiveFailures = 0
e.openUntil = time.Time{}
e.lastCheckedAt = now
e.lastHealthy = true
e.lastErr = ""
}
func (e *endpoint) recordFailure(now time.Time, err error, failureThreshold int, openTimeout time.Duration) {
e.mu.Lock()
defer e.mu.Unlock()
e.consecutiveFailures++
e.lastCheckedAt = now
e.lastHealthy = false
e.lastErr = err.Error()
if e.consecutiveFailures >= failureThreshold {
e.openUntil = now.Add(openTimeout)
}
}
func connectionReady(ctx context.Context, conn *grpc.ClientConn) error {
conn.Connect()
for {
state := conn.GetState()
switch state {
case connectivity.Ready:
return nil
case connectivity.Shutdown:
return fmt.Errorf("connection shutdown")
}
if !conn.WaitForStateChange(ctx, state) {
if err := ctx.Err(); err != nil {
return err
}
return fmt.Errorf("state did not change")
}
}
}
func shouldRetry(err error, parent context.Context) bool {
if err == nil {
return false
}
if parent.Err() != nil {
return false
}
if errors.Is(err, context.DeadlineExceeded) {
return true
}
st, ok := status.FromError(err)
if !ok {
return true
}
switch st.Code() {
case codes.DeadlineExceeded, codes.Unavailable, codes.ResourceExhausted, codes.Aborted, codes.Internal:
return true
default:
return false
}
}
func sleepWithContext(ctx context.Context, d time.Duration) error {
if d <= 0 {
return nil
}
timer := time.NewTimer(d)
defer timer.Stop()
select {
case <-ctx.Done():
return ctx.Err()
case <-timer.C:
return nil
}
}
var _ io.Closer = (*Pool)(nil)