add warden rpc (#9)

* add warden rpc
pull/13/head
longxboy 6 years ago committed by Felix Hao
parent 50d0129461
commit 637a6a3628
  1. 2
      .gitignore
  2. 18
      pkg/conf/flagvar/flagvar.go
  3. 1
      pkg/ecode/common_ecode.go
  4. 48
      pkg/ecode/pb/ecode.go
  5. 96
      pkg/ecode/pb/ecode.pb.go
  6. 13
      pkg/ecode/pb/ecode.proto
  7. 103
      pkg/ecode/status.go
  8. 66
      pkg/ecode/status_test.go
  9. 102
      pkg/ecode/types/status.pb.go
  10. 23
      pkg/ecode/types/status.proto
  11. 27
      pkg/net/metadata/key.go
  12. 62
      pkg/net/rpc/warden/CHANGELOG.md
  13. 10
      pkg/net/rpc/warden/OWNERS
  14. 13
      pkg/net/rpc/warden/README.md
  15. 20
      pkg/net/rpc/warden/balancer/p2c/CHANGELOG.md
  16. 9
      pkg/net/rpc/warden/balancer/p2c/OWNERS
  17. 13
      pkg/net/rpc/warden/balancer/p2c/README.md
  18. 269
      pkg/net/rpc/warden/balancer/p2c/p2c.go
  19. 347
      pkg/net/rpc/warden/balancer/p2c/p2c_test.go
  20. 17
      pkg/net/rpc/warden/balancer/wrr/CHANGELOG.md
  21. 9
      pkg/net/rpc/warden/balancer/wrr/OWNERS
  22. 13
      pkg/net/rpc/warden/balancer/wrr/README.md
  23. 302
      pkg/net/rpc/warden/balancer/wrr/wrr.go
  24. 189
      pkg/net/rpc/warden/balancer/wrr/wrr_test.go
  25. 334
      pkg/net/rpc/warden/client.go
  26. 91
      pkg/net/rpc/warden/exapmle_test.go
  27. 189
      pkg/net/rpc/warden/internal/benchmark/bench/client/client.go
  28. 1686
      pkg/net/rpc/warden/internal/benchmark/bench/proto/hello.pb.go
  29. 60
      pkg/net/rpc/warden/internal/benchmark/bench/proto/hello.proto
  30. 103
      pkg/net/rpc/warden/internal/benchmark/bench/server/server.go
  31. 15
      pkg/net/rpc/warden/internal/benchmark/helloworld/client.sh
  32. 85
      pkg/net/rpc/warden/internal/benchmark/helloworld/client/greeter_client.go
  33. 50
      pkg/net/rpc/warden/internal/benchmark/helloworld/server/greeter_server.go
  34. 53
      pkg/net/rpc/warden/internal/encoding/json/json.go
  35. 31
      pkg/net/rpc/warden/internal/examples/client/client.go
  36. 191
      pkg/net/rpc/warden/internal/examples/grpcDebug/client.go
  37. 1
      pkg/net/rpc/warden/internal/examples/grpcDebug/data.json
  38. 108
      pkg/net/rpc/warden/internal/examples/server/main.go
  39. 11
      pkg/net/rpc/warden/internal/metadata/metadata.go
  40. 642
      pkg/net/rpc/warden/internal/proto/testproto/hello.pb.go
  41. 33
      pkg/net/rpc/warden/internal/proto/testproto/hello.proto
  42. 151
      pkg/net/rpc/warden/internal/status/status.go
  43. 164
      pkg/net/rpc/warden/internal/status/status_test.go
  44. 118
      pkg/net/rpc/warden/logging.go
  45. 55
      pkg/net/rpc/warden/logging_test.go
  46. 61
      pkg/net/rpc/warden/recovery.go
  47. 17
      pkg/net/rpc/warden/resolver/CHANGELOG.md
  48. 9
      pkg/net/rpc/warden/resolver/OWNERS
  49. 13
      pkg/net/rpc/warden/resolver/README.md
  50. 6
      pkg/net/rpc/warden/resolver/direct/CHANGELOG.md
  51. 14
      pkg/net/rpc/warden/resolver/direct/README.md
  52. 77
      pkg/net/rpc/warden/resolver/direct/direct.go
  53. 85
      pkg/net/rpc/warden/resolver/direct/direct_test.go
  54. 204
      pkg/net/rpc/warden/resolver/resolver.go
  55. 87
      pkg/net/rpc/warden/resolver/test/mockdiscovery.go
  56. 312
      pkg/net/rpc/warden/resolver/test/resovler_test.go
  57. 16
      pkg/net/rpc/warden/resolver/util.go
  58. 332
      pkg/net/rpc/warden/server.go
  59. 570
      pkg/net/rpc/warden/server_test.go
  60. 25
      pkg/net/rpc/warden/stats.go
  61. 31
      pkg/net/rpc/warden/validate.go

2
.gitignore vendored

@ -1,2 +1,4 @@
go.sum
BUILD
.DS_Store
tool/kratos/kratos

@ -0,0 +1,18 @@
package flagvar
import (
"strings"
)
// StringVars []string implement flag.Value
type StringVars []string
func (s StringVars) String() string {
return strings.Join(s, ",")
}
// Set implement flag.Value
func (s *StringVars) Set(val string) error {
*s = append(*s, val)
return nil
}

@ -12,6 +12,7 @@ var (
NothingFound = add(-404) // 啥都木有
MethodNotAllowed = add(-405) // 不支持该方法
Conflict = add(-409) // 冲突
Canceled = add(-498) // 客户端取消请求
ServerErr = add(-500) // 服务器错误
ServiceUnavailable = add(-503) // 过载保护,服务暂不可用
Deadline = add(-504) // 服务调用超时

@ -0,0 +1,48 @@
package pb
import (
"strconv"
"github.com/bilibili/kratos/pkg/ecode"
any "github.com/golang/protobuf/ptypes/any"
)
func (e *Error) Error() string {
return strconv.FormatInt(int64(e.GetErrCode()), 10)
}
// Code is the code of error.
func (e *Error) Code() int {
return int(e.GetErrCode())
}
// Message is error message.
func (e *Error) Message() string {
return e.GetErrMessage()
}
// Equal compare whether two errors are equal.
func (e *Error) Equal(ec error) bool {
return ecode.Cause(ec).Code() == e.Code()
}
// Details return error details.
func (e *Error) Details() []interface{} {
return []interface{}{e.GetErrDetail()}
}
// From will convert ecode.Codes to pb.Error.
//
// Deprecated: please use ecode.Error
func From(ec ecode.Codes) *Error {
var detail *any.Any
if details := ec.Details(); len(details) > 0 {
detail, _ = details[0].(*any.Any)
}
return &Error{
ErrCode: int32(ec.Code()),
ErrMessage: ec.Message(),
ErrDetail: detail,
}
}

@ -0,0 +1,96 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// source: error.proto
package pb
import proto "github.com/golang/protobuf/proto"
import fmt "fmt"
import math "math"
import any "github.com/golang/protobuf/ptypes/any"
// Reference imports to suppress errors if they are not otherwise used.
var _ = proto.Marshal
var _ = fmt.Errorf
var _ = math.Inf
// This is a compile-time assertion to ensure that this generated file
// is compatible with the proto package it is being compiled against.
// A compilation error at this line likely means your copy of the
// proto package needs to be updated.
const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package
// Deprecated: please use ecode.Error
type Error struct {
ErrCode int32 `protobuf:"varint,1,opt,name=err_code,json=errCode,proto3" json:"err_code,omitempty"`
ErrMessage string `protobuf:"bytes,2,opt,name=err_message,json=errMessage,proto3" json:"err_message,omitempty"`
ErrDetail *any.Any `protobuf:"bytes,3,opt,name=err_detail,json=errDetail,proto3" json:"err_detail,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *Error) Reset() { *m = Error{} }
func (m *Error) String() string { return proto.CompactTextString(m) }
func (*Error) ProtoMessage() {}
func (*Error) Descriptor() ([]byte, []int) {
return fileDescriptor_error_28aad86a4e53115b, []int{0}
}
func (m *Error) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_Error.Unmarshal(m, b)
}
func (m *Error) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_Error.Marshal(b, m, deterministic)
}
func (dst *Error) XXX_Merge(src proto.Message) {
xxx_messageInfo_Error.Merge(dst, src)
}
func (m *Error) XXX_Size() int {
return xxx_messageInfo_Error.Size(m)
}
func (m *Error) XXX_DiscardUnknown() {
xxx_messageInfo_Error.DiscardUnknown(m)
}
var xxx_messageInfo_Error proto.InternalMessageInfo
func (m *Error) GetErrCode() int32 {
if m != nil {
return m.ErrCode
}
return 0
}
func (m *Error) GetErrMessage() string {
if m != nil {
return m.ErrMessage
}
return ""
}
func (m *Error) GetErrDetail() *any.Any {
if m != nil {
return m.ErrDetail
}
return nil
}
func init() {
proto.RegisterType((*Error)(nil), "err.Error")
}
func init() { proto.RegisterFile("error.proto", fileDescriptor_error_28aad86a4e53115b) }
var fileDescriptor_error_28aad86a4e53115b = []byte{
// 164 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x34, 0x8d, 0xc1, 0xca, 0x82, 0x40,
0x14, 0x85, 0x99, 0x5f, 0xfc, 0xcb, 0x71, 0x37, 0xb4, 0xd0, 0x36, 0x49, 0x2b, 0x57, 0x23, 0xe4,
0x13, 0x44, 0xb5, 0x6c, 0xe3, 0x0b, 0x88, 0xe6, 0x49, 0x02, 0xf3, 0xc6, 0xd1, 0x20, 0xdf, 0x3e,
0x1c, 0x69, 0x79, 0xcf, 0xf7, 0x71, 0x3f, 0x1d, 0x82, 0x14, 0xda, 0x17, 0x65, 0x14, 0xe3, 0x81,
0xdc, 0xc6, 0xad, 0x48, 0xdb, 0x21, 0x73, 0x53, 0xfd, 0xbe, 0x67, 0x55, 0x3f, 0x2d, 0x7c, 0xff,
0xd1, 0xfe, 0x65, 0xd6, 0x4d, 0xac, 0xd7, 0x20, 0xcb, 0x9b, 0x34, 0x88, 0x54, 0xa2, 0x52, 0xbf,
0x58, 0x81, 0x3c, 0x49, 0x03, 0xb3, 0x73, 0x2f, 0xcb, 0x27, 0x86, 0xa1, 0x6a, 0x11, 0xfd, 0x25,
0x2a, 0x0d, 0x0a, 0x0d, 0xf2, 0xba, 0x2c, 0x26, 0xd7, 0xf3, 0x55, 0x36, 0x18, 0xab, 0x47, 0x17,
0x79, 0x89, 0x4a, 0xc3, 0xc3, 0xc6, 0x2e, 0x51, 0xfb, 0x8b, 0xda, 0x63, 0x3f, 0x15, 0x01, 0xc8,
0xb3, 0xd3, 0xea, 0x7f, 0x07, 0xf2, 0x6f, 0x00, 0x00, 0x00, 0xff, 0xff, 0xf7, 0x41, 0x22, 0xfd,
0xaf, 0x00, 0x00, 0x00,
}

@ -0,0 +1,13 @@
syntax = "proto3";
package pb;
import "google/protobuf/any.proto";
option go_package = "go-common/library/ecode/pb";
message Error {
int32 err_code = 1;
string err_message = 2;
google.protobuf.Any err_detail = 3;
}

@ -0,0 +1,103 @@
package ecode
import (
"fmt"
"strconv"
"github.com/bilibili/kratos/pkg/ecode/types"
"github.com/golang/protobuf/proto"
"github.com/golang/protobuf/ptypes"
)
// Error new status with code and message
func Error(code Code, message string) *Status {
return &Status{s: &types.Status{Code: int32(code.Code()), Message: message}}
}
// Errorf new status with code and message
func Errorf(code Code, format string, args ...interface{}) *Status {
return Error(code, fmt.Sprintf(format, args...))
}
var _ Codes = &Status{}
// Status statusError is an alias of a status proto
// implement ecode.Codes
type Status struct {
s *types.Status
}
// Error implement error
func (s *Status) Error() string {
return s.Message()
}
// Code return error code
func (s *Status) Code() int {
return int(s.s.Code)
}
// Message return error message for developer
func (s *Status) Message() string {
if s.s.Message == "" {
return strconv.Itoa(int(s.s.Code))
}
return s.s.Message
}
// Details return error details
func (s *Status) Details() []interface{} {
if s == nil || s.s == nil {
return nil
}
details := make([]interface{}, 0, len(s.s.Details))
for _, any := range s.s.Details {
detail := &ptypes.DynamicAny{}
if err := ptypes.UnmarshalAny(any, detail); err != nil {
details = append(details, err)
continue
}
details = append(details, detail.Message)
}
return details
}
// WithDetails WithDetails
func (s *Status) WithDetails(pbs ...proto.Message) (*Status, error) {
for _, pb := range pbs {
anyMsg, err := ptypes.MarshalAny(pb)
if err != nil {
return s, err
}
s.s.Details = append(s.s.Details, anyMsg)
}
return s, nil
}
// Equal for compatible.
// Deprecated: please use ecode.EqualError.
func (s *Status) Equal(err error) bool {
return EqualError(s, err)
}
// Proto return origin protobuf message
func (s *Status) Proto() *types.Status {
return s.s
}
// FromCode create status from ecode
func FromCode(code Code) *Status {
return &Status{s: &types.Status{Code: int32(code)}}
}
// FromProto new status from grpc detail
func FromProto(pbMsg proto.Message) Codes {
if msg, ok := pbMsg.(*types.Status); ok {
if msg.Message == "" {
// NOTE: if message is empty convert to pure Code, will get message from config center.
return Code(msg.Code)
}
return &Status{s: msg}
}
return Errorf(ServerErr, "invalid proto message get %v", pbMsg)
}

@ -0,0 +1,66 @@
package ecode
import (
"testing"
"time"
"github.com/golang/protobuf/ptypes/timestamp"
"github.com/smartystreets/goconvey/convey"
"github.com/stretchr/testify/assert"
"github.com/bilibili/kratos/pkg/ecode/types"
)
func TestEqual(t *testing.T) {
convey.Convey("Equal", t, func(ctx convey.C) {
ctx.Convey("When err1=Error(RequestErr, 'test') and err2=Errorf(RequestErr, 'test')", func(ctx convey.C) {
err1 := Error(RequestErr, "test")
err2 := Errorf(RequestErr, "test")
ctx.Convey("Then err1=err2, err1 != nil", func(ctx convey.C) {
ctx.So(err1, convey.ShouldResemble, err2)
ctx.So(err1, convey.ShouldNotBeNil)
})
})
})
// assert.True(t, OK.Equal(nil))
// assert.True(t, err1.Equal(err2))
// assert.False(t, err1.Equal(nil))
// assert.True(t, Equal(nil, nil))
}
func TestDetail(t *testing.T) {
m := &timestamp.Timestamp{Seconds: time.Now().Unix()}
st, _ := Error(RequestErr, "RequestErr").WithDetails(m)
assert.Equal(t, "RequestErr", st.Message())
assert.Equal(t, int(RequestErr), st.Code())
assert.IsType(t, m, st.Details()[0])
}
func TestFromCode(t *testing.T) {
err := FromCode(RequestErr)
assert.Equal(t, int(RequestErr), err.Code())
assert.Equal(t, "-400", err.Message())
}
func TestFromProto(t *testing.T) {
msg := &types.Status{Code: 2233, Message: "error"}
err := FromProto(msg)
assert.Equal(t, 2233, err.Code())
assert.Equal(t, "error", err.Message())
m := &timestamp.Timestamp{Seconds: time.Now().Unix()}
err = FromProto(m)
assert.Equal(t, -500, err.Code())
assert.Contains(t, err.Message(), "invalid proto message get")
}
func TestEmpty(t *testing.T) {
st := &Status{}
assert.Len(t, st.Details(), 0)
st = nil
assert.Len(t, st.Details(), 0)
}

@ -0,0 +1,102 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// source: internal/types/status.proto
package types // import "github.com/bilibili/kratos/pkg/ecode/types"
import proto "github.com/golang/protobuf/proto"
import fmt "fmt"
import math "math"
import any "github.com/golang/protobuf/ptypes/any"
// Reference imports to suppress errors if they are not otherwise used.
var _ = proto.Marshal
var _ = fmt.Errorf
var _ = math.Inf
// This is a compile-time assertion to ensure that this generated file
// is compatible with the proto package it is being compiled against.
// A compilation error at this line likely means your copy of the
// proto package needs to be updated.
const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package
type Status struct {
// The error code see ecode.Code
Code int32 `protobuf:"varint,1,opt,name=code" json:"code,omitempty"`
// A developer-facing error message, which should be in English. Any
Message string `protobuf:"bytes,2,opt,name=message" json:"message,omitempty"`
// A list of messages that carry the error details. There is a common set of
// message types for APIs to use.
Details []*any.Any `protobuf:"bytes,3,rep,name=details" json:"details,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *Status) Reset() { *m = Status{} }
func (m *Status) String() string { return proto.CompactTextString(m) }
func (*Status) ProtoMessage() {}
func (*Status) Descriptor() ([]byte, []int) {
return fileDescriptor_status_88668d6b2bf80f08, []int{0}
}
func (m *Status) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_Status.Unmarshal(m, b)
}
func (m *Status) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_Status.Marshal(b, m, deterministic)
}
func (dst *Status) XXX_Merge(src proto.Message) {
xxx_messageInfo_Status.Merge(dst, src)
}
func (m *Status) XXX_Size() int {
return xxx_messageInfo_Status.Size(m)
}
func (m *Status) XXX_DiscardUnknown() {
xxx_messageInfo_Status.DiscardUnknown(m)
}
var xxx_messageInfo_Status proto.InternalMessageInfo
func (m *Status) GetCode() int32 {
if m != nil {
return m.Code
}
return 0
}
func (m *Status) GetMessage() string {
if m != nil {
return m.Message
}
return ""
}
func (m *Status) GetDetails() []*any.Any {
if m != nil {
return m.Details
}
return nil
}
func init() {
proto.RegisterType((*Status)(nil), "bilibili.rpc.Status")
}
func init() { proto.RegisterFile("internal/types/status.proto", fileDescriptor_status_88668d6b2bf80f08) }
var fileDescriptor_status_88668d6b2bf80f08 = []byte{
// 220 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x54, 0x8f, 0xb1, 0x4a, 0x04, 0x31,
0x10, 0x86, 0xd9, 0x5b, 0xbd, 0xc3, 0x9c, 0x85, 0x04, 0x8b, 0x55, 0x9b, 0xc5, 0x6a, 0x0b, 0x4d,
0x40, 0x4b, 0x2b, 0xcf, 0x17, 0x58, 0x22, 0x36, 0x76, 0x49, 0x6e, 0x2e, 0x04, 0x92, 0xcc, 0x92,
0xe4, 0x8a, 0xbc, 0x8e, 0x4f, 0x2a, 0x9b, 0x65, 0x41, 0x8b, 0x19, 0x66, 0x98, 0xff, 0xe7, 0xfb,
0x87, 0x3c, 0xd8, 0x90, 0x21, 0x06, 0xe9, 0x78, 0x2e, 0x13, 0x24, 0x9e, 0xb2, 0xcc, 0xe7, 0xc4,
0xa6, 0x88, 0x19, 0xe9, 0xb5, 0xb2, 0xce, 0xce, 0xc5, 0xe2, 0xa4, 0xef, 0xef, 0x0c, 0xa2, 0x71,
0xc0, 0xeb, 0x4d, 0x9d, 0x4f, 0x5c, 0x86, 0xb2, 0x08, 0x1f, 0x4f, 0x64, 0xfb, 0x59, 0x8d, 0x94,
0x92, 0x0b, 0x8d, 0x47, 0xe8, 0x9a, 0xbe, 0x19, 0x2e, 0x45, 0x9d, 0x69, 0x47, 0x76, 0x1e, 0x52,
0x92, 0x06, 0xba, 0x4d, 0xdf, 0x0c, 0x57, 0x62, 0x5d, 0x29, 0x23, 0xbb, 0x23, 0x64, 0x69, 0x5d,
0xea, 0xda, 0xbe, 0x1d, 0xf6, 0x2f, 0xb7, 0x6c, 0x81, 0xb0, 0x15, 0xc2, 0xde, 0x43, 0x11, 0xab,
0xe8, 0xf0, 0x45, 0x6e, 0x34, 0x7a, 0xf6, 0x37, 0xd6, 0x61, 0xbf, 0x90, 0xc7, 0xd9, 0x30, 0x36,
0xdf, 0x4f, 0x06, 0x9f, 0x35, 0x7a, 0x8f, 0x81, 0x3b, 0xab, 0xa2, 0x8c, 0x85, 0xc3, 0x9c, 0x82,
0xff, 0x7f, 0xf4, 0xad, 0xf6, 0x9f, 0x4d, 0x2b, 0xc6, 0x0f, 0xb5, 0xad, 0xb4, 0xd7, 0xdf, 0x00,
0x00, 0x00, 0xff, 0xff, 0x80, 0xa3, 0xc1, 0x82, 0x0d, 0x01, 0x00, 0x00,
}

@ -0,0 +1,23 @@
syntax = "proto3";
package bilibili.rpc;
import "google/protobuf/any.proto";
option go_package = "github.com/bilibili/Kratos/pkg/ecode/types;types";
option java_multiple_files = true;
option java_outer_classname = "StatusProto";
option java_package = "com.bilibili.rpc";
option objc_class_prefix = "RPC";
message Status {
// The error code see ecode.Code
int32 code = 1;
// A developer-facing error message, which should be in English. Any
string message = 2;
// A list of messages that carry the error details. There is a common set of
// message types for APIs to use.
repeated google.protobuf.Any details = 3;
}

@ -37,3 +37,30 @@ const (
// Device 客户端信息
Device = "device"
)
var outgoingKey = map[string]struct{}{
Color: struct{}{},
RemoteIP: struct{}{},
RemotePort: struct{}{},
Mirror: struct{}{},
}
var incomingKey = map[string]struct{}{
Caller: struct{}{},
}
// IsOutgoingKey represent this key should propagate by rpc.
func IsOutgoingKey(key string) bool {
_, ok := outgoingKey[key]
return ok
}
// IsIncomingKey represent this key should extract from rpc metadata.
func IsIncomingKey(key string) (ok bool) {
_, ok = outgoingKey[key]
if ok {
return
}
_, ok = incomingKey[key]
return
}

@ -0,0 +1,62 @@
### net/rpc/warden
##### Version 1.1.12
1. 设置 caller 为 no_user 如果 user 不存在
##### Version 1.1.12
1. warden支持mirror传递
##### Version 1.1.11
1. Validate RequestErr支持详细报错信息
##### Version 1.1.10
1. 默认读取环境中的color
##### Version 1.1.9
1. 增加NonBlock模式
##### Version 1.1.8
1. 新增appid mock
##### Version 1.1.7
1. 兼容cpu为0和wrr dt为0的情况
##### Version 1.1.6
1. 修改caller传递和获取方式
2. 添加error detail example
##### Version 1.1.5
1. 增加server端json格式支持
##### Version 1.1.4
1. 判断reosvler.builder为nil之后再注册
##### Version 1.1.3
1. 支持zone和clusters
##### Version 1.1.2
1. 业务错误日志记为 WARN
##### Version 1.1.1
1. server实现了返回cpu信息
##### Version 1.1.0
1. 增加ErrorDetail
2. 修复日志打印error信息丢失问题
##### Version 1.0.3
1. 给server增加keepalive参数
##### Version 1.0.2
1. 替代默认的timoue,使用durtaion.Shrink()来传递context
2. 修复peer.Addr为nil时会panic的问题
##### Version 1.0.1
1. 去除timeout的手动传递,改为使用grpc默认自带的grpc-timeout
2. 获取server address改为使用call option的方式,去除对balancer的依赖
##### Version 1.0.0
1. 使用NewClient来新建一个RPC客户端,并默认集成trace、log、recovery、moniter拦截器
2. 使用NewServer来新建一个RPC服务端,并默认集成trace、log、recovery、moniter拦截器

@ -0,0 +1,10 @@
# See the OWNERS docs at https://go.k8s.io/owners
approvers:
- caoguoliang
- maojian
labels:
- library
reviewers:
- caoguoliang
- maojian

@ -0,0 +1,13 @@
#### net/rcp/warden
##### 项目简介
来自 bilibili 主站技术部的 RPC 框架,融合主站技术部的核心科技,带来如飞一般的体验。
##### 编译环境
- **请只用 Golang v1.9.x 以上版本编译执行**
##### 依赖包
- [grpc](google.golang.org/grpc)

@ -0,0 +1,20 @@
### business/warden/balancer/p2c
### Version 1.3.1
1. add more test
### Version 1.3
1. P2C替换smooth weighted round-robin
##### Version 1.2.1
1. 删除了netflix ribbon的权重算法,改成了平方根算法
##### Version 1.2.0
1. 实现了动态计算的调度轮询算法(使用了服务端的成功率数据,替换基于本地计算的成功率数据)
##### Version 1.1.0
1. 实现了动态计算的调度轮询算法
##### Version 1.0.0
1. 实现了带权重可以识别Color的轮询算法

@ -0,0 +1,9 @@
# See the OWNERS docs at https://go.k8s.io/owners
approvers:
- caoguoliang
labels:
- library
reviewers:
- caoguoliang
- maojian

@ -0,0 +1,13 @@
#### business/warden/balancer/wrr
##### 项目简介
warden 的 weighted round robin负载均衡模块,主要用于为每个RPC请求返回一个Server节点以供调用
##### 编译环境
- **请只用 Golang v1.9.x 以上版本编译执行**
##### 依赖包
- [grpc](google.golang.org/grpc)

@ -0,0 +1,269 @@
package p2c
import (
"context"
"math"
"math/rand"
"strconv"
"sync"
"sync/atomic"
"time"
"github.com/bilibili/kratos/pkg/conf/env"
"github.com/bilibili/kratos/pkg/log"
nmd "github.com/bilibili/kratos/pkg/net/metadata"
wmd "github.com/bilibili/kratos/pkg/net/rpc/warden/internal/metadata"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/balancer/base"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/resolver"
"google.golang.org/grpc/status"
)
const (
// The mean lifetime of `cost`, it reaches its half-life after Tau*ln(2).
tau = int64(time.Millisecond * 600)
// if statistic not collected,we add a big penalty to endpoint
penalty = uint64(1000 * time.Millisecond * 250)
forceGap = int64(time.Second * 3)
)
var _ base.PickerBuilder = &p2cPickerBuilder{}
var _ balancer.Picker = &p2cPicker{}
// Name is the name of pick of two random choices balancer.
const Name = "p2c"
// newBuilder creates a new weighted-roundrobin balancer builder.
func newBuilder() balancer.Builder {
return base.NewBalancerBuilder(Name, &p2cPickerBuilder{})
}
func init() {
balancer.Register(newBuilder())
}
type subConn struct {
// metadata
conn balancer.SubConn
addr resolver.Address
meta wmd.MD
//client statistic data
lag uint64
success uint64
inflight int64
// server statistic data
svrCPU uint64
//last collected timestamp
stamp int64
//last pick timestamp
pick int64
// request number in a period time
reqs int64
}
func (sc *subConn) health() uint64 {
return atomic.LoadUint64(&sc.success)
}
func (sc *subConn) cost() uint64 {
load := atomic.LoadUint64(&sc.svrCPU) * atomic.LoadUint64(&sc.lag) * uint64(atomic.LoadInt64(&sc.inflight))
if load == 0 {
// penalty是初始化没有数据时的惩罚值,默认为1e9 * 250
load = penalty
}
return load
}
// statistics is info for log
type statistic struct {
addr string
score float64
cs uint64
lantency uint64
cpu uint64
inflight int64
reqs int64
}
type p2cPickerBuilder struct{}
func (*p2cPickerBuilder) Build(readySCs map[resolver.Address]balancer.SubConn) balancer.Picker {
p := &p2cPicker{
colors: make(map[string]*p2cPicker),
r: rand.New(rand.NewSource(time.Now().UnixNano())),
}
for addr, sc := range readySCs {
meta, ok := addr.Metadata.(wmd.MD)
if !ok {
meta = wmd.MD{
Weight: 10,
}
}
subc := &subConn{
conn: sc,
addr: addr,
meta: meta,
svrCPU: 500,
lag: 0,
success: 1000,
inflight: 1,
}
if meta.Color == "" {
p.subConns = append(p.subConns, subc)
continue
}
// if color not empty, use color picker
cp, ok := p.colors[meta.Color]
if !ok {
cp = &p2cPicker{r: rand.New(rand.NewSource(time.Now().UnixNano()))}
p.colors[meta.Color] = cp
}
cp.subConns = append(cp.subConns, subc)
}
return p
}
type p2cPicker struct {
// subConns is the snapshot of the weighted-roundrobin balancer when this picker was
// created. The slice is immutable. Each Get() will do a round robin
// selection from it and return the selected SubConn.
subConns []*subConn
colors map[string]*p2cPicker
logTs int64
r *rand.Rand
lk sync.Mutex
}
func (p *p2cPicker) Pick(ctx context.Context, opts balancer.PickOptions) (balancer.SubConn, func(balancer.DoneInfo), error) {
// FIXME refactor to unify the color logic
color := nmd.String(ctx, nmd.Color)
if color == "" && env.Color != "" {
color = env.Color
}
if color != "" {
if cp, ok := p.colors[color]; ok {
return cp.pick(ctx, opts)
}
}
return p.pick(ctx, opts)
}
func (p *p2cPicker) pick(ctx context.Context, opts balancer.PickOptions) (balancer.SubConn, func(balancer.DoneInfo), error) {
var pc, upc *subConn
start := time.Now().UnixNano()
if len(p.subConns) <= 0 {
return nil, nil, balancer.ErrNoSubConnAvailable
} else if len(p.subConns) == 1 {
pc = p.subConns[0]
} else {
// choose two distinct nodes
p.lk.Lock()
a := p.r.Intn(len(p.subConns))
b := p.r.Intn(len(p.subConns) - 1)
p.lk.Unlock()
if b >= a {
b = b + 1
}
nodeA, nodeB := p.subConns[a], p.subConns[b]
// meta.Weight为服务发布者在disocvery中设置的权重
if nodeA.cost()*nodeB.health()*nodeB.meta.Weight > nodeB.cost()*nodeA.health()*nodeA.meta.Weight {
pc, upc = nodeB, nodeA
} else {
pc, upc = nodeA, nodeB
}
// 如果选中的节点,在forceGap期间内没有被选中一次,那么强制一次
// 利用强制的机会,来触发成功率、延迟的衰减
// 原子锁conn.pick保证并发安全,放行一次
pick := atomic.LoadInt64(&upc.pick)
if start-pick > forceGap && atomic.CompareAndSwapInt64(&upc.pick, pick, start) {
pc = upc
}
}
// 节点未发生切换才更新pick时间
if pc != upc {
atomic.StoreInt64(&pc.pick, start)
}
atomic.AddInt64(&pc.inflight, 1)
atomic.AddInt64(&pc.reqs, 1)
return pc.conn, func(di balancer.DoneInfo) {
atomic.AddInt64(&pc.inflight, -1)
now := time.Now().UnixNano()
// get moving average ratio w
stamp := atomic.SwapInt64(&pc.stamp, now)
td := now - stamp
if td < 0 {
td = 0
}
w := math.Exp(float64(-td) / float64(tau))
lag := now - start
if lag < 0 {
lag = 0
}
oldLag := atomic.LoadUint64(&pc.lag)
if oldLag == 0 {
w = 0.0
}
lag = int64(float64(oldLag)*w + float64(lag)*(1.0-w))
atomic.StoreUint64(&pc.lag, uint64(lag))
success := uint64(1000) // error value ,if error set 1
if di.Err != nil {
if st, ok := status.FromError(di.Err); ok {
// only counter the local grpc error, ignore any business error
if st.Code() != codes.Unknown && st.Code() != codes.OK {
success = 0
}
}
}
oldSuc := atomic.LoadUint64(&pc.success)
success = uint64(float64(oldSuc)*w + float64(success)*(1.0-w))
atomic.StoreUint64(&pc.success, success)
trailer := di.Trailer
if strs, ok := trailer[wmd.CPUUsage]; ok {
if cpu, err2 := strconv.ParseUint(strs[0], 10, 64); err2 == nil && cpu > 0 {
atomic.StoreUint64(&pc.svrCPU, cpu)
}
}
logTs := atomic.LoadInt64(&p.logTs)
if now-logTs > int64(time.Second*3) {
if atomic.CompareAndSwapInt64(&p.logTs, logTs, now) {
p.printStats()
}
}
}, nil
}
func (p *p2cPicker) printStats() {
if len(p.subConns) <= 0 {
return
}
stats := make([]statistic, 0, len(p.subConns))
for _, conn := range p.subConns {
var stat statistic
stat.addr = conn.addr.Addr
stat.cpu = atomic.LoadUint64(&conn.svrCPU)
stat.cs = atomic.LoadUint64(&conn.success)
stat.inflight = atomic.LoadInt64(&conn.inflight)
stat.lantency = atomic.LoadUint64(&conn.lag)
stat.reqs = atomic.SwapInt64(&conn.reqs, 0)
load := stat.cpu * uint64(stat.inflight) * stat.lantency
if load != 0 {
stat.score = float64(stat.cs*conn.meta.Weight*1e8) / float64(load)
}
stats = append(stats, stat)
}
log.Info("p2c %s : %+v", p.subConns[0].addr.ServerName, stats)
//fmt.Printf("%+v\n", stats)
}

@ -0,0 +1,347 @@
package p2c
import (
"context"
"flag"
"fmt"
"math/rand"
"strconv"
"sync/atomic"
"testing"
"time"
"github.com/bilibili/kratos/pkg/conf/env"
nmd "github.com/bilibili/kratos/pkg/net/metadata"
wmeta "github.com/bilibili/kratos/pkg/net/rpc/warden/internal/metadata"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/resolver"
"google.golang.org/grpc/status"
)
var serverNum int
var cliNum int
var concurrency int
var extraLoad int64
var extraDelay int64
var extraWeight uint64
func init() {
flag.IntVar(&serverNum, "snum", 5, "-snum 6")
flag.IntVar(&cliNum, "cnum", 5, "-cnum 12")
flag.IntVar(&concurrency, "concurrency", 5, "-cc 10")
flag.Int64Var(&extraLoad, "exload", 3, "-exload 3")
flag.Int64Var(&extraDelay, "exdelay", 0, "-exdelay 250")
flag.Uint64Var(&extraWeight, "extraWeight", 0, "-exdelay 50")
}
type testSubConn struct {
addr resolver.Address
wait chan struct{}
//statics
reqs int64
usage int64
cpu int64
prevReq int64
prevUsage int64
//control params
loadJitter int64
delayJitter int64
}
func newTestSubConn(addr string, weight uint64, color string) (sc *testSubConn) {
sc = &testSubConn{
addr: resolver.Address{
Addr: addr,
Metadata: wmeta.MD{
Weight: weight,
Color: color,
},
},
wait: make(chan struct{}, 1000),
}
go func() {
for {
for i := 0; i < 210; i++ {
<-sc.wait
}
time.Sleep(time.Millisecond * 20)
}
}()
return
}
func (s *testSubConn) connect(ctx context.Context) {
time.Sleep(time.Millisecond * 15)
//add qps counter when request come in
atomic.AddInt64(&s.reqs, 1)
select {
case <-ctx.Done():
return
case s.wait <- struct{}{}:
atomic.AddInt64(&s.usage, 1)
}
load := atomic.LoadInt64(&s.loadJitter)
if load > 0 {
for i := 0; i <= rand.Intn(int(load)); i++ {
select {
case <-ctx.Done():
return
case s.wait <- struct{}{}:
atomic.AddInt64(&s.usage, 1)
}
}
}
delay := atomic.LoadInt64(&s.delayJitter)
if delay > 0 {
delay = rand.Int63n(delay)
time.Sleep(time.Millisecond * time.Duration(delay))
}
}
func (s *testSubConn) UpdateAddresses([]resolver.Address) {
}
// Connect starts the connecting for this SubConn.
func (s *testSubConn) Connect() {
}
func TestBalancerPick(t *testing.T) {
scs := map[resolver.Address]balancer.SubConn{}
sc1 := &testSubConn{
addr: resolver.Address{
Addr: "test1",
Metadata: wmeta.MD{
Weight: 8,
},
},
}
sc2 := &testSubConn{
addr: resolver.Address{
Addr: "test2",
Metadata: wmeta.MD{
Weight: 4,
Color: "red",
},
},
}
sc3 := &testSubConn{
addr: resolver.Address{
Addr: "test3",
Metadata: wmeta.MD{
Weight: 2,
Color: "red",
},
},
}
sc4 := &testSubConn{
addr: resolver.Address{
Addr: "test4",
Metadata: wmeta.MD{
Weight: 2,
Color: "purple",
},
},
}
scs[sc1.addr] = sc1
scs[sc2.addr] = sc2
scs[sc3.addr] = sc3
scs[sc4.addr] = sc4
b := &p2cPickerBuilder{}
picker := b.Build(scs)
res := []string{"test1", "test1", "test1", "test1"}
for i := 0; i < 3; i++ {
conn, _, err := picker.Pick(context.Background(), balancer.PickOptions{})
if err != nil {
t.Fatalf("picker.Pick failed!idx:=%d", i)
}
sc := conn.(*testSubConn)
if sc.addr.Addr != res[i] {
t.Fatalf("the subconn picked(%s),but expected(%s)", sc.addr.Addr, res[i])
}
}
ctx := nmd.NewContext(context.Background(), nmd.New(map[string]interface{}{"color": "black"}))
for i := 0; i < 4; i++ {
conn, _, err := picker.Pick(ctx, balancer.PickOptions{})
if err != nil {
t.Fatalf("picker.Pick failed!idx:=%d", i)
}
sc := conn.(*testSubConn)
if sc.addr.Addr != res[i] {
t.Fatalf("the (%d) subconn picked(%s),but expected(%s)", i, sc.addr.Addr, res[i])
}
}
env.Color = "purple"
ctx2 := context.Background()
for i := 0; i < 4; i++ {
conn, _, err := picker.Pick(ctx2, balancer.PickOptions{})
if err != nil {
t.Fatalf("picker.Pick failed!idx:=%d", i)
}
sc := conn.(*testSubConn)
if sc.addr.Addr != "test4" {
t.Fatalf("the (%d) subconn picked(%s),but expected(%s)", i, sc.addr.Addr, res[i])
}
}
}
func Benchmark_Wrr(b *testing.B) {
scs := map[resolver.Address]balancer.SubConn{}
for i := 0; i < 50; i++ {
addr := resolver.Address{
Addr: fmt.Sprintf("addr_%d", i),
Metadata: wmeta.MD{Weight: 10},
}
scs[addr] = &testSubConn{addr: addr}
}
wpb := &p2cPickerBuilder{}
picker := wpb.Build(scs)
opt := balancer.PickOptions{}
ctx := context.Background()
for idx := 0; idx < b.N; idx++ {
_, done, err := picker.Pick(ctx, opt)
if err != nil {
done(balancer.DoneInfo{})
}
}
}
func TestChaosPick(t *testing.T) {
flag.Parse()
fmt.Printf("start chaos test!svrNum:%d cliNum:%d concurrency:%d exLoad:%d exDelay:%d\n", serverNum, cliNum, concurrency, extraLoad, extraDelay)
c := newController(serverNum, cliNum)
c.launch(concurrency)
go c.updateStatics()
go c.control(extraLoad, extraDelay)
time.Sleep(time.Second * 50)
}
func newController(svrNum int, cliNum int) *controller {
//new servers
servers := []*testSubConn{}
var weight uint64 = 10
if extraWeight > 0 {
weight = extraWeight
}
for i := 0; i < svrNum; i++ {
weight += extraWeight
sc := newTestSubConn(fmt.Sprintf("addr_%d", i), weight, "")
servers = append(servers, sc)
}
//new clients
var clients []balancer.Picker
scs := map[resolver.Address]balancer.SubConn{}
for _, v := range servers {
scs[v.addr] = v
}
for i := 0; i < cliNum; i++ {
wpb := &p2cPickerBuilder{}
picker := wpb.Build(scs)
clients = append(clients, picker)
}
c := &controller{
servers: servers,
clients: clients,
}
return c
}
type controller struct {
servers []*testSubConn
clients []balancer.Picker
}
func (c *controller) launch(concurrency int) {
opt := balancer.PickOptions{}
bkg := context.Background()
for i := range c.clients {
for j := 0; j < concurrency; j++ {
picker := c.clients[i]
go func() {
for {
ctx, cancel := context.WithTimeout(bkg, time.Millisecond*250)
sc, done, _ := picker.Pick(ctx, opt)
server := sc.(*testSubConn)
server.connect(ctx)
var err error
if ctx.Err() != nil {
err = status.Errorf(codes.DeadlineExceeded, "dead")
}
cancel()
cpu := atomic.LoadInt64(&server.cpu)
md := make(map[string]string)
md[wmeta.CPUUsage] = strconv.FormatInt(cpu, 10)
done(balancer.DoneInfo{Trailer: metadata.New(md), Err: err})
time.Sleep(time.Millisecond * 10)
}
}()
}
}
}
func (c *controller) updateStatics() {
for {
time.Sleep(time.Millisecond * 500)
for _, sc := range c.servers {
usage := atomic.LoadInt64(&sc.usage)
avgCpu := (usage - sc.prevUsage) * 2
atomic.StoreInt64(&sc.cpu, avgCpu)
sc.prevUsage = usage
}
}
}
func (c *controller) control(extraLoad, extraDelay int64) {
var chaos int
for {
fmt.Printf("\n")
//make some chaos
n := rand.Intn(3)
chaos = n + 1
for i := 0; i < chaos; i++ {
if extraLoad > 0 {
degree := rand.Int63n(extraLoad)
degree++
atomic.StoreInt64(&c.servers[i].loadJitter, degree)
fmt.Printf("set addr_%d load:%d ", i, degree)
}
if extraDelay > 0 {
degree := rand.Int63n(extraDelay)
atomic.StoreInt64(&c.servers[i].delayJitter, degree)
fmt.Printf("set addr_%d delay:%dms ", i, degree)
}
}
fmt.Printf("\n")
sleep := int64(5)
time.Sleep(time.Second * time.Duration(sleep))
for _, sc := range c.servers {
req := atomic.LoadInt64(&sc.reqs)
qps := (req - sc.prevReq) / sleep
wait := len(sc.wait)
sc.prevReq = req
fmt.Printf("%s qps:%d waits:%d\n", sc.addr.Addr, qps, wait)
}
for _, picker := range c.clients {
p := picker.(*p2cPicker)
p.printStats()
}
fmt.Printf("\n")
//reset chaos
for i := 0; i < chaos; i++ {
atomic.StoreInt64(&c.servers[i].loadJitter, 0)
atomic.StoreInt64(&c.servers[i].delayJitter, 0)
}
chaos = 0
}
}

@ -0,0 +1,17 @@
### business/warden/balancer/wrr
##### Version 1.3.0
1. 迁移 stat.Summary 到 metric.RollingCounter,metric.RollingGauge
##### Version 1.2.1
1. 删除了netflix ribbon的权重算法,改成了平方根算法
##### Version 1.2.0
1. 实现了动态计算的调度轮询算法(使用了服务端的成功率数据,替换基于本地计算的成功率数据)
##### Version 1.1.0
1. 实现了动态计算的调度轮询算法
##### Version 1.0.0
1. 实现了带权重可以识别Color的轮询算法

@ -0,0 +1,9 @@
# See the OWNERS docs at https://go.k8s.io/owners
approvers:
- caoguoliang
labels:
- library
reviewers:
- caoguoliang
- maojian

@ -0,0 +1,13 @@
#### business/warden/balancer/wrr
##### 项目简介
warden 的 weighted round robin负载均衡模块,主要用于为每个RPC请求返回一个Server节点以供调用
##### 编译环境
- **请只用 Golang v1.9.x 以上版本编译执行**
##### 依赖包
- [grpc](google.golang.org/grpc)

@ -0,0 +1,302 @@
package wrr
import (
"context"
"math"
"strconv"
"sync"
"sync/atomic"
"time"
"github.com/bilibili/kratos/pkg/conf/env"
"github.com/bilibili/kratos/pkg/log"
nmd "github.com/bilibili/kratos/pkg/net/metadata"
wmeta "github.com/bilibili/kratos/pkg/net/rpc/warden/internal/metadata"
"github.com/bilibili/kratos/pkg/stat/metric"
"google.golang.org/grpc"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/balancer/base"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/resolver"
"google.golang.org/grpc/status"
)
var _ base.PickerBuilder = &wrrPickerBuilder{}
var _ balancer.Picker = &wrrPicker{}
// var dwrrFeature feature.Feature = "dwrr"
// Name is the name of round_robin balancer.
const Name = "wrr"
// newBuilder creates a new weighted-roundrobin balancer builder.
func newBuilder() balancer.Builder {
return base.NewBalancerBuilder(Name, &wrrPickerBuilder{})
}
func init() {
//feature.DefaultGate.Add(map[feature.Feature]feature.Spec{
// dwrrFeature: {Default: false},
//})
balancer.Register(newBuilder())
}
type serverInfo struct {
cpu int64
success uint64 // float64 bits
}
type subConn struct {
conn balancer.SubConn
addr resolver.Address
meta wmeta.MD
err metric.RollingCounter
latency metric.RollingGauge
si serverInfo
// effective weight
ewt int64
// current weight
cwt int64
// last score
score float64
}
func (c *subConn) errSummary() (err int64, req int64) {
c.err.Reduce(func(iterator metric.Iterator) float64 {
for iterator.Next() {
bucket := iterator.Bucket()
req += bucket.Count
for _, p := range bucket.Points {
err += int64(p)
}
}
return 0
})
return
}
func (c *subConn) latencySummary() (latency float64, count int64) {
c.latency.Reduce(func(iterator metric.Iterator) float64 {
for iterator.Next() {
bucket := iterator.Bucket()
count += bucket.Count
for _, p := range bucket.Points {
latency += p
}
}
return 0
})
return latency / float64(count), count
}
// statistics is info for log
type statistics struct {
addr string
ewt int64
cs float64
ss float64
lantency float64
cpu float64
req int64
}
// Stats is grpc Interceptor for client to collect server stats
func Stats() grpc.UnaryClientInterceptor {
return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) (err error) {
var (
trailer metadata.MD
md nmd.MD
ok bool
)
if md, ok = nmd.FromContext(ctx); !ok {
md = nmd.MD{}
} else {
md = md.Copy()
}
ctx = nmd.NewContext(ctx, md)
opts = append(opts, grpc.Trailer(&trailer))
err = invoker(ctx, method, req, reply, cc, opts...)
conn, ok := md["conn"].(*subConn)
if !ok {
return
}
if strs, ok := trailer[wmeta.CPUUsage]; ok {
if cpu, err2 := strconv.ParseInt(strs[0], 10, 64); err2 == nil && cpu > 0 {
atomic.StoreInt64(&conn.si.cpu, cpu)
}
}
return
}
}
type wrrPickerBuilder struct{}
func (*wrrPickerBuilder) Build(readySCs map[resolver.Address]balancer.SubConn) balancer.Picker {
p := &wrrPicker{
colors: make(map[string]*wrrPicker),
}
for addr, sc := range readySCs {
meta, ok := addr.Metadata.(wmeta.MD)
if !ok {
meta = wmeta.MD{
Weight: 10,
}
}
subc := &subConn{
conn: sc,
addr: addr,
meta: meta,
ewt: int64(meta.Weight),
score: -1,
err: metric.NewRollingCounter(metric.RollingCounterOpts{
Size: 10,
BucketDuration: time.Millisecond * 100,
}),
latency: metric.NewRollingGauge(metric.RollingGaugeOpts{
Size: 10,
BucketDuration: time.Millisecond * 100,
}),
si: serverInfo{cpu: 500, success: math.Float64bits(1)},
}
if meta.Color == "" {
p.subConns = append(p.subConns, subc)
continue
}
// if color not empty, use color picker
cp, ok := p.colors[meta.Color]
if !ok {
cp = &wrrPicker{}
p.colors[meta.Color] = cp
}
cp.subConns = append(cp.subConns, subc)
}
return p
}
type wrrPicker struct {
// subConns is the snapshot of the weighted-roundrobin balancer when this picker was
// created. The slice is immutable. Each Get() will do a round robin
// selection from it and return the selected SubConn.
subConns []*subConn
colors map[string]*wrrPicker
updateAt int64
mu sync.Mutex
}
func (p *wrrPicker) Pick(ctx context.Context, opts balancer.PickOptions) (balancer.SubConn, func(balancer.DoneInfo), error) {
// FIXME refactor to unify the color logic
color := nmd.String(ctx, nmd.Color)
if color == "" && env.Color != "" {
color = env.Color
}
if color != "" {
if cp, ok := p.colors[color]; ok {
return cp.pick(ctx, opts)
}
}
return p.pick(ctx, opts)
}
func (p *wrrPicker) pick(ctx context.Context, opts balancer.PickOptions) (balancer.SubConn, func(balancer.DoneInfo), error) {
var (
conn *subConn
totalWeight int64
)
if len(p.subConns) <= 0 {
return nil, nil, balancer.ErrNoSubConnAvailable
}
p.mu.Lock()
// nginx wrr load balancing algorithm: http://blog.csdn.net/zhangskd/article/details/50194069
for _, sc := range p.subConns {
totalWeight += sc.ewt
sc.cwt += sc.ewt
if conn == nil || conn.cwt < sc.cwt {
conn = sc
}
}
conn.cwt -= totalWeight
p.mu.Unlock()
start := time.Now()
if cmd, ok := nmd.FromContext(ctx); ok {
cmd["conn"] = conn
}
//if !feature.DefaultGate.Enabled(dwrrFeature) {
// return conn.conn, nil, nil
//}
return conn.conn, func(di balancer.DoneInfo) {
ev := int64(0) // error value ,if error set 1
if di.Err != nil {
if st, ok := status.FromError(di.Err); ok {
// only counter the local grpc error, ignore any business error
if st.Code() != codes.Unknown && st.Code() != codes.OK {
ev = 1
}
}
}
conn.err.Add(ev)
now := time.Now()
conn.latency.Add(now.Sub(start).Nanoseconds() / 1e5)
u := atomic.LoadInt64(&p.updateAt)
if now.UnixNano()-u < int64(time.Second) {
return
}
if !atomic.CompareAndSwapInt64(&p.updateAt, u, now.UnixNano()) {
return
}
var (
stats = make([]statistics, len(p.subConns))
count int
total float64
)
for i, conn := range p.subConns {
cpu := float64(atomic.LoadInt64(&conn.si.cpu))
ss := math.Float64frombits(atomic.LoadUint64(&conn.si.success))
errc, req := conn.errSummary()
lagv, lagc := conn.latencySummary()
if req > 0 && lagc > 0 && lagv > 0 {
// client-side success ratio
cs := 1 - (float64(errc) / float64(req))
if cs <= 0 {
cs = 0.1
} else if cs <= 0.2 && req <= 5 {
cs = 0.2
}
conn.score = math.Sqrt((cs * ss * ss * 1e9) / (lagv * cpu))
stats[i] = statistics{cs: cs, ss: ss, lantency: lagv, cpu: cpu, req: req}
}
stats[i].addr = conn.addr.Addr
if conn.score > 0 {
total += conn.score
count++
}
}
// count must be greater than 1,otherwise will lead ewt to 0
if count < 2 {
return
}
avgscore := total / float64(count)
p.mu.Lock()
for i, conn := range p.subConns {
if conn.score <= 0 {
conn.score = avgscore
}
conn.ewt = int64(conn.score * float64(conn.meta.Weight))
stats[i].ewt = conn.ewt
}
p.mu.Unlock()
log.Info("warden wrr(%s): %+v", conn.addr.ServerName, stats)
}, nil
}

