refactor: unify selector filter (#2277)

* unify selector

Co-authored-by: caoguoliang01 <caoguoliang01@bilibili.com>
Co-authored-by: chenzhihui <zhihui_chen@foxmail.com>
pull/2283/head
longxboy 2 years ago committed by GitHub
parent d11c6892b4
commit 11cd43e3c3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 12
      selector/default_selector.go
  2. 4
      selector/filter.go
  3. 2
      selector/filter/version.go
  4. 13
      selector/global.go
  5. 8
      selector/options.go
  6. 12
      selector/p2c/p2c.go
  7. 8
      selector/p2c/p2c_test.go
  8. 12
      selector/random/random.go
  9. 4
      selector/random/random_test.go
  10. 6
      selector/selector.go
  11. 11
      selector/selector_test.go
  12. 12
      selector/wrr/wrr.go
  13. 4
      selector/wrr/wrr_test.go
  14. 68
      transport/grpc/balancer.go
  15. 11
      transport/grpc/balancer_test.go
  16. 33
      transport/grpc/client.go
  17. 8
      transport/grpc/transport.go
  18. 24
      transport/http/client.go
  19. 15
      transport/http/client_test.go

@ -9,7 +9,6 @@ import (
type Default struct { type Default struct {
NodeBuilder WeightedNodeBuilder NodeBuilder WeightedNodeBuilder
Balancer Balancer Balancer Balancer
Filters []Filter
nodes atomic.Value nodes atomic.Value
} }
@ -27,16 +26,13 @@ func (d *Default) Select(ctx context.Context, opts ...SelectOption) (selected No
for _, o := range opts { for _, o := range opts {
o(&options) o(&options)
} }
if len(d.Filters) > 0 || len(options.Filters) > 0 { if len(options.NodeFilters) > 0 {
newNodes := make([]Node, len(nodes)) newNodes := make([]Node, len(nodes))
for i, wc := range nodes { for i, wc := range nodes {
newNodes[i] = wc newNodes[i] = wc
} }
for _, f := range d.Filters { for _, filter := range options.NodeFilters {
newNodes = f(ctx, newNodes) newNodes = filter(ctx, newNodes)
}
for _, f := range options.Filters {
newNodes = f(ctx, newNodes)
} }
candidates = make([]WeightedNode, len(newNodes)) candidates = make([]WeightedNode, len(newNodes))
for i, n := range newNodes { for i, n := range newNodes {
@ -74,7 +70,6 @@ func (d *Default) Apply(nodes []Node) {
type DefaultBuilder struct { type DefaultBuilder struct {
Node WeightedNodeBuilder Node WeightedNodeBuilder
Balancer BalancerBuilder Balancer BalancerBuilder
Filters []Filter
} }
// Build create builder // Build create builder
@ -82,6 +77,5 @@ func (db *DefaultBuilder) Build() Selector {
return &Default{ return &Default{
NodeBuilder: db.Node, NodeBuilder: db.Node,
Balancer: db.Balancer.Build(), Balancer: db.Balancer.Build(),
Filters: db.Filters,
} }
} }

@ -2,5 +2,5 @@ package selector
import "context" import "context"
// Filter is select filter. // NodeFilter is select filter.
type Filter func(context.Context, []Node) []Node type NodeFilter func(context.Context, []Node) []Node

@ -7,7 +7,7 @@ import (
) )
// Version is version filter. // Version is version filter.
func Version(version string) selector.Filter { func Version(version string) selector.NodeFilter {
return func(_ context.Context, nodes []selector.Node) []selector.Node { return func(_ context.Context, nodes []selector.Node) []selector.Node {
newNodes := nodes[:0] newNodes := nodes[:0]
for _, n := range nodes { for _, n := range nodes {

@ -0,0 +1,13 @@
package selector
var globalSelector Builder
// GlobalSelector returns global selector builder.
func GlobalSelector() Builder {
return globalSelector
}
// SetGlobalSelector set global selector builder.
func SetGlobalSelector(builder Builder) {
globalSelector = builder
}

@ -2,15 +2,15 @@ package selector
// SelectOptions is Select Options. // SelectOptions is Select Options.
type SelectOptions struct { type SelectOptions struct {
Filters []Filter NodeFilters []NodeFilter
} }
// SelectOption is Selector option. // SelectOption is Selector option.
type SelectOption func(*SelectOptions) type SelectOption func(*SelectOptions)
// WithFilter with filter options // WithNodeFilter with filter options
func WithFilter(fn ...Filter) SelectOption { func WithNodeFilter(fn ...NodeFilter) SelectOption {
return func(opts *SelectOptions) { return func(opts *SelectOptions) {
opts.Filters = fn opts.NodeFilters = fn
} }
} }

@ -19,20 +19,11 @@ const (
var _ selector.Balancer = &Balancer{} var _ selector.Balancer = &Balancer{}
// WithFilter with select filters
func WithFilter(filters ...selector.Filter) Option {
return func(o *options) {
o.filters = filters
}
}
// Option is random builder option. // Option is random builder option.
type Option func(o *options) type Option func(o *options)
// options is random builder options // options is random builder options
type options struct { type options struct{}
filters []selector.Filter
}
// New creates a p2c selector. // New creates a p2c selector.
func New(opts ...Option) selector.Selector { func New(opts ...Option) selector.Selector {
@ -95,7 +86,6 @@ func NewBuilder(opts ...Option) selector.Builder {
opt(&option) opt(&option)
} }
return &selector.DefaultBuilder{ return &selector.DefaultBuilder{
Filters: option.filters,
Balancer: &Builder{}, Balancer: &Builder{},
Node: &ewma.Builder{}, Node: &ewma.Builder{},
} }

@ -16,7 +16,7 @@ import (
) )
func TestWrr3(t *testing.T) { func TestWrr3(t *testing.T) {
p2c := New(WithFilter(filter.Version("v2.0.0"))) p2c := New()
var nodes []selector.Node var nodes []selector.Node
for i := 0; i < 3; i++ { for i := 0; i < 3; i++ {
addr := fmt.Sprintf("127.0.0.%d:8080", i) addr := fmt.Sprintf("127.0.0.%d:8080", i)
@ -41,7 +41,7 @@ func TestWrr3(t *testing.T) {
d := time.Duration(rand.Intn(500)) * time.Millisecond d := time.Duration(rand.Intn(500)) * time.Millisecond
lk.Unlock() lk.Unlock()
time.Sleep(d) time.Sleep(d)
n, done, err := p2c.Select(context.Background()) n, done, err := p2c.Select(context.Background(), selector.WithNodeFilter(filter.Version("v2.0.0")))
if err != nil { if err != nil {
t.Errorf("expect %v, got %v", nil, err) t.Errorf("expect %v, got %v", nil, err)
} }
@ -92,7 +92,7 @@ func TestEmpty(t *testing.T) {
} }
func TestOne(t *testing.T) { func TestOne(t *testing.T) {
p2c := New(WithFilter(filter.Version("v2.0.0"))) p2c := New()
var nodes []selector.Node var nodes []selector.Node
for i := 0; i < 1; i++ { for i := 0; i < 1; i++ {
addr := fmt.Sprintf("127.0.0.%d:8080", i) addr := fmt.Sprintf("127.0.0.%d:8080", i)
@ -106,7 +106,7 @@ func TestOne(t *testing.T) {
})) }))
} }
p2c.Apply(nodes) p2c.Apply(nodes)
n, done, err := p2c.Select(context.Background()) n, done, err := p2c.Select(context.Background(), selector.WithNodeFilter(filter.Version("v2.0.0")))
if err != nil { if err != nil {
t.Errorf("expect %v, got %v", nil, err) t.Errorf("expect %v, got %v", nil, err)
} }

@ -15,20 +15,11 @@ const (
var _ selector.Balancer = &Balancer{} // Name is balancer name var _ selector.Balancer = &Balancer{} // Name is balancer name
// WithFilter with select filters
func WithFilter(filters ...selector.Filter) Option {
return func(o *options) {
o.filters = filters
}
}
// Option is random builder option. // Option is random builder option.
type Option func(o *options) type Option func(o *options)
// options is random builder options // options is random builder options
type options struct { type options struct{}
filters []selector.Filter
}
// Balancer is a random balancer. // Balancer is a random balancer.
type Balancer struct{} type Balancer struct{}
@ -56,7 +47,6 @@ func NewBuilder(opts ...Option) selector.Builder {
opt(&option) opt(&option)
} }
return &selector.DefaultBuilder{ return &selector.DefaultBuilder{
Filters: option.filters,
Balancer: &Builder{}, Balancer: &Builder{},
Node: &direct.Builder{}, Node: &direct.Builder{},
} }

@ -10,7 +10,7 @@ import (
) )
func TestWrr(t *testing.T) { func TestWrr(t *testing.T) {
random := New(WithFilter(filter.Version("v2.0.0"))) random := New()
var nodes []selector.Node var nodes []selector.Node
nodes = append(nodes, selector.NewNode( nodes = append(nodes, selector.NewNode(
"http", "http",
@ -31,7 +31,7 @@ func TestWrr(t *testing.T) {
random.Apply(nodes) random.Apply(nodes)
var count1, count2 int var count1, count2 int
for i := 0; i < 200; i++ { for i := 0; i < 200; i++ {
n, done, err := random.Select(context.Background()) n, done, err := random.Select(context.Background(), selector.WithNodeFilter(filter.Version("v2.0.0")))
if err != nil { if err != nil {
t.Errorf("expect no error, got %v", err) t.Errorf("expect no error, got %v", err)
} }

@ -57,7 +57,7 @@ type DoneInfo struct {
// Response Error // Response Error
Err error Err error
// Response Metadata // Response Metadata
ReplyMeta ReplyMeta ReplyMD ReplyMD
// BytesSent indicates if any bytes have been sent to the server. // BytesSent indicates if any bytes have been sent to the server.
BytesSent bool BytesSent bool
@ -65,8 +65,8 @@ type DoneInfo struct {
BytesReceived bool BytesReceived bool
} }
// ReplyMeta is Reply Metadata. // ReplyMD is Reply Metadata.
type ReplyMeta interface { type ReplyMD interface {
Get(key string) string Get(key string) string
} }

@ -49,7 +49,7 @@ func (b *mockWeightedNodeBuilder) Build(n Node) WeightedNode {
return &mockWeightedNode{Node: n} return &mockWeightedNode{Node: n}
} }
func mockFilter(version string) Filter { func mockFilter(version string) NodeFilter {
return func(_ context.Context, nodes []Node) []Node { return func(_ context.Context, nodes []Node) []Node {
newNodes := nodes[:0] newNodes := nodes[:0]
for _, n := range nodes { for _, n := range nodes {
@ -83,7 +83,6 @@ func (b *mockBalancer) Pick(ctx context.Context, nodes []WeightedNode) (selected
func TestDefault(t *testing.T) { func TestDefault(t *testing.T) {
builder := DefaultBuilder{ builder := DefaultBuilder{
Node: &mockWeightedNodeBuilder{}, Node: &mockWeightedNodeBuilder{},
Filters: []Filter{mockFilter("v2.0.0")},
Balancer: &mockBalancerBuilder{}, Balancer: &mockBalancerBuilder{},
} }
selector := builder.Build() selector := builder.Build()
@ -109,7 +108,7 @@ func TestDefault(t *testing.T) {
Metadata: map[string]string{"weight": "10"}, Metadata: map[string]string{"weight": "10"},
})) }))
selector.Apply(nodes) selector.Apply(nodes)
n, done, err := selector.Select(context.Background(), WithFilter(mockFilter("v2.0.0"))) n, done, err := selector.Select(context.Background(), WithNodeFilter(mockFilter("v2.0.0")))
if err != nil { if err != nil {
t.Errorf("expect %v, got %v", nil, err) t.Errorf("expect %v, got %v", nil, err)
} }
@ -137,7 +136,7 @@ func TestDefault(t *testing.T) {
done(context.Background(), DoneInfo{}) done(context.Background(), DoneInfo{})
// no v3.0.0 instance // no v3.0.0 instance
n, done, err = selector.Select(context.Background(), WithFilter(mockFilter("v3.0.0"))) n, done, err = selector.Select(context.Background(), WithNodeFilter(mockFilter("v3.0.0")))
if !errors.Is(ErrNoAvailable, err) { if !errors.Is(ErrNoAvailable, err) {
t.Errorf("expect %v, got %v", ErrNoAvailable, err) t.Errorf("expect %v, got %v", ErrNoAvailable, err)
} }
@ -150,7 +149,7 @@ func TestDefault(t *testing.T) {
// apply zero instance // apply zero instance
selector.Apply([]Node{}) selector.Apply([]Node{})
n, done, err = selector.Select(context.Background(), WithFilter(mockFilter("v2.0.0"))) n, done, err = selector.Select(context.Background(), WithNodeFilter(mockFilter("v2.0.0")))
if !errors.Is(ErrNoAvailable, err) { if !errors.Is(ErrNoAvailable, err) {
t.Errorf("expect %v, got %v", ErrNoAvailable, err) t.Errorf("expect %v, got %v", ErrNoAvailable, err)
} }
@ -163,7 +162,7 @@ func TestDefault(t *testing.T) {
// apply zero instance // apply zero instance
selector.Apply(nil) selector.Apply(nil)
n, done, err = selector.Select(context.Background(), WithFilter(mockFilter("v2.0.0"))) n, done, err = selector.Select(context.Background(), WithNodeFilter(mockFilter("v2.0.0")))
if !errors.Is(ErrNoAvailable, err) { if !errors.Is(ErrNoAvailable, err) {
t.Errorf("expect %v, got %v", ErrNoAvailable, err) t.Errorf("expect %v, got %v", ErrNoAvailable, err)
} }

@ -15,20 +15,11 @@ const (
var _ selector.Balancer = &Balancer{} // Name is balancer name var _ selector.Balancer = &Balancer{} // Name is balancer name
// WithFilter with select filters
func WithFilter(filters ...selector.Filter) Option {
return func(o *options) {
o.filters = filters
}
}
// Option is random builder option. // Option is random builder option.
type Option func(o *options) type Option func(o *options)
// options is random builder options // options is random builder options
type options struct { type options struct{}
filters []selector.Filter
}
// Balancer is a random balancer. // Balancer is a random balancer.
type Balancer struct { type Balancer struct {
@ -77,7 +68,6 @@ func NewBuilder(opts ...Option) selector.Builder {
opt(&option) opt(&option)
} }
return &selector.DefaultBuilder{ return &selector.DefaultBuilder{
Filters: option.filters,
Balancer: &Builder{}, Balancer: &Builder{},
Node: &direct.Builder{}, Node: &direct.Builder{},
} }

@ -11,7 +11,7 @@ import (
) )
func TestWrr(t *testing.T) { func TestWrr(t *testing.T) {
wrr := New(WithFilter(filter.Version("v2.0.0"))) wrr := New()
var nodes []selector.Node var nodes []selector.Node
nodes = append(nodes, selector.NewNode( nodes = append(nodes, selector.NewNode(
"http", "http",
@ -32,7 +32,7 @@ func TestWrr(t *testing.T) {
wrr.Apply(nodes) wrr.Apply(nodes)
var count1, count2 int var count1, count2 int
for i := 0; i < 90; i++ { for i := 0; i < 90; i++ {
n, done, err := wrr.Select(context.Background()) n, done, err := wrr.Select(context.Background(), selector.WithNodeFilter(filter.Version("v2.0.0")))
if err != nil { if err != nil {
t.Errorf("expect no error, got %v", err) t.Errorf("expect no error, got %v", err)
} }

@ -1,59 +1,45 @@
package grpc package grpc
import ( import (
"sync"
"github.com/go-kratos/kratos/v2/registry" "github.com/go-kratos/kratos/v2/registry"
"github.com/go-kratos/kratos/v2/selector" "github.com/go-kratos/kratos/v2/selector"
"github.com/go-kratos/kratos/v2/selector/p2c"
"github.com/go-kratos/kratos/v2/selector/random"
"github.com/go-kratos/kratos/v2/selector/wrr"
"github.com/go-kratos/kratos/v2/transport" "github.com/go-kratos/kratos/v2/transport"
gBalancer "google.golang.org/grpc/balancer" "google.golang.org/grpc/balancer"
"google.golang.org/grpc/balancer/base" "google.golang.org/grpc/balancer/base"
"google.golang.org/grpc/metadata" "google.golang.org/grpc/metadata"
) )
var ( const (
_ base.PickerBuilder = &Builder{} balancerName = "selector"
_ gBalancer.Picker = &Picker{} )
mu sync.Mutex var (
_ base.PickerBuilder = &balancerBuilder{}
_ balancer.Picker = &balancerPicker{}
) )
func init() { func init() {
// inject global grpc balancer
SetGlobalBalancer(random.Name, random.NewBuilder())
SetGlobalBalancer(wrr.Name, wrr.NewBuilder())
SetGlobalBalancer(p2c.Name, p2c.NewBuilder())
}
// SetGlobalBalancer set grpc balancer with scheme.
func SetGlobalBalancer(scheme string, builder selector.Builder) {
mu.Lock()
defer mu.Unlock()
b := base.NewBalancerBuilder( b := base.NewBalancerBuilder(
scheme, balancerName,
&Builder{builder: builder}, &balancerBuilder{
builder: selector.GlobalSelector(),
},
base.Config{HealthCheck: true}, base.Config{HealthCheck: true},
) )
gBalancer.Register(b) balancer.Register(b)
} }
// Builder is grpc balancer builder. type balancerBuilder struct {
type Builder struct {
builder selector.Builder builder selector.Builder
} }
// Build creates a grpc Picker. // Build creates a grpc Picker.
func (b *Builder) Build(info base.PickerBuildInfo) gBalancer.Picker { func (b *balancerBuilder) Build(info base.PickerBuildInfo) balancer.Picker {
if len(info.ReadySCs) == 0 { if len(info.ReadySCs) == 0 {
// Block the RPC until a new picker is available via UpdateState(). // Block the RPC until a new picker is available via UpdateState().
return base.NewErrPicker(gBalancer.ErrNoSubConnAvailable) return base.NewErrPicker(balancer.ErrNoSubConnAvailable)
} }
nodes := make([]selector.Node, 0) nodes := make([]selector.Node, 0)
for conn, info := range info.ReadySCs { for conn, info := range info.ReadySCs {
ins, _ := info.Address.Attributes.Value("rawServiceInstance").(*registry.ServiceInstance) ins, _ := info.Address.Attributes.Value("rawServiceInstance").(*registry.ServiceInstance)
@ -62,40 +48,40 @@ func (b *Builder) Build(info base.PickerBuildInfo) gBalancer.Picker {
subConn: conn, subConn: conn,
}) })
} }
p := &Picker{ p := &balancerPicker{
selector: b.builder.Build(), selector: b.builder.Build(),
} }
p.selector.Apply(nodes) p.selector.Apply(nodes)
return p return p
} }
// Picker is a grpc picker. // balancerPicker is a grpc picker.
type Picker struct { type balancerPicker struct {
selector selector.Selector selector selector.Selector
} }
// Pick pick instances. // Pick pick instances.
func (p *Picker) Pick(info gBalancer.PickInfo) (gBalancer.PickResult, error) { func (p *balancerPicker) Pick(info balancer.PickInfo) (balancer.PickResult, error) {
var filters []selector.Filter var filters []selector.NodeFilter
if tr, ok := transport.FromClientContext(info.Ctx); ok { if tr, ok := transport.FromClientContext(info.Ctx); ok {
if gtr, ok := tr.(*Transport); ok { if gtr, ok := tr.(*Transport); ok {
filters = gtr.SelectFilters() filters = gtr.NodeFilters()
} }
} }
n, done, err := p.selector.Select(info.Ctx, selector.WithFilter(filters...)) n, done, err := p.selector.Select(info.Ctx, selector.WithNodeFilter(filters...))
if err != nil { if err != nil {
return gBalancer.PickResult{}, err return balancer.PickResult{}, err
} }
return gBalancer.PickResult{ return balancer.PickResult{
SubConn: n.(*grpcNode).subConn, SubConn: n.(*grpcNode).subConn,
Done: func(di gBalancer.DoneInfo) { Done: func(di balancer.DoneInfo) {
done(info.Ctx, selector.DoneInfo{ done(info.Ctx, selector.DoneInfo{
Err: di.Err, Err: di.Err,
BytesSent: di.BytesSent, BytesSent: di.BytesSent,
BytesReceived: di.BytesReceived, BytesReceived: di.BytesReceived,
ReplyMeta: Trailer(di.Trailer), ReplyMD: Trailer(di.Trailer),
}) })
}, },
}, nil }, nil
@ -115,5 +101,5 @@ func (t Trailer) Get(k string) string {
type grpcNode struct { type grpcNode struct {
selector.Node selector.Node
subConn gBalancer.SubConn subConn balancer.SubConn
} }

@ -19,19 +19,10 @@ func TestTrailer(t *testing.T) {
} }
} }
func TestBalancerName(t *testing.T) {
o := &clientOptions{}
WithBalancerName("p2c")(o)
if !reflect.DeepEqual("p2c", o.balancerName) {
t.Errorf("expect %v, got %v", "p2c", o.balancerName)
}
}
func TestFilters(t *testing.T) { func TestFilters(t *testing.T) {
o := &clientOptions{} o := &clientOptions{}
WithFilter(func(_ context.Context, nodes []selector.Node) []selector.Node { WithNodeFilter(func(_ context.Context, nodes []selector.Node) []selector.Node {
return nodes return nodes
})(o) })(o)
if !reflect.DeepEqual(1, len(o.filters)) { if !reflect.DeepEqual(1, len(o.filters)) {

@ -10,7 +10,7 @@ import (
"github.com/go-kratos/kratos/v2/middleware" "github.com/go-kratos/kratos/v2/middleware"
"github.com/go-kratos/kratos/v2/registry" "github.com/go-kratos/kratos/v2/registry"
"github.com/go-kratos/kratos/v2/selector" "github.com/go-kratos/kratos/v2/selector"
"github.com/go-kratos/kratos/v2/selector/wrr" "github.com/go-kratos/kratos/v2/selector/p2c"
"github.com/go-kratos/kratos/v2/transport" "github.com/go-kratos/kratos/v2/transport"
"github.com/go-kratos/kratos/v2/transport/grpc/resolver/discovery" "github.com/go-kratos/kratos/v2/transport/grpc/resolver/discovery"
@ -23,6 +23,12 @@ import (
grpcmd "google.golang.org/grpc/metadata" grpcmd "google.golang.org/grpc/metadata"
) )
func init() {
if selector.GlobalSelector() == nil {
selector.SetGlobalSelector(p2c.NewBuilder())
}
}
// ClientOption is gRPC client option. // ClientOption is gRPC client option.
type ClientOption func(o *clientOptions) type ClientOption func(o *clientOptions)
@ -75,15 +81,8 @@ func WithOptions(opts ...grpc.DialOption) ClientOption {
} }
} }
// WithBalancerName with balancer name // WithNodeFilter with select filters
func WithBalancerName(name string) ClientOption { func WithNodeFilter(filters ...selector.NodeFilter) ClientOption {
return func(o *clientOptions) {
o.balancerName = name
}
}
// WithFilter with select filters
func WithFilter(filters ...selector.Filter) ClientOption {
return func(o *clientOptions) { return func(o *clientOptions) {
o.filters = filters o.filters = filters
} }
@ -105,7 +104,7 @@ type clientOptions struct {
ints []grpc.UnaryClientInterceptor ints []grpc.UnaryClientInterceptor
grpcOpts []grpc.DialOption grpcOpts []grpc.DialOption
balancerName string balancerName string
filters []selector.Filter filters []selector.NodeFilter
} }
// Dial returns a GRPC connection. // Dial returns a GRPC connection.
@ -121,7 +120,7 @@ func DialInsecure(ctx context.Context, opts ...ClientOption) (*grpc.ClientConn,
func dial(ctx context.Context, insecure bool, opts ...ClientOption) (*grpc.ClientConn, error) { func dial(ctx context.Context, insecure bool, opts ...ClientOption) (*grpc.ClientConn, error) {
options := clientOptions{ options := clientOptions{
timeout: 2000 * time.Millisecond, timeout: 2000 * time.Millisecond,
balancerName: wrr.Name, balancerName: balancerName,
} }
for _, o := range opts { for _, o := range opts {
o(&options) o(&options)
@ -156,13 +155,13 @@ func dial(ctx context.Context, insecure bool, opts ...ClientOption) (*grpc.Clien
return grpc.DialContext(ctx, options.endpoint, grpcOpts...) return grpc.DialContext(ctx, options.endpoint, grpcOpts...)
} }
func unaryClientInterceptor(ms []middleware.Middleware, timeout time.Duration, filters []selector.Filter) grpc.UnaryClientInterceptor { func unaryClientInterceptor(ms []middleware.Middleware, timeout time.Duration, filters []selector.NodeFilter) grpc.UnaryClientInterceptor {
return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
ctx = transport.NewClientContext(ctx, &Transport{ ctx = transport.NewClientContext(ctx, &Transport{
endpoint: cc.Target(), endpoint: cc.Target(),
operation: method, operation: method,
reqHeader: headerCarrier{}, reqHeader: headerCarrier{},
filters: filters, nodeFilters: filters,
}) })
if timeout > 0 { if timeout > 0 {
var cancel context.CancelFunc var cancel context.CancelFunc

@ -14,7 +14,7 @@ type Transport struct {
operation string operation string
reqHeader headerCarrier reqHeader headerCarrier
replyHeader headerCarrier replyHeader headerCarrier
filters []selector.Filter nodeFilters []selector.NodeFilter
} }
// Kind returns the transport kind. // Kind returns the transport kind.
@ -42,9 +42,9 @@ func (tr *Transport) ReplyHeader() transport.Header {
return tr.replyHeader return tr.replyHeader
} }
// SelectFilters returns the client select filters. // NodeFilters returns the client select filters.
func (tr *Transport) SelectFilters() []selector.Filter { func (tr *Transport) NodeFilters() []selector.NodeFilter {
return tr.filters return tr.nodeFilters
} }
type headerCarrier metadata.MD type headerCarrier metadata.MD

@ -16,10 +16,16 @@ import (
"github.com/go-kratos/kratos/v2/middleware" "github.com/go-kratos/kratos/v2/middleware"
"github.com/go-kratos/kratos/v2/registry" "github.com/go-kratos/kratos/v2/registry"
"github.com/go-kratos/kratos/v2/selector" "github.com/go-kratos/kratos/v2/selector"
"github.com/go-kratos/kratos/v2/selector/wrr" "github.com/go-kratos/kratos/v2/selector/p2c"
"github.com/go-kratos/kratos/v2/transport" "github.com/go-kratos/kratos/v2/transport"
) )
func init() {
if selector.GlobalSelector() == nil {
selector.SetGlobalSelector(p2c.NewBuilder())
}
}
// DecodeErrorFunc is decode error func. // DecodeErrorFunc is decode error func.
type DecodeErrorFunc func(ctx context.Context, res *http.Response) error type DecodeErrorFunc func(ctx context.Context, res *http.Response) error
@ -43,7 +49,7 @@ type clientOptions struct {
decoder DecodeResponseFunc decoder DecodeResponseFunc
errorDecoder DecodeErrorFunc errorDecoder DecodeErrorFunc
transport http.RoundTripper transport http.RoundTripper
selector selector.Selector nodeFilters []selector.NodeFilter
discovery registry.Discovery discovery registry.Discovery
middleware []middleware.Middleware middleware []middleware.Middleware
block bool block bool
@ -112,10 +118,10 @@ func WithDiscovery(d registry.Discovery) ClientOption {
} }
} }
// WithSelector with client selector. // WithNodeFilter with select filters
func WithSelector(selector selector.Selector) ClientOption { func WithNodeFilter(filters ...selector.NodeFilter) ClientOption {
return func(o *clientOptions) { return func(o *clientOptions) {
o.selector = selector o.nodeFilters = filters
} }
} }
@ -140,6 +146,7 @@ type Client struct {
r *resolver r *resolver
cc *http.Client cc *http.Client
insecure bool insecure bool
selector selector.Selector
} }
// NewClient returns an HTTP client. // NewClient returns an HTTP client.
@ -151,7 +158,6 @@ func NewClient(ctx context.Context, opts ...ClientOption) (*Client, error) {
decoder: DefaultResponseDecoder, decoder: DefaultResponseDecoder,
errorDecoder: DefaultErrorDecoder, errorDecoder: DefaultErrorDecoder,
transport: http.DefaultTransport, transport: http.DefaultTransport,
selector: wrr.New(),
} }
for _, o := range opts { for _, o := range opts {
o(&options) o(&options)
@ -166,10 +172,11 @@ func NewClient(ctx context.Context, opts ...ClientOption) (*Client, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
selector := selector.GlobalSelector().Build()
var r *resolver var r *resolver
if options.discovery != nil { if options.discovery != nil {
if target.Scheme == "discovery" { if target.Scheme == "discovery" {
if r, err = newResolver(ctx, options.discovery, target, options.selector, options.block, insecure); err != nil { if r, err = newResolver(ctx, options.discovery, target, selector, options.block, insecure); err != nil {
return nil, fmt.Errorf("[http client] new resolver failed!err: %v", options.endpoint) return nil, fmt.Errorf("[http client] new resolver failed!err: %v", options.endpoint)
} }
} else if _, _, err := host.ExtractHostPort(options.endpoint); err != nil { } else if _, _, err := host.ExtractHostPort(options.endpoint); err != nil {
@ -185,6 +192,7 @@ func NewClient(ctx context.Context, opts ...ClientOption) (*Client, error) {
Timeout: options.timeout, Timeout: options.timeout,
Transport: options.transport, Transport: options.transport,
}, },
selector: selector,
}, nil }, nil
} }
@ -276,7 +284,7 @@ func (client *Client) do(req *http.Request) (*http.Response, error) {
err error err error
node selector.Node node selector.Node
) )
if node, done, err = client.opts.selector.Select(req.Context()); err != nil { if node, done, err = client.selector.Select(req.Context(), selector.WithNodeFilter(client.opts.nodeFilters...)); err != nil {
return nil, errors.ServiceUnavailable("NODE_NOT_FOUND", err.Error()) return nil, errors.ServiceUnavailable("NODE_NOT_FOUND", err.Error())
} }
if client.insecure { if client.insecure {

@ -182,13 +182,18 @@ func TestWithDiscovery(t *testing.T) {
} }
} }
func TestWithSelector(t *testing.T) { func TestWithNodeFilter(t *testing.T) {
ov := &selector.Default{} ov := func(context.Context, []selector.Node) []selector.Node {
o := WithSelector(ov) return []selector.Node{&selector.DefaultNode{}}
}
o := WithNodeFilter(ov)
co := &clientOptions{} co := &clientOptions{}
o(co) o(co)
if !reflect.DeepEqual(co.selector, ov) { for _, n := range co.nodeFilters {
t.Errorf("expected selector to be %v, got %v", ov, co.selector) ret := n(context.Background(), nil)
if len(ret) != 1 {
t.Errorf("expected node length to be 1, got %v", len(ret))
}
} }
} }

Loading…
Cancel
Save