diff --git a/examples/selector/client/main.go b/examples/selector/client/main.go index 3641c6325..30517e480 100644 --- a/examples/selector/client/main.go +++ b/examples/selector/client/main.go @@ -8,6 +8,7 @@ import ( "github.com/go-kratos/kratos/contrib/registry/consul/v2" "github.com/go-kratos/kratos/examples/helloworld/helloworld" "github.com/go-kratos/kratos/v2/middleware/recovery" + "github.com/go-kratos/kratos/v2/selector" "github.com/go-kratos/kratos/v2/selector/filter" "github.com/go-kratos/kratos/v2/selector/p2c" "github.com/go-kratos/kratos/v2/selector/wrr" @@ -31,7 +32,11 @@ func main() { // 由于gRPC框架的限制只能使用全局balancer+filter的方式来实现selector // 这里使用weighted round robin算法的balancer+静态version=1.0.0的Filter grpc.WithBalancerName(wrr.Name), - grpc.WithSelectFilter(filter.Version("1.0.0")), + grpc.WithNodeFilter( + func(node selector.Node) bool { + return node.Version() == "1.0.0" + }, + ), ) if err != nil { log.Fatal(err) diff --git a/selector/default_selector.go b/selector/default_selector.go index af4455f63..b06601f6f 100644 --- a/selector/default_selector.go +++ b/selector/default_selector.go @@ -2,7 +2,7 @@ package selector import ( "context" - "sync" + "sync/atomic" ) // Default is composite selector. @@ -11,29 +11,40 @@ type Default struct { Balancer Balancer Filters []Filter - lk sync.RWMutex - weightedNodes []Node + nodes atomic.Value } // Select select one node. func (d *Default) Select(ctx context.Context, opts ...SelectOption) (selected Node, done DoneFunc, err error) { - d.lk.RLock() - weightedNodes := d.weightedNodes - d.lk.RUnlock() - - for _, f := range d.Filters { - weightedNodes = f(ctx, weightedNodes) + var ( + options SelectOptions + candidates []WeightedNode + ) + nodes, ok := d.nodes.Load().([]WeightedNode) + if !ok { + return nil, nil, ErrNoAvailable } - var options SelectOptions for _, o := range opts { o(&options) } - for _, f := range options.Filters { - weightedNodes = f(ctx, weightedNodes) + if len(d.Filters) > 0 { + newNodes := make([]Node, len(nodes)) + for i, wc := range nodes { + newNodes[i] = wc + } + for _, f := range d.Filters { + newNodes = f(ctx, newNodes) + } + candidates = make([]WeightedNode, len(newNodes)) + for i, n := range newNodes { + candidates[i] = n.(WeightedNode) + } + } else { + candidates = nodes } - candidates := make([]WeightedNode, 0, len(weightedNodes)) - for _, n := range weightedNodes { - candidates = append(candidates, n.(WeightedNode)) + + if len(options.Filters) > 0 { + candidates = d.nodeFilter(options.Filters, candidates) } if len(candidates) == 0 { return nil, nil, ErrNoAvailable @@ -45,16 +56,31 @@ func (d *Default) Select(ctx context.Context, opts ...SelectOption) (selected No return wn.Raw(), done, nil } +func (d *Default) nodeFilter(filters []NodeFilter, nodes []WeightedNode) []WeightedNode { + newNodes := make([]WeightedNode, 0, len(nodes)) + for _, n := range nodes { + var remove bool + for _, f := range filters { + if !f(n) { + remove = true + break + } + } + if !remove { + newNodes = append(newNodes, n) + } + } + return newNodes +} + // Apply update nodes info. func (d *Default) Apply(nodes []Node) { - weightedNodes := make([]Node, 0, len(nodes)) + weightedNodes := make([]WeightedNode, 0, len(nodes)) for _, n := range nodes { weightedNodes = append(weightedNodes, d.NodeBuilder.Build(n)) } - d.lk.Lock() // TODO: Do not delete unchanged nodes - d.weightedNodes = weightedNodes - d.lk.Unlock() + d.nodes.Store(weightedNodes) } // DefaultBuilder is de diff --git a/selector/filter.go b/selector/filter.go new file mode 100644 index 000000000..008be6586 --- /dev/null +++ b/selector/filter.go @@ -0,0 +1,10 @@ +package selector + +import "context" + +// Filter is select filter. +type Filter func(context.Context, []Node) []Node + +// NodeFilter is node filter. +// If it returns false, the node will be removed out from the balancer pick list +type NodeFilter func(node Node) bool diff --git a/selector/filter/version_test.go b/selector/filter/version_test.go index e4d2acbfd..7deed51d3 100644 --- a/selector/filter/version_test.go +++ b/selector/filter/version_test.go @@ -30,7 +30,7 @@ func TestVersion(t *testing.T) { Endpoints: []string{"http://127.0.0.2:9090"}, })) - n := f(context.Background(), nodes) - assert.Equal(t, 1, len(n)) - assert.Equal(t, "127.0.0.2:9090", n[0].Address()) + nodes = f(context.Background(), nodes) + assert.Equal(t, 1, len(nodes)) + assert.Equal(t, "127.0.0.2:9090", nodes[0].Address()) } diff --git a/selector/options.go b/selector/options.go index f4ce0f731..7e1cca4b1 100644 --- a/selector/options.go +++ b/selector/options.go @@ -1,20 +1,15 @@ package selector -import "context" - // SelectOptions is Select Options. type SelectOptions struct { - Filters []Filter + Filters []NodeFilter } // SelectOption is Selector option. type SelectOption func(*SelectOptions) -// Filter is node filter function. -type Filter func(context.Context, []Node) []Node - -// WithFilter with filter options -func WithFilter(fn ...Filter) SelectOption { +// WithNodeFilter with filter options +func WithNodeFilter(fn ...NodeFilter) SelectOption { return func(opts *SelectOptions) { opts.Filters = fn } diff --git a/selector/selector_test.go b/selector/selector_test.go index 3c0fa13e8..f6f3838d3 100644 --- a/selector/selector_test.go +++ b/selector/selector_test.go @@ -3,6 +3,7 @@ package selector import ( "context" "math/rand" + "strconv" "sync/atomic" "testing" "time" @@ -106,7 +107,9 @@ func TestDefault(t *testing.T) { Metadata: map[string]string{"weight": "10"}, })) selector.Apply(nodes) - n, done, err := selector.Select(context.Background(), WithFilter(mockFilter("v2.0.0"))) + n, done, err := selector.Select(context.Background(), WithNodeFilter(func(node Node) bool { + return (node.Version() == "v2.0.0") + })) assert.Nil(t, err) assert.NotNil(t, n) assert.NotNil(t, done) @@ -118,15 +121,74 @@ func TestDefault(t *testing.T) { done(context.Background(), DoneInfo{}) // no v3.0.0 instance - n, done, err = selector.Select(context.Background(), WithFilter(mockFilter("v3.0.0"))) + n, done, err = selector.Select(context.Background(), WithNodeFilter(func(node Node) bool { + return (node.Version() == "v3.0.0") + })) assert.Equal(t, ErrNoAvailable, err) assert.Nil(t, done) assert.Nil(t, n) // apply zero instance selector.Apply([]Node{}) - n, done, err = selector.Select(context.Background(), WithFilter(mockFilter("v2.0.0"))) + n, done, err = selector.Select(context.Background(), WithNodeFilter(func(node Node) bool { + return (node.Version() == "v2.0.0") + })) assert.Equal(t, ErrNoAvailable, err) assert.Nil(t, done) assert.Nil(t, n) + + // apply zero instance + selector.Apply(nil) + n, done, err = selector.Select(context.Background(), WithNodeFilter(func(node Node) bool { + return (node.Version() == "v2.0.0") + })) + assert.Equal(t, ErrNoAvailable, err) + assert.Nil(t, done) + assert.Nil(t, n) +} + +func TestNodeFilterWithRandom(t *testing.T) { + for i := 0; i < 100; i++ { + testBaseFilter(t, 1000, rand.Intn(1000)) + } + + testBaseFilter(t, 0, rand.Intn(1000)) + testBaseFilter(t, 1, 1000) + testBaseFilter(t, 2, 1000) + testBaseFilter(t, 3, 1000) + testBaseFilter(t, 1, 0) + testBaseFilter(t, 2, 0) + testBaseFilter(t, 3, 0) +} + +func testBaseFilter(t *testing.T, length int, reservedRatio int) { + var raw []WeightedNode + var targets map[string]WeightedNode = make(map[string]WeightedNode) + for i := 0; i < length; i++ { + addr := strconv.FormatInt(int64(i), 10) + raw = append(raw, &mockWeightedNode{Node: NewNode( + addr, + ®istry.ServiceInstance{ + ID: addr, + Name: "helloworld", + Endpoints: []string{addr}, + })}) + if reservedRatio > rand.Intn(length) { + targets[addr] = raw[i] + } + } + + f := func(node Node) bool { + if _, ok := targets[node.Address()]; ok { + return true + } + return false + } + d := Default{} + raw = d.nodeFilter([]NodeFilter{f}, raw) + assert.Equal(t, len(targets), len(raw)) + for _, n := range raw { + _, ok := targets[n.Address()] + assert.True(t, ok) + } } diff --git a/transport/grpc/balancer.go b/transport/grpc/balancer.go index 9d0da64d7..d44c96bf7 100644 --- a/transport/grpc/balancer.go +++ b/transport/grpc/balancer.go @@ -71,14 +71,14 @@ type Picker struct { // Pick pick instances. func (p *Picker) Pick(info gBalancer.PickInfo) (gBalancer.PickResult, error) { - var filters []selector.Filter + var filters []selector.NodeFilter if tr, ok := transport.FromClientContext(info.Ctx); ok { if gtr, ok := tr.(*Transport); ok { - filters = gtr.Filters() + filters = gtr.NodeFilters() } } - n, done, err := p.selector.Select(info.Ctx, selector.WithFilter(filters...)) + n, done, err := p.selector.Select(info.Ctx, selector.WithNodeFilter(filters...)) if err != nil { return gBalancer.PickResult{}, err } diff --git a/transport/grpc/balancer_test.go b/transport/grpc/balancer_test.go index dcce7ffe0..da99aebee 100644 --- a/transport/grpc/balancer_test.go +++ b/transport/grpc/balancer_test.go @@ -3,7 +3,7 @@ package grpc import ( "testing" - "github.com/go-kratos/kratos/v2/selector/filter" + "github.com/go-kratos/kratos/v2/selector" "github.com/stretchr/testify/assert" "google.golang.org/grpc/metadata" ) @@ -24,6 +24,8 @@ func TestBalancerName(t *testing.T) { func TestFilters(t *testing.T) { o := &clientOptions{} - WithSelectFilter(filter.Version("2"))(o) + WithNodeFilter(func(selector.Node) bool { + return true + })(o) assert.Equal(t, 1, len(o.filters)) } diff --git a/transport/grpc/client.go b/transport/grpc/client.go index dd1a614b3..a84a09a78 100644 --- a/transport/grpc/client.go +++ b/transport/grpc/client.go @@ -80,8 +80,8 @@ func WithBalancerName(name string) ClientOption { } } -// WithSelectFilter with select filters -func WithSelectFilter(filters ...selector.Filter) ClientOption { +// WithNodeFilter with select filters +func WithNodeFilter(filters ...selector.NodeFilter) ClientOption { return func(o *clientOptions) { o.filters = filters } @@ -97,7 +97,7 @@ type clientOptions struct { ints []grpc.UnaryClientInterceptor grpcOpts []grpc.DialOption balancerName string - filters []selector.Filter + filters []selector.NodeFilter } // Dial returns a GRPC connection. @@ -143,7 +143,7 @@ func dial(ctx context.Context, insecure bool, opts ...ClientOption) (*grpc.Clien return grpc.DialContext(ctx, options.endpoint, grpcOpts...) } -func unaryClientInterceptor(ms []middleware.Middleware, timeout time.Duration, filters []selector.Filter) grpc.UnaryClientInterceptor { +func unaryClientInterceptor(ms []middleware.Middleware, timeout time.Duration, filters []selector.NodeFilter) grpc.UnaryClientInterceptor { return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { ctx = transport.NewClientContext(ctx, &Transport{ endpoint: cc.Target(), diff --git a/transport/grpc/transport.go b/transport/grpc/transport.go index eb8c51272..b7df26311 100644 --- a/transport/grpc/transport.go +++ b/transport/grpc/transport.go @@ -14,7 +14,7 @@ type Transport struct { operation string reqHeader headerCarrier replyHeader headerCarrier - filters []selector.Filter + filters []selector.NodeFilter } // Kind returns the transport kind. @@ -43,7 +43,7 @@ func (tr *Transport) ReplyHeader() transport.Header { } // Filters returns the client select filters. -func (tr *Transport) Filters() []selector.Filter { +func (tr *Transport) NodeFilters() []selector.NodeFilter { return tr.filters }