@ -0,0 +1,189 @@
package wrr
import (
"context"
"fmt"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"testing"
"time"
"github.com/bilibili/kratos/pkg/conf/env"
nmd "github.com/bilibili/kratos/pkg/net/metadata"
wmeta "github.com/bilibili/kratos/pkg/net/rpc/warden/internal/metadata"
"github.com/bilibili/kratos/pkg/stat/metric"
"github.com/stretchr/testify/assert"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/resolver"
)
type testSubConn struct {
addr resolver.Address
}
func (s *testSubConn) UpdateAddresses([]resolver.Address) {
}
// Connect starts the connecting for this SubConn.
func (s *testSubConn) Connect() {
fmt.Println(s.addr.Addr)
}
func TestBalancerPick(t *testing.T) {
scs := map[resolver.Address]balancer.SubConn{}
sc1 := &testSubConn{
addr: resolver.Address{
Addr: "test1",
Metadata: wmeta.MD{
Weight: 8,
},
},
}
sc2 := &testSubConn{
addr: resolver.Address{
Addr: "test2",
Metadata: wmeta.MD{
Weight: 4,
Color: "red",
},
},
}
sc3 := &testSubConn{
addr: resolver.Address{
Addr: "test3",
Metadata: wmeta.MD{
Weight: 2,
Color: "red",
},
},
}
scs[sc1.addr] = sc1
scs[sc2.addr] = sc2
scs[sc3.addr] = sc3
b := &wrrPickerBuilder{}
picker := b.Build(scs)
res := []string{"test1", "test1", "test1", "test1"}
for i := 0; i < 3; i++ {
conn, _, err := picker.Pick(context.Background(), balancer.PickOptions{})
if err != nil {
t.Fatalf("picker.Pick failed!idx:=%d", i)
}
sc := conn.(*testSubConn)
if sc.addr.Addr != res[i] {
t.Fatalf("the subconn picked(%s),but expected(%s)", sc.addr.Addr, res[i])
}
}
res2 := []string{"test2", "test3", "test2", "test2", "test3", "test2"}
ctx := nmd.NewContext(context.Background(), nmd.New(map[string]interface{}{"color": "red"}))
for i := 0; i < 6; i++ {
conn, _, err := picker.Pick(ctx, balancer.PickOptions{})
if err != nil {
t.Fatalf("picker.Pick failed!idx:=%d", i)
}
sc := conn.(*testSubConn)
if sc.addr.Addr != res2[i] {
t.Fatalf("the (%d) subconn picked(%s),but expected(%s)", i, sc.addr.Addr, res2[i])
}
}
ctx = nmd.NewContext(context.Background(), nmd.New(map[string]interface{}{"color": "black"}))
for i := 0; i < 4; i++ {
conn, _, err := picker.Pick(ctx, balancer.PickOptions{})
if err != nil {
t.Fatalf("picker.Pick failed!idx:=%d", i)
}
sc := conn.(*testSubConn)
if sc.addr.Addr != res[i] {
t.Fatalf("the (%d) subconn picked(%s),but expected(%s)", i, sc.addr.Addr, res[i])
}
}
// test for env color
ctx = context.Background()
env.Color = "red"
for i := 0; i < 6; i++ {
conn, _, err := picker.Pick(ctx, balancer.PickOptions{})
if err != nil {
t.Fatalf("picker.Pick failed!idx:=%d", i)
}
sc := conn.(*testSubConn)
if sc.addr.Addr != res2[i] {
t.Fatalf("the (%d) subconn picked(%s),but expected(%s)", i, sc.addr.Addr, res2[i])
}
}
}
func TestBalancerDone(t *testing.T) {
scs := map[resolver.Address]balancer.SubConn{}
sc1 := &testSubConn{
addr: resolver.Address{
Addr: "test1",
Metadata: wmeta.MD{
Weight: 8,
},
},
}
scs[sc1.addr] = sc1
b := &wrrPickerBuilder{}
picker := b.Build(scs)
_, done, _ := picker.Pick(context.Background(), balancer.PickOptions{})
time.Sleep(100 * time.Millisecond)
done(balancer.DoneInfo{Err: status.Errorf(codes.Unknown, "test")})
err, req := picker.(*wrrPicker).subConns[0].errSummary()
assert.Equal(t, int64(0), err)
assert.Equal(t, int64(1), req)
latency, count := picker.(*wrrPicker).subConns[0].latencySummary()
expectLatency := float64(100*time.Millisecond) / 1e5
if !(expectLatency < latency && latency < (expectLatency+100)) {
t.Fatalf("latency is less than 100ms or greter than 100ms, %f", latency)
}
assert.Equal(t, int64(1), count)
_, done, _ = picker.Pick(context.Background(), balancer.PickOptions{})
done(balancer.DoneInfo{Err: status.Errorf(codes.Aborted, "test")})
err, req = picker.(*wrrPicker).subConns[0].errSummary()
assert.Equal(t, int64(1), err)
assert.Equal(t, int64(2), req)
}
func TestErrSummary(t *testing.T) {
sc := &subConn{
err: metric.NewRollingCounter(metric.RollingCounterOpts{
Size: 10,
BucketDuration: time.Millisecond * 100,
}),
latency: metric.NewRollingGauge(metric.RollingGaugeOpts{
Size: 10,
BucketDuration: time.Millisecond * 100,
}),
}
for i := 0; i < 10; i++ {
sc.err.Add(0)
sc.err.Add(1)
}
err, req := sc.errSummary()
assert.Equal(t, int64(10), err)
assert.Equal(t, int64(20), req)
}
func TestLatencySummary(t *testing.T) {
sc := &subConn{
err: metric.NewRollingCounter(metric.RollingCounterOpts{
Size: 10,
BucketDuration: time.Millisecond * 100,
}),
latency: metric.NewRollingGauge(metric.RollingGaugeOpts{
Size: 10,
BucketDuration: time.Millisecond * 100,
}),
}
for i := 1; i <= 100; i++ {
sc.latency.Add(int64(i))
}
latency, count := sc.latencySummary()
assert.Equal(t, 50.50, latency)
assert.Equal(t, int64(100), count)
}

