Feat: add load balancer (#1437)

* add balancer
* add p2c balancer
* add http client selector filter

Co-authored-by: yuemoxi <99347745@qq.com>
Co-authored-by: chenzhihui <zhihui_chen@foxmail.com>
pull/1466/head
longxboy 3 years ago committed by GitHub
parent 0184d217cf
commit 20f0a07d36
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 47
      examples/registry/consul/client/main.go
  2. 9
      examples/registry/consul/server/main.go
  3. 30
      selector/balancer.go
  4. 49
      selector/default.go
  5. 20
      selector/filter/version.go
  6. 60
      selector/node/default.go
  7. 52
      selector/node/direct/direct.go
  8. 180
      selector/node/ewma/node.go
  9. 21
      selector/options.go
  10. 74
      selector/p2c/p2c.go
  11. 38
      selector/random/random.go
  12. 66
      selector/selector.go
  13. 106
      transport/grpc/balancer/balancer.go
  14. 6
      transport/grpc/client.go
  15. 7
      transport/grpc/resolver/discovery/resolver.go
  16. 21
      transport/http/balancer/balancer.go
  17. 43
      transport/http/balancer/random/random.go
  18. 23
      transport/http/calloption.go
  19. 35
      transport/http/client.go
  20. 49
      transport/http/resolver.go

@ -14,19 +14,13 @@ import (
)
func main() {
client, err := api.NewClient(api.DefaultConfig())
consulCli, err := api.NewClient(api.DefaultConfig())
if err != nil {
panic(err)
}
for {
callHTTP(client)
callGRPC(client)
time.Sleep(time.Second)
}
}
r := consul.New(consulCli)
func callGRPC(cli *api.Client) {
r := consul.New(cli)
// new grpc client
conn, err := grpc.DialInsecure(
context.Background(),
grpc.WithEndpoint("discovery:///helloworld"),
@ -36,17 +30,10 @@ func callGRPC(cli *api.Client) {
log.Fatal(err)
}
defer conn.Close()
client := helloworld.NewGreeterClient(conn)
reply, err := client.SayHello(context.Background(), &helloworld.HelloRequest{Name: "kratos"})
if err != nil {
log.Fatal(err)
}
log.Printf("[grpc] SayHello %+v\n", reply)
}
gClient := helloworld.NewGreeterClient(conn)
func callHTTP(cli *api.Client) {
r := consul.New(cli)
conn, err := http.NewClient(
// new http client
hConn, err := http.NewClient(
context.Background(),
http.WithMiddleware(
recovery.Recovery(),
@ -57,9 +44,25 @@ func callHTTP(cli *api.Client) {
if err != nil {
log.Fatal(err)
}
defer conn.Close()
time.Sleep(time.Millisecond * 250)
client := helloworld.NewGreeterHTTPClient(conn)
defer hConn.Close()
hClient := helloworld.NewGreeterHTTPClient(hConn)
for {
time.Sleep(time.Second)
callGRPC(gClient)
callHTTP(hClient)
}
}
func callGRPC(client helloworld.GreeterClient) {
reply, err := client.SayHello(context.Background(), &helloworld.HelloRequest{Name: "kratos"})
if err != nil {
log.Fatal(err)
}
log.Printf("[grpc] SayHello %+v\n", reply)
}
func callHTTP(client helloworld.GreeterHTTPClient) {
reply, err := client.SayHello(context.Background(), &helloworld.HelloRequest{Name: "kratos"})
if err != nil {
log.Fatal(err)

@ -3,11 +3,13 @@ package main
import (
"context"
"fmt"
"log"
"os"
"github.com/go-kratos/kratos/contrib/registry/consul/v2"
"github.com/go-kratos/kratos/examples/helloworld/helloworld"
"github.com/go-kratos/kratos/v2"
"github.com/go-kratos/kratos/v2/log"
"github.com/go-kratos/kratos/v2/middleware/logging"
"github.com/go-kratos/kratos/v2/middleware/recovery"
"github.com/go-kratos/kratos/v2/transport/grpc"
"github.com/go-kratos/kratos/v2/transport/http"
@ -25,6 +27,9 @@ func (s *server) SayHello(ctx context.Context, in *helloworld.HelloRequest) (*he
}
func main() {
logger := log.NewStdLogger(os.Stdout)
log := log.NewHelper(logger)
consulClient, err := api.NewClient(api.DefaultConfig())
if err != nil {
log.Fatal(err)
@ -34,12 +39,14 @@ func main() {
http.Address(":8000"),
http.Middleware(
recovery.Recovery(),
logging.Server(logger),
),
)
grpcSrv := grpc.NewServer(
grpc.Address(":9000"),
grpc.Middleware(
recovery.Recovery(),
logging.Server(logger),
),
)

@ -0,0 +1,30 @@
package selector
import (
"context"
"time"
)
// Balancer is balancer interface
type Balancer interface {
Pick(ctx context.Context, nodes []WeightedNode) (selected WeightedNode, done DoneFunc, err error)
}
// WeightedNode calculates scheduling weight in real time
type WeightedNode interface {
Node
// Weight is the runtime calculated weight
Weight() float64
// Pick the node
Pick() DoneFunc
// PickElapsed is time elapsed since the latest pick
PickElapsed() time.Duration
}
// WeightedNodeBuilder is WeightedNode Builder
type WeightedNodeBuilder interface {
Build(Node) WeightedNode
}

@ -0,0 +1,49 @@
package selector
import (
"context"
"sync"
)
// Default is composite selector.
type Default struct {
NodeBuilder WeightedNodeBuilder
Balancer Balancer
lk sync.RWMutex
weightedNodes []Node
}
// Select select one node.
func (d *Default) Select(ctx context.Context, opts ...SelectOption) (selected Node, done DoneFunc, err error) {
d.lk.RLock()
weightedNodes := d.weightedNodes
d.lk.RUnlock()
var options SelectOptions
for _, o := range opts {
o(&options)
}
for _, f := range options.Filters {
weightedNodes = f(ctx, weightedNodes)
}
candidates := make([]WeightedNode, 0, len(weightedNodes))
for _, n := range weightedNodes {
candidates = append(candidates, n.(WeightedNode))
}
if len(candidates) == 0 {
return nil, nil, ErrNoAvailable
}
return d.Balancer.Pick(ctx, candidates)
}
// Apply update nodes info.
func (d *Default) Apply(nodes []Node) {
weightedNodes := make([]Node, 0, len(nodes))
for _, n := range nodes {
weightedNodes = append(weightedNodes, d.NodeBuilder.Build(n))
}
d.lk.Lock()
// TODO: Do not delete unchanged nodes
d.weightedNodes = weightedNodes
d.lk.Unlock()
}

@ -0,0 +1,20 @@
package filter
import (
"context"
"github.com/go-kratos/kratos/v2/selector"
)
// Version is verion filter.
func Version(version string) selector.Filter {
return func(_ context.Context, nodes []selector.Node) []selector.Node {
filters := make([]selector.Node, 0, len(nodes))
for _, n := range nodes {
if n.Version() == version {
filters = append(filters, n)
}
}
return filters
}
}

@ -0,0 +1,60 @@
package node
import (
"strconv"
"github.com/go-kratos/kratos/v2/registry"
"github.com/go-kratos/kratos/v2/selector"
)
// Node is slector node
type Node struct {
addr string
weight *int64
version string
name string
metadata map[string]string
}
// Address is node address
func (n *Node) Address() string {
return n.addr
}
// ServiceName is node serviceName
func (n *Node) ServiceName() string {
return n.name
}
// InitialWeight is node initialWeight
func (n *Node) InitialWeight() *int64 {
return n.weight
}
// Version is node version
func (n *Node) Version() string {
return n.version
}
// Metadata is node metadata
func (n *Node) Metadata() map[string]string {
return n.metadata
}
// New node
func New(addr string, ins *registry.ServiceInstance) selector.Node {
n := &Node{
addr: addr,
}
if ins != nil {
n.name = ins.Name
n.version = ins.Version
n.metadata = ins.Metadata
if str, ok := ins.Metadata["weight"]; ok {
if weight, err := strconv.ParseInt(str, 10, 64); err == nil {
n.weight = &weight
}
}
}
return n
}

@ -0,0 +1,52 @@
package direct
import (
"context"
"sync/atomic"
"time"
"github.com/go-kratos/kratos/v2/selector"
)
const (
defaultWeight = 100
)
var (
_ selector.WeightedNode = &node{}
_ selector.WeightedNodeBuilder = &Builder{}
)
// node is endpoint instance
type node struct {
selector.Node
// last lastPick timestamp
lastPick int64
}
// Builder is direct node builder
type Builder struct{}
// Build create node
func (*Builder) Build(n selector.Node) selector.WeightedNode {
return &node{Node: n, lastPick: 0}
}
func (n *node) Pick() selector.DoneFunc {
now := time.Now().UnixNano()
atomic.StoreInt64(&n.lastPick, now)
return func(ctx context.Context, di selector.DoneInfo) {}
}
// Weight is node effective weight
func (n *node) Weight() float64 {
if n.InitialWeight() != nil {
return float64(*n.InitialWeight())
}
return defaultWeight
}
func (n *node) PickElapsed() time.Duration {
return time.Duration(time.Now().UnixNano() - atomic.LoadInt64(&n.lastPick))
}

@ -0,0 +1,180 @@
package ewma
import (
"container/list"
"context"
"math"
"sync"
"sync/atomic"
"time"
"github.com/go-kratos/kratos/v2/errors"
"github.com/go-kratos/kratos/v2/selector"
)
const (
// The mean lifetime of `cost`, it reaches its half-life after Tau*ln(2).
tau = int64(time.Millisecond * 600)
// if statistic not collected,we add a big lag penalty to endpoint
penalty = uint64(time.Second * 10)
)
var (
_ selector.WeightedNode = &node{}
_ selector.WeightedNodeBuilder = &Builder{}
)
// node is endpoint instance
type node struct {
selector.Node
// client statistic data
lag int64
success uint64
inflight int64
inflights *list.List
// last collected timestamp
stamp int64
predictTs int64
predict int64
// request number in a period time
reqs int64
// last lastPick timestamp
lastPick int64
errHandler func(err error) (isErr bool)
lk sync.RWMutex
}
// Builder is ewma node builder.
type Builder struct {
ErrHandler func(err error) (isErr bool)
}
// Build create a weighted node.
func (b *Builder) Build(n selector.Node) selector.WeightedNode {
s := &node{
Node: n,
lag: 0,
success: 1000,
inflight: 1,
inflights: list.New(),
errHandler: b.ErrHandler,
}
return s
}
func (n *node) health() uint64 {
return atomic.LoadUint64(&n.success)
}
func (n *node) load() (load uint64) {
now := time.Now().UnixNano()
avgLag := atomic.LoadInt64(&n.lag)
lastPredictTs := atomic.LoadInt64(&n.predictTs)
predicInterval := avgLag / 5
if predicInterval < int64(time.Millisecond*5) {
predicInterval = int64(time.Millisecond * 5)
} else if predicInterval > int64(time.Millisecond*200) {
predicInterval = int64(time.Millisecond * 200)
}
if now-lastPredictTs > predicInterval {
if atomic.CompareAndSwapInt64(&n.predictTs, lastPredictTs, now) {
var (
total int64
count int
predict int64
)
n.lk.RLock()
first := n.inflights.Front()
for first != nil {
lag := now - first.Value.(int64)
if lag > avgLag {
count++
total += lag
}
first = first.Next()
}
if count > (n.inflights.Len()/2 + 1) {
predict = total / int64(count)
}
n.lk.RUnlock()
atomic.StoreInt64(&n.predict, predict)
}
}
if avgLag == 0 {
// penalty是node刚启动时没有数据时的惩罚值,默认为1e9 * 10
load = penalty * uint64(atomic.LoadInt64(&n.inflight))
} else {
predict := atomic.LoadInt64(&n.predict)
if predict > avgLag {
avgLag = predict
}
load = uint64(avgLag) * uint64(atomic.LoadInt64(&n.inflight))
}
return
}
// Pick pick a node.
func (n *node) Pick() selector.DoneFunc {
now := time.Now().UnixNano()
atomic.StoreInt64(&n.lastPick, now)
atomic.AddInt64(&n.inflight, 1)
atomic.AddInt64(&n.reqs, 1)
n.lk.Lock()
e := n.inflights.PushBack(now)
n.lk.Unlock()
return func(ctx context.Context, di selector.DoneInfo) {
n.lk.Lock()
n.inflights.Remove(e)
n.lk.Unlock()
atomic.AddInt64(&n.inflight, -1)
now := time.Now().UnixNano()
// get moving average ratio w
stamp := atomic.SwapInt64(&n.stamp, now)
td := now - stamp
if td < 0 {
td = 0
}
w := math.Exp(float64(-td) / float64(tau))
start := e.Value.(int64)
lag := now - start
if lag < 0 {
lag = 0
}
oldLag := atomic.LoadInt64(&n.lag)
if oldLag == 0 {
w = 0.0
}
lag = int64(float64(oldLag)*w + float64(lag)*(1.0-w))
atomic.StoreInt64(&n.lag, lag)
success := uint64(1000) // error value ,if error set 1
if di.Err != nil {
if n.errHandler != nil {
if n.errHandler(di.Err) {
success = 0
}
} else if errors.Is(context.DeadlineExceeded, di.Err) || errors.Is(context.Canceled, di.Err) ||
errors.IsServiceUnavailable(di.Err) || errors.IsGatewayTimeout(di.Err) {
success = 0
}
}
oldSuc := atomic.LoadUint64(&n.success)
success = uint64(float64(oldSuc)*w + float64(success)*(1.0-w))
atomic.StoreUint64(&n.success, success)
}
}
// Weight is node effective weight.
func (n *node) Weight() (weight float64) {
weight = float64(n.health()*uint64(time.Second)) / float64(n.load())
return
}
func (n *node) PickElapsed() time.Duration {
return time.Duration(time.Now().UnixNano() - atomic.LoadInt64(&n.lastPick))
}

@ -0,0 +1,21 @@
package selector
import "context"
// SelectOptions is Select Options.
type SelectOptions struct {
Filters []Filter
}
// SelectOption is Selector option.
type SelectOption func(*SelectOptions)
// Filter is node filter function.
type Filter func(context.Context, []Node) []Node
// WithFilter with filter options
func WithFilter(fn ...Filter) SelectOption {
return func(opts *SelectOptions) {
opts.Filters = fn
}
}

@ -0,0 +1,74 @@
package p2c
import (
"context"
"math/rand"
"sync/atomic"
"time"
"github.com/go-kratos/kratos/v2/selector"
"github.com/go-kratos/kratos/v2/selector/node/ewma"
)
const (
forcePick = time.Second * 3
// Name is balancer name
Name = "p2c"
)
var _ selector.Balancer = &Balancer{}
// New creates a p2c selector.
func New() selector.Selector {
return &selector.Default{
NodeBuilder: &ewma.Builder{},
Balancer: &Balancer{
r: rand.New(rand.NewSource(time.Now().UnixNano())),
},
}
}
// Balancer is p2c selector.
type Balancer struct {
r *rand.Rand
lk int64
}
// choose two distinct nodes.
func (s *Balancer) prePick(nodes []selector.WeightedNode) (nodeA selector.WeightedNode, nodeB selector.WeightedNode) {
a := s.r.Intn(len(nodes))
b := s.r.Intn(len(nodes) - 1)
if b >= a {
b = b + 1
}
nodeA, nodeB = nodes[a], nodes[b]
return
}
// Pick pick a node.
func (s *Balancer) Pick(ctx context.Context, nodes []selector.WeightedNode) (selector.WeightedNode, selector.DoneFunc, error) {
if len(nodes) == 0 {
return nil, nil, selector.ErrNoAvailable
} else if len(nodes) == 1 {
done := nodes[0].Pick()
return nodes[0], done, nil
}
var pc, upc selector.WeightedNode
nodeA, nodeB := s.prePick(nodes)
// meta.Weight为服务发布者在discovery中设置的权重
if nodeB.Weight() > nodeA.Weight() {
pc, upc = nodeB, nodeA
} else {
pc, upc = nodeA, nodeB
}
// 如果落选节点在forceGap期间内从来没有被选中一次,则强制选一次
// 利用强制的机会,来触发成功率、延迟的更新
if upc.PickElapsed() > forcePick && atomic.CompareAndSwapInt64(&s.lk, 0, 1) {
pc = upc
atomic.StoreInt64(&s.lk, 0)
}
done := pc.Pick()
return pc, done, nil
}

@ -0,0 +1,38 @@
package random
import (
"context"
"math/rand"
"github.com/go-kratos/kratos/v2/selector"
"github.com/go-kratos/kratos/v2/selector/node/direct"
)
var (
_ selector.Balancer = &Balancer{}
// Name is balancer name
Name = "random"
)
// Balancer is a random balancer.
type Balancer struct{}
// New random a selector.
func New() selector.Selector {
return &selector.Default{
Balancer: &Balancer{},
NodeBuilder: &direct.Builder{},
}
}
// Pick pick a weighted node.
func (p *Balancer) Pick(_ context.Context, nodes []selector.WeightedNode) (selector.WeightedNode, selector.DoneFunc, error) {
if len(nodes) == 0 {
return nil, nil, selector.ErrNoAvailable
}
cur := rand.Intn(len(nodes))
selected := nodes[cur]
d := selected.Pick()
return selected, d, nil
}

@ -0,0 +1,66 @@
package selector
import (
"context"
"github.com/go-kratos/kratos/v2/errors"
)
// ErrNoAvailable is no available node.
var ErrNoAvailable = errors.ServiceUnavailable("no_available_node", "")
// Selector is node pick balancer.
type Selector interface {
Rebalancer
// Select nodes
// if err == nil, selected and done must not be empty.
Select(ctx context.Context, opts ...SelectOption) (selected Node, done DoneFunc, err error)
}
// Rebalancer is nodes rebalancer.
type Rebalancer interface {
// apply all nodes when any changes happen
Apply(nodes []Node)
}
// Node is node interface.
type Node interface {
// Address is the unique address under the same service
Address() string
// ServiceName is service name
ServiceName() string
// InitialWeight is the initial value of scheduling weight
// if not set return nil
InitialWeight() *int64
// Version is service node version
Version() string
// Metadata is the kv pair metadata associated with the service instance.
// version,namespace,region,protocol etc..
Metadata() map[string]string
}
// DoneInfo is callback info when RPC invoke done.
type DoneInfo struct {
// Response Error
Err error
// Response Metadata
ReplyMeta ReplyMeta
// BytesSent indicates if any bytes have been sent to the server.
BytesSent bool
// BytesReceived indicates if any byte has been received from the server.
BytesReceived bool
}
// ReplyMeta is Reply Metadata.
type ReplyMeta interface {
Get(key string) string
}
// DoneFunc is callback function when RPC invoke done.
type DoneFunc func(ctx context.Context, di DoneInfo)

@ -0,0 +1,106 @@
package balancer
import (
"sync"
"github.com/go-kratos/kratos/v2/registry"
"github.com/go-kratos/kratos/v2/selector"
"github.com/go-kratos/kratos/v2/selector/node"
"github.com/go-kratos/kratos/v2/selector/p2c"
"github.com/go-kratos/kratos/v2/selector/random"
gBalancer "google.golang.org/grpc/balancer"
"google.golang.org/grpc/balancer/base"
"google.golang.org/grpc/metadata"
)
var (
_ base.PickerBuilder = &Builder{}
_ gBalancer.Picker = &Picker{}
mu sync.Mutex
)
func init() {
// inject global grpc balancer
SetGlobalBalancer(random.Name, random.New())
SetGlobalBalancer(p2c.Name, p2c.New())
}
// SetGlobalBalancer set grpc balancer with scheme.
func SetGlobalBalancer(scheme string, selector selector.Selector) {
mu.Lock()
defer mu.Unlock()
b := base.NewBalancerBuilder(
scheme,
&Builder{selector},
base.Config{HealthCheck: true},
)
gBalancer.Register(b)
}
// Builder is grpc balancer builder.
type Builder struct {
selector selector.Selector
}
// Build creates a grpc Picker.
func (b *Builder) Build(info base.PickerBuildInfo) gBalancer.Picker {
nodes := make([]selector.Node, 0)
subConns := make(map[string]gBalancer.SubConn)
for conn, info := range info.ReadySCs {
if _, ok := subConns[info.Address.Addr]; ok {
continue
}
subConns[info.Address.Addr] = conn
ins, _ := info.Address.Attributes.Value("rawServiceInstance").(*registry.ServiceInstance)
nodes = append(nodes, node.New(info.Address.Addr, ins))
}
p := &Picker{
selector: b.selector,
subConns: subConns,
}
p.selector.Apply(nodes)
return p
}
// Picker is a grpc picker.
type Picker struct {
subConns map[string]gBalancer.SubConn
selector selector.Selector
}
// Pick pick instances.
func (p *Picker) Pick(info gBalancer.PickInfo) (gBalancer.PickResult, error) {
n, done, err := p.selector.Select(info.Ctx)
if err != nil {
return gBalancer.PickResult{}, err
}
sub := p.subConns[n.Address()]
return gBalancer.PickResult{
SubConn: sub,
Done: func(di gBalancer.DoneInfo) {
done(info.Ctx, selector.DoneInfo{
Err: di.Err,
BytesSent: di.BytesSent,
BytesReceived: di.BytesReceived,
ReplyMeta: Trailer(di.Trailer),
})
},
}, nil
}
// Trailer is a grpc trailder MD.
type Trailer metadata.MD
// Get get a grpc trailer value.
func (t Trailer) Get(k string) string {
v := metadata.MD(t).Get(k)
if len(v) > 0 {
return v[0]
}
return ""
}

@ -8,14 +8,16 @@ import (
"github.com/go-kratos/kratos/v2/middleware"
"github.com/go-kratos/kratos/v2/registry"
"github.com/go-kratos/kratos/v2/selector/random"
"github.com/go-kratos/kratos/v2/transport"
"github.com/go-kratos/kratos/v2/transport/grpc/resolver/discovery"
// init resolver
_ "github.com/go-kratos/kratos/v2/transport/grpc/resolver/direct"
// init balancer
_ "github.com/go-kratos/kratos/v2/transport/grpc/balancer"
"google.golang.org/grpc"
"google.golang.org/grpc/balancer/roundrobin"
"google.golang.org/grpc/credentials"
grpcmd "google.golang.org/grpc/metadata"
)
@ -107,7 +109,7 @@ func dial(ctx context.Context, insecure bool, opts ...ClientOption) (*grpc.Clien
ints = append(ints, options.ints...)
}
grpcOpts := []grpc.DialOption{
grpc.WithDefaultServiceConfig(fmt.Sprintf(`{"LoadBalancingPolicy": "%s"}`, roundrobin.Name)),
grpc.WithDefaultServiceConfig(fmt.Sprintf(`{"LoadBalancingPolicy": "%s"}`, random.Name)),
grpc.WithChainUnaryInterceptor(ints...),
}
if options.discovery != nil {

@ -46,6 +46,7 @@ func (r *discoveryResolver) watch() {
func (r *discoveryResolver) update(ins []*registry.ServiceInstance) {
addrs := make([]resolver.Address, 0)
endpoints := make(map[string]struct{})
for _, in := range ins {
endpoint, err := endpoint.ParseEndpoint(in.Endpoints, "grpc", !r.insecure)
if err != nil {
@ -55,11 +56,17 @@ func (r *discoveryResolver) update(ins []*registry.ServiceInstance) {
if endpoint == "" {
continue
}
// filter redundant endpoints
if _, ok := endpoints[endpoint]; ok {
continue
}
endpoints[endpoint] = struct{}{}
addr := resolver.Address{
ServerName: in.Name,
Attributes: parseAttributes(in.Metadata),
Addr: endpoint,
}
addr.Attributes = addr.Attributes.WithValues("rawServiceInstance", in)
addrs = append(addrs, addr)
}
if len(addrs) == 0 {

@ -1,21 +0,0 @@
package balancer
import (
"context"
"github.com/go-kratos/kratos/v2/registry"
)
// DoneInfo is callback when rpc done
type DoneInfo struct {
Err error
Trailer map[string]string
}
// Balancer is node pick balancer
type Balancer interface {
// Pick one node
Pick(ctx context.Context) (node *registry.ServiceInstance, done func(context.Context, DoneInfo), err error)
// Update nodes when nodes removed or added
Update(nodes []*registry.ServiceInstance)
}

@ -1,43 +0,0 @@
package random
import (
"context"
"fmt"
"math/rand"
"sync"
"github.com/go-kratos/kratos/v2/registry"
"github.com/go-kratos/kratos/v2/transport/http/balancer"
)
var _ balancer.Balancer = &Balancer{}
type Balancer struct {
lock sync.RWMutex
nodes []*registry.ServiceInstance
}
func New() *Balancer {
return &Balancer{}
}
func (b *Balancer) Pick(ctx context.Context) (node *registry.ServiceInstance, done func(context.Context, balancer.DoneInfo), err error) {
b.lock.RLock()
nodes := b.nodes
b.lock.RUnlock()
if len(nodes) == 0 {
return nil, nil, fmt.Errorf("no instances available")
}
if len(nodes) == 1 {
return nodes[0], func(context.Context, balancer.DoneInfo) {}, nil
}
idx := rand.Intn(len(nodes))
return nodes[idx], func(context.Context, balancer.DoneInfo) {}, nil
}
func (b *Balancer) Update(nodes []*registry.ServiceInstance) {
b.lock.Lock()
defer b.lock.Unlock()
b.nodes = nodes
}

@ -1,6 +1,10 @@
package http
import "net/http"
import (
"net/http"
"github.com/go-kratos/kratos/v2/selector"
)
// CallOption configures a Call before it starts or extracts information from
// a Call after it completes.
@ -18,6 +22,7 @@ type callInfo struct {
contentType string
operation string
pathTemplate string
filters []selector.Filter
}
// EmptyCallOption does not alter the Call configuration.
@ -88,6 +93,22 @@ func (o PathTemplateCallOption) before(c *callInfo) error {
return nil
}
// SelectFilter is http select filter
func SelectFilter(filters ...selector.Filter) CallOption {
return SelectFilterCallOption{filters: filters}
}
// SelectFilterCallOption is set call select filters
type SelectFilterCallOption struct {
EmptyCallOption
filters []selector.Filter
}
func (o SelectFilterCallOption) before(c *callInfo) error {
c.filters = o.filters
return nil
}
// Header returns a CallOptions that retrieves the http response header
// from server reply.
func Header(header *http.Header) CallOption {

@ -12,14 +12,13 @@ import (
"github.com/go-kratos/kratos/v2/encoding"
"github.com/go-kratos/kratos/v2/errors"
"github.com/go-kratos/kratos/v2/internal/endpoint"
"github.com/go-kratos/kratos/v2/internal/host"
"github.com/go-kratos/kratos/v2/internal/httputil"
"github.com/go-kratos/kratos/v2/middleware"
"github.com/go-kratos/kratos/v2/registry"
"github.com/go-kratos/kratos/v2/selector"
"github.com/go-kratos/kratos/v2/selector/random"
"github.com/go-kratos/kratos/v2/transport"
"github.com/go-kratos/kratos/v2/transport/http/balancer"
"github.com/go-kratos/kratos/v2/transport/http/balancer/random"
)
// DecodeErrorFunc is decode error func.
@ -45,7 +44,7 @@ type clientOptions struct {
decoder DecodeResponseFunc
errorDecoder DecodeErrorFunc
transport http.RoundTripper
balancer balancer.Balancer
selector selector.Selector
discovery registry.Discovery
middleware []middleware.Middleware
block bool
@ -114,12 +113,10 @@ func WithDiscovery(d registry.Discovery) ClientOption {
}
}
// WithBalancer with client balancer.
// Experimental
// Notice: This type is EXPERIMENTAL and may be changed or removed in a later release.
func WithBalancer(b balancer.Balancer) ClientOption {
// WithSelector with client selector.
func WithSelector(selector selector.Selector) ClientOption {
return func(o *clientOptions) {
o.balancer = b
o.selector = selector
}
}
@ -155,7 +152,7 @@ func NewClient(ctx context.Context, opts ...ClientOption) (*Client, error) {
decoder: DefaultResponseDecoder,
errorDecoder: DefaultErrorDecoder,
transport: http.DefaultTransport,
balancer: random.New(),
selector: random.New(),
}
for _, o := range opts {
o(&options)
@ -173,7 +170,7 @@ func NewClient(ctx context.Context, opts ...ClientOption) (*Client, error) {
var r *resolver
if options.discovery != nil {
if target.Scheme == "discovery" {
if r, err = newResolver(ctx, options.discovery, target, options.balancer, options.block, insecure); err != nil {
if r, err = newResolver(ctx, options.discovery, target, options.selector, options.block, insecure); err != nil {
return nil, fmt.Errorf("[http client] new resolver failed!err: %v", options.endpoint)
}
} else if _, _, err := host.ExtractHostPort(options.endpoint); err != nil {
@ -235,17 +232,13 @@ func (client *Client) Invoke(ctx context.Context, method, path string, args inte
func (client *Client) invoke(ctx context.Context, req *http.Request, args interface{}, reply interface{}, c callInfo, opts ...CallOption) error {
h := func(ctx context.Context, in interface{}) (interface{}, error) {
var done func(context.Context, balancer.DoneInfo)
var done func(context.Context, selector.DoneInfo)
if client.r != nil {
var (
err error
node *registry.ServiceInstance
node selector.Node
)
if node, done, err = client.opts.balancer.Pick(ctx); err != nil {
return nil, errors.ServiceUnavailable("NODE_NOT_FOUND", err.Error())
}
endpoint, err := endpoint.ParseEndpoint(node.Endpoints, "http", !client.insecure)
if err != nil {
if node, done, err = client.opts.selector.Select(ctx, selector.WithFilter(c.filters...)); err != nil {
return nil, errors.ServiceUnavailable("NODE_NOT_FOUND", err.Error())
}
if client.insecure {
@ -253,12 +246,12 @@ func (client *Client) invoke(ctx context.Context, req *http.Request, args interf
} else {
req.URL.Scheme = "https"
}
req.URL.Host = endpoint
req.Host = endpoint
req.URL.Host = node.Address()
req.Host = node.Address()
}
res, err := client.do(ctx, req, c)
if done != nil {
done(ctx, balancer.DoneInfo{Err: err})
done(ctx, selector.DoneInfo{Err: err})
}
if res != nil {
cs := csAttempt{res: res}

@ -5,19 +5,15 @@ import (
"errors"
"net/url"
"strings"
"sync"
"time"
"github.com/go-kratos/kratos/v2/internal/endpoint"
"github.com/go-kratos/kratos/v2/log"
"github.com/go-kratos/kratos/v2/registry"
"github.com/go-kratos/kratos/v2/selector"
"github.com/go-kratos/kratos/v2/selector/node"
)
// Updater is resolver nodes updater
type Updater interface {
Update(nodes []*registry.ServiceInstance)
}
// Target is resolver target
type Target struct {
Scheme string
@ -45,9 +41,7 @@ func parseTarget(endpoint string, insecure bool) (*Target, error) {
}
type resolver struct {
lock sync.RWMutex
nodes []*registry.ServiceInstance
updater Updater
rebalancer selector.Rebalancer
target *Target
watcher registry.Watcher
@ -56,17 +50,17 @@ type resolver struct {
insecure bool
}
func newResolver(ctx context.Context, discovery registry.Discovery, target *Target, updater Updater, block, insecure bool) (*resolver, error) {
func newResolver(ctx context.Context, discovery registry.Discovery, target *Target, rebalancer selector.Rebalancer, block, insecure bool) (*resolver, error) {
watcher, err := discovery.Watch(ctx, target.Endpoint)
if err != nil {
return nil, err
}
r := &resolver{
target: target,
watcher: watcher,
logger: log.NewHelper(log.DefaultLogger),
updater: updater,
insecure: insecure,
target: target,
watcher: watcher,
logger: log.NewHelper(log.DefaultLogger),
rebalancer: rebalancer,
insecure: insecure,
}
if block {
done := make(chan error, 1)
@ -77,8 +71,7 @@ func newResolver(ctx context.Context, discovery registry.Discovery, target *Targ
done <- err
return
}
r.update(services)
if len(r.nodes) > 0 {
if r.update(services) {
done <- nil
return
}
@ -119,27 +112,25 @@ func newResolver(ctx context.Context, discovery registry.Discovery, target *Targ
return r, nil
}
func (r *resolver) update(services []*registry.ServiceInstance) {
nodes := make([]*registry.ServiceInstance, 0)
for _, in := range services {
ept, err := endpoint.ParseEndpoint(in.Endpoints, "http", !r.insecure)
func (r *resolver) update(services []*registry.ServiceInstance) bool {
nodes := make([]selector.Node, 0)
for _, ins := range services {
ept, err := endpoint.ParseEndpoint(ins.Endpoints, "http", !r.insecure)
if err != nil {
r.logger.Errorf("Failed to parse (%v) discovery endpoint: %v error %v", r.target, in.Endpoints, err)
r.logger.Errorf("Failed to parse (%v) discovery endpoint: %v error %v", r.target, ins.Endpoints, err)
continue
}
if ept == "" {
continue
}
nodes = append(nodes, in)
nodes = append(nodes, node.New(ept, ins))
}
if len(nodes) != 0 {
r.updater.Update(nodes)
r.lock.Lock()
r.nodes = nodes
r.lock.Unlock()
} else {
if len(nodes) == 0 {
r.logger.Warnf("[http resovler]Zero endpoint found,refused to write,ser: %s ins: %v", r.target.Endpoint, nodes)
return false
}
r.rebalancer.Apply(nodes)
return true
}
func (r *resolver) Close() error {

Loading…
Cancel
Save