refactor: unify selector filter (#2277)

* unify selector

Co-authored-by: caoguoliang01 <caoguoliang01@bilibili.com>
Co-authored-by: chenzhihui <zhihui_chen@foxmail.com>
pull/2283/head
longxboy 2 years ago committed by GitHub
parent d11c6892b4
commit 11cd43e3c3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 12
      selector/default_selector.go
  2. 4
      selector/filter.go
  3. 2
      selector/filter/version.go
  4. 13
      selector/global.go
  5. 8
      selector/options.go
  6. 12
      selector/p2c/p2c.go
  7. 8
      selector/p2c/p2c_test.go
  8. 12
      selector/random/random.go
  9. 4
      selector/random/random_test.go
  10. 6
      selector/selector.go
  11. 11
      selector/selector_test.go
  12. 12
      selector/wrr/wrr.go
  13. 4
      selector/wrr/wrr_test.go
  14. 68
      transport/grpc/balancer.go
  15. 11
      transport/grpc/balancer_test.go
  16. 33
      transport/grpc/client.go
  17. 8
      transport/grpc/transport.go
  18. 24
      transport/http/client.go
  19. 15
      transport/http/client_test.go

@ -9,7 +9,6 @@ import (
type Default struct {
NodeBuilder WeightedNodeBuilder
Balancer Balancer
Filters []Filter
nodes atomic.Value
}
@ -27,16 +26,13 @@ func (d *Default) Select(ctx context.Context, opts ...SelectOption) (selected No
for _, o := range opts {
o(&options)
}
if len(d.Filters) > 0 || len(options.Filters) > 0 {
if len(options.NodeFilters) > 0 {
newNodes := make([]Node, len(nodes))
for i, wc := range nodes {
newNodes[i] = wc
}
for _, f := range d.Filters {
newNodes = f(ctx, newNodes)
}
for _, f := range options.Filters {
newNodes = f(ctx, newNodes)
for _, filter := range options.NodeFilters {
newNodes = filter(ctx, newNodes)
}
candidates = make([]WeightedNode, len(newNodes))
for i, n := range newNodes {
@ -74,7 +70,6 @@ func (d *Default) Apply(nodes []Node) {
type DefaultBuilder struct {
Node WeightedNodeBuilder
Balancer BalancerBuilder
Filters []Filter
}
// Build create builder
@ -82,6 +77,5 @@ func (db *DefaultBuilder) Build() Selector {
return &Default{
NodeBuilder: db.Node,
Balancer: db.Balancer.Build(),
Filters: db.Filters,
}
}

@ -2,5 +2,5 @@ package selector
import "context"
// Filter is select filter.
type Filter func(context.Context, []Node) []Node
// NodeFilter is select filter.
type NodeFilter func(context.Context, []Node) []Node

@ -7,7 +7,7 @@ import (
)
// Version is version filter.
func Version(version string) selector.Filter {
func Version(version string) selector.NodeFilter {
return func(_ context.Context, nodes []selector.Node) []selector.Node {
newNodes := nodes[:0]
for _, n := range nodes {

@ -0,0 +1,13 @@
package selector
var globalSelector Builder
// GlobalSelector returns global selector builder.
func GlobalSelector() Builder {
return globalSelector
}
// SetGlobalSelector set global selector builder.
func SetGlobalSelector(builder Builder) {
globalSelector = builder
}

@ -2,15 +2,15 @@ package selector
// SelectOptions is Select Options.
type SelectOptions struct {
Filters []Filter
NodeFilters []NodeFilter
}
// SelectOption is Selector option.
type SelectOption func(*SelectOptions)
// WithFilter with filter options
func WithFilter(fn ...Filter) SelectOption {
// WithNodeFilter with filter options
func WithNodeFilter(fn ...NodeFilter) SelectOption {
return func(opts *SelectOptions) {
opts.Filters = fn
opts.NodeFilters = fn
}
}

@ -19,20 +19,11 @@ const (
var _ selector.Balancer = &Balancer{}
// WithFilter with select filters
func WithFilter(filters ...selector.Filter) Option {
return func(o *options) {
o.filters = filters
}
}
// Option is random builder option.
type Option func(o *options)
// options is random builder options
type options struct {
filters []selector.Filter
}
type options struct{}
// New creates a p2c selector.
func New(opts ...Option) selector.Selector {
@ -95,7 +86,6 @@ func NewBuilder(opts ...Option) selector.Builder {
opt(&option)
}
return &selector.DefaultBuilder{
Filters: option.filters,
Balancer: &Builder{},
Node: &ewma.Builder{},
}

@ -16,7 +16,7 @@ import (
)
func TestWrr3(t *testing.T) {
p2c := New(WithFilter(filter.Version("v2.0.0")))
p2c := New()
var nodes []selector.Node
for i := 0; i < 3; i++ {
addr := fmt.Sprintf("127.0.0.%d:8080", i)
@ -41,7 +41,7 @@ func TestWrr3(t *testing.T) {
d := time.Duration(rand.Intn(500)) * time.Millisecond
lk.Unlock()
time.Sleep(d)
n, done, err := p2c.Select(context.Background())
n, done, err := p2c.Select(context.Background(), selector.WithNodeFilter(filter.Version("v2.0.0")))
if err != nil {
t.Errorf("expect %v, got %v", nil, err)
}
@ -92,7 +92,7 @@ func TestEmpty(t *testing.T) {
}
func TestOne(t *testing.T) {
p2c := New(WithFilter(filter.Version("v2.0.0")))
p2c := New()
var nodes []selector.Node
for i := 0; i < 1; i++ {
addr := fmt.Sprintf("127.0.0.%d:8080", i)
@ -106,7 +106,7 @@ func TestOne(t *testing.T) {
}))
}
p2c.Apply(nodes)
n, done, err := p2c.Select(context.Background())
n, done, err := p2c.Select(context.Background(), selector.WithNodeFilter(filter.Version("v2.0.0")))
if err != nil {
t.Errorf("expect %v, got %v", nil, err)
}

@ -15,20 +15,11 @@ const (
var _ selector.Balancer = &Balancer{} // Name is balancer name
// WithFilter with select filters
func WithFilter(filters ...selector.Filter) Option {
return func(o *options) {
o.filters = filters
}
}
// Option is random builder option.
type Option func(o *options)
// options is random builder options
type options struct {
filters []selector.Filter
}
type options struct{}
// Balancer is a random balancer.
type Balancer struct{}
@ -56,7 +47,6 @@ func NewBuilder(opts ...Option) selector.Builder {
opt(&option)
}
return &selector.DefaultBuilder{
Filters: option.filters,
Balancer: &Builder{},
Node: &direct.Builder{},
}

@ -10,7 +10,7 @@ import (
)
func TestWrr(t *testing.T) {
random := New(WithFilter(filter.Version("v2.0.0")))
random := New()
var nodes []selector.Node
nodes = append(nodes, selector.NewNode(
"http",
@ -31,7 +31,7 @@ func TestWrr(t *testing.T) {
random.Apply(nodes)
var count1, count2 int
for i := 0; i < 200; i++ {
n, done, err := random.Select(context.Background())
n, done, err := random.Select(context.Background(), selector.WithNodeFilter(filter.Version("v2.0.0")))
if err != nil {
t.Errorf("expect no error, got %v", err)
}

@ -57,7 +57,7 @@ type DoneInfo struct {
// Response Error
Err error
// Response Metadata
ReplyMeta ReplyMeta
ReplyMD ReplyMD
// BytesSent indicates if any bytes have been sent to the server.
BytesSent bool
@ -65,8 +65,8 @@ type DoneInfo struct {
BytesReceived bool
}
// ReplyMeta is Reply Metadata.
type ReplyMeta interface {
// ReplyMD is Reply Metadata.
type ReplyMD interface {
Get(key string) string
}

@ -49,7 +49,7 @@ func (b *mockWeightedNodeBuilder) Build(n Node) WeightedNode {
return &mockWeightedNode{Node: n}
}
func mockFilter(version string) Filter {
func mockFilter(version string) NodeFilter {
return func(_ context.Context, nodes []Node) []Node {
newNodes := nodes[:0]
for _, n := range nodes {
@ -83,7 +83,6 @@ func (b *mockBalancer) Pick(ctx context.Context, nodes []WeightedNode) (selected
func TestDefault(t *testing.T) {
builder := DefaultBuilder{
Node: &mockWeightedNodeBuilder{},
Filters: []Filter{mockFilter("v2.0.0")},
Balancer: &mockBalancerBuilder{},
}
selector := builder.Build()
@ -109,7 +108,7 @@ func TestDefault(t *testing.T) {
Metadata: map[string]string{"weight": "10"},
}))
selector.Apply(nodes)
n, done, err := selector.Select(context.Background(), WithFilter(mockFilter("v2.0.0")))
n, done, err := selector.Select(context.Background(), WithNodeFilter(mockFilter("v2.0.0")))
if err != nil {
t.Errorf("expect %v, got %v", nil, err)
}
@ -137,7 +136,7 @@ func TestDefault(t *testing.T) {
done(context.Background(), DoneInfo{})
// no v3.0.0 instance
n, done, err = selector.Select(context.Background(), WithFilter(mockFilter("v3.0.0")))
n, done, err = selector.Select(context.Background(), WithNodeFilter(mockFilter("v3.0.0")))
if !errors.Is(ErrNoAvailable, err) {
t.Errorf("expect %v, got %v", ErrNoAvailable, err)
}
@ -150,7 +149,7 @@ func TestDefault(t *testing.T) {
// apply zero instance
selector.Apply([]Node{})
n, done, err = selector.Select(context.Background(), WithFilter(mockFilter("v2.0.0")))
n, done, err = selector.Select(context.Background(), WithNodeFilter(mockFilter("v2.0.0")))
if !errors.Is(ErrNoAvailable, err) {
t.Errorf("expect %v, got %v", ErrNoAvailable, err)
}
@ -163,7 +162,7 @@ func TestDefault(t *testing.T) {
// apply zero instance
selector.Apply(nil)
n, done, err = selector.Select(context.Background(), WithFilter(mockFilter("v2.0.0")))
n, done, err = selector.Select(context.Background(), WithNodeFilter(mockFilter("v2.0.0")))
if !errors.Is(ErrNoAvailable, err) {
t.Errorf("expect %v, got %v", ErrNoAvailable, err)
}

@ -15,20 +15,11 @@ const (
var _ selector.Balancer = &Balancer{} // Name is balancer name
// WithFilter with select filters
func WithFilter(filters ...selector.Filter) Option {
return func(o *options) {
o.filters = filters
}
}
// Option is random builder option.
type Option func(o *options)
// options is random builder options
type options struct {
filters []selector.Filter
}
type options struct{}
// Balancer is a random balancer.
type Balancer struct {
@ -77,7 +68,6 @@ func NewBuilder(opts ...Option) selector.Builder {
opt(&option)
}
return &selector.DefaultBuilder{
Filters: option.filters,
Balancer: &Builder{},
Node: &direct.Builder{},
}

@ -11,7 +11,7 @@ import (
)
func TestWrr(t *testing.T) {
wrr := New(WithFilter(filter.Version("v2.0.0")))
wrr := New()
var nodes []selector.Node
nodes = append(nodes, selector.NewNode(
"http",
@ -32,7 +32,7 @@ func TestWrr(t *testing.T) {
wrr.Apply(nodes)
var count1, count2 int
for i := 0; i < 90; i++ {
n, done, err := wrr.Select(context.Background())
n, done, err := wrr.Select(context.Background(), selector.WithNodeFilter(filter.Version("v2.0.0")))
if err != nil {
t.Errorf("expect no error, got %v", err)
}

@ -1,59 +1,45 @@
package grpc
import (
"sync"
"github.com/go-kratos/kratos/v2/registry"
"github.com/go-kratos/kratos/v2/selector"
"github.com/go-kratos/kratos/v2/selector/p2c"
"github.com/go-kratos/kratos/v2/selector/random"
"github.com/go-kratos/kratos/v2/selector/wrr"
"github.com/go-kratos/kratos/v2/transport"
gBalancer "google.golang.org/grpc/balancer"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/balancer/base"
"google.golang.org/grpc/metadata"
)
var (
_ base.PickerBuilder = &Builder{}
_ gBalancer.Picker = &Picker{}
const (
balancerName = "selector"
)
mu sync.Mutex
var (
_ base.PickerBuilder = &balancerBuilder{}
_ balancer.Picker = &balancerPicker{}
)
func init() {
// inject global grpc balancer
SetGlobalBalancer(random.Name, random.NewBuilder())
SetGlobalBalancer(wrr.Name, wrr.NewBuilder())
SetGlobalBalancer(p2c.Name, p2c.NewBuilder())
}
// SetGlobalBalancer set grpc balancer with scheme.
func SetGlobalBalancer(scheme string, builder selector.Builder) {
mu.Lock()
defer mu.Unlock()
b := base.NewBalancerBuilder(
scheme,
&Builder{builder: builder},
balancerName,
&balancerBuilder{
builder: selector.GlobalSelector(),
},
base.Config{HealthCheck: true},
)
gBalancer.Register(b)
balancer.Register(b)
}
// Builder is grpc balancer builder.
type Builder struct {
type balancerBuilder struct {
builder selector.Builder
}
// Build creates a grpc Picker.
func (b *Builder) Build(info base.PickerBuildInfo) gBalancer.Picker {
func (b *balancerBuilder) Build(info base.PickerBuildInfo) balancer.Picker {
if len(info.ReadySCs) == 0 {
// Block the RPC until a new picker is available via UpdateState().
return base.NewErrPicker(gBalancer.ErrNoSubConnAvailable)
return base.NewErrPicker(balancer.ErrNoSubConnAvailable)
}
nodes := make([]selector.Node, 0)
for conn, info := range info.ReadySCs {
ins, _ := info.Address.Attributes.Value("rawServiceInstance").(*registry.ServiceInstance)
@ -62,40 +48,40 @@ func (b *Builder) Build(info base.PickerBuildInfo) gBalancer.Picker {
subConn: conn,
})
}
p := &Picker{
p := &balancerPicker{
selector: b.builder.Build(),
}
p.selector.Apply(nodes)
return p
}
// Picker is a grpc picker.
type Picker struct {
// balancerPicker is a grpc picker.
type balancerPicker struct {
selector selector.Selector
}
// Pick pick instances.
func (p *Picker) Pick(info gBalancer.PickInfo) (gBalancer.PickResult, error) {
var filters []selector.Filter
func (p *balancerPicker) Pick(info balancer.PickInfo) (balancer.PickResult, error) {
var filters []selector.NodeFilter
if tr, ok := transport.FromClientContext(info.Ctx); ok {
if gtr, ok := tr.(*Transport); ok {
filters = gtr.SelectFilters()
filters = gtr.NodeFilters()
}
}
n, done, err := p.selector.Select(info.Ctx, selector.WithFilter(filters...))
n, done, err := p.selector.Select(info.Ctx, selector.WithNodeFilter(filters...))
if err != nil {
return gBalancer.PickResult{}, err
return balancer.PickResult{}, err
}
return gBalancer.PickResult{
return balancer.PickResult{
SubConn: n.(*grpcNode).subConn,
Done: func(di gBalancer.DoneInfo) {
Done: func(di balancer.DoneInfo) {
done(info.Ctx, selector.DoneInfo{
Err: di.Err,
BytesSent: di.BytesSent,
BytesReceived: di.BytesReceived,
ReplyMeta: Trailer(di.Trailer),
ReplyMD: Trailer(di.Trailer),
})
},
}, nil
@ -115,5 +101,5 @@ func (t Trailer) Get(k string) string {
type grpcNode struct {
selector.Node
subConn gBalancer.SubConn
subConn balancer.SubConn
}

@ -19,19 +19,10 @@ func TestTrailer(t *testing.T) {
}
}
func TestBalancerName(t *testing.T) {
o := &clientOptions{}
WithBalancerName("p2c")(o)
if !reflect.DeepEqual("p2c", o.balancerName) {
t.Errorf("expect %v, got %v", "p2c", o.balancerName)
}
}
func TestFilters(t *testing.T) {
o := &clientOptions{}
WithFilter(func(_ context.Context, nodes []selector.Node) []selector.Node {
WithNodeFilter(func(_ context.Context, nodes []selector.Node) []selector.Node {
return nodes
})(o)
if !reflect.DeepEqual(1, len(o.filters)) {

@ -10,7 +10,7 @@ import (
"github.com/go-kratos/kratos/v2/middleware"
"github.com/go-kratos/kratos/v2/registry"
"github.com/go-kratos/kratos/v2/selector"
"github.com/go-kratos/kratos/v2/selector/wrr"
"github.com/go-kratos/kratos/v2/selector/p2c"
"github.com/go-kratos/kratos/v2/transport"
"github.com/go-kratos/kratos/v2/transport/grpc/resolver/discovery"
@ -23,6 +23,12 @@ import (
grpcmd "google.golang.org/grpc/metadata"
)
func init() {
if selector.GlobalSelector() == nil {
selector.SetGlobalSelector(p2c.NewBuilder())
}
}
// ClientOption is gRPC client option.
type ClientOption func(o *clientOptions)
@ -75,15 +81,8 @@ func WithOptions(opts ...grpc.DialOption) ClientOption {
}
}
// WithBalancerName with balancer name
func WithBalancerName(name string) ClientOption {
return func(o *clientOptions) {
o.balancerName = name
}
}
// WithFilter with select filters
func WithFilter(filters ...selector.Filter) ClientOption {
// WithNodeFilter with select filters
func WithNodeFilter(filters ...selector.NodeFilter) ClientOption {
return func(o *clientOptions) {
o.filters = filters
}
@ -105,7 +104,7 @@ type clientOptions struct {
ints []grpc.UnaryClientInterceptor
grpcOpts []grpc.DialOption
balancerName string
filters []selector.Filter
filters []selector.NodeFilter
}
// Dial returns a GRPC connection.
@ -121,7 +120,7 @@ func DialInsecure(ctx context.Context, opts ...ClientOption) (*grpc.ClientConn,
func dial(ctx context.Context, insecure bool, opts ...ClientOption) (*grpc.ClientConn, error) {
options := clientOptions{
timeout: 2000 * time.Millisecond,
balancerName: wrr.Name,
balancerName: balancerName,
}
for _, o := range opts {
o(&options)
@ -156,13 +155,13 @@ func dial(ctx context.Context, insecure bool, opts ...ClientOption) (*grpc.Clien
return grpc.DialContext(ctx, options.endpoint, grpcOpts...)
}
func unaryClientInterceptor(ms []middleware.Middleware, timeout time.Duration, filters []selector.Filter) grpc.UnaryClientInterceptor {
func unaryClientInterceptor(ms []middleware.Middleware, timeout time.Duration, filters []selector.NodeFilter) grpc.UnaryClientInterceptor {
return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
ctx = transport.NewClientContext(ctx, &Transport{
endpoint: cc.Target(),
operation: method,
reqHeader: headerCarrier{},
filters: filters,
endpoint: cc.Target(),
operation: method,
reqHeader: headerCarrier{},
nodeFilters: filters,
})
if timeout > 0 {
var cancel context.CancelFunc

@ -14,7 +14,7 @@ type Transport struct {
operation string
reqHeader headerCarrier
replyHeader headerCarrier
filters []selector.Filter
nodeFilters []selector.NodeFilter
}
// Kind returns the transport kind.
@ -42,9 +42,9 @@ func (tr *Transport) ReplyHeader() transport.Header {
return tr.replyHeader
}
// SelectFilters returns the client select filters.
func (tr *Transport) SelectFilters() []selector.Filter {
return tr.filters
// NodeFilters returns the client select filters.
func (tr *Transport) NodeFilters() []selector.NodeFilter {
return tr.nodeFilters
}
type headerCarrier metadata.MD

@ -16,10 +16,16 @@ import (
"github.com/go-kratos/kratos/v2/middleware"
"github.com/go-kratos/kratos/v2/registry"
"github.com/go-kratos/kratos/v2/selector"
"github.com/go-kratos/kratos/v2/selector/wrr"
"github.com/go-kratos/kratos/v2/selector/p2c"
"github.com/go-kratos/kratos/v2/transport"
)
func init() {
if selector.GlobalSelector() == nil {
selector.SetGlobalSelector(p2c.NewBuilder())
}
}
// DecodeErrorFunc is decode error func.
type DecodeErrorFunc func(ctx context.Context, res *http.Response) error
@ -43,7 +49,7 @@ type clientOptions struct {
decoder DecodeResponseFunc
errorDecoder DecodeErrorFunc
transport http.RoundTripper
selector selector.Selector
nodeFilters []selector.NodeFilter
discovery registry.Discovery
middleware []middleware.Middleware
block bool
@ -112,10 +118,10 @@ func WithDiscovery(d registry.Discovery) ClientOption {
}
}
// WithSelector with client selector.
func WithSelector(selector selector.Selector) ClientOption {
// WithNodeFilter with select filters
func WithNodeFilter(filters ...selector.NodeFilter) ClientOption {
return func(o *clientOptions) {
o.selector = selector
o.nodeFilters = filters
}
}
@ -140,6 +146,7 @@ type Client struct {
r *resolver
cc *http.Client
insecure bool
selector selector.Selector
}
// NewClient returns an HTTP client.
@ -151,7 +158,6 @@ func NewClient(ctx context.Context, opts ...ClientOption) (*Client, error) {
decoder: DefaultResponseDecoder,
errorDecoder: DefaultErrorDecoder,
transport: http.DefaultTransport,
selector: wrr.New(),
}
for _, o := range opts {
o(&options)
@ -166,10 +172,11 @@ func NewClient(ctx context.Context, opts ...ClientOption) (*Client, error) {
if err != nil {
return nil, err
}
selector := selector.GlobalSelector().Build()
var r *resolver
if options.discovery != nil {
if target.Scheme == "discovery" {
if r, err = newResolver(ctx, options.discovery, target, options.selector, options.block, insecure); err != nil {
if r, err = newResolver(ctx, options.discovery, target, selector, options.block, insecure); err != nil {
return nil, fmt.Errorf("[http client] new resolver failed!err: %v", options.endpoint)
}
} else if _, _, err := host.ExtractHostPort(options.endpoint); err != nil {
@ -185,6 +192,7 @@ func NewClient(ctx context.Context, opts ...ClientOption) (*Client, error) {
Timeout: options.timeout,
Transport: options.transport,
},
selector: selector,
}, nil
}
@ -276,7 +284,7 @@ func (client *Client) do(req *http.Request) (*http.Response, error) {
err error
node selector.Node
)
if node, done, err = client.opts.selector.Select(req.Context()); err != nil {
if node, done, err = client.selector.Select(req.Context(), selector.WithNodeFilter(client.opts.nodeFilters...)); err != nil {
return nil, errors.ServiceUnavailable("NODE_NOT_FOUND", err.Error())
}
if client.insecure {

@ -182,13 +182,18 @@ func TestWithDiscovery(t *testing.T) {
}
}
func TestWithSelector(t *testing.T) {
ov := &selector.Default{}
o := WithSelector(ov)
func TestWithNodeFilter(t *testing.T) {
ov := func(context.Context, []selector.Node) []selector.Node {
return []selector.Node{&selector.DefaultNode{}}
}
o := WithNodeFilter(ov)
co := &clientOptions{}
o(co)
if !reflect.DeepEqual(co.selector, ov) {
t.Errorf("expected selector to be %v, got %v", ov, co.selector)
for _, n := range co.nodeFilters {
ret := n(context.Background(), nil)
if len(ret) != 1 {
t.Errorf("expected node length to be 1, got %v", len(ret))
}
}
}

Loading…
Cancel
Save