@ -0,0 +1,334 @@
package warden
import (
"context"
"fmt"
"net/url"
"os"
"strconv"
"strings"
"sync"
"time"
"github.com/bilibili/kratos/pkg/conf/env"
"github.com/bilibili/kratos/pkg/conf/flagvar"
"github.com/bilibili/kratos/pkg/ecode"
"github.com/bilibili/kratos/pkg/naming"
"github.com/bilibili/kratos/pkg/naming/discovery"
nmd "github.com/bilibili/kratos/pkg/net/metadata"
"github.com/bilibili/kratos/pkg/net/netutil/breaker"
"github.com/bilibili/kratos/pkg/net/rpc/warden/balancer/p2c"
"github.com/bilibili/kratos/pkg/net/rpc/warden/internal/status"
"github.com/bilibili/kratos/pkg/net/rpc/warden/resolver"
"github.com/bilibili/kratos/pkg/net/trace"
xtime "github.com/bilibili/kratos/pkg/time"
"github.com/pkg/errors"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer"
gstatus "google.golang.org/grpc/status"
)
var _grpcTarget flagvar.StringVars
var (
_once sync.Once
_defaultCliConf = &ClientConfig{
Dial: xtime.Duration(time.Second * 10),
Timeout: xtime.Duration(time.Millisecond * 250),
Subset: 50,
}
_defaultClient *Client
)
func baseMetadata() metadata.MD {
gmd := metadata.MD{nmd.Caller: []string{env.AppID}}
if env.Color != "" {
gmd[nmd.Color] = []string{env.Color}
}
return gmd
}
// ClientConfig is rpc client conf.
type ClientConfig struct {
Dial xtime.Duration
Timeout xtime.Duration
Breaker *breaker.Config
Method map[string]*ClientConfig
Clusters []string
Zone string
Subset int
NonBlock bool
}
// Client is the framework's client side instance, it contains the ctx, opt and interceptors.
// Create an instance of Client, by using NewClient().
type Client struct {
conf *ClientConfig
breaker *breaker.Group
mutex sync.RWMutex
opt []grpc.DialOption
handlers []grpc.UnaryClientInterceptor
}
// handle returns a new unary client interceptor for OpenTracing\Logging\LinkTimeout.
func (c *Client) handle() grpc.UnaryClientInterceptor {
return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) (err error) {
var (
ok bool
cmd nmd.MD
t trace.Trace
gmd metadata.MD
conf *ClientConfig
cancel context.CancelFunc
addr string
p peer.Peer
)
var ec ecode.Codes = ecode.OK
// apm tracing
if t, ok = trace.FromContext(ctx); ok {
t = t.Fork("", method)
defer t.Finish(&err)
}
// setup metadata
gmd = baseMetadata()
trace.Inject(t, trace.GRPCFormat, gmd)
c.mutex.RLock()
if conf, ok = c.conf.Method[method]; !ok {
conf = c.conf
}
c.mutex.RUnlock()
brk := c.breaker.Get(method)
if err = brk.Allow(); err != nil {
statsClient.Incr(method, "breaker")
return
}
defer onBreaker(brk, &err)
_, ctx, cancel = conf.Timeout.Shrink(ctx)
defer cancel()
if cmd, ok = nmd.FromContext(ctx); ok {
for netKey, val := range cmd {
if !nmd.IsOutgoingKey(netKey) {
continue
}
valstr, ok := val.(string)
if ok {
gmd[netKey] = []string{valstr}
}
}
}
// merge with old matadata if exists
if oldmd, ok := metadata.FromOutgoingContext(ctx); ok {
gmd = metadata.Join(gmd, oldmd)
}
ctx = metadata.NewOutgoingContext(ctx, gmd)
opts = append(opts, grpc.Peer(&p))
if err = invoker(ctx, method, req, reply, cc, opts...); err != nil {
gst, _ := gstatus.FromError(err)
ec = status.ToEcode(gst)
err = errors.WithMessage(ec, gst.Message())
}
if p.Addr != nil {
addr = p.Addr.String()
}
if t != nil {
t.SetTag(trace.String(trace.TagAddress, addr), trace.String(trace.TagComment, ""))
}
return
}
}
func onBreaker(breaker breaker.Breaker, err *error) {
if err != nil && *err != nil {
if ecode.ServerErr.Equal(*err) || ecode.ServiceUnavailable.Equal(*err) || ecode.Deadline.Equal(*err) || ecode.LimitExceed.Equal(*err) {
breaker.MarkFailed()
return
}
}
breaker.MarkSuccess()
}
// NewConn will create a grpc conn by default config.
func NewConn(target string, opt ...grpc.DialOption) (*grpc.ClientConn, error) {
return DefaultClient().Dial(context.Background(), target, opt...)
}
// NewClient returns a new blank Client instance with a default client interceptor.
// opt can be used to add grpc dial options.
func NewClient(conf *ClientConfig, opt ...grpc.DialOption) *Client {
resolver.Register(discovery.Builder())
c := new(Client)
if err := c.SetConfig(conf); err != nil {
panic(err)
}
c.UseOpt(grpc.WithBalancerName(p2c.Name))
c.UseOpt(opt...)
c.Use(c.recovery(), clientLogging(), c.handle())
return c
}
// DefaultClient returns a new default Client instance with a default client interceptor and default dialoption.
// opt can be used to add grpc dial options.
func DefaultClient() *Client {
resolver.Register(discovery.Builder())
_once.Do(func() {
_defaultClient = NewClient(nil)
})
return _defaultClient
}
// SetConfig hot reloads client config
func (c *Client) SetConfig(conf *ClientConfig) (err error) {
if conf == nil {
conf = _defaultCliConf
}
if conf.Dial <= 0 {
conf.Dial = xtime.Duration(time.Second * 10)
}
if conf.Timeout <= 0 {
conf.Timeout = xtime.Duration(time.Millisecond * 250)
}
if conf.Subset <= 0 {
conf.Subset = 50
}
// FIXME(maojian) check Method dial/timeout
c.mutex.Lock()
c.conf = conf
if c.breaker == nil {
c.breaker = breaker.NewGroup(conf.Breaker)
} else {
c.breaker.Reload(conf.Breaker)
}
c.mutex.Unlock()
return nil
}
// Use attachs a global inteceptor to the Client.
// For example, this is the right place for a circuit breaker or error management inteceptor.
func (c *Client) Use(handlers ...grpc.UnaryClientInterceptor) *Client {
finalSize := len(c.handlers) + len(handlers)
if finalSize >= int(_abortIndex) {
panic("warden: client use too many handlers")
}
mergedHandlers := make([]grpc.UnaryClientInterceptor, finalSize)
copy(mergedHandlers, c.handlers)
copy(mergedHandlers[len(c.handlers):], handlers)
c.handlers = mergedHandlers
return c
}
// UseOpt attachs a global grpc DialOption to the Client.
func (c *Client) UseOpt(opt ...grpc.DialOption) *Client {
c.opt = append(c.opt, opt...)
return c
}
// Dial creates a client connection to the given target.
// Target format is scheme://authority/endpoint?query_arg=value
// example: discovery://default/account.account.service?cluster=shfy01&cluster=shfy02
func (c *Client) Dial(ctx context.Context, target string, opt ...grpc.DialOption) (conn *grpc.ClientConn, err error) {
if !c.conf.NonBlock {
c.opt = append(c.opt, grpc.WithBlock())
}
c.opt = append(c.opt, grpc.WithInsecure())
c.opt = append(c.opt, grpc.WithUnaryInterceptor(c.chainUnaryClient()))
c.opt = append(c.opt, opt...)
c.mutex.RLock()
conf := c.conf
c.mutex.RUnlock()
if conf.Dial > 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, time.Duration(conf.Dial))
defer cancel()
}
if u, e := url.Parse(target); e == nil {
v := u.Query()
for _, c := range c.conf.Clusters {
v.Add(naming.MetaCluster, c)
}
if c.conf.Zone != "" {
v.Add(naming.MetaZone, c.conf.Zone)
}
if v.Get("subset") == "" && c.conf.Subset > 0 {
v.Add("subset", strconv.FormatInt(int64(c.conf.Subset), 10))
}
u.RawQuery = v.Encode()
// 比较_grpcTarget中的appid是否等于u.path中的appid,并替换成mock的地址
for _, t := range _grpcTarget {
strs := strings.SplitN(t, "=", 2)
if len(strs) == 2 && ("/"+strs[0]) == u.Path {
u.Path = "/" + strs[1]
u.Scheme = "passthrough"
u.RawQuery = ""
break
}
}
target = u.String()
}
if conn, err = grpc.DialContext(ctx, target, c.opt...); err != nil {
fmt.Fprintf(os.Stderr, "warden client: dial %s error %v!", target, err)
}
err = errors.WithStack(err)
return
}
// DialTLS creates a client connection over tls transport to the given target.
func (c *Client) DialTLS(ctx context.Context, target string, file string, name string) (conn *grpc.ClientConn, err error) {
var creds credentials.TransportCredentials
creds, err = credentials.NewClientTLSFromFile(file, name)
if err != nil {
err = errors.WithStack(err)
return
}
c.opt = append(c.opt, grpc.WithBlock())
c.opt = append(c.opt, grpc.WithTransportCredentials(creds))
c.opt = append(c.opt, grpc.WithUnaryInterceptor(c.chainUnaryClient()))
c.mutex.RLock()
conf := c.conf
c.mutex.RUnlock()
if conf.Dial > 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, time.Duration(conf.Dial))
defer cancel()
}
conn, err = grpc.DialContext(ctx, target, c.opt...)
err = errors.WithStack(err)
return
}
// chainUnaryClient creates a single interceptor out of a chain of many interceptors.
//
// Execution is done in left-to-right order, including passing of context.
// For example ChainUnaryClient(one, two, three) will execute one before two before three.
func (c *Client) chainUnaryClient() grpc.UnaryClientInterceptor {
n := len(c.handlers)
if n == 0 {
return func(ctx context.Context, method string, req, reply interface{},
cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
return invoker(ctx, method, req, reply, cc, opts...)
}
}
return func(ctx context.Context, method string, req, reply interface{},
cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
var (
i int
chainHandler grpc.UnaryInvoker
)
chainHandler = func(ictx context.Context, imethod string, ireq, ireply interface{}, ic *grpc.ClientConn, iopts ...grpc.CallOption) error {
if i == n-1 {
return invoker(ictx, imethod, ireq, ireply, ic, iopts...)
}
i++
return c.handlers[i](ictx, imethod, ireq, ireply, ic, chainHandler, iopts...)
}
return c.handlers[0](ctx, method, req, reply, cc, chainHandler, opts...)
}
}

@ -0,0 +1,91 @@
package warden_test
import (
"context"
"fmt"
"io"
"time"
"github.com/bilibili/kratos/pkg/log"
"github.com/bilibili/kratos/pkg/net/netutil/breaker"
"github.com/bilibili/kratos/pkg/net/rpc/warden"
pb "github.com/bilibili/kratos/pkg/net/rpc/warden/internal/proto/testproto"
xtime "github.com/bilibili/kratos/pkg/time"
"google.golang.org/grpc"
)
type helloServer struct {
}
func (s *helloServer) SayHello(ctx context.Context, in *pb.HelloRequest) (*pb.HelloReply, error) {
return &pb.HelloReply{Message: "Hello " + in.Name, Success: true}, nil
}
func (s *helloServer) StreamHello(ss pb.Greeter_StreamHelloServer) error {
for i := 0; i < 3; i++ {
in, err := ss.Recv()
if err == io.EOF {
return nil
}
if err != nil {
return err
}
ret := &pb.HelloReply{Message: "Hello " + in.Name, Success: true}
err = ss.Send(ret)
if err != nil {
return err
}
}
return nil
}
func ExampleServer() {
s := warden.NewServer(&warden.ServerConfig{Timeout: xtime.Duration(time.Second), Addr: ":8080"})
// apply server interceptor middleware
s.Use(func(ctx context.Context, req interface{}, args *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
newctx, cancel := context.WithTimeout(ctx, time.Second*10)
defer cancel()
resp, err := handler(newctx, req)
return resp, err
})
pb.RegisterGreeterServer(s.Server(), &helloServer{})
s.Start()
}
func ExampleClient() {
client := warden.NewClient(&warden.ClientConfig{
Dial: xtime.Duration(time.Second * 10),
Timeout: xtime.Duration(time.Second * 10),
Breaker: &breaker.Config{
Window: xtime.Duration(3 * time.Second),
Sleep: xtime.Duration(3 * time.Second),
Bucket: 10,
Ratio: 0.3,
Request: 20,
},
})
// apply client interceptor middleware
client.Use(func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) (ret error) {
newctx, cancel := context.WithTimeout(ctx, time.Second*5)
defer cancel()
ret = invoker(newctx, method, req, reply, cc, opts...)
return
})
conn, err := client.Dial(context.Background(), "127.0.0.1:8080")
if err != nil {
log.Error("did not connect: %v", err)
return
}
defer conn.Close()
c := pb.NewGreeterClient(conn)
name := "2233"
rp, err := c.SayHello(context.Background(), &pb.HelloRequest{Name: name, Age: 18})
if err != nil {
log.Error("could not greet: %v", err)
return
}
fmt.Println("rp", *rp)
}

