393 lines
9.0 KiB
Go
393 lines
9.0 KiB
Go
package upstream
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"strings"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"chatappgateway/internal/config"
|
|
|
|
"google.golang.org/grpc"
|
|
"google.golang.org/grpc/codes"
|
|
"google.golang.org/grpc/connectivity"
|
|
"google.golang.org/grpc/credentials/insecure"
|
|
"google.golang.org/grpc/status"
|
|
)
|
|
|
|
// Handle 暴露给调用方的单个下游节点句柄。
|
|
type Handle struct {
|
|
Target string
|
|
Conn *grpc.ClientConn
|
|
}
|
|
|
|
// Pool 管理一组下游节点的连接、健康状态和容错策略。
|
|
type Pool struct {
|
|
name string
|
|
timeout time.Duration
|
|
maxAttempts int
|
|
retryBackoff time.Duration
|
|
failureThreshold int
|
|
openTimeout time.Duration
|
|
healthCacheTTL time.Duration
|
|
counter atomic.Uint64
|
|
endpoints []*endpoint
|
|
}
|
|
|
|
type endpoint struct {
|
|
handle Handle
|
|
|
|
mu sync.Mutex
|
|
consecutiveFailures int
|
|
openUntil time.Time
|
|
lastCheckedAt time.Time
|
|
lastHealthy bool
|
|
lastErr string
|
|
}
|
|
|
|
type snapshot struct {
|
|
handle Handle
|
|
fresh bool
|
|
healthy bool
|
|
open bool
|
|
errString string
|
|
}
|
|
|
|
// New 为一组下游节点建立连接池。
|
|
func New(ctx context.Context, name string, cfg config.UpstreamConfig) (*Pool, error) {
|
|
pool := &Pool{
|
|
name: name,
|
|
timeout: cfg.Timeout,
|
|
maxAttempts: cfg.Retry.MaxAttempts,
|
|
retryBackoff: cfg.Retry.Backoff,
|
|
failureThreshold: cfg.CircuitBreaker.FailureThreshold,
|
|
openTimeout: cfg.CircuitBreaker.OpenTimeout,
|
|
healthCacheTTL: cfg.HealthCache.TTL,
|
|
endpoints: make([]*endpoint, 0, len(cfg.Targets)),
|
|
}
|
|
|
|
for _, target := range cfg.Targets {
|
|
conn, err := grpc.DialContext(
|
|
ctx,
|
|
target,
|
|
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
|
)
|
|
if err != nil {
|
|
_ = pool.Close()
|
|
return nil, fmt.Errorf("dial %s %s: %w", name, target, err)
|
|
}
|
|
pool.endpoints = append(pool.endpoints, &endpoint{
|
|
handle: Handle{
|
|
Target: target,
|
|
Conn: conn,
|
|
},
|
|
})
|
|
}
|
|
|
|
return pool, nil
|
|
}
|
|
|
|
// Close 关闭全部 gRPC 连接。
|
|
func (p *Pool) Close() error {
|
|
var firstErr error
|
|
for _, endpoint := range p.endpoints {
|
|
if endpoint.handle.Conn == nil {
|
|
continue
|
|
}
|
|
if err := endpoint.handle.Conn.Close(); err != nil && firstErr == nil {
|
|
firstErr = err
|
|
}
|
|
}
|
|
return firstErr
|
|
}
|
|
|
|
// Ready 使用缓存和连接状态判断当前是否至少有一个健康节点可用。
|
|
func (p *Pool) Ready(ctx context.Context) error {
|
|
now := time.Now()
|
|
if cached, err := p.cachedReady(now); cached {
|
|
return err
|
|
}
|
|
|
|
ordered := p.candidateOrder(now)
|
|
if len(ordered) == 0 {
|
|
return p.unavailableError(now)
|
|
}
|
|
|
|
var errs []string
|
|
for _, endpoint := range ordered {
|
|
if err := p.refreshEndpoint(ctx, endpoint); err == nil {
|
|
return nil
|
|
} else {
|
|
errs = append(errs, fmt.Sprintf("%s: %v", endpoint.handle.Target, err))
|
|
}
|
|
}
|
|
|
|
return errors.New(strings.Join(errs, "; "))
|
|
}
|
|
|
|
// Call 选择健康节点执行调用,并在可重试错误上做节点级重试。
|
|
func Call[T any](ctx context.Context, pool *Pool, invoke func(context.Context, Handle) (T, error)) (T, error) {
|
|
var zero T
|
|
var lastErr error
|
|
tried := make(map[string]struct{})
|
|
|
|
for attempt := 1; attempt <= pool.maxAttempts; attempt++ {
|
|
endpoint := pool.pickEndpoint(time.Now(), tried)
|
|
if endpoint == nil {
|
|
if lastErr != nil {
|
|
return zero, lastErr
|
|
}
|
|
return zero, pool.unavailableError(time.Now())
|
|
}
|
|
|
|
tried[endpoint.handle.Target] = struct{}{}
|
|
|
|
callCtx, cancel := context.WithTimeout(ctx, pool.timeout)
|
|
result, err := invoke(callCtx, endpoint.handle)
|
|
cancel()
|
|
|
|
if err == nil {
|
|
endpoint.recordSuccess(time.Now())
|
|
return result, nil
|
|
}
|
|
|
|
lastErr = err
|
|
if shouldRetry(err, ctx) {
|
|
endpoint.recordFailure(time.Now(), err, pool.failureThreshold, pool.openTimeout)
|
|
if attempt < pool.maxAttempts {
|
|
if err := sleepWithContext(ctx, pool.retryBackoff); err != nil {
|
|
return zero, err
|
|
}
|
|
continue
|
|
}
|
|
return zero, err
|
|
}
|
|
|
|
// 非重试型错误说明目标节点虽然返回失败,但链路是可达的,不应拉开熔断。
|
|
endpoint.recordSuccess(time.Now())
|
|
return zero, err
|
|
}
|
|
|
|
if lastErr != nil {
|
|
return zero, lastErr
|
|
}
|
|
return zero, pool.unavailableError(time.Now())
|
|
}
|
|
|
|
func (p *Pool) cachedReady(now time.Time) (bool, error) {
|
|
allFresh := true
|
|
var errs []string
|
|
|
|
for _, endpoint := range p.endpoints {
|
|
snap := endpoint.snapshot(now, p.healthCacheTTL)
|
|
if snap.fresh && snap.healthy && !snap.open {
|
|
return true, nil
|
|
}
|
|
if !snap.fresh {
|
|
allFresh = false
|
|
continue
|
|
}
|
|
if snap.errString != "" {
|
|
errs = append(errs, fmt.Sprintf("%s: %s", snap.handle.Target, snap.errString))
|
|
} else {
|
|
errs = append(errs, fmt.Sprintf("%s: not ready", snap.handle.Target))
|
|
}
|
|
}
|
|
|
|
if allFresh {
|
|
return true, errors.New(strings.Join(errs, "; "))
|
|
}
|
|
return false, nil
|
|
}
|
|
|
|
func (p *Pool) refreshEndpoint(ctx context.Context, endpoint *endpoint) error {
|
|
checkCtx, cancel := context.WithTimeout(ctx, p.readyCheckTimeout())
|
|
defer cancel()
|
|
|
|
if err := connectionReady(checkCtx, endpoint.handle.Conn); err != nil {
|
|
endpoint.recordFailure(time.Now(), err, p.failureThreshold, p.openTimeout)
|
|
return err
|
|
}
|
|
|
|
endpoint.recordSuccess(time.Now())
|
|
return nil
|
|
}
|
|
|
|
func (p *Pool) readyCheckTimeout() time.Duration {
|
|
if p.timeout <= 2*time.Second {
|
|
return p.timeout
|
|
}
|
|
return 2 * time.Second
|
|
}
|
|
|
|
func (p *Pool) pickEndpoint(now time.Time, tried map[string]struct{}) *endpoint {
|
|
ordered := p.candidateOrder(now)
|
|
for _, endpoint := range ordered {
|
|
if _, ok := tried[endpoint.handle.Target]; !ok {
|
|
return endpoint
|
|
}
|
|
}
|
|
if len(ordered) == 0 {
|
|
return nil
|
|
}
|
|
return ordered[0]
|
|
}
|
|
|
|
func (p *Pool) candidateOrder(now time.Time) []*endpoint {
|
|
var healthy []*endpoint
|
|
var unknown []*endpoint
|
|
var unhealthy []*endpoint
|
|
|
|
for _, endpoint := range p.endpoints {
|
|
snap := endpoint.snapshot(now, p.healthCacheTTL)
|
|
if snap.open {
|
|
continue
|
|
}
|
|
switch {
|
|
case snap.fresh && snap.healthy:
|
|
healthy = append(healthy, endpoint)
|
|
case !snap.fresh:
|
|
unknown = append(unknown, endpoint)
|
|
default:
|
|
unhealthy = append(unhealthy, endpoint)
|
|
}
|
|
}
|
|
|
|
seed := int(p.counter.Add(1) - 1)
|
|
ordered := make([]*endpoint, 0, len(healthy)+len(unknown)+len(unhealthy))
|
|
ordered = append(ordered, rotate(healthy, seed)...)
|
|
ordered = append(ordered, rotate(unknown, seed)...)
|
|
ordered = append(ordered, rotate(unhealthy, seed)...)
|
|
return ordered
|
|
}
|
|
|
|
func rotate(items []*endpoint, seed int) []*endpoint {
|
|
if len(items) == 0 {
|
|
return nil
|
|
}
|
|
offset := seed % len(items)
|
|
return append(items[offset:], items[:offset]...)
|
|
}
|
|
|
|
func (p *Pool) unavailableError(now time.Time) error {
|
|
var parts []string
|
|
for _, endpoint := range p.endpoints {
|
|
snap := endpoint.snapshot(now, p.healthCacheTTL)
|
|
if snap.open {
|
|
parts = append(parts, fmt.Sprintf("%s: circuit open", snap.handle.Target))
|
|
continue
|
|
}
|
|
if snap.errString != "" {
|
|
parts = append(parts, fmt.Sprintf("%s: %s", snap.handle.Target, snap.errString))
|
|
continue
|
|
}
|
|
parts = append(parts, fmt.Sprintf("%s: unavailable", snap.handle.Target))
|
|
}
|
|
return fmt.Errorf("%s has no available upstreams: %s", p.name, strings.Join(parts, "; "))
|
|
}
|
|
|
|
func (e *endpoint) snapshot(now time.Time, ttl time.Duration) snapshot {
|
|
e.mu.Lock()
|
|
defer e.mu.Unlock()
|
|
|
|
fresh := !e.lastCheckedAt.IsZero() && now.Sub(e.lastCheckedAt) <= ttl
|
|
open := !e.openUntil.IsZero() && now.Before(e.openUntil)
|
|
return snapshot{
|
|
handle: e.handle,
|
|
fresh: fresh,
|
|
healthy: e.lastHealthy,
|
|
open: open,
|
|
errString: e.lastErr,
|
|
}
|
|
}
|
|
|
|
func (e *endpoint) recordSuccess(now time.Time) {
|
|
e.mu.Lock()
|
|
defer e.mu.Unlock()
|
|
|
|
e.consecutiveFailures = 0
|
|
e.openUntil = time.Time{}
|
|
e.lastCheckedAt = now
|
|
e.lastHealthy = true
|
|
e.lastErr = ""
|
|
}
|
|
|
|
func (e *endpoint) recordFailure(now time.Time, err error, failureThreshold int, openTimeout time.Duration) {
|
|
e.mu.Lock()
|
|
defer e.mu.Unlock()
|
|
|
|
e.consecutiveFailures++
|
|
e.lastCheckedAt = now
|
|
e.lastHealthy = false
|
|
e.lastErr = err.Error()
|
|
if e.consecutiveFailures >= failureThreshold {
|
|
e.openUntil = now.Add(openTimeout)
|
|
}
|
|
}
|
|
|
|
func connectionReady(ctx context.Context, conn *grpc.ClientConn) error {
|
|
conn.Connect()
|
|
|
|
for {
|
|
state := conn.GetState()
|
|
switch state {
|
|
case connectivity.Ready:
|
|
return nil
|
|
case connectivity.Shutdown:
|
|
return fmt.Errorf("connection shutdown")
|
|
}
|
|
|
|
if !conn.WaitForStateChange(ctx, state) {
|
|
if err := ctx.Err(); err != nil {
|
|
return err
|
|
}
|
|
return fmt.Errorf("state did not change")
|
|
}
|
|
}
|
|
}
|
|
|
|
func shouldRetry(err error, parent context.Context) bool {
|
|
if err == nil {
|
|
return false
|
|
}
|
|
if parent.Err() != nil {
|
|
return false
|
|
}
|
|
if errors.Is(err, context.DeadlineExceeded) {
|
|
return true
|
|
}
|
|
|
|
st, ok := status.FromError(err)
|
|
if !ok {
|
|
return true
|
|
}
|
|
|
|
switch st.Code() {
|
|
case codes.DeadlineExceeded, codes.Unavailable, codes.ResourceExhausted, codes.Aborted, codes.Internal:
|
|
return true
|
|
default:
|
|
return false
|
|
}
|
|
}
|
|
|
|
func sleepWithContext(ctx context.Context, d time.Duration) error {
|
|
if d <= 0 {
|
|
return nil
|
|
}
|
|
timer := time.NewTimer(d)
|
|
defer timer.Stop()
|
|
|
|
select {
|
|
case <-ctx.Done():
|
|
return ctx.Err()
|
|
case <-timer.C:
|
|
return nil
|
|
}
|
|
}
|
|
|
|
var _ io.Closer = (*Pool)(nil)
|