diff --git a/selector/balancer.go b/selector/balancer.go index 9e5accbdc..dbd64bbbd 100644 --- a/selector/balancer.go +++ b/selector/balancer.go @@ -19,6 +19,9 @@ type BalancerBuilder interface { type WeightedNode interface { Node + // Raw returns the original node + Raw() Node + // Weight is the runtime calculated weight Weight() float64 diff --git a/selector/default_selector.go b/selector/default_selector.go index 0e08d615e..af4455f63 100644 --- a/selector/default_selector.go +++ b/selector/default_selector.go @@ -38,7 +38,11 @@ func (d *Default) Select(ctx context.Context, opts ...SelectOption) (selected No if len(candidates) == 0 { return nil, nil, ErrNoAvailable } - return d.Balancer.Pick(ctx, candidates) + wn, done, err := d.Balancer.Pick(ctx, candidates) + if err != nil { + return nil, nil, err + } + return wn.Raw(), done, nil } // Apply update nodes info. diff --git a/selector/node/direct/direct.go b/selector/node/direct/direct.go index e3bf139da..62b30ed7f 100644 --- a/selector/node/direct/direct.go +++ b/selector/node/direct/direct.go @@ -50,3 +50,7 @@ func (n *Node) Weight() float64 { func (n *Node) PickElapsed() time.Duration { return time.Duration(time.Now().UnixNano() - atomic.LoadInt64(&n.lastPick)) } + +func (n *Node) Raw() selector.Node { + return n.Node +} diff --git a/selector/node/direct/direct_test.go b/selector/node/direct/direct_test.go index cca733683..51689f528 100644 --- a/selector/node/direct/direct_test.go +++ b/selector/node/direct/direct_test.go @@ -48,6 +48,6 @@ func TestDirectDefaultWeight(t *testing.T) { time.Sleep(time.Millisecond * 10) done(context.Background(), selector.DoneInfo{}) assert.Equal(t, float64(100), wn.Weight()) - assert.Greater(t, time.Millisecond*15, wn.PickElapsed()) + assert.Greater(t, time.Millisecond*20, wn.PickElapsed()) assert.Less(t, time.Millisecond*5, wn.PickElapsed()) } diff --git a/selector/node/ewma/node.go b/selector/node/ewma/node.go index 444d8018a..2f2ff7e68 100644 --- a/selector/node/ewma/node.go +++ b/selector/node/ewma/node.go @@ -178,3 +178,7 @@ func (n *Node) Weight() (weight float64) { func (n *Node) PickElapsed() time.Duration { return time.Duration(time.Now().UnixNano() - atomic.LoadInt64(&n.lastPick)) } + +func (n *Node) Raw() selector.Node { + return n.Node +} diff --git a/selector/selector_test.go b/selector/selector_test.go index 8168c9bb0..3c0fa13e8 100644 --- a/selector/selector_test.go +++ b/selector/selector_test.go @@ -17,6 +17,11 @@ type mockWeightedNode struct { lastPick int64 } +// Raw returns the original node +func (n *mockWeightedNode) Raw() Node { + return n.Node +} + // Weight is the runtime calculated weight func (n *mockWeightedNode) Weight() float64 { if n.InitialWeight() != nil { diff --git a/transport/grpc/balancer.go b/transport/grpc/balancer.go index 94b3d29bb..9d0da64d7 100644 --- a/transport/grpc/balancer.go +++ b/transport/grpc/balancer.go @@ -50,19 +50,15 @@ type Builder struct { // Build creates a grpc Picker. func (b *Builder) Build(info base.PickerBuildInfo) gBalancer.Picker { nodes := make([]selector.Node, 0) - subConns := make(map[string]gBalancer.SubConn) for conn, info := range info.ReadySCs { - if _, ok := subConns[info.Address.Addr]; ok { - continue - } - subConns[info.Address.Addr] = conn - ins, _ := info.Address.Attributes.Value("rawServiceInstance").(*registry.ServiceInstance) - nodes = append(nodes, selector.NewNode(info.Address.Addr, ins)) + nodes = append(nodes, &grpcNode{ + Node: selector.NewNode(info.Address.Addr, ins), + subConn: conn, + }) } p := &Picker{ selector: b.builder.Build(), - subConns: subConns, } p.selector.Apply(nodes) return p @@ -70,7 +66,6 @@ func (b *Builder) Build(info base.PickerBuildInfo) gBalancer.Picker { // Picker is a grpc picker. type Picker struct { - subConns map[string]gBalancer.SubConn selector selector.Selector } @@ -87,10 +82,9 @@ func (p *Picker) Pick(info gBalancer.PickInfo) (gBalancer.PickResult, error) { if err != nil { return gBalancer.PickResult{}, err } - sub := p.subConns[n.Address()] return gBalancer.PickResult{ - SubConn: sub, + SubConn: n.(*grpcNode).subConn, Done: func(di gBalancer.DoneInfo) { done(info.Ctx, selector.DoneInfo{ Err: di.Err, @@ -113,3 +107,8 @@ func (t Trailer) Get(k string) string { } return "" } + +type grpcNode struct { + selector.Node + subConn gBalancer.SubConn +}