@ -0,0 +1,189 @@
package main
import (
"flag"
"log"
"reflect"
"sync"
"sync/atomic"
"time"
"github.com/bilibili/kratos/pkg/net/netutil/breaker"
"github.com/bilibili/kratos/pkg/net/rpc/warden"
"github.com/bilibili/kratos/pkg/net/rpc/warden/internal/benchmark/bench/proto"
xtime "github.com/bilibili/kratos/pkg/time"
goproto "github.com/gogo/protobuf/proto"
"github.com/montanaflynn/stats"
"golang.org/x/net/context"
"google.golang.org/grpc"
)
const (
iws = 65535 * 1000
iwsc = 65535 * 10000
readBuffer = 32 * 1024
writeBuffer = 32 * 1024
)
var concurrency = flag.Int("c", 50, "concurrency")
var total = flag.Int("t", 500000, "total requests for all clients")
var host = flag.String("s", "127.0.0.1:8972", "server ip and port")
var isWarden = flag.Bool("w", true, "is warden or grpc client")
var strLen = flag.Int("l", 600, "the length of the str")
func wardenCli() proto.HelloClient {
log.Println("start warden cli")
client := warden.NewClient(&warden.ClientConfig{
Dial: xtime.Duration(time.Second * 10),
Timeout: xtime.Duration(time.Second * 10),
Breaker: &breaker.Config{
Window: xtime.Duration(3 * time.Second),
Sleep: xtime.Duration(3 * time.Second),
Bucket: 10,
Ratio: 0.3,
Request: 20,
},
},
grpc.WithInitialWindowSize(iws),
grpc.WithInitialConnWindowSize(iwsc),
grpc.WithReadBufferSize(readBuffer),
grpc.WithWriteBufferSize(writeBuffer))
conn, err := client.Dial(context.Background(), *host)
if err != nil {
log.Fatalf("did not connect: %v", err)
}
cli := proto.NewHelloClient(conn)
return cli
}
func grpcCli() proto.HelloClient {
log.Println("start grpc cli")
conn, err := grpc.Dial(*host, grpc.WithInsecure(),
grpc.WithInitialWindowSize(iws),
grpc.WithInitialConnWindowSize(iwsc),
grpc.WithReadBufferSize(readBuffer),
grpc.WithWriteBufferSize(writeBuffer))
if err != nil {
log.Fatalf("did not connect: %v", err)
}
cli := proto.NewHelloClient(conn)
return cli
}
func main() {
flag.Parse()
c := *concurrency
m := *total / c
var wg sync.WaitGroup
wg.Add(c)
log.Printf("concurrency: %d\nrequests per client: %d\n\n", c, m)
args := prepareArgs()
b, _ := goproto.Marshal(args)
log.Printf("message size: %d bytes\n\n", len(b))
var trans uint64
var transOK uint64
d := make([][]int64, c)
for i := 0; i < c; i++ {
dt := make([]int64, 0, m)
d = append(d, dt)
}
var cli proto.HelloClient
if *isWarden {
cli = wardenCli()
} else {
cli = grpcCli()
}
//warmup
cli.Say(context.Background(), args)
totalT := time.Now().UnixNano()
for i := 0; i < c; i++ {
go func(i int) {
for j := 0; j < m; j++ {
t := time.Now().UnixNano()
reply, err := cli.Say(context.Background(), args)
t = time.Now().UnixNano() - t
d[i] = append(d[i], t)
if err == nil && reply.Field1 == "OK" {
atomic.AddUint64(&transOK, 1)
}
atomic.AddUint64(&trans, 1)
}
wg.Done()
}(i)
}
wg.Wait()
totalT = time.Now().UnixNano() - totalT
totalT = totalT / 1e6
log.Printf("took %d ms for %d requests\n", totalT, *total)
totalD := make([]int64, 0, *total)
for _, k := range d {
totalD = append(totalD, k...)
}
totalD2 := make([]float64, 0, *total)
for _, k := range totalD {
totalD2 = append(totalD2, float64(k))
}
mean, _ := stats.Mean(totalD2)
median, _ := stats.Median(totalD2)
max, _ := stats.Max(totalD2)
min, _ := stats.Min(totalD2)
tp99, _ := stats.Percentile(totalD2, 99)
tp999, _ := stats.Percentile(totalD2, 99.9)
log.Printf("sent requests : %d\n", *total)
log.Printf("received requests_OK : %d\n", atomic.LoadUint64(&transOK))
log.Printf("throughput (TPS) : %d\n", int64(c*m)*1000/totalT)
log.Printf("mean: %v ms, median: %v ms, max: %v ms, min: %v ms, p99: %v ms, p999:%v ms\n", mean/1e6, median/1e6, max/1e6, min/1e6, tp99/1e6, tp999/1e6)
}
func prepareArgs() *proto.BenchmarkMessage {
b := true
var i int32 = 120000
var i64 int64 = 98765432101234
var s = "许多往事在眼前一幕一幕,变的那麼模糊"
repeat := *strLen / (8 * 54)
if repeat == 0 {
repeat = 1
}
var str string
for i := 0; i < repeat; i++ {
str += s
}
var args proto.BenchmarkMessage
v := reflect.ValueOf(&args).Elem()
num := v.NumField()
for k := 0; k < num; k++ {
field := v.Field(k)
if field.Type().Kind() == reflect.Ptr {
switch v.Field(k).Type().Elem().Kind() {
case reflect.Int, reflect.Int32:
field.Set(reflect.ValueOf(&i))
case reflect.Int64:
field.Set(reflect.ValueOf(&i64))
case reflect.Bool:
field.Set(reflect.ValueOf(&b))
case reflect.String:
field.Set(reflect.ValueOf(&str))
}
} else {
switch field.Kind() {
case reflect.Int, reflect.Int32, reflect.Int64:
field.SetInt(9876543)
case reflect.Bool:
field.SetBool(true)
case reflect.String:
field.SetString(str)
}
}
}
return &args
}

@ -0,0 +1,60 @@
syntax = "proto3";
package proto;
import "github.com/gogo/protobuf/gogoproto/gogo.proto";
option optimize_for = SPEED;
option (gogoproto.goproto_enum_prefix_all) = false;
option (gogoproto.goproto_getters_all) = false;
option (gogoproto.unmarshaler_all) = true;
option (gogoproto.marshaler_all) = true;
option (gogoproto.sizer_all) = true;
service Hello {
// Sends a greeting
rpc Say (BenchmarkMessage) returns (BenchmarkMessage) {}
}
message BenchmarkMessage {
string field1 = 1;
string field9 = 9;
string field18 = 18;
bool field80 = 80;
bool field81 = 81;
int32 field2 = 2;
int32 field3 = 3;
int32 field280 = 280;
int32 field6 = 6;
int64 field22 = 22;
string field4 = 4;
fixed64 field5 = 5;
bool field59 = 59;
string field7 = 7;
int32 field16 = 16;
int32 field130 = 130;
bool field12 = 12;
bool field17 = 17;
bool field13 = 13;
bool field14 = 14;
int32 field104 = 104;
int32 field100 = 100;
int32 field101 = 101;
string field102 = 102;
string field103 = 103;
int32 field29 = 29;
bool field30 = 30;
int32 field60 = 60;
int32 field271 = 271;
int32 field272 = 272;
int32 field150 = 150;
int32 field23 = 23;
bool field24 = 24 ;
int32 field25 = 25 ;
bool field78 = 78;
int32 field67 = 67;
int32 field68 = 68;
int32 field128 = 128;
string field129 = 129;
int32 field131 = 131;
}

@ -0,0 +1,103 @@
package main
import (
"context"
"flag"
"log"
"net"
"net/http"
_ "net/http/pprof"
"sync/atomic"
"time"
"github.com/bilibili/kratos/pkg/net/rpc/warden"
"github.com/bilibili/kratos/pkg/net/rpc/warden/internal/benchmark/bench/proto"
xtime "github.com/bilibili/kratos/pkg/time"
"github.com/prometheus/client_golang/prometheus/promhttp"
"google.golang.org/grpc"
)
const (
iws = 65535 * 1000
iwsc = 65535 * 10000
readBuffer = 32 * 1024
writeBuffer = 32 * 1024
)
var reqNum uint64
type Hello struct{}
func (t *Hello) Say(ctx context.Context, args *proto.BenchmarkMessage) (reply *proto.BenchmarkMessage, err error) {
s := "OK"
var i int32 = 100
args.Field1 = s
args.Field2 = i
atomic.AddUint64(&reqNum, 1)
return args, nil
}
var host = flag.String("s", "0.0.0.0:8972", "listened ip and port")
var isWarden = flag.Bool("w", true, "is warden or grpc client")
func main() {
go func() {
log.Println("run http at :6060")
http.HandleFunc("/metrics", func(w http.ResponseWriter, r *http.Request) {
h := promhttp.Handler()
h.ServeHTTP(w, r)
})
log.Println(http.ListenAndServe("0.0.0.0:6060", nil))
}()
flag.Parse()
go stat()
if *isWarden {
runWarden()
} else {
runGrpc()
}
}
func runGrpc() {
log.Println("run grpc")
lis, err := net.Listen("tcp", *host)
if err != nil {
log.Fatalf("failed to listen: %v", err)
}
s := grpc.NewServer(grpc.InitialWindowSize(iws),
grpc.InitialConnWindowSize(iwsc),
grpc.ReadBufferSize(readBuffer),
grpc.WriteBufferSize(writeBuffer))
proto.RegisterHelloServer(s, &Hello{})
s.Serve(lis)
}
func runWarden() {
log.Println("run warden")
s := warden.NewServer(&warden.ServerConfig{Timeout: xtime.Duration(time.Second * 3)},
grpc.InitialWindowSize(iws),
grpc.InitialConnWindowSize(iwsc),
grpc.ReadBufferSize(readBuffer),
grpc.WriteBufferSize(writeBuffer))
proto.RegisterHelloServer(s.Server(), &Hello{})
s.Run(*host)
}
func stat() {
ticker := time.NewTicker(time.Second * 5)
defer ticker.Stop()
var last uint64
lastTs := uint64(time.Now().UnixNano())
for {
<-ticker.C
now := atomic.LoadUint64(&reqNum)
nowTs := uint64(time.Now().UnixNano())
qps := (now - last) * 1e6 / ((nowTs - lastTs) / 1e3)
last = now
lastTs = nowTs
log.Println("qps:", qps)
}
}

@ -0,0 +1,15 @@
#!/bin/bash
go build -o client greeter_client.go
echo size 100 concurrent 30
./client -s 100 -c 30
echo size 1000 concurrent 30
./client -s 1000 -c 30
echo size 10000 concurrent 30
./client -s 10000 -c 30
echo size 100 concurrent 300
./client -s 100 -c 300
echo size 1000 concurrent 300
./client -s 1000 -c 300
echo size 10000 concurrent 300
./client -s 10000 -c 300
rm client

@ -0,0 +1,85 @@
package main
import (
"context"
"flag"
"fmt"
"math/rand"
"sync"
"sync/atomic"
"time"
"github.com/bilibili/kratos/pkg/net/netutil/breaker"
"github.com/bilibili/kratos/pkg/net/rpc/warden"
pb "github.com/bilibili/kratos/pkg/net/rpc/warden/internal/proto/testproto"
xtime "github.com/bilibili/kratos/pkg/time"
)
var (
ccf = &warden.ClientConfig{
Dial: xtime.Duration(time.Second * 10),
Timeout: xtime.Duration(time.Second * 10),
Breaker: &breaker.Config{
Window: xtime.Duration(3 * time.Second),
Sleep: xtime.Duration(3 * time.Second),
Bucket: 10,
Ratio: 0.3,
Request: 20,
},
}
cli pb.GreeterClient
wg sync.WaitGroup
reqSize int
concurrency int
request int
all int64
)
func init() {
flag.IntVar(&reqSize, "s", 10, "request size")
flag.IntVar(&concurrency, "c", 10, "concurrency")
flag.IntVar(&request, "r", 1000, "request per routine")
}
func main() {
flag.Parse()
name := randSeq(reqSize)
cli = newClient()
for i := 0; i < concurrency; i++ {
wg.Add(1)
go sayHello(&pb.HelloRequest{Name: name})
}
wg.Wait()
fmt.Printf("per request cost %v\n", all/int64(request*concurrency))
}
func sayHello(in *pb.HelloRequest) {
defer wg.Done()
now := time.Now()
for i := 0; i < request; i++ {
cli.SayHello(context.TODO(), in)
}
delta := time.Since(now)
atomic.AddInt64(&all, int64(delta/time.Millisecond))
}
var letters = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ")
func randSeq(n int) string {
b := make([]rune, n)
for i := range b {
b[i] = letters[rand.Intn(len(letters))]
}
return string(b)
}
func newClient() (cli pb.GreeterClient) {
client := warden.NewClient(ccf)
conn, err := client.Dial(context.TODO(), "127.0.0.1:9999")
if err != nil {
return
}
cli = pb.NewGreeterClient(conn)
return
}

@ -0,0 +1,50 @@
package main
import (
"context"
"net/http"
"time"
"github.com/bilibili/kratos/pkg/net/rpc/warden"
pb "github.com/bilibili/kratos/pkg/net/rpc/warden/internal/proto/testproto"
xtime "github.com/bilibili/kratos/pkg/time"
"github.com/prometheus/client_golang/prometheus/promhttp"
)
var (
config = &warden.ServerConfig{Timeout: xtime.Duration(time.Second)}
)
func main() {
newServer()
}
type hello struct {
}
func (s *hello) SayHello(c context.Context, in *pb.HelloRequest) (out *pb.HelloReply, err error) {
out = new(pb.HelloReply)
out.Message = in.Name
return
}
func (s *hello) StreamHello(ss pb.Greeter_StreamHelloServer) error {
return nil
}
func newServer() {
server := warden.NewServer(config)
pb.RegisterGreeterServer(server.Server(), &hello{})
go func() {
http.HandleFunc("/metrics", func(w http.ResponseWriter, r *http.Request) {
h := promhttp.Handler()
h.ServeHTTP(w, r)
})
http.ListenAndServe("0.0.0.0:9998", nil)
}()
err := server.Run(":9999")
if err != nil {
return
}
}

@ -0,0 +1,53 @@
package codec
import (
"bytes"
"encoding/json"
"github.com/gogo/protobuf/jsonpb"
"github.com/gogo/protobuf/proto"
"google.golang.org/grpc/encoding"
)
//Reference https://jbrandhorst.com/post/grpc-json/
func init() {
encoding.RegisterCodec(JSON{
Marshaler: jsonpb.Marshaler{
EmitDefaults: true,
OrigName: true,
},
})
}
// JSON is impl of encoding.Codec
type JSON struct {
jsonpb.Marshaler
jsonpb.Unmarshaler
}
// Name is name of JSON
func (j JSON) Name() string {
return "json"
}
// Marshal is json marshal
func (j JSON) Marshal(v interface{}) (out []byte, err error) {
if pm, ok := v.(proto.Message); ok {
b := new(bytes.Buffer)
err := j.Marshaler.Marshal(b, pm)
if err != nil {
return nil, err
}
return b.Bytes(), nil
}
return json.Marshal(v)
}
// Unmarshal is json unmarshal
func (j JSON) Unmarshal(data []byte, v interface{}) (err error) {
if pm, ok := v.(proto.Message); ok {
b := bytes.NewBuffer(data)
return j.Unmarshaler.Unmarshal(b, pm)
}
return json.Unmarshal(data, v)
}

@ -0,0 +1,31 @@
package main
import (
"context"
"flag"
"fmt"
"github.com/bilibili/kratos/pkg/log"
"github.com/bilibili/kratos/pkg/net/rpc/warden"
pb "github.com/bilibili/kratos/pkg/net/rpc/warden/internal/proto/testproto"
)
// usage: ./client -grpc.target=test.service=127.0.0.1:8080
func main() {
log.Init(&log.Config{Stdout: true})
flag.Parse()
conn, err := warden.NewClient(nil).Dial(context.Background(), "127.0.0.1:8081")
if err != nil {
panic(err)
}
cli := pb.NewGreeterClient(conn)
normalCall(cli)
}
func normalCall(cli pb.GreeterClient) {
reply, err := cli.SayHello(context.Background(), &pb.HelloRequest{Name: "tom", Age: 23})
if err != nil {
panic(err)
}
fmt.Println("get reply:", *reply)
}

@ -0,0 +1,191 @@
package main
import (
"context"
"encoding/json"
"flag"
"fmt"
"io/ioutil"
"math/rand"
"net/http"
"os"
"strings"
"github.com/gogo/protobuf/jsonpb"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/encoding"
)
// Reply for test
type Reply struct {
res []byte
}
type Discovery struct {
HttpClient *http.Client
Nodes []string
}
var (
data string
file string
method string
addr string
tlsCert string
tlsServerName string
appID string
env string
)
//Reference https://jbrandhorst.com/post/grpc-json/
func init() {
encoding.RegisterCodec(JSON{
Marshaler: jsonpb.Marshaler{
EmitDefaults: true,
OrigName: true,
},
})
flag.StringVar(&data, "data", `{"name":"longxia","age":19}`, `{"name":"longxia","age":19}`)
flag.StringVar(&file, "file", ``, `./data.json`)
flag.StringVar(&method, "method", "/testproto.Greeter/SayHello", `/testproto.Greeter/SayHello`)
flag.StringVar(&addr, "addr", "127.0.0.1:8080", `127.0.0.1:8080`)
flag.StringVar(&tlsCert, "cert", "", `./cert.pem`)
flag.StringVar(&tlsServerName, "server_name", "", `hello_server`)
flag.StringVar(&appID, "appid", "", `appid`)
flag.StringVar(&env, "env", "", `env`)
}
// 该example因为使用的是json传输格式所以只能用于调试或测试,用于线上会导致性能下降
// 使用方法:
// ./grpcDebug -data='{"name":"xia","age":19}' -addr=127.0.0.1:8080 -method=/testproto.Greeter/SayHello
// ./grpcDebug -file=data.json -addr=127.0.0.1:8080 -method=/testproto.Greeter/SayHello
// DEPLOY_ENV=uat ./grpcDebug -appid=main.community.reply-service -method=/reply.service.v1.Reply/ReplyInfoCache -data='{"rp_id"=1493769244}'
func main() {
flag.Parse()
opts := []grpc.DialOption{
grpc.WithInsecure(),
grpc.WithDefaultCallOptions(grpc.CallContentSubtype(JSON{}.Name())),
}
if tlsCert != "" {
creds, err := credentials.NewClientTLSFromFile(tlsCert, tlsServerName)
if err != nil {
panic(err)
}
opts = append(opts, grpc.WithTransportCredentials(creds))
}
if file != "" {
content, err := ioutil.ReadFile(file)
if err != nil {
fmt.Println("ioutil.ReadFile %s failed!err:=%v", file, err)
os.Exit(1)
}
if len(content) > 0 {
data = string(content)
}
}
if appID != "" {
addr = ipFromDiscovery(appID, env)
}
conn, err := grpc.Dial(addr, opts...)
if err != nil {
panic(err)
}
var reply Reply
err = grpc.Invoke(context.Background(), method, []byte(data), &reply, conn)
if err != nil {
panic(err)
}
fmt.Println(string(reply.res))
}
func ipFromDiscovery(appID, env string) string {
d := &Discovery{
Nodes: []string{"discovery.bilibili.co", "api.bilibili.co"},
HttpClient: http.DefaultClient,
}
deployEnv := os.Getenv("DEPLOY_ENV")
if deployEnv != "" {
env = deployEnv
}
return d.addr(appID, env, d.nodes())
}
func (d *Discovery) nodes() (addrs []string) {
res := new(struct {
Code int `json:"code"`
Data []struct {
Addr string `json:"addr"`
} `json:"data"`
})
resp, err := d.HttpClient.Get(fmt.Sprintf("http://%s/discovery/nodes", d.Nodes[rand.Intn(len(d.Nodes))]))
if err != nil {
panic(err)
}
defer resp.Body.Close()
if err = json.NewDecoder(resp.Body).Decode(&res); err != nil {
panic(err)
}
for _, data := range res.Data {
addrs = append(addrs, data.Addr)
}
return
}
func (d *Discovery) addr(appID, env string, nodes []string) (ip string) {
res := new(struct {
Code int `json:"code"`
Message string `json:"message"`
Data map[string]*struct {
ZoneInstances map[string][]*struct {
AppID string `json:"appid"`
Addrs []string `json:"addrs"`
} `json:"zone_instances"`
} `json:"data"`
})
host, _ := os.Hostname()
resp, err := d.HttpClient.Get(fmt.Sprintf("http://%s/discovery/polls?appid=%s&env=%s&hostname=%s", nodes[rand.Intn(len(nodes))], appID, env, host))
if err != nil {
panic(err)
}
defer resp.Body.Close()
if err = json.NewDecoder(resp.Body).Decode(&res); err != nil {
panic(err)
}
for _, data := range res.Data {
for _, zoneInstance := range data.ZoneInstances {
for _, instance := range zoneInstance {
if instance.AppID == appID {
for _, addr := range instance.Addrs {
if strings.Contains(addr, "grpc://") {
return strings.Replace(addr, "grpc://", "", -1)
}
}
}
}
}
}
return
}
// JSON is impl of encoding.Codec
type JSON struct {
jsonpb.Marshaler
jsonpb.Unmarshaler
}
// Name is name of JSON
func (j JSON) Name() string {
return "json"
}
// Marshal is json marshal
func (j JSON) Marshal(v interface{}) (out []byte, err error) {
return v.([]byte), nil
}
// Unmarshal is json unmarshal
func (j JSON) Unmarshal(data []byte, v interface{}) (err error) {
v.(*Reply).res = data
return nil
}

@ -0,0 +1,108 @@
package main
import (
"context"
"fmt"
"io"
"os"
"os/signal"
"syscall"
"time"
"github.com/bilibili/kratos/pkg/ecode"
epb "github.com/bilibili/kratos/pkg/ecode/pb"
"github.com/bilibili/kratos/pkg/log"
"github.com/bilibili/kratos/pkg/net/rpc/warden"
pb "github.com/bilibili/kratos/pkg/net/rpc/warden/internal/proto/testproto"
xtime "github.com/bilibili/kratos/pkg/time"
"github.com/golang/protobuf/ptypes"
"google.golang.org/grpc"
)
type helloServer struct {
addr string
}
func (s *helloServer) SayHello(ctx context.Context, in *pb.HelloRequest) (*pb.HelloReply, error) {
if in.Name == "err_detail_test" {
any, _ := ptypes.MarshalAny(&pb.HelloReply{Success: true, Message: "this is test detail"})
err := epb.From(ecode.AccessDenied)
err.ErrDetail = any
return nil, err
}
return &pb.HelloReply{Message: fmt.Sprintf("hello %s from %s", in.Name, s.addr)}, nil
}
func (s *helloServer) StreamHello(ss pb.Greeter_StreamHelloServer) error {
for i := 0; i < 3; i++ {
in, err := ss.Recv()
if err == io.EOF {
return nil
}
if err != nil {
return err
}
ret := &pb.HelloReply{Message: "Hello " + in.Name, Success: true}
err = ss.Send(ret)
if err != nil {
return err
}
}
return nil
}
func runServer(addr string) *warden.Server {
server := warden.NewServer(&warden.ServerConfig{
//服务端每个请求的默认超时时间
Timeout: xtime.Duration(time.Second),
})
server.Use(middleware())
pb.RegisterGreeterServer(server.Server(), &helloServer{addr: addr})
go func() {
err := server.Run(addr)
if err != nil {
panic("run server failed!" + err.Error())
}
}()
return server
}
func main() {
log.Init(&log.Config{Stdout: true})
server := runServer("0.0.0.0:8081")
signalHandler(server)
}
//类似于中间件
func middleware() grpc.UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
//记录调用方法
log.Info("method:%s", info.FullMethod)
//call chain
resp, err = handler(ctx, req)
return
}
}
func signalHandler(s *warden.Server) {
var (
ch = make(chan os.Signal, 1)
)
signal.Notify(ch, syscall.SIGHUP, syscall.SIGQUIT, syscall.SIGTERM, syscall.SIGINT)
for {
si := <-ch
switch si {
case syscall.SIGQUIT, syscall.SIGTERM, syscall.SIGINT:
log.Info("get a signal %s, stop the consume process", si.String())
ctx, cancel := context.WithTimeout(context.Background(), time.Second*3)
defer cancel()
//gracefully shutdown with timeout
s.Shutdown(ctx)
return
case syscall.SIGHUP:
default:
return
}
}
}

@ -0,0 +1,11 @@
package metadata
const (
CPUUsage = "cpu_usage"
)
// MD is context metadata for balancer and resolver
type MD struct {
Weight uint64
Color string
}

