From 11cd43e3c302fb0583b6fd01a4afedffd0e93665 Mon Sep 17 00:00:00 2001 From: longxboy Date: Tue, 16 Aug 2022 21:21:58 +0800 Subject: [PATCH] refactor: unify selector filter (#2277) * unify selector Co-authored-by: caoguoliang01 Co-authored-by: chenzhihui --- selector/default_selector.go | 12 ++---- selector/filter.go | 4 +- selector/filter/version.go | 2 +- selector/global.go | 13 +++++++ selector/options.go | 8 ++-- selector/p2c/p2c.go | 12 +----- selector/p2c/p2c_test.go | 8 ++-- selector/random/random.go | 12 +----- selector/random/random_test.go | 4 +- selector/selector.go | 6 +-- selector/selector_test.go | 11 +++--- selector/wrr/wrr.go | 12 +----- selector/wrr/wrr_test.go | 4 +- transport/grpc/balancer.go | 68 +++++++++++++-------------------- transport/grpc/balancer_test.go | 11 +----- transport/grpc/client.go | 33 ++++++++-------- transport/grpc/transport.go | 8 ++-- transport/http/client.go | 24 ++++++++---- transport/http/client_test.go | 15 +++++--- 19 files changed, 116 insertions(+), 151 deletions(-) create mode 100644 selector/global.go diff --git a/selector/default_selector.go b/selector/default_selector.go index 0e99249eb..247af2e92 100644 --- a/selector/default_selector.go +++ b/selector/default_selector.go @@ -9,7 +9,6 @@ import ( type Default struct { NodeBuilder WeightedNodeBuilder Balancer Balancer - Filters []Filter nodes atomic.Value } @@ -27,16 +26,13 @@ func (d *Default) Select(ctx context.Context, opts ...SelectOption) (selected No for _, o := range opts { o(&options) } - if len(d.Filters) > 0 || len(options.Filters) > 0 { + if len(options.NodeFilters) > 0 { newNodes := make([]Node, len(nodes)) for i, wc := range nodes { newNodes[i] = wc } - for _, f := range d.Filters { - newNodes = f(ctx, newNodes) - } - for _, f := range options.Filters { - newNodes = f(ctx, newNodes) + for _, filter := range options.NodeFilters { + newNodes = filter(ctx, newNodes) } candidates = make([]WeightedNode, len(newNodes)) for i, n := range newNodes { @@ -74,7 +70,6 @@ func (d *Default) Apply(nodes []Node) { type DefaultBuilder struct { Node WeightedNodeBuilder Balancer BalancerBuilder - Filters []Filter } // Build create builder @@ -82,6 +77,5 @@ func (db *DefaultBuilder) Build() Selector { return &Default{ NodeBuilder: db.Node, Balancer: db.Balancer.Build(), - Filters: db.Filters, } } diff --git a/selector/filter.go b/selector/filter.go index e88a52daf..c8b630d2b 100644 --- a/selector/filter.go +++ b/selector/filter.go @@ -2,5 +2,5 @@ package selector import "context" -// Filter is select filter. -type Filter func(context.Context, []Node) []Node +// NodeFilter is select filter. +type NodeFilter func(context.Context, []Node) []Node diff --git a/selector/filter/version.go b/selector/filter/version.go index 815d3fb70..41d56fbd3 100644 --- a/selector/filter/version.go +++ b/selector/filter/version.go @@ -7,7 +7,7 @@ import ( ) // Version is version filter. -func Version(version string) selector.Filter { +func Version(version string) selector.NodeFilter { return func(_ context.Context, nodes []selector.Node) []selector.Node { newNodes := nodes[:0] for _, n := range nodes { diff --git a/selector/global.go b/selector/global.go new file mode 100644 index 000000000..2f951212b --- /dev/null +++ b/selector/global.go @@ -0,0 +1,13 @@ +package selector + +var globalSelector Builder + +// GlobalSelector returns global selector builder. +func GlobalSelector() Builder { + return globalSelector +} + +// SetGlobalSelector set global selector builder. +func SetGlobalSelector(builder Builder) { + globalSelector = builder +} diff --git a/selector/options.go b/selector/options.go index c9129fb32..14b441d7b 100644 --- a/selector/options.go +++ b/selector/options.go @@ -2,15 +2,15 @@ package selector // SelectOptions is Select Options. type SelectOptions struct { - Filters []Filter + NodeFilters []NodeFilter } // SelectOption is Selector option. type SelectOption func(*SelectOptions) -// WithFilter with filter options -func WithFilter(fn ...Filter) SelectOption { +// WithNodeFilter with filter options +func WithNodeFilter(fn ...NodeFilter) SelectOption { return func(opts *SelectOptions) { - opts.Filters = fn + opts.NodeFilters = fn } } diff --git a/selector/p2c/p2c.go b/selector/p2c/p2c.go index 21a2199b8..6f8d5a3fa 100644 --- a/selector/p2c/p2c.go +++ b/selector/p2c/p2c.go @@ -19,20 +19,11 @@ const ( var _ selector.Balancer = &Balancer{} -// WithFilter with select filters -func WithFilter(filters ...selector.Filter) Option { - return func(o *options) { - o.filters = filters - } -} - // Option is random builder option. type Option func(o *options) // options is random builder options -type options struct { - filters []selector.Filter -} +type options struct{} // New creates a p2c selector. func New(opts ...Option) selector.Selector { @@ -95,7 +86,6 @@ func NewBuilder(opts ...Option) selector.Builder { opt(&option) } return &selector.DefaultBuilder{ - Filters: option.filters, Balancer: &Builder{}, Node: &ewma.Builder{}, } diff --git a/selector/p2c/p2c_test.go b/selector/p2c/p2c_test.go index 03d549e2c..f23412b71 100644 --- a/selector/p2c/p2c_test.go +++ b/selector/p2c/p2c_test.go @@ -16,7 +16,7 @@ import ( ) func TestWrr3(t *testing.T) { - p2c := New(WithFilter(filter.Version("v2.0.0"))) + p2c := New() var nodes []selector.Node for i := 0; i < 3; i++ { addr := fmt.Sprintf("127.0.0.%d:8080", i) @@ -41,7 +41,7 @@ func TestWrr3(t *testing.T) { d := time.Duration(rand.Intn(500)) * time.Millisecond lk.Unlock() time.Sleep(d) - n, done, err := p2c.Select(context.Background()) + n, done, err := p2c.Select(context.Background(), selector.WithNodeFilter(filter.Version("v2.0.0"))) if err != nil { t.Errorf("expect %v, got %v", nil, err) } @@ -92,7 +92,7 @@ func TestEmpty(t *testing.T) { } func TestOne(t *testing.T) { - p2c := New(WithFilter(filter.Version("v2.0.0"))) + p2c := New() var nodes []selector.Node for i := 0; i < 1; i++ { addr := fmt.Sprintf("127.0.0.%d:8080", i) @@ -106,7 +106,7 @@ func TestOne(t *testing.T) { })) } p2c.Apply(nodes) - n, done, err := p2c.Select(context.Background()) + n, done, err := p2c.Select(context.Background(), selector.WithNodeFilter(filter.Version("v2.0.0"))) if err != nil { t.Errorf("expect %v, got %v", nil, err) } diff --git a/selector/random/random.go b/selector/random/random.go index 824fb8834..cf59534d0 100644 --- a/selector/random/random.go +++ b/selector/random/random.go @@ -15,20 +15,11 @@ const ( var _ selector.Balancer = &Balancer{} // Name is balancer name -// WithFilter with select filters -func WithFilter(filters ...selector.Filter) Option { - return func(o *options) { - o.filters = filters - } -} - // Option is random builder option. type Option func(o *options) // options is random builder options -type options struct { - filters []selector.Filter -} +type options struct{} // Balancer is a random balancer. type Balancer struct{} @@ -56,7 +47,6 @@ func NewBuilder(opts ...Option) selector.Builder { opt(&option) } return &selector.DefaultBuilder{ - Filters: option.filters, Balancer: &Builder{}, Node: &direct.Builder{}, } diff --git a/selector/random/random_test.go b/selector/random/random_test.go index 7e69f3ee6..33e689941 100644 --- a/selector/random/random_test.go +++ b/selector/random/random_test.go @@ -10,7 +10,7 @@ import ( ) func TestWrr(t *testing.T) { - random := New(WithFilter(filter.Version("v2.0.0"))) + random := New() var nodes []selector.Node nodes = append(nodes, selector.NewNode( "http", @@ -31,7 +31,7 @@ func TestWrr(t *testing.T) { random.Apply(nodes) var count1, count2 int for i := 0; i < 200; i++ { - n, done, err := random.Select(context.Background()) + n, done, err := random.Select(context.Background(), selector.WithNodeFilter(filter.Version("v2.0.0"))) if err != nil { t.Errorf("expect no error, got %v", err) } diff --git a/selector/selector.go b/selector/selector.go index 79ecdf786..e346f6c35 100644 --- a/selector/selector.go +++ b/selector/selector.go @@ -57,7 +57,7 @@ type DoneInfo struct { // Response Error Err error // Response Metadata - ReplyMeta ReplyMeta + ReplyMD ReplyMD // BytesSent indicates if any bytes have been sent to the server. BytesSent bool @@ -65,8 +65,8 @@ type DoneInfo struct { BytesReceived bool } -// ReplyMeta is Reply Metadata. -type ReplyMeta interface { +// ReplyMD is Reply Metadata. +type ReplyMD interface { Get(key string) string } diff --git a/selector/selector_test.go b/selector/selector_test.go index 25e70acde..f7c1f1bcf 100644 --- a/selector/selector_test.go +++ b/selector/selector_test.go @@ -49,7 +49,7 @@ func (b *mockWeightedNodeBuilder) Build(n Node) WeightedNode { return &mockWeightedNode{Node: n} } -func mockFilter(version string) Filter { +func mockFilter(version string) NodeFilter { return func(_ context.Context, nodes []Node) []Node { newNodes := nodes[:0] for _, n := range nodes { @@ -83,7 +83,6 @@ func (b *mockBalancer) Pick(ctx context.Context, nodes []WeightedNode) (selected func TestDefault(t *testing.T) { builder := DefaultBuilder{ Node: &mockWeightedNodeBuilder{}, - Filters: []Filter{mockFilter("v2.0.0")}, Balancer: &mockBalancerBuilder{}, } selector := builder.Build() @@ -109,7 +108,7 @@ func TestDefault(t *testing.T) { Metadata: map[string]string{"weight": "10"}, })) selector.Apply(nodes) - n, done, err := selector.Select(context.Background(), WithFilter(mockFilter("v2.0.0"))) + n, done, err := selector.Select(context.Background(), WithNodeFilter(mockFilter("v2.0.0"))) if err != nil { t.Errorf("expect %v, got %v", nil, err) } @@ -137,7 +136,7 @@ func TestDefault(t *testing.T) { done(context.Background(), DoneInfo{}) // no v3.0.0 instance - n, done, err = selector.Select(context.Background(), WithFilter(mockFilter("v3.0.0"))) + n, done, err = selector.Select(context.Background(), WithNodeFilter(mockFilter("v3.0.0"))) if !errors.Is(ErrNoAvailable, err) { t.Errorf("expect %v, got %v", ErrNoAvailable, err) } @@ -150,7 +149,7 @@ func TestDefault(t *testing.T) { // apply zero instance selector.Apply([]Node{}) - n, done, err = selector.Select(context.Background(), WithFilter(mockFilter("v2.0.0"))) + n, done, err = selector.Select(context.Background(), WithNodeFilter(mockFilter("v2.0.0"))) if !errors.Is(ErrNoAvailable, err) { t.Errorf("expect %v, got %v", ErrNoAvailable, err) } @@ -163,7 +162,7 @@ func TestDefault(t *testing.T) { // apply zero instance selector.Apply(nil) - n, done, err = selector.Select(context.Background(), WithFilter(mockFilter("v2.0.0"))) + n, done, err = selector.Select(context.Background(), WithNodeFilter(mockFilter("v2.0.0"))) if !errors.Is(ErrNoAvailable, err) { t.Errorf("expect %v, got %v", ErrNoAvailable, err) } diff --git a/selector/wrr/wrr.go b/selector/wrr/wrr.go index c3405915b..0b95e04fb 100644 --- a/selector/wrr/wrr.go +++ b/selector/wrr/wrr.go @@ -15,20 +15,11 @@ const ( var _ selector.Balancer = &Balancer{} // Name is balancer name -// WithFilter with select filters -func WithFilter(filters ...selector.Filter) Option { - return func(o *options) { - o.filters = filters - } -} - // Option is random builder option. type Option func(o *options) // options is random builder options -type options struct { - filters []selector.Filter -} +type options struct{} // Balancer is a random balancer. type Balancer struct { @@ -77,7 +68,6 @@ func NewBuilder(opts ...Option) selector.Builder { opt(&option) } return &selector.DefaultBuilder{ - Filters: option.filters, Balancer: &Builder{}, Node: &direct.Builder{}, } diff --git a/selector/wrr/wrr_test.go b/selector/wrr/wrr_test.go index e3e96ed45..f1d4b45f5 100644 --- a/selector/wrr/wrr_test.go +++ b/selector/wrr/wrr_test.go @@ -11,7 +11,7 @@ import ( ) func TestWrr(t *testing.T) { - wrr := New(WithFilter(filter.Version("v2.0.0"))) + wrr := New() var nodes []selector.Node nodes = append(nodes, selector.NewNode( "http", @@ -32,7 +32,7 @@ func TestWrr(t *testing.T) { wrr.Apply(nodes) var count1, count2 int for i := 0; i < 90; i++ { - n, done, err := wrr.Select(context.Background()) + n, done, err := wrr.Select(context.Background(), selector.WithNodeFilter(filter.Version("v2.0.0"))) if err != nil { t.Errorf("expect no error, got %v", err) } diff --git a/transport/grpc/balancer.go b/transport/grpc/balancer.go index 2a27c05ad..bd50f0983 100644 --- a/transport/grpc/balancer.go +++ b/transport/grpc/balancer.go @@ -1,59 +1,45 @@ package grpc import ( - "sync" - "github.com/go-kratos/kratos/v2/registry" "github.com/go-kratos/kratos/v2/selector" - "github.com/go-kratos/kratos/v2/selector/p2c" - "github.com/go-kratos/kratos/v2/selector/random" - "github.com/go-kratos/kratos/v2/selector/wrr" "github.com/go-kratos/kratos/v2/transport" - gBalancer "google.golang.org/grpc/balancer" + "google.golang.org/grpc/balancer" "google.golang.org/grpc/balancer/base" "google.golang.org/grpc/metadata" ) -var ( - _ base.PickerBuilder = &Builder{} - _ gBalancer.Picker = &Picker{} +const ( + balancerName = "selector" +) - mu sync.Mutex +var ( + _ base.PickerBuilder = &balancerBuilder{} + _ balancer.Picker = &balancerPicker{} ) func init() { - // inject global grpc balancer - SetGlobalBalancer(random.Name, random.NewBuilder()) - SetGlobalBalancer(wrr.Name, wrr.NewBuilder()) - SetGlobalBalancer(p2c.Name, p2c.NewBuilder()) -} - -// SetGlobalBalancer set grpc balancer with scheme. -func SetGlobalBalancer(scheme string, builder selector.Builder) { - mu.Lock() - defer mu.Unlock() - b := base.NewBalancerBuilder( - scheme, - &Builder{builder: builder}, + balancerName, + &balancerBuilder{ + builder: selector.GlobalSelector(), + }, base.Config{HealthCheck: true}, ) - gBalancer.Register(b) + balancer.Register(b) } -// Builder is grpc balancer builder. -type Builder struct { +type balancerBuilder struct { builder selector.Builder } // Build creates a grpc Picker. -func (b *Builder) Build(info base.PickerBuildInfo) gBalancer.Picker { +func (b *balancerBuilder) Build(info base.PickerBuildInfo) balancer.Picker { if len(info.ReadySCs) == 0 { // Block the RPC until a new picker is available via UpdateState(). - return base.NewErrPicker(gBalancer.ErrNoSubConnAvailable) + return base.NewErrPicker(balancer.ErrNoSubConnAvailable) } - nodes := make([]selector.Node, 0) for conn, info := range info.ReadySCs { ins, _ := info.Address.Attributes.Value("rawServiceInstance").(*registry.ServiceInstance) @@ -62,40 +48,40 @@ func (b *Builder) Build(info base.PickerBuildInfo) gBalancer.Picker { subConn: conn, }) } - p := &Picker{ + p := &balancerPicker{ selector: b.builder.Build(), } p.selector.Apply(nodes) return p } -// Picker is a grpc picker. -type Picker struct { +// balancerPicker is a grpc picker. +type balancerPicker struct { selector selector.Selector } // Pick pick instances. -func (p *Picker) Pick(info gBalancer.PickInfo) (gBalancer.PickResult, error) { - var filters []selector.Filter +func (p *balancerPicker) Pick(info balancer.PickInfo) (balancer.PickResult, error) { + var filters []selector.NodeFilter if tr, ok := transport.FromClientContext(info.Ctx); ok { if gtr, ok := tr.(*Transport); ok { - filters = gtr.SelectFilters() + filters = gtr.NodeFilters() } } - n, done, err := p.selector.Select(info.Ctx, selector.WithFilter(filters...)) + n, done, err := p.selector.Select(info.Ctx, selector.WithNodeFilter(filters...)) if err != nil { - return gBalancer.PickResult{}, err + return balancer.PickResult{}, err } - return gBalancer.PickResult{ + return balancer.PickResult{ SubConn: n.(*grpcNode).subConn, - Done: func(di gBalancer.DoneInfo) { + Done: func(di balancer.DoneInfo) { done(info.Ctx, selector.DoneInfo{ Err: di.Err, BytesSent: di.BytesSent, BytesReceived: di.BytesReceived, - ReplyMeta: Trailer(di.Trailer), + ReplyMD: Trailer(di.Trailer), }) }, }, nil @@ -115,5 +101,5 @@ func (t Trailer) Get(k string) string { type grpcNode struct { selector.Node - subConn gBalancer.SubConn + subConn balancer.SubConn } diff --git a/transport/grpc/balancer_test.go b/transport/grpc/balancer_test.go index bff519bd3..5b7002643 100644 --- a/transport/grpc/balancer_test.go +++ b/transport/grpc/balancer_test.go @@ -19,19 +19,10 @@ func TestTrailer(t *testing.T) { } } -func TestBalancerName(t *testing.T) { - o := &clientOptions{} - - WithBalancerName("p2c")(o) - if !reflect.DeepEqual("p2c", o.balancerName) { - t.Errorf("expect %v, got %v", "p2c", o.balancerName) - } -} - func TestFilters(t *testing.T) { o := &clientOptions{} - WithFilter(func(_ context.Context, nodes []selector.Node) []selector.Node { + WithNodeFilter(func(_ context.Context, nodes []selector.Node) []selector.Node { return nodes })(o) if !reflect.DeepEqual(1, len(o.filters)) { diff --git a/transport/grpc/client.go b/transport/grpc/client.go index 923e52671..a1424d209 100644 --- a/transport/grpc/client.go +++ b/transport/grpc/client.go @@ -10,7 +10,7 @@ import ( "github.com/go-kratos/kratos/v2/middleware" "github.com/go-kratos/kratos/v2/registry" "github.com/go-kratos/kratos/v2/selector" - "github.com/go-kratos/kratos/v2/selector/wrr" + "github.com/go-kratos/kratos/v2/selector/p2c" "github.com/go-kratos/kratos/v2/transport" "github.com/go-kratos/kratos/v2/transport/grpc/resolver/discovery" @@ -23,6 +23,12 @@ import ( grpcmd "google.golang.org/grpc/metadata" ) +func init() { + if selector.GlobalSelector() == nil { + selector.SetGlobalSelector(p2c.NewBuilder()) + } +} + // ClientOption is gRPC client option. type ClientOption func(o *clientOptions) @@ -75,15 +81,8 @@ func WithOptions(opts ...grpc.DialOption) ClientOption { } } -// WithBalancerName with balancer name -func WithBalancerName(name string) ClientOption { - return func(o *clientOptions) { - o.balancerName = name - } -} - -// WithFilter with select filters -func WithFilter(filters ...selector.Filter) ClientOption { +// WithNodeFilter with select filters +func WithNodeFilter(filters ...selector.NodeFilter) ClientOption { return func(o *clientOptions) { o.filters = filters } @@ -105,7 +104,7 @@ type clientOptions struct { ints []grpc.UnaryClientInterceptor grpcOpts []grpc.DialOption balancerName string - filters []selector.Filter + filters []selector.NodeFilter } // Dial returns a GRPC connection. @@ -121,7 +120,7 @@ func DialInsecure(ctx context.Context, opts ...ClientOption) (*grpc.ClientConn, func dial(ctx context.Context, insecure bool, opts ...ClientOption) (*grpc.ClientConn, error) { options := clientOptions{ timeout: 2000 * time.Millisecond, - balancerName: wrr.Name, + balancerName: balancerName, } for _, o := range opts { o(&options) @@ -156,13 +155,13 @@ func dial(ctx context.Context, insecure bool, opts ...ClientOption) (*grpc.Clien return grpc.DialContext(ctx, options.endpoint, grpcOpts...) } -func unaryClientInterceptor(ms []middleware.Middleware, timeout time.Duration, filters []selector.Filter) grpc.UnaryClientInterceptor { +func unaryClientInterceptor(ms []middleware.Middleware, timeout time.Duration, filters []selector.NodeFilter) grpc.UnaryClientInterceptor { return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { ctx = transport.NewClientContext(ctx, &Transport{ - endpoint: cc.Target(), - operation: method, - reqHeader: headerCarrier{}, - filters: filters, + endpoint: cc.Target(), + operation: method, + reqHeader: headerCarrier{}, + nodeFilters: filters, }) if timeout > 0 { var cancel context.CancelFunc diff --git a/transport/grpc/transport.go b/transport/grpc/transport.go index efa70e74c..a11a271ae 100644 --- a/transport/grpc/transport.go +++ b/transport/grpc/transport.go @@ -14,7 +14,7 @@ type Transport struct { operation string reqHeader headerCarrier replyHeader headerCarrier - filters []selector.Filter + nodeFilters []selector.NodeFilter } // Kind returns the transport kind. @@ -42,9 +42,9 @@ func (tr *Transport) ReplyHeader() transport.Header { return tr.replyHeader } -// SelectFilters returns the client select filters. -func (tr *Transport) SelectFilters() []selector.Filter { - return tr.filters +// NodeFilters returns the client select filters. +func (tr *Transport) NodeFilters() []selector.NodeFilter { + return tr.nodeFilters } type headerCarrier metadata.MD diff --git a/transport/http/client.go b/transport/http/client.go index f1c213598..79ac4f911 100644 --- a/transport/http/client.go +++ b/transport/http/client.go @@ -16,10 +16,16 @@ import ( "github.com/go-kratos/kratos/v2/middleware" "github.com/go-kratos/kratos/v2/registry" "github.com/go-kratos/kratos/v2/selector" - "github.com/go-kratos/kratos/v2/selector/wrr" + "github.com/go-kratos/kratos/v2/selector/p2c" "github.com/go-kratos/kratos/v2/transport" ) +func init() { + if selector.GlobalSelector() == nil { + selector.SetGlobalSelector(p2c.NewBuilder()) + } +} + // DecodeErrorFunc is decode error func. type DecodeErrorFunc func(ctx context.Context, res *http.Response) error @@ -43,7 +49,7 @@ type clientOptions struct { decoder DecodeResponseFunc errorDecoder DecodeErrorFunc transport http.RoundTripper - selector selector.Selector + nodeFilters []selector.NodeFilter discovery registry.Discovery middleware []middleware.Middleware block bool @@ -112,10 +118,10 @@ func WithDiscovery(d registry.Discovery) ClientOption { } } -// WithSelector with client selector. -func WithSelector(selector selector.Selector) ClientOption { +// WithNodeFilter with select filters +func WithNodeFilter(filters ...selector.NodeFilter) ClientOption { return func(o *clientOptions) { - o.selector = selector + o.nodeFilters = filters } } @@ -140,6 +146,7 @@ type Client struct { r *resolver cc *http.Client insecure bool + selector selector.Selector } // NewClient returns an HTTP client. @@ -151,7 +158,6 @@ func NewClient(ctx context.Context, opts ...ClientOption) (*Client, error) { decoder: DefaultResponseDecoder, errorDecoder: DefaultErrorDecoder, transport: http.DefaultTransport, - selector: wrr.New(), } for _, o := range opts { o(&options) @@ -166,10 +172,11 @@ func NewClient(ctx context.Context, opts ...ClientOption) (*Client, error) { if err != nil { return nil, err } + selector := selector.GlobalSelector().Build() var r *resolver if options.discovery != nil { if target.Scheme == "discovery" { - if r, err = newResolver(ctx, options.discovery, target, options.selector, options.block, insecure); err != nil { + if r, err = newResolver(ctx, options.discovery, target, selector, options.block, insecure); err != nil { return nil, fmt.Errorf("[http client] new resolver failed!err: %v", options.endpoint) } } else if _, _, err := host.ExtractHostPort(options.endpoint); err != nil { @@ -185,6 +192,7 @@ func NewClient(ctx context.Context, opts ...ClientOption) (*Client, error) { Timeout: options.timeout, Transport: options.transport, }, + selector: selector, }, nil } @@ -276,7 +284,7 @@ func (client *Client) do(req *http.Request) (*http.Response, error) { err error node selector.Node ) - if node, done, err = client.opts.selector.Select(req.Context()); err != nil { + if node, done, err = client.selector.Select(req.Context(), selector.WithNodeFilter(client.opts.nodeFilters...)); err != nil { return nil, errors.ServiceUnavailable("NODE_NOT_FOUND", err.Error()) } if client.insecure { diff --git a/transport/http/client_test.go b/transport/http/client_test.go index c768b364b..4ee00456f 100644 --- a/transport/http/client_test.go +++ b/transport/http/client_test.go @@ -182,13 +182,18 @@ func TestWithDiscovery(t *testing.T) { } } -func TestWithSelector(t *testing.T) { - ov := &selector.Default{} - o := WithSelector(ov) +func TestWithNodeFilter(t *testing.T) { + ov := func(context.Context, []selector.Node) []selector.Node { + return []selector.Node{&selector.DefaultNode{}} + } + o := WithNodeFilter(ov) co := &clientOptions{} o(co) - if !reflect.DeepEqual(co.selector, ov) { - t.Errorf("expected selector to be %v, got %v", ov, co.selector) + for _, n := range co.nodeFilters { + ret := n(context.Background(), nil) + if len(ret) != 1 { + t.Errorf("expected node length to be 1, got %v", len(ret)) + } } }