2026-04-04 01:28:57 +08:00

117 lines
2.8 KiB
Go

package config
import (
"bytes"
"fmt"
"net"
"os"
"time"
"gopkg.in/yaml.v3"
)
const DefaultPath = "config/local.yaml"
// Config 汇总整个网关服务的运行参数。
type Config struct {
App AppConfig `yaml:"app"`
GRPC GRPCConfig `yaml:"grpc"`
}
// AppConfig 描述应用自身参数。
type AppConfig struct {
Name string `yaml:"name"`
Env string `yaml:"env"`
HTTPAddr string `yaml:"http_addr"`
ShutdownTimeout time.Duration `yaml:"shutdown_timeout"`
}
// GRPCConfig 聚合所有下游 gRPC 服务配置。
type GRPCConfig struct {
User UpstreamConfig `yaml:"user"`
Pay UpstreamConfig `yaml:"pay"`
}
// UpstreamConfig 描述单个下游服务的地址和超时。
type UpstreamConfig struct {
Target string `yaml:"target"`
Timeout time.Duration `yaml:"timeout"`
}
// Load 从 YAML 文件加载配置,并补齐默认值和校验。
func Load(path string) (Config, error) {
// 先读取配置文件内容。
data, err := os.ReadFile(path)
if err != nil {
return Config{}, fmt.Errorf("read config file %s: %w", path, err)
}
// 使用 KnownFields 防止配置拼写错误悄悄溜过。
cfg := defaultConfig()
decoder := yaml.NewDecoder(bytes.NewReader(data))
decoder.KnownFields(true)
if err := decoder.Decode(&cfg); err != nil {
return Config{}, fmt.Errorf("decode config file %s: %w", path, err)
}
if err := validate(cfg); err != nil {
return Config{}, err
}
return cfg, nil
}
func defaultConfig() Config {
return Config{
App: AppConfig{
Name: "chatappgateway",
Env: "local",
HTTPAddr: ":8080",
ShutdownTimeout: 10 * time.Second,
},
GRPC: GRPCConfig{
User: UpstreamConfig{
Target: "127.0.0.1:9001",
Timeout: 3 * time.Second,
},
Pay: UpstreamConfig{
Target: "127.0.0.1:9002",
Timeout: 3 * time.Second,
},
},
}
}
func validate(cfg Config) error {
// 应用名用于日志标签,不能为空。
if cfg.App.Name == "" {
return fmt.Errorf("app.name is required")
}
// 监听地址必须能被 TCP 地址解析。
if _, err := net.ResolveTCPAddr("tcp", cfg.App.HTTPAddr); err != nil {
return fmt.Errorf("app.http_addr is invalid: %w", err)
}
if cfg.App.ShutdownTimeout <= 0 {
return fmt.Errorf("app.shutdown_timeout must be greater than 0")
}
if err := validateUpstream("grpc.user", cfg.GRPC.User); err != nil {
return err
}
if err := validateUpstream("grpc.pay", cfg.GRPC.Pay); err != nil {
return err
}
return nil
}
func validateUpstream(name string, cfg UpstreamConfig) error {
if cfg.Target == "" {
return fmt.Errorf("%s.target is required", name)
}
if _, err := net.ResolveTCPAddr("tcp", cfg.Target); err != nil {
return fmt.Errorf("%s.target is invalid: %w", name, err)
}
if cfg.Timeout <= 0 {
return fmt.Errorf("%s.timeout must be greater than 0", name)
}
return nil
}