@ -0,0 +1,642 @@
// Code generated by protoc-gen-gogo. DO NOT EDIT.
// source: hello.proto
/*
Package testproto is a generated protocol buffer package.
It is generated from these files:
hello.proto
It has these top-level messages:
HelloRequest
HelloReply
*/
package testproto
import proto "github.com/golang/protobuf/proto"
import fmt "fmt"
import math "math"
import _ "github.com/gogo/protobuf/gogoproto"
import context "golang.org/x/net/context"
import grpc "google.golang.org/grpc"
import io "io"
// Reference imports to suppress errors if they are not otherwise used.
var _ = proto.Marshal
var _ = fmt.Errorf
var _ = math.Inf
// This is a compile-time assertion to ensure that this generated file
// is compatible with the proto package it is being compiled against.
// A compilation error at this line likely means your copy of the
// proto package needs to be updated.
const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package
// The request message containing the user's name.
type HelloRequest struct {
Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name" validate:"required"`
Age int32 `protobuf:"varint,2,opt,name=age,proto3" json:"age" validate:"min=0"`
}
func (m *HelloRequest) Reset() { *m = HelloRequest{} }
func (m *HelloRequest) String() string { return proto.CompactTextString(m) }
func (*HelloRequest) ProtoMessage() {}
func (*HelloRequest) Descriptor() ([]byte, []int) { return fileDescriptorHello, []int{0} }
// The response message containing the greetings
type HelloReply struct {
Message string `protobuf:"bytes,1,opt,name=message,proto3" json:"message,omitempty"`
Success bool `protobuf:"varint,2,opt,name=success,proto3" json:"success,omitempty"`
}
func (m *HelloReply) Reset() { *m = HelloReply{} }
func (m *HelloReply) String() string { return proto.CompactTextString(m) }
func (*HelloReply) ProtoMessage() {}
func (*HelloReply) Descriptor() ([]byte, []int) { return fileDescriptorHello, []int{1} }
func init() {
proto.RegisterType((*HelloRequest)(nil), "testproto.HelloRequest")
proto.RegisterType((*HelloReply)(nil), "testproto.HelloReply")
}
// Reference imports to suppress errors if they are not otherwise used.
var _ context.Context
var _ grpc.ClientConn
// This is a compile-time assertion to ensure that this generated file
// is compatible with the grpc package it is being compiled against.
const _ = grpc.SupportPackageIsVersion4
// Client API for Greeter service
type GreeterClient interface {
// Sends a greeting
SayHello(ctx context.Context, in *HelloRequest, opts ...grpc.CallOption) (*HelloReply, error)
// A bidirectional streaming RPC call recvice HelloRequest return HelloReply
StreamHello(ctx context.Context, opts ...grpc.CallOption) (Greeter_StreamHelloClient, error)
}
type greeterClient struct {
cc *grpc.ClientConn
}
func NewGreeterClient(cc *grpc.ClientConn) GreeterClient {
return &greeterClient{cc}
}
func (c *greeterClient) SayHello(ctx context.Context, in *HelloRequest, opts ...grpc.CallOption) (*HelloReply, error) {
out := new(HelloReply)
err := grpc.Invoke(ctx, "/testproto.Greeter/SayHello", in, out, c.cc, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *greeterClient) StreamHello(ctx context.Context, opts ...grpc.CallOption) (Greeter_StreamHelloClient, error) {
stream, err := grpc.NewClientStream(ctx, &_Greeter_serviceDesc.Streams[0], c.cc, "/testproto.Greeter/StreamHello", opts...)
if err != nil {
return nil, err
}
x := &greeterStreamHelloClient{stream}
return x, nil
}
type Greeter_StreamHelloClient interface {
Send(*HelloRequest) error
Recv() (*HelloReply, error)
grpc.ClientStream
}
type greeterStreamHelloClient struct {
grpc.ClientStream
}
func (x *greeterStreamHelloClient) Send(m *HelloRequest) error {
return x.ClientStream.SendMsg(m)
}
func (x *greeterStreamHelloClient) Recv() (*HelloReply, error) {
m := new(HelloReply)
if err := x.ClientStream.RecvMsg(m); err != nil {
return nil, err
}
return m, nil
}
// Server API for Greeter service
type GreeterServer interface {
// Sends a greeting
SayHello(context.Context, *HelloRequest) (*HelloReply, error)
// A bidirectional streaming RPC call recvice HelloRequest return HelloReply
StreamHello(Greeter_StreamHelloServer) error
}
func RegisterGreeterServer(s *grpc.Server, srv GreeterServer) {
s.RegisterService(&_Greeter_serviceDesc, srv)
}
func _Greeter_SayHello_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(HelloRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(GreeterServer).SayHello(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/testproto.Greeter/SayHello",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(GreeterServer).SayHello(ctx, req.(*HelloRequest))
}
return interceptor(ctx, in, info, handler)
}
func _Greeter_StreamHello_Handler(srv interface{}, stream grpc.ServerStream) error {
return srv.(GreeterServer).StreamHello(&greeterStreamHelloServer{stream})
}
type Greeter_StreamHelloServer interface {
Send(*HelloReply) error
Recv() (*HelloRequest, error)
grpc.ServerStream
}
type greeterStreamHelloServer struct {
grpc.ServerStream
}
func (x *greeterStreamHelloServer) Send(m *HelloReply) error {
return x.ServerStream.SendMsg(m)
}
func (x *greeterStreamHelloServer) Recv() (*HelloRequest, error) {
m := new(HelloRequest)
if err := x.ServerStream.RecvMsg(m); err != nil {
return nil, err
}
return m, nil
}
var _Greeter_serviceDesc = grpc.ServiceDesc{
ServiceName: "testproto.Greeter",
HandlerType: (*GreeterServer)(nil),
Methods: []grpc.MethodDesc{
{
MethodName: "SayHello",
Handler: _Greeter_SayHello_Handler,
},
},
Streams: []grpc.StreamDesc{
{
StreamName: "StreamHello",
Handler: _Greeter_StreamHello_Handler,
ServerStreams: true,
ClientStreams: true,
},
},
Metadata: "hello.proto",
}
func (m *HelloRequest) Marshal() (dAtA []byte, err error) {
size := m.Size()
dAtA = make([]byte, size)
n, err := m.MarshalTo(dAtA)
if err != nil {
return nil, err
}
return dAtA[:n], nil
}
func (m *HelloRequest) MarshalTo(dAtA []byte) (int, error) {
var i int
_ = i
var l int
_ = l
if len(m.Name) > 0 {
dAtA[i] = 0xa
i++
i = encodeVarintHello(dAtA, i, uint64(len(m.Name)))
i += copy(dAtA[i:], m.Name)
}
if m.Age != 0 {
dAtA[i] = 0x10
i++
i = encodeVarintHello(dAtA, i, uint64(m.Age))
}
return i, nil
}
func (m *HelloReply) Marshal() (dAtA []byte, err error) {
size := m.Size()
dAtA = make([]byte, size)
n, err := m.MarshalTo(dAtA)
if err != nil {
return nil, err
}
return dAtA[:n], nil
}
func (m *HelloReply) MarshalTo(dAtA []byte) (int, error) {
var i int
_ = i
var l int
_ = l
if len(m.Message) > 0 {
dAtA[i] = 0xa
i++
i = encodeVarintHello(dAtA, i, uint64(len(m.Message)))
i += copy(dAtA[i:], m.Message)
}
if m.Success {
dAtA[i] = 0x10
i++
if m.Success {
dAtA[i] = 1
} else {
dAtA[i] = 0
}
i++
}
return i, nil
}
func encodeVarintHello(dAtA []byte, offset int, v uint64) int {
for v >= 1<<7 {
dAtA[offset] = uint8(v&0x7f | 0x80)
v >>= 7
offset++
}
dAtA[offset] = uint8(v)
return offset + 1
}
func (m *HelloRequest) Size() (n int) {
var l int
_ = l
l = len(m.Name)
if l > 0 {
n += 1 + l + sovHello(uint64(l))
}
if m.Age != 0 {
n += 1 + sovHello(uint64(m.Age))
}
return n
}
func (m *HelloReply) Size() (n int) {
var l int
_ = l
l = len(m.Message)
if l > 0 {
n += 1 + l + sovHello(uint64(l))
}
if m.Success {
n += 2
}
return n
}
func sovHello(x uint64) (n int) {
for {
n++
x >>= 7
if x == 0 {
break
}
}
return n
}
func sozHello(x uint64) (n int) {
return sovHello(uint64((x << 1) ^ uint64((int64(x) >> 63))))
}
func (m *HelloRequest) Unmarshal(dAtA []byte) error {
l := len(dAtA)
iNdEx := 0
for iNdEx < l {
preIndex := iNdEx
var wire uint64
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return ErrIntOverflowHello
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
wire |= (uint64(b) & 0x7F) << shift
if b < 0x80 {
break
}
}
fieldNum := int32(wire >> 3)
wireType := int(wire & 0x7)
if wireType == 4 {
return fmt.Errorf("proto: HelloRequest: wiretype end group for non-group")
}
if fieldNum <= 0 {
return fmt.Errorf("proto: HelloRequest: illegal tag %d (wire type %d)", fieldNum, wire)
}
switch fieldNum {
case 1:
if wireType != 2 {
return fmt.Errorf("proto: wrong wireType = %d for field Name", wireType)
}
var stringLen uint64
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return ErrIntOverflowHello
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
stringLen |= (uint64(b) & 0x7F) << shift
if b < 0x80 {
break
}
}
intStringLen := int(stringLen)
if intStringLen < 0 {
return ErrInvalidLengthHello
}
postIndex := iNdEx + intStringLen
if postIndex > l {
return io.ErrUnexpectedEOF
}
m.Name = string(dAtA[iNdEx:postIndex])
iNdEx = postIndex
case 2:
if wireType != 0 {
return fmt.Errorf("proto: wrong wireType = %d for field Age", wireType)
}
m.Age = 0
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return ErrIntOverflowHello
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
m.Age |= (int32(b) & 0x7F) << shift
if b < 0x80 {
break
}
}
default:
iNdEx = preIndex
skippy, err := skipHello(dAtA[iNdEx:])
if err != nil {
return err
}
if skippy < 0 {
return ErrInvalidLengthHello
}
if (iNdEx + skippy) > l {
return io.ErrUnexpectedEOF
}
iNdEx += skippy
}
}
if iNdEx > l {
return io.ErrUnexpectedEOF
}
return nil
}
func (m *HelloReply) Unmarshal(dAtA []byte) error {
l := len(dAtA)
iNdEx := 0
for iNdEx < l {
preIndex := iNdEx
var wire uint64
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return ErrIntOverflowHello
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
wire |= (uint64(b) & 0x7F) << shift
if b < 0x80 {
break
}
}
fieldNum := int32(wire >> 3)
wireType := int(wire & 0x7)
if wireType == 4 {
return fmt.Errorf("proto: HelloReply: wiretype end group for non-group")
}
if fieldNum <= 0 {
return fmt.Errorf("proto: HelloReply: illegal tag %d (wire type %d)", fieldNum, wire)
}
switch fieldNum {
case 1:
if wireType != 2 {
return fmt.Errorf("proto: wrong wireType = %d for field Message", wireType)
}
var stringLen uint64
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return ErrIntOverflowHello
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
stringLen |= (uint64(b) & 0x7F) << shift
if b < 0x80 {
break
}
}
intStringLen := int(stringLen)
if intStringLen < 0 {
return ErrInvalidLengthHello
}
postIndex := iNdEx + intStringLen
if postIndex > l {
return io.ErrUnexpectedEOF
}
m.Message = string(dAtA[iNdEx:postIndex])
iNdEx = postIndex
case 2:
if wireType != 0 {
return fmt.Errorf("proto: wrong wireType = %d for field Success", wireType)
}
var v int
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return ErrIntOverflowHello
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
v |= (int(b) & 0x7F) << shift
if b < 0x80 {
break
}
}
m.Success = bool(v != 0)
default:
iNdEx = preIndex
skippy, err := skipHello(dAtA[iNdEx:])
if err != nil {
return err
}
if skippy < 0 {
return ErrInvalidLengthHello
}
if (iNdEx + skippy) > l {
return io.ErrUnexpectedEOF
}
iNdEx += skippy
}
}
if iNdEx > l {
return io.ErrUnexpectedEOF
}
return nil
}
func skipHello(dAtA []byte) (n int, err error) {
l := len(dAtA)
iNdEx := 0
for iNdEx < l {
var wire uint64
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return 0, ErrIntOverflowHello
}
if iNdEx >= l {
return 0, io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
wire |= (uint64(b) & 0x7F) << shift
if b < 0x80 {
break
}
}
wireType := int(wire & 0x7)
switch wireType {
case 0:
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return 0, ErrIntOverflowHello
}
if iNdEx >= l {
return 0, io.ErrUnexpectedEOF
}
iNdEx++
if dAtA[iNdEx-1] < 0x80 {
break
}
}
return iNdEx, nil
case 1:
iNdEx += 8
return iNdEx, nil
case 2:
var length int
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return 0, ErrIntOverflowHello
}
if iNdEx >= l {
return 0, io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
length |= (int(b) & 0x7F) << shift
if b < 0x80 {
break
}
}
iNdEx += length
if length < 0 {
return 0, ErrInvalidLengthHello
}
return iNdEx, nil
case 3:
for {
var innerWire uint64
var start int = iNdEx
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return 0, ErrIntOverflowHello
}
if iNdEx >= l {
return 0, io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
innerWire |= (uint64(b) & 0x7F) << shift
if b < 0x80 {
break
}
}
innerWireType := int(innerWire & 0x7)
if innerWireType == 4 {
break
}
next, err := skipHello(dAtA[start:])
if err != nil {
return 0, err
}
iNdEx = start + next
}
return iNdEx, nil
case 4:
return iNdEx, nil
case 5:
iNdEx += 4
return iNdEx, nil
default:
return 0, fmt.Errorf("proto: illegal wireType %d", wireType)
}
}
panic("unreachable")
}
var (
ErrInvalidLengthHello = fmt.Errorf("proto: negative length found during unmarshaling")
ErrIntOverflowHello = fmt.Errorf("proto: integer overflow")
)
func init() { proto.RegisterFile("hello.proto", fileDescriptorHello) }
var fileDescriptorHello = []byte{
// 296 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x9c, 0x90, 0x3f, 0x4e, 0xc3, 0x30,
0x14, 0xc6, 0x63, 0xfe, 0xb5, 0x75, 0x19, 0x90, 0x11, 0x22, 0x2a, 0x92, 0x53, 0x79, 0xca, 0xd2,
0xb4, 0xa2, 0x1b, 0x02, 0x09, 0x85, 0x01, 0xe6, 0xf4, 0x04, 0x4e, 0xfa, 0x48, 0x23, 0x25, 0x75,
0x6a, 0x3b, 0x48, 0xb9, 0x03, 0x07, 0xe0, 0x48, 0x1d, 0x7b, 0x82, 0x88, 0x86, 0xad, 0x63, 0x4f,
0x80, 0x62, 0x28, 0x20, 0xb1, 0x75, 0x7b, 0x3f, 0x7f, 0xfa, 0x7e, 0x4f, 0x7e, 0xb8, 0x3b, 0x83,
0x34, 0x15, 0x5e, 0x2e, 0x85, 0x16, 0xa4, 0xa3, 0x41, 0x69, 0x33, 0xf6, 0x06, 0x71, 0xa2, 0x67,
0x45, 0xe8, 0x45, 0x22, 0x1b, 0xc6, 0x22, 0x16, 0x43, 0xf3, 0x1c, 0x16, 0xcf, 0x86, 0x0c, 0x98,
0xe9, 0xab, 0xc9, 0x24, 0x3e, 0x7d, 0x6a, 0x44, 0x01, 0x2c, 0x0a, 0x50, 0x9a, 0x8c, 0xf1, 0xd1,
0x9c, 0x67, 0x60, 0xa3, 0x3e, 0x72, 0x3b, 0xbe, 0xb3, 0xa9, 0x1c, 0xc3, 0xdb, 0xca, 0x39, 0x7f,
0xe1, 0x69, 0x32, 0xe5, 0x1a, 0x6e, 0x98, 0x84, 0x45, 0x91, 0x48, 0x98, 0xb2, 0xc0, 0x84, 0x64,
0x80, 0x0f, 0x79, 0x0c, 0xf6, 0x41, 0x1f, 0xb9, 0xc7, 0xfe, 0xd5, 0xa6, 0x72, 0x1a, 0xdc, 0x56,
0xce, 0xd9, 0x6f, 0x25, 0x4b, 0xe6, 0x77, 0x23, 0x16, 0x34, 0x01, 0xbb, 0xc7, 0xf8, 0x7b, 0x67,
0x9e, 0x96, 0xc4, 0xc6, 0xad, 0x0c, 0x94, 0x6a, 0x04, 0x66, 0x69, 0xb0, 0xc3, 0x26, 0x51, 0x45,
0x14, 0x81, 0x52, 0x46, 0xdd, 0x0e, 0x76, 0x78, 0xfd, 0x8a, 0x70, 0xeb, 0x51, 0x02, 0x68, 0x90,
0xe4, 0x16, 0xb7, 0x27, 0xbc, 0x34, 0x42, 0x72, 0xe9, 0xfd, 0x1c, 0xc2, 0xfb, 0xfb, 0xad, 0xde,
0xc5, 0xff, 0x20, 0x4f, 0x4b, 0x66, 0x91, 0x07, 0xdc, 0x9d, 0x68, 0x09, 0x3c, 0xdb, 0x53, 0xe0,
0xa2, 0x11, 0xf2, 0xed, 0xe5, 0x9a, 0x5a, 0xab, 0x35, 0xb5, 0x96, 0x35, 0x45, 0xab, 0x9a, 0xa2,
0xf7, 0x9a, 0xa2, 0xb7, 0x0f, 0x6a, 0x85, 0x27, 0xa6, 0x31, 0xfe, 0x0c, 0x00, 0x00, 0xff, 0xff,
0x13, 0x57, 0x88, 0x03, 0xae, 0x01, 0x00, 0x00,
}

@ -0,0 +1,33 @@
syntax = "proto3";
package testproto;
import "github.com/gogo/protobuf/gogoproto/gogo.proto";
option (gogoproto.goproto_enum_prefix_all) = false;
option (gogoproto.goproto_getters_all) = false;
option (gogoproto.unmarshaler_all) = true;
option (gogoproto.marshaler_all) = true;
option (gogoproto.sizer_all) = true;
option (gogoproto.goproto_registration) = true;
// The greeting service definition.
service Greeter {
// Sends a greeting
rpc SayHello (HelloRequest) returns (HelloReply) {}
// A bidirectional streaming RPC call recvice HelloRequest return HelloReply
rpc StreamHello(stream HelloRequest) returns (stream HelloReply) {}
}
// The request message containing the user's name.
message HelloRequest {
string name = 1 [(gogoproto.jsontag) = "name", (gogoproto.moretags) = "validate:\"required\""];
int32 age = 2 [(gogoproto.jsontag) = "age", (gogoproto.moretags) = "validate:\"min=0\""];
}
// The response message containing the greetings
message HelloReply {
string message = 1;
bool success = 2;
}

@ -0,0 +1,151 @@
package status
import (
"context"
"strconv"
"github.com/golang/protobuf/proto"
"github.com/golang/protobuf/ptypes"
"github.com/pkg/errors"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"github.com/bilibili/kratos/pkg/ecode"
"github.com/bilibili/kratos/pkg/ecode/pb"
)
// togRPCCode convert ecode.Codo to gRPC code
func togRPCCode(code ecode.Codes) codes.Code {
switch code.Code() {
case ecode.OK.Code():
return codes.OK
case ecode.RequestErr.Code():
return codes.InvalidArgument
case ecode.NothingFound.Code():
return codes.NotFound
case ecode.Unauthorized.Code():
return codes.Unauthenticated
case ecode.AccessDenied.Code():
return codes.PermissionDenied
case ecode.LimitExceed.Code():
return codes.ResourceExhausted
case ecode.MethodNotAllowed.Code():
return codes.Unimplemented
case ecode.Deadline.Code():
return codes.DeadlineExceeded
case ecode.ServiceUnavailable.Code():
return codes.Unavailable
}
return codes.Unknown
}
func toECode(gst *status.Status) ecode.Code {
gcode := gst.Code()
switch gcode {
case codes.OK:
return ecode.OK
case codes.InvalidArgument:
return ecode.RequestErr
case codes.NotFound:
return ecode.NothingFound
case codes.PermissionDenied:
return ecode.AccessDenied
case codes.Unauthenticated:
return ecode.Unauthorized
case codes.ResourceExhausted:
return ecode.LimitExceed
case codes.Unimplemented:
return ecode.MethodNotAllowed
case codes.DeadlineExceeded:
return ecode.Deadline
case codes.Unavailable:
return ecode.ServiceUnavailable
case codes.Unknown:
return ecode.String(gst.Message())
}
return ecode.ServerErr
}
// FromError convert error for service reply and try to convert it to grpc.Status.
func FromError(svrErr error) (gst *status.Status) {
var err error
svrErr = errors.Cause(svrErr)
if code, ok := svrErr.(ecode.Codes); ok {
// TODO: deal with err
if gst, err = gRPCStatusFromEcode(code); err == nil {
return
}
}
// for some special error convert context.Canceled to ecode.Canceled,
// context.DeadlineExceeded to ecode.DeadlineExceeded only for raw error
// if err be wrapped will not effect.
switch svrErr {
case context.Canceled:
gst, _ = gRPCStatusFromEcode(ecode.Canceled)
case context.DeadlineExceeded:
gst, _ = gRPCStatusFromEcode(ecode.Deadline)
default:
gst, _ = status.FromError(svrErr)
}
return
}
func gRPCStatusFromEcode(code ecode.Codes) (*status.Status, error) {
var st *ecode.Status
switch v := code.(type) {
// compatible old pb.Error remove it after nobody use pb.Error.
case *pb.Error:
return status.New(codes.Unknown, v.Error()).WithDetails(v)
case *ecode.Status:
st = v
case ecode.Code:
st = ecode.FromCode(v)
default:
st = ecode.Error(ecode.Code(code.Code()), code.Message())
for _, detail := range code.Details() {
if msg, ok := detail.(proto.Message); ok {
st.WithDetails(msg)
}
}
}
// gst := status.New(togRPCCode(st), st.Message())
// NOTE: compatible with PHP swoole gRPC put code in status message as string.
// gst := status.New(togRPCCode(st), strconv.Itoa(st.Code()))
gst := status.New(codes.Unknown, strconv.Itoa(st.Code()))
pbe := &pb.Error{ErrCode: int32(st.Code()), ErrMessage: gst.Message()}
// NOTE: server return ecode.Status will be covert to pb.Error details will be ignored
// and put it at details[0] for compatible old client
return gst.WithDetails(pbe, st.Proto())
}
// ToEcode convert grpc.status to ecode.Codes
func ToEcode(gst *status.Status) ecode.Codes {
details := gst.Details()
// reverse range details, details may contain three case,
// if details contain pb.Error and ecode.Status use eocde.Status first.
//
// Details layout:
// pb.Error [0: pb.Error]
// both pb.Error and ecode.Status [0: pb.Error, 1: ecode.Status]
// ecode.Status [0: ecode.Status]
for i := len(details) - 1; i >= 0; i-- {
detail := details[i]
// compatible with old pb.Error.
if pe, ok := detail.(*pb.Error); ok {
st := ecode.Error(ecode.Code(pe.ErrCode), pe.ErrMessage)
if pe.ErrDetail != nil {
dynMsg := new(ptypes.DynamicAny)
// TODO deal with unmarshalAny error.
if err := ptypes.UnmarshalAny(pe.ErrDetail, dynMsg); err == nil {
st, _ = st.WithDetails(dynMsg.Message)
}
}
return st
}
// convert detail to status only use first detail
if pb, ok := detail.(proto.Message); ok {
return ecode.FromProto(pb)
}
}
return toECode(gst)
}

