diff --git a/examples/selector/client/main.go b/examples/selector/client/main.go index 30517e480..f490646a7 100644 --- a/examples/selector/client/main.go +++ b/examples/selector/client/main.go @@ -8,7 +8,6 @@ import ( "github.com/go-kratos/kratos/contrib/registry/consul/v2" "github.com/go-kratos/kratos/examples/helloworld/helloworld" "github.com/go-kratos/kratos/v2/middleware/recovery" - "github.com/go-kratos/kratos/v2/selector" "github.com/go-kratos/kratos/v2/selector/filter" "github.com/go-kratos/kratos/v2/selector/p2c" "github.com/go-kratos/kratos/v2/selector/wrr" @@ -32,10 +31,8 @@ func main() { // 由于gRPC框架的限制只能使用全局balancer+filter的方式来实现selector // 这里使用weighted round robin算法的balancer+静态version=1.0.0的Filter grpc.WithBalancerName(wrr.Name), - grpc.WithNodeFilter( - func(node selector.Node) bool { - return node.Version() == "1.0.0" - }, + grpc.WithFilter( + filter.Version("1.0.0"), ), ) if err != nil { diff --git a/selector/default_selector.go b/selector/default_selector.go index ae5e4596a..7550befdd 100644 --- a/selector/default_selector.go +++ b/selector/default_selector.go @@ -27,7 +27,7 @@ func (d *Default) Select(ctx context.Context, opts ...SelectOption) (selected No for _, o := range opts { o(&options) } - if len(d.Filters) > 0 { + if len(d.Filters) > 0 || len(options.Filters) > 0 { newNodes := make([]Node, len(nodes)) for i, wc := range nodes { newNodes[i] = wc @@ -35,6 +35,9 @@ func (d *Default) Select(ctx context.Context, opts ...SelectOption) (selected No for _, f := range d.Filters { newNodes = f(ctx, newNodes) } + for _, f := range options.Filters { + newNodes = f(ctx, newNodes) + } candidates = make([]WeightedNode, len(newNodes)) for i, n := range newNodes { candidates[i] = n.(WeightedNode) @@ -43,9 +46,6 @@ func (d *Default) Select(ctx context.Context, opts ...SelectOption) (selected No candidates = nodes } - if len(options.Filters) > 0 { - candidates = d.nodeFilter(options.Filters, candidates) - } if len(candidates) == 0 { return nil, nil, ErrNoAvailable } @@ -56,23 +56,6 @@ func (d *Default) Select(ctx context.Context, opts ...SelectOption) (selected No return wn.Raw(), done, nil } -func (d *Default) nodeFilter(filters []NodeFilter, nodes []WeightedNode) []WeightedNode { - newNodes := make([]WeightedNode, 0, len(nodes)) - for _, n := range nodes { - var remove bool - for _, f := range filters { - if !f(n) { - remove = true - break - } - } - if !remove { - newNodes = append(newNodes, n) - } - } - return newNodes -} - // Apply update nodes info. func (d *Default) Apply(nodes []Node) { weightedNodes := make([]WeightedNode, 0, len(nodes)) diff --git a/selector/filter.go b/selector/filter.go index 008be6586..e88a52daf 100644 --- a/selector/filter.go +++ b/selector/filter.go @@ -4,7 +4,3 @@ import "context" // Filter is select filter. type Filter func(context.Context, []Node) []Node - -// NodeFilter is node filter. -// If it returns false, the node will be removed out from the balancer pick list -type NodeFilter func(node Node) bool diff --git a/selector/filter/version.go b/selector/filter/version.go index fdcfcc7e3..815d3fb70 100644 --- a/selector/filter/version.go +++ b/selector/filter/version.go @@ -9,12 +9,12 @@ import ( // Version is version filter. func Version(version string) selector.Filter { return func(_ context.Context, nodes []selector.Node) []selector.Node { - filters := make([]selector.Node, 0, len(nodes)) + newNodes := nodes[:0] for _, n := range nodes { if n.Version() == version { - filters = append(filters, n) + newNodes = append(newNodes, n) } } - return filters + return newNodes } } diff --git a/selector/options.go b/selector/options.go index 7e1cca4b1..c9129fb32 100644 --- a/selector/options.go +++ b/selector/options.go @@ -2,14 +2,14 @@ package selector // SelectOptions is Select Options. type SelectOptions struct { - Filters []NodeFilter + Filters []Filter } // SelectOption is Selector option. type SelectOption func(*SelectOptions) -// WithNodeFilter with filter options -func WithNodeFilter(fn ...NodeFilter) SelectOption { +// WithFilter with filter options +func WithFilter(fn ...Filter) SelectOption { return func(opts *SelectOptions) { opts.Filters = fn } diff --git a/selector/selector_test.go b/selector/selector_test.go index f6f3838d3..7c735a334 100644 --- a/selector/selector_test.go +++ b/selector/selector_test.go @@ -3,7 +3,6 @@ package selector import ( "context" "math/rand" - "strconv" "sync/atomic" "testing" "time" @@ -51,13 +50,13 @@ func (b *mockWeightedNodeBuilder) Build(n Node) WeightedNode { func mockFilter(version string) Filter { return func(_ context.Context, nodes []Node) []Node { - filters := make([]Node, 0, len(nodes)) + newNodes := nodes[:0] for _, n := range nodes { if n.Version() == version { - filters = append(filters, n) + newNodes = append(newNodes, n) } } - return filters + return newNodes } } @@ -107,9 +106,7 @@ func TestDefault(t *testing.T) { Metadata: map[string]string{"weight": "10"}, })) selector.Apply(nodes) - n, done, err := selector.Select(context.Background(), WithNodeFilter(func(node Node) bool { - return (node.Version() == "v2.0.0") - })) + n, done, err := selector.Select(context.Background(), WithFilter(mockFilter("v2.0.0"))) assert.Nil(t, err) assert.NotNil(t, n) assert.NotNil(t, done) @@ -121,74 +118,22 @@ func TestDefault(t *testing.T) { done(context.Background(), DoneInfo{}) // no v3.0.0 instance - n, done, err = selector.Select(context.Background(), WithNodeFilter(func(node Node) bool { - return (node.Version() == "v3.0.0") - })) + n, done, err = selector.Select(context.Background(), WithFilter(mockFilter("v3.0.0"))) assert.Equal(t, ErrNoAvailable, err) assert.Nil(t, done) assert.Nil(t, n) // apply zero instance selector.Apply([]Node{}) - n, done, err = selector.Select(context.Background(), WithNodeFilter(func(node Node) bool { - return (node.Version() == "v2.0.0") - })) + n, done, err = selector.Select(context.Background(), WithFilter(mockFilter("v2.0.0"))) assert.Equal(t, ErrNoAvailable, err) assert.Nil(t, done) assert.Nil(t, n) // apply zero instance selector.Apply(nil) - n, done, err = selector.Select(context.Background(), WithNodeFilter(func(node Node) bool { - return (node.Version() == "v2.0.0") - })) + n, done, err = selector.Select(context.Background(), WithFilter(mockFilter("v2.0.0"))) assert.Equal(t, ErrNoAvailable, err) assert.Nil(t, done) assert.Nil(t, n) } - -func TestNodeFilterWithRandom(t *testing.T) { - for i := 0; i < 100; i++ { - testBaseFilter(t, 1000, rand.Intn(1000)) - } - - testBaseFilter(t, 0, rand.Intn(1000)) - testBaseFilter(t, 1, 1000) - testBaseFilter(t, 2, 1000) - testBaseFilter(t, 3, 1000) - testBaseFilter(t, 1, 0) - testBaseFilter(t, 2, 0) - testBaseFilter(t, 3, 0) -} - -func testBaseFilter(t *testing.T, length int, reservedRatio int) { - var raw []WeightedNode - var targets map[string]WeightedNode = make(map[string]WeightedNode) - for i := 0; i < length; i++ { - addr := strconv.FormatInt(int64(i), 10) - raw = append(raw, &mockWeightedNode{Node: NewNode( - addr, - ®istry.ServiceInstance{ - ID: addr, - Name: "helloworld", - Endpoints: []string{addr}, - })}) - if reservedRatio > rand.Intn(length) { - targets[addr] = raw[i] - } - } - - f := func(node Node) bool { - if _, ok := targets[node.Address()]; ok { - return true - } - return false - } - d := Default{} - raw = d.nodeFilter([]NodeFilter{f}, raw) - assert.Equal(t, len(targets), len(raw)) - for _, n := range raw { - _, ok := targets[n.Address()] - assert.True(t, ok) - } -} diff --git a/transport/grpc/balancer.go b/transport/grpc/balancer.go index d44c96bf7..5731bb9ea 100644 --- a/transport/grpc/balancer.go +++ b/transport/grpc/balancer.go @@ -71,14 +71,14 @@ type Picker struct { // Pick pick instances. func (p *Picker) Pick(info gBalancer.PickInfo) (gBalancer.PickResult, error) { - var filters []selector.NodeFilter + var filters []selector.Filter if tr, ok := transport.FromClientContext(info.Ctx); ok { if gtr, ok := tr.(*Transport); ok { - filters = gtr.NodeFilters() + filters = gtr.SelectFilters() } } - n, done, err := p.selector.Select(info.Ctx, selector.WithNodeFilter(filters...)) + n, done, err := p.selector.Select(info.Ctx, selector.WithFilter(filters...)) if err != nil { return gBalancer.PickResult{}, err } diff --git a/transport/grpc/balancer_test.go b/transport/grpc/balancer_test.go index da99aebee..c0ea6aa94 100644 --- a/transport/grpc/balancer_test.go +++ b/transport/grpc/balancer_test.go @@ -1,6 +1,7 @@ package grpc import ( + "context" "testing" "github.com/go-kratos/kratos/v2/selector" @@ -24,8 +25,8 @@ func TestBalancerName(t *testing.T) { func TestFilters(t *testing.T) { o := &clientOptions{} - WithNodeFilter(func(selector.Node) bool { - return true + WithFilter(func(_ context.Context, nodes []selector.Node) []selector.Node { + return nodes })(o) assert.Equal(t, 1, len(o.filters)) } diff --git a/transport/grpc/client.go b/transport/grpc/client.go index a84a09a78..9d7ede93f 100644 --- a/transport/grpc/client.go +++ b/transport/grpc/client.go @@ -80,8 +80,8 @@ func WithBalancerName(name string) ClientOption { } } -// WithNodeFilter with select filters -func WithNodeFilter(filters ...selector.NodeFilter) ClientOption { +// WithFilter with select filters +func WithFilter(filters ...selector.Filter) ClientOption { return func(o *clientOptions) { o.filters = filters } @@ -97,7 +97,7 @@ type clientOptions struct { ints []grpc.UnaryClientInterceptor grpcOpts []grpc.DialOption balancerName string - filters []selector.NodeFilter + filters []selector.Filter } // Dial returns a GRPC connection. @@ -143,7 +143,7 @@ func dial(ctx context.Context, insecure bool, opts ...ClientOption) (*grpc.Clien return grpc.DialContext(ctx, options.endpoint, grpcOpts...) } -func unaryClientInterceptor(ms []middleware.Middleware, timeout time.Duration, filters []selector.NodeFilter) grpc.UnaryClientInterceptor { +func unaryClientInterceptor(ms []middleware.Middleware, timeout time.Duration, filters []selector.Filter) grpc.UnaryClientInterceptor { return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { ctx = transport.NewClientContext(ctx, &Transport{ endpoint: cc.Target(), diff --git a/transport/grpc/transport.go b/transport/grpc/transport.go index b7df26311..efa70e74c 100644 --- a/transport/grpc/transport.go +++ b/transport/grpc/transport.go @@ -14,7 +14,7 @@ type Transport struct { operation string reqHeader headerCarrier replyHeader headerCarrier - filters []selector.NodeFilter + filters []selector.Filter } // Kind returns the transport kind. @@ -42,8 +42,8 @@ func (tr *Transport) ReplyHeader() transport.Header { return tr.replyHeader } -// Filters returns the client select filters. -func (tr *Transport) NodeFilters() []selector.NodeFilter { +// SelectFilters returns the client select filters. +func (tr *Transport) SelectFilters() []selector.Filter { return tr.filters }