diff --git a/selector/default_selector.go b/selector/default_selector.go index 7550befdd..0e99249eb 100644 --- a/selector/default_selector.go +++ b/selector/default_selector.go @@ -53,6 +53,10 @@ func (d *Default) Select(ctx context.Context, opts ...SelectOption) (selected No if err != nil { return nil, nil, err } + p, ok := FromPeerContext(ctx) + if ok { + p.Node = wn.Raw() + } return wn.Raw(), done, nil } diff --git a/selector/peer.go b/selector/peer.go new file mode 100644 index 000000000..273ba0649 --- /dev/null +++ b/selector/peer.go @@ -0,0 +1,25 @@ +package selector + +import ( + "context" +) + +type peerKey struct{} + +// Peer contains the information of the peer for an RPC, such as the address +// and authentication information. +type Peer struct { + // node is the peer node. + Node Node +} + +// NewPeerContext creates a new context with peer information attached. +func NewPeerContext(ctx context.Context, p *Peer) context.Context { + return context.WithValue(ctx, peerKey{}, p) +} + +// FromPeerContext returns the peer information in ctx if it exists. +func FromPeerContext(ctx context.Context) (p *Peer, ok bool) { + p, ok = ctx.Value(peerKey{}).(*Peer) + return +} diff --git a/selector/peer_test.go b/selector/peer_test.go new file mode 100644 index 000000000..d0eeeb819 --- /dev/null +++ b/selector/peer_test.go @@ -0,0 +1,24 @@ +package selector + +import ( + "context" + "testing" +) + +func TestPeer(t *testing.T) { + p := Peer{ + Node: mockWeightedNode{}, + } + ctx := NewPeerContext(context.Background(), &p) + p2, ok := FromPeerContext(ctx) + if !ok || p2.Node == nil { + t.Fatalf(" no peer found!") + } +} + +func TestNotPeer(t *testing.T) { + _, ok := FromPeerContext(context.Background()) + if ok { + t.Fatalf("test no peer found peer!") + } +} diff --git a/transport/grpc/client.go b/transport/grpc/client.go index a5174b738..923e52671 100644 --- a/transport/grpc/client.go +++ b/transport/grpc/client.go @@ -184,6 +184,8 @@ func unaryClientInterceptor(ms []middleware.Middleware, timeout time.Duration, f if len(ms) > 0 { h = middleware.Chain(ms...)(h) } + var p selector.Peer + ctx = selector.NewPeerContext(ctx, &p) _, err := h(ctx, req) return err } diff --git a/transport/http/client.go b/transport/http/client.go index 2ae29ffb2..f1c213598 100644 --- a/transport/http/client.go +++ b/transport/http/client.go @@ -247,6 +247,8 @@ func (client *Client) invoke(ctx context.Context, req *http.Request, args interf } return reply, nil } + var p selector.Peer + ctx = selector.NewPeerContext(ctx, &p) if len(client.opts.middleware) > 0 { h = middleware.Chain(client.opts.middleware...)(h) }