@ -0,0 +1,164 @@
package status
import (
"context"
"errors"
"fmt"
"testing"
"time"
"github.com/golang/protobuf/ptypes"
"github.com/golang/protobuf/ptypes/timestamp"
pkgerr "github.com/pkg/errors"
"github.com/stretchr/testify/assert"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"github.com/bilibili/kratos/pkg/ecode"
"github.com/bilibili/kratos/pkg/ecode/pb"
)
func TestCodeConvert(t *testing.T) {
var table = map[codes.Code]ecode.Code{
codes.OK: ecode.OK,
// codes.Canceled
codes.Unknown: ecode.ServerErr,
codes.InvalidArgument: ecode.RequestErr,
codes.DeadlineExceeded: ecode.Deadline,
codes.NotFound: ecode.NothingFound,
// codes.AlreadyExists
codes.PermissionDenied: ecode.AccessDenied,
codes.ResourceExhausted: ecode.LimitExceed,
// codes.FailedPrecondition
// codes.Aborted
// codes.OutOfRange
codes.Unimplemented: ecode.MethodNotAllowed,
codes.Unavailable: ecode.ServiceUnavailable,
// codes.DataLoss
codes.Unauthenticated: ecode.Unauthorized,
}
for k, v := range table {
assert.Equal(t, toECode(status.New(k, "-500")), v)
}
for k, v := range table {
assert.Equal(t, togRPCCode(v), k, fmt.Sprintf("togRPC code error: %d -> %d", v, k))
}
}
func TestNoDetailsConvert(t *testing.T) {
gst := status.New(codes.Unknown, "-2233")
assert.Equal(t, toECode(gst).Code(), -2233)
gst = status.New(codes.Internal, "")
assert.Equal(t, toECode(gst).Code(), -500)
}
func TestFromError(t *testing.T) {
t.Run("input general error", func(t *testing.T) {
err := errors.New("general error")
gst := FromError(err)
assert.Equal(t, codes.Unknown, gst.Code())
assert.Contains(t, gst.Message(), "general")
})
t.Run("input wrap error", func(t *testing.T) {
err := pkgerr.Wrap(ecode.RequestErr, "hh")
gst := FromError(err)
assert.Equal(t, "-400", gst.Message())
})
t.Run("input ecode.Code", func(t *testing.T) {
err := ecode.RequestErr
gst := FromError(err)
//assert.Equal(t, codes.InvalidArgument, gst.Code())
// NOTE: set all grpc.status as Unkown when error is ecode.Codes for compatible
assert.Equal(t, codes.Unknown, gst.Code())
// NOTE: gst.Message == str(ecode.Code) for compatible php leagcy code
assert.Equal(t, err.Message(), gst.Message())
})
t.Run("input raw Canceled", func(t *testing.T) {
gst := FromError(context.Canceled)
assert.Equal(t, codes.Unknown, gst.Code())
assert.Equal(t, "-498", gst.Message())
})
t.Run("input raw DeadlineExceeded", func(t *testing.T) {
gst := FromError(context.DeadlineExceeded)
assert.Equal(t, codes.Unknown, gst.Code())
assert.Equal(t, "-504", gst.Message())
})
t.Run("input pb.Error", func(t *testing.T) {
m := &timestamp.Timestamp{Seconds: time.Now().Unix()}
detail, _ := ptypes.MarshalAny(m)
err := &pb.Error{ErrCode: 2233, ErrMessage: "message", ErrDetail: detail}
gst := FromError(err)
assert.Equal(t, codes.Unknown, gst.Code())
assert.Len(t, gst.Details(), 1)
assert.Equal(t, "2233", gst.Message())
})
t.Run("input ecode.Status", func(t *testing.T) {
m := &timestamp.Timestamp{Seconds: time.Now().Unix()}
err, _ := ecode.Error(ecode.Unauthorized, "unauthorized").WithDetails(m)
gst := FromError(err)
//assert.Equal(t, codes.Unauthenticated, gst.Code())
// NOTE: set all grpc.status as Unkown when error is ecode.Codes for compatible
assert.Equal(t, codes.Unknown, gst.Code())
assert.Len(t, gst.Details(), 2)
details := gst.Details()
assert.IsType(t, &pb.Error{}, details[0])
assert.IsType(t, err.Proto(), details[1])
})
}
func TestToEcode(t *testing.T) {
t.Run("input general grpc.Status", func(t *testing.T) {
gst := status.New(codes.Unknown, "unknown")
ec := ToEcode(gst)
assert.Equal(t, int(ecode.ServerErr), ec.Code())
assert.Equal(t, "-500", ec.Message())
assert.Len(t, ec.Details(), 0)
})
t.Run("input pb.Error", func(t *testing.T) {
m := &timestamp.Timestamp{Seconds: time.Now().Unix()}
detail, _ := ptypes.MarshalAny(m)
gst := status.New(codes.InvalidArgument, "requesterr")
gst, _ = gst.WithDetails(&pb.Error{ErrCode: 1122, ErrMessage: "message", ErrDetail: detail})
ec := ToEcode(gst)
assert.Equal(t, 1122, ec.Code())
assert.Equal(t, "message", ec.Message())
assert.Len(t, ec.Details(), 1)
assert.IsType(t, m, ec.Details()[0])
})
t.Run("input pb.Error and ecode.Status", func(t *testing.T) {
gst := status.New(codes.InvalidArgument, "requesterr")
gst, _ = gst.WithDetails(
&pb.Error{ErrCode: 1122, ErrMessage: "message"},
ecode.Errorf(ecode.AccessKeyErr, "AccessKeyErr").Proto(),
)
ec := ToEcode(gst)
assert.Equal(t, int(ecode.AccessKeyErr), ec.Code())
assert.Equal(t, "AccessKeyErr", ec.Message())
})
t.Run("input encode.Status", func(t *testing.T) {
m := &timestamp.Timestamp{Seconds: time.Now().Unix()}
st, _ := ecode.Errorf(ecode.AccessKeyErr, "AccessKeyErr").WithDetails(m)
gst := status.New(codes.InvalidArgument, "requesterr")
gst, _ = gst.WithDetails(st.Proto())
ec := ToEcode(gst)
assert.Equal(t, int(ecode.AccessKeyErr), ec.Code())
assert.Equal(t, "AccessKeyErr", ec.Message())
assert.Len(t, ec.Details(), 1)
assert.IsType(t, m, ec.Details()[0])
})
}

@ -0,0 +1,118 @@
package warden
import (
"context"
"fmt"
"strconv"
"time"
"google.golang.org/grpc"
"google.golang.org/grpc/peer"
"github.com/bilibili/kratos/pkg/ecode"
"github.com/bilibili/kratos/pkg/log"
"github.com/bilibili/kratos/pkg/net/metadata"
"github.com/bilibili/kratos/pkg/stat"
)
var (
statsClient = stat.RPCClient
statsServer = stat.RPCServer
)
func logFn(code int, dt time.Duration) func(context.Context, ...log.D) {
switch {
case code < 0:
return log.Errorv
case dt >= time.Millisecond*500:
// TODO: slowlog make it configurable.
return log.Warnv
case code > 0:
return log.Warnv
}
return log.Infov
}
// clientLogging warden grpc logging
func clientLogging() grpc.UnaryClientInterceptor {
return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
startTime := time.Now()
var peerInfo peer.Peer
opts = append(opts, grpc.Peer(&peerInfo))
// invoker requests
err := invoker(ctx, method, req, reply, cc, opts...)
// after request
code := ecode.Cause(err).Code()
duration := time.Since(startTime)
// monitor
statsClient.Timing(method, int64(duration/time.Millisecond))
statsClient.Incr(method, strconv.Itoa(code))
var ip string
if peerInfo.Addr != nil {
ip = peerInfo.Addr.String()
}
logFields := []log.D{
log.KVString("ip", ip),
log.KVString("path", method),
log.KVInt("ret", code),
// TODO: it will panic if someone remove String method from protobuf message struct that auto generate from protoc.
log.KVString("args", req.(fmt.Stringer).String()),
log.KVFloat64("ts", duration.Seconds()),
log.KVString("source", "grpc-access-log"),
}
if err != nil {
logFields = append(logFields, log.KV("error", err.Error()), log.KVString("stack", fmt.Sprintf("%+v", err)))
}
logFn(code, duration)(ctx, logFields...)
return err
}
}
// serverLogging warden grpc logging
func serverLogging() grpc.UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
startTime := time.Now()
caller := metadata.String(ctx, metadata.Caller)
if caller == "" {
caller = "no_user"
}
var remoteIP string
if peerInfo, ok := peer.FromContext(ctx); ok {
remoteIP = peerInfo.Addr.String()
}
var quota float64
if deadline, ok := ctx.Deadline(); ok {
quota = time.Until(deadline).Seconds()
}
// call server handler
resp, err := handler(ctx, req)
// after server response
code := ecode.Cause(err).Code()
duration := time.Since(startTime)
// monitor
statsServer.Timing(caller, int64(duration/time.Millisecond), info.FullMethod)
statsServer.Incr(caller, info.FullMethod, strconv.Itoa(code))
logFields := []log.D{
log.KVString("user", caller),
log.KVString("ip", remoteIP),
log.KVString("path", info.FullMethod),
log.KVInt("ret", code),
// TODO: it will panic if someone remove String method from protobuf message struct that auto generate from protoc.
log.KVString("args", req.(fmt.Stringer).String()),
log.KVFloat64("ts", duration.Seconds()),
log.KVFloat64("timeout_quota", quota),
log.KVString("source", "grpc-access-log"),
}
if err != nil {
logFields = append(logFields, log.KV("error", err.Error()), log.KV("stack", fmt.Sprintf("%+v", err)))
}
logFn(code, duration)(ctx, logFields...)
return resp, err
}
}

@ -0,0 +1,55 @@
package warden
import (
"context"
"reflect"
"testing"
"time"
"github.com/bilibili/kratos/pkg/log"
)
func Test_logFn(t *testing.T) {
type args struct {
code int
dt time.Duration
}
tests := []struct {
name string
args args
want func(context.Context, ...log.D)
}{
{
name: "ok",
args: args{code: 0, dt: time.Millisecond},
want: log.Infov,
},
{
name: "slowlog",
args: args{code: 0, dt: time.Second},
want: log.Warnv,
},
{
name: "business error",
args: args{code: 2233, dt: time.Millisecond},
want: log.Warnv,
},
{
name: "system error",
args: args{code: -1, dt: 0},
want: log.Errorv,
},
{
name: "system error and slowlog",
args: args{code: -1, dt: time.Second},
want: log.Errorv,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := logFn(tt.args.code, tt.args.dt); reflect.ValueOf(got).Pointer() != reflect.ValueOf(tt.want).Pointer() {
t.Errorf("unexpect log function!")
}
})
}
}

@ -0,0 +1,61 @@
package warden
import (
"context"
"fmt"
"os"
"runtime"
"github.com/bilibili/kratos/pkg/ecode"
"github.com/bilibili/kratos/pkg/log"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
// recovery is a server interceptor that recovers from any panics.
func (s *Server) recovery() grpc.UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, args *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
defer func() {
if rerr := recover(); rerr != nil {
const size = 64 << 10
buf := make([]byte, size)
rs := runtime.Stack(buf, false)
if rs > size {
rs = size
}
buf = buf[:rs]
pl := fmt.Sprintf("grpc server panic: %v\n%v\n%s\n", req, rerr, buf)
fmt.Fprintf(os.Stderr, pl)
log.Error(pl)
err = status.Errorf(codes.Unknown, ecode.ServerErr.Error())
}
}()
resp, err = handler(ctx, req)
return
}
}
// recovery return a client interceptor that recovers from any panics.
func (c *Client) recovery() grpc.UnaryClientInterceptor {
return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) (err error) {
defer func() {
if rerr := recover(); rerr != nil {
const size = 64 << 10
buf := make([]byte, size)
rs := runtime.Stack(buf, false)
if rs > size {
rs = size
}
buf = buf[:rs]
pl := fmt.Sprintf("grpc client panic: %v\n%v\n%v\n%s\n", req, reply, rerr, buf)
fmt.Fprintf(os.Stderr, pl)
log.Error(pl)
err = ecode.ServerErr
}
}()
err = invoker(ctx, method, req, reply, cc, opts...)
return
}
}

@ -0,0 +1,17 @@
### business/warden/resolver
##### Version 1.1.1
1. add dial helper
##### Version 1.1.0
1. 增加了子集选择算法
##### Version 1.0.2
1. 增加GET接口
##### Version 1.0.1
1. 支持zone和clusters
##### Version 1.0.0
1. 实现了基本的服务发现功能

@ -0,0 +1,9 @@
# See the OWNERS docs at https://go.k8s.io/owners
approvers:
- caoguoliang
labels:
- library
reviewers:
- caoguoliang
- maojian

@ -0,0 +1,13 @@
#### business/warden/resolver
##### 项目简介
warden 的 服务发现模块,用于从底层的注册中心中获取Server节点列表并返回给GRPC
##### 编译环境
- **请只用 Golang v1.9.x 以上版本编译执行**
##### 依赖包
- [grpc](google.golang.org/grpc)

@ -0,0 +1,6 @@
### business/warden/resolver/direct
##### Version 1.0.0
1. 实现了基本的服务发现直连功能

@ -0,0 +1,14 @@
#### business/warden/resolver/direct
##### 项目简介
warden 的直连服务模块,用于通过IP地址列表直接连接后端服务
连接字符串格式: direct://default/192.168.1.1:8080,192.168.1.2:8081
##### 编译环境
- **请只用 Golang v1.9.x 以上版本编译执行**
##### 依赖包
- [grpc](google.golang.org/grpc)

@ -0,0 +1,77 @@
package direct
import (
"context"
"fmt"
"strings"
"github.com/bilibili/kratos/pkg/conf/env"
"github.com/bilibili/kratos/pkg/naming"
"github.com/bilibili/kratos/pkg/net/rpc/warden/resolver"
)
const (
// Name is the name of direct resolver
Name = "direct"
)
var _ naming.Resolver = &Direct{}
// New return Direct
func New() *Direct {
return &Direct{}
}
// Build build direct.
func Build(id string) *Direct {
return &Direct{id: id}
}
// Direct is a resolver for conneting endpoints directly.
// example format: direct://default/192.168.1.1:8080,192.168.1.2:8081
type Direct struct {
id string
}
// Build direct build.
func (d *Direct) Build(id string) naming.Resolver {
return &Direct{id: id}
}
// Scheme return the Scheme of Direct
func (d *Direct) Scheme() string {
return Name
}
// Watch a tree
func (d *Direct) Watch() <-chan struct{} {
ch := make(chan struct{}, 1)
ch <- struct{}{}
return ch
}
//Unwatch a tree
func (d *Direct) Unwatch(id string) {
}
//Fetch fetch isntances
func (d *Direct) Fetch(ctx context.Context) (insMap map[string][]*naming.Instance, found bool) {
var ins []*naming.Instance
addrs := strings.Split(d.id, ",")
for _, addr := range addrs {
ins = append(ins, &naming.Instance{
Addrs: []string{fmt.Sprintf("%s://%s", resolver.Scheme, addr)},
})
}
if len(ins) > 0 {
found = true
}
insMap = map[string][]*naming.Instance{env.Zone: ins}
return
}
//Close close Direct
func (d *Direct) Close() error {
return nil
}

@ -0,0 +1,85 @@
package direct
import (
"context"
"fmt"
"os"
"testing"
"time"
"github.com/bilibili/kratos/pkg/net/netutil/breaker"
"github.com/bilibili/kratos/pkg/net/rpc/warden"
pb "github.com/bilibili/kratos/pkg/net/rpc/warden/internal/proto/testproto"
"github.com/bilibili/kratos/pkg/net/rpc/warden/resolver"
xtime "github.com/bilibili/kratos/pkg/time"
)
type testServer struct {
name string
}
func (ts *testServer) SayHello(context.Context, *pb.HelloRequest) (*pb.HelloReply, error) {
return &pb.HelloReply{Message: ts.name, Success: true}, nil
}
func (ts *testServer) StreamHello(ss pb.Greeter_StreamHelloServer) error {
panic("not implement error")
}
func createServer(name, listen string) *warden.Server {
s := warden.NewServer(&warden.ServerConfig{Timeout: xtime.Duration(time.Second)})
ts := &testServer{name}
pb.RegisterGreeterServer(s.Server(), ts)
go func() {
if err := s.Run(listen); err != nil {
panic(fmt.Sprintf("run warden server fail! err: %s", err))
}
}()
return s
}
func TestMain(m *testing.M) {
resolver.Register(New())
ctx := context.TODO()
s1 := createServer("server1", "127.0.0.1:18081")
s2 := createServer("server2", "127.0.0.1:18082")
defer s1.Shutdown(ctx)
defer s2.Shutdown(ctx)
os.Exit(m.Run())
}
func createTestClient(t *testing.T, connStr string) pb.GreeterClient {
client := warden.NewClient(&warden.ClientConfig{
Dial: xtime.Duration(time.Second * 10),
Timeout: xtime.Duration(time.Second * 10),
Breaker: &breaker.Config{
Window: xtime.Duration(3 * time.Second),
Sleep: xtime.Duration(3 * time.Second),
Bucket: 10,
Ratio: 0.3,
Request: 20,
},
})
conn, err := client.Dial(context.TODO(), connStr)
if err != nil {
t.Fatalf("create client fail!err%s", err)
}
return pb.NewGreeterClient(conn)
}
func TestDirect(t *testing.T) {
cli := createTestClient(t, "direct://default/127.0.0.1:18083,127.0.0.1:18082")
count := 0
for i := 0; i < 10; i++ {
if resp, err := cli.SayHello(context.TODO(), &pb.HelloRequest{Age: 1, Name: "hello"}); err != nil {
t.Fatalf("TestDirect: SayHello failed!err:=%v", err)
} else {
if resp.Message == "server2" {
count++
}
}
}
if count != 10 {
t.Fatalf("TestDirect: get server2 times must be 10")
}
}

@ -0,0 +1,204 @@
package resolver
import (
"context"
"fmt"
"math/rand"
"net/url"
"os"
"sort"
"strconv"
"strings"
"sync"
"github.com/bilibili/kratos/pkg/conf/env"
"github.com/bilibili/kratos/pkg/log"
"github.com/bilibili/kratos/pkg/naming"
wmeta "github.com/bilibili/kratos/pkg/net/rpc/warden/internal/metadata"
"github.com/dgryski/go-farm"
"github.com/pkg/errors"
"google.golang.org/grpc/resolver"
)
const (
// Scheme is the scheme of discovery address
Scheme = "grpc"
)
var (
_ resolver.Resolver = &Resolver{}
_ resolver.Builder = &Builder{}
mu sync.Mutex
)
// Register register resolver builder if nil.
func Register(b naming.Builder) {
mu.Lock()
defer mu.Unlock()
if resolver.Get(b.Scheme()) == nil {
resolver.Register(&Builder{b})
}
}
// Set override any registered builder
func Set(b naming.Builder) {
mu.Lock()
defer mu.Unlock()
resolver.Register(&Builder{b})
}
// Builder is also a resolver builder.
// It's build() function always returns itself.
type Builder struct {
naming.Builder
}
// Build returns itself for Resolver, because it's both a builder and a resolver.
func (b *Builder) Build(target resolver.Target, cc resolver.ClientConn, opts resolver.BuildOption) (resolver.Resolver, error) {
var zone = env.Zone
ss := int64(50)
clusters := map[string]struct{}{}
str := strings.SplitN(target.Endpoint, "?", 2)
if len(str) == 0 {
return nil, errors.Errorf("warden resolver: parse target.Endpoint(%s) failed!err:=endpoint is empty", target.Endpoint)
} else if len(str) == 2 {
m, err := url.ParseQuery(str[1])
if err == nil {
for _, c := range m[naming.MetaCluster] {
clusters[c] = struct{}{}
}
zones := m[naming.MetaZone]
if len(zones) > 0 {
zone = zones[0]
}
if sub, ok := m["subset"]; ok {
if t, err := strconv.ParseInt(sub[0], 10, 64); err == nil {
ss = t
}
}
}
}
r := &Resolver{
nr: b.Builder.Build(str[0]),
cc: cc,
quit: make(chan struct{}, 1),
clusters: clusters,
zone: zone,
subsetSize: ss,
}
go r.updateproc()
return r, nil
}
// Resolver watches for the updates on the specified target.
// Updates include address updates and service config updates.
type Resolver struct {
nr naming.Resolver
cc resolver.ClientConn
quit chan struct{}
clusters map[string]struct{}
zone string
subsetSize int64
}
// Close is a noop for Resolver.
func (r *Resolver) Close() {
select {
case r.quit <- struct{}{}:
r.nr.Close()
default:
}
}
// ResolveNow is a noop for Resolver.
func (r *Resolver) ResolveNow(o resolver.ResolveNowOption) {
}
func (r *Resolver) updateproc() {
event := r.nr.Watch()
for {
select {
case <-r.quit:
return
case _, ok := <-event:
if !ok {
return
}
}
if insInfo, ok := r.nr.Fetch(context.Background()); ok {
instances, ok := insInfo.Instances[r.zone]
if !ok {
for _, value := range insInfo.Instances {
instances = append(instances, value...)
}
}
if r.subsetSize > 0 && len(instances) > 0 {
instances = r.subset(instances, env.Hostname, r.subsetSize)
}
r.newAddress(instances)
}
}
}
func (r *Resolver) subset(backends []*naming.Instance, clientID string, size int64) []*naming.Instance {
if len(backends) <= int(size) {
return backends
}
sort.Slice(backends, func(i, j int) bool {
return backends[i].Hostname < backends[j].Hostname
})
count := int64(len(backends)) / size
id := farm.Fingerprint64([]byte(clientID))
round := int64(id / uint64(count))
s := rand.NewSource(round)
ra := rand.New(s)
ra.Shuffle(len(backends), func(i, j int) {
backends[i], backends[j] = backends[j], backends[i]
})
start := (id % uint64(count)) * uint64(size)
return backends[int(start) : int(start)+int(size)]
}
func (r *Resolver) newAddress(instances []*naming.Instance) {
if len(instances) <= 0 {
return
}
addrs := make([]resolver.Address, 0, len(instances))
for _, ins := range instances {
if len(r.clusters) > 0 {
if _, ok := r.clusters[ins.Metadata[naming.MetaCluster]]; !ok {
continue
}
}
var weight int64
if weight, _ = strconv.ParseInt(ins.Metadata[naming.MetaWeight], 10, 64); weight <= 0 {
weight = 10
}
var rpc string
for _, a := range ins.Addrs {
u, err := url.Parse(a)
if err == nil && u.Scheme == Scheme {
rpc = u.Host
}
}
if rpc == "" {
fmt.Fprintf(os.Stderr, "warden/resolver: app(%s,%s) no valid grpc address(%v) found!", ins.AppID, ins.Hostname, ins.Addrs)
log.Warn("warden/resolver: invalid rpc address(%s,%s,%v) found!", ins.AppID, ins.Hostname, ins.Addrs)
continue
}
addr := resolver.Address{
Addr: rpc,
Type: resolver.Backend,
ServerName: ins.AppID,
Metadata: wmeta.MD{Weight: uint64(weight), Color: ins.Metadata[naming.MetaColor]},
}
addrs = append(addrs, addr)
}
r.cc.NewAddress(addrs)
}

@ -0,0 +1,87 @@
package resolver
import (
"context"
"github.com/bilibili/kratos/pkg/conf/env"
"github.com/bilibili/kratos/pkg/naming"
)
type mockDiscoveryBuilder struct {
instances map[string]*naming.Instance
watchch map[string][]*mockDiscoveryResolver
}
func (mb *mockDiscoveryBuilder) Build(id string) naming.Resolver {
mr := &mockDiscoveryResolver{
d: mb,
watchch: make(chan struct{}, 1),
}
mb.watchch[id] = append(mb.watchch[id], mr)
mr.watchch <- struct{}{}
return mr
}
func (mb *mockDiscoveryBuilder) Scheme() string {
return "mockdiscovery"
}
type mockDiscoveryResolver struct {
//instances map[string]*naming.Instance
d *mockDiscoveryBuilder
watchch chan struct{}
}
var _ naming.Resolver = &mockDiscoveryResolver{}
func (md *mockDiscoveryResolver) Fetch(ctx context.Context) (map[string][]*naming.Instance, bool) {
zones := make(map[string][]*naming.Instance)
for _, v := range md.d.instances {
zones[v.Zone] = append(zones[v.Zone], v)
}
return zones, len(zones) > 0
}
func (md *mockDiscoveryResolver) Watch() <-chan struct{} {
return md.watchch
}
func (md *mockDiscoveryResolver) Close() error {
close(md.watchch)
return nil
}
func (md *mockDiscoveryResolver) Scheme() string {
return "mockdiscovery"
}
func (mb *mockDiscoveryBuilder) registry(appID string, hostname, rpc string, metadata map[string]string) {
ins := &naming.Instance{
AppID: appID,
Env: "hello=world",
Hostname: hostname,
Addrs: []string{"grpc://" + rpc},
Version: "1.1",
Zone: env.Zone,
Metadata: metadata,
}
mb.instances[hostname] = ins
if ch, ok := mb.watchch[appID]; ok {
var bullet struct{}
for _, c := range ch {
c.watchch <- bullet
}
}
}
func (mb *mockDiscoveryBuilder) cancel(hostname string) {
ins, ok := mb.instances[hostname]
if !ok {
return
}
delete(mb.instances, hostname)
if ch, ok := mb.watchch[ins.AppID]; ok {
var bullet struct{}
for _, c := range ch {
c.watchch <- bullet
}
}
}

