feat: add base fitler to improve performace (#1612)

* add node filter

Co-authored-by: chenzhihui <zhihui_chen@foxmail.com>
pull/1615/head
longxboy 3 years ago committed by GitHub
parent c3d0bb66bb
commit 988c2312b4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 7
      examples/selector/client/main.go
  2. 64
      selector/default_selector.go
  3. 10
      selector/filter.go
  4. 6
      selector/filter/version_test.go
  5. 11
      selector/options.go
  6. 68
      selector/selector_test.go
  7. 6
      transport/grpc/balancer.go
  8. 6
      transport/grpc/balancer_test.go
  9. 8
      transport/grpc/client.go
  10. 4
      transport/grpc/transport.go

@ -8,6 +8,7 @@ import (
"github.com/go-kratos/kratos/contrib/registry/consul/v2" "github.com/go-kratos/kratos/contrib/registry/consul/v2"
"github.com/go-kratos/kratos/examples/helloworld/helloworld" "github.com/go-kratos/kratos/examples/helloworld/helloworld"
"github.com/go-kratos/kratos/v2/middleware/recovery" "github.com/go-kratos/kratos/v2/middleware/recovery"
"github.com/go-kratos/kratos/v2/selector"
"github.com/go-kratos/kratos/v2/selector/filter" "github.com/go-kratos/kratos/v2/selector/filter"
"github.com/go-kratos/kratos/v2/selector/p2c" "github.com/go-kratos/kratos/v2/selector/p2c"
"github.com/go-kratos/kratos/v2/selector/wrr" "github.com/go-kratos/kratos/v2/selector/wrr"
@ -31,7 +32,11 @@ func main() {
// 由于gRPC框架的限制只能使用全局balancer+filter的方式来实现selector // 由于gRPC框架的限制只能使用全局balancer+filter的方式来实现selector
// 这里使用weighted round robin算法的balancer+静态version=1.0.0的Filter // 这里使用weighted round robin算法的balancer+静态version=1.0.0的Filter
grpc.WithBalancerName(wrr.Name), grpc.WithBalancerName(wrr.Name),
grpc.WithSelectFilter(filter.Version("1.0.0")), grpc.WithNodeFilter(
func(node selector.Node) bool {
return node.Version() == "1.0.0"
},
),
) )
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)

@ -2,7 +2,7 @@ package selector
import ( import (
"context" "context"
"sync" "sync/atomic"
) )
// Default is composite selector. // Default is composite selector.
@ -11,29 +11,40 @@ type Default struct {
Balancer Balancer Balancer Balancer
Filters []Filter Filters []Filter
lk sync.RWMutex nodes atomic.Value
weightedNodes []Node
} }
// Select select one node. // Select select one node.
func (d *Default) Select(ctx context.Context, opts ...SelectOption) (selected Node, done DoneFunc, err error) { func (d *Default) Select(ctx context.Context, opts ...SelectOption) (selected Node, done DoneFunc, err error) {
d.lk.RLock() var (
weightedNodes := d.weightedNodes options SelectOptions
d.lk.RUnlock() candidates []WeightedNode
)
for _, f := range d.Filters { nodes, ok := d.nodes.Load().([]WeightedNode)
weightedNodes = f(ctx, weightedNodes) if !ok {
return nil, nil, ErrNoAvailable
} }
var options SelectOptions
for _, o := range opts { for _, o := range opts {
o(&options) o(&options)
} }
for _, f := range options.Filters { if len(d.Filters) > 0 {
weightedNodes = f(ctx, weightedNodes) newNodes := make([]Node, len(nodes))
for i, wc := range nodes {
newNodes[i] = wc
}
for _, f := range d.Filters {
newNodes = f(ctx, newNodes)
}
candidates = make([]WeightedNode, len(newNodes))
for i, n := range newNodes {
candidates[i] = n.(WeightedNode)
}
} else {
candidates = nodes
} }
candidates := make([]WeightedNode, 0, len(weightedNodes))
for _, n := range weightedNodes { if len(options.Filters) > 0 {
candidates = append(candidates, n.(WeightedNode)) candidates = d.nodeFilter(options.Filters, candidates)
} }
if len(candidates) == 0 { if len(candidates) == 0 {
return nil, nil, ErrNoAvailable return nil, nil, ErrNoAvailable
@ -45,16 +56,31 @@ func (d *Default) Select(ctx context.Context, opts ...SelectOption) (selected No
return wn.Raw(), done, nil return wn.Raw(), done, nil
} }
func (d *Default) nodeFilter(filters []NodeFilter, nodes []WeightedNode) []WeightedNode {
newNodes := make([]WeightedNode, 0, len(nodes))
for _, n := range nodes {
var remove bool
for _, f := range filters {
if !f(n) {
remove = true
break
}
}
if !remove {
newNodes = append(newNodes, n)
}
}
return newNodes
}
// Apply update nodes info. // Apply update nodes info.
func (d *Default) Apply(nodes []Node) { func (d *Default) Apply(nodes []Node) {
weightedNodes := make([]Node, 0, len(nodes)) weightedNodes := make([]WeightedNode, 0, len(nodes))
for _, n := range nodes { for _, n := range nodes {
weightedNodes = append(weightedNodes, d.NodeBuilder.Build(n)) weightedNodes = append(weightedNodes, d.NodeBuilder.Build(n))
} }
d.lk.Lock()
// TODO: Do not delete unchanged nodes // TODO: Do not delete unchanged nodes
d.weightedNodes = weightedNodes d.nodes.Store(weightedNodes)
d.lk.Unlock()
} }
// DefaultBuilder is de // DefaultBuilder is de

@ -0,0 +1,10 @@
package selector
import "context"
// Filter is select filter.
type Filter func(context.Context, []Node) []Node
// NodeFilter is node filter.
// If it returns false, the node will be removed out from the balancer pick list
type NodeFilter func(node Node) bool

@ -30,7 +30,7 @@ func TestVersion(t *testing.T) {
Endpoints: []string{"http://127.0.0.2:9090"}, Endpoints: []string{"http://127.0.0.2:9090"},
})) }))
n := f(context.Background(), nodes) nodes = f(context.Background(), nodes)
assert.Equal(t, 1, len(n)) assert.Equal(t, 1, len(nodes))
assert.Equal(t, "127.0.0.2:9090", n[0].Address()) assert.Equal(t, "127.0.0.2:9090", nodes[0].Address())
} }

