package grpc import ( "google.golang.org/grpc/balancer" "google.golang.org/grpc/balancer/base" "google.golang.org/grpc/metadata" "github.com/go-kratos/kratos/v2/registry" "github.com/go-kratos/kratos/v2/selector" "github.com/go-kratos/kratos/v2/transport" ) const ( balancerName = "selector" ) var ( _ base.PickerBuilder = (*balancerBuilder)(nil) _ balancer.Picker = (*balancerPicker)(nil) ) type balancerBuilder struct { builder selector.Builder } func init() { b := base.NewBalancerBuilder( balancerName, &balancerBuilder{ builder: selector.GlobalSelector(), }, base.Config{HealthCheck: true}, ) balancer.Register(b) } // Build creates a grpc Picker. func (b *balancerBuilder) Build(info base.PickerBuildInfo) balancer.Picker { if len(info.ReadySCs) == 0 { // Block the RPC until a new picker is available via UpdateState(). return base.NewErrPicker(balancer.ErrNoSubConnAvailable) } nodes := make([]selector.Node, 0, len(info.ReadySCs)) for conn, info := range info.ReadySCs { ins, _ := info.Address.Attributes.Value("rawServiceInstance").(*registry.ServiceInstance) nodes = append(nodes, &grpcNode{ Node: selector.NewNode("grpc", info.Address.Addr, ins), subConn: conn, }) } p := &balancerPicker{ selector: b.builder.Build(), } p.selector.Apply(nodes) return p } // balancerPicker is a grpc picker. type balancerPicker struct { selector selector.Selector } // Pick pick instances. func (p *balancerPicker) Pick(info balancer.PickInfo) (balancer.PickResult, error) { var filters []selector.NodeFilter if tr, ok := transport.FromClientContext(info.Ctx); ok { if gtr, ok := tr.(*Transport); ok { filters = gtr.NodeFilters() } } n, done, err := p.selector.Select(info.Ctx, selector.WithNodeFilter(filters...)) if err != nil { return balancer.PickResult{}, err } return balancer.PickResult{ SubConn: n.(*grpcNode).subConn, Done: func(di balancer.DoneInfo) { done(info.Ctx, selector.DoneInfo{ Err: di.Err, BytesSent: di.BytesSent, BytesReceived: di.BytesReceived, ReplyMD: Trailer(di.Trailer), }) }, }, nil } // Trailer is a grpc trailer MD. type Trailer metadata.MD // Get get a grpc trailer value. func (t Trailer) Get(k string) string { v := metadata.MD(t).Get(k) if len(v) > 0 { return v[0] } return "" } type grpcNode struct { selector.Node subConn balancer.SubConn }