@ -0,0 +1,312 @@
package resolver
import (
"context"
"fmt"
"os"
"testing"
"time"
"github.com/bilibili/kratos/pkg/conf/env"
"github.com/bilibili/kratos/pkg/naming"
"github.com/bilibili/kratos/pkg/net/netutil/breaker"
"github.com/bilibili/kratos/pkg/net/rpc/warden"
pb "github.com/bilibili/kratos/pkg/net/rpc/warden/internal/proto/testproto"
"github.com/bilibili/kratos/pkg/net/rpc/warden/resolver"
xtime "github.com/bilibili/kratos/pkg/time"
"github.com/stretchr/testify/assert"
)
var testServerMap map[string]*testServer
func init() {
testServerMap = make(map[string]*testServer)
}
const testAppID = "main.test"
type testServer struct {
SayHelloCount int
}
func resetCount() {
for _, s := range testServerMap {
s.SayHelloCount = 0
}
}
func (ts *testServer) SayHello(context.Context, *pb.HelloRequest) (*pb.HelloReply, error) {
ts.SayHelloCount++
return &pb.HelloReply{Message: "hello", Success: true}, nil
}
func (ts *testServer) StreamHello(ss pb.Greeter_StreamHelloServer) error {
panic("not implement error")
}
func createServer(name, listen string) *warden.Server {
s := warden.NewServer(&warden.ServerConfig{Timeout: xtime.Duration(time.Second)})
ts := &testServer{}
testServerMap[name] = ts
pb.RegisterGreeterServer(s.Server(), ts)
go func() {
if err := s.Run(listen); err != nil {
panic(fmt.Sprintf("run warden server fail! err: %s", err))
}
}()
return s
}
func NSayHello(c pb.GreeterClient, n int) func(*testing.T) {
return func(t *testing.T) {
for i := 0; i < n; i++ {
if _, err := c.SayHello(context.TODO(), &pb.HelloRequest{Age: 1, Name: "hello"}); err != nil {
t.Fatalf("call sayhello fail! err: %s", err)
}
}
}
}
func createTestClient(t *testing.T) pb.GreeterClient {
client := warden.NewClient(&warden.ClientConfig{
Dial: xtime.Duration(time.Second * 10),
Timeout: xtime.Duration(time.Second * 10),
Breaker: &breaker.Config{
Window: xtime.Duration(3 * time.Second),
Sleep: xtime.Duration(3 * time.Second),
Bucket: 10,
Ratio: 0.3,
Request: 20,
},
})
conn, err := client.Dial(context.TODO(), "mockdiscovery://authority/main.test")
if err != nil {
t.Fatalf("create client fail!err%s", err)
}
return pb.NewGreeterClient(conn)
}
var mockResolver *mockDiscoveryBuilder
func newMockDiscoveryBuilder() *mockDiscoveryBuilder {
return &mockDiscoveryBuilder{
instances: make(map[string]*naming.Instance),
watchch: make(map[string][]*mockDiscoveryResolver),
}
}
func TestMain(m *testing.M) {
ctx := context.TODO()
mockResolver = newMockDiscoveryBuilder()
resolver.Set(mockResolver)
s1 := createServer("server1", "127.0.0.1:18081")
s2 := createServer("server2", "127.0.0.1:18082")
s3 := createServer("server3", "127.0.0.1:18083")
s4 := createServer("server4", "127.0.0.1:18084")
s5 := createServer("server5", "127.0.0.1:18085")
defer s1.Shutdown(ctx)
defer s2.Shutdown(ctx)
defer s3.Shutdown(ctx)
defer s4.Shutdown(ctx)
defer s5.Shutdown(ctx)
os.Exit(m.Run())
}
func TestAddResolver(t *testing.T) {
mockResolver.registry(testAppID, "server1", "127.0.0.1:18081", map[string]string{})
c := createTestClient(t)
t.Run("test_say_hello", NSayHello(c, 10))
assert.Equal(t, 10, testServerMap["server1"].SayHelloCount)
resetCount()
}
func TestDeleteResolver(t *testing.T) {
mockResolver.registry(testAppID, "server1", "127.0.0.1:18081", map[string]string{})
mockResolver.registry(testAppID, "server2", "127.0.0.1:18082", map[string]string{})
c := createTestClient(t)
t.Run("test_say_hello", NSayHello(c, 10))
assert.Equal(t, 10, testServerMap["server1"].SayHelloCount+testServerMap["server2"].SayHelloCount)
mockResolver.cancel("server1")
resetCount()
time.Sleep(time.Millisecond * 10)
t.Run("test_say_hello", NSayHello(c, 10))
assert.Equal(t, 0, testServerMap["server1"].SayHelloCount)
resetCount()
}
func TestUpdateResolver(t *testing.T) {
mockResolver.registry(testAppID, "server1", "127.0.0.1:18081", map[string]string{})
mockResolver.registry(testAppID, "server2", "127.0.0.1:18082", map[string]string{})
c := createTestClient(t)
t.Run("test_say_hello", NSayHello(c, 10))
assert.Equal(t, 10, testServerMap["server1"].SayHelloCount+testServerMap["server2"].SayHelloCount)
mockResolver.registry(testAppID, "server1", "127.0.0.1:18083", map[string]string{})
mockResolver.registry(testAppID, "server2", "127.0.0.1:18084", map[string]string{})
resetCount()
time.Sleep(time.Millisecond * 10)
t.Run("test_say_hello", NSayHello(c, 10))
assert.Equal(t, 0, testServerMap["server1"].SayHelloCount+testServerMap["server2"].SayHelloCount)
assert.Equal(t, 10, testServerMap["server3"].SayHelloCount+testServerMap["server4"].SayHelloCount)
resetCount()
}
func TestErrorResolver(t *testing.T) {
mockResolver := newMockDiscoveryBuilder()
resolver.Set(mockResolver)
mockResolver.registry(testAppID, "server1", "127.0.0.1:18081", map[string]string{})
mockResolver.registry(testAppID, "server6", "127.0.0.1:18086", map[string]string{})
c := createTestClient(t)
t.Run("test_say_hello", NSayHello(c, 10))
assert.Equal(t, 10, testServerMap["server1"].SayHelloCount)
resetCount()
}
func TestClusterResolver(t *testing.T) {
mockResolver := newMockDiscoveryBuilder()
resolver.Set(mockResolver)
mockResolver.registry(testAppID, "server1", "127.0.0.1:18081", map[string]string{"cluster": "c1"})
mockResolver.registry(testAppID, "server2", "127.0.0.1:18082", map[string]string{"cluster": "c1"})
mockResolver.registry(testAppID, "server3", "127.0.0.1:18083", map[string]string{"cluster": "c2"})
mockResolver.registry(testAppID, "server4", "127.0.0.1:18084", map[string]string{})
mockResolver.registry(testAppID, "server5", "127.0.0.1:18084", map[string]string{})
client := warden.NewClient(&warden.ClientConfig{Clusters: []string{"c1"}})
conn, err := client.Dial(context.TODO(), "mockdiscovery://authority/main.test?cluster=c2")
if err != nil {
t.Fatalf("create client fail!err%s", err)
}
time.Sleep(time.Millisecond * 10)
cli := pb.NewGreeterClient(conn)
if _, err := cli.SayHello(context.TODO(), &pb.HelloRequest{Age: 1, Name: "hello"}); err != nil {
t.Fatalf("call sayhello fail! err: %s", err)
}
if _, err := cli.SayHello(context.TODO(), &pb.HelloRequest{Age: 1, Name: "hello"}); err != nil {
t.Fatalf("call sayhello fail! err: %s", err)
}
if _, err := cli.SayHello(context.TODO(), &pb.HelloRequest{Age: 1, Name: "hello"}); err != nil {
t.Fatalf("call sayhello fail! err: %s", err)
}
assert.Equal(t, 1, testServerMap["server1"].SayHelloCount)
assert.Equal(t, 1, testServerMap["server2"].SayHelloCount)
assert.Equal(t, 1, testServerMap["server3"].SayHelloCount)
resetCount()
}
func TestNoClusterResolver(t *testing.T) {
mockResolver := newMockDiscoveryBuilder()
resolver.Set(mockResolver)
mockResolver.registry(testAppID, "server1", "127.0.0.1:18081", map[string]string{"cluster": "c1"})
mockResolver.registry(testAppID, "server2", "127.0.0.1:18082", map[string]string{"cluster": "c1"})
mockResolver.registry(testAppID, "server3", "127.0.0.1:18083", map[string]string{"cluster": "c2"})
mockResolver.registry(testAppID, "server4", "127.0.0.1:18084", map[string]string{})
client := warden.NewClient(&warden.ClientConfig{})
conn, err := client.Dial(context.TODO(), "mockdiscovery://authority/main.test")
if err != nil {
t.Fatalf("create client fail!err%s", err)
}
time.Sleep(time.Millisecond * 20)
cli := pb.NewGreeterClient(conn)
if _, err := cli.SayHello(context.TODO(), &pb.HelloRequest{Age: 1, Name: "hello"}); err != nil {
t.Fatalf("call sayhello fail! err: %s", err)
}
if _, err := cli.SayHello(context.TODO(), &pb.HelloRequest{Age: 1, Name: "hello"}); err != nil {
t.Fatalf("call sayhello fail! err: %s", err)
}
if _, err := cli.SayHello(context.TODO(), &pb.HelloRequest{Age: 1, Name: "hello"}); err != nil {
t.Fatalf("call sayhello fail! err: %s", err)
}
if _, err := cli.SayHello(context.TODO(), &pb.HelloRequest{Age: 1, Name: "hello"}); err != nil {
t.Fatalf("call sayhello fail! err: %s", err)
}
assert.Equal(t, 1, testServerMap["server1"].SayHelloCount)
assert.Equal(t, 1, testServerMap["server2"].SayHelloCount)
assert.Equal(t, 1, testServerMap["server3"].SayHelloCount)
assert.Equal(t, 1, testServerMap["server4"].SayHelloCount)
resetCount()
}
func TestZoneResolver(t *testing.T) {
mockResolver := newMockDiscoveryBuilder()
resolver.Set(mockResolver)
mockResolver.registry(testAppID, "server1", "127.0.0.1:18081", map[string]string{})
env.Zone = "testsh"
mockResolver.registry(testAppID, "server2", "127.0.0.1:18082", map[string]string{})
env.Zone = "hhhh"
client := warden.NewClient(&warden.ClientConfig{Zone: "testsh"})
conn, err := client.Dial(context.TODO(), "mockdiscovery://authority/main.test")
if err != nil {
t.Fatalf("create client fail!err%s", err)
}
time.Sleep(time.Millisecond * 10)
cli := pb.NewGreeterClient(conn)
if _, err := cli.SayHello(context.TODO(), &pb.HelloRequest{Age: 1, Name: "hello"}); err != nil {
t.Fatalf("call sayhello fail! err: %s", err)
}
if _, err := cli.SayHello(context.TODO(), &pb.HelloRequest{Age: 1, Name: "hello"}); err != nil {
t.Fatalf("call sayhello fail! err: %s", err)
}
if _, err := cli.SayHello(context.TODO(), &pb.HelloRequest{Age: 1, Name: "hello"}); err != nil {
t.Fatalf("call sayhello fail! err: %s", err)
}
assert.Equal(t, 0, testServerMap["server1"].SayHelloCount)
assert.Equal(t, 3, testServerMap["server2"].SayHelloCount)
resetCount()
}
func TestSubsetConn(t *testing.T) {
mockResolver := newMockDiscoveryBuilder()
resolver.Set(mockResolver)
mockResolver.registry(testAppID, "server1", "127.0.0.1:18081", map[string]string{})
mockResolver.registry(testAppID, "server2", "127.0.0.1:18082", map[string]string{})
mockResolver.registry(testAppID, "server3", "127.0.0.1:18083", map[string]string{})
mockResolver.registry(testAppID, "server4", "127.0.0.1:18084", map[string]string{})
mockResolver.registry(testAppID, "server5", "127.0.0.1:18085", map[string]string{})
client := warden.NewClient(nil)
conn, err := client.Dial(context.TODO(), "mockdiscovery://authority/main.test?subset=3")
if err != nil {
t.Fatalf("create client fail!err%s", err)
}
time.Sleep(time.Millisecond * 20)
cli := pb.NewGreeterClient(conn)
for i := 0; i < 6; i++ {
if _, err := cli.SayHello(context.TODO(), &pb.HelloRequest{Age: 1, Name: "hello"}); err != nil {
t.Fatalf("call sayhello fail! err: %s", err)
}
}
assert.Equal(t, 2, testServerMap["server2"].SayHelloCount)
assert.Equal(t, 2, testServerMap["server5"].SayHelloCount)
assert.Equal(t, 2, testServerMap["server4"].SayHelloCount)
resetCount()
mockResolver.cancel("server4")
time.Sleep(time.Millisecond * 20)
for i := 0; i < 6; i++ {
if _, err := cli.SayHello(context.TODO(), &pb.HelloRequest{Age: 1, Name: "hello"}); err != nil {
t.Fatalf("call sayhello fail! err: %s", err)
}
}
assert.Equal(t, 2, testServerMap["server5"].SayHelloCount)
assert.Equal(t, 2, testServerMap["server2"].SayHelloCount)
assert.Equal(t, 2, testServerMap["server3"].SayHelloCount)
resetCount()
mockResolver.registry(testAppID, "server4", "127.0.0.1:18084", map[string]string{})
time.Sleep(time.Millisecond * 20)
for i := 0; i < 6; i++ {
if _, err := cli.SayHello(context.TODO(), &pb.HelloRequest{Age: 1, Name: "hello"}); err != nil {
t.Fatalf("call sayhello fail! err: %s", err)
}
}
assert.Equal(t, 2, testServerMap["server2"].SayHelloCount)
assert.Equal(t, 2, testServerMap["server5"].SayHelloCount)
assert.Equal(t, 2, testServerMap["server4"].SayHelloCount)
}

@ -0,0 +1,16 @@
package resolver
import (
"flag"
"fmt"
)
// RegisterTarget will register grpc discovery mock address flag
func RegisterTarget(target *string, discoveryID string) {
flag.CommandLine.StringVar(
target,
fmt.Sprintf("grpc.%s", discoveryID),
fmt.Sprintf("discovery://default/%s", discoveryID),
fmt.Sprintf("App's grpc target.\n example: -grpc.%s=\"127.0.0.1:9090\"", discoveryID),
)
}

@ -0,0 +1,332 @@
package warden
import (
"context"
"flag"
"fmt"
"math"
"net"
"os"
"sync"
"time"
"github.com/bilibili/kratos/pkg/conf/dsn"
"github.com/bilibili/kratos/pkg/log"
nmd "github.com/bilibili/kratos/pkg/net/metadata"
"github.com/bilibili/kratos/pkg/net/trace"
xtime "github.com/bilibili/kratos/pkg/time"
//this package is for json format response
_ "github.com/bilibili/kratos/pkg/net/rpc/warden/internal/encoding/json"
"github.com/bilibili/kratos/pkg/net/rpc/warden/internal/status"
"github.com/pkg/errors"
"google.golang.org/grpc"
"google.golang.org/grpc/keepalive"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer"
"google.golang.org/grpc/reflection"
)
var (
_grpcDSN string
_defaultSerConf = &ServerConfig{
Network: "tcp",
Addr: "0.0.0.0:9000",
Timeout: xtime.Duration(time.Second),
IdleTimeout: xtime.Duration(time.Second * 60),
MaxLifeTime: xtime.Duration(time.Hour * 2),
ForceCloseWait: xtime.Duration(time.Second * 20),
KeepAliveInterval: xtime.Duration(time.Second * 60),
KeepAliveTimeout: xtime.Duration(time.Second * 20),
}
_abortIndex int8 = math.MaxInt8 / 2
)
// ServerConfig is rpc server conf.
type ServerConfig struct {
// Network is grpc listen network,default value is tcp
Network string `dsn:"network"`
// Addr is grpc listen addr,default value is 0.0.0.0:9000
Addr string `dsn:"address"`
// Timeout is context timeout for per rpc call.
Timeout xtime.Duration `dsn:"query.timeout"`
// IdleTimeout is a duration for the amount of time after which an idle connection would be closed by sending a GoAway.
// Idleness duration is defined since the most recent time the number of outstanding RPCs became zero or the connection establishment.
IdleTimeout xtime.Duration `dsn:"query.idleTimeout"`
// MaxLifeTime is a duration for the maximum amount of time a connection may exist before it will be closed by sending a GoAway.
// A random jitter of +/-10% will be added to MaxConnectionAge to spread out connection storms.
MaxLifeTime xtime.Duration `dsn:"query.maxLife"`
// ForceCloseWait is an additive period after MaxLifeTime after which the connection will be forcibly closed.
ForceCloseWait xtime.Duration `dsn:"query.closeWait"`
// KeepAliveInterval is after a duration of this time if the server doesn't see any activity it pings the client to see if the transport is still alive.
KeepAliveInterval xtime.Duration `dsn:"query.keepaliveInterval"`
// KeepAliveTimeout is After having pinged for keepalive check, the server waits for a duration of Timeout and if no activity is seen even after that
// the connection is closed.
KeepAliveTimeout xtime.Duration `dsn:"query.keepaliveTimeout"`
}
// Server is the framework's server side instance, it contains the GrpcServer, interceptor and interceptors.
// Create an instance of Server, by using NewServer().
type Server struct {
conf *ServerConfig
mutex sync.RWMutex
server *grpc.Server
handlers []grpc.UnaryServerInterceptor
}
// handle return a new unary server interceptor for OpenTracing\Logging\LinkTimeout.
func (s *Server) handle() grpc.UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, args *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
var (
cancel func()
addr string
)
s.mutex.RLock()
conf := s.conf
s.mutex.RUnlock()
// get derived timeout from grpc context,
// compare with the warden configured,
// and use the minimum one
timeout := time.Duration(conf.Timeout)
if dl, ok := ctx.Deadline(); ok {
ctimeout := time.Until(dl)
if ctimeout-time.Millisecond*20 > 0 {
ctimeout = ctimeout - time.Millisecond*20
}
if timeout > ctimeout {
timeout = ctimeout
}
}
ctx, cancel = context.WithTimeout(ctx, timeout)
defer cancel()
// get grpc metadata(trace & remote_ip & color)
var t trace.Trace
cmd := nmd.MD{}
if gmd, ok := metadata.FromIncomingContext(ctx); ok {
for key, vals := range gmd {
if nmd.IsIncomingKey(key) {
cmd[key] = vals[0]
}
}
}
if t == nil {
t = trace.New(args.FullMethod)
} else {
t.SetTitle(args.FullMethod)
}
if pr, ok := peer.FromContext(ctx); ok {
addr = pr.Addr.String()
t.SetTag(trace.String(trace.TagAddress, addr))
}
defer t.Finish(&err)
// use common meta data context instead of grpc context
ctx = nmd.NewContext(ctx, cmd)
ctx = trace.NewContext(ctx, t)
resp, err = handler(ctx, req)
return resp, status.FromError(err).Err()
}
}
func init() {
addFlag(flag.CommandLine)
}
func addFlag(fs *flag.FlagSet) {
v := os.Getenv("GRPC")
if v == "" {
v = "tcp://0.0.0.0:9000/?timeout=1s&idle_timeout=60s"
}
fs.StringVar(&_grpcDSN, "grpc", v, "listen grpc dsn, or use GRPC env variable.")
fs.Var(&_grpcTarget, "grpc.target", "usage: -grpc.target=seq.service=127.0.0.1:9000 -grpc.target=fav.service=192.168.10.1:9000")
}
func parseDSN(rawdsn string) *ServerConfig {
conf := new(ServerConfig)
d, err := dsn.Parse(rawdsn)
if err != nil {
panic(errors.WithMessage(err, fmt.Sprintf("warden: invalid dsn: %s", rawdsn)))
}
if _, err = d.Bind(conf); err != nil {
panic(errors.WithMessage(err, fmt.Sprintf("warden: invalid dsn: %s", rawdsn)))
}
return conf
}
// NewServer returns a new blank Server instance with a default server interceptor.
func NewServer(conf *ServerConfig, opt ...grpc.ServerOption) (s *Server) {
if conf == nil {
if !flag.Parsed() {
fmt.Fprint(os.Stderr, "[warden] please call flag.Parse() before Init warden server, some configure may not effect\n")
}
conf = parseDSN(_grpcDSN)
}
s = new(Server)
if err := s.SetConfig(conf); err != nil {
panic(errors.Errorf("warden: set config failed!err: %s", err.Error()))
}
keepParam := grpc.KeepaliveParams(keepalive.ServerParameters{
MaxConnectionIdle: time.Duration(s.conf.IdleTimeout),
MaxConnectionAgeGrace: time.Duration(s.conf.ForceCloseWait),
Time: time.Duration(s.conf.KeepAliveInterval),
Timeout: time.Duration(s.conf.KeepAliveTimeout),
MaxConnectionAge: time.Duration(s.conf.MaxLifeTime),
})
opt = append(opt, keepParam, grpc.UnaryInterceptor(s.interceptor))
s.server = grpc.NewServer(opt...)
s.Use(s.recovery(), s.handle(), serverLogging(), s.stats(), s.validate())
return
}
// SetConfig hot reloads server config
func (s *Server) SetConfig(conf *ServerConfig) (err error) {
if conf == nil {
conf = _defaultSerConf
}
if conf.Timeout <= 0 {
conf.Timeout = xtime.Duration(time.Second)
}
if conf.IdleTimeout <= 0 {
conf.IdleTimeout = xtime.Duration(time.Second * 60)
}
if conf.MaxLifeTime <= 0 {
conf.MaxLifeTime = xtime.Duration(time.Hour * 2)
}
if conf.ForceCloseWait <= 0 {
conf.ForceCloseWait = xtime.Duration(time.Second * 20)
}
if conf.KeepAliveInterval <= 0 {
conf.KeepAliveInterval = xtime.Duration(time.Second * 60)
}
if conf.KeepAliveTimeout <= 0 {
conf.KeepAliveTimeout = xtime.Duration(time.Second * 20)
}
if conf.Addr == "" {
conf.Addr = "0.0.0.0:9000"
}
if conf.Network == "" {
conf.Network = "tcp"
}
s.mutex.Lock()
s.conf = conf
s.mutex.Unlock()
return nil
}
// interceptor is a single interceptor out of a chain of many interceptors.
// Execution is done in left-to-right order, including passing of context.
// For example ChainUnaryServer(one, two, three) will execute one before two before three, and three
// will see context changes of one and two.
func (s *Server) interceptor(ctx context.Context, req interface{}, args *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
var (
i int
chain grpc.UnaryHandler
)
n := len(s.handlers)
if n == 0 {
return handler(ctx, req)
}
chain = func(ic context.Context, ir interface{}) (interface{}, error) {
if i == n-1 {
return handler(ic, ir)
}
i++
return s.handlers[i](ic, ir, args, chain)
}
return s.handlers[0](ctx, req, args, chain)
}
// Server return the grpc server for registering service.
func (s *Server) Server() *grpc.Server {
return s.server
}
// Use attachs a global inteceptor to the server.
// For example, this is the right place for a rate limiter or error management inteceptor.
func (s *Server) Use(handlers ...grpc.UnaryServerInterceptor) *Server {
finalSize := len(s.handlers) + len(handlers)
if finalSize >= int(_abortIndex) {
panic("warden: server use too many handlers")
}
mergedHandlers := make([]grpc.UnaryServerInterceptor, finalSize)
copy(mergedHandlers, s.handlers)
copy(mergedHandlers[len(s.handlers):], handlers)
s.handlers = mergedHandlers
return s
}
// Run create a tcp listener and start goroutine for serving each incoming request.
// Run will return a non-nil error unless Stop or GracefulStop is called.
func (s *Server) Run(addr string) error {
lis, err := net.Listen("tcp", addr)
if err != nil {
err = errors.WithStack(err)
log.Error("failed to listen: %v", err)
return err
}
reflection.Register(s.server)
return s.Serve(lis)
}
// RunUnix create a unix listener and start goroutine for serving each incoming request.
// RunUnix will return a non-nil error unless Stop or GracefulStop is called.
func (s *Server) RunUnix(file string) error {
lis, err := net.Listen("unix", file)
if err != nil {
err = errors.WithStack(err)
log.Error("failed to listen: %v", err)
return err
}
reflection.Register(s.server)
return s.Serve(lis)
}
// Start create a new goroutine run server with configured listen addr
// will panic if any error happend
// return server itself
func (s *Server) Start() (*Server, error) {
lis, err := net.Listen(s.conf.Network, s.conf.Addr)
if err != nil {
return nil, err
}
reflection.Register(s.server)
go func() {
if err := s.Serve(lis); err != nil {
panic(err)
}
}()
return s, nil
}
// Serve accepts incoming connections on the listener lis, creating a new
// ServerTransport and service goroutine for each.
// Serve will return a non-nil error unless Stop or GracefulStop is called.
func (s *Server) Serve(lis net.Listener) error {
return s.server.Serve(lis)
}
// Shutdown stops the server gracefully. It stops the server from
// accepting new connections and RPCs and blocks until all the pending RPCs are
// finished or the context deadline is reached.
func (s *Server) Shutdown(ctx context.Context) (err error) {
ch := make(chan struct{})
go func() {
s.server.GracefulStop()
close(ch)
}()
select {
case <-ctx.Done():
s.server.Stop()
err = ctx.Err()
case <-ch:
}
return
}