@ -1,20 +1,15 @@
package selector package selector
import "context"
// SelectOptions is Select Options. // SelectOptions is Select Options.
type SelectOptions struct { type SelectOptions struct {
Filters []Filter Filters []NodeFilter
} }
// SelectOption is Selector option. // SelectOption is Selector option.
type SelectOption func(*SelectOptions) type SelectOption func(*SelectOptions)
// Filter is node filter function. // WithNodeFilter with filter options
type Filter func(context.Context, []Node) []Node func WithNodeFilter(fn ...NodeFilter) SelectOption {
// WithFilter with filter options
func WithFilter(fn ...Filter) SelectOption {
return func(opts *SelectOptions) { return func(opts *SelectOptions) {
opts.Filters = fn opts.Filters = fn
} }

@ -3,6 +3,7 @@ package selector
import ( import (
"context" "context"
"math/rand" "math/rand"
"strconv"
"sync/atomic" "sync/atomic"
"testing" "testing"
"time" "time"
@ -106,7 +107,9 @@ func TestDefault(t *testing.T) {
Metadata: map[string]string{"weight": "10"}, Metadata: map[string]string{"weight": "10"},
})) }))
selector.Apply(nodes) selector.Apply(nodes)
n, done, err := selector.Select(context.Background(), WithFilter(mockFilter("v2.0.0"))) n, done, err := selector.Select(context.Background(), WithNodeFilter(func(node Node) bool {
return (node.Version() == "v2.0.0")
}))
assert.Nil(t, err) assert.Nil(t, err)
assert.NotNil(t, n) assert.NotNil(t, n)
assert.NotNil(t, done) assert.NotNil(t, done)
@ -118,15 +121,74 @@ func TestDefault(t *testing.T) {
done(context.Background(), DoneInfo{}) done(context.Background(), DoneInfo{})
// no v3.0.0 instance // no v3.0.0 instance
n, done, err = selector.Select(context.Background(), WithFilter(mockFilter("v3.0.0"))) n, done, err = selector.Select(context.Background(), WithNodeFilter(func(node Node) bool {
return (node.Version() == "v3.0.0")
}))
assert.Equal(t, ErrNoAvailable, err) assert.Equal(t, ErrNoAvailable, err)
assert.Nil(t, done) assert.Nil(t, done)
assert.Nil(t, n) assert.Nil(t, n)
// apply zero instance // apply zero instance
selector.Apply([]Node{}) selector.Apply([]Node{})
n, done, err = selector.Select(context.Background(), WithFilter(mockFilter("v2.0.0"))) n, done, err = selector.Select(context.Background(), WithNodeFilter(func(node Node) bool {
return (node.Version() == "v2.0.0")
}))
assert.Equal(t, ErrNoAvailable, err) assert.Equal(t, ErrNoAvailable, err)
assert.Nil(t, done) assert.Nil(t, done)
assert.Nil(t, n) assert.Nil(t, n)
// apply zero instance
selector.Apply(nil)
n, done, err = selector.Select(context.Background(), WithNodeFilter(func(node Node) bool {
return (node.Version() == "v2.0.0")
}))
assert.Equal(t, ErrNoAvailable, err)
assert.Nil(t, done)
assert.Nil(t, n)
}
func TestNodeFilterWithRandom(t *testing.T) {
for i := 0; i < 100; i++ {
testBaseFilter(t, 1000, rand.Intn(1000))
}
testBaseFilter(t, 0, rand.Intn(1000))
testBaseFilter(t, 1, 1000)
testBaseFilter(t, 2, 1000)
testBaseFilter(t, 3, 1000)
testBaseFilter(t, 1, 0)
testBaseFilter(t, 2, 0)
testBaseFilter(t, 3, 0)
}
func testBaseFilter(t *testing.T, length int, reservedRatio int) {
var raw []WeightedNode
var targets map[string]WeightedNode = make(map[string]WeightedNode)
for i := 0; i < length; i++ {
addr := strconv.FormatInt(int64(i), 10)
raw = append(raw, &mockWeightedNode{Node: NewNode(
addr,
&registry.ServiceInstance{
ID: addr,
Name: "helloworld",
Endpoints: []string{addr},
})})
if reservedRatio > rand.Intn(length) {
targets[addr] = raw[i]
}
}
f := func(node Node) bool {
if _, ok := targets[node.Address()]; ok {
return true
}
return false
}
d := Default{}
raw = d.nodeFilter([]NodeFilter{f}, raw)
assert.Equal(t, len(targets), len(raw))
for _, n := range raw {
_, ok := targets[n.Address()]
assert.True(t, ok)
}
} }

@ -71,14 +71,14 @@ type Picker struct {
// Pick pick instances. // Pick pick instances.
func (p *Picker) Pick(info gBalancer.PickInfo) (gBalancer.PickResult, error) { func (p *Picker) Pick(info gBalancer.PickInfo) (gBalancer.PickResult, error) {
var filters []selector.Filter var filters []selector.NodeFilter
if tr, ok := transport.FromClientContext(info.Ctx); ok { if tr, ok := transport.FromClientContext(info.Ctx); ok {
if gtr, ok := tr.(*Transport); ok { if gtr, ok := tr.(*Transport); ok {
filters = gtr.Filters() filters = gtr.NodeFilters()
} }
} }
n, done, err := p.selector.Select(info.Ctx, selector.WithFilter(filters...)) n, done, err := p.selector.Select(info.Ctx, selector.WithNodeFilter(filters...))
if err != nil { if err != nil {
return gBalancer.PickResult{}, err return gBalancer.PickResult{}, err
} }

@ -3,7 +3,7 @@ package grpc
import ( import (
"testing" "testing"
"github.com/go-kratos/kratos/v2/selector/filter" "github.com/go-kratos/kratos/v2/selector"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"google.golang.org/grpc/metadata" "google.golang.org/grpc/metadata"
) )
@ -24,6 +24,8 @@ func TestBalancerName(t *testing.T) {
func TestFilters(t *testing.T) { func TestFilters(t *testing.T) {
o := &clientOptions{} o := &clientOptions{}
WithSelectFilter(filter.Version("2"))(o) WithNodeFilter(func(selector.Node) bool {
return true
})(o)
assert.Equal(t, 1, len(o.filters)) assert.Equal(t, 1, len(o.filters))
} }

@ -80,8 +80,8 @@ func WithBalancerName(name string) ClientOption {
} }
} }
// WithSelectFilter with select filters // WithNodeFilter with select filters
func WithSelectFilter(filters ...selector.Filter) ClientOption { func WithNodeFilter(filters ...selector.NodeFilter) ClientOption {
return func(o *clientOptions) { return func(o *clientOptions) {
o.filters = filters o.filters = filters
} }
@ -97,7 +97,7 @@ type clientOptions struct {
ints []grpc.UnaryClientInterceptor ints []grpc.UnaryClientInterceptor
grpcOpts []grpc.DialOption grpcOpts []grpc.DialOption
balancerName string balancerName string
filters []selector.Filter filters []selector.NodeFilter
} }
// Dial returns a GRPC connection. // Dial returns a GRPC connection.
@ -143,7 +143,7 @@ func dial(ctx context.Context, insecure bool, opts ...ClientOption) (*grpc.Clien
return grpc.DialContext(ctx, options.endpoint, grpcOpts...) return grpc.DialContext(ctx, options.endpoint, grpcOpts...)
} }
func unaryClientInterceptor(ms []middleware.Middleware, timeout time.Duration, filters []selector.Filter) grpc.UnaryClientInterceptor { func unaryClientInterceptor(ms []middleware.Middleware, timeout time.Duration, filters []selector.NodeFilter) grpc.UnaryClientInterceptor {
return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
ctx = transport.NewClientContext(ctx, &Transport{ ctx = transport.NewClientContext(ctx, &Transport{
endpoint: cc.Target(), endpoint: cc.Target(),

@ -14,7 +14,7 @@ type Transport struct {
operation string operation string
reqHeader headerCarrier reqHeader headerCarrier
replyHeader headerCarrier replyHeader headerCarrier
filters []selector.Filter filters []selector.NodeFilter
} }
// Kind returns the transport kind. // Kind returns the transport kind.
@ -43,7 +43,7 @@ func (tr *Transport) ReplyHeader() transport.Header {
} }
// Filters returns the client select filters. // Filters returns the client select filters.
func (tr *Transport) Filters() []selector.Filter { func (tr *Transport) NodeFilters() []selector.NodeFilter {
return tr.filters return tr.filters
} }

Loading…
Cancel
Save