@ -0,0 +1,570 @@
package warden
import (
"context"
"fmt"
"io"
"math/rand"
"net"
"os"
"strconv"
"strings"
"sync"
"testing"
"time"
"github.com/bilibili/kratos/pkg/ecode"
"github.com/bilibili/kratos/pkg/log"
nmd "github.com/bilibili/kratos/pkg/net/metadata"
"github.com/bilibili/kratos/pkg/net/netutil/breaker"
pb "github.com/bilibili/kratos/pkg/net/rpc/warden/internal/proto/testproto"
xtrace "github.com/bilibili/kratos/pkg/net/trace"
xtime "github.com/bilibili/kratos/pkg/time"
"github.com/pkg/errors"
"github.com/stretchr/testify/assert"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
const (
_separator = "\001"
)
var (
outPut []string
_testOnce sync.Once
server *Server
clientConfig = ClientConfig{
Dial: xtime.Duration(time.Second * 10),
Timeout: xtime.Duration(time.Second * 10),
Breaker: &breaker.Config{
Window: xtime.Duration(3 * time.Second),
Sleep: xtime.Duration(3 * time.Second),
Bucket: 10,
Ratio: 0.3,
Request: 20,
},
}
clientConfig2 = ClientConfig{
Dial: xtime.Duration(time.Second * 10),
Timeout: xtime.Duration(time.Second * 10),
Breaker: &breaker.Config{
Window: xtime.Duration(3 * time.Second),
Sleep: xtime.Duration(3 * time.Second),
Bucket: 10,
Ratio: 0.3,
Request: 20,
},
Method: map[string]*ClientConfig{`/testproto.Greeter/SayHello`: {Timeout: xtime.Duration(time.Millisecond * 200)}},
}
)
type helloServer struct {
t *testing.T
}
func (s *helloServer) SayHello(ctx context.Context, in *pb.HelloRequest) (*pb.HelloReply, error) {
if in.Name == "trace_test" {
t, isok := xtrace.FromContext(ctx)
if !isok {
t = xtrace.New("test title")
s.t.Fatalf("no trace extracted from server context")
}
newCtx := xtrace.NewContext(ctx, t)
if in.Age == 0 {
runClient(newCtx, &clientConfig, s.t, "trace_test", 1)
}
} else if in.Name == "recovery_test" {
panic("test recovery")
} else if in.Name == "graceful_shutdown" {
time.Sleep(time.Second * 3)
} else if in.Name == "timeout_test" {
if in.Age > 10 {
s.t.Fatalf("can not deliver requests over 10 times because of link timeout")
return &pb.HelloReply{Message: "Hello " + in.Name, Success: true}, nil
}
time.Sleep(time.Millisecond * 10)
_, err := runClient(ctx, &clientConfig, s.t, "timeout_test", in.Age+1)
return &pb.HelloReply{Message: "Hello " + in.Name, Success: true}, err
} else if in.Name == "timeout_test2" {
if in.Age > 10 {
s.t.Fatalf("can not deliver requests over 10 times because of link timeout")
return &pb.HelloReply{Message: "Hello " + in.Name, Success: true}, nil
}
time.Sleep(time.Millisecond * 10)
_, err := runClient(ctx, &clientConfig2, s.t, "timeout_test2", in.Age+1)
return &pb.HelloReply{Message: "Hello " + in.Name, Success: true}, err
} else if in.Name == "color_test" {
if in.Age == 0 {
resp, err := runClient(ctx, &clientConfig, s.t, "color_test", in.Age+1)
return resp, err
}
color := nmd.String(ctx, nmd.Color)
return &pb.HelloReply{Message: "Hello " + color, Success: true}, nil
} else if in.Name == "breaker_test" {
if rand.Intn(100) <= 50 {
return nil, status.Errorf(codes.ResourceExhausted, "test")
}
return &pb.HelloReply{Message: "Hello " + in.Name, Success: true}, nil
} else if in.Name == "error_detail" {
err, _ := ecode.Error(ecode.Code(123456), "test_error_detail").WithDetails(&pb.HelloReply{Success: true})
return nil, err
} else if in.Name == "ecode_status" {
reply := &pb.HelloReply{Message: "status", Success: true}
st, _ := ecode.Error(ecode.RequestErr, "RequestErr").WithDetails(reply)
return nil, st
} else if in.Name == "general_error" {
return nil, fmt.Errorf("haha is error")
} else if in.Name == "ecode_code_error" {
return nil, ecode.Conflict
} else if in.Name == "pb_error_error" {
return nil, ecode.Error(ecode.Code(11122), "haha")
} else if in.Name == "ecode_status_error" {
return nil, ecode.Error(ecode.RequestErr, "RequestErr")
} else if in.Name == "test_remote_port" {
if strconv.Itoa(int(in.Age)) != nmd.String(ctx, nmd.RemotePort) {
return nil, fmt.Errorf("error port %d", in.Age)
}
reply := &pb.HelloReply{Message: "status", Success: true}
return reply, nil
}
return &pb.HelloReply{Message: "Hello " + in.Name, Success: true}, nil
}
func (s *helloServer) StreamHello(ss pb.Greeter_StreamHelloServer) error {
for i := 0; i < 3; i++ {
in, err := ss.Recv()
if err == io.EOF {
return nil
}
if err != nil {
return err
}
ret := &pb.HelloReply{Message: "Hello " + in.Name, Success: true}
err = ss.Send(ret)
if err != nil {
return err
}
}
return nil
}
func runServer(t *testing.T, interceptors ...grpc.UnaryServerInterceptor) func() {
return func() {
server = NewServer(&ServerConfig{Addr: "127.0.0.1:8080", Timeout: xtime.Duration(time.Second)})
pb.RegisterGreeterServer(server.Server(), &helloServer{t})
server.Use(
func(ctx context.Context, req interface{}, args *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
outPut = append(outPut, "1")
resp, err := handler(ctx, req)
outPut = append(outPut, "2")
return resp, err
},
func(ctx context.Context, req interface{}, args *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
outPut = append(outPut, "3")
resp, err := handler(ctx, req)
outPut = append(outPut, "4")
return resp, err
})
if _, err := server.Start(); err != nil {
t.Fatal(err)
}
}
}
func runClient(ctx context.Context, cc *ClientConfig, t *testing.T, name string, age int32, interceptors ...grpc.UnaryClientInterceptor) (resp *pb.HelloReply, err error) {
client := NewClient(cc)
client.Use(interceptors...)
conn, err := client.Dial(context.Background(), "127.0.0.1:8080")
if err != nil {
panic(fmt.Errorf("did not connect: %v,req: %v %v", err, name, age))
}
defer conn.Close()
c := pb.NewGreeterClient(conn)
resp, err = c.SayHello(ctx, &pb.HelloRequest{Name: name, Age: age})
return
}
func TestMain(t *testing.T) {
log.Init(nil)
}
func Test_Warden(t *testing.T) {
xtrace.Init(&xtrace.Config{Addr: "127.0.0.1:9982", Timeout: xtime.Duration(time.Second * 3)})
go _testOnce.Do(runServer(t))
go runClient(context.Background(), &clientConfig, t, "trace_test", 0)
testTrace(t, 9982, false)
testInterceptorChain(t)
testValidation(t)
testServerRecovery(t)
testClientRecovery(t)
testErrorDetail(t)
testECodeStatus(t)
testColorPass(t)
testRemotePort(t)
testLinkTimeout(t)
testClientConfig(t)
testBreaker(t)
testAllErrorCase(t)
testGracefulShutDown(t)
}
func testValidation(t *testing.T) {
_, err := runClient(context.Background(), &clientConfig, t, "", 0)
if !ecode.RequestErr.Equal(err) {
t.Fatalf("testValidation should return ecode.RequestErr,but is %v", err)
}
}
func testAllErrorCase(t *testing.T) {
// } else if in.Name == "general_error" {
// return nil, fmt.Errorf("haha is error")
// } else if in.Name == "ecode_code_error" {
// return nil, ecode.CreativeArticleTagErr
// } else if in.Name == "pb_error_error" {
// return nil, &errpb.Error{ErrCode: 11122, ErrMessage: "haha"}
// } else if in.Name == "ecode_status_error" {
// return nil, ecode.Error(ecode.RequestErr, "RequestErr")
// }
ctx := context.Background()
t.Run("general_error", func(t *testing.T) {
_, err := runClient(ctx, &clientConfig, t, "general_error", 0)
assert.Contains(t, err.Error(), "haha")
ec := ecode.Cause(err)
assert.Equal(t, -500, ec.Code())
// remove this assert in future
assert.Equal(t, "-500", ec.Message())
})
t.Run("ecode_code_error", func(t *testing.T) {
_, err := runClient(ctx, &clientConfig, t, "ecode_code_error", 0)
ec := ecode.Cause(err)
assert.Equal(t, ecode.Conflict.Code(), ec.Code())
// remove this assert in future
assert.Equal(t, "20024", ec.Message())
})
t.Run("pb_error_error", func(t *testing.T) {
_, err := runClient(ctx, &clientConfig, t, "pb_error_error", 0)
ec := ecode.Cause(err)
assert.Equal(t, 11122, ec.Code())
assert.Equal(t, "haha", ec.Message())
})
t.Run("ecode_status_error", func(t *testing.T) {
_, err := runClient(ctx, &clientConfig, t, "ecode_status_error", 0)
ec := ecode.Cause(err)
assert.Equal(t, ecode.RequestErr.Code(), ec.Code())
assert.Equal(t, "RequestErr", ec.Message())
})
}
func testBreaker(t *testing.T) {
client := NewClient(&clientConfig)
conn, err := client.Dial(context.Background(), "127.0.0.1:8080")
if err != nil {
t.Fatalf("did not connect: %v", err)
}
defer conn.Close()
c := pb.NewGreeterClient(conn)
for i := 0; i < 35; i++ {
_, err := c.SayHello(context.Background(), &pb.HelloRequest{Name: "breaker_test"})
if err != nil {
if ecode.ServiceUnavailable.Equal(err) {
return
}
}
}
t.Fatalf("testBreaker failed!No breaker was triggered")
}
func testColorPass(t *testing.T) {
ctx := nmd.NewContext(context.Background(), nmd.MD{
nmd.Color: "red",
})
resp, err := runClient(ctx, &clientConfig, t, "color_test", 0)
if err != nil {
t.Fatalf("testColorPass return error %v", err)
}
if resp == nil || resp.Message != "Hello red" {
t.Fatalf("testColorPass resp.Message must be red,%v", *resp)
}
}
func testRemotePort(t *testing.T) {
ctx := nmd.NewContext(context.Background(), nmd.MD{
nmd.RemotePort: "8000",
})
_, err := runClient(ctx, &clientConfig, t, "test_remote_port", 8000)
if err != nil {
t.Fatalf("testRemotePort return error %v", err)
}
}
func testLinkTimeout(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*200)
defer cancel()
_, err := runClient(ctx, &clientConfig, t, "timeout_test", 0)
if err == nil {
t.Fatalf("testLinkTimeout must return error")
}
if !ecode.Deadline.Equal(err) {
t.Fatalf("testLinkTimeout must return error RPCDeadline,err:%v", err)
}
}
func testClientConfig(t *testing.T) {
_, err := runClient(context.Background(), &clientConfig2, t, "timeout_test2", 0)
if err == nil {
t.Fatalf("testLinkTimeout must return error")
}
if !ecode.Deadline.Equal(err) {
t.Fatalf("testLinkTimeout must return error RPCDeadline,err:%v", err)
}
}
func testGracefulShutDown(t *testing.T) {
wg := sync.WaitGroup{}
for i := 0; i < 10; i++ {
wg.Add(1)
go func() {
defer wg.Done()
resp, err := runClient(context.Background(), &clientConfig, t, "graceful_shutdown", 0)
if err != nil {
panic(fmt.Errorf("run graceful_shutdown client return(%v)", err))
}
if !resp.Success || resp.Message != "Hello graceful_shutdown" {
panic(fmt.Errorf("run graceful_shutdown client return(%v,%v)", err, *resp))
}
}()
}
go func() {
time.Sleep(time.Second)
ctx, cancel := context.WithTimeout(context.Background(), time.Second*3)
defer cancel()
server.Shutdown(ctx)
}()
wg.Wait()
}
func testClientRecovery(t *testing.T) {
ctx := context.Background()
client := NewClient(&clientConfig)
client.Use(func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) (ret error) {
invoker(ctx, method, req, reply, cc, opts...)
panic("client recovery test")
})
conn, err := client.Dial(ctx, "127.0.0.1:8080")
if err != nil {
t.Fatalf("did not connect: %v", err)
}
defer conn.Close()
c := pb.NewGreeterClient(conn)
_, err = c.SayHello(ctx, &pb.HelloRequest{Name: "other_test", Age: 0})
if err == nil {
t.Fatalf("recovery must return error")
}
e, ok := errors.Cause(err).(ecode.Codes)
if !ok {
t.Fatalf("recovery must return ecode error")
}
if !ecode.ServerErr.Equal(e) {
t.Fatalf("recovery must return ecode.RPCClientErr")
}
}
func testServerRecovery(t *testing.T) {
ctx := context.Background()
client := NewClient(&clientConfig)
conn, err := client.Dial(ctx, "127.0.0.1:8080")
if err != nil {
t.Fatalf("did not connect: %v", err)
}
defer conn.Close()
c := pb.NewGreeterClient(conn)
_, err = c.SayHello(ctx, &pb.HelloRequest{Name: "recovery_test", Age: 0})
if err == nil {
t.Fatalf("recovery must return error")
}
e, ok := errors.Cause(err).(ecode.Codes)
if !ok {
t.Fatalf("recovery must return ecode error")
}
if e.Code() != ecode.ServerErr.Code() {
t.Fatalf("recovery must return ecode.ServerErr")
}
}
func testInterceptorChain(t *testing.T) {
// NOTE: don't delete this sleep
time.Sleep(time.Millisecond)
if outPut[0] != "1" || outPut[1] != "3" || outPut[2] != "1" || outPut[3] != "3" || outPut[4] != "4" || outPut[5] != "2" || outPut[6] != "4" || outPut[7] != "2" {
t.Fatalf("outPut shoud be [1 3 1 3 4 2 4 2]!")
}
}
func testErrorDetail(t *testing.T) {
_, err := runClient(context.Background(), &clientConfig2, t, "error_detail", 0)
if err == nil {
t.Fatalf("testErrorDetail must return error")
}
if ec, ok := errors.Cause(err).(ecode.Codes); !ok {
t.Fatalf("testErrorDetail must return ecode error")
} else if ec.Code() != 123456 || ec.Message() != "test_error_detail" || len(ec.Details()) == 0 {
t.Fatalf("testErrorDetail must return code:123456 and message:test_error_detail, code: %d, message: %s, details length: %d", ec.Code(), ec.Message(), len(ec.Details()))
} else if _, ok := ec.Details()[len(ec.Details())-1].(*pb.HelloReply); !ok {
t.Fatalf("expect get pb.HelloReply %#v", ec.Details()[len(ec.Details())-1])
}
}
func testECodeStatus(t *testing.T) {
_, err := runClient(context.Background(), &clientConfig2, t, "ecode_status", 0)
if err == nil {
t.Fatalf("testECodeStatus must return error")
}
st, ok := errors.Cause(err).(*ecode.Status)
if !ok {
t.Fatalf("testECodeStatus must return *ecode.Status")
}
if st.Code() != int(ecode.RequestErr) && st.Message() != "RequestErr" {
t.Fatalf("testECodeStatus must return code: -400, message: RequestErr get: code: %d, message: %s", st.Code(), st.Message())
}
detail := st.Details()[0].(*pb.HelloReply)
if !detail.Success || detail.Message != "status" {
t.Fatalf("wrong detail")
}
}
func testTrace(t *testing.T, port int, isStream bool) {
listener, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: port})
if err != nil {
t.Fatalf("listent udp failed, %v", err)
return
}
data := make([]byte, 1024)
strs := make([][]string, 0)
for {
var n int
n, _, err = listener.ReadFromUDP(data)
if err != nil {
t.Fatalf("read from udp faild, %v", err)
}
str := strings.Split(string(data[:n]), _separator)
strs = append(strs, str)
if len(strs) == 2 {
break
}
}
if len(strs[0]) == 0 || len(strs[1]) == 0 {
t.Fatalf("trace str's length must be greater than 0")
}
}
func BenchmarkServer(b *testing.B) {
server := NewServer(&ServerConfig{Addr: "127.0.0.1:8080", Timeout: xtime.Duration(time.Second)})
go func() {
pb.RegisterGreeterServer(server.Server(), &helloServer{})
if _, err := server.Start(); err != nil {
os.Exit(0)
return
}
}()
defer func() {
server.Server().Stop()
}()
client := NewClient(&clientConfig)
conn, err := client.Dial(context.Background(), "127.0.0.1:8080")
if err != nil {
conn.Close()
b.Fatalf("did not connect: %v", err)
}
b.ResetTimer()
b.RunParallel(func(parab *testing.PB) {
for parab.Next() {
c := pb.NewGreeterClient(conn)
resp, err := c.SayHello(context.Background(), &pb.HelloRequest{Name: "benchmark_test", Age: 1})
if err != nil {
conn.Close()
b.Fatalf("c.SayHello failed: %v,req: %v %v", err, "benchmark", 1)
}
if !resp.Success {
b.Error("repsonse not success!")
}
}
})
conn.Close()
}
func TestParseDSN(t *testing.T) {
dsn := "tcp://0.0.0.0:80/?timeout=100ms&idleTimeout=120s&keepaliveInterval=120s&keepaliveTimeout=20s&maxLife=4h&closeWait=3s"
config := parseDSN(dsn)
if config.Network != "tcp" || config.Addr != "0.0.0.0:80" || time.Duration(config.Timeout) != time.Millisecond*100 ||
time.Duration(config.IdleTimeout) != time.Second*120 || time.Duration(config.KeepAliveInterval) != time.Second*120 ||
time.Duration(config.MaxLifeTime) != time.Hour*4 || time.Duration(config.ForceCloseWait) != time.Second*3 || time.Duration(config.KeepAliveTimeout) != time.Second*20 {
t.Fatalf("parseDSN(%s) not compare config result(%+v)", dsn, config)
}
dsn = "unix:///temp/warden.sock?timeout=300ms"
config = parseDSN(dsn)
if config.Network != "unix" || config.Addr != "/temp/warden.sock" || time.Duration(config.Timeout) != time.Millisecond*300 {
t.Fatalf("parseDSN(%s) not compare config result(%+v)", dsn, config)
}
}
type testServer struct {
helloFn func(ctx context.Context, req *pb.HelloRequest) (*pb.HelloReply, error)
}
func (t *testServer) SayHello(ctx context.Context, req *pb.HelloRequest) (*pb.HelloReply, error) {
return t.helloFn(ctx, req)
}
func (t *testServer) StreamHello(pb.Greeter_StreamHelloServer) error { panic("not implemented") }
// NewTestServerClient .
func NewTestServerClient(invoker func(ctx context.Context, req *pb.HelloRequest) (*pb.HelloReply, error), svrcfg *ServerConfig, clicfg *ClientConfig) (pb.GreeterClient, func() error) {
srv := NewServer(svrcfg)
pb.RegisterGreeterServer(srv.Server(), &testServer{helloFn: invoker})
lis, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
panic(err)
}
ch := make(chan bool, 1)
go func() {
ch <- true
srv.Serve(lis)
}()
<-ch
println(lis.Addr().String())
conn, err := NewConn(lis.Addr().String())
if err != nil {
panic(err)
}
return pb.NewGreeterClient(conn), func() error { return srv.Shutdown(context.Background()) }
}
func TestMetadata(t *testing.T) {
cli, cancel := NewTestServerClient(func(ctx context.Context, req *pb.HelloRequest) (*pb.HelloReply, error) {
assert.Equal(t, "red", nmd.String(ctx, nmd.Color))
assert.Equal(t, "2.2.3.3", nmd.String(ctx, nmd.RemoteIP))
assert.Equal(t, "2233", nmd.String(ctx, nmd.RemotePort))
return &pb.HelloReply{}, nil
}, nil, nil)
defer cancel()
ctx := nmd.NewContext(context.Background(), nmd.MD{
nmd.Color: "red",
nmd.RemoteIP: "2.2.3.3",
nmd.RemotePort: "2233",
})
_, err := cli.SayHello(ctx, &pb.HelloRequest{Name: "test"})
assert.Nil(t, err)
}

@ -0,0 +1,25 @@
package warden
import (
"context"
"strconv"
nmd "github.com/bilibili/kratos/pkg/net/rpc/warden/internal/metadata"
"github.com/bilibili/kratos/pkg/stat/sys/cpu"
"google.golang.org/grpc"
gmd "google.golang.org/grpc/metadata"
)
func (s *Server) stats() grpc.UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, args *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
resp, err = handler(ctx, req)
var cpustat cpu.Stat
cpu.ReadStat(&cpustat)
if cpustat.Usage != 0 {
trailer := gmd.Pairs([]string{nmd.CPUUsage, strconv.FormatInt(int64(cpustat.Usage), 10)}...)
grpc.SetTrailer(ctx, trailer)
}
return
}
}

@ -0,0 +1,31 @@
package warden
import (
"context"
"github.com/bilibili/kratos/pkg/ecode"
"google.golang.org/grpc"
"gopkg.in/go-playground/validator.v9"
)
var validate = validator.New()
// Validate return a client interceptor validate incoming request per RPC call.
func (s *Server) validate() grpc.UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, args *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
if err = validate.Struct(req); err != nil {
err = ecode.Error(ecode.RequestErr, err.Error())
return
}
resp, err = handler(ctx, req)
return
}
}
// RegisterValidation adds a validation Func to a Validate's map of validators denoted by the key
// NOTE: if the key already exists, the previous validation function will be replaced.
// NOTE: this method is not thread-safe it is intended that these all be registered prior to any validation
func (s *Server) RegisterValidation(key string, fn validator.Func) error {
return validate.RegisterValidation(key, fn)
}
Loading…
Cancel
Save